Compare commits
	
		
			6 Commits
		
	
	
		
			version/20
			...
			policies/p
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| b3883f7fbf | |||
| 87c6b0128a | |||
| b243c97916 | |||
| 3f66527521 | |||
| 2f7c258657 | |||
| 917c90374f | 
| @ -1,5 +1,5 @@ | |||||||
| [bumpversion] | [bumpversion] | ||||||
| current_version = 2025.2.4 | current_version = 2024.12.3 | ||||||
| tag = True | tag = True | ||||||
| commit = True | commit = True | ||||||
| parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))? | parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))? | ||||||
|  | |||||||
							
								
								
									
										2
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								Makefile
									
									
									
									
									
								
							| @ -21,7 +21,7 @@ pg_name := $(shell python -m authentik.lib.config postgresql.name 2>/dev/null) | |||||||
| CODESPELL_ARGS = -D - -D .github/codespell-dictionary.txt \ | CODESPELL_ARGS = -D - -D .github/codespell-dictionary.txt \ | ||||||
| 		-I .github/codespell-words.txt \ | 		-I .github/codespell-words.txt \ | ||||||
| 		-S 'web/src/locales/**' \ | 		-S 'web/src/locales/**' \ | ||||||
| 		-S 'website/docs/developer-docs/api/reference/**' \ | 		-S 'website/developer-docs/api/reference/**' \ | ||||||
| 		-S '**/node_modules/**' \ | 		-S '**/node_modules/**' \ | ||||||
| 		-S '**/dist/**' \ | 		-S '**/dist/**' \ | ||||||
| 		$(PY_SOURCES) \ | 		$(PY_SOURCES) \ | ||||||
|  | |||||||
| @ -20,8 +20,8 @@ Even if the issue is not a CVE, we still greatly appreciate your help in hardeni | |||||||
|  |  | ||||||
| | Version   | Supported | | | Version   | Supported | | ||||||
| | --------- | --------- | | | --------- | --------- | | ||||||
|  | | 2024.10.x | ✅        | | ||||||
| | 2024.12.x | ✅        | | | 2024.12.x | ✅        | | ||||||
| | 2025.2.x  | ✅        | |  | ||||||
|  |  | ||||||
| ## Reporting a Vulnerability | ## Reporting a Vulnerability | ||||||
|  |  | ||||||
|  | |||||||
| @ -2,7 +2,7 @@ | |||||||
|  |  | ||||||
| from os import environ | from os import environ | ||||||
|  |  | ||||||
| __version__ = "2025.2.4" | __version__ = "2024.12.3" | ||||||
| ENV_GIT_HASH_KEY = "GIT_BUILD_HASH" | ENV_GIT_HASH_KEY = "GIT_BUILD_HASH" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| @ -59,7 +59,7 @@ class SystemInfoSerializer(PassiveSerializer): | |||||||
|             if not isinstance(value, str): |             if not isinstance(value, str): | ||||||
|                 continue |                 continue | ||||||
|             actual_value = value |             actual_value = value | ||||||
|             if raw_session is not None and raw_session in actual_value: |             if raw_session in actual_value: | ||||||
|                 actual_value = actual_value.replace( |                 actual_value = actual_value.replace( | ||||||
|                     raw_session, SafeExceptionReporterFilter.cleansed_substitute |                     raw_session, SafeExceptionReporterFilter.cleansed_substitute | ||||||
|                 ) |                 ) | ||||||
|  | |||||||
| @ -50,6 +50,7 @@ from authentik.enterprise.providers.microsoft_entra.models import ( | |||||||
|     MicrosoftEntraProviderGroup, |     MicrosoftEntraProviderGroup, | ||||||
|     MicrosoftEntraProviderUser, |     MicrosoftEntraProviderUser, | ||||||
| ) | ) | ||||||
|  | from authentik.enterprise.providers.rac.models import ConnectionToken | ||||||
| from authentik.enterprise.providers.ssf.models import StreamEvent | from authentik.enterprise.providers.ssf.models import StreamEvent | ||||||
| from authentik.enterprise.stages.authenticator_endpoint_gdtc.models import ( | from authentik.enterprise.stages.authenticator_endpoint_gdtc.models import ( | ||||||
|     EndpointDevice, |     EndpointDevice, | ||||||
| @ -71,7 +72,6 @@ from authentik.providers.oauth2.models import ( | |||||||
|     DeviceToken, |     DeviceToken, | ||||||
|     RefreshToken, |     RefreshToken, | ||||||
| ) | ) | ||||||
| from authentik.providers.rac.models import ConnectionToken |  | ||||||
| from authentik.providers.scim.models import SCIMProviderGroup, SCIMProviderUser | from authentik.providers.scim.models import SCIMProviderGroup, SCIMProviderUser | ||||||
| from authentik.rbac.models import Role | from authentik.rbac.models import Role | ||||||
| from authentik.sources.scim.models import SCIMSourceGroup, SCIMSourceUser | from authentik.sources.scim.models import SCIMSourceGroup, SCIMSourceUser | ||||||
|  | |||||||
| @ -4,7 +4,6 @@ from json import loads | |||||||
|  |  | ||||||
| from django.db.models import Prefetch | from django.db.models import Prefetch | ||||||
| from django.http import Http404 | from django.http import Http404 | ||||||
| from django.utils.translation import gettext as _ |  | ||||||
| from django_filters.filters import CharFilter, ModelMultipleChoiceFilter | from django_filters.filters import CharFilter, ModelMultipleChoiceFilter | ||||||
| from django_filters.filterset import FilterSet | from django_filters.filterset import FilterSet | ||||||
| from drf_spectacular.utils import ( | from drf_spectacular.utils import ( | ||||||
| @ -82,37 +81,9 @@ class GroupSerializer(ModelSerializer): | |||||||
|         if not self.instance or not parent: |         if not self.instance or not parent: | ||||||
|             return parent |             return parent | ||||||
|         if str(parent.group_uuid) == str(self.instance.group_uuid): |         if str(parent.group_uuid) == str(self.instance.group_uuid): | ||||||
|             raise ValidationError(_("Cannot set group as parent of itself.")) |             raise ValidationError("Cannot set group as parent of itself.") | ||||||
|         return parent |         return parent | ||||||
|  |  | ||||||
|     def validate_is_superuser(self, superuser: bool): |  | ||||||
|         """Ensure that the user creating this group has permissions to set the superuser flag""" |  | ||||||
|         request: Request = self.context.get("request", None) |  | ||||||
|         if not request: |  | ||||||
|             return superuser |  | ||||||
|         # If we're updating an instance, and the state hasn't changed, we don't need to check perms |  | ||||||
|         if self.instance and superuser == self.instance.is_superuser: |  | ||||||
|             return superuser |  | ||||||
|         user: User = request.user |  | ||||||
|         perm = ( |  | ||||||
|             "authentik_core.enable_group_superuser" |  | ||||||
|             if superuser |  | ||||||
|             else "authentik_core.disable_group_superuser" |  | ||||||
|         ) |  | ||||||
|         has_perm = user.has_perm(perm) |  | ||||||
|         if self.instance and not has_perm: |  | ||||||
|             has_perm = user.has_perm(perm, self.instance) |  | ||||||
|         if not has_perm: |  | ||||||
|             raise ValidationError( |  | ||||||
|                 _( |  | ||||||
|                     ( |  | ||||||
|                         "User does not have permission to set " |  | ||||||
|                         "superuser status to {superuser_status}." |  | ||||||
|                     ).format_map({"superuser_status": superuser}) |  | ||||||
|                 ) |  | ||||||
|             ) |  | ||||||
|         return superuser |  | ||||||
|  |  | ||||||
|     class Meta: |     class Meta: | ||||||
|         model = Group |         model = Group | ||||||
|         fields = [ |         fields = [ | ||||||
|  | |||||||
| @ -1,14 +1,13 @@ | |||||||
| """User API Views""" | """User API Views""" | ||||||
|  |  | ||||||
| from datetime import timedelta | from datetime import timedelta | ||||||
| from importlib import import_module |  | ||||||
| from json import loads | from json import loads | ||||||
| from typing import Any | from typing import Any | ||||||
|  |  | ||||||
| from django.conf import settings |  | ||||||
| from django.contrib.auth import update_session_auth_hash | from django.contrib.auth import update_session_auth_hash | ||||||
| from django.contrib.auth.models import Permission | from django.contrib.auth.models import Permission | ||||||
| from django.contrib.sessions.backends.base import SessionBase | from django.contrib.sessions.backends.cache import KEY_PREFIX | ||||||
|  | from django.core.cache import cache | ||||||
| from django.db.models.functions import ExtractHour | from django.db.models.functions import ExtractHour | ||||||
| from django.db.transaction import atomic | from django.db.transaction import atomic | ||||||
| from django.db.utils import IntegrityError | from django.db.utils import IntegrityError | ||||||
| @ -92,7 +91,6 @@ from authentik.stages.email.tasks import send_mails | |||||||
| from authentik.stages.email.utils import TemplateEmailMessage | from authentik.stages.email.utils import TemplateEmailMessage | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
| SessionStore: SessionBase = import_module(settings.SESSION_ENGINE).SessionStore |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class UserGroupSerializer(ModelSerializer): | class UserGroupSerializer(ModelSerializer): | ||||||
| @ -375,7 +373,7 @@ class UsersFilter(FilterSet): | |||||||
|         method="filter_attributes", |         method="filter_attributes", | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     is_superuser = BooleanFilter(field_name="ak_groups", method="filter_is_superuser") |     is_superuser = BooleanFilter(field_name="ak_groups", lookup_expr="is_superuser") | ||||||
|     uuid = UUIDFilter(field_name="uuid") |     uuid = UUIDFilter(field_name="uuid") | ||||||
|  |  | ||||||
|     path = CharFilter(field_name="path") |     path = CharFilter(field_name="path") | ||||||
| @ -393,11 +391,6 @@ class UsersFilter(FilterSet): | |||||||
|         queryset=Group.objects.all().order_by("name"), |         queryset=Group.objects.all().order_by("name"), | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     def filter_is_superuser(self, queryset, name, value): |  | ||||||
|         if value: |  | ||||||
|             return queryset.filter(ak_groups__is_superuser=True).distinct() |  | ||||||
|         return queryset.exclude(ak_groups__is_superuser=True).distinct() |  | ||||||
|  |  | ||||||
|     def filter_attributes(self, queryset, name, value): |     def filter_attributes(self, queryset, name, value): | ||||||
|         """Filter attributes by query args""" |         """Filter attributes by query args""" | ||||||
|         try: |         try: | ||||||
| @ -776,8 +769,7 @@ class UserViewSet(UsedByMixin, ModelViewSet): | |||||||
|         if not instance.is_active: |         if not instance.is_active: | ||||||
|             sessions = AuthenticatedSession.objects.filter(user=instance) |             sessions = AuthenticatedSession.objects.filter(user=instance) | ||||||
|             session_ids = sessions.values_list("session_key", flat=True) |             session_ids = sessions.values_list("session_key", flat=True) | ||||||
|             for session in session_ids: |             cache.delete_many(f"{KEY_PREFIX}{session}" for session in session_ids) | ||||||
|                 SessionStore(session).delete() |  | ||||||
|             sessions.delete() |             sessions.delete() | ||||||
|             LOGGER.debug("Deleted user's sessions", user=instance.username) |             LOGGER.debug("Deleted user's sessions", user=instance.username) | ||||||
|         return response |         return response | ||||||
|  | |||||||
| @ -1,26 +0,0 @@ | |||||||
| # Generated by Django 5.0.11 on 2025-01-30 23:55 |  | ||||||
|  |  | ||||||
| from django.db import migrations |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class Migration(migrations.Migration): |  | ||||||
|  |  | ||||||
|     dependencies = [ |  | ||||||
|         ("authentik_core", "0042_authenticatedsession_authentik_c_expires_08251d_idx_and_more"), |  | ||||||
|     ] |  | ||||||
|  |  | ||||||
|     operations = [ |  | ||||||
|         migrations.AlterModelOptions( |  | ||||||
|             name="group", |  | ||||||
|             options={ |  | ||||||
|                 "permissions": [ |  | ||||||
|                     ("add_user_to_group", "Add user to group"), |  | ||||||
|                     ("remove_user_from_group", "Remove user from group"), |  | ||||||
|                     ("enable_group_superuser", "Enable superuser status"), |  | ||||||
|                     ("disable_group_superuser", "Disable superuser status"), |  | ||||||
|                 ], |  | ||||||
|                 "verbose_name": "Group", |  | ||||||
|                 "verbose_name_plural": "Groups", |  | ||||||
|             }, |  | ||||||
|         ), |  | ||||||
|     ] |  | ||||||
| @ -204,8 +204,6 @@ class Group(SerializerModel, AttributesMixin): | |||||||
|         permissions = [ |         permissions = [ | ||||||
|             ("add_user_to_group", _("Add user to group")), |             ("add_user_to_group", _("Add user to group")), | ||||||
|             ("remove_user_from_group", _("Remove user from group")), |             ("remove_user_from_group", _("Remove user from group")), | ||||||
|             ("enable_group_superuser", _("Enable superuser status")), |  | ||||||
|             ("disable_group_superuser", _("Disable superuser status")), |  | ||||||
|         ] |         ] | ||||||
|  |  | ||||||
|     def __str__(self): |     def __str__(self): | ||||||
|  | |||||||
| @ -1,10 +1,7 @@ | |||||||
| """authentik core signals""" | """authentik core signals""" | ||||||
|  |  | ||||||
| from importlib import import_module |  | ||||||
|  |  | ||||||
| from django.conf import settings |  | ||||||
| from django.contrib.auth.signals import user_logged_in, user_logged_out | from django.contrib.auth.signals import user_logged_in, user_logged_out | ||||||
| from django.contrib.sessions.backends.base import SessionBase | from django.contrib.sessions.backends.cache import KEY_PREFIX | ||||||
| from django.core.cache import cache | from django.core.cache import cache | ||||||
| from django.core.signals import Signal | from django.core.signals import Signal | ||||||
| from django.db.models import Model | from django.db.models import Model | ||||||
| @ -28,7 +25,6 @@ password_changed = Signal() | |||||||
| login_failed = Signal() | login_failed = Signal() | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
| SessionStore: SessionBase = import_module(settings.SESSION_ENGINE).SessionStore |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @receiver(post_save, sender=Application) | @receiver(post_save, sender=Application) | ||||||
| @ -64,7 +60,8 @@ def user_logged_out_session(sender, request: HttpRequest, user: User, **_): | |||||||
| @receiver(pre_delete, sender=AuthenticatedSession) | @receiver(pre_delete, sender=AuthenticatedSession) | ||||||
| def authenticated_session_delete(sender: type[Model], instance: "AuthenticatedSession", **_): | def authenticated_session_delete(sender: type[Model], instance: "AuthenticatedSession", **_): | ||||||
|     """Delete session when authenticated session is deleted""" |     """Delete session when authenticated session is deleted""" | ||||||
|     SessionStore(instance.session_key).delete() |     cache_key = f"{KEY_PREFIX}{instance.session_key}" | ||||||
|  |     cache.delete(cache_key) | ||||||
|  |  | ||||||
|  |  | ||||||
| @receiver(pre_save) | @receiver(pre_save) | ||||||
|  | |||||||
| @ -35,7 +35,8 @@ from authentik.flows.planner import ( | |||||||
|     FlowPlanner, |     FlowPlanner, | ||||||
| ) | ) | ||||||
| from authentik.flows.stage import StageView | from authentik.flows.stage import StageView | ||||||
| from authentik.flows.views.executor import NEXT_ARG_NAME, SESSION_KEY_GET | from authentik.flows.views.executor import NEXT_ARG_NAME, SESSION_KEY_GET, SESSION_KEY_PLAN | ||||||
|  | from authentik.lib.utils.urls import redirect_with_qs | ||||||
| from authentik.lib.views import bad_request_message | from authentik.lib.views import bad_request_message | ||||||
| from authentik.policies.denied import AccessDeniedResponse | from authentik.policies.denied import AccessDeniedResponse | ||||||
| from authentik.policies.utils import delete_none_values | from authentik.policies.utils import delete_none_values | ||||||
| @ -46,9 +47,8 @@ from authentik.stages.user_write.stage import PLAN_CONTEXT_USER_PATH | |||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  |  | ||||||
| PLAN_CONTEXT_SOURCE_GROUPS = "source_groups" |  | ||||||
| SESSION_KEY_SOURCE_FLOW_STAGES = "authentik/flows/source_flow_stages" |  | ||||||
| SESSION_KEY_OVERRIDE_FLOW_TOKEN = "authentik/flows/source_override_flow_token"  # nosec | SESSION_KEY_OVERRIDE_FLOW_TOKEN = "authentik/flows/source_override_flow_token"  # nosec | ||||||
|  | PLAN_CONTEXT_SOURCE_GROUPS = "source_groups" | ||||||
|  |  | ||||||
|  |  | ||||||
| class MessageStage(StageView): | class MessageStage(StageView): | ||||||
| @ -219,17 +219,9 @@ class SourceFlowManager: | |||||||
|             } |             } | ||||||
|         ) |         ) | ||||||
|         flow_context.update(self.policy_context) |         flow_context.update(self.policy_context) | ||||||
|         flow_context.setdefault(PLAN_CONTEXT_REDIRECT, final_redirect) |  | ||||||
|  |  | ||||||
|         if not flow: |  | ||||||
|             # We only check for the flow token here if we don't have a flow, otherwise we rely on |  | ||||||
|             # SESSION_KEY_SOURCE_FLOW_STAGES to delegate the usage of this token and dynamically add |  | ||||||
|             # stages that deal with this token to return to another flow |  | ||||||
|         if SESSION_KEY_OVERRIDE_FLOW_TOKEN in self.request.session: |         if SESSION_KEY_OVERRIDE_FLOW_TOKEN in self.request.session: | ||||||
|             token: FlowToken = self.request.session.get(SESSION_KEY_OVERRIDE_FLOW_TOKEN) |             token: FlowToken = self.request.session.get(SESSION_KEY_OVERRIDE_FLOW_TOKEN) | ||||||
|                 self._logger.info( |             self._logger.info("Replacing source flow with overridden flow", flow=token.flow.slug) | ||||||
|                     "Replacing source flow with overridden flow", flow=token.flow.slug |  | ||||||
|                 ) |  | ||||||
|             plan = token.plan |             plan = token.plan | ||||||
|             plan.context[PLAN_CONTEXT_IS_RESTORED] = token |             plan.context[PLAN_CONTEXT_IS_RESTORED] = token | ||||||
|             plan.context.update(flow_context) |             plan.context.update(flow_context) | ||||||
| @ -238,9 +230,17 @@ class SourceFlowManager: | |||||||
|             if stages: |             if stages: | ||||||
|                 for stage in stages: |                 for stage in stages: | ||||||
|                     plan.append_stage(stage) |                     plan.append_stage(stage) | ||||||
|                 redirect = plan.to_redirect(self.request, token.flow) |             self.request.session[SESSION_KEY_PLAN] = plan | ||||||
|  |             flow_slug = token.flow.slug | ||||||
|             token.delete() |             token.delete() | ||||||
|                 return redirect |             return redirect_with_qs( | ||||||
|  |                 "authentik_core:if-flow", | ||||||
|  |                 self.request.GET, | ||||||
|  |                 flow_slug=flow_slug, | ||||||
|  |             ) | ||||||
|  |         flow_context.setdefault(PLAN_CONTEXT_REDIRECT, final_redirect) | ||||||
|  |  | ||||||
|  |         if not flow: | ||||||
|             return bad_request_message( |             return bad_request_message( | ||||||
|                 self.request, |                 self.request, | ||||||
|                 _("Configured flow does not exist."), |                 _("Configured flow does not exist."), | ||||||
| @ -259,8 +259,6 @@ class SourceFlowManager: | |||||||
|         if stages: |         if stages: | ||||||
|             for stage in stages: |             for stage in stages: | ||||||
|                 plan.append_stage(stage) |                 plan.append_stage(stage) | ||||||
|         for stage in self.request.session.get(SESSION_KEY_SOURCE_FLOW_STAGES, []): |  | ||||||
|             plan.append_stage(stage) |  | ||||||
|         return plan.to_redirect(self.request, flow) |         return plan.to_redirect(self.request, flow) | ||||||
|  |  | ||||||
|     def handle_auth( |     def handle_auth( | ||||||
| @ -297,8 +295,6 @@ class SourceFlowManager: | |||||||
|         # When request isn't authenticated we jump straight to auth |         # When request isn't authenticated we jump straight to auth | ||||||
|         if not self.request.user.is_authenticated: |         if not self.request.user.is_authenticated: | ||||||
|             return self.handle_auth(connection) |             return self.handle_auth(connection) | ||||||
|         # When an override flow token exists we actually still use a flow for link |  | ||||||
|         # to continue the existing flow we came from |  | ||||||
|         if SESSION_KEY_OVERRIDE_FLOW_TOKEN in self.request.session: |         if SESSION_KEY_OVERRIDE_FLOW_TOKEN in self.request.session: | ||||||
|             return self._prepare_flow(None, connection) |             return self._prepare_flow(None, connection) | ||||||
|         connection.save() |         connection.save() | ||||||
|  | |||||||
| @ -67,8 +67,6 @@ def clean_expired_models(self: SystemTask): | |||||||
|                 raise ImproperlyConfigured( |                 raise ImproperlyConfigured( | ||||||
|                     "Invalid session_storage setting, allowed values are db and cache" |                     "Invalid session_storage setting, allowed values are db and cache" | ||||||
|                 ) |                 ) | ||||||
|     if CONFIG.get("session_storage", "cache") == "db": |  | ||||||
|         DBSessionStore.clear_expired() |  | ||||||
|     LOGGER.debug("Expired sessions", model=AuthenticatedSession, amount=amount) |     LOGGER.debug("Expired sessions", model=AuthenticatedSession, amount=amount) | ||||||
|  |  | ||||||
|     messages.append(f"Expired {amount} {AuthenticatedSession._meta.verbose_name_plural}") |     messages.append(f"Expired {amount} {AuthenticatedSession._meta.verbose_name_plural}") | ||||||
|  | |||||||
| @ -11,7 +11,6 @@ | |||||||
|         build: "{{ build }}", |         build: "{{ build }}", | ||||||
|         api: { |         api: { | ||||||
|             base: "{{ base_url }}", |             base: "{{ base_url }}", | ||||||
|             relBase: "{{ base_url_rel }}", |  | ||||||
|         }, |         }, | ||||||
|     }; |     }; | ||||||
|     window.addEventListener("DOMContentLoaded", function () { |     window.addEventListener("DOMContentLoaded", function () { | ||||||
|  | |||||||
| @ -8,8 +8,6 @@ | |||||||
|     <head> |     <head> | ||||||
|         <meta charset="UTF-8"> |         <meta charset="UTF-8"> | ||||||
|         <meta name="viewport" content="width=device-width, initial-scale=1, maximum-scale=1"> |         <meta name="viewport" content="width=device-width, initial-scale=1, maximum-scale=1"> | ||||||
|         {# Darkreader breaks the site regardless of theme as its not compatible with webcomponents, and we default to a dark theme based on preferred colour-scheme #} |  | ||||||
|         <meta name="darkreader-lock"> |  | ||||||
|         <title>{% block title %}{% trans title|default:brand.branding_title %}{% endblock %}</title> |         <title>{% block title %}{% trans title|default:brand.branding_title %}{% endblock %}</title> | ||||||
|         <link rel="icon" href="{{ brand.branding_favicon_url }}"> |         <link rel="icon" href="{{ brand.branding_favicon_url }}"> | ||||||
|         <link rel="shortcut icon" href="{{ brand.branding_favicon_url }}"> |         <link rel="shortcut icon" href="{{ brand.branding_favicon_url }}"> | ||||||
|  | |||||||
| @ -4,7 +4,7 @@ from django.urls.base import reverse | |||||||
| from guardian.shortcuts import assign_perm | from guardian.shortcuts import assign_perm | ||||||
| from rest_framework.test import APITestCase | from rest_framework.test import APITestCase | ||||||
|  |  | ||||||
| from authentik.core.models import Group | from authentik.core.models import Group, User | ||||||
| from authentik.core.tests.utils import create_test_admin_user, create_test_user | from authentik.core.tests.utils import create_test_admin_user, create_test_user | ||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
|  |  | ||||||
| @ -14,7 +14,7 @@ class TestGroupsAPI(APITestCase): | |||||||
|  |  | ||||||
|     def setUp(self) -> None: |     def setUp(self) -> None: | ||||||
|         self.login_user = create_test_user() |         self.login_user = create_test_user() | ||||||
|         self.user = create_test_user() |         self.user = User.objects.create(username="test-user") | ||||||
|  |  | ||||||
|     def test_list_with_users(self): |     def test_list_with_users(self): | ||||||
|         """Test listing with users""" |         """Test listing with users""" | ||||||
| @ -109,57 +109,3 @@ class TestGroupsAPI(APITestCase): | |||||||
|             }, |             }, | ||||||
|         ) |         ) | ||||||
|         self.assertEqual(res.status_code, 400) |         self.assertEqual(res.status_code, 400) | ||||||
|  |  | ||||||
|     def test_superuser_no_perm(self): |  | ||||||
|         """Test creating a superuser group without permission""" |  | ||||||
|         assign_perm("authentik_core.add_group", self.login_user) |  | ||||||
|         self.client.force_login(self.login_user) |  | ||||||
|         res = self.client.post( |  | ||||||
|             reverse("authentik_api:group-list"), |  | ||||||
|             data={"name": generate_id(), "is_superuser": True}, |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(res.status_code, 400) |  | ||||||
|         self.assertJSONEqual( |  | ||||||
|             res.content, |  | ||||||
|             {"is_superuser": ["User does not have permission to set superuser status to True."]}, |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def test_superuser_update_no_perm(self): |  | ||||||
|         """Test updating a superuser group without permission""" |  | ||||||
|         group = Group.objects.create(name=generate_id(), is_superuser=True) |  | ||||||
|         assign_perm("view_group", self.login_user, group) |  | ||||||
|         assign_perm("change_group", self.login_user, group) |  | ||||||
|         self.client.force_login(self.login_user) |  | ||||||
|         res = self.client.patch( |  | ||||||
|             reverse("authentik_api:group-detail", kwargs={"pk": group.pk}), |  | ||||||
|             data={"is_superuser": False}, |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(res.status_code, 400) |  | ||||||
|         self.assertJSONEqual( |  | ||||||
|             res.content, |  | ||||||
|             {"is_superuser": ["User does not have permission to set superuser status to False."]}, |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def test_superuser_update_no_change(self): |  | ||||||
|         """Test updating a superuser group without permission |  | ||||||
|         and without changing the superuser status""" |  | ||||||
|         group = Group.objects.create(name=generate_id(), is_superuser=True) |  | ||||||
|         assign_perm("view_group", self.login_user, group) |  | ||||||
|         assign_perm("change_group", self.login_user, group) |  | ||||||
|         self.client.force_login(self.login_user) |  | ||||||
|         res = self.client.patch( |  | ||||||
|             reverse("authentik_api:group-detail", kwargs={"pk": group.pk}), |  | ||||||
|             data={"name": generate_id(), "is_superuser": True}, |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(res.status_code, 200) |  | ||||||
|  |  | ||||||
|     def test_superuser_create(self): |  | ||||||
|         """Test creating a superuser group with permission""" |  | ||||||
|         assign_perm("authentik_core.add_group", self.login_user) |  | ||||||
|         assign_perm("authentik_core.enable_group_superuser", self.login_user) |  | ||||||
|         self.client.force_login(self.login_user) |  | ||||||
|         res = self.client.post( |  | ||||||
|             reverse("authentik_api:group-list"), |  | ||||||
|             data={"name": generate_id(), "is_superuser": True}, |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(res.status_code, 201) |  | ||||||
|  | |||||||
| @ -1,7 +1,6 @@ | |||||||
| """Test Users API""" | """Test Users API""" | ||||||
|  |  | ||||||
| from datetime import datetime | from datetime import datetime | ||||||
| from json import loads |  | ||||||
|  |  | ||||||
| from django.contrib.sessions.backends.cache import KEY_PREFIX | from django.contrib.sessions.backends.cache import KEY_PREFIX | ||||||
| from django.core.cache import cache | from django.core.cache import cache | ||||||
| @ -16,11 +15,7 @@ from authentik.core.models import ( | |||||||
|     User, |     User, | ||||||
|     UserTypes, |     UserTypes, | ||||||
| ) | ) | ||||||
| from authentik.core.tests.utils import ( | from authentik.core.tests.utils import create_test_admin_user, create_test_brand, create_test_flow | ||||||
|     create_test_admin_user, |  | ||||||
|     create_test_brand, |  | ||||||
|     create_test_flow, |  | ||||||
| ) |  | ||||||
| from authentik.flows.models import FlowDesignation | from authentik.flows.models import FlowDesignation | ||||||
| from authentik.lib.generators import generate_id, generate_key | from authentik.lib.generators import generate_id, generate_key | ||||||
| from authentik.stages.email.models import EmailStage | from authentik.stages.email.models import EmailStage | ||||||
| @ -46,32 +41,6 @@ class TestUsersAPI(APITestCase): | |||||||
|         ) |         ) | ||||||
|         self.assertEqual(response.status_code, 200) |         self.assertEqual(response.status_code, 200) | ||||||
|  |  | ||||||
|     def test_filter_is_superuser(self): |  | ||||||
|         """Test API filtering by superuser status""" |  | ||||||
|         self.client.force_login(self.admin) |  | ||||||
|         # Test superuser |  | ||||||
|         response = self.client.get( |  | ||||||
|             reverse("authentik_api:user-list"), |  | ||||||
|             data={ |  | ||||||
|                 "is_superuser": True, |  | ||||||
|             }, |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, 200) |  | ||||||
|         body = loads(response.content) |  | ||||||
|         self.assertEqual(len(body["results"]), 1) |  | ||||||
|         self.assertEqual(body["results"][0]["username"], self.admin.username) |  | ||||||
|         # Test non-superuser |  | ||||||
|         response = self.client.get( |  | ||||||
|             reverse("authentik_api:user-list"), |  | ||||||
|             data={ |  | ||||||
|                 "is_superuser": False, |  | ||||||
|             }, |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, 200) |  | ||||||
|         body = loads(response.content) |  | ||||||
|         self.assertEqual(len(body["results"]), 1, body) |  | ||||||
|         self.assertEqual(body["results"][0]["username"], self.user.username) |  | ||||||
|  |  | ||||||
|     def test_list_with_groups(self): |     def test_list_with_groups(self): | ||||||
|         """Test listing with groups""" |         """Test listing with groups""" | ||||||
|         self.client.force_login(self.admin) |         self.client.force_login(self.admin) | ||||||
|  | |||||||
| @ -55,7 +55,7 @@ class RedirectToAppLaunch(View): | |||||||
|             ) |             ) | ||||||
|         except FlowNonApplicableException: |         except FlowNonApplicableException: | ||||||
|             raise Http404 from None |             raise Http404 from None | ||||||
|         plan.append_stage(in_memory_stage(RedirectToAppStage)) |         plan.insert_stage(in_memory_stage(RedirectToAppStage)) | ||||||
|         return plan.to_redirect(request, flow) |         return plan.to_redirect(request, flow) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| @ -53,7 +53,6 @@ class InterfaceView(TemplateView): | |||||||
|         kwargs["build"] = get_build_hash() |         kwargs["build"] = get_build_hash() | ||||||
|         kwargs["url_kwargs"] = self.kwargs |         kwargs["url_kwargs"] = self.kwargs | ||||||
|         kwargs["base_url"] = self.request.build_absolute_uri(CONFIG.get("web.path", "/")) |         kwargs["base_url"] = self.request.build_absolute_uri(CONFIG.get("web.path", "/")) | ||||||
|         kwargs["base_url_rel"] = CONFIG.get("web.path", "/") |  | ||||||
|         return super().get_context_data(**kwargs) |         return super().get_context_data(**kwargs) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| @ -97,8 +97,6 @@ class EnterpriseAuditMiddleware(AuditMiddleware): | |||||||
|         thread_kwargs: dict | None = None, |         thread_kwargs: dict | None = None, | ||||||
|         **_, |         **_, | ||||||
|     ): |     ): | ||||||
|         if not self.enabled: |  | ||||||
|             return super().post_save_handler(request, sender, instance, created, thread_kwargs, **_) |  | ||||||
|         if not should_log_model(instance): |         if not should_log_model(instance): | ||||||
|             return None |             return None | ||||||
|         thread_kwargs = {} |         thread_kwargs = {} | ||||||
| @ -124,8 +122,6 @@ class EnterpriseAuditMiddleware(AuditMiddleware): | |||||||
|     ): |     ): | ||||||
|         thread_kwargs = {} |         thread_kwargs = {} | ||||||
|         m2m_field = None |         m2m_field = None | ||||||
|         if not self.enabled: |  | ||||||
|             return super().m2m_changed_handler(request, sender, instance, action, thread_kwargs) |  | ||||||
|         # For the audit log we don't care about `pre_` or `post_` so we trim that part off |         # For the audit log we don't care about `pre_` or `post_` so we trim that part off | ||||||
|         _, _, action_direction = action.partition("_") |         _, _, action_direction = action.partition("_") | ||||||
|         # resolve the "through" model to an actual field |         # resolve the "through" model to an actual field | ||||||
|  | |||||||
| @ -6,12 +6,13 @@ from rest_framework.viewsets import GenericViewSet | |||||||
| from authentik.core.api.groups import GroupMemberSerializer | from authentik.core.api.groups import GroupMemberSerializer | ||||||
| from authentik.core.api.used_by import UsedByMixin | from authentik.core.api.used_by import UsedByMixin | ||||||
| from authentik.core.api.utils import ModelSerializer | from authentik.core.api.utils import ModelSerializer | ||||||
| from authentik.providers.rac.api.endpoints import EndpointSerializer | from authentik.enterprise.api import EnterpriseRequiredMixin | ||||||
| from authentik.providers.rac.api.providers import RACProviderSerializer | from authentik.enterprise.providers.rac.api.endpoints import EndpointSerializer | ||||||
| from authentik.providers.rac.models import ConnectionToken | from authentik.enterprise.providers.rac.api.providers import RACProviderSerializer | ||||||
|  | from authentik.enterprise.providers.rac.models import ConnectionToken | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class ConnectionTokenSerializer(ModelSerializer): | class ConnectionTokenSerializer(EnterpriseRequiredMixin, ModelSerializer): | ||||||
|     """ConnectionToken Serializer""" |     """ConnectionToken Serializer""" | ||||||
| 
 | 
 | ||||||
|     provider_obj = RACProviderSerializer(source="provider", read_only=True) |     provider_obj = RACProviderSerializer(source="provider", read_only=True) | ||||||
| @ -14,9 +14,10 @@ from structlog.stdlib import get_logger | |||||||
| from authentik.core.api.used_by import UsedByMixin | from authentik.core.api.used_by import UsedByMixin | ||||||
| from authentik.core.api.utils import ModelSerializer | from authentik.core.api.utils import ModelSerializer | ||||||
| from authentik.core.models import Provider | from authentik.core.models import Provider | ||||||
|  | from authentik.enterprise.api import EnterpriseRequiredMixin | ||||||
|  | from authentik.enterprise.providers.rac.api.providers import RACProviderSerializer | ||||||
|  | from authentik.enterprise.providers.rac.models import Endpoint | ||||||
| from authentik.policies.engine import PolicyEngine | from authentik.policies.engine import PolicyEngine | ||||||
| from authentik.providers.rac.api.providers import RACProviderSerializer |  | ||||||
| from authentik.providers.rac.models import Endpoint |  | ||||||
| from authentik.rbac.filters import ObjectFilter | from authentik.rbac.filters import ObjectFilter | ||||||
| 
 | 
 | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
| @ -27,7 +28,7 @@ def user_endpoint_cache_key(user_pk: str) -> str: | |||||||
|     return f"goauthentik.io/providers/rac/endpoint_access/{user_pk}" |     return f"goauthentik.io/providers/rac/endpoint_access/{user_pk}" | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class EndpointSerializer(ModelSerializer): | class EndpointSerializer(EnterpriseRequiredMixin, ModelSerializer): | ||||||
|     """Endpoint Serializer""" |     """Endpoint Serializer""" | ||||||
| 
 | 
 | ||||||
|     provider_obj = RACProviderSerializer(source="provider", read_only=True) |     provider_obj = RACProviderSerializer(source="provider", read_only=True) | ||||||
| @ -10,7 +10,7 @@ from rest_framework.viewsets import ModelViewSet | |||||||
| from authentik.core.api.property_mappings import PropertyMappingSerializer | from authentik.core.api.property_mappings import PropertyMappingSerializer | ||||||
| from authentik.core.api.used_by import UsedByMixin | from authentik.core.api.used_by import UsedByMixin | ||||||
| from authentik.core.api.utils import JSONDictField | from authentik.core.api.utils import JSONDictField | ||||||
| from authentik.providers.rac.models import RACPropertyMapping | from authentik.enterprise.providers.rac.models import RACPropertyMapping | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class RACPropertyMappingSerializer(PropertyMappingSerializer): | class RACPropertyMappingSerializer(PropertyMappingSerializer): | ||||||
| @ -5,10 +5,11 @@ from rest_framework.viewsets import ModelViewSet | |||||||
| 
 | 
 | ||||||
| from authentik.core.api.providers import ProviderSerializer | from authentik.core.api.providers import ProviderSerializer | ||||||
| from authentik.core.api.used_by import UsedByMixin | from authentik.core.api.used_by import UsedByMixin | ||||||
| from authentik.providers.rac.models import RACProvider | from authentik.enterprise.api import EnterpriseRequiredMixin | ||||||
|  | from authentik.enterprise.providers.rac.models import RACProvider | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class RACProviderSerializer(ProviderSerializer): | class RACProviderSerializer(EnterpriseRequiredMixin, ProviderSerializer): | ||||||
|     """RACProvider Serializer""" |     """RACProvider Serializer""" | ||||||
| 
 | 
 | ||||||
|     outpost_set = ListField(child=CharField(), read_only=True, source="outpost_set.all") |     outpost_set = ListField(child=CharField(), read_only=True, source="outpost_set.all") | ||||||
							
								
								
									
										14
									
								
								authentik/enterprise/providers/rac/apps.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								authentik/enterprise/providers/rac/apps.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,14 @@ | |||||||
|  | """RAC app config""" | ||||||
|  |  | ||||||
|  | from authentik.enterprise.apps import EnterpriseConfig | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class AuthentikEnterpriseProviderRAC(EnterpriseConfig): | ||||||
|  |     """authentik enterprise rac app config""" | ||||||
|  |  | ||||||
|  |     name = "authentik.enterprise.providers.rac" | ||||||
|  |     label = "authentik_providers_rac" | ||||||
|  |     verbose_name = "authentik Enterprise.Providers.RAC" | ||||||
|  |     default = True | ||||||
|  |     mountpoint = "" | ||||||
|  |     ws_mountpoint = "authentik.enterprise.providers.rac.urls" | ||||||
| @ -7,22 +7,22 @@ from channels.generic.websocket import AsyncWebsocketConsumer | |||||||
| from django.http.request import QueryDict | from django.http.request import QueryDict | ||||||
| from structlog.stdlib import BoundLogger, get_logger | from structlog.stdlib import BoundLogger, get_logger | ||||||
| 
 | 
 | ||||||
|  | from authentik.enterprise.providers.rac.models import ConnectionToken, RACProvider | ||||||
| from authentik.outposts.consumer import OUTPOST_GROUP_INSTANCE | from authentik.outposts.consumer import OUTPOST_GROUP_INSTANCE | ||||||
| from authentik.outposts.models import Outpost, OutpostState, OutpostType | from authentik.outposts.models import Outpost, OutpostState, OutpostType | ||||||
| from authentik.providers.rac.models import ConnectionToken, RACProvider |  | ||||||
| 
 | 
 | ||||||
| # Global broadcast group, which messages are sent to when the outpost connects back | # Global broadcast group, which messages are sent to when the outpost connects back | ||||||
| # to authentik for a specific connection | # to authentik for a specific connection | ||||||
| # The `RACClientConsumer` consumer adds itself to this group on connection, | # The `RACClientConsumer` consumer adds itself to this group on connection, | ||||||
| # and removes itself once it has been assigned a specific outpost channel | # and removes itself once it has been assigned a specific outpost channel | ||||||
| RAC_CLIENT_GROUP = "group_rac_client" | RAC_CLIENT_GROUP = "group_enterprise_rac_client" | ||||||
| # A group for all connections in a given authentik session ID | # A group for all connections in a given authentik session ID | ||||||
| # A disconnect message is sent to this group when the session expires/is deleted | # A disconnect message is sent to this group when the session expires/is deleted | ||||||
| RAC_CLIENT_GROUP_SESSION = "group_rac_client_%(session)s" | RAC_CLIENT_GROUP_SESSION = "group_enterprise_rac_client_%(session)s" | ||||||
| # A group for all connections with a specific token, which in almost all cases | # A group for all connections with a specific token, which in almost all cases | ||||||
| # is just one connection, however this is used to disconnect the connection | # is just one connection, however this is used to disconnect the connection | ||||||
| # when the token is deleted | # when the token is deleted | ||||||
| RAC_CLIENT_GROUP_TOKEN = "group_rac_token_%(token)s"  # nosec | RAC_CLIENT_GROUP_TOKEN = "group_enterprise_rac_token_%(token)s"  # nosec | ||||||
| 
 | 
 | ||||||
| # Step 1: Client connects to this websocket endpoint | # Step 1: Client connects to this websocket endpoint | ||||||
| # Step 2: We prepare all the connection args for Guac | # Step 2: We prepare all the connection args for Guac | ||||||
| @ -3,7 +3,7 @@ | |||||||
| from channels.exceptions import ChannelFull | from channels.exceptions import ChannelFull | ||||||
| from channels.generic.websocket import AsyncWebsocketConsumer | from channels.generic.websocket import AsyncWebsocketConsumer | ||||||
| 
 | 
 | ||||||
| from authentik.providers.rac.consumer_client import RAC_CLIENT_GROUP | from authentik.enterprise.providers.rac.consumer_client import RAC_CLIENT_GROUP | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class RACOutpostConsumer(AsyncWebsocketConsumer): | class RACOutpostConsumer(AsyncWebsocketConsumer): | ||||||
| @ -74,7 +74,7 @@ class RACProvider(Provider): | |||||||
| 
 | 
 | ||||||
|     @property |     @property | ||||||
|     def serializer(self) -> type[Serializer]: |     def serializer(self) -> type[Serializer]: | ||||||
|         from authentik.providers.rac.api.providers import RACProviderSerializer |         from authentik.enterprise.providers.rac.api.providers import RACProviderSerializer | ||||||
| 
 | 
 | ||||||
|         return RACProviderSerializer |         return RACProviderSerializer | ||||||
| 
 | 
 | ||||||
| @ -100,7 +100,7 @@ class Endpoint(SerializerModel, PolicyBindingModel): | |||||||
| 
 | 
 | ||||||
|     @property |     @property | ||||||
|     def serializer(self) -> type[Serializer]: |     def serializer(self) -> type[Serializer]: | ||||||
|         from authentik.providers.rac.api.endpoints import EndpointSerializer |         from authentik.enterprise.providers.rac.api.endpoints import EndpointSerializer | ||||||
| 
 | 
 | ||||||
|         return EndpointSerializer |         return EndpointSerializer | ||||||
| 
 | 
 | ||||||
| @ -129,7 +129,7 @@ class RACPropertyMapping(PropertyMapping): | |||||||
| 
 | 
 | ||||||
|     @property |     @property | ||||||
|     def serializer(self) -> type[Serializer]: |     def serializer(self) -> type[Serializer]: | ||||||
|         from authentik.providers.rac.api.property_mappings import ( |         from authentik.enterprise.providers.rac.api.property_mappings import ( | ||||||
|             RACPropertyMappingSerializer, |             RACPropertyMappingSerializer, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
| @ -4,17 +4,18 @@ from asgiref.sync import async_to_sync | |||||||
| from channels.layers import get_channel_layer | from channels.layers import get_channel_layer | ||||||
| from django.contrib.auth.signals import user_logged_out | from django.contrib.auth.signals import user_logged_out | ||||||
| from django.core.cache import cache | from django.core.cache import cache | ||||||
| from django.db.models.signals import post_delete, post_save, pre_delete | from django.db.models import Model | ||||||
|  | from django.db.models.signals import post_save, pre_delete | ||||||
| from django.dispatch import receiver | from django.dispatch import receiver | ||||||
| from django.http import HttpRequest | from django.http import HttpRequest | ||||||
| 
 | 
 | ||||||
| from authentik.core.models import User | from authentik.core.models import User | ||||||
| from authentik.providers.rac.api.endpoints import user_endpoint_cache_key | from authentik.enterprise.providers.rac.api.endpoints import user_endpoint_cache_key | ||||||
| from authentik.providers.rac.consumer_client import ( | from authentik.enterprise.providers.rac.consumer_client import ( | ||||||
|     RAC_CLIENT_GROUP_SESSION, |     RAC_CLIENT_GROUP_SESSION, | ||||||
|     RAC_CLIENT_GROUP_TOKEN, |     RAC_CLIENT_GROUP_TOKEN, | ||||||
| ) | ) | ||||||
| from authentik.providers.rac.models import ConnectionToken, Endpoint | from authentik.enterprise.providers.rac.models import ConnectionToken, Endpoint | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @receiver(user_logged_out) | @receiver(user_logged_out) | ||||||
| @ -45,8 +46,12 @@ def pre_delete_connection_token_disconnect(sender, instance: ConnectionToken, ** | |||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @receiver([post_save, post_delete], sender=Endpoint) | @receiver(post_save, sender=Endpoint) | ||||||
| def post_save_post_delete_endpoint(**_): | def post_save_endpoint(sender: type[Model], instance, created: bool, **_): | ||||||
|     """Clear user's endpoint cache upon endpoint creation or deletion""" |     """Clear user's endpoint cache upon endpoint creation""" | ||||||
|  |     if not created:  # pragma: no cover | ||||||
|  |         return | ||||||
|  | 
 | ||||||
|  |     # Delete user endpoint cache | ||||||
|     keys = cache.keys(user_endpoint_cache_key("*")) |     keys = cache.keys(user_endpoint_cache_key("*")) | ||||||
|     cache.delete_many(keys) |     cache.delete_many(keys) | ||||||
| @ -3,7 +3,7 @@ | |||||||
| {% load authentik_core %} | {% load authentik_core %} | ||||||
| 
 | 
 | ||||||
| {% block head %} | {% block head %} | ||||||
| <script src="{% versioned_script 'dist/rac/index-%v.js' %}" type="module"></script> | <script src="{% versioned_script 'dist/enterprise/rac/index-%v.js' %}" type="module"></script> | ||||||
| <meta name="theme-color" content="#18191a" media="(prefers-color-scheme: dark)"> | <meta name="theme-color" content="#18191a" media="(prefers-color-scheme: dark)"> | ||||||
| <meta name="theme-color" content="#ffffff" media="(prefers-color-scheme: light)"> | <meta name="theme-color" content="#ffffff" media="(prefers-color-scheme: light)"> | ||||||
| <link rel="icon" href="{{ tenant.branding_favicon_url }}"> | <link rel="icon" href="{{ tenant.branding_favicon_url }}"> | ||||||
| @ -1,9 +1,16 @@ | |||||||
| """Test RAC Provider""" | """Test RAC Provider""" | ||||||
| 
 | 
 | ||||||
|  | from datetime import timedelta | ||||||
|  | from time import mktime | ||||||
|  | from unittest.mock import MagicMock, patch | ||||||
|  | 
 | ||||||
| from django.urls import reverse | from django.urls import reverse | ||||||
|  | from django.utils.timezone import now | ||||||
| from rest_framework.test import APITestCase | from rest_framework.test import APITestCase | ||||||
| 
 | 
 | ||||||
| from authentik.core.tests.utils import create_test_admin_user, create_test_flow | from authentik.core.tests.utils import create_test_admin_user, create_test_flow | ||||||
|  | from authentik.enterprise.license import LicenseKey | ||||||
|  | from authentik.enterprise.models import License | ||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -13,8 +20,21 @@ class TestAPI(APITestCase): | |||||||
|     def setUp(self) -> None: |     def setUp(self) -> None: | ||||||
|         self.user = create_test_admin_user() |         self.user = create_test_admin_user() | ||||||
| 
 | 
 | ||||||
|  |     @patch( | ||||||
|  |         "authentik.enterprise.license.LicenseKey.validate", | ||||||
|  |         MagicMock( | ||||||
|  |             return_value=LicenseKey( | ||||||
|  |                 aud="", | ||||||
|  |                 exp=int(mktime((now() + timedelta(days=3000)).timetuple())), | ||||||
|  |                 name=generate_id(), | ||||||
|  |                 internal_users=100, | ||||||
|  |                 external_users=100, | ||||||
|  |             ) | ||||||
|  |         ), | ||||||
|  |     ) | ||||||
|     def test_create(self): |     def test_create(self): | ||||||
|         """Test creation of RAC Provider""" |         """Test creation of RAC Provider""" | ||||||
|  |         License.objects.create(key=generate_id()) | ||||||
|         self.client.force_login(self.user) |         self.client.force_login(self.user) | ||||||
|         response = self.client.post( |         response = self.client.post( | ||||||
|             reverse("authentik_api:racprovider-list"), |             reverse("authentik_api:racprovider-list"), | ||||||
| @ -5,10 +5,10 @@ from rest_framework.test import APITestCase | |||||||
| 
 | 
 | ||||||
| from authentik.core.models import Application | from authentik.core.models import Application | ||||||
| from authentik.core.tests.utils import create_test_admin_user | from authentik.core.tests.utils import create_test_admin_user | ||||||
|  | from authentik.enterprise.providers.rac.models import Endpoint, Protocols, RACProvider | ||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
| from authentik.policies.dummy.models import DummyPolicy | from authentik.policies.dummy.models import DummyPolicy | ||||||
| from authentik.policies.models import PolicyBinding | from authentik.policies.models import PolicyBinding | ||||||
| from authentik.providers.rac.models import Endpoint, Protocols, RACProvider |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class TestEndpointsAPI(APITestCase): | class TestEndpointsAPI(APITestCase): | ||||||
| @ -4,14 +4,14 @@ from django.test import TransactionTestCase | |||||||
| 
 | 
 | ||||||
| from authentik.core.models import Application, AuthenticatedSession | from authentik.core.models import Application, AuthenticatedSession | ||||||
| from authentik.core.tests.utils import create_test_admin_user | from authentik.core.tests.utils import create_test_admin_user | ||||||
| from authentik.lib.generators import generate_id | from authentik.enterprise.providers.rac.models import ( | ||||||
| from authentik.providers.rac.models import ( |  | ||||||
|     ConnectionToken, |     ConnectionToken, | ||||||
|     Endpoint, |     Endpoint, | ||||||
|     Protocols, |     Protocols, | ||||||
|     RACPropertyMapping, |     RACPropertyMapping, | ||||||
|     RACProvider, |     RACProvider, | ||||||
| ) | ) | ||||||
|  | from authentik.lib.generators import generate_id | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class TestModels(TransactionTestCase): | class TestModels(TransactionTestCase): | ||||||
| @ -1,17 +1,23 @@ | |||||||
| """RAC Views tests""" | """RAC Views tests""" | ||||||
| 
 | 
 | ||||||
|  | from datetime import timedelta | ||||||
| from json import loads | from json import loads | ||||||
|  | from time import mktime | ||||||
|  | from unittest.mock import MagicMock, patch | ||||||
| 
 | 
 | ||||||
| from django.urls import reverse | from django.urls import reverse | ||||||
|  | from django.utils.timezone import now | ||||||
| from rest_framework.test import APITestCase | from rest_framework.test import APITestCase | ||||||
| 
 | 
 | ||||||
| from authentik.core.models import Application | from authentik.core.models import Application | ||||||
| from authentik.core.tests.utils import create_test_admin_user, create_test_flow | from authentik.core.tests.utils import create_test_admin_user, create_test_flow | ||||||
|  | from authentik.enterprise.license import LicenseKey | ||||||
|  | from authentik.enterprise.models import License | ||||||
|  | from authentik.enterprise.providers.rac.models import Endpoint, Protocols, RACProvider | ||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
| from authentik.policies.denied import AccessDeniedResponse | from authentik.policies.denied import AccessDeniedResponse | ||||||
| from authentik.policies.dummy.models import DummyPolicy | from authentik.policies.dummy.models import DummyPolicy | ||||||
| from authentik.policies.models import PolicyBinding | from authentik.policies.models import PolicyBinding | ||||||
| from authentik.providers.rac.models import Endpoint, Protocols, RACProvider |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class TestRACViews(APITestCase): | class TestRACViews(APITestCase): | ||||||
| @ -33,8 +39,21 @@ class TestRACViews(APITestCase): | |||||||
|             provider=self.provider, |             provider=self.provider, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|  |     @patch( | ||||||
|  |         "authentik.enterprise.license.LicenseKey.validate", | ||||||
|  |         MagicMock( | ||||||
|  |             return_value=LicenseKey( | ||||||
|  |                 aud="", | ||||||
|  |                 exp=int(mktime((now() + timedelta(days=3000)).timetuple())), | ||||||
|  |                 name=generate_id(), | ||||||
|  |                 internal_users=100, | ||||||
|  |                 external_users=100, | ||||||
|  |             ) | ||||||
|  |         ), | ||||||
|  |     ) | ||||||
|     def test_no_policy(self): |     def test_no_policy(self): | ||||||
|         """Test request""" |         """Test request""" | ||||||
|  |         License.objects.create(key=generate_id()) | ||||||
|         self.client.force_login(self.user) |         self.client.force_login(self.user) | ||||||
|         response = self.client.get( |         response = self.client.get( | ||||||
|             reverse( |             reverse( | ||||||
| @ -51,6 +70,18 @@ class TestRACViews(APITestCase): | |||||||
|         final_response = self.client.get(next_url) |         final_response = self.client.get(next_url) | ||||||
|         self.assertEqual(final_response.status_code, 200) |         self.assertEqual(final_response.status_code, 200) | ||||||
| 
 | 
 | ||||||
|  |     @patch( | ||||||
|  |         "authentik.enterprise.license.LicenseKey.validate", | ||||||
|  |         MagicMock( | ||||||
|  |             return_value=LicenseKey( | ||||||
|  |                 aud="", | ||||||
|  |                 exp=int(mktime((now() + timedelta(days=3000)).timetuple())), | ||||||
|  |                 name=generate_id(), | ||||||
|  |                 internal_users=100, | ||||||
|  |                 external_users=100, | ||||||
|  |             ) | ||||||
|  |         ), | ||||||
|  |     ) | ||||||
|     def test_app_deny(self): |     def test_app_deny(self): | ||||||
|         """Test request (deny on app level)""" |         """Test request (deny on app level)""" | ||||||
|         PolicyBinding.objects.create( |         PolicyBinding.objects.create( | ||||||
| @ -58,6 +89,7 @@ class TestRACViews(APITestCase): | |||||||
|             policy=DummyPolicy.objects.create(name="deny", result=False, wait_min=1, wait_max=2), |             policy=DummyPolicy.objects.create(name="deny", result=False, wait_min=1, wait_max=2), | ||||||
|             order=0, |             order=0, | ||||||
|         ) |         ) | ||||||
|  |         License.objects.create(key=generate_id()) | ||||||
|         self.client.force_login(self.user) |         self.client.force_login(self.user) | ||||||
|         response = self.client.get( |         response = self.client.get( | ||||||
|             reverse( |             reverse( | ||||||
| @ -67,6 +99,18 @@ class TestRACViews(APITestCase): | |||||||
|         ) |         ) | ||||||
|         self.assertIsInstance(response, AccessDeniedResponse) |         self.assertIsInstance(response, AccessDeniedResponse) | ||||||
| 
 | 
 | ||||||
|  |     @patch( | ||||||
|  |         "authentik.enterprise.license.LicenseKey.validate", | ||||||
|  |         MagicMock( | ||||||
|  |             return_value=LicenseKey( | ||||||
|  |                 aud="", | ||||||
|  |                 exp=int(mktime((now() + timedelta(days=3000)).timetuple())), | ||||||
|  |                 name=generate_id(), | ||||||
|  |                 internal_users=100, | ||||||
|  |                 external_users=100, | ||||||
|  |             ) | ||||||
|  |         ), | ||||||
|  |     ) | ||||||
|     def test_endpoint_deny(self): |     def test_endpoint_deny(self): | ||||||
|         """Test request (deny on endpoint level)""" |         """Test request (deny on endpoint level)""" | ||||||
|         PolicyBinding.objects.create( |         PolicyBinding.objects.create( | ||||||
| @ -74,6 +118,7 @@ class TestRACViews(APITestCase): | |||||||
|             policy=DummyPolicy.objects.create(name="deny", result=False, wait_min=1, wait_max=2), |             policy=DummyPolicy.objects.create(name="deny", result=False, wait_min=1, wait_max=2), | ||||||
|             order=0, |             order=0, | ||||||
|         ) |         ) | ||||||
|  |         License.objects.create(key=generate_id()) | ||||||
|         self.client.force_login(self.user) |         self.client.force_login(self.user) | ||||||
|         response = self.client.get( |         response = self.client.get( | ||||||
|             reverse( |             reverse( | ||||||
| @ -4,14 +4,14 @@ from channels.auth import AuthMiddleware | |||||||
| from channels.sessions import CookieMiddleware | from channels.sessions import CookieMiddleware | ||||||
| from django.urls import path | from django.urls import path | ||||||
| 
 | 
 | ||||||
|  | from authentik.enterprise.providers.rac.api.connection_tokens import ConnectionTokenViewSet | ||||||
|  | from authentik.enterprise.providers.rac.api.endpoints import EndpointViewSet | ||||||
|  | from authentik.enterprise.providers.rac.api.property_mappings import RACPropertyMappingViewSet | ||||||
|  | from authentik.enterprise.providers.rac.api.providers import RACProviderViewSet | ||||||
|  | from authentik.enterprise.providers.rac.consumer_client import RACClientConsumer | ||||||
|  | from authentik.enterprise.providers.rac.consumer_outpost import RACOutpostConsumer | ||||||
|  | from authentik.enterprise.providers.rac.views import RACInterface, RACStartView | ||||||
| from authentik.outposts.channels import TokenOutpostMiddleware | from authentik.outposts.channels import TokenOutpostMiddleware | ||||||
| from authentik.providers.rac.api.connection_tokens import ConnectionTokenViewSet |  | ||||||
| from authentik.providers.rac.api.endpoints import EndpointViewSet |  | ||||||
| from authentik.providers.rac.api.property_mappings import RACPropertyMappingViewSet |  | ||||||
| from authentik.providers.rac.api.providers import RACProviderViewSet |  | ||||||
| from authentik.providers.rac.consumer_client import RACClientConsumer |  | ||||||
| from authentik.providers.rac.consumer_outpost import RACOutpostConsumer |  | ||||||
| from authentik.providers.rac.views import RACInterface, RACStartView |  | ||||||
| from authentik.root.asgi_middleware import SessionMiddleware | from authentik.root.asgi_middleware import SessionMiddleware | ||||||
| from authentik.root.middleware import ChannelsLoggingMiddleware | from authentik.root.middleware import ChannelsLoggingMiddleware | ||||||
| 
 | 
 | ||||||
| @ -10,6 +10,8 @@ from django.utils.translation import gettext as _ | |||||||
| 
 | 
 | ||||||
| from authentik.core.models import Application, AuthenticatedSession | from authentik.core.models import Application, AuthenticatedSession | ||||||
| from authentik.core.views.interface import InterfaceView | from authentik.core.views.interface import InterfaceView | ||||||
|  | from authentik.enterprise.policy import EnterprisePolicyAccessView | ||||||
|  | from authentik.enterprise.providers.rac.models import ConnectionToken, Endpoint, RACProvider | ||||||
| from authentik.events.models import Event, EventAction | from authentik.events.models import Event, EventAction | ||||||
| from authentik.flows.challenge import RedirectChallenge | from authentik.flows.challenge import RedirectChallenge | ||||||
| from authentik.flows.exceptions import FlowNonApplicableException | from authentik.flows.exceptions import FlowNonApplicableException | ||||||
| @ -18,11 +20,9 @@ from authentik.flows.planner import PLAN_CONTEXT_APPLICATION, FlowPlanner | |||||||
| from authentik.flows.stage import RedirectStage | from authentik.flows.stage import RedirectStage | ||||||
| from authentik.lib.utils.time import timedelta_from_string | from authentik.lib.utils.time import timedelta_from_string | ||||||
| from authentik.policies.engine import PolicyEngine | from authentik.policies.engine import PolicyEngine | ||||||
| from authentik.policies.views import PolicyAccessView |  | ||||||
| from authentik.providers.rac.models import ConnectionToken, Endpoint, RACProvider |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class RACStartView(PolicyAccessView): | class RACStartView(EnterprisePolicyAccessView): | ||||||
|     """Start a RAC connection by checking access and creating a connection token""" |     """Start a RAC connection by checking access and creating a connection token""" | ||||||
| 
 | 
 | ||||||
|     endpoint: Endpoint |     endpoint: Endpoint | ||||||
| @ -46,7 +46,7 @@ class RACStartView(PolicyAccessView): | |||||||
|             ) |             ) | ||||||
|         except FlowNonApplicableException: |         except FlowNonApplicableException: | ||||||
|             raise Http404 from None |             raise Http404 from None | ||||||
|         plan.append_stage( |         plan.insert_stage( | ||||||
|             in_memory_stage( |             in_memory_stage( | ||||||
|                 RACFinalStage, |                 RACFinalStage, | ||||||
|                 application=self.application, |                 application=self.application, | ||||||
| @ -16,6 +16,7 @@ TENANT_APPS = [ | |||||||
|     "authentik.enterprise.audit", |     "authentik.enterprise.audit", | ||||||
|     "authentik.enterprise.providers.google_workspace", |     "authentik.enterprise.providers.google_workspace", | ||||||
|     "authentik.enterprise.providers.microsoft_entra", |     "authentik.enterprise.providers.microsoft_entra", | ||||||
|  |     "authentik.enterprise.providers.rac", | ||||||
|     "authentik.enterprise.providers.ssf", |     "authentik.enterprise.providers.ssf", | ||||||
|     "authentik.enterprise.stages.authenticator_endpoint_gdtc", |     "authentik.enterprise.stages.authenticator_endpoint_gdtc", | ||||||
|     "authentik.enterprise.stages.source", |     "authentik.enterprise.stages.source", | ||||||
|  | |||||||
| @ -9,16 +9,13 @@ from django.utils.timezone import now | |||||||
| from guardian.shortcuts import get_anonymous_user | from guardian.shortcuts import get_anonymous_user | ||||||
|  |  | ||||||
| from authentik.core.models import Source, User | from authentik.core.models import Source, User | ||||||
| from authentik.core.sources.flow_manager import ( | from authentik.core.sources.flow_manager import SESSION_KEY_OVERRIDE_FLOW_TOKEN | ||||||
|     SESSION_KEY_OVERRIDE_FLOW_TOKEN, |  | ||||||
|     SESSION_KEY_SOURCE_FLOW_STAGES, |  | ||||||
| ) |  | ||||||
| from authentik.core.types import UILoginButton | from authentik.core.types import UILoginButton | ||||||
| from authentik.enterprise.stages.source.models import SourceStage | from authentik.enterprise.stages.source.models import SourceStage | ||||||
| from authentik.flows.challenge import Challenge, ChallengeResponse | from authentik.flows.challenge import Challenge, ChallengeResponse | ||||||
| from authentik.flows.models import FlowToken, in_memory_stage | from authentik.flows.models import FlowToken | ||||||
| from authentik.flows.planner import PLAN_CONTEXT_IS_RESTORED | from authentik.flows.planner import PLAN_CONTEXT_IS_RESTORED | ||||||
| from authentik.flows.stage import ChallengeStageView, StageView | from authentik.flows.stage import ChallengeStageView | ||||||
| from authentik.lib.utils.time import timedelta_from_string | from authentik.lib.utils.time import timedelta_from_string | ||||||
|  |  | ||||||
| PLAN_CONTEXT_RESUME_TOKEN = "resume_token"  # nosec | PLAN_CONTEXT_RESUME_TOKEN = "resume_token"  # nosec | ||||||
| @ -52,7 +49,6 @@ class SourceStageView(ChallengeStageView): | |||||||
|     def get_challenge(self, *args, **kwargs) -> Challenge: |     def get_challenge(self, *args, **kwargs) -> Challenge: | ||||||
|         resume_token = self.create_flow_token() |         resume_token = self.create_flow_token() | ||||||
|         self.request.session[SESSION_KEY_OVERRIDE_FLOW_TOKEN] = resume_token |         self.request.session[SESSION_KEY_OVERRIDE_FLOW_TOKEN] = resume_token | ||||||
|         self.request.session[SESSION_KEY_SOURCE_FLOW_STAGES] = [in_memory_stage(SourceStageFinal)] |  | ||||||
|         return self.login_button.challenge |         return self.login_button.challenge | ||||||
|  |  | ||||||
|     def create_flow_token(self) -> FlowToken: |     def create_flow_token(self) -> FlowToken: | ||||||
| @ -81,19 +77,3 @@ class SourceStageView(ChallengeStageView): | |||||||
|  |  | ||||||
|     def challenge_valid(self, response: ChallengeResponse) -> HttpResponse: |     def challenge_valid(self, response: ChallengeResponse) -> HttpResponse: | ||||||
|         return self.executor.stage_ok() |         return self.executor.stage_ok() | ||||||
|  |  | ||||||
|  |  | ||||||
| class SourceStageFinal(StageView): |  | ||||||
|     """Dynamic stage injected in the source flow manager. This is injected in the |  | ||||||
|     flow the source flow manager picks (authentication or enrollment), and will run at the end. |  | ||||||
|     This stage uses the override flow token to resume execution of the initial flow the |  | ||||||
|     source stage is bound to.""" |  | ||||||
|  |  | ||||||
|     def dispatch(self, *args, **kwargs): |  | ||||||
|         token: FlowToken = self.request.session.get(SESSION_KEY_OVERRIDE_FLOW_TOKEN) |  | ||||||
|         self.logger.info("Replacing source flow with overridden flow", flow=token.flow.slug) |  | ||||||
|         plan = token.plan |  | ||||||
|         plan.context[PLAN_CONTEXT_IS_RESTORED] = token |  | ||||||
|         response = plan.to_redirect(self.request, token.flow) |  | ||||||
|         token.delete() |  | ||||||
|         return response |  | ||||||
|  | |||||||
| @ -4,8 +4,7 @@ from django.urls import reverse | |||||||
|  |  | ||||||
| from authentik.core.tests.utils import create_test_flow, create_test_user | from authentik.core.tests.utils import create_test_flow, create_test_user | ||||||
| from authentik.enterprise.stages.source.models import SourceStage | from authentik.enterprise.stages.source.models import SourceStage | ||||||
| from authentik.enterprise.stages.source.stage import SourceStageFinal | from authentik.flows.models import FlowDesignation, FlowStageBinding, FlowToken | ||||||
| from authentik.flows.models import FlowDesignation, FlowStageBinding, FlowToken, in_memory_stage |  | ||||||
| from authentik.flows.planner import PLAN_CONTEXT_IS_RESTORED, FlowPlan | from authentik.flows.planner import PLAN_CONTEXT_IS_RESTORED, FlowPlan | ||||||
| from authentik.flows.tests import FlowTestCase | from authentik.flows.tests import FlowTestCase | ||||||
| from authentik.flows.views.executor import SESSION_KEY_PLAN | from authentik.flows.views.executor import SESSION_KEY_PLAN | ||||||
| @ -88,7 +87,6 @@ class TestSourceStage(FlowTestCase): | |||||||
|         self.assertIsNotNone(flow_token) |         self.assertIsNotNone(flow_token) | ||||||
|         session = self.client.session |         session = self.client.session | ||||||
|         plan: FlowPlan = session[SESSION_KEY_PLAN] |         plan: FlowPlan = session[SESSION_KEY_PLAN] | ||||||
|         plan.insert_stage(in_memory_stage(SourceStageFinal), index=0) |  | ||||||
|         plan.context[PLAN_CONTEXT_IS_RESTORED] = flow_token |         plan.context[PLAN_CONTEXT_IS_RESTORED] = flow_token | ||||||
|         session[SESSION_KEY_PLAN] = plan |         session[SESSION_KEY_PLAN] = plan | ||||||
|         session.save() |         session.save() | ||||||
| @ -98,6 +96,4 @@ class TestSourceStage(FlowTestCase): | |||||||
|             reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}), follow=True |             reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}), follow=True | ||||||
|         ) |         ) | ||||||
|         self.assertEqual(response.status_code, 200) |         self.assertEqual(response.status_code, 200) | ||||||
|         self.assertStageRedirects( |         self.assertStageRedirects(response, reverse("authentik_core:root-redirect")) | ||||||
|             response, reverse("authentik_core:if-flow", kwargs={"flow_slug": flow.slug}) |  | ||||||
|         ) |  | ||||||
|  | |||||||
| @ -76,10 +76,10 @@ class FlowPlan: | |||||||
|         self.bindings.append(binding) |         self.bindings.append(binding) | ||||||
|         self.markers.append(marker or StageMarker()) |         self.markers.append(marker or StageMarker()) | ||||||
|  |  | ||||||
|     def insert_stage(self, stage: Stage, marker: StageMarker | None = None, index=1): |     def insert_stage(self, stage: Stage, marker: StageMarker | None = None): | ||||||
|         """Insert stage into plan, as immediate next stage""" |         """Insert stage into plan, as immediate next stage""" | ||||||
|         self.bindings.insert(index, FlowStageBinding(stage=stage, order=0)) |         self.bindings.insert(1, FlowStageBinding(stage=stage, order=0)) | ||||||
|         self.markers.insert(index, marker or StageMarker()) |         self.markers.insert(1, marker or StageMarker()) | ||||||
|  |  | ||||||
|     def redirect(self, destination: str): |     def redirect(self, destination: str): | ||||||
|         """Insert a redirect stage as next stage""" |         """Insert a redirect stage as next stage""" | ||||||
|  | |||||||
| @ -282,14 +282,16 @@ class ConfigLoader: | |||||||
|  |  | ||||||
|     def get_optional_int(self, path: str, default=None) -> int | None: |     def get_optional_int(self, path: str, default=None) -> int | None: | ||||||
|         """Wrapper for get that converts value into int or None if set""" |         """Wrapper for get that converts value into int or None if set""" | ||||||
|         value = self.get(path, UNSET) |         value = self.get(path, default) | ||||||
|         if value is UNSET: |         if value is UNSET: | ||||||
|             return default |             return default | ||||||
|         try: |         try: | ||||||
|             return int(value) |             return int(value) | ||||||
|         except (ValueError, TypeError) as exc: |         except (ValueError, TypeError) as exc: | ||||||
|             if value is None or (isinstance(value, str) and value.lower() == "null"): |             if value is None or (isinstance(value, str) and value.lower() == "null"): | ||||||
|                 return None |                 return default | ||||||
|  |             if value is UNSET: | ||||||
|  |                 return default | ||||||
|             self.log("warning", "Failed to parse config as int", path=path, exc=str(exc)) |             self.log("warning", "Failed to parse config as int", path=path, exc=str(exc)) | ||||||
|             return default |             return default | ||||||
|  |  | ||||||
| @ -370,9 +372,9 @@ def django_db_config(config: ConfigLoader | None = None) -> dict: | |||||||
|                 "sslcert": config.get("postgresql.sslcert"), |                 "sslcert": config.get("postgresql.sslcert"), | ||||||
|                 "sslkey": config.get("postgresql.sslkey"), |                 "sslkey": config.get("postgresql.sslkey"), | ||||||
|             }, |             }, | ||||||
|             "CONN_MAX_AGE": config.get_optional_int("postgresql.conn_max_age", 0), |             "CONN_MAX_AGE": CONFIG.get_optional_int("postgresql.conn_max_age", 0), | ||||||
|             "CONN_HEALTH_CHECKS": config.get_bool("postgresql.conn_health_checks", False), |             "CONN_HEALTH_CHECKS": CONFIG.get_bool("postgresql.conn_health_checks", False), | ||||||
|             "DISABLE_SERVER_SIDE_CURSORS": config.get_bool( |             "DISABLE_SERVER_SIDE_CURSORS": CONFIG.get_bool( | ||||||
|                 "postgresql.disable_server_side_cursors", False |                 "postgresql.disable_server_side_cursors", False | ||||||
|             ), |             ), | ||||||
|             "TEST": { |             "TEST": { | ||||||
| @ -381,8 +383,8 @@ def django_db_config(config: ConfigLoader | None = None) -> dict: | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     conn_max_age = config.get_optional_int("postgresql.conn_max_age", UNSET) |     conn_max_age = CONFIG.get_optional_int("postgresql.conn_max_age", UNSET) | ||||||
|     disable_server_side_cursors = config.get_bool("postgresql.disable_server_side_cursors", UNSET) |     disable_server_side_cursors = CONFIG.get_bool("postgresql.disable_server_side_cursors", UNSET) | ||||||
|     if config.get_bool("postgresql.use_pgpool", False): |     if config.get_bool("postgresql.use_pgpool", False): | ||||||
|         db["default"]["DISABLE_SERVER_SIDE_CURSORS"] = True |         db["default"]["DISABLE_SERVER_SIDE_CURSORS"] = True | ||||||
|         if disable_server_side_cursors is not UNSET: |         if disable_server_side_cursors is not UNSET: | ||||||
|  | |||||||
| @ -158,18 +158,6 @@ class TestConfig(TestCase): | |||||||
|             test_obj = Test() |             test_obj = Test() | ||||||
|             dumps(test_obj, indent=4, cls=AttrEncoder) |             dumps(test_obj, indent=4, cls=AttrEncoder) | ||||||
|  |  | ||||||
|     def test_get_optional_int(self): |  | ||||||
|         config = ConfigLoader() |  | ||||||
|         self.assertEqual(config.get_optional_int("foo", 21), 21) |  | ||||||
|         self.assertEqual(config.get_optional_int("foo"), None) |  | ||||||
|         config.set("foo", "21") |  | ||||||
|         self.assertEqual(config.get_optional_int("foo"), 21) |  | ||||||
|         self.assertEqual(config.get_optional_int("foo", 0), 21) |  | ||||||
|         self.assertEqual(config.get_optional_int("foo", "null"), 21) |  | ||||||
|         config.set("foo", "null") |  | ||||||
|         self.assertEqual(config.get_optional_int("foo"), None) |  | ||||||
|         self.assertEqual(config.get_optional_int("foo", 21), None) |  | ||||||
|  |  | ||||||
|     @mock.patch.dict(environ, check_deprecations_env_vars) |     @mock.patch.dict(environ, check_deprecations_env_vars) | ||||||
|     def test_check_deprecations(self): |     def test_check_deprecations(self): | ||||||
|         """Test config key re-write for deprecated env vars""" |         """Test config key re-write for deprecated env vars""" | ||||||
| @ -233,16 +221,6 @@ class TestConfig(TestCase): | |||||||
|             }, |             }, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def test_db_conn_max_age(self): |  | ||||||
|         """Test DB conn_max_age Config""" |  | ||||||
|         config = ConfigLoader() |  | ||||||
|         config.set("postgresql.conn_max_age", "null") |  | ||||||
|         conf = django_db_config(config) |  | ||||||
|         self.assertEqual( |  | ||||||
|             conf["default"]["CONN_MAX_AGE"], |  | ||||||
|             None, |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def test_db_read_replicas(self): |     def test_db_read_replicas(self): | ||||||
|         """Test read replicas""" |         """Test read replicas""" | ||||||
|         config = ConfigLoader() |         config = ConfigLoader() | ||||||
|  | |||||||
| @ -1,54 +0,0 @@ | |||||||
| """Email utility functions""" |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def mask_email(email: str | None) -> str | None: |  | ||||||
|     """Mask email address for privacy |  | ||||||
|  |  | ||||||
|     Args: |  | ||||||
|         email: Email address to mask |  | ||||||
|     Returns: |  | ||||||
|         Masked email address or None if input is None |  | ||||||
|     Example: |  | ||||||
|         mask_email("myname@company.org") |  | ||||||
|         'm*****@c******.org' |  | ||||||
|     """ |  | ||||||
|     if not email: |  | ||||||
|         return None |  | ||||||
|  |  | ||||||
|     # Basic email format validation |  | ||||||
|     if email.count("@") != 1: |  | ||||||
|         raise ValueError("Invalid email format: Must contain exactly one '@' symbol") |  | ||||||
|  |  | ||||||
|     local, domain = email.split("@") |  | ||||||
|     if not local or not domain: |  | ||||||
|         raise ValueError("Invalid email format: Local and domain parts cannot be empty") |  | ||||||
|  |  | ||||||
|     domain_parts = domain.split(".") |  | ||||||
|     if len(domain_parts) < 2:  # noqa: PLR2004 |  | ||||||
|         raise ValueError("Invalid email format: Domain must contain at least one dot") |  | ||||||
|  |  | ||||||
|     limit = 2 |  | ||||||
|  |  | ||||||
|     # Mask local part (keep first char) |  | ||||||
|     if len(local) <= limit: |  | ||||||
|         masked_local = "*" * len(local) |  | ||||||
|     else: |  | ||||||
|         masked_local = local[0] + "*" * (len(local) - 1) |  | ||||||
|  |  | ||||||
|     # Mask each domain part except the last one (TLD) |  | ||||||
|     masked_domain_parts = [] |  | ||||||
|     for _i, part in enumerate(domain_parts[:-1]):  # Process all parts except TLD |  | ||||||
|         if not part:  # Check for empty parts (consecutive dots) |  | ||||||
|             raise ValueError("Invalid email format: Domain parts cannot be empty") |  | ||||||
|         if len(part) <= limit: |  | ||||||
|             masked_part = "*" * len(part) |  | ||||||
|         else: |  | ||||||
|             masked_part = part[0] + "*" * (len(part) - 1) |  | ||||||
|         masked_domain_parts.append(masked_part) |  | ||||||
|  |  | ||||||
|     # Add TLD unchanged |  | ||||||
|     if not domain_parts[-1]:  # Check if TLD is empty |  | ||||||
|         raise ValueError("Invalid email format: TLD cannot be empty") |  | ||||||
|     masked_domain_parts.append(domain_parts[-1]) |  | ||||||
|  |  | ||||||
|     return f"{masked_local}@{'.'.join(masked_domain_parts)}" |  | ||||||
| @ -19,6 +19,7 @@ from authentik.core.api.used_by import UsedByMixin | |||||||
| from authentik.core.api.utils import JSONDictField, ModelSerializer, PassiveSerializer | from authentik.core.api.utils import JSONDictField, ModelSerializer, PassiveSerializer | ||||||
| from authentik.core.models import Provider | from authentik.core.models import Provider | ||||||
| from authentik.enterprise.license import LicenseKey | from authentik.enterprise.license import LicenseKey | ||||||
|  | from authentik.enterprise.providers.rac.models import RACProvider | ||||||
| from authentik.lib.utils.time import timedelta_from_string, timedelta_string_validator | from authentik.lib.utils.time import timedelta_from_string, timedelta_string_validator | ||||||
| from authentik.outposts.api.service_connections import ServiceConnectionSerializer | from authentik.outposts.api.service_connections import ServiceConnectionSerializer | ||||||
| from authentik.outposts.apps import MANAGED_OUTPOST, MANAGED_OUTPOST_NAME | from authentik.outposts.apps import MANAGED_OUTPOST, MANAGED_OUTPOST_NAME | ||||||
| @ -30,7 +31,6 @@ from authentik.outposts.models import ( | |||||||
| ) | ) | ||||||
| from authentik.providers.ldap.models import LDAPProvider | from authentik.providers.ldap.models import LDAPProvider | ||||||
| from authentik.providers.proxy.models import ProxyProvider | from authentik.providers.proxy.models import ProxyProvider | ||||||
| from authentik.providers.rac.models import RACProvider |  | ||||||
| from authentik.providers.radius.models import RadiusProvider | from authentik.providers.radius.models import RadiusProvider | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| @ -18,6 +18,8 @@ from kubernetes.config.kube_config import KUBE_CONFIG_DEFAULT_LOCATION | |||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
| from yaml import safe_load | from yaml import safe_load | ||||||
|  |  | ||||||
|  | from authentik.enterprise.providers.rac.controllers.docker import RACDockerController | ||||||
|  | from authentik.enterprise.providers.rac.controllers.kubernetes import RACKubernetesController | ||||||
| from authentik.events.models import TaskStatus | from authentik.events.models import TaskStatus | ||||||
| from authentik.events.system_tasks import SystemTask, prefill_task | from authentik.events.system_tasks import SystemTask, prefill_task | ||||||
| from authentik.lib.config import CONFIG | from authentik.lib.config import CONFIG | ||||||
| @ -39,8 +41,6 @@ from authentik.providers.ldap.controllers.docker import LDAPDockerController | |||||||
| from authentik.providers.ldap.controllers.kubernetes import LDAPKubernetesController | from authentik.providers.ldap.controllers.kubernetes import LDAPKubernetesController | ||||||
| from authentik.providers.proxy.controllers.docker import ProxyDockerController | from authentik.providers.proxy.controllers.docker import ProxyDockerController | ||||||
| from authentik.providers.proxy.controllers.kubernetes import ProxyKubernetesController | from authentik.providers.proxy.controllers.kubernetes import ProxyKubernetesController | ||||||
| from authentik.providers.rac.controllers.docker import RACDockerController |  | ||||||
| from authentik.providers.rac.controllers.kubernetes import RACKubernetesController |  | ||||||
| from authentik.providers.radius.controllers.docker import RadiusDockerController | from authentik.providers.radius.controllers.docker import RadiusDockerController | ||||||
| from authentik.providers.radius.controllers.kubernetes import RadiusKubernetesController | from authentik.providers.radius.controllers.kubernetes import RadiusKubernetesController | ||||||
| from authentik.root.celery import CELERY_APP | from authentik.root.celery import CELERY_APP | ||||||
|  | |||||||
| @ -1,11 +1,26 @@ | |||||||
| """Expression Policy API""" | """Expression Policy API""" | ||||||
|  |  | ||||||
|  | from drf_spectacular.utils import OpenApiResponse, extend_schema | ||||||
|  | from guardian.shortcuts import get_objects_for_user | ||||||
|  | from rest_framework.decorators import action | ||||||
|  | from rest_framework.fields import CharField | ||||||
|  | from rest_framework.request import Request | ||||||
|  | from rest_framework.response import Response | ||||||
| from rest_framework.viewsets import ModelViewSet | from rest_framework.viewsets import ModelViewSet | ||||||
|  | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik.core.api.used_by import UsedByMixin | from authentik.core.api.used_by import UsedByMixin | ||||||
|  | from authentik.events.logs import LogEventSerializer, capture_logs | ||||||
|  | from authentik.policies.api.exec import PolicyTestResultSerializer, PolicyTestSerializer | ||||||
| from authentik.policies.api.policies import PolicySerializer | from authentik.policies.api.policies import PolicySerializer | ||||||
| from authentik.policies.expression.evaluator import PolicyEvaluator | from authentik.policies.expression.evaluator import PolicyEvaluator | ||||||
| from authentik.policies.expression.models import ExpressionPolicy | from authentik.policies.expression.models import ExpressionPolicy | ||||||
|  | from authentik.policies.models import PolicyBinding | ||||||
|  | from authentik.policies.process import PolicyProcess | ||||||
|  | from authentik.policies.types import PolicyRequest | ||||||
|  | from authentik.rbac.decorators import permission_required | ||||||
|  |  | ||||||
|  | LOGGER = get_logger() | ||||||
|  |  | ||||||
|  |  | ||||||
| class ExpressionPolicySerializer(PolicySerializer): | class ExpressionPolicySerializer(PolicySerializer): | ||||||
| @ -30,3 +45,50 @@ class ExpressionPolicyViewSet(UsedByMixin, ModelViewSet): | |||||||
|     filterset_fields = "__all__" |     filterset_fields = "__all__" | ||||||
|     ordering = ["name"] |     ordering = ["name"] | ||||||
|     search_fields = ["name"] |     search_fields = ["name"] | ||||||
|  |  | ||||||
|  |     class ExpressionPolicyTestSerializer(PolicyTestSerializer): | ||||||
|  |         """Expression policy test serializer""" | ||||||
|  |  | ||||||
|  |         expression = CharField() | ||||||
|  |  | ||||||
|  |     @permission_required("authentik_policies.view_policy") | ||||||
|  |     @extend_schema( | ||||||
|  |         request=ExpressionPolicyTestSerializer(), | ||||||
|  |         responses={ | ||||||
|  |             200: PolicyTestResultSerializer(), | ||||||
|  |             400: OpenApiResponse(description="Invalid parameters"), | ||||||
|  |         }, | ||||||
|  |     ) | ||||||
|  |     @action(detail=True, pagination_class=None, filter_backends=[], methods=["POST"]) | ||||||
|  |     def test(self, request: Request, pk: str) -> Response: | ||||||
|  |         """Test policy""" | ||||||
|  |         policy = self.get_object() | ||||||
|  |         test_params = self.ExpressionPolicyTestSerializer(data=request.data) | ||||||
|  |         if not test_params.is_valid(): | ||||||
|  |             return Response(test_params.errors, status=400) | ||||||
|  |  | ||||||
|  |         # User permission check, only allow policy testing for users that are readable | ||||||
|  |         users = get_objects_for_user(request.user, "authentik_core.view_user").filter( | ||||||
|  |             pk=test_params.validated_data["user"].pk | ||||||
|  |         ) | ||||||
|  |         if not users.exists(): | ||||||
|  |             return Response(status=400) | ||||||
|  |  | ||||||
|  |         policy.expression = test_params.validated_data["expression"] | ||||||
|  |  | ||||||
|  |         p_request = PolicyRequest(users.first()) | ||||||
|  |         p_request.debug = True | ||||||
|  |         p_request.set_http_request(self.request) | ||||||
|  |         p_request.context = test_params.validated_data.get("context", {}) | ||||||
|  |  | ||||||
|  |         proc = PolicyProcess(PolicyBinding(policy=policy), p_request, None) | ||||||
|  |         with capture_logs() as logs: | ||||||
|  |             result = proc.execute() | ||||||
|  |         log_messages = [] | ||||||
|  |         for log in logs: | ||||||
|  |             if log.attributes.get("process", "") == "PolicyProcess": | ||||||
|  |                 continue | ||||||
|  |             log_messages.append(LogEventSerializer(log).data) | ||||||
|  |         result.log_messages = log_messages | ||||||
|  |         response = PolicyTestResultSerializer(result) | ||||||
|  |         return Response(response.data) | ||||||
|  | |||||||
| @ -42,12 +42,6 @@ class GeoIPPolicySerializer(CountryFieldMixin, PolicySerializer): | |||||||
|             "asns", |             "asns", | ||||||
|             "countries", |             "countries", | ||||||
|             "countries_obj", |             "countries_obj", | ||||||
|             "check_history_distance", |  | ||||||
|             "history_max_distance_km", |  | ||||||
|             "distance_tolerance_km", |  | ||||||
|             "history_login_count", |  | ||||||
|             "check_impossible_travel", |  | ||||||
|             "impossible_tolerance_km", |  | ||||||
|         ] |         ] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| @ -1,43 +0,0 @@ | |||||||
| # Generated by Django 5.0.10 on 2025-01-02 20:40 |  | ||||||
|  |  | ||||||
| from django.db import migrations, models |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class Migration(migrations.Migration): |  | ||||||
|  |  | ||||||
|     dependencies = [ |  | ||||||
|         ("authentik_policies_geoip", "0001_initial"), |  | ||||||
|     ] |  | ||||||
|  |  | ||||||
|     operations = [ |  | ||||||
|         migrations.AddField( |  | ||||||
|             model_name="geoippolicy", |  | ||||||
|             name="check_history_distance", |  | ||||||
|             field=models.BooleanField(default=False), |  | ||||||
|         ), |  | ||||||
|         migrations.AddField( |  | ||||||
|             model_name="geoippolicy", |  | ||||||
|             name="check_impossible_travel", |  | ||||||
|             field=models.BooleanField(default=False), |  | ||||||
|         ), |  | ||||||
|         migrations.AddField( |  | ||||||
|             model_name="geoippolicy", |  | ||||||
|             name="distance_tolerance_km", |  | ||||||
|             field=models.PositiveIntegerField(default=50), |  | ||||||
|         ), |  | ||||||
|         migrations.AddField( |  | ||||||
|             model_name="geoippolicy", |  | ||||||
|             name="history_login_count", |  | ||||||
|             field=models.PositiveIntegerField(default=5), |  | ||||||
|         ), |  | ||||||
|         migrations.AddField( |  | ||||||
|             model_name="geoippolicy", |  | ||||||
|             name="history_max_distance_km", |  | ||||||
|             field=models.PositiveBigIntegerField(default=100), |  | ||||||
|         ), |  | ||||||
|         migrations.AddField( |  | ||||||
|             model_name="geoippolicy", |  | ||||||
|             name="impossible_tolerance_km", |  | ||||||
|             field=models.PositiveIntegerField(default=100), |  | ||||||
|         ), |  | ||||||
|     ] |  | ||||||
| @ -4,21 +4,15 @@ from itertools import chain | |||||||
|  |  | ||||||
| from django.contrib.postgres.fields import ArrayField | from django.contrib.postgres.fields import ArrayField | ||||||
| from django.db import models | from django.db import models | ||||||
| from django.utils.timezone import now |  | ||||||
| from django.utils.translation import gettext as _ | from django.utils.translation import gettext as _ | ||||||
| from django_countries.fields import CountryField | from django_countries.fields import CountryField | ||||||
| from geopy import distance |  | ||||||
| from rest_framework.serializers import BaseSerializer | from rest_framework.serializers import BaseSerializer | ||||||
|  |  | ||||||
| from authentik.events.context_processors.geoip import GeoIPDict |  | ||||||
| from authentik.events.models import Event, EventAction |  | ||||||
| from authentik.policies.exceptions import PolicyException | from authentik.policies.exceptions import PolicyException | ||||||
| from authentik.policies.geoip.exceptions import GeoIPNotFoundException | from authentik.policies.geoip.exceptions import GeoIPNotFoundException | ||||||
| from authentik.policies.models import Policy | from authentik.policies.models import Policy | ||||||
| from authentik.policies.types import PolicyRequest, PolicyResult | from authentik.policies.types import PolicyRequest, PolicyResult | ||||||
|  |  | ||||||
| MAX_DISTANCE_HOUR_KM = 1000 |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class GeoIPPolicy(Policy): | class GeoIPPolicy(Policy): | ||||||
|     """Ensure the user satisfies requirements of geography or network topology, based on IP |     """Ensure the user satisfies requirements of geography or network topology, based on IP | ||||||
| @ -27,15 +21,6 @@ class GeoIPPolicy(Policy): | |||||||
|     asns = ArrayField(models.IntegerField(), blank=True, default=list) |     asns = ArrayField(models.IntegerField(), blank=True, default=list) | ||||||
|     countries = CountryField(multiple=True, blank=True) |     countries = CountryField(multiple=True, blank=True) | ||||||
|  |  | ||||||
|     distance_tolerance_km = models.PositiveIntegerField(default=50) |  | ||||||
|  |  | ||||||
|     check_history_distance = models.BooleanField(default=False) |  | ||||||
|     history_max_distance_km = models.PositiveBigIntegerField(default=100) |  | ||||||
|     history_login_count = models.PositiveIntegerField(default=5) |  | ||||||
|  |  | ||||||
|     check_impossible_travel = models.BooleanField(default=False) |  | ||||||
|     impossible_tolerance_km = models.PositiveIntegerField(default=100) |  | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def serializer(self) -> type[BaseSerializer]: |     def serializer(self) -> type[BaseSerializer]: | ||||||
|         from authentik.policies.geoip.api import GeoIPPolicySerializer |         from authentik.policies.geoip.api import GeoIPPolicySerializer | ||||||
| @ -52,27 +37,21 @@ class GeoIPPolicy(Policy): | |||||||
|         - the client IP is advertised by an autonomous system with ASN in the `asns` |         - the client IP is advertised by an autonomous system with ASN in the `asns` | ||||||
|         - the client IP is geolocated in a country of `countries` |         - the client IP is geolocated in a country of `countries` | ||||||
|         """ |         """ | ||||||
|         static_results: list[PolicyResult] = [] |         results: list[PolicyResult] = [] | ||||||
|         dynamic_results: list[PolicyResult] = [] |  | ||||||
|  |  | ||||||
|         if self.asns: |         if self.asns: | ||||||
|             static_results.append(self.passes_asn(request)) |             results.append(self.passes_asn(request)) | ||||||
|         if self.countries: |         if self.countries: | ||||||
|             static_results.append(self.passes_country(request)) |             results.append(self.passes_country(request)) | ||||||
|  |  | ||||||
|         if self.check_history_distance or self.check_impossible_travel: |         if not results: | ||||||
|             dynamic_results.append(self.passes_distance(request)) |  | ||||||
|  |  | ||||||
|         if not static_results and not dynamic_results: |  | ||||||
|             return PolicyResult(True) |             return PolicyResult(True) | ||||||
|  |  | ||||||
|         passing = any(r.passing for r in static_results) and all(r.passing for r in dynamic_results) |         passing = any(r.passing for r in results) | ||||||
|         messages = chain( |         messages = chain(*[r.messages for r in results]) | ||||||
|             *[r.messages for r in static_results], *[r.messages for r in dynamic_results] |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|         result = PolicyResult(passing, *messages) |         result = PolicyResult(passing, *messages) | ||||||
|         result.source_results = list(chain(static_results, dynamic_results)) |         result.source_results = results | ||||||
|  |  | ||||||
|         return result |         return result | ||||||
|  |  | ||||||
| @ -94,7 +73,7 @@ class GeoIPPolicy(Policy): | |||||||
|  |  | ||||||
|     def passes_country(self, request: PolicyRequest) -> PolicyResult: |     def passes_country(self, request: PolicyRequest) -> PolicyResult: | ||||||
|         # This is not a single get chain because `request.context` can contain `{ "geoip": None }`. |         # This is not a single get chain because `request.context` can contain `{ "geoip": None }`. | ||||||
|         geoip_data: GeoIPDict | None = request.context.get("geoip") |         geoip_data = request.context.get("geoip") | ||||||
|         country = geoip_data.get("country") if geoip_data else None |         country = geoip_data.get("country") if geoip_data else None | ||||||
|  |  | ||||||
|         if not country: |         if not country: | ||||||
| @ -108,42 +87,6 @@ class GeoIPPolicy(Policy): | |||||||
|  |  | ||||||
|         return PolicyResult(True) |         return PolicyResult(True) | ||||||
|  |  | ||||||
|     def passes_distance(self, request: PolicyRequest) -> PolicyResult: |  | ||||||
|         """Check if current policy execution is out of distance range compared |  | ||||||
|         to previous authentication requests""" |  | ||||||
|         # Get previous login event and GeoIP data |  | ||||||
|         previous_logins = Event.objects.filter( |  | ||||||
|             action=EventAction.LOGIN, user__pk=request.user.pk, context__geo__isnull=False |  | ||||||
|         ).order_by("-created")[: self.history_login_count] |  | ||||||
|         _now = now() |  | ||||||
|         geoip_data: GeoIPDict | None = request.context.get("geoip") |  | ||||||
|         if not geoip_data: |  | ||||||
|             return PolicyResult(False) |  | ||||||
|         for previous_login in previous_logins: |  | ||||||
|             previous_login_geoip: GeoIPDict = previous_login.context["geo"] |  | ||||||
|  |  | ||||||
|             # Figure out distance |  | ||||||
|             dist = distance.geodesic( |  | ||||||
|                 (previous_login_geoip["lat"], previous_login_geoip["long"]), |  | ||||||
|                 (geoip_data["lat"], geoip_data["long"]), |  | ||||||
|             ) |  | ||||||
|             if self.check_history_distance and dist.km >= ( |  | ||||||
|                 self.history_max_distance_km + self.distance_tolerance_km |  | ||||||
|             ): |  | ||||||
|                 return PolicyResult( |  | ||||||
|                     False, _("Distance from previous authentication is larger than threshold.") |  | ||||||
|                 ) |  | ||||||
|             # Check if distance between `previous_login` and now is more |  | ||||||
|             # than max distance per hour times the amount of hours since the previous login |  | ||||||
|             # (round down to the lowest closest time of hours) |  | ||||||
|             # clamped to be at least 1 hour |  | ||||||
|             rel_time_hours = max(int((_now - previous_login.created).total_seconds() / 3600), 1) |  | ||||||
|             if self.check_impossible_travel and dist.km >= ( |  | ||||||
|                 (MAX_DISTANCE_HOUR_KM * rel_time_hours) + self.distance_tolerance_km |  | ||||||
|             ): |  | ||||||
|                 return PolicyResult(False, _("Distance is further than possible.")) |  | ||||||
|         return PolicyResult(True) |  | ||||||
|  |  | ||||||
|     class Meta(Policy.PolicyMeta): |     class Meta(Policy.PolicyMeta): | ||||||
|         verbose_name = _("GeoIP Policy") |         verbose_name = _("GeoIP Policy") | ||||||
|         verbose_name_plural = _("GeoIP Policies") |         verbose_name_plural = _("GeoIP Policies") | ||||||
|  | |||||||
| @ -1,10 +1,8 @@ | |||||||
| """geoip policy tests""" | """geoip policy tests""" | ||||||
|  |  | ||||||
| from django.test import TestCase | from django.test import TestCase | ||||||
|  | from guardian.shortcuts import get_anonymous_user | ||||||
|  |  | ||||||
| from authentik.core.tests.utils import create_test_user |  | ||||||
| from authentik.events.models import Event, EventAction |  | ||||||
| from authentik.events.utils import get_user |  | ||||||
| from authentik.policies.engine import PolicyRequest, PolicyResult | from authentik.policies.engine import PolicyRequest, PolicyResult | ||||||
| from authentik.policies.exceptions import PolicyException | from authentik.policies.exceptions import PolicyException | ||||||
| from authentik.policies.geoip.exceptions import GeoIPNotFoundException | from authentik.policies.geoip.exceptions import GeoIPNotFoundException | ||||||
| @ -16,8 +14,8 @@ class TestGeoIPPolicy(TestCase): | |||||||
|  |  | ||||||
|     def setUp(self): |     def setUp(self): | ||||||
|         super().setUp() |         super().setUp() | ||||||
|         self.user = create_test_user() |  | ||||||
|         self.request = PolicyRequest(self.user) |         self.request = PolicyRequest(get_anonymous_user()) | ||||||
|  |  | ||||||
|         self.context_disabled_geoip = {} |         self.context_disabled_geoip = {} | ||||||
|         self.context_unknown_ip = {"asn": None, "geoip": None} |         self.context_unknown_ip = {"asn": None, "geoip": None} | ||||||
| @ -128,70 +126,3 @@ class TestGeoIPPolicy(TestCase): | |||||||
|         result: PolicyResult = policy.passes(self.request) |         result: PolicyResult = policy.passes(self.request) | ||||||
|  |  | ||||||
|         self.assertTrue(result.passing) |         self.assertTrue(result.passing) | ||||||
|  |  | ||||||
|     def test_history(self): |  | ||||||
|         """Test history checks""" |  | ||||||
|         Event.objects.create( |  | ||||||
|             action=EventAction.LOGIN, |  | ||||||
|             user=get_user(self.user), |  | ||||||
|             context={ |  | ||||||
|                 # Random location in Canada |  | ||||||
|                 "geo": {"lat": 55.868351, "long": -104.441011}, |  | ||||||
|             }, |  | ||||||
|         ) |  | ||||||
|         # Random location in Poland |  | ||||||
|         self.request.context["geoip"] = {"lat": 50.950613, "long": 20.363679} |  | ||||||
|  |  | ||||||
|         policy = GeoIPPolicy.objects.create(check_history_distance=True) |  | ||||||
|  |  | ||||||
|         result: PolicyResult = policy.passes(self.request) |  | ||||||
|         self.assertFalse(result.passing) |  | ||||||
|  |  | ||||||
|     def test_history_no_data(self): |  | ||||||
|         """Test history checks (with no geoip data in context)""" |  | ||||||
|         Event.objects.create( |  | ||||||
|             action=EventAction.LOGIN, |  | ||||||
|             user=get_user(self.user), |  | ||||||
|             context={ |  | ||||||
|                 # Random location in Canada |  | ||||||
|                 "geo": {"lat": 55.868351, "long": -104.441011}, |  | ||||||
|             }, |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|         policy = GeoIPPolicy.objects.create(check_history_distance=True) |  | ||||||
|  |  | ||||||
|         result: PolicyResult = policy.passes(self.request) |  | ||||||
|         self.assertFalse(result.passing) |  | ||||||
|  |  | ||||||
|     def test_history_impossible_travel(self): |  | ||||||
|         """Test history checks""" |  | ||||||
|         Event.objects.create( |  | ||||||
|             action=EventAction.LOGIN, |  | ||||||
|             user=get_user(self.user), |  | ||||||
|             context={ |  | ||||||
|                 # Random location in Canada |  | ||||||
|                 "geo": {"lat": 55.868351, "long": -104.441011}, |  | ||||||
|             }, |  | ||||||
|         ) |  | ||||||
|         # Random location in Poland |  | ||||||
|         self.request.context["geoip"] = {"lat": 50.950613, "long": 20.363679} |  | ||||||
|  |  | ||||||
|         policy = GeoIPPolicy.objects.create(check_impossible_travel=True) |  | ||||||
|  |  | ||||||
|         result: PolicyResult = policy.passes(self.request) |  | ||||||
|         self.assertFalse(result.passing) |  | ||||||
|  |  | ||||||
|     def test_history_no_geoip(self): |  | ||||||
|         """Test history checks (previous login with no geoip data)""" |  | ||||||
|         Event.objects.create( |  | ||||||
|             action=EventAction.LOGIN, |  | ||||||
|             user=get_user(self.user), |  | ||||||
|             context={}, |  | ||||||
|         ) |  | ||||||
|         # Random location in Poland |  | ||||||
|         self.request.context["geoip"] = {"lat": 50.950613, "long": 20.363679} |  | ||||||
|  |  | ||||||
|         policy = GeoIPPolicy.objects.create(check_history_distance=True) |  | ||||||
|  |  | ||||||
|         result: PolicyResult = policy.passes(self.request) |  | ||||||
|         self.assertFalse(result.passing) |  | ||||||
|  | |||||||
| @ -148,10 +148,10 @@ class PasswordPolicy(Policy): | |||||||
|             user_inputs.append(request.user.email) |             user_inputs.append(request.user.email) | ||||||
|         if request.http_request: |         if request.http_request: | ||||||
|             user_inputs.append(request.http_request.brand.branding_title) |             user_inputs.append(request.http_request.brand.branding_title) | ||||||
|         # Only calculate result for the first 72 characters, as with over 100 char |         # Only calculate result for the first 100 characters, as with over 100 char | ||||||
|         # long passwords we can be reasonably sure that they'll surpass the score anyways |         # long passwords we can be reasonably sure that they'll surpass the score anyways | ||||||
|         # See https://github.com/dropbox/zxcvbn#runtime-latency |         # See https://github.com/dropbox/zxcvbn#runtime-latency | ||||||
|         results = zxcvbn(password[:72], user_inputs) |         results = zxcvbn(password[:100], user_inputs) | ||||||
|         LOGGER.debug("password failed", check="zxcvbn", score=results["score"]) |         LOGGER.debug("password failed", check="zxcvbn", score=results["score"]) | ||||||
|         result = PolicyResult(results["score"] > self.zxcvbn_score_threshold) |         result = PolicyResult(results["score"] > self.zxcvbn_score_threshold) | ||||||
|         if not result.passing: |         if not result.passing: | ||||||
|  | |||||||
| @ -71,7 +71,7 @@ class CodeValidatorView(PolicyAccessView): | |||||||
|         except FlowNonApplicableException: |         except FlowNonApplicableException: | ||||||
|             LOGGER.warning("Flow not applicable to user") |             LOGGER.warning("Flow not applicable to user") | ||||||
|             return None |             return None | ||||||
|         plan.append_stage(in_memory_stage(OAuthDeviceCodeFinishStage)) |         plan.insert_stage(in_memory_stage(OAuthDeviceCodeFinishStage)) | ||||||
|         return plan.to_redirect(self.request, self.token.provider.authorization_flow) |         return plan.to_redirect(self.request, self.token.provider.authorization_flow) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| @ -34,5 +34,5 @@ class EndSessionView(PolicyAccessView): | |||||||
|                 PLAN_CONTEXT_APPLICATION: self.application, |                 PLAN_CONTEXT_APPLICATION: self.application, | ||||||
|             }, |             }, | ||||||
|         ) |         ) | ||||||
|         plan.append_stage(in_memory_stage(SessionEndStage)) |         plan.insert_stage(in_memory_stage(SessionEndStage)) | ||||||
|         return plan.to_redirect(self.request, self.flow) |         return plan.to_redirect(self.request, self.flow) | ||||||
|  | |||||||
| @ -36,17 +36,17 @@ class IngressReconciler(KubernetesObjectReconciler[V1Ingress]): | |||||||
|     def reconciler_name() -> str: |     def reconciler_name() -> str: | ||||||
|         return "ingress" |         return "ingress" | ||||||
|  |  | ||||||
|     def _check_annotations(self, current: V1Ingress, reference: V1Ingress): |     def _check_annotations(self, reference: V1Ingress): | ||||||
|         """Check that all annotations *we* set are correct""" |         """Check that all annotations *we* set are correct""" | ||||||
|         for key, value in reference.metadata.annotations.items(): |         for key, value in self.get_ingress_annotations().items(): | ||||||
|             if key not in current.metadata.annotations: |             if key not in reference.metadata.annotations: | ||||||
|                 raise NeedsUpdate() |                 raise NeedsUpdate() | ||||||
|             if current.metadata.annotations[key] != value: |             if reference.metadata.annotations[key] != value: | ||||||
|                 raise NeedsUpdate() |                 raise NeedsUpdate() | ||||||
|  |  | ||||||
|     def reconcile(self, current: V1Ingress, reference: V1Ingress): |     def reconcile(self, current: V1Ingress, reference: V1Ingress): | ||||||
|         super().reconcile(current, reference) |         super().reconcile(current, reference) | ||||||
|         self._check_annotations(current, reference) |         self._check_annotations(reference) | ||||||
|         # Create a list of all expected host and tls hosts |         # Create a list of all expected host and tls hosts | ||||||
|         expected_hosts = [] |         expected_hosts = [] | ||||||
|         expected_hosts_tls = [] |         expected_hosts_tls = [] | ||||||
|  | |||||||
| @ -1,14 +0,0 @@ | |||||||
| """RAC app config""" |  | ||||||
|  |  | ||||||
| from authentik.blueprints.apps import ManagedAppConfig |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class AuthentikProviderRAC(ManagedAppConfig): |  | ||||||
|     """authentik rac app config""" |  | ||||||
|  |  | ||||||
|     name = "authentik.providers.rac" |  | ||||||
|     label = "authentik_providers_rac" |  | ||||||
|     verbose_name = "authentik Providers.RAC" |  | ||||||
|     default = True |  | ||||||
|     mountpoint = "" |  | ||||||
|     ws_mountpoint = "authentik.providers.rac.urls" |  | ||||||
| @ -61,7 +61,7 @@ class SAMLSLOView(PolicyAccessView): | |||||||
|                 PLAN_CONTEXT_APPLICATION: self.application, |                 PLAN_CONTEXT_APPLICATION: self.application, | ||||||
|             }, |             }, | ||||||
|         ) |         ) | ||||||
|         plan.append_stage(in_memory_stage(SessionEndStage)) |         plan.insert_stage(in_memory_stage(SessionEndStage)) | ||||||
|         return plan.to_redirect(self.request, self.flow) |         return plan.to_redirect(self.request, self.flow) | ||||||
|  |  | ||||||
|     def post(self, request: HttpRequest, application_slug: str) -> HttpResponse: |     def post(self, request: HttpRequest, application_slug: str) -> HttpResponse: | ||||||
|  | |||||||
| @ -243,7 +243,6 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]): | |||||||
|             if user.value not in users_should: |             if user.value not in users_should: | ||||||
|                 users_to_remove.append(user.value) |                 users_to_remove.append(user.value) | ||||||
|         # Check users that should be in the group and add them |         # Check users that should be in the group and add them | ||||||
|         if current_group.members is not None: |  | ||||||
|         for user in users_should: |         for user in users_should: | ||||||
|             if len([x for x in current_group.members if x.value == user]) < 1: |             if len([x for x in current_group.members if x.value == user]) < 1: | ||||||
|                 users_to_add.append(user) |                 users_to_add.append(user) | ||||||
|  | |||||||
| @ -1,12 +1,10 @@ | |||||||
| """User client""" | """User client""" | ||||||
|  |  | ||||||
| from django.db import transaction |  | ||||||
| from django.utils.http import urlencode |  | ||||||
| from pydantic import ValidationError | from pydantic import ValidationError | ||||||
|  |  | ||||||
| from authentik.core.models import User | from authentik.core.models import User | ||||||
| from authentik.lib.sync.mapper import PropertyMappingManager | from authentik.lib.sync.mapper import PropertyMappingManager | ||||||
| from authentik.lib.sync.outgoing.exceptions import ObjectExistsSyncException, StopSync | from authentik.lib.sync.outgoing.exceptions import StopSync | ||||||
| from authentik.policies.utils import delete_none_values | from authentik.policies.utils import delete_none_values | ||||||
| from authentik.providers.scim.clients.base import SCIMClient | from authentik.providers.scim.clients.base import SCIMClient | ||||||
| from authentik.providers.scim.clients.schema import SCIM_USER_SCHEMA | from authentik.providers.scim.clients.schema import SCIM_USER_SCHEMA | ||||||
| @ -57,8 +55,6 @@ class SCIMUserClient(SCIMClient[User, SCIMProviderUser, SCIMUserSchema]): | |||||||
|     def create(self, user: User): |     def create(self, user: User): | ||||||
|         """Create user from scratch and create a connection object""" |         """Create user from scratch and create a connection object""" | ||||||
|         scim_user = self.to_schema(user, None) |         scim_user = self.to_schema(user, None) | ||||||
|         with transaction.atomic(): |  | ||||||
|             try: |  | ||||||
|         response = self._request( |         response = self._request( | ||||||
|             "POST", |             "POST", | ||||||
|             "/Users", |             "/Users", | ||||||
| @ -67,25 +63,10 @@ class SCIMUserClient(SCIMClient[User, SCIMProviderUser, SCIMUserSchema]): | |||||||
|                 exclude_unset=True, |                 exclude_unset=True, | ||||||
|             ), |             ), | ||||||
|         ) |         ) | ||||||
|             except ObjectExistsSyncException as exc: |  | ||||||
|                 if not self._config.filter.supported: |  | ||||||
|                     raise exc |  | ||||||
|                 users = self._request( |  | ||||||
|                     "GET", f"/Users?{urlencode({'filter': f'userName eq {scim_user.userName}'})}" |  | ||||||
|                 ) |  | ||||||
|                 users_res = users.get("Resources", []) |  | ||||||
|                 if len(users_res) < 1: |  | ||||||
|                     raise exc |  | ||||||
|                 return SCIMProviderUser.objects.create( |  | ||||||
|                     provider=self.provider, user=user, scim_id=users_res[0]["id"] |  | ||||||
|                 ) |  | ||||||
|             else: |  | ||||||
|         scim_id = response.get("id") |         scim_id = response.get("id") | ||||||
|         if not scim_id or scim_id == "": |         if not scim_id or scim_id == "": | ||||||
|             raise StopSync("SCIM Response with missing or invalid `id`") |             raise StopSync("SCIM Response with missing or invalid `id`") | ||||||
|                 return SCIMProviderUser.objects.create( |         return SCIMProviderUser.objects.create(provider=self.provider, user=user, scim_id=scim_id) | ||||||
|                     provider=self.provider, user=user, scim_id=scim_id |  | ||||||
|                 ) |  | ||||||
|  |  | ||||||
|     def update(self, user: User, connection: SCIMProviderUser): |     def update(self, user: User, connection: SCIMProviderUser): | ||||||
|         """Update existing user""" |         """Update existing user""" | ||||||
|  | |||||||
| @ -2,7 +2,7 @@ | |||||||
|  |  | ||||||
| from django.apps import apps | from django.apps import apps | ||||||
| from django.contrib.auth.models import Permission | from django.contrib.auth.models import Permission | ||||||
| from django.db.models import QuerySet | from django.db.models import Q, QuerySet | ||||||
| from django_filters.filters import ModelChoiceFilter | from django_filters.filters import ModelChoiceFilter | ||||||
| from django_filters.filterset import FilterSet | from django_filters.filterset import FilterSet | ||||||
| from django_filters.rest_framework import DjangoFilterBackend | from django_filters.rest_framework import DjangoFilterBackend | ||||||
| @ -18,6 +18,7 @@ from rest_framework.filters import OrderingFilter, SearchFilter | |||||||
| from rest_framework.permissions import IsAuthenticated | from rest_framework.permissions import IsAuthenticated | ||||||
| from rest_framework.viewsets import ReadOnlyModelViewSet | from rest_framework.viewsets import ReadOnlyModelViewSet | ||||||
|  |  | ||||||
|  | from authentik.blueprints.v1.importer import excluded_models | ||||||
| from authentik.core.api.utils import ModelSerializer, PassiveSerializer | from authentik.core.api.utils import ModelSerializer, PassiveSerializer | ||||||
| from authentik.core.models import User | from authentik.core.models import User | ||||||
| from authentik.lib.validators import RequiredTogetherValidator | from authentik.lib.validators import RequiredTogetherValidator | ||||||
| @ -105,13 +106,13 @@ class RBACPermissionViewSet(ReadOnlyModelViewSet): | |||||||
|     ] |     ] | ||||||
|  |  | ||||||
|     def get_queryset(self) -> QuerySet: |     def get_queryset(self) -> QuerySet: | ||||||
|         return ( |         query = Q() | ||||||
|             Permission.objects.all() |         for model in excluded_models(): | ||||||
|             .select_related("content_type") |             query |= Q( | ||||||
|             .filter( |                 content_type__app_label=model._meta.app_label, | ||||||
|                 content_type__app_label__startswith="authentik", |                 content_type__model=model._meta.model_name, | ||||||
|             ) |  | ||||||
|             ) |             ) | ||||||
|  |         return Permission.objects.all().select_related("content_type").exclude(query) | ||||||
|  |  | ||||||
|  |  | ||||||
| class PermissionAssignSerializer(PassiveSerializer): | class PermissionAssignSerializer(PassiveSerializer): | ||||||
|  | |||||||
| @ -87,7 +87,6 @@ TENANT_APPS = [ | |||||||
|     "authentik.providers.ldap", |     "authentik.providers.ldap", | ||||||
|     "authentik.providers.oauth2", |     "authentik.providers.oauth2", | ||||||
|     "authentik.providers.proxy", |     "authentik.providers.proxy", | ||||||
|     "authentik.providers.rac", |  | ||||||
|     "authentik.providers.radius", |     "authentik.providers.radius", | ||||||
|     "authentik.providers.saml", |     "authentik.providers.saml", | ||||||
|     "authentik.providers.scim", |     "authentik.providers.scim", | ||||||
| @ -101,7 +100,6 @@ TENANT_APPS = [ | |||||||
|     "authentik.sources.scim", |     "authentik.sources.scim", | ||||||
|     "authentik.stages.authenticator", |     "authentik.stages.authenticator", | ||||||
|     "authentik.stages.authenticator_duo", |     "authentik.stages.authenticator_duo", | ||||||
|     "authentik.stages.authenticator_email", |  | ||||||
|     "authentik.stages.authenticator_sms", |     "authentik.stages.authenticator_sms", | ||||||
|     "authentik.stages.authenticator_static", |     "authentik.stages.authenticator_static", | ||||||
|     "authentik.stages.authenticator_totp", |     "authentik.stages.authenticator_totp", | ||||||
|  | |||||||
| @ -68,6 +68,8 @@ class OAuth2Client(BaseOAuthClient): | |||||||
|             error_desc = self.get_request_arg("error_description", None) |             error_desc = self.get_request_arg("error_description", None) | ||||||
|             return {"error": error_desc or error or _("No token received.")} |             return {"error": error_desc or error or _("No token received.")} | ||||||
|         args = { |         args = { | ||||||
|  |             "client_id": self.get_client_id(), | ||||||
|  |             "client_secret": self.get_client_secret(), | ||||||
|             "redirect_uri": callback, |             "redirect_uri": callback, | ||||||
|             "code": code, |             "code": code, | ||||||
|             "grant_type": "authorization_code", |             "grant_type": "authorization_code", | ||||||
|  | |||||||
| @ -28,7 +28,7 @@ def update_well_known_jwks(self: SystemTask): | |||||||
|             LOGGER.warning("Failed to update well_known", source=source, exc=exc, text=text) |             LOGGER.warning("Failed to update well_known", source=source, exc=exc, text=text) | ||||||
|             messages.append(f"Failed to update OIDC configuration for {source.slug}") |             messages.append(f"Failed to update OIDC configuration for {source.slug}") | ||||||
|             continue |             continue | ||||||
|         config: dict = well_known_config.json() |         config = well_known_config.json() | ||||||
|         try: |         try: | ||||||
|             dirty = False |             dirty = False | ||||||
|             source_attr_key = ( |             source_attr_key = ( | ||||||
| @ -40,9 +40,7 @@ def update_well_known_jwks(self: SystemTask): | |||||||
|             for source_attr, config_key in source_attr_key: |             for source_attr, config_key in source_attr_key: | ||||||
|                 # Check if we're actually changing anything to only |                 # Check if we're actually changing anything to only | ||||||
|                 # save when something has changed |                 # save when something has changed | ||||||
|                 if config_key not in config: |                 if getattr(source, source_attr, "") != config[config_key]: | ||||||
|                     continue |  | ||||||
|                 if getattr(source, source_attr, "") != config.get(config_key, ""): |  | ||||||
|                     dirty = True |                     dirty = True | ||||||
|                 setattr(source, source_attr, config[config_key]) |                 setattr(source, source_attr, config[config_key]) | ||||||
|         except (IndexError, KeyError) as exc: |         except (IndexError, KeyError) as exc: | ||||||
|  | |||||||
| @ -2,7 +2,6 @@ | |||||||
|  |  | ||||||
| from typing import Any | from typing import Any | ||||||
|  |  | ||||||
| from requests import RequestException |  | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient | from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient | ||||||
| @ -22,35 +21,10 @@ class AzureADOAuthRedirect(OAuthRedirect): | |||||||
|         } |         } | ||||||
|  |  | ||||||
|  |  | ||||||
| class AzureADClient(UserprofileHeaderAuthClient): |  | ||||||
|     """Fetch AzureAD group information""" |  | ||||||
|  |  | ||||||
|     def get_profile_info(self, token): |  | ||||||
|         profile_data = super().get_profile_info(token) |  | ||||||
|         if "https://graph.microsoft.com/GroupMember.Read.All" not in self.source.additional_scopes: |  | ||||||
|             return profile_data |  | ||||||
|         group_response = self.session.request( |  | ||||||
|             "get", |  | ||||||
|             "https://graph.microsoft.com/v1.0/me/memberOf", |  | ||||||
|             headers={"Authorization": f"{token['token_type']} {token['access_token']}"}, |  | ||||||
|         ) |  | ||||||
|         try: |  | ||||||
|             group_response.raise_for_status() |  | ||||||
|         except RequestException as exc: |  | ||||||
|             LOGGER.warning( |  | ||||||
|                 "Unable to fetch user profile", |  | ||||||
|                 exc=exc, |  | ||||||
|                 response=exc.response.text if exc.response else str(exc), |  | ||||||
|             ) |  | ||||||
|             return None |  | ||||||
|         profile_data["raw_groups"] = group_response.json() |  | ||||||
|         return profile_data |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class AzureADOAuthCallback(OpenIDConnectOAuth2Callback): | class AzureADOAuthCallback(OpenIDConnectOAuth2Callback): | ||||||
|     """AzureAD OAuth2 Callback""" |     """AzureAD OAuth2 Callback""" | ||||||
|  |  | ||||||
|     client_class = AzureADClient |     client_class = UserprofileHeaderAuthClient | ||||||
|  |  | ||||||
|     def get_user_id(self, info: dict[str, str]) -> str: |     def get_user_id(self, info: dict[str, str]) -> str: | ||||||
|         # Default try to get `id` for the Graph API endpoint |         # Default try to get `id` for the Graph API endpoint | ||||||
| @ -79,24 +53,8 @@ class AzureADType(SourceType): | |||||||
|  |  | ||||||
|     def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]: |     def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]: | ||||||
|         mail = info.get("mail", None) or info.get("otherMails", [None])[0] |         mail = info.get("mail", None) or info.get("otherMails", [None])[0] | ||||||
|         # Format group info |  | ||||||
|         groups = [] |  | ||||||
|         group_id_dict = {} |  | ||||||
|         for group in info.get("raw_groups", {}).get("value", []): |  | ||||||
|             if group["@odata.type"] != "#microsoft.graph.group": |  | ||||||
|                 continue |  | ||||||
|             groups.append(group["id"]) |  | ||||||
|             group_id_dict[group["id"]] = group |  | ||||||
|         info["raw_groups"] = group_id_dict |  | ||||||
|         return { |         return { | ||||||
|             "username": info.get("userPrincipalName"), |             "username": info.get("userPrincipalName"), | ||||||
|             "email": mail, |             "email": mail, | ||||||
|             "name": info.get("displayName"), |             "name": info.get("displayName"), | ||||||
|             "groups": groups, |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|     def get_base_group_properties(self, source, group_id, **kwargs): |  | ||||||
|         raw_group = kwargs["info"]["raw_groups"][group_id] |  | ||||||
|         return { |  | ||||||
|             "name": raw_group["displayName"], |  | ||||||
|         } |         } | ||||||
|  | |||||||
| @ -1,85 +0,0 @@ | |||||||
| """AuthenticatorEmailStage API Views""" |  | ||||||
|  |  | ||||||
| from rest_framework import mixins |  | ||||||
| from rest_framework.viewsets import GenericViewSet, ModelViewSet |  | ||||||
|  |  | ||||||
| from authentik.core.api.groups import GroupMemberSerializer |  | ||||||
| from authentik.core.api.used_by import UsedByMixin |  | ||||||
| from authentik.core.api.utils import ModelSerializer |  | ||||||
| from authentik.flows.api.stages import StageSerializer |  | ||||||
| from authentik.stages.authenticator_email.models import AuthenticatorEmailStage, EmailDevice |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class AuthenticatorEmailStageSerializer(StageSerializer): |  | ||||||
|     """AuthenticatorEmailStage Serializer""" |  | ||||||
|  |  | ||||||
|     class Meta: |  | ||||||
|         model = AuthenticatorEmailStage |  | ||||||
|         fields = StageSerializer.Meta.fields + [ |  | ||||||
|             "configure_flow", |  | ||||||
|             "friendly_name", |  | ||||||
|             "use_global_settings", |  | ||||||
|             "host", |  | ||||||
|             "port", |  | ||||||
|             "username", |  | ||||||
|             "password", |  | ||||||
|             "use_tls", |  | ||||||
|             "use_ssl", |  | ||||||
|             "timeout", |  | ||||||
|             "from_address", |  | ||||||
|             "subject", |  | ||||||
|             "token_expiry", |  | ||||||
|             "template", |  | ||||||
|         ] |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class AuthenticatorEmailStageViewSet(UsedByMixin, ModelViewSet): |  | ||||||
|     """AuthenticatorEmailStage Viewset""" |  | ||||||
|  |  | ||||||
|     queryset = AuthenticatorEmailStage.objects.all() |  | ||||||
|     serializer_class = AuthenticatorEmailStageSerializer |  | ||||||
|     filterset_fields = "__all__" |  | ||||||
|     ordering = ["name"] |  | ||||||
|     search_fields = ["name"] |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class EmailDeviceSerializer(ModelSerializer): |  | ||||||
|     """Serializer for email authenticator devices""" |  | ||||||
|  |  | ||||||
|     user = GroupMemberSerializer(read_only=True) |  | ||||||
|  |  | ||||||
|     class Meta: |  | ||||||
|         model = EmailDevice |  | ||||||
|         fields = ["name", "pk", "email", "user"] |  | ||||||
|         depth = 2 |  | ||||||
|         extra_kwargs = { |  | ||||||
|             "email": {"read_only": True}, |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class EmailDeviceViewSet( |  | ||||||
|     mixins.RetrieveModelMixin, |  | ||||||
|     mixins.UpdateModelMixin, |  | ||||||
|     mixins.DestroyModelMixin, |  | ||||||
|     UsedByMixin, |  | ||||||
|     mixins.ListModelMixin, |  | ||||||
|     GenericViewSet, |  | ||||||
| ): |  | ||||||
|     """Viewset for email authenticator devices""" |  | ||||||
|  |  | ||||||
|     queryset = EmailDevice.objects.all() |  | ||||||
|     serializer_class = EmailDeviceSerializer |  | ||||||
|     search_fields = ["name"] |  | ||||||
|     filterset_fields = ["name"] |  | ||||||
|     ordering = ["name"] |  | ||||||
|     owner_field = "user" |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class EmailAdminDeviceViewSet(ModelViewSet): |  | ||||||
|     """Viewset for email authenticator devices (for admins)""" |  | ||||||
|  |  | ||||||
|     queryset = EmailDevice.objects.all() |  | ||||||
|     serializer_class = EmailDeviceSerializer |  | ||||||
|     search_fields = ["name"] |  | ||||||
|     filterset_fields = ["name"] |  | ||||||
|     ordering = ["name"] |  | ||||||
| @ -1,12 +0,0 @@ | |||||||
| """Email Authenticator""" |  | ||||||
|  |  | ||||||
| from authentik.blueprints.apps import ManagedAppConfig |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class AuthentikStageAuthenticatorEmailConfig(ManagedAppConfig): |  | ||||||
|     """Email Authenticator App config""" |  | ||||||
|  |  | ||||||
|     name = "authentik.stages.authenticator_email" |  | ||||||
|     label = "authentik_stages_authenticator_email" |  | ||||||
|     verbose_name = "authentik Stages.Authenticator.Email" |  | ||||||
|     default = True |  | ||||||
| @ -1,132 +0,0 @@ | |||||||
| # Generated by Django 5.0.10 on 2025-01-27 20:05 |  | ||||||
|  |  | ||||||
| import django.db.models.deletion |  | ||||||
| import django.utils.timezone |  | ||||||
| from django.conf import settings |  | ||||||
| from django.db import migrations, models |  | ||||||
|  |  | ||||||
| import authentik.lib.utils.time |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class Migration(migrations.Migration): |  | ||||||
|  |  | ||||||
|     initial = True |  | ||||||
|  |  | ||||||
|     dependencies = [ |  | ||||||
|         ("authentik_flows", "0027_auto_20231028_1424"), |  | ||||||
|         migrations.swappable_dependency(settings.AUTH_USER_MODEL), |  | ||||||
|     ] |  | ||||||
|  |  | ||||||
|     operations = [ |  | ||||||
|         migrations.CreateModel( |  | ||||||
|             name="AuthenticatorEmailStage", |  | ||||||
|             fields=[ |  | ||||||
|                 ( |  | ||||||
|                     "stage_ptr", |  | ||||||
|                     models.OneToOneField( |  | ||||||
|                         auto_created=True, |  | ||||||
|                         on_delete=django.db.models.deletion.CASCADE, |  | ||||||
|                         parent_link=True, |  | ||||||
|                         primary_key=True, |  | ||||||
|                         serialize=False, |  | ||||||
|                         to="authentik_flows.stage", |  | ||||||
|                     ), |  | ||||||
|                 ), |  | ||||||
|                 ("friendly_name", models.TextField(null=True)), |  | ||||||
|                 ( |  | ||||||
|                     "use_global_settings", |  | ||||||
|                     models.BooleanField( |  | ||||||
|                         default=False, |  | ||||||
|                         help_text="When enabled, global Email connection settings will be used and connection settings below will be ignored.", |  | ||||||
|                     ), |  | ||||||
|                 ), |  | ||||||
|                 ("host", models.TextField(default="localhost")), |  | ||||||
|                 ("port", models.IntegerField(default=25)), |  | ||||||
|                 ("username", models.TextField(blank=True, default="")), |  | ||||||
|                 ("password", models.TextField(blank=True, default="")), |  | ||||||
|                 ("use_tls", models.BooleanField(default=False)), |  | ||||||
|                 ("use_ssl", models.BooleanField(default=False)), |  | ||||||
|                 ("timeout", models.IntegerField(default=10)), |  | ||||||
|                 ( |  | ||||||
|                     "from_address", |  | ||||||
|                     models.EmailField(default="system@authentik.local", max_length=254), |  | ||||||
|                 ), |  | ||||||
|                 ( |  | ||||||
|                     "token_expiry", |  | ||||||
|                     models.TextField( |  | ||||||
|                         default="minutes=30", |  | ||||||
|                         help_text="Time the token sent is valid (Format: hours=3,minutes=17,seconds=300).", |  | ||||||
|                         validators=[authentik.lib.utils.time.timedelta_string_validator], |  | ||||||
|                     ), |  | ||||||
|                 ), |  | ||||||
|                 ("subject", models.TextField(default="authentik Sign-in code")), |  | ||||||
|                 ("template", models.TextField(default="email/email_otp.html")), |  | ||||||
|                 ( |  | ||||||
|                     "configure_flow", |  | ||||||
|                     models.ForeignKey( |  | ||||||
|                         blank=True, |  | ||||||
|                         help_text="Flow used by an authenticated user to configure this Stage. If empty, user will not be able to configure this stage.", |  | ||||||
|                         null=True, |  | ||||||
|                         on_delete=django.db.models.deletion.SET_NULL, |  | ||||||
|                         to="authentik_flows.flow", |  | ||||||
|                     ), |  | ||||||
|                 ), |  | ||||||
|             ], |  | ||||||
|             options={ |  | ||||||
|                 "verbose_name": "Email Authenticator Setup Stage", |  | ||||||
|                 "verbose_name_plural": "Email Authenticator Setup Stages", |  | ||||||
|             }, |  | ||||||
|             bases=("authentik_flows.stage", models.Model), |  | ||||||
|         ), |  | ||||||
|         migrations.CreateModel( |  | ||||||
|             name="EmailDevice", |  | ||||||
|             fields=[ |  | ||||||
|                 ( |  | ||||||
|                     "id", |  | ||||||
|                     models.AutoField( |  | ||||||
|                         auto_created=True, primary_key=True, serialize=False, verbose_name="ID" |  | ||||||
|                     ), |  | ||||||
|                 ), |  | ||||||
|                 ("created", models.DateTimeField(auto_now_add=True)), |  | ||||||
|                 ("last_updated", models.DateTimeField(auto_now=True)), |  | ||||||
|                 ( |  | ||||||
|                     "name", |  | ||||||
|                     models.CharField( |  | ||||||
|                         help_text="The human-readable name of this device.", max_length=64 |  | ||||||
|                     ), |  | ||||||
|                 ), |  | ||||||
|                 ( |  | ||||||
|                     "confirmed", |  | ||||||
|                     models.BooleanField(default=True, help_text="Is this device ready for use?"), |  | ||||||
|                 ), |  | ||||||
|                 ("token", models.CharField(blank=True, max_length=16, null=True)), |  | ||||||
|                 ( |  | ||||||
|                     "valid_until", |  | ||||||
|                     models.DateTimeField( |  | ||||||
|                         default=django.utils.timezone.now, |  | ||||||
|                         help_text="The timestamp of the moment of expiry of the saved token.", |  | ||||||
|                     ), |  | ||||||
|                 ), |  | ||||||
|                 ("email", models.EmailField(max_length=254)), |  | ||||||
|                 ("last_used", models.DateTimeField(auto_now=True)), |  | ||||||
|                 ( |  | ||||||
|                     "stage", |  | ||||||
|                     models.ForeignKey( |  | ||||||
|                         on_delete=django.db.models.deletion.CASCADE, |  | ||||||
|                         to="authentik_stages_authenticator_email.authenticatoremailstage", |  | ||||||
|                     ), |  | ||||||
|                 ), |  | ||||||
|                 ( |  | ||||||
|                     "user", |  | ||||||
|                     models.ForeignKey( |  | ||||||
|                         on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL |  | ||||||
|                     ), |  | ||||||
|                 ), |  | ||||||
|             ], |  | ||||||
|             options={ |  | ||||||
|                 "verbose_name": "Email Device", |  | ||||||
|                 "verbose_name_plural": "Email Devices", |  | ||||||
|                 "unique_together": {("user", "email")}, |  | ||||||
|             }, |  | ||||||
|         ), |  | ||||||
|     ] |  | ||||||
| @ -1,176 +0,0 @@ | |||||||
| from django.contrib.auth import get_user_model |  | ||||||
| from django.core.mail.backends.base import BaseEmailBackend |  | ||||||
| from django.core.mail.backends.smtp import EmailBackend |  | ||||||
| from django.db import models |  | ||||||
| from django.template import TemplateSyntaxError |  | ||||||
| from django.utils.translation import gettext_lazy as _ |  | ||||||
| from django.views import View |  | ||||||
| from rest_framework.serializers import BaseSerializer |  | ||||||
|  |  | ||||||
| from authentik.core.types import UserSettingSerializer |  | ||||||
| from authentik.events.models import Event, EventAction |  | ||||||
| from authentik.flows.exceptions import StageInvalidException |  | ||||||
| from authentik.flows.models import ConfigurableStage, FriendlyNamedStage, Stage |  | ||||||
| from authentik.lib.config import CONFIG |  | ||||||
| from authentik.lib.models import SerializerModel |  | ||||||
| from authentik.lib.utils.errors import exception_to_string |  | ||||||
| from authentik.lib.utils.time import timedelta_string_validator |  | ||||||
| from authentik.stages.authenticator.models import SideChannelDevice |  | ||||||
| from authentik.stages.email.utils import TemplateEmailMessage |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class EmailTemplates(models.TextChoices): |  | ||||||
|     """Templates used for rendering the Email""" |  | ||||||
|  |  | ||||||
|     EMAIL_OTP = ( |  | ||||||
|         "email/email_otp.html", |  | ||||||
|         _("Email OTP"), |  | ||||||
|     )  # nosec |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class AuthenticatorEmailStage(ConfigurableStage, FriendlyNamedStage, Stage): |  | ||||||
|     """Use Email-based authentication instead of authenticator-based.""" |  | ||||||
|  |  | ||||||
|     use_global_settings = models.BooleanField( |  | ||||||
|         default=False, |  | ||||||
|         help_text=_( |  | ||||||
|             "When enabled, global Email connection settings will be used and " |  | ||||||
|             "connection settings below will be ignored." |  | ||||||
|         ), |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     host = models.TextField(default="localhost") |  | ||||||
|     port = models.IntegerField(default=25) |  | ||||||
|     username = models.TextField(default="", blank=True) |  | ||||||
|     password = models.TextField(default="", blank=True) |  | ||||||
|     use_tls = models.BooleanField(default=False) |  | ||||||
|     use_ssl = models.BooleanField(default=False) |  | ||||||
|     timeout = models.IntegerField(default=10) |  | ||||||
|     from_address = models.EmailField(default="system@authentik.local") |  | ||||||
|  |  | ||||||
|     token_expiry = models.TextField( |  | ||||||
|         default="minutes=30", |  | ||||||
|         validators=[timedelta_string_validator], |  | ||||||
|         help_text=_("Time the token sent is valid (Format: hours=3,minutes=17,seconds=300)."), |  | ||||||
|     ) |  | ||||||
|     subject = models.TextField(default="authentik Sign-in code") |  | ||||||
|     template = models.TextField(default=EmailTemplates.EMAIL_OTP) |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def serializer(self) -> type[BaseSerializer]: |  | ||||||
|         from authentik.stages.authenticator_email.api import AuthenticatorEmailStageSerializer |  | ||||||
|  |  | ||||||
|         return AuthenticatorEmailStageSerializer |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def view(self) -> type[View]: |  | ||||||
|         from authentik.stages.authenticator_email.stage import AuthenticatorEmailStageView |  | ||||||
|  |  | ||||||
|         return AuthenticatorEmailStageView |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def component(self) -> str: |  | ||||||
|         return "ak-stage-authenticator-email-form" |  | ||||||
|  |  | ||||||
|     def ui_user_settings(self) -> UserSettingSerializer | None: |  | ||||||
|         return UserSettingSerializer( |  | ||||||
|             data={ |  | ||||||
|                 "title": self.friendly_name or str(self._meta.verbose_name), |  | ||||||
|                 "component": "ak-user-settings-authenticator-email", |  | ||||||
|             } |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def backend_class(self) -> type[BaseEmailBackend]: |  | ||||||
|         """Get the email backend class to use""" |  | ||||||
|         return EmailBackend |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def backend(self) -> BaseEmailBackend: |  | ||||||
|         """Get fully configured Email Backend instance""" |  | ||||||
|         if self.use_global_settings: |  | ||||||
|             CONFIG.refresh("email.password") |  | ||||||
|             return self.backend_class( |  | ||||||
|                 host=CONFIG.get("email.host"), |  | ||||||
|                 port=CONFIG.get_int("email.port"), |  | ||||||
|                 username=CONFIG.get("email.username"), |  | ||||||
|                 password=CONFIG.get("email.password"), |  | ||||||
|                 use_tls=CONFIG.get_bool("email.use_tls", False), |  | ||||||
|                 use_ssl=CONFIG.get_bool("email.use_ssl", False), |  | ||||||
|                 timeout=CONFIG.get_int("email.timeout"), |  | ||||||
|             ) |  | ||||||
|         return self.backend_class( |  | ||||||
|             host=self.host, |  | ||||||
|             port=self.port, |  | ||||||
|             username=self.username, |  | ||||||
|             password=self.password, |  | ||||||
|             use_tls=self.use_tls, |  | ||||||
|             use_ssl=self.use_ssl, |  | ||||||
|             timeout=self.timeout, |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def send(self, device: "EmailDevice"): |  | ||||||
|         # Lazy import here to avoid circular import |  | ||||||
|         from authentik.stages.email.tasks import send_mails |  | ||||||
|  |  | ||||||
|         # Compose the message using templates |  | ||||||
|         message = device._compose_email() |  | ||||||
|         return send_mails(device.stage, message) |  | ||||||
|  |  | ||||||
|     def __str__(self): |  | ||||||
|         return f"Email Authenticator Stage {self.name}" |  | ||||||
|  |  | ||||||
|     class Meta: |  | ||||||
|         verbose_name = _("Email Authenticator Setup Stage") |  | ||||||
|         verbose_name_plural = _("Email Authenticator Setup Stages") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class EmailDevice(SerializerModel, SideChannelDevice): |  | ||||||
|     """Email Device""" |  | ||||||
|  |  | ||||||
|     user = models.ForeignKey(get_user_model(), on_delete=models.CASCADE) |  | ||||||
|     email = models.EmailField() |  | ||||||
|     stage = models.ForeignKey(AuthenticatorEmailStage, on_delete=models.CASCADE) |  | ||||||
|     last_used = models.DateTimeField(auto_now=True) |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def serializer(self) -> type[BaseSerializer]: |  | ||||||
|         from authentik.stages.authenticator_email.api import EmailDeviceSerializer |  | ||||||
|  |  | ||||||
|         return EmailDeviceSerializer |  | ||||||
|  |  | ||||||
|     def _compose_email(self) -> TemplateEmailMessage: |  | ||||||
|         try: |  | ||||||
|             pending_user = self.user |  | ||||||
|             stage = self.stage |  | ||||||
|             email = self.email |  | ||||||
|  |  | ||||||
|             message = TemplateEmailMessage( |  | ||||||
|                 subject=_(stage.subject), |  | ||||||
|                 to=[(pending_user.name, email)], |  | ||||||
|                 template_name=stage.template, |  | ||||||
|                 template_context={ |  | ||||||
|                     "user": pending_user, |  | ||||||
|                     "expires": self.valid_until, |  | ||||||
|                     "token": self.token, |  | ||||||
|                 }, |  | ||||||
|             ) |  | ||||||
|             return message |  | ||||||
|         except TemplateSyntaxError as exc: |  | ||||||
|             Event.new( |  | ||||||
|                 EventAction.CONFIGURATION_ERROR, |  | ||||||
|                 message=_("Exception occurred while rendering E-mail template"), |  | ||||||
|                 error=exception_to_string(exc), |  | ||||||
|                 template=stage.template, |  | ||||||
|             ).from_http(self.request) |  | ||||||
|             raise StageInvalidException from exc |  | ||||||
|  |  | ||||||
|     def __str__(self): |  | ||||||
|         if not self.pk: |  | ||||||
|             return "New Email Device" |  | ||||||
|         return f"Email Device for {self.user_id}" |  | ||||||
|  |  | ||||||
|     class Meta: |  | ||||||
|         verbose_name = _("Email Device") |  | ||||||
|         verbose_name_plural = _("Email Devices") |  | ||||||
|         unique_together = (("user", "email"),) |  | ||||||
| @ -1,177 +0,0 @@ | |||||||
| """Email Setup stage""" |  | ||||||
|  |  | ||||||
| from django.db.models import Q |  | ||||||
| from django.http import HttpRequest, HttpResponse |  | ||||||
| from django.http.request import QueryDict |  | ||||||
| from django.template.exceptions import TemplateSyntaxError |  | ||||||
| from django.utils.translation import gettext_lazy as _ |  | ||||||
| from rest_framework.exceptions import ValidationError |  | ||||||
| from rest_framework.fields import BooleanField, CharField, IntegerField |  | ||||||
|  |  | ||||||
| from authentik.events.models import Event, EventAction |  | ||||||
| from authentik.flows.challenge import ( |  | ||||||
|     Challenge, |  | ||||||
|     ChallengeResponse, |  | ||||||
|     WithUserInfoChallenge, |  | ||||||
| ) |  | ||||||
| from authentik.flows.exceptions import StageInvalidException |  | ||||||
| from authentik.flows.stage import ChallengeStageView |  | ||||||
| from authentik.lib.utils.email import mask_email |  | ||||||
| from authentik.lib.utils.errors import exception_to_string |  | ||||||
| from authentik.lib.utils.time import timedelta_from_string |  | ||||||
| from authentik.stages.authenticator_email.models import ( |  | ||||||
|     AuthenticatorEmailStage, |  | ||||||
|     EmailDevice, |  | ||||||
| ) |  | ||||||
| from authentik.stages.email.tasks import send_mails |  | ||||||
| from authentik.stages.email.utils import TemplateEmailMessage |  | ||||||
| from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT |  | ||||||
|  |  | ||||||
| SESSION_KEY_EMAIL_DEVICE = "authentik/stages/authenticator_email/email_device" |  | ||||||
| PLAN_CONTEXT_EMAIL = "email" |  | ||||||
| PLAN_CONTEXT_EMAIL_SENT = "email_sent" |  | ||||||
| PLAN_CONTEXT_EMAIL_OVERRIDE = "email" |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class AuthenticatorEmailChallenge(WithUserInfoChallenge): |  | ||||||
|     """Authenticator Email Setup challenge""" |  | ||||||
|  |  | ||||||
|     # Set to true if no previous prompt stage set the email |  | ||||||
|     # this stage will also check prompt_data.email |  | ||||||
|     email = CharField(default=None, allow_blank=True, allow_null=True) |  | ||||||
|     email_required = BooleanField(default=True) |  | ||||||
|     component = CharField(default="ak-stage-authenticator-email") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class AuthenticatorEmailChallengeResponse(ChallengeResponse): |  | ||||||
|     """Authenticator Email Challenge response, device is set by get_response_instance""" |  | ||||||
|  |  | ||||||
|     device: EmailDevice |  | ||||||
|  |  | ||||||
|     code = IntegerField(required=False) |  | ||||||
|     email = CharField(required=False) |  | ||||||
|  |  | ||||||
|     component = CharField(default="ak-stage-authenticator-email") |  | ||||||
|  |  | ||||||
|     def validate(self, attrs: dict) -> dict: |  | ||||||
|         """Check""" |  | ||||||
|         if "code" not in attrs: |  | ||||||
|             if "email" not in attrs: |  | ||||||
|                 raise ValidationError("email required") |  | ||||||
|             self.device.email = attrs["email"] |  | ||||||
|             self.stage.validate_and_send(attrs["email"]) |  | ||||||
|             return super().validate(attrs) |  | ||||||
|         if not self.device.verify_token(str(attrs["code"])): |  | ||||||
|             raise ValidationError(_("Code does not match")) |  | ||||||
|         self.device.confirmed = True |  | ||||||
|         return super().validate(attrs) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class AuthenticatorEmailStageView(ChallengeStageView): |  | ||||||
|     """Authenticator Email Setup stage""" |  | ||||||
|  |  | ||||||
|     response_class = AuthenticatorEmailChallengeResponse |  | ||||||
|  |  | ||||||
|     def validate_and_send(self, email: str): |  | ||||||
|         """Validate email and send message""" |  | ||||||
|         pending_user = self.get_pending_user() |  | ||||||
|  |  | ||||||
|         stage: AuthenticatorEmailStage = self.executor.current_stage |  | ||||||
|         if EmailDevice.objects.filter(Q(email=email), stage=stage.pk).exists(): |  | ||||||
|             raise ValidationError(_("Invalid email")) |  | ||||||
|  |  | ||||||
|         device: EmailDevice = self.request.session[SESSION_KEY_EMAIL_DEVICE] |  | ||||||
|  |  | ||||||
|         try: |  | ||||||
|             message = TemplateEmailMessage( |  | ||||||
|                 subject=_(stage.subject), |  | ||||||
|                 to=[(pending_user.name, email)], |  | ||||||
|                 language=pending_user.locale(self.request), |  | ||||||
|                 template_name=stage.template, |  | ||||||
|                 template_context={ |  | ||||||
|                     "user": pending_user, |  | ||||||
|                     "expires": device.valid_until, |  | ||||||
|                     "token": device.token, |  | ||||||
|                 }, |  | ||||||
|             ) |  | ||||||
|  |  | ||||||
|             send_mails(stage, message) |  | ||||||
|         except TemplateSyntaxError as exc: |  | ||||||
|             Event.new( |  | ||||||
|                 EventAction.CONFIGURATION_ERROR, |  | ||||||
|                 message=_("Exception occurred while rendering E-mail template"), |  | ||||||
|                 error=exception_to_string(exc), |  | ||||||
|                 template=stage.template, |  | ||||||
|             ).from_http(self.request) |  | ||||||
|             raise StageInvalidException from exc |  | ||||||
|  |  | ||||||
|     def _has_email(self) -> str | None: |  | ||||||
|         context = self.executor.plan.context |  | ||||||
|  |  | ||||||
|         # Check user's email attribute |  | ||||||
|         user = self.get_pending_user() |  | ||||||
|         if user.email: |  | ||||||
|             self.logger.debug("got email from user attributes") |  | ||||||
|             return user.email |  | ||||||
|         # Check plan context for email |  | ||||||
|         if PLAN_CONTEXT_EMAIL in context.get(PLAN_CONTEXT_PROMPT, {}): |  | ||||||
|             self.logger.debug("got email from plan context") |  | ||||||
|             return context.get(PLAN_CONTEXT_PROMPT, {}).get(PLAN_CONTEXT_EMAIL) |  | ||||||
|         # Check device for email |  | ||||||
|         if SESSION_KEY_EMAIL_DEVICE in self.request.session: |  | ||||||
|             self.logger.debug("got email from device in session") |  | ||||||
|             device: EmailDevice = self.request.session[SESSION_KEY_EMAIL_DEVICE] |  | ||||||
|             if device.email == "": |  | ||||||
|                 return None |  | ||||||
|             return device.email |  | ||||||
|         return None |  | ||||||
|  |  | ||||||
|     def get_challenge(self, *args, **kwargs) -> Challenge: |  | ||||||
|         email = self._has_email() |  | ||||||
|         return AuthenticatorEmailChallenge( |  | ||||||
|             data={ |  | ||||||
|                 "email": mask_email(email), |  | ||||||
|                 "email_required": email is None, |  | ||||||
|             } |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def get_response_instance(self, data: QueryDict) -> ChallengeResponse: |  | ||||||
|         response = super().get_response_instance(data) |  | ||||||
|         response.device = self.request.session[SESSION_KEY_EMAIL_DEVICE] |  | ||||||
|         return response |  | ||||||
|  |  | ||||||
|     def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: |  | ||||||
|         user = self.get_pending_user() |  | ||||||
|  |  | ||||||
|         stage: AuthenticatorEmailStage = self.executor.current_stage |  | ||||||
|         if SESSION_KEY_EMAIL_DEVICE not in self.request.session: |  | ||||||
|             device = EmailDevice(user=user, confirmed=False, stage=stage, name="Email Device") |  | ||||||
|             valid_secs: int = timedelta_from_string(stage.token_expiry).total_seconds() |  | ||||||
|             device.generate_token(valid_secs=valid_secs, commit=False) |  | ||||||
|             self.request.session[SESSION_KEY_EMAIL_DEVICE] = device |  | ||||||
|             if email := self._has_email(): |  | ||||||
|                 device.email = email |  | ||||||
|                 try: |  | ||||||
|                     self.validate_and_send(email) |  | ||||||
|                 except ValidationError as exc: |  | ||||||
|                     # We had an email given already (at this point only possible from flow |  | ||||||
|                     # context), but an error occurred while sending (most likely) |  | ||||||
|                     # due to a duplicate device, so delete the email we got given, reset the state |  | ||||||
|                     # (ish) and retry |  | ||||||
|                     device.email = "" |  | ||||||
|                     self.executor.plan.context.get(PLAN_CONTEXT_PROMPT, {}).pop( |  | ||||||
|                         PLAN_CONTEXT_EMAIL, None |  | ||||||
|                     ) |  | ||||||
|                     self.request.session.pop(SESSION_KEY_EMAIL_DEVICE, None) |  | ||||||
|                     self.logger.warning("failed to send email to pre-set address", exc=exc) |  | ||||||
|                     return self.get(request, *args, **kwargs) |  | ||||||
|         return super().get(request, *args, **kwargs) |  | ||||||
|  |  | ||||||
|     def challenge_valid(self, response: ChallengeResponse) -> HttpResponse: |  | ||||||
|         """Email Token is validated by challenge""" |  | ||||||
|         device: EmailDevice = self.request.session[SESSION_KEY_EMAIL_DEVICE] |  | ||||||
|         if not device.confirmed: |  | ||||||
|             return self.challenge_invalid(response) |  | ||||||
|         device.save() |  | ||||||
|         del self.request.session[SESSION_KEY_EMAIL_DEVICE] |  | ||||||
|         return self.executor.stage_ok() |  | ||||||
| @ -1,44 +0,0 @@ | |||||||
| {% extends "email/base.html" %} |  | ||||||
|  |  | ||||||
| {% load i18n %} |  | ||||||
| {% load humanize %} |  | ||||||
|  |  | ||||||
| {% block content %} |  | ||||||
| <tr> |  | ||||||
|   <td align="center"> |  | ||||||
|     <h1> |  | ||||||
|       {% blocktrans with username=user.username %} |  | ||||||
|       Hi {{ username }}, |  | ||||||
|       {% endblocktrans %} |  | ||||||
|     </h1> |  | ||||||
|   </td> |  | ||||||
| </tr> |  | ||||||
| <tr> |  | ||||||
|   <td align="center"> |  | ||||||
|     <table border="0"> |  | ||||||
|       <tr> |  | ||||||
|         <td align="center" style="max-width: 300px; padding: 20px 0; color: #212124;"> |  | ||||||
|           {% blocktrans %} |  | ||||||
|           Email MFA code. |  | ||||||
|           {% endblocktrans %} |  | ||||||
|         </td> |  | ||||||
|       </tr> |  | ||||||
|       <tr> |  | ||||||
|         <td align="center" class="btn btn-primary"> |  | ||||||
|           {{ token }} |  | ||||||
|         </td> |  | ||||||
|       </tr> |  | ||||||
|     </table> |  | ||||||
|   </td> |  | ||||||
| </tr> |  | ||||||
| {% endblock %} |  | ||||||
|  |  | ||||||
| {% block sub_content %} |  | ||||||
| <tr> |  | ||||||
|   <td style="padding: 20px; font-size: 12px; color: #212124;" align="center"> |  | ||||||
|     {% blocktrans with expires=expires|timeuntil %} |  | ||||||
|     If you did not request this code, please ignore this email. The code above is valid for {{ expires }}. |  | ||||||
|     {% endblocktrans %} |  | ||||||
|   </td> |  | ||||||
| </tr> |  | ||||||
| {% endblock %} |  | ||||||
| @ -1,13 +0,0 @@ | |||||||
| {% load i18n %}{% load humanize %}{% autoescape off %}{% blocktrans with username=user.username %}Hi {{ username }},{% endblocktrans %} |  | ||||||
|  |  | ||||||
| {% blocktrans %} |  | ||||||
| Email MFA code |  | ||||||
| {% endblocktrans %} |  | ||||||
| {{ token }} |  | ||||||
| {% blocktrans with expires=expires|timeuntil %} |  | ||||||
| If you did not request this code, please ignore this email. The code above is valid for {{ expires }}. |  | ||||||
| {% endblocktrans %} |  | ||||||
|  |  | ||||||
| --  |  | ||||||
| Powered by goauthentik.io. |  | ||||||
| {% endautoescape %} |  | ||||||
| @ -1,342 +0,0 @@ | |||||||
| """Test Email Authenticator API""" |  | ||||||
|  |  | ||||||
| from datetime import timedelta |  | ||||||
| from unittest.mock import MagicMock, PropertyMock, patch |  | ||||||
|  |  | ||||||
| from django.core import mail |  | ||||||
| from django.core.mail.backends.smtp import EmailBackend |  | ||||||
| from django.db.utils import IntegrityError |  | ||||||
| from django.template.exceptions import TemplateDoesNotExist |  | ||||||
| from django.urls import reverse |  | ||||||
| from django.utils.timezone import now |  | ||||||
|  |  | ||||||
| from authentik.core.tests.utils import create_test_admin_user, create_test_flow, create_test_user |  | ||||||
| from authentik.flows.models import FlowStageBinding |  | ||||||
| from authentik.flows.tests import FlowTestCase |  | ||||||
| from authentik.lib.config import CONFIG |  | ||||||
| from authentik.lib.utils.email import mask_email |  | ||||||
| from authentik.stages.authenticator_email.api import ( |  | ||||||
|     AuthenticatorEmailStageSerializer, |  | ||||||
|     EmailDeviceSerializer, |  | ||||||
| ) |  | ||||||
| from authentik.stages.authenticator_email.models import AuthenticatorEmailStage, EmailDevice |  | ||||||
| from authentik.stages.authenticator_email.stage import ( |  | ||||||
|     SESSION_KEY_EMAIL_DEVICE, |  | ||||||
| ) |  | ||||||
| from authentik.stages.email.utils import TemplateEmailMessage |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestAuthenticatorEmailStage(FlowTestCase): |  | ||||||
|     """Test Email Authenticator stage""" |  | ||||||
|  |  | ||||||
|     def setUp(self): |  | ||||||
|         super().setUp() |  | ||||||
|         self.flow = create_test_flow() |  | ||||||
|         self.user = create_test_admin_user() |  | ||||||
|         self.user_noemail = create_test_user(email="") |  | ||||||
|         self.stage = AuthenticatorEmailStage.objects.create( |  | ||||||
|             name="email-authenticator", |  | ||||||
|             use_global_settings=True, |  | ||||||
|             from_address="test@authentik.local", |  | ||||||
|             configure_flow=self.flow, |  | ||||||
|             token_expiry="minutes=30", |  | ||||||
|         )  # nosec |  | ||||||
|         self.binding = FlowStageBinding.objects.create(target=self.flow, stage=self.stage, order=0) |  | ||||||
|         self.device = EmailDevice.objects.create( |  | ||||||
|             user=self.user, |  | ||||||
|             stage=self.stage, |  | ||||||
|             email="test@authentik.local", |  | ||||||
|         ) |  | ||||||
|         self.client.force_login(self.user) |  | ||||||
|  |  | ||||||
|     def test_device_str(self): |  | ||||||
|         """Test string representation of device""" |  | ||||||
|         self.assertEqual(str(self.device), f"Email Device for {self.user.pk}") |  | ||||||
|         # Test unsaved device |  | ||||||
|         unsaved_device = EmailDevice( |  | ||||||
|             user=self.user, |  | ||||||
|             stage=self.stage, |  | ||||||
|             email="test@authentik.local", |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(str(unsaved_device), "New Email Device") |  | ||||||
|  |  | ||||||
|     def test_stage_str(self): |  | ||||||
|         """Test string representation of stage""" |  | ||||||
|         self.assertEqual(str(self.stage), f"Email Authenticator Stage {self.stage.name}") |  | ||||||
|  |  | ||||||
|     def test_token_lifecycle(self): |  | ||||||
|         """Test token generation, validation and expiry""" |  | ||||||
|         # Initially no token |  | ||||||
|         self.assertIsNone(self.device.token) |  | ||||||
|  |  | ||||||
|         # Generate token |  | ||||||
|         self.device.generate_token() |  | ||||||
|         token = self.device.token |  | ||||||
|         self.assertIsNotNone(token) |  | ||||||
|         self.assertIsNotNone(self.device.valid_until) |  | ||||||
|         self.assertTrue(self.device.valid_until > now()) |  | ||||||
|  |  | ||||||
|         # Verify invalid token |  | ||||||
|         self.assertFalse(self.device.verify_token("000000")) |  | ||||||
|  |  | ||||||
|         # Verify correct token (should clear token after verification) |  | ||||||
|         self.assertTrue(self.device.verify_token(token)) |  | ||||||
|         self.assertIsNone(self.device.token) |  | ||||||
|  |  | ||||||
|     def test_stage_no_prefill(self): |  | ||||||
|         """Test stage without prefilled email""" |  | ||||||
|         self.client.force_login(self.user_noemail) |  | ||||||
|         with patch( |  | ||||||
|             "authentik.stages.authenticator_email.models.AuthenticatorEmailStage.backend_class", |  | ||||||
|             PropertyMock(return_value=EmailBackend), |  | ||||||
|         ): |  | ||||||
|             response = self.client.get( |  | ||||||
|                 reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), |  | ||||||
|             ) |  | ||||||
|             self.assertStageResponse( |  | ||||||
|                 response, |  | ||||||
|                 self.flow, |  | ||||||
|                 self.user_noemail, |  | ||||||
|                 component="ak-stage-authenticator-email", |  | ||||||
|                 email_required=True, |  | ||||||
|             ) |  | ||||||
|  |  | ||||||
|     def test_stage_submit(self): |  | ||||||
|         """Test stage email submission""" |  | ||||||
|         # Initialize the flow |  | ||||||
|         response = self.client.get( |  | ||||||
|             reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), |  | ||||||
|         ) |  | ||||||
|         self.assertStageResponse( |  | ||||||
|             response, |  | ||||||
|             self.flow, |  | ||||||
|             self.user, |  | ||||||
|             component="ak-stage-authenticator-email", |  | ||||||
|             email_required=False, |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|         # Test email submission with locmem backend |  | ||||||
|         def mock_send_mails(stage, *messages): |  | ||||||
|             """Mock send_mails to send directly""" |  | ||||||
|             for message in messages: |  | ||||||
|                 message.send() |  | ||||||
|  |  | ||||||
|         with ( |  | ||||||
|             patch( |  | ||||||
|                 "authentik.stages.authenticator_email.models.AuthenticatorEmailStage.backend_class", |  | ||||||
|                 return_value=EmailBackend, |  | ||||||
|             ), |  | ||||||
|             patch( |  | ||||||
|                 "authentik.stages.authenticator_email.stage.send_mails", |  | ||||||
|                 side_effect=mock_send_mails, |  | ||||||
|             ), |  | ||||||
|         ): |  | ||||||
|             response = self.client.post( |  | ||||||
|                 reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), |  | ||||||
|                 data={"component": "ak-stage-authenticator-email", "email": "test@example.com"}, |  | ||||||
|             ) |  | ||||||
|             self.assertEqual(response.status_code, 200) |  | ||||||
|             self.assertEqual(len(mail.outbox), 1) |  | ||||||
|             sent_mail = mail.outbox[0] |  | ||||||
|             self.assertEqual(sent_mail.subject, self.stage.subject) |  | ||||||
|             self.assertEqual(sent_mail.to, [f"{self.user} <test@example.com>"]) |  | ||||||
|             # Get from_address from global email config to test if global settings are being used |  | ||||||
|             from_address_global = CONFIG.get("email.from") |  | ||||||
|             self.assertEqual(sent_mail.from_email, from_address_global) |  | ||||||
|  |  | ||||||
|         self.assertStageResponse( |  | ||||||
|             response, |  | ||||||
|             self.flow, |  | ||||||
|             self.user, |  | ||||||
|             component="ak-stage-authenticator-email", |  | ||||||
|             response_errors={}, |  | ||||||
|             email_required=False, |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def test_email_template(self): |  | ||||||
|         """Test email template rendering""" |  | ||||||
|         self.device.generate_token() |  | ||||||
|         message = self.device._compose_email() |  | ||||||
|  |  | ||||||
|         self.assertIsInstance(message, TemplateEmailMessage) |  | ||||||
|         self.assertEqual(message.subject, self.stage.subject) |  | ||||||
|         self.assertEqual(message.to, [f"{self.user.name} <{self.device.email}>"]) |  | ||||||
|         self.assertTrue(self.device.token in message.body) |  | ||||||
|  |  | ||||||
|     def test_duplicate_email(self): |  | ||||||
|         """Test attempting to use same email twice""" |  | ||||||
|         email = "test2@authentik.local" |  | ||||||
|         # First device |  | ||||||
|         EmailDevice.objects.create( |  | ||||||
|             user=self.user, |  | ||||||
|             stage=self.stage, |  | ||||||
|             email=email, |  | ||||||
|         ) |  | ||||||
|         # Attempt to create second device with same email |  | ||||||
|         with self.assertRaises(IntegrityError): |  | ||||||
|             EmailDevice.objects.create( |  | ||||||
|                 user=self.user, |  | ||||||
|                 stage=self.stage, |  | ||||||
|                 email=email, |  | ||||||
|             ) |  | ||||||
|  |  | ||||||
|     def test_token_expiry(self): |  | ||||||
|         """Test token expiration behavior""" |  | ||||||
|         self.device.generate_token() |  | ||||||
|         token = self.device.token |  | ||||||
|         # Set token as expired |  | ||||||
|         self.device.valid_until = now() - timedelta(minutes=1) |  | ||||||
|         self.device.save() |  | ||||||
|         # Verify expired token fails |  | ||||||
|         self.assertFalse(self.device.verify_token(token)) |  | ||||||
|  |  | ||||||
|     def test_template_errors(self): |  | ||||||
|         """Test handling of template errors""" |  | ||||||
|         self.stage.template = "{% invalid template %}" |  | ||||||
|         with self.assertRaises(TemplateDoesNotExist): |  | ||||||
|             self.stage.send(self.device) |  | ||||||
|  |  | ||||||
|     def test_challenge_response_validation(self): |  | ||||||
|         """Test challenge response validation""" |  | ||||||
|         # Initialize the flow |  | ||||||
|         self.client.force_login(self.user_noemail) |  | ||||||
|         with patch( |  | ||||||
|             "authentik.stages.authenticator_email.models.AuthenticatorEmailStage.backend_class", |  | ||||||
|             PropertyMock(return_value=EmailBackend), |  | ||||||
|         ): |  | ||||||
|             response = self.client.get( |  | ||||||
|                 reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), |  | ||||||
|             ) |  | ||||||
|  |  | ||||||
|             # Test missing code and email |  | ||||||
|             response = self.client.post( |  | ||||||
|                 reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), |  | ||||||
|                 data={"component": "ak-stage-authenticator-email"}, |  | ||||||
|             ) |  | ||||||
|             self.assertIn("email required", str(response.content)) |  | ||||||
|  |  | ||||||
|             # Test invalid code |  | ||||||
|             response = self.client.post( |  | ||||||
|                 reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), |  | ||||||
|                 data={"component": "ak-stage-authenticator-email", "code": "000000"}, |  | ||||||
|             ) |  | ||||||
|             self.assertIn("Code does not match", str(response.content)) |  | ||||||
|  |  | ||||||
|             # Test valid code |  | ||||||
|             self.client.force_login(self.user) |  | ||||||
|             response = self.client.get( |  | ||||||
|                 reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), |  | ||||||
|             ) |  | ||||||
|             device = self.device |  | ||||||
|             token = device.token |  | ||||||
|             response = self.client.post( |  | ||||||
|                 reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), |  | ||||||
|                 data={"component": "ak-stage-authenticator-email", "code": token}, |  | ||||||
|             ) |  | ||||||
|             self.assertEqual(response.status_code, 200) |  | ||||||
|             self.assertTrue(device.confirmed) |  | ||||||
|  |  | ||||||
|     def test_challenge_generation(self): |  | ||||||
|         """Test challenge generation""" |  | ||||||
|         # Test with masked email |  | ||||||
|         with patch( |  | ||||||
|             "authentik.stages.authenticator_email.models.AuthenticatorEmailStage.backend_class", |  | ||||||
|             PropertyMock(return_value=EmailBackend), |  | ||||||
|         ): |  | ||||||
|             response = self.client.get( |  | ||||||
|                 reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), |  | ||||||
|             ) |  | ||||||
|             self.assertStageResponse( |  | ||||||
|                 response, |  | ||||||
|                 self.flow, |  | ||||||
|                 self.user, |  | ||||||
|                 component="ak-stage-authenticator-email", |  | ||||||
|                 email_required=False, |  | ||||||
|             ) |  | ||||||
|             masked_email = mask_email(self.user.email) |  | ||||||
|             self.assertEqual(masked_email, response.json()["email"]) |  | ||||||
|  |  | ||||||
|             # Test without email |  | ||||||
|             self.client.force_login(self.user_noemail) |  | ||||||
|             response = self.client.get( |  | ||||||
|                 reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), |  | ||||||
|             ) |  | ||||||
|             self.assertStageResponse( |  | ||||||
|                 response, |  | ||||||
|                 self.flow, |  | ||||||
|                 self.user_noemail, |  | ||||||
|                 component="ak-stage-authenticator-email", |  | ||||||
|                 email_required=True, |  | ||||||
|             ) |  | ||||||
|             self.assertIsNone(response.json()["email"]) |  | ||||||
|  |  | ||||||
|     def test_session_management(self): |  | ||||||
|         """Test session device management""" |  | ||||||
|         # Test device creation in session |  | ||||||
|         with patch( |  | ||||||
|             "authentik.stages.authenticator_email.models.AuthenticatorEmailStage.backend_class", |  | ||||||
|             PropertyMock(return_value=EmailBackend), |  | ||||||
|         ): |  | ||||||
|             # Delete any existing devices for this test |  | ||||||
|             EmailDevice.objects.filter(user=self.user).delete() |  | ||||||
|  |  | ||||||
|             response = self.client.get( |  | ||||||
|                 reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), |  | ||||||
|             ) |  | ||||||
|             self.assertIn(SESSION_KEY_EMAIL_DEVICE, self.client.session) |  | ||||||
|             device = self.client.session[SESSION_KEY_EMAIL_DEVICE] |  | ||||||
|             self.assertIsInstance(device, EmailDevice) |  | ||||||
|             self.assertFalse(device.confirmed) |  | ||||||
|             self.assertEqual(device.user, self.user) |  | ||||||
|  |  | ||||||
|             # Test device confirmation and cleanup |  | ||||||
|             device.confirmed = True |  | ||||||
|             device.email = "new_test@authentik.local"  # Use a different email |  | ||||||
|             self.client.session[SESSION_KEY_EMAIL_DEVICE] = device |  | ||||||
|             self.client.session.save() |  | ||||||
|             response = self.client.post( |  | ||||||
|                 reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), |  | ||||||
|                 data={"component": "ak-stage-authenticator-email", "code": device.token}, |  | ||||||
|             ) |  | ||||||
|             self.assertEqual(response.status_code, 200) |  | ||||||
|             self.assertTrue(device.confirmed) |  | ||||||
|             # Get a fresh session to check if the key was removed |  | ||||||
|             session = self.client.session |  | ||||||
|             session.save() |  | ||||||
|             session.load() |  | ||||||
|             self.assertNotIn(SESSION_KEY_EMAIL_DEVICE, session) |  | ||||||
|  |  | ||||||
|     def test_model_properties_and_methods(self): |  | ||||||
|         """Test model properties""" |  | ||||||
|         device = self.device |  | ||||||
|         stage = self.stage |  | ||||||
|  |  | ||||||
|         self.assertEqual(stage.serializer, AuthenticatorEmailStageSerializer) |  | ||||||
|         self.assertIsInstance(stage.backend, EmailBackend) |  | ||||||
|         self.assertEqual(device.serializer, EmailDeviceSerializer) |  | ||||||
|  |  | ||||||
|         # Test AuthenticatorEmailStage send method |  | ||||||
|         with patch( |  | ||||||
|             "authentik.stages.authenticator_email.models.AuthenticatorEmailStage.backend_class", |  | ||||||
|             return_value=EmailBackend, |  | ||||||
|         ): |  | ||||||
|             self.device.generate_token() |  | ||||||
|             # Test EmailDevice _compose_email method |  | ||||||
|             message = self.device._compose_email() |  | ||||||
|             self.assertIsInstance(message, TemplateEmailMessage) |  | ||||||
|             self.assertEqual(message.subject, self.stage.subject) |  | ||||||
|             self.assertEqual(message.to, [f"{self.user.name} <{self.device.email}>"]) |  | ||||||
|             self.assertTrue(self.device.token in message.body) |  | ||||||
|             # Test AuthenticatorEmailStage send method |  | ||||||
|             self.stage.send(device) |  | ||||||
|  |  | ||||||
|     def test_email_tasks(self): |  | ||||||
|  |  | ||||||
|         email_send_mock = MagicMock() |  | ||||||
|         with patch( |  | ||||||
|             "authentik.stages.email.tasks.send_mails", |  | ||||||
|             email_send_mock, |  | ||||||
|         ): |  | ||||||
|             # Test AuthenticatorEmailStage send method |  | ||||||
|             self.stage.send(self.device) |  | ||||||
|             email_send_mock.assert_called_once() |  | ||||||
| @ -1,17 +0,0 @@ | |||||||
| """API URLs""" |  | ||||||
|  |  | ||||||
| from authentik.stages.authenticator_email.api import ( |  | ||||||
|     AuthenticatorEmailStageViewSet, |  | ||||||
|     EmailAdminDeviceViewSet, |  | ||||||
|     EmailDeviceViewSet, |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| api_urlpatterns = [ |  | ||||||
|     ("authenticators/email", EmailDeviceViewSet), |  | ||||||
|     ( |  | ||||||
|         "authenticators/admin/email", |  | ||||||
|         EmailAdminDeviceViewSet, |  | ||||||
|         "admin-emaildevice", |  | ||||||
|     ), |  | ||||||
|     ("stages/authenticator/email", AuthenticatorEmailStageViewSet), |  | ||||||
| ] |  | ||||||
| @ -26,13 +26,10 @@ from authentik.events.middleware import audit_ignore | |||||||
| from authentik.events.models import Event, EventAction | from authentik.events.models import Event, EventAction | ||||||
| from authentik.flows.stage import StageView | from authentik.flows.stage import StageView | ||||||
| from authentik.flows.views.executor import SESSION_KEY_APPLICATION_PRE | from authentik.flows.views.executor import SESSION_KEY_APPLICATION_PRE | ||||||
| from authentik.lib.utils.email import mask_email |  | ||||||
| from authentik.lib.utils.time import timedelta_from_string |  | ||||||
| from authentik.root.middleware import ClientIPMiddleware | from authentik.root.middleware import ClientIPMiddleware | ||||||
| from authentik.stages.authenticator import match_token | from authentik.stages.authenticator import match_token | ||||||
| from authentik.stages.authenticator.models import Device | from authentik.stages.authenticator.models import Device | ||||||
| from authentik.stages.authenticator_duo.models import AuthenticatorDuoStage, DuoDevice | from authentik.stages.authenticator_duo.models import AuthenticatorDuoStage, DuoDevice | ||||||
| from authentik.stages.authenticator_email.models import EmailDevice |  | ||||||
| from authentik.stages.authenticator_sms.models import SMSDevice | from authentik.stages.authenticator_sms.models import SMSDevice | ||||||
| from authentik.stages.authenticator_validate.models import AuthenticatorValidateStage, DeviceClasses | from authentik.stages.authenticator_validate.models import AuthenticatorValidateStage, DeviceClasses | ||||||
| from authentik.stages.authenticator_webauthn.models import UserVerification, WebAuthnDevice | from authentik.stages.authenticator_webauthn.models import UserVerification, WebAuthnDevice | ||||||
| @ -57,8 +54,6 @@ def get_challenge_for_device( | |||||||
|     """Generate challenge for a single device""" |     """Generate challenge for a single device""" | ||||||
|     if isinstance(device, WebAuthnDevice): |     if isinstance(device, WebAuthnDevice): | ||||||
|         return get_webauthn_challenge(request, stage, device) |         return get_webauthn_challenge(request, stage, device) | ||||||
|     if isinstance(device, EmailDevice): |  | ||||||
|         return {"email": mask_email(device.email)} |  | ||||||
|     # Code-based challenges have no hints |     # Code-based challenges have no hints | ||||||
|     return {} |     return {} | ||||||
|  |  | ||||||
| @ -108,8 +103,6 @@ def select_challenge(request: HttpRequest, device: Device): | |||||||
|     """Callback when the user selected a challenge in the frontend.""" |     """Callback when the user selected a challenge in the frontend.""" | ||||||
|     if isinstance(device, SMSDevice): |     if isinstance(device, SMSDevice): | ||||||
|         select_challenge_sms(request, device) |         select_challenge_sms(request, device) | ||||||
|     elif isinstance(device, EmailDevice): |  | ||||||
|         select_challenge_email(request, device) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def select_challenge_sms(request: HttpRequest, device: SMSDevice): | def select_challenge_sms(request: HttpRequest, device: SMSDevice): | ||||||
| @ -118,13 +111,6 @@ def select_challenge_sms(request: HttpRequest, device: SMSDevice): | |||||||
|     device.stage.send(device.token, device) |     device.stage.send(device.token, device) | ||||||
|  |  | ||||||
|  |  | ||||||
| def select_challenge_email(request: HttpRequest, device: EmailDevice): |  | ||||||
|     """Send Email""" |  | ||||||
|     valid_secs: int = timedelta_from_string(device.stage.token_expiry).total_seconds() |  | ||||||
|     device.generate_token(valid_secs=valid_secs) |  | ||||||
|     device.stage.send(device) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def validate_challenge_code(code: str, stage_view: StageView, user: User) -> Device: | def validate_challenge_code(code: str, stage_view: StageView, user: User) -> Device: | ||||||
|     """Validate code-based challenges. We test against every device, on purpose, as |     """Validate code-based challenges. We test against every device, on purpose, as | ||||||
|     the user mustn't choose between totp and static devices.""" |     the user mustn't choose between totp and static devices.""" | ||||||
|  | |||||||
| @ -1,37 +0,0 @@ | |||||||
| # Generated by Django 5.0.10 on 2025-01-16 02:48 |  | ||||||
|  |  | ||||||
| import authentik.stages.authenticator_validate.models |  | ||||||
| import django.contrib.postgres.fields |  | ||||||
| from django.db import migrations, models |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class Migration(migrations.Migration): |  | ||||||
|  |  | ||||||
|     dependencies = [ |  | ||||||
|         ( |  | ||||||
|             "authentik_stages_authenticator_validate", |  | ||||||
|             "0013_authenticatorvalidatestage_webauthn_allowed_device_types", |  | ||||||
|         ), |  | ||||||
|     ] |  | ||||||
|  |  | ||||||
|     operations = [ |  | ||||||
|         migrations.AlterField( |  | ||||||
|             model_name="authenticatorvalidatestage", |  | ||||||
|             name="device_classes", |  | ||||||
|             field=django.contrib.postgres.fields.ArrayField( |  | ||||||
|                 base_field=models.TextField( |  | ||||||
|                     choices=[ |  | ||||||
|                         ("static", "Static"), |  | ||||||
|                         ("totp", "TOTP"), |  | ||||||
|                         ("webauthn", "WebAuthn"), |  | ||||||
|                         ("duo", "Duo"), |  | ||||||
|                         ("sms", "SMS"), |  | ||||||
|                         ("email", "Email"), |  | ||||||
|                     ] |  | ||||||
|                 ), |  | ||||||
|                 default=authentik.stages.authenticator_validate.models.default_device_classes, |  | ||||||
|                 help_text="Device classes which can be used to authenticate", |  | ||||||
|                 size=None, |  | ||||||
|             ), |  | ||||||
|         ), |  | ||||||
|     ] |  | ||||||
| @ -20,7 +20,6 @@ class DeviceClasses(models.TextChoices): | |||||||
|     WEBAUTHN = "webauthn", _("WebAuthn") |     WEBAUTHN = "webauthn", _("WebAuthn") | ||||||
|     DUO = "duo", _("Duo") |     DUO = "duo", _("Duo") | ||||||
|     SMS = "sms", _("SMS") |     SMS = "sms", _("SMS") | ||||||
|     EMAIL = "email", _("Email") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def default_device_classes() -> list: | def default_device_classes() -> list: | ||||||
| @ -31,7 +30,6 @@ def default_device_classes() -> list: | |||||||
|         DeviceClasses.WEBAUTHN, |         DeviceClasses.WEBAUTHN, | ||||||
|         DeviceClasses.DUO, |         DeviceClasses.DUO, | ||||||
|         DeviceClasses.SMS, |         DeviceClasses.SMS, | ||||||
|         DeviceClasses.EMAIL, |  | ||||||
|     ] |     ] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| @ -23,7 +23,6 @@ from authentik.flows.stage import ChallengeStageView | |||||||
| from authentik.lib.utils.time import timedelta_from_string | from authentik.lib.utils.time import timedelta_from_string | ||||||
| from authentik.stages.authenticator import devices_for_user | from authentik.stages.authenticator import devices_for_user | ||||||
| from authentik.stages.authenticator.models import Device | from authentik.stages.authenticator.models import Device | ||||||
| from authentik.stages.authenticator_email.models import EmailDevice |  | ||||||
| from authentik.stages.authenticator_sms.models import SMSDevice | from authentik.stages.authenticator_sms.models import SMSDevice | ||||||
| from authentik.stages.authenticator_validate.challenge import ( | from authentik.stages.authenticator_validate.challenge import ( | ||||||
|     DeviceChallenge, |     DeviceChallenge, | ||||||
| @ -85,9 +84,7 @@ class AuthenticatorValidationChallengeResponse(ChallengeResponse): | |||||||
|  |  | ||||||
|     def validate_code(self, code: str) -> str: |     def validate_code(self, code: str) -> str: | ||||||
|         """Validate code-based response, raise error if code isn't allowed""" |         """Validate code-based response, raise error if code isn't allowed""" | ||||||
|         self._challenge_allowed( |         self._challenge_allowed([DeviceClasses.TOTP, DeviceClasses.STATIC, DeviceClasses.SMS]) | ||||||
|             [DeviceClasses.TOTP, DeviceClasses.STATIC, DeviceClasses.SMS, DeviceClasses.EMAIL] |  | ||||||
|         ) |  | ||||||
|         self.device = validate_challenge_code(code, self.stage, self.stage.get_pending_user()) |         self.device = validate_challenge_code(code, self.stage, self.stage.get_pending_user()) | ||||||
|         return code |         return code | ||||||
|  |  | ||||||
| @ -120,17 +117,12 @@ class AuthenticatorValidationChallengeResponse(ChallengeResponse): | |||||||
|         if not allowed: |         if not allowed: | ||||||
|             raise ValidationError("invalid challenge selected") |             raise ValidationError("invalid challenge selected") | ||||||
|  |  | ||||||
|         device_class = challenge.get("device_class", "") |         if challenge.get("device_class", "") != "sms": | ||||||
|         if device_class == "sms": |             return challenge | ||||||
|         devices = SMSDevice.objects.filter(pk=int(challenge.get("device_uid", "0"))) |         devices = SMSDevice.objects.filter(pk=int(challenge.get("device_uid", "0"))) | ||||||
|         if not devices.exists(): |         if not devices.exists(): | ||||||
|             raise ValidationError("invalid challenge selected") |             raise ValidationError("invalid challenge selected") | ||||||
|         select_challenge(self.stage.request, devices.first()) |         select_challenge(self.stage.request, devices.first()) | ||||||
|         elif device_class == "email": |  | ||||||
|             devices = EmailDevice.objects.filter(pk=int(challenge.get("device_uid", "0"))) |  | ||||||
|             if not devices.exists(): |  | ||||||
|                 raise ValidationError("invalid challenge selected") |  | ||||||
|             select_challenge(self.stage.request, devices.first()) |  | ||||||
|         return challenge |         return challenge | ||||||
|  |  | ||||||
|     def validate_selected_stage(self, stage_pk: str) -> str: |     def validate_selected_stage(self, stage_pk: str) -> str: | ||||||
|  | |||||||
| @ -1,183 +0,0 @@ | |||||||
| """Test validator stage for Email devices""" |  | ||||||
|  |  | ||||||
| from django.test.client import RequestFactory |  | ||||||
| from django.urls.base import reverse |  | ||||||
|  |  | ||||||
| from authentik.core.tests.utils import create_test_admin_user, create_test_flow |  | ||||||
| from authentik.flows.models import FlowStageBinding, NotConfiguredAction |  | ||||||
| from authentik.flows.tests import FlowTestCase |  | ||||||
| from authentik.lib.generators import generate_id |  | ||||||
| from authentik.lib.utils.email import mask_email |  | ||||||
| from authentik.stages.authenticator_email.models import AuthenticatorEmailStage, EmailDevice |  | ||||||
| from authentik.stages.authenticator_validate.models import AuthenticatorValidateStage, DeviceClasses |  | ||||||
| from authentik.stages.identification.models import IdentificationStage, UserFields |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class AuthenticatorValidateStageEmailTests(FlowTestCase): |  | ||||||
|     """Test validator stage for Email devices""" |  | ||||||
|  |  | ||||||
|     def setUp(self) -> None: |  | ||||||
|         self.user = create_test_admin_user() |  | ||||||
|         self.request_factory = RequestFactory() |  | ||||||
|         # Create email authenticator stage |  | ||||||
|         self.stage = AuthenticatorEmailStage.objects.create( |  | ||||||
|             name="email-authenticator", |  | ||||||
|             use_global_settings=True, |  | ||||||
|             from_address="test@authentik.local", |  | ||||||
|         ) |  | ||||||
|         # Create identification stage |  | ||||||
|         self.ident_stage = IdentificationStage.objects.create( |  | ||||||
|             name=generate_id(), |  | ||||||
|             user_fields=[UserFields.USERNAME], |  | ||||||
|         ) |  | ||||||
|         # Create validation stage |  | ||||||
|         self.validate_stage = AuthenticatorValidateStage.objects.create( |  | ||||||
|             name=generate_id(), |  | ||||||
|             device_classes=[DeviceClasses.EMAIL], |  | ||||||
|         ) |  | ||||||
|         # Create flow with both stages |  | ||||||
|         self.flow = create_test_flow() |  | ||||||
|         FlowStageBinding.objects.create(target=self.flow, stage=self.ident_stage, order=0) |  | ||||||
|         FlowStageBinding.objects.create(target=self.flow, stage=self.validate_stage, order=1) |  | ||||||
|  |  | ||||||
|     def _identify_user(self): |  | ||||||
|         """Helper to identify user in flow""" |  | ||||||
|         response = self.client.post( |  | ||||||
|             reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), |  | ||||||
|             {"uid_field": self.user.username}, |  | ||||||
|             follow=True, |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, 200) |  | ||||||
|         return response |  | ||||||
|  |  | ||||||
|     def _send_challenge(self, device): |  | ||||||
|         """Helper to send challenge for device""" |  | ||||||
|         response = self.client.post( |  | ||||||
|             reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), |  | ||||||
|             { |  | ||||||
|                 "component": "ak-stage-authenticator-validate", |  | ||||||
|                 "selected_challenge": { |  | ||||||
|                     "device_class": "email", |  | ||||||
|                     "device_uid": str(device.pk), |  | ||||||
|                     "challenge": {}, |  | ||||||
|                     "last_used": device.last_used.isoformat() if device.last_used else None, |  | ||||||
|                 }, |  | ||||||
|             }, |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, 200) |  | ||||||
|         return response |  | ||||||
|  |  | ||||||
|     def test_happy_path(self): |  | ||||||
|         """Test validator stage with valid code""" |  | ||||||
|         # Create a device for our user |  | ||||||
|         device = EmailDevice.objects.create( |  | ||||||
|             user=self.user, |  | ||||||
|             confirmed=True, |  | ||||||
|             stage=self.stage, |  | ||||||
|             email="xx@0.co", |  | ||||||
|         )  # Short email for testing purposes |  | ||||||
|  |  | ||||||
|         # First identify the user |  | ||||||
|         self._identify_user() |  | ||||||
|  |  | ||||||
|         # Send the challenge |  | ||||||
|         response = self._send_challenge(device) |  | ||||||
|         response_data = self.assertStageResponse( |  | ||||||
|             response, |  | ||||||
|             flow=self.flow, |  | ||||||
|             component="ak-stage-authenticator-validate", |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|         # Get the device challenge from the response and verify it matches |  | ||||||
|         device_challenge = response_data["device_challenges"][0] |  | ||||||
|         self.assertEqual(device_challenge["device_class"], "email") |  | ||||||
|         self.assertEqual(device_challenge["device_uid"], str(device.pk)) |  | ||||||
|         self.assertEqual(device_challenge["challenge"], {"email": mask_email(device.email)}) |  | ||||||
|  |  | ||||||
|         # Generate a token for the device |  | ||||||
|         device.generate_token() |  | ||||||
|  |  | ||||||
|         # Submit the valid code |  | ||||||
|         response = self.client.post( |  | ||||||
|             reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), |  | ||||||
|             {"component": "ak-stage-authenticator-validate", "code": device.token}, |  | ||||||
|         ) |  | ||||||
|         # Should redirect to root since this is the last stage |  | ||||||
|         self.assertStageRedirects(response, "/") |  | ||||||
|  |  | ||||||
|     def test_no_device(self): |  | ||||||
|         """Test validator stage without configured device""" |  | ||||||
|         configuration_stage = AuthenticatorEmailStage.objects.create( |  | ||||||
|             name=generate_id(), |  | ||||||
|             use_global_settings=True, |  | ||||||
|             from_address="test@authentik.local", |  | ||||||
|         ) |  | ||||||
|         stage = AuthenticatorValidateStage.objects.create( |  | ||||||
|             name=generate_id(), |  | ||||||
|             not_configured_action=NotConfiguredAction.CONFIGURE, |  | ||||||
|             device_classes=[DeviceClasses.EMAIL], |  | ||||||
|         ) |  | ||||||
|         stage.configuration_stages.set([configuration_stage]) |  | ||||||
|         flow = create_test_flow() |  | ||||||
|         FlowStageBinding.objects.create(target=flow, stage=stage, order=2) |  | ||||||
|  |  | ||||||
|         response = self.client.post( |  | ||||||
|             reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}), |  | ||||||
|             {"component": "ak-stage-authenticator-validate"}, |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, 200) |  | ||||||
|         response_data = self.assertStageResponse( |  | ||||||
|             response, |  | ||||||
|             flow=flow, |  | ||||||
|             component="ak-stage-authenticator-validate", |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response_data["configuration_stages"], []) |  | ||||||
|         self.assertEqual(response_data["device_challenges"], []) |  | ||||||
|         self.assertEqual( |  | ||||||
|             response_data["response_errors"], |  | ||||||
|             {"non_field_errors": [{"code": "invalid", "string": "Empty response"}]}, |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def test_invalid_code(self): |  | ||||||
|         """Test validator stage with invalid code""" |  | ||||||
|         # Create a device for our user |  | ||||||
|         device = EmailDevice.objects.create( |  | ||||||
|             user=self.user, |  | ||||||
|             confirmed=True, |  | ||||||
|             stage=self.stage, |  | ||||||
|             email="test@authentik.local", |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|         # First identify the user |  | ||||||
|         self._identify_user() |  | ||||||
|  |  | ||||||
|         # Send the challenge |  | ||||||
|         self._send_challenge(device) |  | ||||||
|  |  | ||||||
|         # Generate a token for the device |  | ||||||
|         device.generate_token() |  | ||||||
|  |  | ||||||
|         # Try invalid code and verify error message |  | ||||||
|         response = self.client.post( |  | ||||||
|             reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), |  | ||||||
|             {"component": "ak-stage-authenticator-validate", "code": "invalid"}, |  | ||||||
|         ) |  | ||||||
|         response_data = self.assertStageResponse( |  | ||||||
|             response, |  | ||||||
|             flow=self.flow, |  | ||||||
|             component="ak-stage-authenticator-validate", |  | ||||||
|         ) |  | ||||||
|         self.assertEqual( |  | ||||||
|             response_data["response_errors"], |  | ||||||
|             { |  | ||||||
|                 "code": [ |  | ||||||
|                     { |  | ||||||
|                         "code": "invalid", |  | ||||||
|                         "string": ( |  | ||||||
|                             "Invalid Token. Please ensure the time on your device " |  | ||||||
|                             "is accurate and try again." |  | ||||||
|                         ), |  | ||||||
|                     } |  | ||||||
|                 ], |  | ||||||
|             }, |  | ||||||
|         ) |  | ||||||
| @ -12,31 +12,18 @@ from structlog.stdlib import get_logger | |||||||
|  |  | ||||||
| from authentik.events.models import Event, EventAction, TaskStatus | from authentik.events.models import Event, EventAction, TaskStatus | ||||||
| from authentik.events.system_tasks import SystemTask | from authentik.events.system_tasks import SystemTask | ||||||
| from authentik.lib.utils.reflection import class_to_path, path_to_class |  | ||||||
| from authentik.root.celery import CELERY_APP | from authentik.root.celery import CELERY_APP | ||||||
| from authentik.stages.authenticator_email.models import AuthenticatorEmailStage |  | ||||||
| from authentik.stages.email.models import EmailStage | from authentik.stages.email.models import EmailStage | ||||||
| from authentik.stages.email.utils import logo_data | from authentik.stages.email.utils import logo_data | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  |  | ||||||
|  |  | ||||||
| def send_mails( | def send_mails(stage: EmailStage, *messages: list[EmailMultiAlternatives]): | ||||||
|     stage: EmailStage | AuthenticatorEmailStage, *messages: list[EmailMultiAlternatives] |     """Wrapper to convert EmailMessage to dict and send it from worker""" | ||||||
| ): |  | ||||||
|     """Wrapper to convert EmailMessage to dict and send it from worker |  | ||||||
|  |  | ||||||
|     Args: |  | ||||||
|         stage: Either an EmailStage or AuthenticatorEmailStage instance |  | ||||||
|         messages: List of email messages to send |  | ||||||
|     Returns: |  | ||||||
|         Celery group promise for the email sending tasks |  | ||||||
|     """ |  | ||||||
|     tasks = [] |     tasks = [] | ||||||
|     # Use the class path instead of the class itself for serialization |  | ||||||
|     stage_class_path = class_to_path(stage.__class__) |  | ||||||
|     for message in messages: |     for message in messages: | ||||||
|         tasks.append(send_mail.s(message.__dict__, stage_class_path, str(stage.pk))) |         tasks.append(send_mail.s(message.__dict__, str(stage.pk))) | ||||||
|     lazy_group = group(*tasks) |     lazy_group = group(*tasks) | ||||||
|     promise = lazy_group() |     promise = lazy_group() | ||||||
|     return promise |     return promise | ||||||
| @ -60,29 +47,23 @@ def get_email_body(email: EmailMultiAlternatives) -> str: | |||||||
|     retry_backoff=True, |     retry_backoff=True, | ||||||
|     base=SystemTask, |     base=SystemTask, | ||||||
| ) | ) | ||||||
| def send_mail( | def send_mail(self: SystemTask, message: dict[Any, Any], email_stage_pk: str | None = None): | ||||||
|     self: SystemTask, |  | ||||||
|     message: dict[Any, Any], |  | ||||||
|     stage_class_path: str | None = None, |  | ||||||
|     email_stage_pk: str | None = None, |  | ||||||
| ): |  | ||||||
|     """Send Email for Email Stage. Retries are scheduled automatically.""" |     """Send Email for Email Stage. Retries are scheduled automatically.""" | ||||||
|     self.save_on_success = False |     self.save_on_success = False | ||||||
|     message_id = make_msgid(domain=DNS_NAME) |     message_id = make_msgid(domain=DNS_NAME) | ||||||
|     self.set_uid(slugify(message_id.replace(".", "_").replace("@", "_"))) |     self.set_uid(slugify(message_id.replace(".", "_").replace("@", "_"))) | ||||||
|     try: |     try: | ||||||
|         if not stage_class_path or not email_stage_pk: |         if not email_stage_pk: | ||||||
|             stage = EmailStage(use_global_settings=True) |             stage: EmailStage = EmailStage(use_global_settings=True) | ||||||
|         else: |         else: | ||||||
|             stage_class = path_to_class(stage_class_path) |             stages = EmailStage.objects.filter(pk=email_stage_pk) | ||||||
|             stages = stage_class.objects.filter(pk=email_stage_pk) |  | ||||||
|             if not stages.exists(): |             if not stages.exists(): | ||||||
|                 self.set_status( |                 self.set_status( | ||||||
|                     TaskStatus.WARNING, |                     TaskStatus.WARNING, | ||||||
|                     "Email stage does not exist anymore. Discarding message.", |                     "Email stage does not exist anymore. Discarding message.", | ||||||
|                 ) |                 ) | ||||||
|                 return |                 return | ||||||
|             stage: EmailStage | AuthenticatorEmailStage = stages.first() |             stage: EmailStage = stages.first() | ||||||
|         try: |         try: | ||||||
|             backend = stage.backend |             backend = stage.backend | ||||||
|         except ValueError as exc: |         except ValueError as exc: | ||||||
| @ -104,13 +85,6 @@ def send_mail( | |||||||
|         # can't be converted to json) |         # can't be converted to json) | ||||||
|         message_object.attach(logo_data()) |         message_object.attach(logo_data()) | ||||||
|  |  | ||||||
|         if ( |  | ||||||
|             message_object.to |  | ||||||
|             and isinstance(message_object.to[0], str) |  | ||||||
|             and "=?utf-8?" in message_object.to[0] |  | ||||||
|         ): |  | ||||||
|             message_object.to = [message_object.to[0].split("<")[-1].replace(">", "")] |  | ||||||
|  |  | ||||||
|         LOGGER.debug("Sending mail", to=message_object.to) |         LOGGER.debug("Sending mail", to=message_object.to) | ||||||
|         backend.send_messages([message_object]) |         backend.send_messages([message_object]) | ||||||
|         Event.new( |         Event.new( | ||||||
|  | |||||||
| @ -8,7 +8,7 @@ from django.core.mail.backends.locmem import EmailBackend | |||||||
| from django.urls import reverse | from django.urls import reverse | ||||||
|  |  | ||||||
| from authentik.core.models import User | from authentik.core.models import User | ||||||
| from authentik.core.tests.utils import create_test_admin_user, create_test_flow, create_test_user | from authentik.core.tests.utils import create_test_admin_user, create_test_flow | ||||||
| from authentik.events.models import Event, EventAction | from authentik.events.models import Event, EventAction | ||||||
| from authentik.flows.markers import StageMarker | from authentik.flows.markers import StageMarker | ||||||
| from authentik.flows.models import FlowDesignation, FlowStageBinding | from authentik.flows.models import FlowDesignation, FlowStageBinding | ||||||
| @ -67,67 +67,6 @@ class TestEmailStageSending(FlowTestCase): | |||||||
|             self.assertEqual(event.context["to_email"], [f"{self.user.name} <{self.user.email}>"]) |             self.assertEqual(event.context["to_email"], [f"{self.user.name} <{self.user.email}>"]) | ||||||
|             self.assertEqual(event.context["from_email"], "system@authentik.local") |             self.assertEqual(event.context["from_email"], "system@authentik.local") | ||||||
|  |  | ||||||
|     def test_newlines_long_name(self): |  | ||||||
|         """Test with pending user""" |  | ||||||
|         plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) |  | ||||||
|         long_user = create_test_user() |  | ||||||
|         long_user.name = "Test User\r\n Many Words\r\n" |  | ||||||
|         long_user.save() |  | ||||||
|         plan.context[PLAN_CONTEXT_PENDING_USER] = long_user |  | ||||||
|         session = self.client.session |  | ||||||
|         session[SESSION_KEY_PLAN] = plan |  | ||||||
|         session.save() |  | ||||||
|         Event.objects.filter(action=EventAction.EMAIL_SENT).delete() |  | ||||||
|  |  | ||||||
|         url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}) |  | ||||||
|         with patch( |  | ||||||
|             "authentik.stages.email.models.EmailStage.backend_class", |  | ||||||
|             PropertyMock(return_value=EmailBackend), |  | ||||||
|         ): |  | ||||||
|             response = self.client.post(url) |  | ||||||
|             self.assertEqual(response.status_code, 200) |  | ||||||
|             self.assertStageResponse( |  | ||||||
|                 response, |  | ||||||
|                 self.flow, |  | ||||||
|                 response_errors={ |  | ||||||
|                     "non_field_errors": [{"string": "email-sent", "code": "email-sent"}] |  | ||||||
|                 }, |  | ||||||
|             ) |  | ||||||
|             self.assertEqual(len(mail.outbox), 1) |  | ||||||
|             self.assertEqual(mail.outbox[0].subject, "authentik") |  | ||||||
|             self.assertEqual(mail.outbox[0].to, [f"Test User   Many Words   <{long_user.email}>"]) |  | ||||||
|  |  | ||||||
|     def test_utf8_name(self): |  | ||||||
|         """Test with pending user""" |  | ||||||
|         plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) |  | ||||||
|         utf8_user = create_test_user() |  | ||||||
|         utf8_user.name = "Cirilo ЉМНЊ el cirilico И̂ӢЙӤ " |  | ||||||
|         utf8_user.email = "cyrillic@authentik.local" |  | ||||||
|         utf8_user.save() |  | ||||||
|         plan.context[PLAN_CONTEXT_PENDING_USER] = utf8_user |  | ||||||
|         session = self.client.session |  | ||||||
|         session[SESSION_KEY_PLAN] = plan |  | ||||||
|         session.save() |  | ||||||
|         Event.objects.filter(action=EventAction.EMAIL_SENT).delete() |  | ||||||
|  |  | ||||||
|         url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}) |  | ||||||
|         with patch( |  | ||||||
|             "authentik.stages.email.models.EmailStage.backend_class", |  | ||||||
|             PropertyMock(return_value=EmailBackend), |  | ||||||
|         ): |  | ||||||
|             response = self.client.post(url) |  | ||||||
|             self.assertEqual(response.status_code, 200) |  | ||||||
|             self.assertStageResponse( |  | ||||||
|                 response, |  | ||||||
|                 self.flow, |  | ||||||
|                 response_errors={ |  | ||||||
|                     "non_field_errors": [{"string": "email-sent", "code": "email-sent"}] |  | ||||||
|                 }, |  | ||||||
|             ) |  | ||||||
|             self.assertEqual(len(mail.outbox), 1) |  | ||||||
|             self.assertEqual(mail.outbox[0].subject, "authentik") |  | ||||||
|             self.assertEqual(mail.outbox[0].to, [f"{utf8_user.email}"]) |  | ||||||
|  |  | ||||||
|     def test_pending_fake_user(self): |     def test_pending_fake_user(self): | ||||||
|         """Test with pending (fake) user""" |         """Test with pending (fake) user""" | ||||||
|         self.flow.designation = FlowDesignation.RECOVERY |         self.flow.designation = FlowDesignation.RECOVERY | ||||||
|  | |||||||
| @ -1,58 +0,0 @@ | |||||||
| """Test email stage tasks""" |  | ||||||
|  |  | ||||||
| from unittest.mock import patch |  | ||||||
|  |  | ||||||
| from django.core.mail import EmailMultiAlternatives |  | ||||||
| from django.test import TestCase |  | ||||||
|  |  | ||||||
| from authentik.core.tests.utils import create_test_admin_user |  | ||||||
| from authentik.lib.utils.reflection import class_to_path |  | ||||||
| from authentik.stages.authenticator_email.models import AuthenticatorEmailStage |  | ||||||
| from authentik.stages.email.models import EmailStage |  | ||||||
| from authentik.stages.email.tasks import get_email_body, send_mails |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestEmailTasks(TestCase): |  | ||||||
|     """Test email stage tasks""" |  | ||||||
|  |  | ||||||
|     def setUp(self): |  | ||||||
|         self.user = create_test_admin_user() |  | ||||||
|         self.stage = EmailStage.objects.create( |  | ||||||
|             name="test-email", |  | ||||||
|             use_global_settings=True, |  | ||||||
|         ) |  | ||||||
|         self.auth_stage = AuthenticatorEmailStage.objects.create( |  | ||||||
|             name="test-auth-email", |  | ||||||
|             use_global_settings=True, |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def test_get_email_body_html(self): |  | ||||||
|         """Test get_email_body with HTML alternative""" |  | ||||||
|         message = EmailMultiAlternatives() |  | ||||||
|         message.body = "plain text" |  | ||||||
|         message.attach_alternative("<p>html content</p>", "text/html") |  | ||||||
|         self.assertEqual(get_email_body(message), "<p>html content</p>") |  | ||||||
|  |  | ||||||
|     def test_get_email_body_plain(self): |  | ||||||
|         """Test get_email_body with plain text only""" |  | ||||||
|         message = EmailMultiAlternatives() |  | ||||||
|         message.body = "plain text" |  | ||||||
|         self.assertEqual(get_email_body(message), "plain text") |  | ||||||
|  |  | ||||||
|     def test_send_mails_email_stage(self): |  | ||||||
|         """Test send_mails with EmailStage""" |  | ||||||
|         message = EmailMultiAlternatives() |  | ||||||
|         with patch("authentik.stages.email.tasks.send_mail") as mock_send: |  | ||||||
|             send_mails(self.stage, message) |  | ||||||
|             mock_send.s.assert_called_once_with( |  | ||||||
|                 message.__dict__, class_to_path(EmailStage), str(self.stage.pk) |  | ||||||
|             ) |  | ||||||
|  |  | ||||||
|     def test_send_mails_authenticator_stage(self): |  | ||||||
|         """Test send_mails with AuthenticatorEmailStage""" |  | ||||||
|         message = EmailMultiAlternatives() |  | ||||||
|         with patch("authentik.stages.email.tasks.send_mail") as mock_send: |  | ||||||
|             send_mails(self.auth_stage, message) |  | ||||||
|             mock_send.s.assert_called_once_with( |  | ||||||
|                 message.__dict__, class_to_path(AuthenticatorEmailStage), str(self.auth_stage.pk) |  | ||||||
|             ) |  | ||||||
| @ -32,14 +32,7 @@ class TemplateEmailMessage(EmailMultiAlternatives): | |||||||
|         sanitized_to = [] |         sanitized_to = [] | ||||||
|         # Ensure that all recipients are valid |         # Ensure that all recipients are valid | ||||||
|         for recipient_name, recipient_email in to: |         for recipient_name, recipient_email in to: | ||||||
|             # Remove any newline characters from name and email before sanitizing |             sanitized_to.append(sanitize_address((recipient_name, recipient_email), "utf-8")) | ||||||
|             clean_name = ( |  | ||||||
|                 recipient_name.replace("\n", " ").replace("\r", " ") if recipient_name else "" |  | ||||||
|             ) |  | ||||||
|             clean_email = ( |  | ||||||
|                 recipient_email.replace("\n", "").replace("\r", "") if recipient_email else "" |  | ||||||
|             ) |  | ||||||
|             sanitized_to.append(sanitize_address((clean_name, clean_email), "utf-8")) |  | ||||||
|         super().__init__(to=sanitized_to, **kwargs) |         super().__init__(to=sanitized_to, **kwargs) | ||||||
|         if not template_name: |         if not template_name: | ||||||
|             return |             return | ||||||
|  | |||||||
| @ -1,30 +0,0 @@ | |||||||
| version: 1 |  | ||||||
| metadata: |  | ||||||
|   labels: |  | ||||||
|     blueprints.goauthentik.io/instantiate: "false" |  | ||||||
|   name: Example - Email MFA setup flow |  | ||||||
| entries: |  | ||||||
| - attrs: |  | ||||||
|     designation: stage_configuration |  | ||||||
|     name: Default Email Authenticator Flow |  | ||||||
|     title: Setup Email Two-Factor Authentication |  | ||||||
|     authentication: require_authenticated |  | ||||||
|   identifiers: |  | ||||||
|     slug: default-authenticator-email-setup |  | ||||||
|   model: authentik_flows.flow |  | ||||||
|   id: flow |  | ||||||
| - attrs: |  | ||||||
|     configure_flow: !KeyOf flow |  | ||||||
|     friendly_name: Email Authenticator |  | ||||||
|     use_global_settings: true |  | ||||||
|     token_expiry: minutes=30 |  | ||||||
|     subject: authentik Sign-in code |  | ||||||
|   identifiers: |  | ||||||
|     name: default-authenticator-email-setup |  | ||||||
|   id: default-authenticator-email-setup |  | ||||||
|   model: authentik_stages_authenticator_email.authenticatoremailstage |  | ||||||
| - identifiers: |  | ||||||
|     order: 0 |  | ||||||
|     stage: !KeyOf default-authenticator-email-setup |  | ||||||
|     target: !KeyOf flow |  | ||||||
|   model: authentik_flows.flowstagebinding |  | ||||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @ -10,7 +10,6 @@ import ( | |||||||
|  |  | ||||||
| 	"goauthentik.io/internal/common" | 	"goauthentik.io/internal/common" | ||||||
| 	"goauthentik.io/internal/config" | 	"goauthentik.io/internal/config" | ||||||
| 	"goauthentik.io/internal/constants" |  | ||||||
| 	"goauthentik.io/internal/debug" | 	"goauthentik.io/internal/debug" | ||||||
| 	"goauthentik.io/internal/outpost/ak" | 	"goauthentik.io/internal/outpost/ak" | ||||||
| 	"goauthentik.io/internal/outpost/ak/healthcheck" | 	"goauthentik.io/internal/outpost/ak/healthcheck" | ||||||
| @ -26,7 +25,6 @@ Required environment variables: | |||||||
|  |  | ||||||
| var rootCmd = &cobra.Command{ | var rootCmd = &cobra.Command{ | ||||||
| 	Long: helpMessage, | 	Long: helpMessage, | ||||||
| 	Version: constants.FullVersion(), |  | ||||||
| 	PersistentPreRun: func(cmd *cobra.Command, args []string) { | 	PersistentPreRun: func(cmd *cobra.Command, args []string) { | ||||||
| 		log.SetLevel(log.DebugLevel) | 		log.SetLevel(log.DebugLevel) | ||||||
| 		log.SetFormatter(&log.JSONFormatter{ | 		log.SetFormatter(&log.JSONFormatter{ | ||||||
|  | |||||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user
	