root: reformat to 100 line width
Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
		| @ -23,9 +23,7 @@ def get_events_per_1h(**filter_kwargs) -> list[dict[str, int]]: | |||||||
|     date_from = now() - timedelta(days=1) |     date_from = now() - timedelta(days=1) | ||||||
|     result = ( |     result = ( | ||||||
|         Event.objects.filter(created__gte=date_from, **filter_kwargs) |         Event.objects.filter(created__gte=date_from, **filter_kwargs) | ||||||
|         .annotate( |         .annotate(age=ExpressionWrapper(now() - F("created"), output_field=DurationField())) | ||||||
|             age=ExpressionWrapper(now() - F("created"), output_field=DurationField()) |  | ||||||
|         ) |  | ||||||
|         .annotate(age_hours=ExtractHour("age")) |         .annotate(age_hours=ExtractHour("age")) | ||||||
|         .values("age_hours") |         .values("age_hours") | ||||||
|         .annotate(count=Count("pk")) |         .annotate(count=Count("pk")) | ||||||
| @ -37,8 +35,7 @@ def get_events_per_1h(**filter_kwargs) -> list[dict[str, int]]: | |||||||
|     for hour in range(0, -24, -1): |     for hour in range(0, -24, -1): | ||||||
|         results.append( |         results.append( | ||||||
|             { |             { | ||||||
|                 "x_cord": time.mktime((_now + timedelta(hours=hour)).timetuple()) |                 "x_cord": time.mktime((_now + timedelta(hours=hour)).timetuple()) * 1000, | ||||||
|                 * 1000, |  | ||||||
|                 "y_cord": data[hour * -1], |                 "y_cord": data[hour * -1], | ||||||
|             } |             } | ||||||
|         ) |         ) | ||||||
|  | |||||||
| @ -61,9 +61,7 @@ class SystemSerializer(PassiveSerializer): | |||||||
|         return { |         return { | ||||||
|             "python_version": python_version, |             "python_version": python_version, | ||||||
|             "gunicorn_version": ".".join(str(x) for x in gunicorn_version), |             "gunicorn_version": ".".join(str(x) for x in gunicorn_version), | ||||||
|             "environment": "kubernetes" |             "environment": "kubernetes" if SERVICE_HOST_ENV_NAME in os.environ else "compose", | ||||||
|             if SERVICE_HOST_ENV_NAME in os.environ |  | ||||||
|             else "compose", |  | ||||||
|             "architecture": platform.machine(), |             "architecture": platform.machine(), | ||||||
|             "platform": platform.platform(), |             "platform": platform.platform(), | ||||||
|             "uname": " ".join(platform.uname()), |             "uname": " ".join(platform.uname()), | ||||||
|  | |||||||
| @ -92,10 +92,7 @@ class TaskViewSet(ViewSet): | |||||||
|             task_func.delay(*task.task_call_args, **task.task_call_kwargs) |             task_func.delay(*task.task_call_args, **task.task_call_kwargs) | ||||||
|             messages.success( |             messages.success( | ||||||
|                 self.request, |                 self.request, | ||||||
|                 _( |                 _("Successfully re-scheduled Task %(name)s!" % {"name": task.task_name}), | ||||||
|                     "Successfully re-scheduled Task %(name)s!" |  | ||||||
|                     % {"name": task.task_name} |  | ||||||
|                 ), |  | ||||||
|             ) |             ) | ||||||
|             return Response(status=204) |             return Response(status=204) | ||||||
|         except ImportError:  # pragma: no cover |         except ImportError:  # pragma: no cover | ||||||
|  | |||||||
| @ -41,9 +41,7 @@ class VersionSerializer(PassiveSerializer): | |||||||
|  |  | ||||||
|     def get_outdated(self, instance) -> bool: |     def get_outdated(self, instance) -> bool: | ||||||
|         """Check if we're running the latest version""" |         """Check if we're running the latest version""" | ||||||
|         return parse(self.get_version_current(instance)) < parse( |         return parse(self.get_version_current(instance)) < parse(self.get_version_latest(instance)) | ||||||
|             self.get_version_latest(instance) |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class VersionView(APIView): | class VersionView(APIView): | ||||||
|  | |||||||
| @ -17,9 +17,7 @@ class WorkerView(APIView): | |||||||
|  |  | ||||||
|     permission_classes = [IsAdminUser] |     permission_classes = [IsAdminUser] | ||||||
|  |  | ||||||
|     @extend_schema( |     @extend_schema(responses=inline_serializer("Workers", fields={"count": IntegerField()})) | ||||||
|         responses=inline_serializer("Workers", fields={"count": IntegerField()}) |  | ||||||
|     ) |  | ||||||
|     def get(self, request: Request) -> Response: |     def get(self, request: Request) -> Response: | ||||||
|         """Get currently connected worker count.""" |         """Get currently connected worker count.""" | ||||||
|         count = len(CELERY_APP.control.ping(timeout=0.5)) |         count = len(CELERY_APP.control.ping(timeout=0.5)) | ||||||
|  | |||||||
| @ -37,18 +37,14 @@ def _set_prom_info(): | |||||||
| def update_latest_version(self: MonitoredTask): | def update_latest_version(self: MonitoredTask): | ||||||
|     """Update latest version info""" |     """Update latest version info""" | ||||||
|     try: |     try: | ||||||
|         response = get( |         response = get("https://api.github.com/repos/goauthentik/authentik/releases/latest") | ||||||
|             "https://api.github.com/repos/goauthentik/authentik/releases/latest" |  | ||||||
|         ) |  | ||||||
|         response.raise_for_status() |         response.raise_for_status() | ||||||
|         data = response.json() |         data = response.json() | ||||||
|         tag_name = data.get("tag_name") |         tag_name = data.get("tag_name") | ||||||
|         upstream_version = tag_name.split("/")[1] |         upstream_version = tag_name.split("/")[1] | ||||||
|         cache.set(VERSION_CACHE_KEY, upstream_version, VERSION_CACHE_TIMEOUT) |         cache.set(VERSION_CACHE_KEY, upstream_version, VERSION_CACHE_TIMEOUT) | ||||||
|         self.set_status( |         self.set_status( | ||||||
|             TaskResult( |             TaskResult(TaskResultStatus.SUCCESSFUL, ["Successfully updated latest Version"]) | ||||||
|                 TaskResultStatus.SUCCESSFUL, ["Successfully updated latest Version"] |  | ||||||
|             ) |  | ||||||
|         ) |         ) | ||||||
|         _set_prom_info() |         _set_prom_info() | ||||||
|         # Check if upstream version is newer than what we're running, |         # Check if upstream version is newer than what we're running, | ||||||
|  | |||||||
| @ -27,9 +27,7 @@ class TestAdminAPI(TestCase): | |||||||
|         response = self.client.get(reverse("authentik_api:admin_system_tasks-list")) |         response = self.client.get(reverse("authentik_api:admin_system_tasks-list")) | ||||||
|         self.assertEqual(response.status_code, 200) |         self.assertEqual(response.status_code, 200) | ||||||
|         body = loads(response.content) |         body = loads(response.content) | ||||||
|         self.assertTrue( |         self.assertTrue(any(task["task_name"] == "clean_expired_models" for task in body)) | ||||||
|             any(task["task_name"] == "clean_expired_models" for task in body) |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def test_tasks_single(self): |     def test_tasks_single(self): | ||||||
|         """Test Task API (read single)""" |         """Test Task API (read single)""" | ||||||
| @ -45,9 +43,7 @@ class TestAdminAPI(TestCase): | |||||||
|         self.assertEqual(body["status"], TaskResultStatus.SUCCESSFUL.name) |         self.assertEqual(body["status"], TaskResultStatus.SUCCESSFUL.name) | ||||||
|         self.assertEqual(body["task_name"], "clean_expired_models") |         self.assertEqual(body["task_name"], "clean_expired_models") | ||||||
|         response = self.client.get( |         response = self.client.get( | ||||||
|             reverse( |             reverse("authentik_api:admin_system_tasks-detail", kwargs={"pk": "qwerqwer"}) | ||||||
|                 "authentik_api:admin_system_tasks-detail", kwargs={"pk": "qwerqwer"} |  | ||||||
|             ) |  | ||||||
|         ) |         ) | ||||||
|         self.assertEqual(response.status_code, 404) |         self.assertEqual(response.status_code, 404) | ||||||
|  |  | ||||||
|  | |||||||
| @ -7,9 +7,7 @@ from rest_framework.response import Response | |||||||
| from rest_framework.viewsets import ModelViewSet | from rest_framework.viewsets import ModelViewSet | ||||||
|  |  | ||||||
|  |  | ||||||
| def permission_required( | def permission_required(perm: Optional[str] = None, other_perms: Optional[list[str]] = None): | ||||||
|     perm: Optional[str] = None, other_perms: Optional[list[str]] = None |  | ||||||
| ): |  | ||||||
|     """Check permissions for a single custom action""" |     """Check permissions for a single custom action""" | ||||||
|  |  | ||||||
|     def wrapper_outter(func: Callable): |     def wrapper_outter(func: Callable): | ||||||
|  | |||||||
| @ -63,9 +63,7 @@ def postprocess_schema_responses(result, generator, **kwargs):  # noqa: W0613 | |||||||
|             method["responses"].setdefault("400", validation_error.ref) |             method["responses"].setdefault("400", validation_error.ref) | ||||||
|             method["responses"].setdefault("403", generic_error.ref) |             method["responses"].setdefault("403", generic_error.ref) | ||||||
|  |  | ||||||
|     result["components"] = generator.registry.build( |     result["components"] = generator.registry.build(spectacular_settings.APPEND_COMPONENTS) | ||||||
|         spectacular_settings.APPEND_COMPONENTS |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     # This is a workaround for authentik/stages/prompt/stage.py |     # This is a workaround for authentik/stages/prompt/stage.py | ||||||
|     # since the serializer PromptChallengeResponse |     # since the serializer PromptChallengeResponse | ||||||
|  | |||||||
| @ -16,17 +16,13 @@ class TestAPIAuth(TestCase): | |||||||
|  |  | ||||||
|     def test_valid_basic(self): |     def test_valid_basic(self): | ||||||
|         """Test valid token""" |         """Test valid token""" | ||||||
|         token = Token.objects.create( |         token = Token.objects.create(intent=TokenIntents.INTENT_API, user=get_anonymous_user()) | ||||||
|             intent=TokenIntents.INTENT_API, user=get_anonymous_user() |  | ||||||
|         ) |  | ||||||
|         auth = b64encode(f":{token.key}".encode()).decode() |         auth = b64encode(f":{token.key}".encode()).decode() | ||||||
|         self.assertEqual(bearer_auth(f"Basic {auth}".encode()), token.user) |         self.assertEqual(bearer_auth(f"Basic {auth}".encode()), token.user) | ||||||
|  |  | ||||||
|     def test_valid_bearer(self): |     def test_valid_bearer(self): | ||||||
|         """Test valid token""" |         """Test valid token""" | ||||||
|         token = Token.objects.create( |         token = Token.objects.create(intent=TokenIntents.INTENT_API, user=get_anonymous_user()) | ||||||
|             intent=TokenIntents.INTENT_API, user=get_anonymous_user() |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(bearer_auth(f"Bearer {token.key}".encode()), token.user) |         self.assertEqual(bearer_auth(f"Bearer {token.key}".encode()), token.user) | ||||||
|  |  | ||||||
|     def test_invalid_type(self): |     def test_invalid_type(self): | ||||||
|  | |||||||
| @ -52,20 +52,12 @@ from authentik.policies.reputation.api import ( | |||||||
| from authentik.providers.ldap.api import LDAPOutpostConfigViewSet, LDAPProviderViewSet | from authentik.providers.ldap.api import LDAPOutpostConfigViewSet, LDAPProviderViewSet | ||||||
| from authentik.providers.oauth2.api.provider import OAuth2ProviderViewSet | from authentik.providers.oauth2.api.provider import OAuth2ProviderViewSet | ||||||
| from authentik.providers.oauth2.api.scope import ScopeMappingViewSet | from authentik.providers.oauth2.api.scope import ScopeMappingViewSet | ||||||
| from authentik.providers.oauth2.api.tokens import ( | from authentik.providers.oauth2.api.tokens import AuthorizationCodeViewSet, RefreshTokenViewSet | ||||||
|     AuthorizationCodeViewSet, | from authentik.providers.proxy.api import ProxyOutpostConfigViewSet, ProxyProviderViewSet | ||||||
|     RefreshTokenViewSet, |  | ||||||
| ) |  | ||||||
| from authentik.providers.proxy.api import ( |  | ||||||
|     ProxyOutpostConfigViewSet, |  | ||||||
|     ProxyProviderViewSet, |  | ||||||
| ) |  | ||||||
| from authentik.providers.saml.api import SAMLPropertyMappingViewSet, SAMLProviderViewSet | from authentik.providers.saml.api import SAMLPropertyMappingViewSet, SAMLProviderViewSet | ||||||
| from authentik.sources.ldap.api import LDAPPropertyMappingViewSet, LDAPSourceViewSet | from authentik.sources.ldap.api import LDAPPropertyMappingViewSet, LDAPSourceViewSet | ||||||
| from authentik.sources.oauth.api.source import OAuthSourceViewSet | from authentik.sources.oauth.api.source import OAuthSourceViewSet | ||||||
| from authentik.sources.oauth.api.source_connection import ( | from authentik.sources.oauth.api.source_connection import UserOAuthSourceConnectionViewSet | ||||||
|     UserOAuthSourceConnectionViewSet, |  | ||||||
| ) |  | ||||||
| from authentik.sources.plex.api import PlexSourceViewSet | from authentik.sources.plex.api import PlexSourceViewSet | ||||||
| from authentik.sources.saml.api import SAMLSourceViewSet | from authentik.sources.saml.api import SAMLSourceViewSet | ||||||
| from authentik.stages.authenticator_duo.api import ( | from authentik.stages.authenticator_duo.api import ( | ||||||
| @ -83,9 +75,7 @@ from authentik.stages.authenticator_totp.api import ( | |||||||
|     TOTPAdminDeviceViewSet, |     TOTPAdminDeviceViewSet, | ||||||
|     TOTPDeviceViewSet, |     TOTPDeviceViewSet, | ||||||
| ) | ) | ||||||
| from authentik.stages.authenticator_validate.api import ( | from authentik.stages.authenticator_validate.api import AuthenticatorValidateStageViewSet | ||||||
|     AuthenticatorValidateStageViewSet, |  | ||||||
| ) |  | ||||||
| from authentik.stages.authenticator_webauthn.api import ( | from authentik.stages.authenticator_webauthn.api import ( | ||||||
|     AuthenticateWebAuthnStageViewSet, |     AuthenticateWebAuthnStageViewSet, | ||||||
|     WebAuthnAdminDeviceViewSet, |     WebAuthnAdminDeviceViewSet, | ||||||
| @ -122,9 +112,7 @@ router.register("core/tenants", TenantViewSet) | |||||||
| router.register("outposts/instances", OutpostViewSet) | router.register("outposts/instances", OutpostViewSet) | ||||||
| router.register("outposts/service_connections/all", ServiceConnectionViewSet) | router.register("outposts/service_connections/all", ServiceConnectionViewSet) | ||||||
| router.register("outposts/service_connections/docker", DockerServiceConnectionViewSet) | router.register("outposts/service_connections/docker", DockerServiceConnectionViewSet) | ||||||
| router.register( | router.register("outposts/service_connections/kubernetes", KubernetesServiceConnectionViewSet) | ||||||
|     "outposts/service_connections/kubernetes", KubernetesServiceConnectionViewSet |  | ||||||
| ) |  | ||||||
| router.register("outposts/proxy", ProxyOutpostConfigViewSet) | router.register("outposts/proxy", ProxyOutpostConfigViewSet) | ||||||
| router.register("outposts/ldap", LDAPOutpostConfigViewSet) | router.register("outposts/ldap", LDAPOutpostConfigViewSet) | ||||||
|  |  | ||||||
| @ -184,9 +172,7 @@ router.register( | |||||||
|     StaticAdminDeviceViewSet, |     StaticAdminDeviceViewSet, | ||||||
|     basename="admin-staticdevice", |     basename="admin-staticdevice", | ||||||
| ) | ) | ||||||
| router.register( | router.register("authenticators/admin/totp", TOTPAdminDeviceViewSet, basename="admin-totpdevice") | ||||||
|     "authenticators/admin/totp", TOTPAdminDeviceViewSet, basename="admin-totpdevice" |  | ||||||
| ) |  | ||||||
| router.register( | router.register( | ||||||
|     "authenticators/admin/webauthn", |     "authenticators/admin/webauthn", | ||||||
|     WebAuthnAdminDeviceViewSet, |     WebAuthnAdminDeviceViewSet, | ||||||
|  | |||||||
| @ -147,9 +147,7 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet): | |||||||
|         """Custom list method that checks Policy based access instead of guardian""" |         """Custom list method that checks Policy based access instead of guardian""" | ||||||
|         should_cache = request.GET.get("search", "") == "" |         should_cache = request.GET.get("search", "") == "" | ||||||
|  |  | ||||||
|         superuser_full_list = ( |         superuser_full_list = str(request.GET.get("superuser_full_list", "false")).lower() == "true" | ||||||
|             str(request.GET.get("superuser_full_list", "false")).lower() == "true" |  | ||||||
|         ) |  | ||||||
|         if superuser_full_list and request.user.is_superuser: |         if superuser_full_list and request.user.is_superuser: | ||||||
|             return super().list(request) |             return super().list(request) | ||||||
|  |  | ||||||
| @ -240,9 +238,7 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet): | |||||||
|         app.save() |         app.save() | ||||||
|         return Response({}) |         return Response({}) | ||||||
|  |  | ||||||
|     @permission_required( |     @permission_required("authentik_core.view_application", ["authentik_events.view_event"]) | ||||||
|         "authentik_core.view_application", ["authentik_events.view_event"] |  | ||||||
|     ) |  | ||||||
|     @extend_schema(responses={200: CoordinateSerializer(many=True)}) |     @extend_schema(responses={200: CoordinateSerializer(many=True)}) | ||||||
|     @action(detail=True, pagination_class=None, filter_backends=[]) |     @action(detail=True, pagination_class=None, filter_backends=[]) | ||||||
|     # pylint: disable=unused-argument |     # pylint: disable=unused-argument | ||||||
|  | |||||||
| @ -68,9 +68,7 @@ class AuthenticatedSessionSerializer(ModelSerializer): | |||||||
|         """Get parsed user agent""" |         """Get parsed user agent""" | ||||||
|         return user_agent_parser.Parse(instance.last_user_agent) |         return user_agent_parser.Parse(instance.last_user_agent) | ||||||
|  |  | ||||||
|     def get_geo_ip( |     def get_geo_ip(self, instance: AuthenticatedSession) -> Optional[GeoIPDict]:  # pragma: no cover | ||||||
|         self, instance: AuthenticatedSession |  | ||||||
|     ) -> Optional[GeoIPDict]:  # pragma: no cover |  | ||||||
|         """Get parsed user agent""" |         """Get parsed user agent""" | ||||||
|         return GEOIP_READER.city_dict(instance.last_ip) |         return GEOIP_READER.city_dict(instance.last_ip) | ||||||
|  |  | ||||||
|  | |||||||
| @ -15,11 +15,7 @@ from rest_framework.viewsets import GenericViewSet | |||||||
|  |  | ||||||
| from authentik.api.decorators import permission_required | from authentik.api.decorators import permission_required | ||||||
| from authentik.core.api.used_by import UsedByMixin | from authentik.core.api.used_by import UsedByMixin | ||||||
| from authentik.core.api.utils import ( | from authentik.core.api.utils import MetaNameSerializer, PassiveSerializer, TypeCreateSerializer | ||||||
|     MetaNameSerializer, |  | ||||||
|     PassiveSerializer, |  | ||||||
|     TypeCreateSerializer, |  | ||||||
| ) |  | ||||||
| from authentik.core.expression import PropertyMappingEvaluator | from authentik.core.expression import PropertyMappingEvaluator | ||||||
| from authentik.core.models import PropertyMapping | from authentik.core.models import PropertyMapping | ||||||
| from authentik.lib.utils.reflection import all_subclasses | from authentik.lib.utils.reflection import all_subclasses | ||||||
| @ -141,9 +137,7 @@ class PropertyMappingViewSet( | |||||||
|                 self.request, |                 self.request, | ||||||
|                 **test_params.validated_data.get("context", {}), |                 **test_params.validated_data.get("context", {}), | ||||||
|             ) |             ) | ||||||
|             response_data["result"] = dumps( |             response_data["result"] = dumps(result, indent=(4 if format_result else None)) | ||||||
|                 result, indent=(4 if format_result else None) |  | ||||||
|             ) |  | ||||||
|         except Exception as exc:  # pylint: disable=broad-except |         except Exception as exc:  # pylint: disable=broad-except | ||||||
|             response_data["result"] = str(exc) |             response_data["result"] = str(exc) | ||||||
|             response_data["successful"] = False |             response_data["successful"] = False | ||||||
|  | |||||||
| @ -93,9 +93,7 @@ class SourceViewSet( | |||||||
|     @action(detail=False, pagination_class=None, filter_backends=[]) |     @action(detail=False, pagination_class=None, filter_backends=[]) | ||||||
|     def user_settings(self, request: Request) -> Response: |     def user_settings(self, request: Request) -> Response: | ||||||
|         """Get all sources the user can configure""" |         """Get all sources the user can configure""" | ||||||
|         _all_sources: Iterable[Source] = Source.objects.filter( |         _all_sources: Iterable[Source] = Source.objects.filter(enabled=True).select_subclasses() | ||||||
|             enabled=True |  | ||||||
|         ).select_subclasses() |  | ||||||
|         matching_sources: list[UserSettingSerializer] = [] |         matching_sources: list[UserSettingSerializer] = [] | ||||||
|         for source in _all_sources: |         for source in _all_sources: | ||||||
|             user_settings = source.ui_user_settings |             user_settings = source.ui_user_settings | ||||||
|  | |||||||
| @ -70,9 +70,7 @@ class TokenViewSet(UsedByMixin, ModelViewSet): | |||||||
|         serializer.save( |         serializer.save( | ||||||
|             user=self.request.user, |             user=self.request.user, | ||||||
|             intent=TokenIntents.INTENT_API, |             intent=TokenIntents.INTENT_API, | ||||||
|             expiring=self.request.user.attributes.get( |             expiring=self.request.user.attributes.get(USER_ATTRIBUTE_TOKEN_EXPIRING, True), | ||||||
|                 USER_ATTRIBUTE_TOKEN_EXPIRING, True |  | ||||||
|             ), |  | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     @permission_required("authentik_core.view_token_key") |     @permission_required("authentik_core.view_token_key") | ||||||
| @ -89,7 +87,5 @@ class TokenViewSet(UsedByMixin, ModelViewSet): | |||||||
|         token: Token = self.get_object() |         token: Token = self.get_object() | ||||||
|         if token.is_expired: |         if token.is_expired: | ||||||
|             raise Http404 |             raise Http404 | ||||||
|         Event.new(EventAction.SECRET_VIEW, secret=token).from_http(  # noqa # nosec |         Event.new(EventAction.SECRET_VIEW, secret=token).from_http(request)  # noqa # nosec | ||||||
|             request |  | ||||||
|         ) |  | ||||||
|         return Response(TokenViewSerializer({"key": token.key}).data) |         return Response(TokenViewSerializer({"key": token.key}).data) | ||||||
|  | |||||||
| @ -79,9 +79,7 @@ class UsedByMixin: | |||||||
|             ).all(): |             ).all(): | ||||||
|                 # Only merge shadows on first object |                 # Only merge shadows on first object | ||||||
|                 if first_object: |                 if first_object: | ||||||
|                     shadows += getattr( |                     shadows += getattr(manager.model._meta, "authentik_used_by_shadows", []) | ||||||
|                         manager.model._meta, "authentik_used_by_shadows", [] |  | ||||||
|                     ) |  | ||||||
|                 first_object = False |                 first_object = False | ||||||
|                 serializer = UsedBySerializer( |                 serializer = UsedBySerializer( | ||||||
|                     data={ |                     data={ | ||||||
|  | |||||||
| @ -26,10 +26,7 @@ from authentik.api.decorators import permission_required | |||||||
| from authentik.core.api.groups import GroupSerializer | from authentik.core.api.groups import GroupSerializer | ||||||
| from authentik.core.api.used_by import UsedByMixin | from authentik.core.api.used_by import UsedByMixin | ||||||
| from authentik.core.api.utils import LinkSerializer, PassiveSerializer, is_dict | from authentik.core.api.utils import LinkSerializer, PassiveSerializer, is_dict | ||||||
| from authentik.core.middleware import ( | from authentik.core.middleware import SESSION_IMPERSONATE_ORIGINAL_USER, SESSION_IMPERSONATE_USER | ||||||
|     SESSION_IMPERSONATE_ORIGINAL_USER, |  | ||||||
|     SESSION_IMPERSONATE_USER, |  | ||||||
| ) |  | ||||||
| from authentik.core.models import Token, TokenIntents, User | from authentik.core.models import Token, TokenIntents, User | ||||||
| from authentik.events.models import EventAction | from authentik.events.models import EventAction | ||||||
| from authentik.tenants.models import Tenant | from authentik.tenants.models import Tenant | ||||||
| @ -87,17 +84,13 @@ class UserMetricsSerializer(PassiveSerializer): | |||||||
|     def get_logins_failed_per_1h(self, _): |     def get_logins_failed_per_1h(self, _): | ||||||
|         """Get failed logins per hour for the last 24 hours""" |         """Get failed logins per hour for the last 24 hours""" | ||||||
|         user = self.context["user"] |         user = self.context["user"] | ||||||
|         return get_events_per_1h( |         return get_events_per_1h(action=EventAction.LOGIN_FAILED, context__username=user.username) | ||||||
|             action=EventAction.LOGIN_FAILED, context__username=user.username |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     @extend_schema_field(CoordinateSerializer(many=True)) |     @extend_schema_field(CoordinateSerializer(many=True)) | ||||||
|     def get_authorizations_per_1h(self, _): |     def get_authorizations_per_1h(self, _): | ||||||
|         """Get failed logins per hour for the last 24 hours""" |         """Get failed logins per hour for the last 24 hours""" | ||||||
|         user = self.context["user"] |         user = self.context["user"] | ||||||
|         return get_events_per_1h( |         return get_events_per_1h(action=EventAction.AUTHORIZE_APPLICATION, user__pk=user.pk) | ||||||
|             action=EventAction.AUTHORIZE_APPLICATION, user__pk=user.pk |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class UsersFilter(FilterSet): | class UsersFilter(FilterSet): | ||||||
| @ -154,9 +147,7 @@ class UserViewSet(UsedByMixin, ModelViewSet): | |||||||
|     # pylint: disable=invalid-name |     # pylint: disable=invalid-name | ||||||
|     def me(self, request: Request) -> Response: |     def me(self, request: Request) -> Response: | ||||||
|         """Get information about current user""" |         """Get information about current user""" | ||||||
|         serializer = SessionUserSerializer( |         serializer = SessionUserSerializer(data={"user": UserSerializer(request.user).data}) | ||||||
|             data={"user": UserSerializer(request.user).data} |  | ||||||
|         ) |  | ||||||
|         if SESSION_IMPERSONATE_USER in request._request.session: |         if SESSION_IMPERSONATE_USER in request._request.session: | ||||||
|             serializer.initial_data["original"] = UserSerializer( |             serializer.initial_data["original"] = UserSerializer( | ||||||
|                 request._request.session[SESSION_IMPERSONATE_ORIGINAL_USER] |                 request._request.session[SESSION_IMPERSONATE_ORIGINAL_USER] | ||||||
|  | |||||||
| @ -3,20 +3,14 @@ from typing import Any | |||||||
|  |  | ||||||
| from django.db.models import Model | from django.db.models import Model | ||||||
| from rest_framework.fields import CharField, IntegerField | from rest_framework.fields import CharField, IntegerField | ||||||
| from rest_framework.serializers import ( | from rest_framework.serializers import Serializer, SerializerMethodField, ValidationError | ||||||
|     Serializer, |  | ||||||
|     SerializerMethodField, |  | ||||||
|     ValidationError, |  | ||||||
| ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def is_dict(value: Any): | def is_dict(value: Any): | ||||||
|     """Ensure a value is a dictionary, useful for JSONFields""" |     """Ensure a value is a dictionary, useful for JSONFields""" | ||||||
|     if isinstance(value, dict): |     if isinstance(value, dict): | ||||||
|         return |         return | ||||||
|     raise ValidationError( |     raise ValidationError("Value must be a dictionary, and not have any duplicate keys.") | ||||||
|         "Value must be a dictionary, and not have any duplicate keys." |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class PassiveSerializer(Serializer): | class PassiveSerializer(Serializer): | ||||||
| @ -25,9 +19,7 @@ class PassiveSerializer(Serializer): | |||||||
|     def create(self, validated_data: dict) -> Model:  # pragma: no cover |     def create(self, validated_data: dict) -> Model:  # pragma: no cover | ||||||
|         return Model() |         return Model() | ||||||
|  |  | ||||||
|     def update( |     def update(self, instance: Model, validated_data: dict) -> Model:  # pragma: no cover | ||||||
|         self, instance: Model, validated_data: dict |  | ||||||
|     ) -> Model:  # pragma: no cover |  | ||||||
|         return Model() |         return Model() | ||||||
|  |  | ||||||
|     class Meta: |     class Meta: | ||||||
|  | |||||||
| @ -38,9 +38,7 @@ class Migration(migrations.Migration): | |||||||
|                 ("password", models.CharField(max_length=128, verbose_name="password")), |                 ("password", models.CharField(max_length=128, verbose_name="password")), | ||||||
|                 ( |                 ( | ||||||
|                     "last_login", |                     "last_login", | ||||||
|                     models.DateTimeField( |                     models.DateTimeField(blank=True, null=True, verbose_name="last login"), | ||||||
|                         blank=True, null=True, verbose_name="last login" |  | ||||||
|                     ), |  | ||||||
|                 ), |                 ), | ||||||
|                 ( |                 ( | ||||||
|                     "is_superuser", |                     "is_superuser", | ||||||
| @ -53,35 +51,25 @@ class Migration(migrations.Migration): | |||||||
|                 ( |                 ( | ||||||
|                     "username", |                     "username", | ||||||
|                     models.CharField( |                     models.CharField( | ||||||
|                         error_messages={ |                         error_messages={"unique": "A user with that username already exists."}, | ||||||
|                             "unique": "A user with that username already exists." |  | ||||||
|                         }, |  | ||||||
|                         help_text="Required. 150 characters or fewer. Letters, digits and @/./+/-/_ only.", |                         help_text="Required. 150 characters or fewer. Letters, digits and @/./+/-/_ only.", | ||||||
|                         max_length=150, |                         max_length=150, | ||||||
|                         unique=True, |                         unique=True, | ||||||
|                         validators=[ |                         validators=[django.contrib.auth.validators.UnicodeUsernameValidator()], | ||||||
|                             django.contrib.auth.validators.UnicodeUsernameValidator() |  | ||||||
|                         ], |  | ||||||
|                         verbose_name="username", |                         verbose_name="username", | ||||||
|                     ), |                     ), | ||||||
|                 ), |                 ), | ||||||
|                 ( |                 ( | ||||||
|                     "first_name", |                     "first_name", | ||||||
|                     models.CharField( |                     models.CharField(blank=True, max_length=30, verbose_name="first name"), | ||||||
|                         blank=True, max_length=30, verbose_name="first name" |  | ||||||
|                     ), |  | ||||||
|                 ), |                 ), | ||||||
|                 ( |                 ( | ||||||
|                     "last_name", |                     "last_name", | ||||||
|                     models.CharField( |                     models.CharField(blank=True, max_length=150, verbose_name="last name"), | ||||||
|                         blank=True, max_length=150, verbose_name="last name" |  | ||||||
|                     ), |  | ||||||
|                 ), |                 ), | ||||||
|                 ( |                 ( | ||||||
|                     "email", |                     "email", | ||||||
|                     models.EmailField( |                     models.EmailField(blank=True, max_length=254, verbose_name="email address"), | ||||||
|                         blank=True, max_length=254, verbose_name="email address" |  | ||||||
|                     ), |  | ||||||
|                 ), |                 ), | ||||||
|                 ( |                 ( | ||||||
|                     "is_staff", |                     "is_staff", | ||||||
| @ -217,9 +205,7 @@ class Migration(migrations.Migration): | |||||||
|                 ), |                 ), | ||||||
|                 ( |                 ( | ||||||
|                     "expires", |                     "expires", | ||||||
|                     models.DateTimeField( |                     models.DateTimeField(default=authentik.core.models.default_token_duration), | ||||||
|                         default=authentik.core.models.default_token_duration |  | ||||||
|                     ), |  | ||||||
|                 ), |                 ), | ||||||
|                 ("expiring", models.BooleanField(default=True)), |                 ("expiring", models.BooleanField(default=True)), | ||||||
|                 ("description", models.TextField(blank=True, default="")), |                 ("description", models.TextField(blank=True, default="")), | ||||||
| @ -306,9 +292,7 @@ class Migration(migrations.Migration): | |||||||
|                 ("name", models.TextField(help_text="Application's display Name.")), |                 ("name", models.TextField(help_text="Application's display Name.")), | ||||||
|                 ( |                 ( | ||||||
|                     "slug", |                     "slug", | ||||||
|                     models.SlugField( |                     models.SlugField(help_text="Internal application name, used in URLs."), | ||||||
|                         help_text="Internal application name, used in URLs." |  | ||||||
|                     ), |  | ||||||
|                 ), |                 ), | ||||||
|                 ("skip_authorization", models.BooleanField(default=False)), |                 ("skip_authorization", models.BooleanField(default=False)), | ||||||
|                 ("meta_launch_url", models.URLField(blank=True, default="")), |                 ("meta_launch_url", models.URLField(blank=True, default="")), | ||||||
|  | |||||||
| @ -17,9 +17,7 @@ def create_default_user(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): | |||||||
|         username="akadmin", email="root@localhost", name="authentik Default Admin" |         username="akadmin", email="root@localhost", name="authentik Default Admin" | ||||||
|     ) |     ) | ||||||
|     if "TF_BUILD" in environ or "AK_ADMIN_PASS" in environ or settings.TEST: |     if "TF_BUILD" in environ or "AK_ADMIN_PASS" in environ or settings.TEST: | ||||||
|         akadmin.set_password( |         akadmin.set_password(environ.get("AK_ADMIN_PASS", "akadmin"), signal=False)  # noqa # nosec | ||||||
|             environ.get("AK_ADMIN_PASS", "akadmin"), signal=False |  | ||||||
|         )  # noqa # nosec |  | ||||||
|     else: |     else: | ||||||
|         akadmin.set_unusable_password() |         akadmin.set_unusable_password() | ||||||
|     akadmin.save() |     akadmin.save() | ||||||
|  | |||||||
| @ -13,8 +13,6 @@ class Migration(migrations.Migration): | |||||||
|         migrations.AlterField( |         migrations.AlterField( | ||||||
|             model_name="source", |             model_name="source", | ||||||
|             name="slug", |             name="slug", | ||||||
|             field=models.SlugField( |             field=models.SlugField(help_text="Internal source name, used in URLs.", unique=True), | ||||||
|                 help_text="Internal source name, used in URLs.", unique=True |  | ||||||
|             ), |  | ||||||
|         ), |         ), | ||||||
|     ] |     ] | ||||||
|  | |||||||
| @ -13,8 +13,6 @@ class Migration(migrations.Migration): | |||||||
|         migrations.AlterField( |         migrations.AlterField( | ||||||
|             model_name="user", |             model_name="user", | ||||||
|             name="first_name", |             name="first_name", | ||||||
|             field=models.CharField( |             field=models.CharField(blank=True, max_length=150, verbose_name="first name"), | ||||||
|                 blank=True, max_length=150, verbose_name="first name" |  | ||||||
|             ), |  | ||||||
|         ), |         ), | ||||||
|     ] |     ] | ||||||
|  | |||||||
| @ -40,9 +40,7 @@ class Migration(migrations.Migration): | |||||||
|         migrations.AlterField( |         migrations.AlterField( | ||||||
|             model_name="user", |             model_name="user", | ||||||
|             name="pb_groups", |             name="pb_groups", | ||||||
|             field=models.ManyToManyField( |             field=models.ManyToManyField(related_name="users", to="authentik_core.Group"), | ||||||
|                 related_name="users", to="authentik_core.Group" |  | ||||||
|             ), |  | ||||||
|         ), |         ), | ||||||
|         migrations.AddField( |         migrations.AddField( | ||||||
|             model_name="group", |             model_name="group", | ||||||
|  | |||||||
| @ -42,9 +42,7 @@ class Migration(migrations.Migration): | |||||||
|         ), |         ), | ||||||
|         migrations.AddIndex( |         migrations.AddIndex( | ||||||
|             model_name="token", |             model_name="token", | ||||||
|             index=models.Index( |             index=models.Index(fields=["identifier"], name="authentik_co_identif_1a34a8_idx"), | ||||||
|                 fields=["identifier"], name="authentik_co_identif_1a34a8_idx" |  | ||||||
|             ), |  | ||||||
|         ), |         ), | ||||||
|         migrations.RunPython(set_default_token_key), |         migrations.RunPython(set_default_token_key), | ||||||
|     ] |     ] | ||||||
|  | |||||||
| @ -17,8 +17,6 @@ class Migration(migrations.Migration): | |||||||
|         migrations.AddField( |         migrations.AddField( | ||||||
|             model_name="application", |             model_name="application", | ||||||
|             name="meta_icon", |             name="meta_icon", | ||||||
|             field=models.FileField( |             field=models.FileField(blank=True, default="", upload_to="application-icons/"), | ||||||
|                 blank=True, default="", upload_to="application-icons/" |  | ||||||
|             ), |  | ||||||
|         ), |         ), | ||||||
|     ] |     ] | ||||||
|  | |||||||
| @ -25,9 +25,7 @@ class Migration(migrations.Migration): | |||||||
|         ), |         ), | ||||||
|         migrations.AddIndex( |         migrations.AddIndex( | ||||||
|             model_name="token", |             model_name="token", | ||||||
|             index=models.Index( |             index=models.Index(fields=["identifier"], name="authentik_c_identif_d9d032_idx"), | ||||||
|                 fields=["identifier"], name="authentik_c_identif_d9d032_idx" |  | ||||||
|             ), |  | ||||||
|         ), |         ), | ||||||
|         migrations.AddIndex( |         migrations.AddIndex( | ||||||
|             model_name="token", |             model_name="token", | ||||||
|  | |||||||
| @ -32,16 +32,12 @@ class Migration(migrations.Migration): | |||||||
|             fields=[ |             fields=[ | ||||||
|                 ( |                 ( | ||||||
|                     "expires", |                     "expires", | ||||||
|                     models.DateTimeField( |                     models.DateTimeField(default=authentik.core.models.default_token_duration), | ||||||
|                         default=authentik.core.models.default_token_duration |  | ||||||
|                     ), |  | ||||||
|                 ), |                 ), | ||||||
|                 ("expiring", models.BooleanField(default=True)), |                 ("expiring", models.BooleanField(default=True)), | ||||||
|                 ( |                 ( | ||||||
|                     "uuid", |                     "uuid", | ||||||
|                     models.UUIDField( |                     models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False), | ||||||
|                         default=uuid.uuid4, primary_key=True, serialize=False |  | ||||||
|                     ), |  | ||||||
|                 ), |                 ), | ||||||
|                 ("session_key", models.CharField(max_length=40)), |                 ("session_key", models.CharField(max_length=40)), | ||||||
|                 ("last_ip", models.TextField()), |                 ("last_ip", models.TextField()), | ||||||
|  | |||||||
| @ -13,8 +13,6 @@ class Migration(migrations.Migration): | |||||||
|         migrations.AlterField( |         migrations.AlterField( | ||||||
|             model_name="application", |             model_name="application", | ||||||
|             name="meta_icon", |             name="meta_icon", | ||||||
|             field=models.FileField( |             field=models.FileField(default=None, null=True, upload_to="application-icons/"), | ||||||
|                 default=None, null=True, upload_to="application-icons/" |  | ||||||
|             ), |  | ||||||
|         ), |         ), | ||||||
|     ] |     ] | ||||||
|  | |||||||
| @ -154,9 +154,7 @@ class User(GuardianUserMixin, AbstractUser): | |||||||
|                 ("s", "158"), |                 ("s", "158"), | ||||||
|                 ("r", "g"), |                 ("r", "g"), | ||||||
|             ] |             ] | ||||||
|             gravatar_url = ( |             gravatar_url = f"{GRAVATAR_URL}/avatar/{mail_hash}?{urlencode(parameters, doseq=True)}" | ||||||
|                 f"{GRAVATAR_URL}/avatar/{mail_hash}?{urlencode(parameters, doseq=True)}" |  | ||||||
|             ) |  | ||||||
|             return escape(gravatar_url) |             return escape(gravatar_url) | ||||||
|         return mode % { |         return mode % { | ||||||
|             "username": self.username, |             "username": self.username, | ||||||
| @ -186,9 +184,7 @@ class Provider(SerializerModel): | |||||||
|         related_name="provider_authorization", |         related_name="provider_authorization", | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     property_mappings = models.ManyToManyField( |     property_mappings = models.ManyToManyField("PropertyMapping", default=None, blank=True) | ||||||
|         "PropertyMapping", default=None, blank=True |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     objects = InheritanceManager() |     objects = InheritanceManager() | ||||||
|  |  | ||||||
| @ -218,9 +214,7 @@ class Application(PolicyBindingModel): | |||||||
|     add custom fields and other properties""" |     add custom fields and other properties""" | ||||||
|  |  | ||||||
|     name = models.TextField(help_text=_("Application's display Name.")) |     name = models.TextField(help_text=_("Application's display Name.")) | ||||||
|     slug = models.SlugField( |     slug = models.SlugField(help_text=_("Internal application name, used in URLs."), unique=True) | ||||||
|         help_text=_("Internal application name, used in URLs."), unique=True |  | ||||||
|     ) |  | ||||||
|     provider = models.OneToOneField( |     provider = models.OneToOneField( | ||||||
|         "Provider", null=True, blank=True, default=None, on_delete=models.SET_DEFAULT |         "Provider", null=True, blank=True, default=None, on_delete=models.SET_DEFAULT | ||||||
|     ) |     ) | ||||||
| @ -244,9 +238,7 @@ class Application(PolicyBindingModel): | |||||||
|         it is returned as-is""" |         it is returned as-is""" | ||||||
|         if not self.meta_icon: |         if not self.meta_icon: | ||||||
|             return None |             return None | ||||||
|         if self.meta_icon.name.startswith("http") or self.meta_icon.name.startswith( |         if self.meta_icon.name.startswith("http") or self.meta_icon.name.startswith("/static"): | ||||||
|             "/static" |  | ||||||
|         ): |  | ||||||
|             return self.meta_icon.name |             return self.meta_icon.name | ||||||
|         return self.meta_icon.url |         return self.meta_icon.url | ||||||
|  |  | ||||||
| @ -301,14 +293,10 @@ class Source(ManagedModel, SerializerModel, PolicyBindingModel): | |||||||
|     """Base Authentication source, i.e. an OAuth Provider, SAML Remote or LDAP Server""" |     """Base Authentication source, i.e. an OAuth Provider, SAML Remote or LDAP Server""" | ||||||
|  |  | ||||||
|     name = models.TextField(help_text=_("Source's display Name.")) |     name = models.TextField(help_text=_("Source's display Name.")) | ||||||
|     slug = models.SlugField( |     slug = models.SlugField(help_text=_("Internal source name, used in URLs."), unique=True) | ||||||
|         help_text=_("Internal source name, used in URLs."), unique=True |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     enabled = models.BooleanField(default=True) |     enabled = models.BooleanField(default=True) | ||||||
|     property_mappings = models.ManyToManyField( |     property_mappings = models.ManyToManyField("PropertyMapping", default=None, blank=True) | ||||||
|         "PropertyMapping", default=None, blank=True |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     authentication_flow = models.ForeignKey( |     authentication_flow = models.ForeignKey( | ||||||
|         Flow, |         Flow, | ||||||
| @ -481,9 +469,7 @@ class PropertyMapping(SerializerModel, ManagedModel): | |||||||
|         """Get serializer for this model""" |         """Get serializer for this model""" | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|     def evaluate( |     def evaluate(self, user: Optional[User], request: Optional[HttpRequest], **kwargs) -> Any: | ||||||
|         self, user: Optional[User], request: Optional[HttpRequest], **kwargs |  | ||||||
|     ) -> Any: |  | ||||||
|         """Evaluate `self.expression` using `**kwargs` as Context.""" |         """Evaluate `self.expression` using `**kwargs` as Context.""" | ||||||
|         from authentik.core.expression import PropertyMappingEvaluator |         from authentik.core.expression import PropertyMappingEvaluator | ||||||
|  |  | ||||||
| @ -522,9 +508,7 @@ class AuthenticatedSession(ExpiringModel): | |||||||
|     last_used = models.DateTimeField(auto_now=True) |     last_used = models.DateTimeField(auto_now=True) | ||||||
|  |  | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def from_request( |     def from_request(request: HttpRequest, user: User) -> Optional["AuthenticatedSession"]: | ||||||
|         request: HttpRequest, user: User |  | ||||||
|     ) -> Optional["AuthenticatedSession"]: |  | ||||||
|         """Create a new session from a http request""" |         """Create a new session from a http request""" | ||||||
|         if not hasattr(request, "session") or not request.session.session_key: |         if not hasattr(request, "session") or not request.session.session_key: | ||||||
|             return None |             return None | ||||||
|  | |||||||
| @ -14,9 +14,7 @@ from prometheus_client import Gauge | |||||||
| # Arguments: user: User, password: str | # Arguments: user: User, password: str | ||||||
| password_changed = Signal() | password_changed = Signal() | ||||||
|  |  | ||||||
| GAUGE_MODELS = Gauge( | GAUGE_MODELS = Gauge("authentik_models", "Count of various objects", ["model_name", "app"]) | ||||||
|     "authentik_models", "Count of various objects", ["model_name", "app"] |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||||
|     from authentik.core.models import AuthenticatedSession, User |     from authentik.core.models import AuthenticatedSession, User | ||||||
| @ -60,15 +58,11 @@ def user_logged_out_session(sender, request: HttpRequest, user: "User", **_): | |||||||
|     """Delete AuthenticatedSession if it exists""" |     """Delete AuthenticatedSession if it exists""" | ||||||
|     from authentik.core.models import AuthenticatedSession |     from authentik.core.models import AuthenticatedSession | ||||||
|  |  | ||||||
|     AuthenticatedSession.objects.filter( |     AuthenticatedSession.objects.filter(session_key=request.session.session_key).delete() | ||||||
|         session_key=request.session.session_key |  | ||||||
|     ).delete() |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @receiver(pre_delete) | @receiver(pre_delete) | ||||||
| def authenticated_session_delete( | def authenticated_session_delete(sender: Type[Model], instance: "AuthenticatedSession", **_): | ||||||
|     sender: Type[Model], instance: "AuthenticatedSession", **_ |  | ||||||
| ): |  | ||||||
|     """Delete session when authenticated session is deleted""" |     """Delete session when authenticated session is deleted""" | ||||||
|     from authentik.core.models import AuthenticatedSession |     from authentik.core.models import AuthenticatedSession | ||||||
|  |  | ||||||
|  | |||||||
| @ -11,16 +11,8 @@ from django.urls import reverse | |||||||
| from django.utils.translation import gettext as _ | from django.utils.translation import gettext as _ | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik.core.models import ( | from authentik.core.models import Source, SourceUserMatchingModes, User, UserSourceConnection | ||||||
|     Source, | from authentik.core.sources.stage import PLAN_CONTEXT_SOURCES_CONNECTION, PostUserEnrollmentStage | ||||||
|     SourceUserMatchingModes, |  | ||||||
|     User, |  | ||||||
|     UserSourceConnection, |  | ||||||
| ) |  | ||||||
| from authentik.core.sources.stage import ( |  | ||||||
|     PLAN_CONTEXT_SOURCES_CONNECTION, |  | ||||||
|     PostUserEnrollmentStage, |  | ||||||
| ) |  | ||||||
| from authentik.events.models import Event, EventAction | from authentik.events.models import Event, EventAction | ||||||
| from authentik.flows.models import Flow, Stage, in_memory_stage | from authentik.flows.models import Flow, Stage, in_memory_stage | ||||||
| from authentik.flows.planner import ( | from authentik.flows.planner import ( | ||||||
| @ -76,9 +68,7 @@ class SourceFlowManager: | |||||||
|     # pylint: disable=too-many-return-statements |     # pylint: disable=too-many-return-statements | ||||||
|     def get_action(self, **kwargs) -> tuple[Action, Optional[UserSourceConnection]]: |     def get_action(self, **kwargs) -> tuple[Action, Optional[UserSourceConnection]]: | ||||||
|         """decide which action should be taken""" |         """decide which action should be taken""" | ||||||
|         new_connection = self.connection_type( |         new_connection = self.connection_type(source=self.source, identifier=self.identifier) | ||||||
|             source=self.source, identifier=self.identifier |  | ||||||
|         ) |  | ||||||
|         # When request is authenticated, always link |         # When request is authenticated, always link | ||||||
|         if self.request.user.is_authenticated: |         if self.request.user.is_authenticated: | ||||||
|             new_connection.user = self.request.user |             new_connection.user = self.request.user | ||||||
| @ -113,9 +103,7 @@ class SourceFlowManager: | |||||||
|             SourceUserMatchingModes.USERNAME_DENY, |             SourceUserMatchingModes.USERNAME_DENY, | ||||||
|         ]: |         ]: | ||||||
|             if not self.enroll_info.get("username", None): |             if not self.enroll_info.get("username", None): | ||||||
|                 self._logger.warning( |                 self._logger.warning("Refusing to use none username", source=self.source) | ||||||
|                     "Refusing to use none username", source=self.source |  | ||||||
|                 ) |  | ||||||
|                 return Action.DENY, None |                 return Action.DENY, None | ||||||
|             query = Q(username__exact=self.enroll_info.get("username", None)) |             query = Q(username__exact=self.enroll_info.get("username", None)) | ||||||
|         self._logger.debug("trying to link with existing user", query=query) |         self._logger.debug("trying to link with existing user", query=query) | ||||||
| @ -229,10 +217,7 @@ class SourceFlowManager: | |||||||
|         """Login user and redirect.""" |         """Login user and redirect.""" | ||||||
|         messages.success( |         messages.success( | ||||||
|             self.request, |             self.request, | ||||||
|             _( |             _("Successfully authenticated with %(source)s!" % {"source": self.source.name}), | ||||||
|                 "Successfully authenticated with %(source)s!" |  | ||||||
|                 % {"source": self.source.name} |  | ||||||
|             ), |  | ||||||
|         ) |         ) | ||||||
|         flow_kwargs = {PLAN_CONTEXT_PENDING_USER: connection.user} |         flow_kwargs = {PLAN_CONTEXT_PENDING_USER: connection.user} | ||||||
|         return self._handle_login_flow(self.source.authentication_flow, **flow_kwargs) |         return self._handle_login_flow(self.source.authentication_flow, **flow_kwargs) | ||||||
| @ -270,10 +255,7 @@ class SourceFlowManager: | |||||||
|         """User was not authenticated and previous request was not authenticated.""" |         """User was not authenticated and previous request was not authenticated.""" | ||||||
|         messages.success( |         messages.success( | ||||||
|             self.request, |             self.request, | ||||||
|             _( |             _("Successfully authenticated with %(source)s!" % {"source": self.source.name}), | ||||||
|                 "Successfully authenticated with %(source)s!" |  | ||||||
|                 % {"source": self.source.name} |  | ||||||
|             ), |  | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|         # We run the Flow planner here so we can pass the Pending user in the context |         # We run the Flow planner here so we can pass the Pending user in the context | ||||||
|  | |||||||
| @ -27,9 +27,7 @@ def clean_expired_models(self: MonitoredTask): | |||||||
|     for cls in ExpiringModel.__subclasses__(): |     for cls in ExpiringModel.__subclasses__(): | ||||||
|         cls: ExpiringModel |         cls: ExpiringModel | ||||||
|         objects = ( |         objects = ( | ||||||
|             cls.objects.all() |             cls.objects.all().exclude(expiring=False).exclude(expiring=True, expires__gt=now()) | ||||||
|             .exclude(expiring=False) |  | ||||||
|             .exclude(expiring=True, expires__gt=now()) |  | ||||||
|         ) |         ) | ||||||
|         for obj in objects: |         for obj in objects: | ||||||
|             obj.expire_action() |             obj.expire_action() | ||||||
|  | |||||||
| @ -17,9 +17,7 @@ class TestApplicationsAPI(APITestCase): | |||||||
|         self.denied = Application.objects.create(name="denied", slug="denied") |         self.denied = Application.objects.create(name="denied", slug="denied") | ||||||
|         PolicyBinding.objects.create( |         PolicyBinding.objects.create( | ||||||
|             target=self.denied, |             target=self.denied, | ||||||
|             policy=DummyPolicy.objects.create( |             policy=DummyPolicy.objects.create(name="deny", result=False, wait_min=1, wait_max=2), | ||||||
|                 name="deny", result=False, wait_min=1, wait_max=2 |  | ||||||
|             ), |  | ||||||
|             order=0, |             order=0, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
| @ -33,9 +31,7 @@ class TestApplicationsAPI(APITestCase): | |||||||
|             ) |             ) | ||||||
|         ) |         ) | ||||||
|         self.assertEqual(response.status_code, 200) |         self.assertEqual(response.status_code, 200) | ||||||
|         self.assertJSONEqual( |         self.assertJSONEqual(force_str(response.content), {"messages": [], "passing": True}) | ||||||
|             force_str(response.content), {"messages": [], "passing": True} |  | ||||||
|         ) |  | ||||||
|         response = self.client.get( |         response = self.client.get( | ||||||
|             reverse( |             reverse( | ||||||
|                 "authentik_api:application-check-access", |                 "authentik_api:application-check-access", | ||||||
| @ -43,9 +39,7 @@ class TestApplicationsAPI(APITestCase): | |||||||
|             ) |             ) | ||||||
|         ) |         ) | ||||||
|         self.assertEqual(response.status_code, 200) |         self.assertEqual(response.status_code, 200) | ||||||
|         self.assertJSONEqual( |         self.assertJSONEqual(force_str(response.content), {"messages": ["dummy"], "passing": False}) | ||||||
|             force_str(response.content), {"messages": ["dummy"], "passing": False} |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def test_list(self): |     def test_list(self): | ||||||
|         """Test list operation without superuser_full_list""" |         """Test list operation without superuser_full_list""" | ||||||
|  | |||||||
| @ -46,9 +46,7 @@ class TestImpersonation(TestCase): | |||||||
|         self.client.force_login(self.other_user) |         self.client.force_login(self.other_user) | ||||||
|  |  | ||||||
|         self.client.get( |         self.client.get( | ||||||
|             reverse( |             reverse("authentik_core:impersonate-init", kwargs={"user_id": self.akadmin.pk}) | ||||||
|                 "authentik_core:impersonate-init", kwargs={"user_id": self.akadmin.pk} |  | ||||||
|             ) |  | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|         response = self.client.get(reverse("authentik_api:user-me")) |         response = self.client.get(reverse("authentik_api:user-me")) | ||||||
|  | |||||||
| @ -22,9 +22,7 @@ class TestModels(TestCase): | |||||||
|  |  | ||||||
|     def test_token_expire_no_expire(self): |     def test_token_expire_no_expire(self): | ||||||
|         """Test token expiring with "expiring" set""" |         """Test token expiring with "expiring" set""" | ||||||
|         token = Token.objects.create( |         token = Token.objects.create(expires=now(), user=get_anonymous_user(), expiring=False) | ||||||
|             expires=now(), user=get_anonymous_user(), expiring=False |  | ||||||
|         ) |  | ||||||
|         sleep(0.5) |         sleep(0.5) | ||||||
|         self.assertFalse(token.is_expired) |         self.assertFalse(token.is_expired) | ||||||
|  |  | ||||||
|  | |||||||
| @ -16,9 +16,7 @@ class TestPropertyMappings(TestCase): | |||||||
|  |  | ||||||
|     def test_expression(self): |     def test_expression(self): | ||||||
|         """Test expression""" |         """Test expression""" | ||||||
|         mapping = PropertyMapping.objects.create( |         mapping = PropertyMapping.objects.create(name="test", expression="return 'test'") | ||||||
|             name="test", expression="return 'test'" |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(mapping.evaluate(None, None), "test") |         self.assertEqual(mapping.evaluate(None, None), "test") | ||||||
|  |  | ||||||
|     def test_expression_syntax(self): |     def test_expression_syntax(self): | ||||||
|  | |||||||
| @ -23,9 +23,7 @@ class TestPropertyMappingAPI(APITestCase): | |||||||
|     def test_test_call(self): |     def test_test_call(self): | ||||||
|         """Test PropertMappings's test endpoint""" |         """Test PropertMappings's test endpoint""" | ||||||
|         response = self.client.post( |         response = self.client.post( | ||||||
|             reverse( |             reverse("authentik_api:propertymapping-test", kwargs={"pk": self.mapping.pk}), | ||||||
|                 "authentik_api:propertymapping-test", kwargs={"pk": self.mapping.pk} |  | ||||||
|             ), |  | ||||||
|             data={ |             data={ | ||||||
|                 "user": self.user.pk, |                 "user": self.user.pk, | ||||||
|             }, |             }, | ||||||
|  | |||||||
| @ -4,12 +4,7 @@ from django.utils.timezone import now | |||||||
| from guardian.shortcuts import get_anonymous_user | from guardian.shortcuts import get_anonymous_user | ||||||
| from rest_framework.test import APITestCase | from rest_framework.test import APITestCase | ||||||
|  |  | ||||||
| from authentik.core.models import ( | from authentik.core.models import USER_ATTRIBUTE_TOKEN_EXPIRING, Token, TokenIntents, User | ||||||
|     USER_ATTRIBUTE_TOKEN_EXPIRING, |  | ||||||
|     Token, |  | ||||||
|     TokenIntents, |  | ||||||
|     User, |  | ||||||
| ) |  | ||||||
| from authentik.core.tasks import clean_expired_models | from authentik.core.tasks import clean_expired_models | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| @ -5,10 +5,7 @@ from django.shortcuts import get_object_or_404, redirect | |||||||
| from django.views import View | from django.views import View | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik.core.middleware import ( | from authentik.core.middleware import SESSION_IMPERSONATE_ORIGINAL_USER, SESSION_IMPERSONATE_USER | ||||||
|     SESSION_IMPERSONATE_ORIGINAL_USER, |  | ||||||
|     SESSION_IMPERSONATE_USER, |  | ||||||
| ) |  | ||||||
| from authentik.core.models import User | from authentik.core.models import User | ||||||
| from authentik.events.models import Event, EventAction | from authentik.events.models import Event, EventAction | ||||||
|  |  | ||||||
| @ -21,9 +18,7 @@ class ImpersonateInitView(View): | |||||||
|     def get(self, request: HttpRequest, user_id: int) -> HttpResponse: |     def get(self, request: HttpRequest, user_id: int) -> HttpResponse: | ||||||
|         """Impersonation handler, checks permissions""" |         """Impersonation handler, checks permissions""" | ||||||
|         if not request.user.has_perm("impersonate"): |         if not request.user.has_perm("impersonate"): | ||||||
|             LOGGER.debug( |             LOGGER.debug("User attempted to impersonate without permissions", user=request.user) | ||||||
|                 "User attempted to impersonate without permissions", user=request.user |  | ||||||
|             ) |  | ||||||
|             return HttpResponse("Unauthorized", status=401) |             return HttpResponse("Unauthorized", status=401) | ||||||
|  |  | ||||||
|         user_to_be = get_object_or_404(User, pk=user_id) |         user_to_be = get_object_or_404(User, pk=user_id) | ||||||
|  | |||||||
| @ -14,9 +14,7 @@ class EndSessionView(TemplateView, PolicyAccessView): | |||||||
|     template_name = "if/end_session.html" |     template_name = "if/end_session.html" | ||||||
|  |  | ||||||
|     def resolve_provider_application(self): |     def resolve_provider_application(self): | ||||||
|         self.application = get_object_or_404( |         self.application = get_object_or_404(Application, slug=self.kwargs["application_slug"]) | ||||||
|             Application, slug=self.kwargs["application_slug"] |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def get_context_data(self, **kwargs: Any) -> dict[str, Any]: |     def get_context_data(self, **kwargs: Any) -> dict[str, Any]: | ||||||
|         context = super().get_context_data(**kwargs) |         context = super().get_context_data(**kwargs) | ||||||
|  | |||||||
| @ -10,12 +10,7 @@ from django_filters.filters import BooleanFilter | |||||||
| from drf_spectacular.types import OpenApiTypes | from drf_spectacular.types import OpenApiTypes | ||||||
| from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema | from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema | ||||||
| from rest_framework.decorators import action | from rest_framework.decorators import action | ||||||
| from rest_framework.fields import ( | from rest_framework.fields import CharField, DateTimeField, IntegerField, SerializerMethodField | ||||||
|     CharField, |  | ||||||
|     DateTimeField, |  | ||||||
|     IntegerField, |  | ||||||
|     SerializerMethodField, |  | ||||||
| ) |  | ||||||
| from rest_framework.request import Request | from rest_framework.request import Request | ||||||
| from rest_framework.response import Response | from rest_framework.response import Response | ||||||
| from rest_framework.serializers import ModelSerializer, ValidationError | from rest_framework.serializers import ModelSerializer, ValidationError | ||||||
| @ -86,9 +81,7 @@ class CertificateKeyPairSerializer(ModelSerializer): | |||||||
|                     backend=default_backend(), |                     backend=default_backend(), | ||||||
|                 ) |                 ) | ||||||
|             except (ValueError, TypeError): |             except (ValueError, TypeError): | ||||||
|                 raise ValidationError( |                 raise ValidationError("Unable to load private key (possibly encrypted?).") | ||||||
|                     "Unable to load private key (possibly encrypted?)." |  | ||||||
|                 ) |  | ||||||
|         return value |         return value | ||||||
|  |  | ||||||
|     class Meta: |     class Meta: | ||||||
| @ -123,9 +116,7 @@ class CertificateGenerationSerializer(PassiveSerializer): | |||||||
|     """Certificate generation parameters""" |     """Certificate generation parameters""" | ||||||
|  |  | ||||||
|     common_name = CharField() |     common_name = CharField() | ||||||
|     subject_alt_name = CharField( |     subject_alt_name = CharField(required=False, allow_blank=True, label=_("Subject-alt name")) | ||||||
|         required=False, allow_blank=True, label=_("Subject-alt name") |  | ||||||
|     ) |  | ||||||
|     validity_days = IntegerField(initial=365) |     validity_days = IntegerField(initial=365) | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -170,9 +161,7 @@ class CertificateKeyPairViewSet(UsedByMixin, ModelViewSet): | |||||||
|         builder = CertificateBuilder() |         builder = CertificateBuilder() | ||||||
|         builder.common_name = data.validated_data["common_name"] |         builder.common_name = data.validated_data["common_name"] | ||||||
|         builder.build( |         builder.build( | ||||||
|             subject_alt_names=data.validated_data.get("subject_alt_name", "").split( |             subject_alt_names=data.validated_data.get("subject_alt_name", "").split(","), | ||||||
|                 "," |  | ||||||
|             ), |  | ||||||
|             validity_days=int(data.validated_data["validity_days"]), |             validity_days=int(data.validated_data["validity_days"]), | ||||||
|         ) |         ) | ||||||
|         instance = builder.save() |         instance = builder.save() | ||||||
| @ -208,9 +197,7 @@ class CertificateKeyPairViewSet(UsedByMixin, ModelViewSet): | |||||||
|                 "Content-Disposition" |                 "Content-Disposition" | ||||||
|             ] = f'attachment; filename="{certificate.name}_certificate.pem"' |             ] = f'attachment; filename="{certificate.name}_certificate.pem"' | ||||||
|             return response |             return response | ||||||
|         return Response( |         return Response(CertificateDataSerializer({"data": certificate.certificate_data}).data) | ||||||
|             CertificateDataSerializer({"data": certificate.certificate_data}).data |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     @extend_schema( |     @extend_schema( | ||||||
|         parameters=[ |         parameters=[ | ||||||
| @ -234,9 +221,7 @@ class CertificateKeyPairViewSet(UsedByMixin, ModelViewSet): | |||||||
|         ).from_http(request) |         ).from_http(request) | ||||||
|         if "download" in request._request.GET: |         if "download" in request._request.GET: | ||||||
|             # Mime type from https://pki-tutorial.readthedocs.io/en/latest/mime.html |             # Mime type from https://pki-tutorial.readthedocs.io/en/latest/mime.html | ||||||
|             response = HttpResponse( |             response = HttpResponse(certificate.key_data, content_type="application/x-pem-file") | ||||||
|                 certificate.key_data, content_type="application/x-pem-file" |  | ||||||
|             ) |  | ||||||
|             response[ |             response[ | ||||||
|                 "Content-Disposition" |                 "Content-Disposition" | ||||||
|             ] = f'attachment; filename="{certificate.name}_private_key.pem"' |             ] = f'attachment; filename="{certificate.name}_private_key.pem"' | ||||||
|  | |||||||
| @ -46,9 +46,7 @@ class CertificateBuilder: | |||||||
|             public_exponent=65537, key_size=2048, backend=default_backend() |             public_exponent=65537, key_size=2048, backend=default_backend() | ||||||
|         ) |         ) | ||||||
|         self.__public_key = self.__private_key.public_key() |         self.__public_key = self.__private_key.public_key() | ||||||
|         alt_names: list[x509.GeneralName] = [ |         alt_names: list[x509.GeneralName] = [x509.DNSName(x) for x in subject_alt_names or []] | ||||||
|             x509.DNSName(x) for x in subject_alt_names or [] |  | ||||||
|         ] |  | ||||||
|         self.__builder = ( |         self.__builder = ( | ||||||
|             x509.CertificateBuilder() |             x509.CertificateBuilder() | ||||||
|             .subject_name( |             .subject_name( | ||||||
| @ -59,9 +57,7 @@ class CertificateBuilder: | |||||||
|                             self.common_name, |                             self.common_name, | ||||||
|                         ), |                         ), | ||||||
|                         x509.NameAttribute(NameOID.ORGANIZATION_NAME, "authentik"), |                         x509.NameAttribute(NameOID.ORGANIZATION_NAME, "authentik"), | ||||||
|                         x509.NameAttribute( |                         x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, "Self-signed"), | ||||||
|                             NameOID.ORGANIZATIONAL_UNIT_NAME, "Self-signed" |  | ||||||
|                         ), |  | ||||||
|                     ] |                     ] | ||||||
|                 ) |                 ) | ||||||
|             ) |             ) | ||||||
| @ -77,9 +73,7 @@ class CertificateBuilder: | |||||||
|             ) |             ) | ||||||
|             .add_extension(x509.SubjectAlternativeName(alt_names), critical=True) |             .add_extension(x509.SubjectAlternativeName(alt_names), critical=True) | ||||||
|             .not_valid_before(datetime.datetime.today() - one_day) |             .not_valid_before(datetime.datetime.today() - one_day) | ||||||
|             .not_valid_after( |             .not_valid_after(datetime.datetime.today() + datetime.timedelta(days=validity_days)) | ||||||
|                 datetime.datetime.today() + datetime.timedelta(days=validity_days) |  | ||||||
|             ) |  | ||||||
|             .serial_number(int(uuid.uuid4())) |             .serial_number(int(uuid.uuid4())) | ||||||
|             .public_key(self.__public_key) |             .public_key(self.__public_key) | ||||||
|         ) |         ) | ||||||
|  | |||||||
| @ -57,9 +57,7 @@ class CertificateKeyPair(CreatedUpdatedModel): | |||||||
|         if not self._private_key and self._private_key != "": |         if not self._private_key and self._private_key != "": | ||||||
|             try: |             try: | ||||||
|                 self._private_key = load_pem_private_key( |                 self._private_key = load_pem_private_key( | ||||||
|                     str.encode( |                     str.encode("\n".join([x.strip() for x in self.key_data.split("\n")])), | ||||||
|                         "\n".join([x.strip() for x in self.key_data.split("\n")]) |  | ||||||
|                     ), |  | ||||||
|                     password=None, |                     password=None, | ||||||
|                     backend=default_backend(), |                     backend=default_backend(), | ||||||
|                 ) |                 ) | ||||||
| @ -70,25 +68,19 @@ class CertificateKeyPair(CreatedUpdatedModel): | |||||||
|     @property |     @property | ||||||
|     def fingerprint_sha256(self) -> str: |     def fingerprint_sha256(self) -> str: | ||||||
|         """Get SHA256 Fingerprint of certificate_data""" |         """Get SHA256 Fingerprint of certificate_data""" | ||||||
|         return hexlify(self.certificate.fingerprint(hashes.SHA256()), ":").decode( |         return hexlify(self.certificate.fingerprint(hashes.SHA256()), ":").decode("utf-8") | ||||||
|             "utf-8" |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def fingerprint_sha1(self) -> str: |     def fingerprint_sha1(self) -> str: | ||||||
|         """Get SHA1 Fingerprint of certificate_data""" |         """Get SHA1 Fingerprint of certificate_data""" | ||||||
|         return hexlify( |         return hexlify(self.certificate.fingerprint(hashes.SHA1()), ":").decode("utf-8")  # nosec | ||||||
|             self.certificate.fingerprint(hashes.SHA1()), ":"  # nosec |  | ||||||
|         ).decode("utf-8") |  | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def kid(self): |     def kid(self): | ||||||
|         """Get Key ID used for JWKS""" |         """Get Key ID used for JWKS""" | ||||||
|         return "{0}".format( |         return "{0}".format( | ||||||
|             md5(self.key_data.encode("utf-8")).hexdigest()  # nosec |             md5(self.key_data.encode("utf-8")).hexdigest() if self.key_data else "" | ||||||
|             if self.key_data |         )  # nosec | ||||||
|             else "" |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def __str__(self) -> str: |     def __str__(self) -> str: | ||||||
|         return f"Certificate-Key Pair {self.name}" |         return f"Certificate-Key Pair {self.name}" | ||||||
|  | |||||||
| @ -143,7 +143,5 @@ class EventViewSet(ModelViewSet): | |||||||
|         """Get all actions""" |         """Get all actions""" | ||||||
|         data = [] |         data = [] | ||||||
|         for value, name in EventAction.choices: |         for value, name in EventAction.choices: | ||||||
|             data.append( |             data.append({"name": name, "description": "", "component": value, "model_name": ""}) | ||||||
|                 {"name": name, "description": "", "component": value, "model_name": ""} |  | ||||||
|             ) |  | ||||||
|         return Response(TypeCreateSerializer(data, many=True).data) |         return Response(TypeCreateSerializer(data, many=True).data) | ||||||
|  | |||||||
| @ -29,12 +29,8 @@ class AuditMiddleware: | |||||||
|  |  | ||||||
|     def __call__(self, request: HttpRequest) -> HttpResponse: |     def __call__(self, request: HttpRequest) -> HttpResponse: | ||||||
|         # Connect signal for automatic logging |         # Connect signal for automatic logging | ||||||
|         if hasattr(request, "user") and getattr( |         if hasattr(request, "user") and getattr(request.user, "is_authenticated", False): | ||||||
|             request.user, "is_authenticated", False |             post_save_handler = partial(self.post_save_handler, user=request.user, request=request) | ||||||
|         ): |  | ||||||
|             post_save_handler = partial( |  | ||||||
|                 self.post_save_handler, user=request.user, request=request |  | ||||||
|             ) |  | ||||||
|             pre_delete_handler = partial( |             pre_delete_handler = partial( | ||||||
|                 self.pre_delete_handler, user=request.user, request=request |                 self.pre_delete_handler, user=request.user, request=request | ||||||
|             ) |             ) | ||||||
| @ -94,13 +90,9 @@ class AuditMiddleware: | |||||||
|  |  | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     # pylint: disable=unused-argument |     # pylint: disable=unused-argument | ||||||
|     def pre_delete_handler( |     def pre_delete_handler(user: User, request: HttpRequest, sender, instance: Model, **_): | ||||||
|         user: User, request: HttpRequest, sender, instance: Model, **_ |  | ||||||
|     ): |  | ||||||
|         """Signal handler for all object's pre_delete""" |         """Signal handler for all object's pre_delete""" | ||||||
|         if isinstance( |         if isinstance(instance, (Event, Notification, UserObjectPermission)):  # pragma: no cover | ||||||
|             instance, (Event, Notification, UserObjectPermission) |  | ||||||
|         ):  # pragma: no cover |  | ||||||
|             return |             return | ||||||
|  |  | ||||||
|         EventNewThread( |         EventNewThread( | ||||||
|  | |||||||
| @ -14,9 +14,7 @@ def convert_user_to_json(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): | |||||||
|         event.delete() |         event.delete() | ||||||
|         # Because event objects cannot be updated, we have to re-create them |         # Because event objects cannot be updated, we have to re-create them | ||||||
|         event.pk = None |         event.pk = None | ||||||
|         event.user_json = ( |         event.user_json = authentik.events.models.get_user(event.user) if event.user else {} | ||||||
|             authentik.events.models.get_user(event.user) if event.user else {} |  | ||||||
|         ) |  | ||||||
|         event._state.adding = True |         event._state.adding = True | ||||||
|         event.save() |         event.save() | ||||||
|  |  | ||||||
| @ -58,7 +56,5 @@ class Migration(migrations.Migration): | |||||||
|             model_name="event", |             model_name="event", | ||||||
|             name="user", |             name="user", | ||||||
|         ), |         ), | ||||||
|         migrations.RenameField( |         migrations.RenameField(model_name="event", old_name="user_json", new_name="user"), | ||||||
|             model_name="event", old_name="user_json", new_name="user" |  | ||||||
|         ), |  | ||||||
|     ] |     ] | ||||||
|  | |||||||
| @ -11,16 +11,12 @@ def notify_configuration_error(apps: Apps, schema_editor: BaseDatabaseSchemaEdit | |||||||
|     db_alias = schema_editor.connection.alias |     db_alias = schema_editor.connection.alias | ||||||
|     Group = apps.get_model("authentik_core", "Group") |     Group = apps.get_model("authentik_core", "Group") | ||||||
|     PolicyBinding = apps.get_model("authentik_policies", "PolicyBinding") |     PolicyBinding = apps.get_model("authentik_policies", "PolicyBinding") | ||||||
|     EventMatcherPolicy = apps.get_model( |     EventMatcherPolicy = apps.get_model("authentik_policies_event_matcher", "EventMatcherPolicy") | ||||||
|         "authentik_policies_event_matcher", "EventMatcherPolicy" |  | ||||||
|     ) |  | ||||||
|     NotificationRule = apps.get_model("authentik_events", "NotificationRule") |     NotificationRule = apps.get_model("authentik_events", "NotificationRule") | ||||||
|     NotificationTransport = apps.get_model("authentik_events", "NotificationTransport") |     NotificationTransport = apps.get_model("authentik_events", "NotificationTransport") | ||||||
|  |  | ||||||
|     admin_group = ( |     admin_group = ( | ||||||
|         Group.objects.using(db_alias) |         Group.objects.using(db_alias).filter(name="authentik Admins", is_superuser=True).first() | ||||||
|         .filter(name="authentik Admins", is_superuser=True) |  | ||||||
|         .first() |  | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     policy, _ = EventMatcherPolicy.objects.using(db_alias).update_or_create( |     policy, _ = EventMatcherPolicy.objects.using(db_alias).update_or_create( | ||||||
| @ -32,9 +28,7 @@ def notify_configuration_error(apps: Apps, schema_editor: BaseDatabaseSchemaEdit | |||||||
|         defaults={"group": admin_group, "severity": NotificationSeverity.ALERT}, |         defaults={"group": admin_group, "severity": NotificationSeverity.ALERT}, | ||||||
|     ) |     ) | ||||||
|     trigger.transports.set( |     trigger.transports.set( | ||||||
|         NotificationTransport.objects.using(db_alias).filter( |         NotificationTransport.objects.using(db_alias).filter(name="default-email-transport") | ||||||
|             name="default-email-transport" |  | ||||||
|         ) |  | ||||||
|     ) |     ) | ||||||
|     trigger.save() |     trigger.save() | ||||||
|     PolicyBinding.objects.using(db_alias).update_or_create( |     PolicyBinding.objects.using(db_alias).update_or_create( | ||||||
| @ -50,16 +44,12 @@ def notify_update(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): | |||||||
|     db_alias = schema_editor.connection.alias |     db_alias = schema_editor.connection.alias | ||||||
|     Group = apps.get_model("authentik_core", "Group") |     Group = apps.get_model("authentik_core", "Group") | ||||||
|     PolicyBinding = apps.get_model("authentik_policies", "PolicyBinding") |     PolicyBinding = apps.get_model("authentik_policies", "PolicyBinding") | ||||||
|     EventMatcherPolicy = apps.get_model( |     EventMatcherPolicy = apps.get_model("authentik_policies_event_matcher", "EventMatcherPolicy") | ||||||
|         "authentik_policies_event_matcher", "EventMatcherPolicy" |  | ||||||
|     ) |  | ||||||
|     NotificationRule = apps.get_model("authentik_events", "NotificationRule") |     NotificationRule = apps.get_model("authentik_events", "NotificationRule") | ||||||
|     NotificationTransport = apps.get_model("authentik_events", "NotificationTransport") |     NotificationTransport = apps.get_model("authentik_events", "NotificationTransport") | ||||||
|  |  | ||||||
|     admin_group = ( |     admin_group = ( | ||||||
|         Group.objects.using(db_alias) |         Group.objects.using(db_alias).filter(name="authentik Admins", is_superuser=True).first() | ||||||
|         .filter(name="authentik Admins", is_superuser=True) |  | ||||||
|         .first() |  | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     policy, _ = EventMatcherPolicy.objects.using(db_alias).update_or_create( |     policy, _ = EventMatcherPolicy.objects.using(db_alias).update_or_create( | ||||||
| @ -71,9 +61,7 @@ def notify_update(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): | |||||||
|         defaults={"group": admin_group, "severity": NotificationSeverity.ALERT}, |         defaults={"group": admin_group, "severity": NotificationSeverity.ALERT}, | ||||||
|     ) |     ) | ||||||
|     trigger.transports.set( |     trigger.transports.set( | ||||||
|         NotificationTransport.objects.using(db_alias).filter( |         NotificationTransport.objects.using(db_alias).filter(name="default-email-transport") | ||||||
|             name="default-email-transport" |  | ||||||
|         ) |  | ||||||
|     ) |     ) | ||||||
|     trigger.save() |     trigger.save() | ||||||
|     PolicyBinding.objects.using(db_alias).update_or_create( |     PolicyBinding.objects.using(db_alias).update_or_create( | ||||||
| @ -89,16 +77,12 @@ def notify_exception(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): | |||||||
|     db_alias = schema_editor.connection.alias |     db_alias = schema_editor.connection.alias | ||||||
|     Group = apps.get_model("authentik_core", "Group") |     Group = apps.get_model("authentik_core", "Group") | ||||||
|     PolicyBinding = apps.get_model("authentik_policies", "PolicyBinding") |     PolicyBinding = apps.get_model("authentik_policies", "PolicyBinding") | ||||||
|     EventMatcherPolicy = apps.get_model( |     EventMatcherPolicy = apps.get_model("authentik_policies_event_matcher", "EventMatcherPolicy") | ||||||
|         "authentik_policies_event_matcher", "EventMatcherPolicy" |  | ||||||
|     ) |  | ||||||
|     NotificationRule = apps.get_model("authentik_events", "NotificationRule") |     NotificationRule = apps.get_model("authentik_events", "NotificationRule") | ||||||
|     NotificationTransport = apps.get_model("authentik_events", "NotificationTransport") |     NotificationTransport = apps.get_model("authentik_events", "NotificationTransport") | ||||||
|  |  | ||||||
|     admin_group = ( |     admin_group = ( | ||||||
|         Group.objects.using(db_alias) |         Group.objects.using(db_alias).filter(name="authentik Admins", is_superuser=True).first() | ||||||
|         .filter(name="authentik Admins", is_superuser=True) |  | ||||||
|         .first() |  | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     policy_policy_exc, _ = EventMatcherPolicy.objects.using(db_alias).update_or_create( |     policy_policy_exc, _ = EventMatcherPolicy.objects.using(db_alias).update_or_create( | ||||||
| @ -114,9 +98,7 @@ def notify_exception(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): | |||||||
|         defaults={"group": admin_group, "severity": NotificationSeverity.ALERT}, |         defaults={"group": admin_group, "severity": NotificationSeverity.ALERT}, | ||||||
|     ) |     ) | ||||||
|     trigger.transports.set( |     trigger.transports.set( | ||||||
|         NotificationTransport.objects.using(db_alias).filter( |         NotificationTransport.objects.using(db_alias).filter(name="default-email-transport") | ||||||
|             name="default-email-transport" |  | ||||||
|         ) |  | ||||||
|     ) |     ) | ||||||
|     trigger.save() |     trigger.save() | ||||||
|     PolicyBinding.objects.using(db_alias).update_or_create( |     PolicyBinding.objects.using(db_alias).update_or_create( | ||||||
|  | |||||||
| @ -38,9 +38,7 @@ def progress_bar( | |||||||
|  |  | ||||||
|     def print_progress_bar(iteration): |     def print_progress_bar(iteration): | ||||||
|         """Progress Bar Printing Function""" |         """Progress Bar Printing Function""" | ||||||
|         percent = ("{0:." + str(decimals) + "f}").format( |         percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total))) | ||||||
|             100 * (iteration / float(total)) |  | ||||||
|         ) |  | ||||||
|         filledLength = int(length * iteration // total) |         filledLength = int(length * iteration // total) | ||||||
|         bar = fill * filledLength + "-" * (length - filledLength) |         bar = fill * filledLength + "-" * (length - filledLength) | ||||||
|         print(f"\r{prefix} |{bar}| {percent}% {suffix}", end=print_end) |         print(f"\r{prefix} |{bar}| {percent}% {suffix}", end=print_end) | ||||||
| @ -78,9 +76,7 @@ class Migration(migrations.Migration): | |||||||
|         migrations.AddField( |         migrations.AddField( | ||||||
|             model_name="event", |             model_name="event", | ||||||
|             name="expires", |             name="expires", | ||||||
|             field=models.DateTimeField( |             field=models.DateTimeField(default=authentik.events.models.default_event_duration), | ||||||
|                 default=authentik.events.models.default_event_duration |  | ||||||
|             ), |  | ||||||
|         ), |         ), | ||||||
|         migrations.AddField( |         migrations.AddField( | ||||||
|             model_name="event", |             model_name="event", | ||||||
|  | |||||||
| @ -15,9 +15,7 @@ class Migration(migrations.Migration): | |||||||
|         migrations.AddField( |         migrations.AddField( | ||||||
|             model_name="event", |             model_name="event", | ||||||
|             name="tenant", |             name="tenant", | ||||||
|             field=models.JSONField( |             field=models.JSONField(blank=True, default=authentik.events.models.default_tenant), | ||||||
|                 blank=True, default=authentik.events.models.default_tenant |  | ||||||
|             ), |  | ||||||
|         ), |         ), | ||||||
|         migrations.AlterField( |         migrations.AlterField( | ||||||
|             model_name="event", |             model_name="event", | ||||||
|  | |||||||
| @ -15,10 +15,7 @@ from requests import RequestException, post | |||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik import __version__ | from authentik import __version__ | ||||||
| from authentik.core.middleware import ( | from authentik.core.middleware import SESSION_IMPERSONATE_ORIGINAL_USER, SESSION_IMPERSONATE_USER | ||||||
|     SESSION_IMPERSONATE_ORIGINAL_USER, |  | ||||||
|     SESSION_IMPERSONATE_USER, |  | ||||||
| ) |  | ||||||
| from authentik.core.models import ExpiringModel, Group, User | from authentik.core.models import ExpiringModel, Group, User | ||||||
| from authentik.events.geo import GEOIP_READER | from authentik.events.geo import GEOIP_READER | ||||||
| from authentik.events.utils import cleanse_dict, get_user, model_to_dict, sanitize_dict | from authentik.events.utils import cleanse_dict, get_user, model_to_dict, sanitize_dict | ||||||
| @ -159,9 +156,7 @@ class Event(ExpiringModel): | |||||||
|         if hasattr(request, "user"): |         if hasattr(request, "user"): | ||||||
|             original_user = None |             original_user = None | ||||||
|             if hasattr(request, "session"): |             if hasattr(request, "session"): | ||||||
|                 original_user = request.session.get( |                 original_user = request.session.get(SESSION_IMPERSONATE_ORIGINAL_USER, None) | ||||||
|                     SESSION_IMPERSONATE_ORIGINAL_USER, None |  | ||||||
|                 ) |  | ||||||
|             self.user = get_user(request.user, original_user) |             self.user = get_user(request.user, original_user) | ||||||
|         if user: |         if user: | ||||||
|             self.user = get_user(user) |             self.user = get_user(user) | ||||||
| @ -169,9 +164,7 @@ class Event(ExpiringModel): | |||||||
|         if hasattr(request, "session"): |         if hasattr(request, "session"): | ||||||
|             if SESSION_IMPERSONATE_ORIGINAL_USER in request.session: |             if SESSION_IMPERSONATE_ORIGINAL_USER in request.session: | ||||||
|                 self.user = get_user(request.session[SESSION_IMPERSONATE_ORIGINAL_USER]) |                 self.user = get_user(request.session[SESSION_IMPERSONATE_ORIGINAL_USER]) | ||||||
|                 self.user["on_behalf_of"] = get_user( |                 self.user["on_behalf_of"] = get_user(request.session[SESSION_IMPERSONATE_USER]) | ||||||
|                     request.session[SESSION_IMPERSONATE_USER] |  | ||||||
|                 ) |  | ||||||
|         # User 255.255.255.255 as fallback if IP cannot be determined |         # User 255.255.255.255 as fallback if IP cannot be determined | ||||||
|         self.client_ip = get_client_ip(request) |         self.client_ip = get_client_ip(request) | ||||||
|         # Apply GeoIP Data, when enabled |         # Apply GeoIP Data, when enabled | ||||||
| @ -414,9 +407,7 @@ class NotificationRule(PolicyBindingModel): | |||||||
|     severity = models.TextField( |     severity = models.TextField( | ||||||
|         choices=NotificationSeverity.choices, |         choices=NotificationSeverity.choices, | ||||||
|         default=NotificationSeverity.NOTICE, |         default=NotificationSeverity.NOTICE, | ||||||
|         help_text=_( |         help_text=_("Controls which severity level the created notifications will have."), | ||||||
|             "Controls which severity level the created notifications will have." |  | ||||||
|         ), |  | ||||||
|     ) |     ) | ||||||
|     group = models.ForeignKey( |     group = models.ForeignKey( | ||||||
|         Group, |         Group, | ||||||
|  | |||||||
| @ -135,9 +135,7 @@ class MonitoredTask(Task): | |||||||
|         self._result = result |         self._result = result | ||||||
|  |  | ||||||
|     # pylint: disable=too-many-arguments |     # pylint: disable=too-many-arguments | ||||||
|     def after_return( |     def after_return(self, status, retval, task_id, args: list[Any], kwargs: dict[str, Any], einfo): | ||||||
|         self, status, retval, task_id, args: list[Any], kwargs: dict[str, Any], einfo |  | ||||||
|     ): |  | ||||||
|         if self._result: |         if self._result: | ||||||
|             if not self._result.uid: |             if not self._result.uid: | ||||||
|                 self._result.uid = self._uid |                 self._result.uid = self._uid | ||||||
| @ -159,9 +157,7 @@ class MonitoredTask(Task): | |||||||
|     # pylint: disable=too-many-arguments |     # pylint: disable=too-many-arguments | ||||||
|     def on_failure(self, exc, task_id, args, kwargs, einfo): |     def on_failure(self, exc, task_id, args, kwargs, einfo): | ||||||
|         if not self._result: |         if not self._result: | ||||||
|             self._result = TaskResult( |             self._result = TaskResult(status=TaskResultStatus.ERROR, messages=[str(exc)]) | ||||||
|                 status=TaskResultStatus.ERROR, messages=[str(exc)] |  | ||||||
|             ) |  | ||||||
|         if not self._result.uid: |         if not self._result.uid: | ||||||
|             self._result.uid = self._uid |             self._result.uid = self._uid | ||||||
|         TaskInfo( |         TaskInfo( | ||||||
| @ -179,8 +175,7 @@ class MonitoredTask(Task): | |||||||
|         Event.new( |         Event.new( | ||||||
|             EventAction.SYSTEM_TASK_EXCEPTION, |             EventAction.SYSTEM_TASK_EXCEPTION, | ||||||
|             message=( |             message=( | ||||||
|                 f"Task {self.__name__} encountered an error: " |                 f"Task {self.__name__} encountered an error: " "\n".join(self._result.messages) | ||||||
|                 "\n".join(self._result.messages) |  | ||||||
|             ), |             ), | ||||||
|         ).save() |         ).save() | ||||||
|         return super().on_failure(exc, task_id, args, kwargs, einfo=einfo) |         return super().on_failure(exc, task_id, args, kwargs, einfo=einfo) | ||||||
|  | |||||||
| @ -2,11 +2,7 @@ | |||||||
| from threading import Thread | from threading import Thread | ||||||
| from typing import Any, Optional | from typing import Any, Optional | ||||||
|  |  | ||||||
| from django.contrib.auth.signals import ( | from django.contrib.auth.signals import user_logged_in, user_logged_out, user_login_failed | ||||||
|     user_logged_in, |  | ||||||
|     user_logged_out, |  | ||||||
|     user_login_failed, |  | ||||||
| ) |  | ||||||
| from django.db.models.signals import post_save | from django.db.models.signals import post_save | ||||||
| from django.dispatch import receiver | from django.dispatch import receiver | ||||||
| from django.http import HttpRequest | from django.http import HttpRequest | ||||||
| @ -30,9 +26,7 @@ class EventNewThread(Thread): | |||||||
|     kwargs: dict[str, Any] |     kwargs: dict[str, Any] | ||||||
|     user: Optional[User] = None |     user: Optional[User] = None | ||||||
|  |  | ||||||
|     def __init__( |     def __init__(self, action: str, request: HttpRequest, user: Optional[User] = None, **kwargs): | ||||||
|         self, action: str, request: HttpRequest, user: Optional[User] = None, **kwargs |  | ||||||
|     ): |  | ||||||
|         super().__init__() |         super().__init__() | ||||||
|         self.action = action |         self.action = action | ||||||
|         self.request = request |         self.request = request | ||||||
| @ -68,9 +62,7 @@ def on_user_logged_out(sender, request: HttpRequest, user: User, **_): | |||||||
|  |  | ||||||
| @receiver(user_write) | @receiver(user_write) | ||||||
| # pylint: disable=unused-argument | # pylint: disable=unused-argument | ||||||
| def on_user_write( | def on_user_write(sender, request: HttpRequest, user: User, data: dict[str, Any], **kwargs): | ||||||
|     sender, request: HttpRequest, user: User, data: dict[str, Any], **kwargs |  | ||||||
| ): |  | ||||||
|     """Log User write""" |     """Log User write""" | ||||||
|     thread = EventNewThread(EventAction.USER_WRITE, request, **data) |     thread = EventNewThread(EventAction.USER_WRITE, request, **data) | ||||||
|     thread.kwargs["created"] = kwargs.get("created", False) |     thread.kwargs["created"] = kwargs.get("created", False) | ||||||
| @ -80,9 +72,7 @@ def on_user_write( | |||||||
|  |  | ||||||
| @receiver(user_login_failed) | @receiver(user_login_failed) | ||||||
| # pylint: disable=unused-argument | # pylint: disable=unused-argument | ||||||
| def on_user_login_failed( | def on_user_login_failed(sender, credentials: dict[str, str], request: HttpRequest, **_): | ||||||
|     sender, credentials: dict[str, str], request: HttpRequest, **_ |  | ||||||
| ): |  | ||||||
|     """Failed Login""" |     """Failed Login""" | ||||||
|     thread = EventNewThread(EventAction.LOGIN_FAILED, request, **credentials) |     thread = EventNewThread(EventAction.LOGIN_FAILED, request, **credentials) | ||||||
|     thread.run() |     thread.run() | ||||||
|  | |||||||
| @ -22,9 +22,7 @@ LOGGER = get_logger() | |||||||
| def event_notification_handler(event_uuid: str): | def event_notification_handler(event_uuid: str): | ||||||
|     """Start task for each trigger definition""" |     """Start task for each trigger definition""" | ||||||
|     for trigger in NotificationRule.objects.all(): |     for trigger in NotificationRule.objects.all(): | ||||||
|         event_trigger_handler.apply_async( |         event_trigger_handler.apply_async(args=[event_uuid, trigger.name], queue="authentik_events") | ||||||
|             args=[event_uuid, trigger.name], queue="authentik_events" |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task() | @CELERY_APP.task() | ||||||
| @ -43,17 +41,13 @@ def event_trigger_handler(event_uuid: str, trigger_name: str): | |||||||
|     if "policy_uuid" in event.context: |     if "policy_uuid" in event.context: | ||||||
|         policy_uuid = event.context["policy_uuid"] |         policy_uuid = event.context["policy_uuid"] | ||||||
|         if PolicyBinding.objects.filter( |         if PolicyBinding.objects.filter( | ||||||
|             target__in=NotificationRule.objects.all().values_list( |             target__in=NotificationRule.objects.all().values_list("pbm_uuid", flat=True), | ||||||
|                 "pbm_uuid", flat=True |  | ||||||
|             ), |  | ||||||
|             policy=policy_uuid, |             policy=policy_uuid, | ||||||
|         ).exists(): |         ).exists(): | ||||||
|             # If policy that caused this event to be created is attached |             # If policy that caused this event to be created is attached | ||||||
|             # to *any* NotificationRule, we return early. |             # to *any* NotificationRule, we return early. | ||||||
|             # This is the most effective way to prevent infinite loops. |             # This is the most effective way to prevent infinite loops. | ||||||
|             LOGGER.debug( |             LOGGER.debug("e(trigger): attempting to prevent infinite loop", trigger=trigger) | ||||||
|                 "e(trigger): attempting to prevent infinite loop", trigger=trigger |  | ||||||
|             ) |  | ||||||
|             return |             return | ||||||
|  |  | ||||||
|     if not trigger.group: |     if not trigger.group: | ||||||
| @ -62,9 +56,7 @@ def event_trigger_handler(event_uuid: str, trigger_name: str): | |||||||
|  |  | ||||||
|     LOGGER.debug("e(trigger): checking if trigger applies", trigger=trigger) |     LOGGER.debug("e(trigger): checking if trigger applies", trigger=trigger) | ||||||
|     try: |     try: | ||||||
|         user = ( |         user = User.objects.filter(pk=event.user.get("pk")).first() or get_anonymous_user() | ||||||
|             User.objects.filter(pk=event.user.get("pk")).first() or get_anonymous_user() |  | ||||||
|         ) |  | ||||||
|     except User.DoesNotExist: |     except User.DoesNotExist: | ||||||
|         LOGGER.warning("e(trigger): failed to get user", trigger=trigger) |         LOGGER.warning("e(trigger): failed to get user", trigger=trigger) | ||||||
|         return |         return | ||||||
| @ -99,20 +91,14 @@ def event_trigger_handler(event_uuid: str, trigger_name: str): | |||||||
|     retry_backoff=True, |     retry_backoff=True, | ||||||
|     base=MonitoredTask, |     base=MonitoredTask, | ||||||
| ) | ) | ||||||
| def notification_transport( | def notification_transport(self: MonitoredTask, notification_pk: int, transport_pk: int): | ||||||
|     self: MonitoredTask, notification_pk: int, transport_pk: int |  | ||||||
| ): |  | ||||||
|     """Send notification over specified transport""" |     """Send notification over specified transport""" | ||||||
|     self.save_on_success = False |     self.save_on_success = False | ||||||
|     try: |     try: | ||||||
|         notification: Notification = Notification.objects.filter( |         notification: Notification = Notification.objects.filter(pk=notification_pk).first() | ||||||
|             pk=notification_pk |  | ||||||
|         ).first() |  | ||||||
|         if not notification: |         if not notification: | ||||||
|             return |             return | ||||||
|         transport: NotificationTransport = NotificationTransport.objects.get( |         transport: NotificationTransport = NotificationTransport.objects.get(pk=transport_pk) | ||||||
|             pk=transport_pk |  | ||||||
|         ) |  | ||||||
|         transport.send(notification) |         transport.send(notification) | ||||||
|         self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL)) |         self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL)) | ||||||
|     except NotificationTransportError as exc: |     except NotificationTransportError as exc: | ||||||
|  | |||||||
| @ -38,7 +38,5 @@ class TestEvents(TestCase): | |||||||
|         event = Event.new("unittest", model=temp_model) |         event = Event.new("unittest", model=temp_model) | ||||||
|         event.save()  # We save to ensure nothing is un-saveable |         event.save()  # We save to ensure nothing is un-saveable | ||||||
|         model_content_type = ContentType.objects.get_for_model(temp_model) |         model_content_type = ContentType.objects.get_for_model(temp_model) | ||||||
|         self.assertEqual( |         self.assertEqual(event.context.get("model").get("app"), model_content_type.app_label) | ||||||
|             event.context.get("model").get("app"), model_content_type.app_label |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(event.context.get("model").get("pk"), temp_model.pk.hex) |         self.assertEqual(event.context.get("model").get("pk"), temp_model.pk.hex) | ||||||
|  | |||||||
| @ -81,12 +81,8 @@ class TestEventsNotifications(TestCase): | |||||||
|  |  | ||||||
|         execute_mock = MagicMock() |         execute_mock = MagicMock() | ||||||
|         passes = MagicMock(side_effect=PolicyException) |         passes = MagicMock(side_effect=PolicyException) | ||||||
|         with patch( |         with patch("authentik.policies.event_matcher.models.EventMatcherPolicy.passes", passes): | ||||||
|             "authentik.policies.event_matcher.models.EventMatcherPolicy.passes", passes |             with patch("authentik.events.models.NotificationTransport.send", execute_mock): | ||||||
|         ): |  | ||||||
|             with patch( |  | ||||||
|                 "authentik.events.models.NotificationTransport.send", execute_mock |  | ||||||
|             ): |  | ||||||
|                 Event.new(EventAction.CUSTOM_PREFIX).save() |                 Event.new(EventAction.CUSTOM_PREFIX).save() | ||||||
|         self.assertEqual(passes.call_count, 1) |         self.assertEqual(passes.call_count, 1) | ||||||
|  |  | ||||||
| @ -96,9 +92,7 @@ class TestEventsNotifications(TestCase): | |||||||
|         self.group.users.add(user2) |         self.group.users.add(user2) | ||||||
|         self.group.save() |         self.group.save() | ||||||
|  |  | ||||||
|         transport = NotificationTransport.objects.create( |         transport = NotificationTransport.objects.create(name="transport", send_once=True) | ||||||
|             name="transport", send_once=True |  | ||||||
|         ) |  | ||||||
|         NotificationRule.objects.filter(name__startswith="default").delete() |         NotificationRule.objects.filter(name__startswith="default").delete() | ||||||
|         trigger = NotificationRule.objects.create(name="trigger", group=self.group) |         trigger = NotificationRule.objects.create(name="trigger", group=self.group) | ||||||
|         trigger.transports.add(transport) |         trigger.transports.add(transport) | ||||||
|  | |||||||
| @ -14,12 +14,7 @@ from rest_framework.fields import BooleanField, FileField, ReadOnlyField | |||||||
| from rest_framework.parsers import MultiPartParser | from rest_framework.parsers import MultiPartParser | ||||||
| from rest_framework.request import Request | from rest_framework.request import Request | ||||||
| from rest_framework.response import Response | from rest_framework.response import Response | ||||||
| from rest_framework.serializers import ( | from rest_framework.serializers import CharField, ModelSerializer, Serializer, SerializerMethodField | ||||||
|     CharField, |  | ||||||
|     ModelSerializer, |  | ||||||
|     Serializer, |  | ||||||
|     SerializerMethodField, |  | ||||||
| ) |  | ||||||
| from rest_framework.viewsets import ModelViewSet | from rest_framework.viewsets import ModelViewSet | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| @ -152,11 +147,7 @@ class FlowViewSet(UsedByMixin, ModelViewSet): | |||||||
|         ], |         ], | ||||||
|     ) |     ) | ||||||
|     @extend_schema( |     @extend_schema( | ||||||
|         request={ |         request={"multipart/form-data": inline_serializer("SetIcon", fields={"file": FileField()})}, | ||||||
|             "multipart/form-data": inline_serializer( |  | ||||||
|                 "SetIcon", fields={"file": FileField()} |  | ||||||
|             ) |  | ||||||
|         }, |  | ||||||
|         responses={ |         responses={ | ||||||
|             204: OpenApiResponse(description="Successfully imported flow"), |             204: OpenApiResponse(description="Successfully imported flow"), | ||||||
|             400: OpenApiResponse(description="Bad request"), |             400: OpenApiResponse(description="Bad request"), | ||||||
| @ -221,9 +212,7 @@ class FlowViewSet(UsedByMixin, ModelViewSet): | |||||||
|             .order_by("order") |             .order_by("order") | ||||||
|         ): |         ): | ||||||
|             for p_index, policy_binding in enumerate( |             for p_index, policy_binding in enumerate( | ||||||
|                 get_objects_for_user( |                 get_objects_for_user(request.user, "authentik_policies.view_policybinding") | ||||||
|                     request.user, "authentik_policies.view_policybinding" |  | ||||||
|                 ) |  | ||||||
|                 .filter(target=stage_binding) |                 .filter(target=stage_binding) | ||||||
|                 .exclude(policy__isnull=True) |                 .exclude(policy__isnull=True) | ||||||
|                 .order_by("order") |                 .order_by("order") | ||||||
| @ -256,20 +245,14 @@ class FlowViewSet(UsedByMixin, ModelViewSet): | |||||||
|                 element: DiagramElement = body[index] |                 element: DiagramElement = body[index] | ||||||
|                 if element.type == "condition": |                 if element.type == "condition": | ||||||
|                     # Policy passes, link policy yes to next stage |                     # Policy passes, link policy yes to next stage | ||||||
|                     footer.append( |                     footer.append(f"{element.identifier}(yes, right)->{body[index + 1].identifier}") | ||||||
|                         f"{element.identifier}(yes, right)->{body[index + 1].identifier}" |  | ||||||
|                     ) |  | ||||||
|                     # Policy doesn't pass, go to stage after next stage |                     # Policy doesn't pass, go to stage after next stage | ||||||
|                     no_element = body[index + 1] |                     no_element = body[index + 1] | ||||||
|                     if no_element.type != "end": |                     if no_element.type != "end": | ||||||
|                         no_element = body[index + 2] |                         no_element = body[index + 2] | ||||||
|                     footer.append( |                     footer.append(f"{element.identifier}(no, bottom)->{no_element.identifier}") | ||||||
|                         f"{element.identifier}(no, bottom)->{no_element.identifier}" |  | ||||||
|                     ) |  | ||||||
|                 elif element.type == "operation": |                 elif element.type == "operation": | ||||||
|                     footer.append( |                     footer.append(f"{element.identifier}(bottom)->{body[index + 1].identifier}") | ||||||
|                         f"{element.identifier}(bottom)->{body[index + 1].identifier}" |  | ||||||
|                     ) |  | ||||||
|         diagram = "\n".join([str(x) for x in header + body + footer]) |         diagram = "\n".join([str(x) for x in header + body + footer]) | ||||||
|         return Response({"diagram": diagram}) |         return Response({"diagram": diagram}) | ||||||
|  |  | ||||||
|  | |||||||
| @ -95,9 +95,7 @@ class Command(BaseCommand):  # pragma: no cover | |||||||
|         """Output results human readable""" |         """Output results human readable""" | ||||||
|         total_max: int = max([max(inner) for inner in values]) |         total_max: int = max([max(inner) for inner in values]) | ||||||
|         total_min: int = min([min(inner) for inner in values]) |         total_min: int = min([min(inner) for inner in values]) | ||||||
|         total_avg = sum([sum(inner) for inner in values]) / sum( |         total_avg = sum([sum(inner) for inner in values]) / sum([len(inner) for inner in values]) | ||||||
|             [len(inner) for inner in values] |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|         print(f"Version: {__version__}") |         print(f"Version: {__version__}") | ||||||
|         print(f"Processes: {len(values)}") |         print(f"Processes: {len(values)}") | ||||||
|  | |||||||
| @ -9,21 +9,15 @@ from authentik.stages.identification.models import UserFields | |||||||
| from authentik.stages.password import BACKEND_DJANGO, BACKEND_LDAP | from authentik.stages.password import BACKEND_DJANGO, BACKEND_LDAP | ||||||
|  |  | ||||||
|  |  | ||||||
| def create_default_authentication_flow( | def create_default_authentication_flow(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): | ||||||
|     apps: Apps, schema_editor: BaseDatabaseSchemaEditor |  | ||||||
| ): |  | ||||||
|     Flow = apps.get_model("authentik_flows", "Flow") |     Flow = apps.get_model("authentik_flows", "Flow") | ||||||
|     FlowStageBinding = apps.get_model("authentik_flows", "FlowStageBinding") |     FlowStageBinding = apps.get_model("authentik_flows", "FlowStageBinding") | ||||||
|     PasswordStage = apps.get_model("authentik_stages_password", "PasswordStage") |     PasswordStage = apps.get_model("authentik_stages_password", "PasswordStage") | ||||||
|     UserLoginStage = apps.get_model("authentik_stages_user_login", "UserLoginStage") |     UserLoginStage = apps.get_model("authentik_stages_user_login", "UserLoginStage") | ||||||
|     IdentificationStage = apps.get_model( |     IdentificationStage = apps.get_model("authentik_stages_identification", "IdentificationStage") | ||||||
|         "authentik_stages_identification", "IdentificationStage" |  | ||||||
|     ) |  | ||||||
|     db_alias = schema_editor.connection.alias |     db_alias = schema_editor.connection.alias | ||||||
|  |  | ||||||
|     identification_stage, _ = IdentificationStage.objects.using( |     identification_stage, _ = IdentificationStage.objects.using(db_alias).update_or_create( | ||||||
|         db_alias |  | ||||||
|     ).update_or_create( |  | ||||||
|         name="default-authentication-identification", |         name="default-authentication-identification", | ||||||
|         defaults={ |         defaults={ | ||||||
|             "user_fields": [UserFields.E_MAIL, UserFields.USERNAME], |             "user_fields": [UserFields.E_MAIL, UserFields.USERNAME], | ||||||
| @ -69,17 +63,13 @@ def create_default_authentication_flow( | |||||||
|     ) |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
| def create_default_invalidation_flow( | def create_default_invalidation_flow(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): | ||||||
|     apps: Apps, schema_editor: BaseDatabaseSchemaEditor |  | ||||||
| ): |  | ||||||
|     Flow = apps.get_model("authentik_flows", "Flow") |     Flow = apps.get_model("authentik_flows", "Flow") | ||||||
|     FlowStageBinding = apps.get_model("authentik_flows", "FlowStageBinding") |     FlowStageBinding = apps.get_model("authentik_flows", "FlowStageBinding") | ||||||
|     UserLogoutStage = apps.get_model("authentik_stages_user_logout", "UserLogoutStage") |     UserLogoutStage = apps.get_model("authentik_stages_user_logout", "UserLogoutStage") | ||||||
|     db_alias = schema_editor.connection.alias |     db_alias = schema_editor.connection.alias | ||||||
|  |  | ||||||
|     UserLogoutStage.objects.using(db_alias).update_or_create( |     UserLogoutStage.objects.using(db_alias).update_or_create(name="default-invalidation-logout") | ||||||
|         name="default-invalidation-logout" |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     flow, _ = Flow.objects.using(db_alias).update_or_create( |     flow, _ = Flow.objects.using(db_alias).update_or_create( | ||||||
|         slug="default-invalidation-flow", |         slug="default-invalidation-flow", | ||||||
|  | |||||||
| @ -15,16 +15,12 @@ PROMPT_POLICY_EXPRESSION = """# Check if we've not been given a username by the | |||||||
| return 'username' not in context.get('prompt_data', {})""" | return 'username' not in context.get('prompt_data', {})""" | ||||||
|  |  | ||||||
|  |  | ||||||
| def create_default_source_enrollment_flow( | def create_default_source_enrollment_flow(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): | ||||||
|     apps: Apps, schema_editor: BaseDatabaseSchemaEditor |  | ||||||
| ): |  | ||||||
|     Flow = apps.get_model("authentik_flows", "Flow") |     Flow = apps.get_model("authentik_flows", "Flow") | ||||||
|     FlowStageBinding = apps.get_model("authentik_flows", "FlowStageBinding") |     FlowStageBinding = apps.get_model("authentik_flows", "FlowStageBinding") | ||||||
|     PolicyBinding = apps.get_model("authentik_policies", "PolicyBinding") |     PolicyBinding = apps.get_model("authentik_policies", "PolicyBinding") | ||||||
|  |  | ||||||
|     ExpressionPolicy = apps.get_model( |     ExpressionPolicy = apps.get_model("authentik_policies_expression", "ExpressionPolicy") | ||||||
|         "authentik_policies_expression", "ExpressionPolicy" |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     PromptStage = apps.get_model("authentik_stages_prompt", "PromptStage") |     PromptStage = apps.get_model("authentik_stages_prompt", "PromptStage") | ||||||
|     Prompt = apps.get_model("authentik_stages_prompt", "Prompt") |     Prompt = apps.get_model("authentik_stages_prompt", "Prompt") | ||||||
| @ -99,16 +95,12 @@ def create_default_source_enrollment_flow( | |||||||
|     ) |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
| def create_default_source_authentication_flow( | def create_default_source_authentication_flow(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): | ||||||
|     apps: Apps, schema_editor: BaseDatabaseSchemaEditor |  | ||||||
| ): |  | ||||||
|     Flow = apps.get_model("authentik_flows", "Flow") |     Flow = apps.get_model("authentik_flows", "Flow") | ||||||
|     FlowStageBinding = apps.get_model("authentik_flows", "FlowStageBinding") |     FlowStageBinding = apps.get_model("authentik_flows", "FlowStageBinding") | ||||||
|     PolicyBinding = apps.get_model("authentik_policies", "PolicyBinding") |     PolicyBinding = apps.get_model("authentik_policies", "PolicyBinding") | ||||||
|  |  | ||||||
|     ExpressionPolicy = apps.get_model( |     ExpressionPolicy = apps.get_model("authentik_policies_expression", "ExpressionPolicy") | ||||||
|         "authentik_policies_expression", "ExpressionPolicy" |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     UserLoginStage = apps.get_model("authentik_stages_user_login", "UserLoginStage") |     UserLoginStage = apps.get_model("authentik_stages_user_login", "UserLoginStage") | ||||||
|  |  | ||||||
|  | |||||||
| @ -7,9 +7,7 @@ from django.db.backends.base.schema import BaseDatabaseSchemaEditor | |||||||
| from authentik.flows.models import FlowDesignation | from authentik.flows.models import FlowDesignation | ||||||
|  |  | ||||||
|  |  | ||||||
| def create_default_provider_authorization_flow( | def create_default_provider_authorization_flow(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): | ||||||
|     apps: Apps, schema_editor: BaseDatabaseSchemaEditor |  | ||||||
| ): |  | ||||||
|     Flow = apps.get_model("authentik_flows", "Flow") |     Flow = apps.get_model("authentik_flows", "Flow") | ||||||
|     FlowStageBinding = apps.get_model("authentik_flows", "FlowStageBinding") |     FlowStageBinding = apps.get_model("authentik_flows", "FlowStageBinding") | ||||||
|  |  | ||||||
|  | |||||||
| @ -32,9 +32,7 @@ def create_default_oobe_flow(apps: Apps, schema_editor: BaseDatabaseSchemaEditor | |||||||
|     PromptStage = apps.get_model("authentik_stages_prompt", "PromptStage") |     PromptStage = apps.get_model("authentik_stages_prompt", "PromptStage") | ||||||
|     Prompt = apps.get_model("authentik_stages_prompt", "Prompt") |     Prompt = apps.get_model("authentik_stages_prompt", "Prompt") | ||||||
|  |  | ||||||
|     ExpressionPolicy = apps.get_model( |     ExpressionPolicy = apps.get_model("authentik_policies_expression", "ExpressionPolicy") | ||||||
|         "authentik_policies_expression", "ExpressionPolicy" |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     db_alias = schema_editor.connection.alias |     db_alias = schema_editor.connection.alias | ||||||
|  |  | ||||||
| @ -52,9 +50,7 @@ def create_default_oobe_flow(apps: Apps, schema_editor: BaseDatabaseSchemaEditor | |||||||
|         name="default-oobe-prefill-user", |         name="default-oobe-prefill-user", | ||||||
|         defaults={"expression": PREFILL_POLICY_EXPRESSION}, |         defaults={"expression": PREFILL_POLICY_EXPRESSION}, | ||||||
|     ) |     ) | ||||||
|     password_usable_policy, _ = ExpressionPolicy.objects.using( |     password_usable_policy, _ = ExpressionPolicy.objects.using(db_alias).update_or_create( | ||||||
|         db_alias |  | ||||||
|     ).update_or_create( |  | ||||||
|         name="default-oobe-password-usable", |         name="default-oobe-password-usable", | ||||||
|         defaults={"expression": PW_USABLE_POLICY_EXPRESSION}, |         defaults={"expression": PW_USABLE_POLICY_EXPRESSION}, | ||||||
|     ) |     ) | ||||||
| @ -83,9 +79,7 @@ def create_default_oobe_flow(apps: Apps, schema_editor: BaseDatabaseSchemaEditor | |||||||
|     prompt_stage, _ = PromptStage.objects.using(db_alias).update_or_create( |     prompt_stage, _ = PromptStage.objects.using(db_alias).update_or_create( | ||||||
|         name="default-oobe-password", |         name="default-oobe-password", | ||||||
|     ) |     ) | ||||||
|     prompt_stage.fields.set( |     prompt_stage.fields.set([prompt_header, prompt_email, password_first, password_second]) | ||||||
|         [prompt_header, prompt_email, password_first, password_second] |  | ||||||
|     ) |  | ||||||
|     prompt_stage.save() |     prompt_stage.save() | ||||||
|  |  | ||||||
|     user_write, _ = UserWriteStage.objects.using(db_alias).update_or_create( |     user_write, _ = UserWriteStage.objects.using(db_alias).update_or_create( | ||||||
|  | |||||||
| @ -138,9 +138,7 @@ class Flow(SerializerModel, PolicyBindingModel): | |||||||
|         it is returned as-is""" |         it is returned as-is""" | ||||||
|         if not self.background: |         if not self.background: | ||||||
|             return "/static/dist/assets/images/flow_background.jpg" |             return "/static/dist/assets/images/flow_background.jpg" | ||||||
|         if self.background.name.startswith("http") or self.background.name.startswith( |         if self.background.name.startswith("http") or self.background.name.startswith("/static"): | ||||||
|             "/static" |  | ||||||
|         ): |  | ||||||
|             return self.background.name |             return self.background.name | ||||||
|         return self.background.url |         return self.background.url | ||||||
|  |  | ||||||
| @ -165,9 +163,7 @@ class Flow(SerializerModel, PolicyBindingModel): | |||||||
|             if result.passing: |             if result.passing: | ||||||
|                 LOGGER.debug("with_policy: flow passing", flow=flow) |                 LOGGER.debug("with_policy: flow passing", flow=flow) | ||||||
|                 return flow |                 return flow | ||||||
|             LOGGER.warning( |             LOGGER.warning("with_policy: flow not passing", flow=flow, messages=result.messages) | ||||||
|                 "with_policy: flow not passing", flow=flow, messages=result.messages |  | ||||||
|             ) |  | ||||||
|         LOGGER.debug("with_policy: no flow found", filters=flow_filter) |         LOGGER.debug("with_policy: no flow found", filters=flow_filter) | ||||||
|         return None |         return None | ||||||
|  |  | ||||||
|  | |||||||
| @ -78,14 +78,10 @@ class FlowPlan: | |||||||
|         marker = self.markers[0] |         marker = self.markers[0] | ||||||
|  |  | ||||||
|         if marker.__class__ is not StageMarker: |         if marker.__class__ is not StageMarker: | ||||||
|             LOGGER.debug( |             LOGGER.debug("f(plan_inst): stage has marker", binding=binding, marker=marker) | ||||||
|                 "f(plan_inst): stage has marker", binding=binding, marker=marker |  | ||||||
|             ) |  | ||||||
|         marked_stage = marker.process(self, binding, http_request) |         marked_stage = marker.process(self, binding, http_request) | ||||||
|         if not marked_stage: |         if not marked_stage: | ||||||
|             LOGGER.debug( |             LOGGER.debug("f(plan_inst): marker returned none, next stage", binding=binding) | ||||||
|                 "f(plan_inst): marker returned none, next stage", binding=binding |  | ||||||
|             ) |  | ||||||
|             self.bindings.remove(binding) |             self.bindings.remove(binding) | ||||||
|             self.markers.remove(marker) |             self.markers.remove(marker) | ||||||
|             if not self.has_stages: |             if not self.has_stages: | ||||||
| @ -193,9 +189,9 @@ class FlowPlanner: | |||||||
|             if default_context: |             if default_context: | ||||||
|                 plan.context = default_context |                 plan.context = default_context | ||||||
|             # Check Flow policies |             # Check Flow policies | ||||||
|             for binding in FlowStageBinding.objects.filter( |             for binding in FlowStageBinding.objects.filter(target__pk=self.flow.pk).order_by( | ||||||
|                 target__pk=self.flow.pk |                 "order" | ||||||
|             ).order_by("order"): |             ): | ||||||
|                 binding: FlowStageBinding |                 binding: FlowStageBinding | ||||||
|                 stage = binding.stage |                 stage = binding.stage | ||||||
|                 marker = StageMarker() |                 marker = StageMarker() | ||||||
|  | |||||||
| @ -26,9 +26,7 @@ def invalidate_flow_cache(sender, instance, **_): | |||||||
|         LOGGER.debug("Invalidating Flow cache", flow=instance, len=total) |         LOGGER.debug("Invalidating Flow cache", flow=instance, len=total) | ||||||
|     if isinstance(instance, FlowStageBinding): |     if isinstance(instance, FlowStageBinding): | ||||||
|         total = delete_cache_prefix(f"{cache_key(instance.target)}*") |         total = delete_cache_prefix(f"{cache_key(instance.target)}*") | ||||||
|         LOGGER.debug( |         LOGGER.debug("Invalidating Flow cache from FlowStageBinding", binding=instance, len=total) | ||||||
|             "Invalidating Flow cache from FlowStageBinding", binding=instance, len=total |  | ||||||
|         ) |  | ||||||
|     if isinstance(instance, Stage): |     if isinstance(instance, Stage): | ||||||
|         total = 0 |         total = 0 | ||||||
|         for binding in FlowStageBinding.objects.filter(stage=instance): |         for binding in FlowStageBinding.objects.filter(stage=instance): | ||||||
|  | |||||||
| @ -42,14 +42,9 @@ class StageView(View): | |||||||
|         other things besides the form display. |         other things besides the form display. | ||||||
|  |  | ||||||
|         If no user is pending, returns request.user""" |         If no user is pending, returns request.user""" | ||||||
|         if ( |         if PLAN_CONTEXT_PENDING_USER_IDENTIFIER in self.executor.plan.context and for_display: | ||||||
|             PLAN_CONTEXT_PENDING_USER_IDENTIFIER in self.executor.plan.context |  | ||||||
|             and for_display |  | ||||||
|         ): |  | ||||||
|             return User( |             return User( | ||||||
|                 username=self.executor.plan.context.get( |                 username=self.executor.plan.context.get(PLAN_CONTEXT_PENDING_USER_IDENTIFIER), | ||||||
|                     PLAN_CONTEXT_PENDING_USER_IDENTIFIER |  | ||||||
|                 ), |  | ||||||
|                 email="", |                 email="", | ||||||
|             ) |             ) | ||||||
|         if PLAN_CONTEXT_PENDING_USER in self.executor.plan.context: |         if PLAN_CONTEXT_PENDING_USER in self.executor.plan.context: | ||||||
|  | |||||||
| @ -89,14 +89,10 @@ class TestFlowPlanner(TestCase): | |||||||
|  |  | ||||||
|         planner = FlowPlanner(flow) |         planner = FlowPlanner(flow) | ||||||
|         planner.plan(request) |         planner.plan(request) | ||||||
|         self.assertEqual( |         self.assertEqual(CACHE_MOCK.set.call_count, 1)  # Ensure plan is written to cache | ||||||
|             CACHE_MOCK.set.call_count, 1 |  | ||||||
|         )  # Ensure plan is written to cache |  | ||||||
|         planner = FlowPlanner(flow) |         planner = FlowPlanner(flow) | ||||||
|         planner.plan(request) |         planner.plan(request) | ||||||
|         self.assertEqual( |         self.assertEqual(CACHE_MOCK.set.call_count, 1)  # Ensure nothing is written to cache | ||||||
|             CACHE_MOCK.set.call_count, 1 |  | ||||||
|         )  # Ensure nothing is written to cache |  | ||||||
|         self.assertEqual(CACHE_MOCK.get.call_count, 2)  # Get is called twice |         self.assertEqual(CACHE_MOCK.get.call_count, 2)  # Get is called twice | ||||||
|  |  | ||||||
|     def test_planner_default_context(self): |     def test_planner_default_context(self): | ||||||
| @ -176,9 +172,7 @@ class TestFlowPlanner(TestCase): | |||||||
|         request.session.save() |         request.session.save() | ||||||
|  |  | ||||||
|         # Here we patch the dummy policy to evaluate to true so the stage is included |         # Here we patch the dummy policy to evaluate to true so the stage is included | ||||||
|         with patch( |         with patch("authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE): | ||||||
|             "authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE |  | ||||||
|         ): |  | ||||||
|             planner = FlowPlanner(flow) |             planner = FlowPlanner(flow) | ||||||
|             plan = planner.plan(request) |             plan = planner.plan(request) | ||||||
|  |  | ||||||
|  | |||||||
| @ -76,9 +76,7 @@ class TestFlowTransfer(TransactionTestCase): | |||||||
|             PolicyBinding.objects.create(policy=flow_policy, target=flow, order=0) |             PolicyBinding.objects.create(policy=flow_policy, target=flow, order=0) | ||||||
|  |  | ||||||
|             user_login = UserLoginStage.objects.create(name=stage_name) |             user_login = UserLoginStage.objects.create(name=stage_name) | ||||||
|             fsb = FlowStageBinding.objects.create( |             fsb = FlowStageBinding.objects.create(target=flow, stage=user_login, order=0) | ||||||
|                 target=flow, stage=user_login, order=0 |  | ||||||
|             ) |  | ||||||
|             PolicyBinding.objects.create(policy=flow_policy, target=fsb, order=0) |             PolicyBinding.objects.create(policy=flow_policy, target=fsb, order=0) | ||||||
|  |  | ||||||
|             exporter = FlowExporter(flow) |             exporter = FlowExporter(flow) | ||||||
|  | |||||||
| @ -11,12 +11,7 @@ from authentik.core.models import User | |||||||
| from authentik.flows.challenge import ChallengeTypes | from authentik.flows.challenge import ChallengeTypes | ||||||
| from authentik.flows.exceptions import FlowNonApplicableException | from authentik.flows.exceptions import FlowNonApplicableException | ||||||
| from authentik.flows.markers import ReevaluateMarker, StageMarker | from authentik.flows.markers import ReevaluateMarker, StageMarker | ||||||
| from authentik.flows.models import ( | from authentik.flows.models import Flow, FlowDesignation, FlowStageBinding, InvalidResponseAction | ||||||
|     Flow, |  | ||||||
|     FlowDesignation, |  | ||||||
|     FlowStageBinding, |  | ||||||
|     InvalidResponseAction, |  | ||||||
| ) |  | ||||||
| from authentik.flows.planner import FlowPlan, FlowPlanner | from authentik.flows.planner import FlowPlan, FlowPlanner | ||||||
| from authentik.flows.stage import PLAN_CONTEXT_PENDING_USER_IDENTIFIER, StageView | from authentik.flows.stage import PLAN_CONTEXT_PENDING_USER_IDENTIFIER, StageView | ||||||
| from authentik.flows.views import NEXT_ARG_NAME, SESSION_KEY_PLAN, FlowExecutorView | from authentik.flows.views import NEXT_ARG_NAME, SESSION_KEY_PLAN, FlowExecutorView | ||||||
| @ -61,9 +56,7 @@ class TestFlowExecutor(TestCase): | |||||||
|         ) |         ) | ||||||
|         stage = DummyStage.objects.create(name="dummy") |         stage = DummyStage.objects.create(name="dummy") | ||||||
|         binding = FlowStageBinding(target=flow, stage=stage, order=0) |         binding = FlowStageBinding(target=flow, stage=stage, order=0) | ||||||
|         plan = FlowPlan( |         plan = FlowPlan(flow_pk=flow.pk.hex + "a", bindings=[binding], markers=[StageMarker()]) | ||||||
|             flow_pk=flow.pk.hex + "a", bindings=[binding], markers=[StageMarker()] |  | ||||||
|         ) |  | ||||||
|         session = self.client.session |         session = self.client.session | ||||||
|         session[SESSION_KEY_PLAN] = plan |         session[SESSION_KEY_PLAN] = plan | ||||||
|         session.save() |         session.save() | ||||||
| @ -163,9 +156,7 @@ class TestFlowExecutor(TestCase): | |||||||
|             target=flow, stage=DummyStage.objects.create(name="dummy2"), order=1 |             target=flow, stage=DummyStage.objects.create(name="dummy2"), order=1 | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|         exec_url = reverse( |         exec_url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}) | ||||||
|             "authentik_api:flow-executor", kwargs={"flow_slug": flow.slug} |  | ||||||
|         ) |  | ||||||
|         # First Request, start planning, renders form |         # First Request, start planning, renders form | ||||||
|         response = self.client.get(exec_url) |         response = self.client.get(exec_url) | ||||||
|         self.assertEqual(response.status_code, 200) |         self.assertEqual(response.status_code, 200) | ||||||
| @ -209,13 +200,9 @@ class TestFlowExecutor(TestCase): | |||||||
|         PolicyBinding.objects.create(policy=false_policy, target=binding2, order=0) |         PolicyBinding.objects.create(policy=false_policy, target=binding2, order=0) | ||||||
|  |  | ||||||
|         # Here we patch the dummy policy to evaluate to true so the stage is included |         # Here we patch the dummy policy to evaluate to true so the stage is included | ||||||
|         with patch( |         with patch("authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE): | ||||||
|             "authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE |  | ||||||
|         ): |  | ||||||
|  |  | ||||||
|             exec_url = reverse( |             exec_url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}) | ||||||
|                 "authentik_api:flow-executor", kwargs={"flow_slug": flow.slug} |  | ||||||
|             ) |  | ||||||
|             # First request, run the planner |             # First request, run the planner | ||||||
|             response = self.client.get(exec_url) |             response = self.client.get(exec_url) | ||||||
|             self.assertEqual(response.status_code, 200) |             self.assertEqual(response.status_code, 200) | ||||||
| @ -263,13 +250,9 @@ class TestFlowExecutor(TestCase): | |||||||
|         PolicyBinding.objects.create(policy=false_policy, target=binding2, order=0) |         PolicyBinding.objects.create(policy=false_policy, target=binding2, order=0) | ||||||
|  |  | ||||||
|         # Here we patch the dummy policy to evaluate to true so the stage is included |         # Here we patch the dummy policy to evaluate to true so the stage is included | ||||||
|         with patch( |         with patch("authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE): | ||||||
|             "authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE |  | ||||||
|         ): |  | ||||||
|  |  | ||||||
|             exec_url = reverse( |             exec_url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}) | ||||||
|                 "authentik_api:flow-executor", kwargs={"flow_slug": flow.slug} |  | ||||||
|             ) |  | ||||||
|             # First request, run the planner |             # First request, run the planner | ||||||
|             response = self.client.get(exec_url) |             response = self.client.get(exec_url) | ||||||
|  |  | ||||||
| @ -334,13 +317,9 @@ class TestFlowExecutor(TestCase): | |||||||
|         PolicyBinding.objects.create(policy=true_policy, target=binding2, order=0) |         PolicyBinding.objects.create(policy=true_policy, target=binding2, order=0) | ||||||
|  |  | ||||||
|         # Here we patch the dummy policy to evaluate to true so the stage is included |         # Here we patch the dummy policy to evaluate to true so the stage is included | ||||||
|         with patch( |         with patch("authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE): | ||||||
|             "authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE |  | ||||||
|         ): |  | ||||||
|  |  | ||||||
|             exec_url = reverse( |             exec_url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}) | ||||||
|                 "authentik_api:flow-executor", kwargs={"flow_slug": flow.slug} |  | ||||||
|             ) |  | ||||||
|             # First request, run the planner |             # First request, run the planner | ||||||
|             response = self.client.get(exec_url) |             response = self.client.get(exec_url) | ||||||
|  |  | ||||||
| @ -422,13 +401,9 @@ class TestFlowExecutor(TestCase): | |||||||
|         PolicyBinding.objects.create(policy=false_policy, target=binding3, order=0) |         PolicyBinding.objects.create(policy=false_policy, target=binding3, order=0) | ||||||
|  |  | ||||||
|         # Here we patch the dummy policy to evaluate to true so the stage is included |         # Here we patch the dummy policy to evaluate to true so the stage is included | ||||||
|         with patch( |         with patch("authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE): | ||||||
|             "authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE |  | ||||||
|         ): |  | ||||||
|  |  | ||||||
|             exec_url = reverse( |             exec_url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}) | ||||||
|                 "authentik_api:flow-executor", kwargs={"flow_slug": flow.slug} |  | ||||||
|             ) |  | ||||||
|             # First request, run the planner |             # First request, run the planner | ||||||
|             response = self.client.get(exec_url) |             response = self.client.get(exec_url) | ||||||
|             self.assertEqual(response.status_code, 200) |             self.assertEqual(response.status_code, 200) | ||||||
| @ -511,9 +486,7 @@ class TestFlowExecutor(TestCase): | |||||||
|         ) |         ) | ||||||
|         request.user = user |         request.user = user | ||||||
|         planner = FlowPlanner(flow) |         planner = FlowPlanner(flow) | ||||||
|         plan = planner.plan( |         plan = planner.plan(request, default_context={PLAN_CONTEXT_PENDING_USER_IDENTIFIER: ident}) | ||||||
|             request, default_context={PLAN_CONTEXT_PENDING_USER_IDENTIFIER: ident} |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|         executor = FlowExecutorView() |         executor = FlowExecutorView() | ||||||
|         executor.plan = plan |         executor.plan = plan | ||||||
| @ -542,9 +515,7 @@ class TestFlowExecutor(TestCase): | |||||||
|             evaluate_on_plan=False, |             evaluate_on_plan=False, | ||||||
|             re_evaluate_policies=True, |             re_evaluate_policies=True, | ||||||
|         ) |         ) | ||||||
|         PolicyBinding.objects.create( |         PolicyBinding.objects.create(policy=reputation_policy, target=deny_binding, order=0) | ||||||
|             policy=reputation_policy, target=deny_binding, order=0 |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|         # Stage 1 is an identification stage |         # Stage 1 is an identification stage | ||||||
|         ident_stage = IdentificationStage.objects.create( |         ident_stage = IdentificationStage.objects.create( | ||||||
| @ -557,9 +528,7 @@ class TestFlowExecutor(TestCase): | |||||||
|             order=1, |             order=1, | ||||||
|             invalid_response_action=InvalidResponseAction.RESTART_WITH_CONTEXT, |             invalid_response_action=InvalidResponseAction.RESTART_WITH_CONTEXT, | ||||||
|         ) |         ) | ||||||
|         exec_url = reverse( |         exec_url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}) | ||||||
|             "authentik_api:flow-executor", kwargs={"flow_slug": flow.slug} |  | ||||||
|         ) |  | ||||||
|         # First request, run the planner |         # First request, run the planner | ||||||
|         response = self.client.get(exec_url) |         response = self.client.get(exec_url) | ||||||
|         self.assertEqual(response.status_code, 200) |         self.assertEqual(response.status_code, 200) | ||||||
| @ -579,9 +548,7 @@ class TestFlowExecutor(TestCase): | |||||||
|                 "user_fields": [UserFields.E_MAIL], |                 "user_fields": [UserFields.E_MAIL], | ||||||
|             }, |             }, | ||||||
|         ) |         ) | ||||||
|         response = self.client.post( |         response = self.client.post(exec_url, {"uid_field": "invalid-string"}, follow=True) | ||||||
|             exec_url, {"uid_field": "invalid-string"}, follow=True |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, 200) |         self.assertEqual(response.status_code, 200) | ||||||
|         self.assertJSONEqual( |         self.assertJSONEqual( | ||||||
|             force_str(response.content), |             force_str(response.content), | ||||||
|  | |||||||
| @ -21,9 +21,7 @@ class TestHelperView(TestCase): | |||||||
|         response = self.client.get( |         response = self.client.get( | ||||||
|             reverse("authentik_flows:default-invalidation"), |             reverse("authentik_flows:default-invalidation"), | ||||||
|         ) |         ) | ||||||
|         expected_url = reverse( |         expected_url = reverse("authentik_core:if-flow", kwargs={"flow_slug": flow.slug}) | ||||||
|             "authentik_core:if-flow", kwargs={"flow_slug": flow.slug} |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, 302) |         self.assertEqual(response.status_code, 302) | ||||||
|         self.assertEqual(response.url, expected_url) |         self.assertEqual(response.url, expected_url) | ||||||
|  |  | ||||||
| @ -40,8 +38,6 @@ class TestHelperView(TestCase): | |||||||
|         response = self.client.get( |         response = self.client.get( | ||||||
|             reverse("authentik_flows:default-invalidation"), |             reverse("authentik_flows:default-invalidation"), | ||||||
|         ) |         ) | ||||||
|         expected_url = reverse( |         expected_url = reverse("authentik_core:if-flow", kwargs={"flow_slug": flow.slug}) | ||||||
|             "authentik_core:if-flow", kwargs={"flow_slug": flow.slug} |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, 302) |         self.assertEqual(response.status_code, 302) | ||||||
|         self.assertEqual(response.url, expected_url) |         self.assertEqual(response.url, expected_url) | ||||||
|  | |||||||
| @ -44,9 +44,7 @@ class FlowBundleEntry: | |||||||
|     attrs: dict[str, Any] |     attrs: dict[str, Any] | ||||||
|  |  | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def from_model( |     def from_model(model: SerializerModel, *extra_identifier_names: str) -> "FlowBundleEntry": | ||||||
|         model: SerializerModel, *extra_identifier_names: str |  | ||||||
|     ) -> "FlowBundleEntry": |  | ||||||
|         """Convert a SerializerModel instance to a Bundle Entry""" |         """Convert a SerializerModel instance to a Bundle Entry""" | ||||||
|         identifiers = { |         identifiers = { | ||||||
|             "pk": model.pk, |             "pk": model.pk, | ||||||
|  | |||||||
| @ -6,11 +6,7 @@ from uuid import UUID | |||||||
| from django.db.models import Q | from django.db.models import Q | ||||||
|  |  | ||||||
| from authentik.flows.models import Flow, FlowStageBinding, Stage | from authentik.flows.models import Flow, FlowStageBinding, Stage | ||||||
| from authentik.flows.transfer.common import ( | from authentik.flows.transfer.common import DataclassEncoder, FlowBundle, FlowBundleEntry | ||||||
|     DataclassEncoder, |  | ||||||
|     FlowBundle, |  | ||||||
|     FlowBundleEntry, |  | ||||||
| ) |  | ||||||
| from authentik.policies.models import Policy, PolicyBinding | from authentik.policies.models import Policy, PolicyBinding | ||||||
| from authentik.stages.prompt.models import PromptStage | from authentik.stages.prompt.models import PromptStage | ||||||
|  |  | ||||||
| @ -37,9 +33,7 @@ class FlowExporter: | |||||||
|  |  | ||||||
|     def walk_stages(self) -> Iterator[FlowBundleEntry]: |     def walk_stages(self) -> Iterator[FlowBundleEntry]: | ||||||
|         """Convert all stages attached to self.flow into FlowBundleEntry objects""" |         """Convert all stages attached to self.flow into FlowBundleEntry objects""" | ||||||
|         stages = ( |         stages = Stage.objects.filter(flow=self.flow).select_related().select_subclasses() | ||||||
|             Stage.objects.filter(flow=self.flow).select_related().select_subclasses() |  | ||||||
|         ) |  | ||||||
|         for stage in stages: |         for stage in stages: | ||||||
|             if isinstance(stage, PromptStage): |             if isinstance(stage, PromptStage): | ||||||
|                 pass |                 pass | ||||||
| @ -56,9 +50,7 @@ class FlowExporter: | |||||||
|         a direct foreign key to a policy.""" |         a direct foreign key to a policy.""" | ||||||
|         # Special case for PromptStage as that has a direct M2M to policy, we have to ensure |         # Special case for PromptStage as that has a direct M2M to policy, we have to ensure | ||||||
|         # all policies referenced in there we also include here |         # all policies referenced in there we also include here | ||||||
|         prompt_stages = PromptStage.objects.filter(flow=self.flow).values_list( |         prompt_stages = PromptStage.objects.filter(flow=self.flow).values_list("pk", flat=True) | ||||||
|             "pk", flat=True |  | ||||||
|         ) |  | ||||||
|         query = Q(bindings__in=self.pbm_uuids) | Q(promptstage__in=prompt_stages) |         query = Q(bindings__in=self.pbm_uuids) | Q(promptstage__in=prompt_stages) | ||||||
|         policies = Policy.objects.filter(query).select_related() |         policies = Policy.objects.filter(query).select_related() | ||||||
|         for policy in policies: |         for policy in policies: | ||||||
| @ -67,9 +59,7 @@ class FlowExporter: | |||||||
|     def walk_policy_bindings(self) -> Iterator[FlowBundleEntry]: |     def walk_policy_bindings(self) -> Iterator[FlowBundleEntry]: | ||||||
|         """Walk over all policybindings relative to us. This is run at the end of the export, as |         """Walk over all policybindings relative to us. This is run at the end of the export, as | ||||||
|         we are sure all objects exist now.""" |         we are sure all objects exist now.""" | ||||||
|         bindings = PolicyBinding.objects.filter( |         bindings = PolicyBinding.objects.filter(target__in=self.pbm_uuids).select_related() | ||||||
|             target__in=self.pbm_uuids |  | ||||||
|         ).select_related() |  | ||||||
|         for binding in bindings: |         for binding in bindings: | ||||||
|             yield FlowBundleEntry.from_model(binding, "policy", "target", "order") |             yield FlowBundleEntry.from_model(binding, "policy", "target", "order") | ||||||
|  |  | ||||||
|  | |||||||
| @ -16,11 +16,7 @@ from rest_framework.serializers import BaseSerializer, Serializer | |||||||
| from structlog.stdlib import BoundLogger, get_logger | from structlog.stdlib import BoundLogger, get_logger | ||||||
|  |  | ||||||
| from authentik.flows.models import Flow, FlowStageBinding, Stage | from authentik.flows.models import Flow, FlowStageBinding, Stage | ||||||
| from authentik.flows.transfer.common import ( | from authentik.flows.transfer.common import EntryInvalidError, FlowBundle, FlowBundleEntry | ||||||
|     EntryInvalidError, |  | ||||||
|     FlowBundle, |  | ||||||
|     FlowBundleEntry, |  | ||||||
| ) |  | ||||||
| from authentik.lib.models import SerializerModel | from authentik.lib.models import SerializerModel | ||||||
| from authentik.policies.models import Policy, PolicyBinding | from authentik.policies.models import Policy, PolicyBinding | ||||||
| from authentik.stages.prompt.models import Prompt | from authentik.stages.prompt.models import Prompt | ||||||
| @ -105,9 +101,7 @@ class FlowImporter: | |||||||
|             if isinstance(value, dict) and "pk" in value: |             if isinstance(value, dict) and "pk" in value: | ||||||
|                 del updated_identifiers[key] |                 del updated_identifiers[key] | ||||||
|                 updated_identifiers[f"{key}"] = value["pk"] |                 updated_identifiers[f"{key}"] = value["pk"] | ||||||
|         existing_models = model.objects.filter( |         existing_models = model.objects.filter(self.__query_from_identifier(updated_identifiers)) | ||||||
|             self.__query_from_identifier(updated_identifiers) |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|         serializer_kwargs = {} |         serializer_kwargs = {} | ||||||
|         if existing_models.exists(): |         if existing_models.exists(): | ||||||
| @ -120,9 +114,7 @@ class FlowImporter: | |||||||
|             ) |             ) | ||||||
|             serializer_kwargs["instance"] = model_instance |             serializer_kwargs["instance"] = model_instance | ||||||
|         else: |         else: | ||||||
|             self.logger.debug( |             self.logger.debug("initialise new instance", model=model, **updated_identifiers) | ||||||
|                 "initialise new instance", model=model, **updated_identifiers |  | ||||||
|             ) |  | ||||||
|         full_data = self.__update_pks_for_attrs(entry.attrs) |         full_data = self.__update_pks_for_attrs(entry.attrs) | ||||||
|         full_data.update(updated_identifiers) |         full_data.update(updated_identifiers) | ||||||
|         serializer_kwargs["data"] = full_data |         serializer_kwargs["data"] = full_data | ||||||
|  | |||||||
| @ -38,13 +38,7 @@ from authentik.flows.challenge import ( | |||||||
|     WithUserInfoChallenge, |     WithUserInfoChallenge, | ||||||
| ) | ) | ||||||
| from authentik.flows.exceptions import EmptyFlowException, FlowNonApplicableException | from authentik.flows.exceptions import EmptyFlowException, FlowNonApplicableException | ||||||
| from authentik.flows.models import ( | from authentik.flows.models import ConfigurableStage, Flow, FlowDesignation, FlowStageBinding, Stage | ||||||
|     ConfigurableStage, |  | ||||||
|     Flow, |  | ||||||
|     FlowDesignation, |  | ||||||
|     FlowStageBinding, |  | ||||||
|     Stage, |  | ||||||
| ) |  | ||||||
| from authentik.flows.planner import ( | from authentik.flows.planner import ( | ||||||
|     PLAN_CONTEXT_PENDING_USER, |     PLAN_CONTEXT_PENDING_USER, | ||||||
|     PLAN_CONTEXT_REDIRECT, |     PLAN_CONTEXT_REDIRECT, | ||||||
| @ -155,9 +149,7 @@ class FlowExecutorView(APIView): | |||||||
|             try: |             try: | ||||||
|                 self.plan = self._initiate_plan() |                 self.plan = self._initiate_plan() | ||||||
|             except FlowNonApplicableException as exc: |             except FlowNonApplicableException as exc: | ||||||
|                 self._logger.warning( |                 self._logger.warning("f(exec): Flow not applicable to current user", exc=exc) | ||||||
|                     "f(exec): Flow not applicable to current user", exc=exc |  | ||||||
|                 ) |  | ||||||
|                 return to_stage_response(self.request, self.handle_invalid_flow(exc)) |                 return to_stage_response(self.request, self.handle_invalid_flow(exc)) | ||||||
|             except EmptyFlowException as exc: |             except EmptyFlowException as exc: | ||||||
|                 self._logger.warning("f(exec): Flow is empty", exc=exc) |                 self._logger.warning("f(exec): Flow is empty", exc=exc) | ||||||
| @ -174,9 +166,7 @@ class FlowExecutorView(APIView): | |||||||
|             # in which case we just delete the plan and invalidate everything |             # in which case we just delete the plan and invalidate everything | ||||||
|             next_binding = self.plan.next(self.request) |             next_binding = self.plan.next(self.request) | ||||||
|         except Exception as exc:  # pylint: disable=broad-except |         except Exception as exc:  # pylint: disable=broad-except | ||||||
|             self._logger.warning( |             self._logger.warning("f(exec): found incompatible flow plan, invalidating run", exc=exc) | ||||||
|                 "f(exec): found incompatible flow plan, invalidating run", exc=exc |  | ||||||
|             ) |  | ||||||
|             keys = cache.keys("flow_*") |             keys = cache.keys("flow_*") | ||||||
|             cache.delete_many(keys) |             cache.delete_many(keys) | ||||||
|             return self.stage_invalid() |             return self.stage_invalid() | ||||||
| @ -314,9 +304,7 @@ class FlowExecutorView(APIView): | |||||||
|         self.request.session[SESSION_KEY_PLAN] = plan |         self.request.session[SESSION_KEY_PLAN] = plan | ||||||
|         kwargs = self.kwargs |         kwargs = self.kwargs | ||||||
|         kwargs.update({"flow_slug": self.flow.slug}) |         kwargs.update({"flow_slug": self.flow.slug}) | ||||||
|         return redirect_with_qs( |         return redirect_with_qs("authentik_api:flow-executor", self.request.GET, **kwargs) | ||||||
|             "authentik_api:flow-executor", self.request.GET, **kwargs |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def _flow_done(self) -> HttpResponse: |     def _flow_done(self) -> HttpResponse: | ||||||
|         """User Successfully passed all stages""" |         """User Successfully passed all stages""" | ||||||
| @ -350,9 +338,7 @@ class FlowExecutorView(APIView): | |||||||
|             ) |             ) | ||||||
|             kwargs = self.kwargs |             kwargs = self.kwargs | ||||||
|             kwargs.update({"flow_slug": self.flow.slug}) |             kwargs.update({"flow_slug": self.flow.slug}) | ||||||
|             return redirect_with_qs( |             return redirect_with_qs("authentik_api:flow-executor", self.request.GET, **kwargs) | ||||||
|                 "authentik_api:flow-executor", self.request.GET, **kwargs |  | ||||||
|             ) |  | ||||||
|         # User passed all stages |         # User passed all stages | ||||||
|         self._logger.debug( |         self._logger.debug( | ||||||
|             "f(exec): User passed all stages", |             "f(exec): User passed all stages", | ||||||
| @ -408,18 +394,13 @@ class FlowErrorResponse(TemplateResponse): | |||||||
|         super().__init__(request=request, template="flows/error.html") |         super().__init__(request=request, template="flows/error.html") | ||||||
|         self.error = error |         self.error = error | ||||||
|  |  | ||||||
|     def resolve_context( |     def resolve_context(self, context: Optional[dict[str, Any]]) -> Optional[dict[str, Any]]: | ||||||
|         self, context: Optional[dict[str, Any]] |  | ||||||
|     ) -> Optional[dict[str, Any]]: |  | ||||||
|         if not context: |         if not context: | ||||||
|             context = {} |             context = {} | ||||||
|         context["error"] = self.error |         context["error"] = self.error | ||||||
|         if self._request.user and self._request.user.is_authenticated: |         if self._request.user and self._request.user.is_authenticated: | ||||||
|             if ( |             if self._request.user.is_superuser or self._request.user.group_attributes().get( | ||||||
|                 self._request.user.is_superuser |                 USER_ATTRIBUTE_DEBUG, False | ||||||
|                 or self._request.user.group_attributes().get( |  | ||||||
|                     USER_ATTRIBUTE_DEBUG, False |  | ||||||
|                 ) |  | ||||||
|             ): |             ): | ||||||
|                 context["tb"] = "".join(format_tb(self.error.__traceback__)) |                 context["tb"] = "".join(format_tb(self.error.__traceback__)) | ||||||
|         return context |         return context | ||||||
| @ -464,9 +445,7 @@ class ToDefaultFlow(View): | |||||||
|                     flow_slug=flow.slug, |                     flow_slug=flow.slug, | ||||||
|                 ) |                 ) | ||||||
|                 del self.request.session[SESSION_KEY_PLAN] |                 del self.request.session[SESSION_KEY_PLAN] | ||||||
|         return redirect_with_qs( |         return redirect_with_qs("authentik_core:if-flow", request.GET, flow_slug=flow.slug) | ||||||
|             "authentik_core:if-flow", request.GET, flow_slug=flow.slug |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def to_stage_response(request: HttpRequest, source: HttpResponse) -> HttpResponse: | def to_stage_response(request: HttpRequest, source: HttpResponse) -> HttpResponse: | ||||||
|  | |||||||
| @ -115,9 +115,7 @@ class ConfigLoader: | |||||||
|         for key, value in os.environ.items(): |         for key, value in os.environ.items(): | ||||||
|             if not key.startswith(ENV_PREFIX): |             if not key.startswith(ENV_PREFIX): | ||||||
|                 continue |                 continue | ||||||
|             relative_key = ( |             relative_key = key.replace(f"{ENV_PREFIX}_", "", 1).replace("__", ".").lower() | ||||||
|                 key.replace(f"{ENV_PREFIX}_", "", 1).replace("__", ".").lower() |  | ||||||
|             ) |  | ||||||
|             # Recursively convert path from a.b.c into outer[a][b][c] |             # Recursively convert path from a.b.c into outer[a][b][c] | ||||||
|             current_obj = outer |             current_obj = outer | ||||||
|             dot_parts = relative_key.split(".") |             dot_parts = relative_key.split(".") | ||||||
|  | |||||||
| @ -37,15 +37,11 @@ class InheritanceAutoManager(InheritanceManager): | |||||||
|         return super().get_queryset().select_subclasses() |         return super().get_queryset().select_subclasses() | ||||||
|  |  | ||||||
|  |  | ||||||
| class InheritanceForwardManyToOneDescriptor( | class InheritanceForwardManyToOneDescriptor(models.fields.related.ForwardManyToOneDescriptor): | ||||||
|     models.fields.related.ForwardManyToOneDescriptor |  | ||||||
| ): |  | ||||||
|     """Forward ManyToOne Descriptor that selects subclass. Requires InheritanceAutoManager.""" |     """Forward ManyToOne Descriptor that selects subclass. Requires InheritanceAutoManager.""" | ||||||
|  |  | ||||||
|     def get_queryset(self, **hints): |     def get_queryset(self, **hints): | ||||||
|         return self.field.remote_field.model.objects.db_manager( |         return self.field.remote_field.model.objects.db_manager(hints=hints).select_subclasses() | ||||||
|             hints=hints |  | ||||||
|         ).select_subclasses() |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class InheritanceForeignKey(models.ForeignKey): | class InheritanceForeignKey(models.ForeignKey): | ||||||
|  | |||||||
| @ -8,11 +8,7 @@ from botocore.exceptions import BotoCoreError | |||||||
| from celery.exceptions import CeleryError | from celery.exceptions import CeleryError | ||||||
| from channels.middleware import BaseMiddleware | from channels.middleware import BaseMiddleware | ||||||
| from channels_redis.core import ChannelFull | from channels_redis.core import ChannelFull | ||||||
| from django.core.exceptions import ( | from django.core.exceptions import ImproperlyConfigured, SuspiciousOperation, ValidationError | ||||||
|     ImproperlyConfigured, |  | ||||||
|     SuspiciousOperation, |  | ||||||
|     ValidationError, |  | ||||||
| ) |  | ||||||
| from django.db import InternalError, OperationalError, ProgrammingError | from django.db import InternalError, OperationalError, ProgrammingError | ||||||
| from django.http.response import Http404 | from django.http.response import Http404 | ||||||
| from django_redis.exceptions import ConnectionInterrupted | from django_redis.exceptions import ConnectionInterrupted | ||||||
|  | |||||||
| @ -26,7 +26,5 @@ class TestEvaluator(TestCase): | |||||||
|     def test_is_group_member(self): |     def test_is_group_member(self): | ||||||
|         """Test expr_is_group_member""" |         """Test expr_is_group_member""" | ||||||
|         self.assertFalse( |         self.assertFalse( | ||||||
|             BaseEvaluator.expr_is_group_member( |             BaseEvaluator.expr_is_group_member(User.objects.get(username="akadmin"), name="test") | ||||||
|                 User.objects.get(username="akadmin"), name="test" |  | ||||||
|             ) |  | ||||||
|         ) |         ) | ||||||
|  | |||||||
| @ -1,17 +1,8 @@ | |||||||
| """Test HTTP Helpers""" | """Test HTTP Helpers""" | ||||||
| from django.test import RequestFactory, TestCase | from django.test import RequestFactory, TestCase | ||||||
|  |  | ||||||
| from authentik.core.models import ( | from authentik.core.models import USER_ATTRIBUTE_CAN_OVERRIDE_IP, Token, TokenIntents, User | ||||||
|     USER_ATTRIBUTE_CAN_OVERRIDE_IP, | from authentik.lib.utils.http import OUTPOST_REMOTE_IP_HEADER, OUTPOST_TOKEN_HEADER, get_client_ip | ||||||
|     Token, |  | ||||||
|     TokenIntents, |  | ||||||
|     User, |  | ||||||
| ) |  | ||||||
| from authentik.lib.utils.http import ( |  | ||||||
|     OUTPOST_REMOTE_IP_HEADER, |  | ||||||
|     OUTPOST_TOKEN_HEADER, |  | ||||||
|     get_client_ip, |  | ||||||
| ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestHTTP(TestCase): | class TestHTTP(TestCase): | ||||||
|  | |||||||
| @ -9,9 +9,7 @@ class TestSentry(TestCase): | |||||||
|  |  | ||||||
|     def test_error_not_sent(self): |     def test_error_not_sent(self): | ||||||
|         """Test SentryIgnoredError not sent""" |         """Test SentryIgnoredError not sent""" | ||||||
|         self.assertIsNone( |         self.assertIsNone(before_send({}, {"exc_info": (0, SentryIgnoredException(), 0)})) | ||||||
|             before_send({}, {"exc_info": (0, SentryIgnoredException(), 0)}) |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def test_error_sent(self): |     def test_error_sent(self): | ||||||
|         """Test error sent""" |         """Test error sent""" | ||||||
|  | |||||||
| @ -29,16 +29,9 @@ def _get_outpost_override_ip(request: HttpRequest) -> Optional[str]: | |||||||
|     """Get the actual remote IP when set by an outpost. Only |     """Get the actual remote IP when set by an outpost. Only | ||||||
|     allowed when the request is authenticated, by a user with USER_ATTRIBUTE_CAN_OVERRIDE_IP set |     allowed when the request is authenticated, by a user with USER_ATTRIBUTE_CAN_OVERRIDE_IP set | ||||||
|     to outpost""" |     to outpost""" | ||||||
|     from authentik.core.models import ( |     from authentik.core.models import USER_ATTRIBUTE_CAN_OVERRIDE_IP, Token, TokenIntents | ||||||
|         USER_ATTRIBUTE_CAN_OVERRIDE_IP, |  | ||||||
|         Token, |  | ||||||
|         TokenIntents, |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     if ( |     if OUTPOST_REMOTE_IP_HEADER not in request.META or OUTPOST_TOKEN_HEADER not in request.META: | ||||||
|         OUTPOST_REMOTE_IP_HEADER not in request.META |  | ||||||
|         or OUTPOST_TOKEN_HEADER not in request.META |  | ||||||
|     ): |  | ||||||
|         return None |         return None | ||||||
|     fake_ip = request.META[OUTPOST_REMOTE_IP_HEADER] |     fake_ip = request.META[OUTPOST_REMOTE_IP_HEADER] | ||||||
|     tokens = Token.filter_not_expired( |     tokens = Token.filter_not_expired( | ||||||
|  | |||||||
| @ -12,9 +12,7 @@ def managed_reconcile(self: MonitoredTask): | |||||||
|     try: |     try: | ||||||
|         ObjectManager().run() |         ObjectManager().run() | ||||||
|         self.set_status( |         self.set_status( | ||||||
|             TaskResult( |             TaskResult(TaskResultStatus.SUCCESSFUL, ["Successfully updated managed models."]) | ||||||
|                 TaskResultStatus.SUCCESSFUL, ["Successfully updated managed models."] |  | ||||||
|             ) |  | ||||||
|         ) |         ) | ||||||
|     except DatabaseError as exc: |     except DatabaseError as exc: | ||||||
|         self.set_status(TaskResult(TaskResultStatus.WARNING, [str(exc)])) |         self.set_status(TaskResult(TaskResultStatus.WARNING, [str(exc)])) | ||||||
|  | |||||||
| @ -15,12 +15,7 @@ from authentik.core.api.used_by import UsedByMixin | |||||||
| from authentik.core.api.utils import PassiveSerializer, is_dict | from authentik.core.api.utils import PassiveSerializer, is_dict | ||||||
| from authentik.core.models import Provider | from authentik.core.models import Provider | ||||||
| from authentik.outposts.api.service_connections import ServiceConnectionSerializer | from authentik.outposts.api.service_connections import ServiceConnectionSerializer | ||||||
| from authentik.outposts.models import ( | from authentik.outposts.models import Outpost, OutpostConfig, OutpostType, default_outpost_config | ||||||
|     Outpost, |  | ||||||
|     OutpostConfig, |  | ||||||
|     OutpostType, |  | ||||||
|     default_outpost_config, |  | ||||||
| ) |  | ||||||
| from authentik.providers.ldap.models import LDAPProvider | from authentik.providers.ldap.models import LDAPProvider | ||||||
| from authentik.providers.proxy.models import ProxyProvider | from authentik.providers.proxy.models import ProxyProvider | ||||||
|  |  | ||||||
|  | |||||||
| @ -15,11 +15,7 @@ from rest_framework.serializers import ModelSerializer | |||||||
| from rest_framework.viewsets import GenericViewSet, ModelViewSet | from rest_framework.viewsets import GenericViewSet, ModelViewSet | ||||||
|  |  | ||||||
| from authentik.core.api.used_by import UsedByMixin | from authentik.core.api.used_by import UsedByMixin | ||||||
| from authentik.core.api.utils import ( | from authentik.core.api.utils import MetaNameSerializer, PassiveSerializer, TypeCreateSerializer | ||||||
|     MetaNameSerializer, |  | ||||||
|     PassiveSerializer, |  | ||||||
|     TypeCreateSerializer, |  | ||||||
| ) |  | ||||||
| from authentik.lib.utils.reflection import all_subclasses | from authentik.lib.utils.reflection import all_subclasses | ||||||
| from authentik.outposts.models import ( | from authentik.outposts.models import ( | ||||||
|     DockerServiceConnection, |     DockerServiceConnection, | ||||||
| @ -129,9 +125,7 @@ class KubernetesServiceConnectionSerializer(ServiceConnectionSerializer): | |||||||
|         if kubeconfig == {}: |         if kubeconfig == {}: | ||||||
|             if not self.initial_data["local"]: |             if not self.initial_data["local"]: | ||||||
|                 raise serializers.ValidationError( |                 raise serializers.ValidationError( | ||||||
|                     _( |                     _("You can only use an empty kubeconfig when connecting to a local cluster.") | ||||||
|                         "You can only use an empty kubeconfig when connecting to a local cluster." |  | ||||||
|                     ) |  | ||||||
|                 ) |                 ) | ||||||
|             # Empty kubeconfig is valid |             # Empty kubeconfig is valid | ||||||
|             return kubeconfig |             return kubeconfig | ||||||
|  | |||||||
| @ -59,9 +59,7 @@ class OutpostConsumer(AuthJsonConsumer): | |||||||
|     def connect(self): |     def connect(self): | ||||||
|         super().connect() |         super().connect() | ||||||
|         uuid = self.scope["url_route"]["kwargs"]["pk"] |         uuid = self.scope["url_route"]["kwargs"]["pk"] | ||||||
|         outpost = get_objects_for_user( |         outpost = get_objects_for_user(self.user, "authentik_outposts.view_outpost").filter(pk=uuid) | ||||||
|             self.user, "authentik_outposts.view_outpost" |  | ||||||
|         ).filter(pk=uuid) |  | ||||||
|         if not outpost.exists(): |         if not outpost.exists(): | ||||||
|             raise DenyConnection() |             raise DenyConnection() | ||||||
|         self.accept() |         self.accept() | ||||||
| @ -129,7 +127,5 @@ class OutpostConsumer(AuthJsonConsumer): | |||||||
|     def event_update(self, event): |     def event_update(self, event): | ||||||
|         """Event handler which is called by post_save signals, Send update instruction""" |         """Event handler which is called by post_save signals, Send update instruction""" | ||||||
|         self.send_json( |         self.send_json( | ||||||
|             asdict( |             asdict(WebsocketMessage(instruction=WebsocketMessageInstruction.TRIGGER_UPDATE)) | ||||||
|                 WebsocketMessage(instruction=WebsocketMessageInstruction.TRIGGER_UPDATE) |  | ||||||
|             ) |  | ||||||
|         ) |         ) | ||||||
|  | |||||||
| @ -9,11 +9,7 @@ from yaml import safe_dump | |||||||
|  |  | ||||||
| from authentik import __version__ | from authentik import __version__ | ||||||
| from authentik.outposts.controllers.base import BaseController, ControllerException | from authentik.outposts.controllers.base import BaseController, ControllerException | ||||||
| from authentik.outposts.models import ( | from authentik.outposts.models import DockerServiceConnection, Outpost, ServiceConnectionInvalid | ||||||
|     DockerServiceConnection, |  | ||||||
|     Outpost, |  | ||||||
|     ServiceConnectionInvalid, |  | ||||||
| ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class DockerController(BaseController): | class DockerController(BaseController): | ||||||
| @ -37,9 +33,7 @@ class DockerController(BaseController): | |||||||
|     def _get_env(self) -> dict[str, str]: |     def _get_env(self) -> dict[str, str]: | ||||||
|         return { |         return { | ||||||
|             "AUTHENTIK_HOST": self.outpost.config.authentik_host.lower(), |             "AUTHENTIK_HOST": self.outpost.config.authentik_host.lower(), | ||||||
|             "AUTHENTIK_INSECURE": str( |             "AUTHENTIK_INSECURE": str(self.outpost.config.authentik_host_insecure).lower(), | ||||||
|                 self.outpost.config.authentik_host_insecure |  | ||||||
|             ).lower(), |  | ||||||
|             "AUTHENTIK_TOKEN": self.outpost.token.key, |             "AUTHENTIK_TOKEN": self.outpost.token.key, | ||||||
|         } |         } | ||||||
|  |  | ||||||
| @ -141,9 +135,7 @@ class DockerController(BaseController): | |||||||
|                 .lower() |                 .lower() | ||||||
|                 != "unless-stopped" |                 != "unless-stopped" | ||||||
|             ): |             ): | ||||||
|                 self.logger.info( |                 self.logger.info("Container has mis-matched restart policy, re-creating...") | ||||||
|                     "Container has mis-matched restart policy, re-creating..." |  | ||||||
|                 ) |  | ||||||
|                 self.down() |                 self.down() | ||||||
|                 return self.up() |                 return self.up() | ||||||
|             # Check that container is healthy |             # Check that container is healthy | ||||||
| @ -157,9 +149,7 @@ class DockerController(BaseController): | |||||||
|                 if has_been_created: |                 if has_been_created: | ||||||
|                     # Since we've just created the container, give it some time to start. |                     # Since we've just created the container, give it some time to start. | ||||||
|                     # If its still not up by then, restart it |                     # If its still not up by then, restart it | ||||||
|                     self.logger.info( |                     self.logger.info("Container is unhealthy and new, giving it time to boot.") | ||||||
|                         "Container is unhealthy and new, giving it time to boot." |  | ||||||
|                     ) |  | ||||||
|                     sleep(60) |                     sleep(60) | ||||||
|                 self.logger.info("Container is unhealthy, restarting...") |                 self.logger.info("Container is unhealthy, restarting...") | ||||||
|                 container.restart() |                 container.restart() | ||||||
| @ -198,9 +188,7 @@ class DockerController(BaseController): | |||||||
|                     "ports": ports, |                     "ports": ports, | ||||||
|                     "environment": { |                     "environment": { | ||||||
|                         "AUTHENTIK_HOST": self.outpost.config.authentik_host, |                         "AUTHENTIK_HOST": self.outpost.config.authentik_host, | ||||||
|                         "AUTHENTIK_INSECURE": str( |                         "AUTHENTIK_INSECURE": str(self.outpost.config.authentik_host_insecure), | ||||||
|                             self.outpost.config.authentik_host_insecure |  | ||||||
|                         ), |  | ||||||
|                         "AUTHENTIK_TOKEN": self.outpost.token.key, |                         "AUTHENTIK_TOKEN": self.outpost.token.key, | ||||||
|                     }, |                     }, | ||||||
|                     "labels": self._get_labels(), |                     "labels": self._get_labels(), | ||||||
|  | |||||||
| @ -17,10 +17,7 @@ from kubernetes.client import ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| from authentik.outposts.controllers.base import FIELD_MANAGER | from authentik.outposts.controllers.base import FIELD_MANAGER | ||||||
| from authentik.outposts.controllers.k8s.base import ( | from authentik.outposts.controllers.k8s.base import KubernetesObjectReconciler, NeedsUpdate | ||||||
|     KubernetesObjectReconciler, |  | ||||||
|     NeedsUpdate, |  | ||||||
| ) |  | ||||||
| from authentik.outposts.models import Outpost | from authentik.outposts.models import Outpost | ||||||
|  |  | ||||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||||
| @ -124,9 +121,7 @@ class DeploymentReconciler(KubernetesObjectReconciler[V1Deployment]): | |||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def delete(self, reference: V1Deployment): |     def delete(self, reference: V1Deployment): | ||||||
|         return self.api.delete_namespaced_deployment( |         return self.api.delete_namespaced_deployment(reference.metadata.name, self.namespace) | ||||||
|             reference.metadata.name, self.namespace |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def retrieve(self) -> V1Deployment: |     def retrieve(self) -> V1Deployment: | ||||||
|         return self.api.read_namespaced_deployment(self.name, self.namespace) |         return self.api.read_namespaced_deployment(self.name, self.namespace) | ||||||
|  | |||||||
| @ -5,10 +5,7 @@ from typing import TYPE_CHECKING | |||||||
| from kubernetes.client import CoreV1Api, V1Secret | from kubernetes.client import CoreV1Api, V1Secret | ||||||
|  |  | ||||||
| from authentik.outposts.controllers.base import FIELD_MANAGER | from authentik.outposts.controllers.base import FIELD_MANAGER | ||||||
| from authentik.outposts.controllers.k8s.base import ( | from authentik.outposts.controllers.k8s.base import KubernetesObjectReconciler, NeedsUpdate | ||||||
|     KubernetesObjectReconciler, |  | ||||||
|     NeedsUpdate, |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||||
|     from authentik.outposts.controllers.kubernetes import KubernetesController |     from authentik.outposts.controllers.kubernetes import KubernetesController | ||||||
| @ -38,9 +35,7 @@ class SecretReconciler(KubernetesObjectReconciler[V1Secret]): | |||||||
|         return V1Secret( |         return V1Secret( | ||||||
|             metadata=meta, |             metadata=meta, | ||||||
|             data={ |             data={ | ||||||
|                 "authentik_host": b64string( |                 "authentik_host": b64string(self.controller.outpost.config.authentik_host), | ||||||
|                     self.controller.outpost.config.authentik_host |  | ||||||
|                 ), |  | ||||||
|                 "authentik_host_insecure": b64string( |                 "authentik_host_insecure": b64string( | ||||||
|                     str(self.controller.outpost.config.authentik_host_insecure) |                     str(self.controller.outpost.config.authentik_host_insecure) | ||||||
|                 ), |                 ), | ||||||
| @ -54,9 +49,7 @@ class SecretReconciler(KubernetesObjectReconciler[V1Secret]): | |||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def delete(self, reference: V1Secret): |     def delete(self, reference: V1Secret): | ||||||
|         return self.api.delete_namespaced_secret( |         return self.api.delete_namespaced_secret(reference.metadata.name, self.namespace) | ||||||
|             reference.metadata.name, self.namespace |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def retrieve(self) -> V1Secret: |     def retrieve(self) -> V1Secret: | ||||||
|         return self.api.read_namespaced_secret(self.name, self.namespace) |         return self.api.read_namespaced_secret(self.name, self.namespace) | ||||||
|  | |||||||
| @ -4,10 +4,7 @@ from typing import TYPE_CHECKING | |||||||
| from kubernetes.client import CoreV1Api, V1Service, V1ServicePort, V1ServiceSpec | from kubernetes.client import CoreV1Api, V1Service, V1ServicePort, V1ServiceSpec | ||||||
|  |  | ||||||
| from authentik.outposts.controllers.base import FIELD_MANAGER | from authentik.outposts.controllers.base import FIELD_MANAGER | ||||||
| from authentik.outposts.controllers.k8s.base import ( | from authentik.outposts.controllers.k8s.base import KubernetesObjectReconciler, NeedsUpdate | ||||||
|     KubernetesObjectReconciler, |  | ||||||
|     NeedsUpdate, |  | ||||||
| ) |  | ||||||
| from authentik.outposts.controllers.k8s.deployment import DeploymentReconciler | from authentik.outposts.controllers.k8s.deployment import DeploymentReconciler | ||||||
|  |  | ||||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||||
| @ -58,9 +55,7 @@ class ServiceReconciler(KubernetesObjectReconciler[V1Service]): | |||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def delete(self, reference: V1Service): |     def delete(self, reference: V1Service): | ||||||
|         return self.api.delete_namespaced_service( |         return self.api.delete_namespaced_service(reference.metadata.name, self.namespace) | ||||||
|             reference.metadata.name, self.namespace |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def retrieve(self) -> V1Service: |     def retrieve(self) -> V1Service: | ||||||
|         return self.api.read_namespaced_service(self.name, self.namespace) |         return self.api.read_namespaced_service(self.name, self.namespace) | ||||||
|  | |||||||
| @ -24,9 +24,7 @@ class KubernetesController(BaseController): | |||||||
|     client: ApiClient |     client: ApiClient | ||||||
|     connection: KubernetesServiceConnection |     connection: KubernetesServiceConnection | ||||||
|  |  | ||||||
|     def __init__( |     def __init__(self, outpost: Outpost, connection: KubernetesServiceConnection) -> None: | ||||||
|         self, outpost: Outpost, connection: KubernetesServiceConnection |  | ||||||
|     ) -> None: |  | ||||||
|         super().__init__(outpost, connection) |         super().__init__(outpost, connection) | ||||||
|         self.client = connection.client() |         self.client = connection.client() | ||||||
|         self.reconcilers = { |         self.reconcilers = { | ||||||
|  | |||||||
| @ -15,9 +15,7 @@ class Migration(migrations.Migration): | |||||||
|         migrations.AddField( |         migrations.AddField( | ||||||
|             model_name="outpost", |             model_name="outpost", | ||||||
|             name="_config", |             name="_config", | ||||||
|             field=models.JSONField( |             field=models.JSONField(default=authentik.outposts.models.default_outpost_config), | ||||||
|                 default=authentik.outposts.models.default_outpost_config |  | ||||||
|             ), |  | ||||||
|         ), |         ), | ||||||
|         migrations.AddField( |         migrations.AddField( | ||||||
|             model_name="outpost", |             model_name="outpost", | ||||||
|  | |||||||
| @ -10,9 +10,7 @@ def fix_missing_token_identifier(apps: Apps, schema_editor: BaseDatabaseSchemaEd | |||||||
|     Token = apps.get_model("authentik_core", "Token") |     Token = apps.get_model("authentik_core", "Token") | ||||||
|     from authentik.outposts.models import Outpost |     from authentik.outposts.models import Outpost | ||||||
|  |  | ||||||
|     for outpost in ( |     for outpost in Outpost.objects.using(schema_editor.connection.alias).all().only("pk"): | ||||||
|         Outpost.objects.using(schema_editor.connection.alias).all().only("pk") |  | ||||||
|     ): |  | ||||||
|         user_identifier = outpost.user_identifier |         user_identifier = outpost.user_identifier | ||||||
|         users = User.objects.filter(username=user_identifier) |         users = User.objects.filter(username=user_identifier) | ||||||
|         if not users.exists(): |         if not users.exists(): | ||||||
|  | |||||||
| @ -14,9 +14,7 @@ import authentik.lib.models | |||||||
| def migrate_to_service_connection(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): | def migrate_to_service_connection(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): | ||||||
|     db_alias = schema_editor.connection.alias |     db_alias = schema_editor.connection.alias | ||||||
|     Outpost = apps.get_model("authentik_outposts", "Outpost") |     Outpost = apps.get_model("authentik_outposts", "Outpost") | ||||||
|     DockerServiceConnection = apps.get_model( |     DockerServiceConnection = apps.get_model("authentik_outposts", "DockerServiceConnection") | ||||||
|         "authentik_outposts", "DockerServiceConnection" |  | ||||||
|     ) |  | ||||||
|     KubernetesServiceConnection = apps.get_model( |     KubernetesServiceConnection = apps.get_model( | ||||||
|         "authentik_outposts", "KubernetesServiceConnection" |         "authentik_outposts", "KubernetesServiceConnection" | ||||||
|     ) |     ) | ||||||
| @ -25,9 +23,7 @@ def migrate_to_service_connection(apps: Apps, schema_editor: BaseDatabaseSchemaE | |||||||
|     k8s = KubernetesServiceConnection.objects.filter(local=True).first() |     k8s = KubernetesServiceConnection.objects.filter(local=True).first() | ||||||
|  |  | ||||||
|     try: |     try: | ||||||
|         for outpost in ( |         for outpost in Outpost.objects.using(db_alias).all().exclude(deployment_type="custom"): | ||||||
|             Outpost.objects.using(db_alias).all().exclude(deployment_type="custom") |  | ||||||
|         ): |  | ||||||
|             if outpost.deployment_type == "kubernetes": |             if outpost.deployment_type == "kubernetes": | ||||||
|                 outpost.service_connection = k8s |                 outpost.service_connection = k8s | ||||||
|             elif outpost.deployment_type == "docker": |             elif outpost.deployment_type == "docker": | ||||||
|  | |||||||
| @ -11,9 +11,7 @@ def remove_pb_prefix_users(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): | |||||||
|     Outpost = apps.get_model("authentik_outposts", "Outpost") |     Outpost = apps.get_model("authentik_outposts", "Outpost") | ||||||
|  |  | ||||||
|     for outpost in Outpost.objects.using(alias).all(): |     for outpost in Outpost.objects.using(alias).all(): | ||||||
|         matching = User.objects.using(alias).filter( |         matching = User.objects.using(alias).filter(username=f"pb-outpost-{outpost.uuid.hex}") | ||||||
|             username=f"pb-outpost-{outpost.uuid.hex}" |  | ||||||
|         ) |  | ||||||
|         if matching.exists(): |         if matching.exists(): | ||||||
|             matching.delete() |             matching.delete() | ||||||
|  |  | ||||||
|  | |||||||
| @ -13,8 +13,6 @@ class Migration(migrations.Migration): | |||||||
|         migrations.AlterField( |         migrations.AlterField( | ||||||
|             model_name="outpost", |             model_name="outpost", | ||||||
|             name="type", |             name="type", | ||||||
|             field=models.TextField( |             field=models.TextField(choices=[("proxy", "Proxy"), ("ldap", "Ldap")], default="proxy"), | ||||||
|                 choices=[("proxy", "Proxy"), ("ldap", "Ldap")], default="proxy" |  | ||||||
|             ), |  | ||||||
|         ), |         ), | ||||||
|     ] |     ] | ||||||
|  | |||||||
| @ -64,9 +64,7 @@ class OutpostConfig: | |||||||
|  |  | ||||||
|     log_level: str = CONFIG.y("log_level") |     log_level: str = CONFIG.y("log_level") | ||||||
|     error_reporting_enabled: bool = CONFIG.y_bool("error_reporting.enabled") |     error_reporting_enabled: bool = CONFIG.y_bool("error_reporting.enabled") | ||||||
|     error_reporting_environment: str = CONFIG.y( |     error_reporting_environment: str = CONFIG.y("error_reporting.environment", "customer") | ||||||
|         "error_reporting.environment", "customer" |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     object_naming_template: str = field(default="ak-outpost-%(name)s") |     object_naming_template: str = field(default="ak-outpost-%(name)s") | ||||||
|     kubernetes_replicas: int = field(default=1) |     kubernetes_replicas: int = field(default=1) | ||||||
| @ -264,9 +262,7 @@ class KubernetesServiceConnection(OutpostServiceConnection): | |||||||
|             client = self.client() |             client = self.client() | ||||||
|             api_instance = VersionApi(client) |             api_instance = VersionApi(client) | ||||||
|             version: VersionInfo = api_instance.get_code() |             version: VersionInfo = api_instance.get_code() | ||||||
|             return OutpostServiceConnectionState( |             return OutpostServiceConnectionState(version=version.git_version, healthy=True) | ||||||
|                 version=version.git_version, healthy=True |  | ||||||
|             ) |  | ||||||
|         except (OpenApiException, HTTPError, ServiceConnectionInvalid): |         except (OpenApiException, HTTPError, ServiceConnectionInvalid): | ||||||
|             return OutpostServiceConnectionState(version="", healthy=False) |             return OutpostServiceConnectionState(version="", healthy=False) | ||||||
|  |  | ||||||
| @ -360,8 +356,7 @@ class Outpost(ManagedModel): | |||||||
|                 if isinstance(model_or_perm, models.Model): |                 if isinstance(model_or_perm, models.Model): | ||||||
|                     model_or_perm: models.Model |                     model_or_perm: models.Model | ||||||
|                     code_name = ( |                     code_name = ( | ||||||
|                         f"{model_or_perm._meta.app_label}." |                         f"{model_or_perm._meta.app_label}." f"view_{model_or_perm._meta.model_name}" | ||||||
|                         f"view_{model_or_perm._meta.model_name}" |  | ||||||
|                     ) |                     ) | ||||||
|                     assign_perm(code_name, user, model_or_perm) |                     assign_perm(code_name, user, model_or_perm) | ||||||
|                 else: |                 else: | ||||||
| @ -417,9 +412,7 @@ class Outpost(ManagedModel): | |||||||
|             self, |             self, | ||||||
|             "authentik_events.add_event", |             "authentik_events.add_event", | ||||||
|         ] |         ] | ||||||
|         for provider in ( |         for provider in Provider.objects.filter(outpost=self).select_related().select_subclasses(): | ||||||
|             Provider.objects.filter(outpost=self).select_related().select_subclasses() |  | ||||||
|         ): |  | ||||||
|             if isinstance(provider, OutpostModel): |             if isinstance(provider, OutpostModel): | ||||||
|                 objects.extend(provider.get_required_objects()) |                 objects.extend(provider.get_required_objects()) | ||||||
|             else: |             else: | ||||||
|  | |||||||
| @ -9,11 +9,7 @@ from authentik.core.models import Provider | |||||||
| from authentik.crypto.models import CertificateKeyPair | from authentik.crypto.models import CertificateKeyPair | ||||||
| from authentik.lib.utils.reflection import class_to_path | from authentik.lib.utils.reflection import class_to_path | ||||||
| from authentik.outposts.models import Outpost, OutpostServiceConnection | from authentik.outposts.models import Outpost, OutpostServiceConnection | ||||||
| from authentik.outposts.tasks import ( | from authentik.outposts.tasks import CACHE_KEY_OUTPOST_DOWN, outpost_controller, outpost_post_save | ||||||
|     CACHE_KEY_OUTPOST_DOWN, |  | ||||||
|     outpost_controller, |  | ||||||
|     outpost_post_save, |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
| UPDATE_TRIGGERING_MODELS = ( | UPDATE_TRIGGERING_MODELS = ( | ||||||
| @ -37,9 +33,7 @@ def pre_save_outpost(sender, instance: Outpost, **_): | |||||||
|     # Name changes the deployment name, need to recreate |     # Name changes the deployment name, need to recreate | ||||||
|     dirty += old_instance.name != instance.name |     dirty += old_instance.name != instance.name | ||||||
|     # namespace requires re-create |     # namespace requires re-create | ||||||
|     dirty += ( |     dirty += old_instance.config.kubernetes_namespace != instance.config.kubernetes_namespace | ||||||
|         old_instance.config.kubernetes_namespace != instance.config.kubernetes_namespace |  | ||||||
|     ) |  | ||||||
|     if bool(dirty): |     if bool(dirty): | ||||||
|         LOGGER.info("Outpost needs re-deployment due to changes", instance=instance) |         LOGGER.info("Outpost needs re-deployment due to changes", instance=instance) | ||||||
|         cache.set(CACHE_KEY_OUTPOST_DOWN % instance.pk.hex, old_instance) |         cache.set(CACHE_KEY_OUTPOST_DOWN % instance.pk.hex, old_instance) | ||||||
|  | |||||||
| @ -62,9 +62,7 @@ def controller_for_outpost(outpost: Outpost) -> Optional[BaseController]: | |||||||
| def outpost_service_connection_state(connection_pk: Any): | def outpost_service_connection_state(connection_pk: Any): | ||||||
|     """Update cached state of a service connection""" |     """Update cached state of a service connection""" | ||||||
|     connection: OutpostServiceConnection = ( |     connection: OutpostServiceConnection = ( | ||||||
|         OutpostServiceConnection.objects.filter(pk=connection_pk) |         OutpostServiceConnection.objects.filter(pk=connection_pk).select_subclasses().first() | ||||||
|         .select_subclasses() |  | ||||||
|         .first() |  | ||||||
|     ) |     ) | ||||||
|     if not connection: |     if not connection: | ||||||
|         return |         return | ||||||
| @ -157,9 +155,7 @@ def outpost_post_save(model_class: str, model_pk: Any): | |||||||
|         outpost_controller.delay(instance.pk) |         outpost_controller.delay(instance.pk) | ||||||
|  |  | ||||||
|     if isinstance(instance, (OutpostModel, Outpost)): |     if isinstance(instance, (OutpostModel, Outpost)): | ||||||
|         LOGGER.debug( |         LOGGER.debug("triggering outpost update from outpostmodel/outpost", instance=instance) | ||||||
|             "triggering outpost update from outpostmodel/outpost", instance=instance |  | ||||||
|         ) |  | ||||||
|         outpost_send_update(instance) |         outpost_send_update(instance) | ||||||
|  |  | ||||||
|     if isinstance(instance, OutpostServiceConnection): |     if isinstance(instance, OutpostServiceConnection): | ||||||
| @ -208,9 +204,7 @@ def _outpost_single_update(outpost: Outpost, layer=None): | |||||||
|         layer = get_channel_layer() |         layer = get_channel_layer() | ||||||
|     for state in OutpostState.for_outpost(outpost): |     for state in OutpostState.for_outpost(outpost): | ||||||
|         for channel in state.channel_ids: |         for channel in state.channel_ids: | ||||||
|             LOGGER.debug( |             LOGGER.debug("sending update", channel=channel, instance=state.uid, outpost=outpost) | ||||||
|                 "sending update", channel=channel, instance=state.uid, outpost=outpost |  | ||||||
|             ) |  | ||||||
|             async_to_sync(layer.send)(channel, {"type": "event.update"}) |             async_to_sync(layer.send)(channel, {"type": "event.update"}) | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -231,9 +225,7 @@ def outpost_local_connection(): | |||||||
|     if Path(kubeconfig_path).exists(): |     if Path(kubeconfig_path).exists(): | ||||||
|         LOGGER.debug("Detected kubeconfig") |         LOGGER.debug("Detected kubeconfig") | ||||||
|         kubeconfig_local_name = f"k8s-{gethostname()}" |         kubeconfig_local_name = f"k8s-{gethostname()}" | ||||||
|         if not KubernetesServiceConnection.objects.filter( |         if not KubernetesServiceConnection.objects.filter(name=kubeconfig_local_name).exists(): | ||||||
|             name=kubeconfig_local_name |  | ||||||
|         ).exists(): |  | ||||||
|             LOGGER.debug("Creating kubeconfig Service Connection") |             LOGGER.debug("Creating kubeconfig Service Connection") | ||||||
|             with open(kubeconfig_path, "r") as _kubeconfig: |             with open(kubeconfig_path, "r") as _kubeconfig: | ||||||
|                 KubernetesServiceConnection.objects.create( |                 KubernetesServiceConnection.objects.create( | ||||||
|  | |||||||
| @ -63,9 +63,7 @@ class TestOutpostServiceConnectionsAPI(APITestCase): | |||||||
|         provider = ProxyProvider.objects.create( |         provider = ProxyProvider.objects.create( | ||||||
|             name="test", authorization_flow=Flow.objects.first() |             name="test", authorization_flow=Flow.objects.first() | ||||||
|         ) |         ) | ||||||
|         invalid = OutpostSerializer( |         invalid = OutpostSerializer(data={"name": "foo", "providers": [provider.pk], "config": {}}) | ||||||
|             data={"name": "foo", "providers": [provider.pk], "config": {}} |  | ||||||
|         ) |  | ||||||
|         self.assertFalse(invalid.is_valid()) |         self.assertFalse(invalid.is_valid()) | ||||||
|         self.assertIn("config", invalid.errors) |         self.assertIn("config", invalid.errors) | ||||||
|         valid = OutpostSerializer( |         valid = OutpostSerializer( | ||||||
|  | |||||||
| @ -2,11 +2,7 @@ | |||||||
| from typing import OrderedDict | from typing import OrderedDict | ||||||
|  |  | ||||||
| from django.core.exceptions import ObjectDoesNotExist | from django.core.exceptions import ObjectDoesNotExist | ||||||
| from rest_framework.serializers import ( | from rest_framework.serializers import ModelSerializer, PrimaryKeyRelatedField, ValidationError | ||||||
|     ModelSerializer, |  | ||||||
|     PrimaryKeyRelatedField, |  | ||||||
|     ValidationError, |  | ||||||
| ) |  | ||||||
| from rest_framework.viewsets import ModelViewSet | from rest_framework.viewsets import ModelViewSet | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
|  | |||||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user
	 Jens Langhammer
					Jens Langhammer