From 1e6e4a0bbcd0090cbacc677fbd4e83edc8830e18 Mon Sep 17 00:00:00 2001 From: Jens Langhammer Date: Sat, 24 Aug 2024 15:59:31 +0200 Subject: [PATCH] refactor from self.executor.current_stage to make nesting easier Signed-off-by: Jens Langhammer --- authentik/core/sources/flow_manager.py | 8 ++- authentik/enterprise/providers/rac/views.py | 8 +-- authentik/enterprise/stages/source/stage.py | 10 ++-- authentik/flows/stage.py | 14 ++--- authentik/stages/authenticator_duo/stage.py | 13 ++--- authentik/stages/authenticator_sms/stage.py | 16 +++--- .../stages/authenticator_static/stage.py | 10 ++-- authentik/stages/authenticator_totp/stage.py | 9 ++-- .../stages/authenticator_validate/stage.py | 52 +++++++++---------- .../stages/authenticator_webauthn/stage.py | 13 +++-- authentik/stages/captcha/stage.py | 8 +-- authentik/stages/consent/stage.py | 14 +++-- authentik/stages/deny/stage.py | 5 +- authentik/stages/dummy/stage.py | 6 +-- authentik/stages/email/stage.py | 22 ++++---- authentik/stages/identification/stage.py | 40 +++++++------- authentik/stages/invitation/stage.py | 5 +- authentik/stages/password/stage.py | 7 ++- authentik/stages/prompt/stage.py | 4 +- authentik/stages/user_delete/stage.py | 3 +- authentik/stages/user_login/stage.py | 17 +++--- authentik/stages/user_logout/stage.py | 3 +- authentik/stages/user_write/stage.py | 16 +++--- 23 files changed, 144 insertions(+), 159 deletions(-) diff --git a/authentik/core/sources/flow_manager.py b/authentik/core/sources/flow_manager.py index 86b78d47ef..eead5bb29f 100644 --- a/authentik/core/sources/flow_manager.py +++ b/authentik/core/sources/flow_manager.py @@ -69,8 +69,8 @@ class MessageStage(StageView): def dispatch(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: """Show a pre-configured message after the flow is done""" - message = getattr(self.executor.current_stage, "message", "") - level = getattr(self.executor.current_stage, "level", messages.SUCCESS) + message = getattr(self.current_stage, "message", "") + level = getattr(self.current_stage, "level", messages.SUCCESS) messages.add_message( self.request, level, @@ -486,9 +486,7 @@ class GroupUpdateStage(StageView): def handle_groups(self) -> bool: self.source: Source = self.executor.plan.context[PLAN_CONTEXT_SOURCE] self.user: User = self.executor.plan.context[PLAN_CONTEXT_PENDING_USER] - self.group_connection_type: GroupSourceConnection = ( - self.executor.current_stage.group_connection_type - ) + self.group_connection_type: GroupSourceConnection = self.current_stage.group_connection_type raw_groups: dict[str, dict[str, Any | dict[str, Any]]] = self.executor.plan.context[ PLAN_CONTEXT_SOURCE_GROUPS diff --git a/authentik/enterprise/providers/rac/views.py b/authentik/enterprise/providers/rac/views.py index 3cdcce2e0a..492a2715d7 100644 --- a/authentik/enterprise/providers/rac/views.py +++ b/authentik/enterprise/providers/rac/views.py @@ -91,9 +91,9 @@ class RACFinalStage(RedirectStage): application: Application def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse: - self.endpoint = self.executor.current_stage.endpoint - self.provider = self.executor.current_stage.provider - self.application = self.executor.current_stage.application + self.endpoint = self.current_stage.endpoint + self.provider = self.current_stage.provider + self.application = self.current_stage.application # Check policies bound to endpoint directly engine = PolicyEngine(self.endpoint, self.request.user, self.request) engine.use_cache = False @@ -132,7 +132,7 @@ class RACFinalStage(RedirectStage): flow=self.executor.plan.flow_pk, endpoint=self.endpoint.name, ).from_http(self.request) - self.executor.current_stage.destination = self.request.build_absolute_uri( + self.current_stage.destination = self.request.build_absolute_uri( reverse("authentik_providers_rac:if-rac", kwargs={"token": str(token.token)}) ) return super().get_challenge(*args, **kwargs) diff --git a/authentik/enterprise/stages/source/stage.py b/authentik/enterprise/stages/source/stage.py index 44d405d33a..789c76116f 100644 --- a/authentik/enterprise/stages/source/stage.py +++ b/authentik/enterprise/stages/source/stage.py @@ -21,16 +21,15 @@ from authentik.lib.utils.time import timedelta_from_string PLAN_CONTEXT_RESUME_TOKEN = "resume_token" # nosec -class SourceStageView(ChallengeStageView): +class SourceStageView(ChallengeStageView[SourceStage]): """Suspend the current flow execution and send the user to a source, after which this flow execution is resumed.""" login_button: UILoginButton def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse: - current_stage: SourceStage = self.executor.current_stage source: Source = ( - Source.objects.filter(pk=current_stage.source_id).select_subclasses().first() + Source.objects.filter(pk=self.current_stage.source_id).select_subclasses().first() ) if not source: self.logger.warning("Source does not exist") @@ -56,11 +55,10 @@ class SourceStageView(ChallengeStageView): pending_user: User = self.get_pending_user() if pending_user.is_anonymous or not pending_user.pk: pending_user = get_anonymous_user() - current_stage: SourceStage = self.executor.current_stage - identifier = slugify(f"ak-source-stage-{current_stage.name}-{str(uuid4())}") + identifier = slugify(f"ak-source-stage-{self.current_stage.name}-{str(uuid4())}") # Don't check for validity here, we only care if the token exists tokens = FlowToken.objects.filter(identifier=identifier) - valid_delta = timedelta_from_string(current_stage.resume_timeout) + valid_delta = timedelta_from_string(self.current_stage.resume_timeout) if not tokens.exists(): return FlowToken.objects.create( expires=now() + valid_delta, diff --git a/authentik/flows/stage.py b/authentik/flows/stage.py index 3eac86cf1d..beb1e949e9 100644 --- a/authentik/flows/stage.py +++ b/authentik/flows/stage.py @@ -30,6 +30,7 @@ from authentik.lib.avatars import DEFAULT_AVATAR, get_avatar from authentik.lib.utils.reflection import class_to_path if TYPE_CHECKING: + from authentik.flows.models import Stage from authentik.flows.views.executor import FlowExecutorView PLAN_CONTEXT_PENDING_USER_IDENTIFIER = "pending_user_identifier" @@ -40,20 +41,21 @@ HIST_FLOWS_STAGE_TIME = Histogram( ) -class StageView(View): +class StageView[TStage: "Stage"](View): """Abstract Stage""" executor: "FlowExecutorView" + current_stage: TStage request: HttpRequest = None logger: BoundLogger - def __init__(self, executor: "FlowExecutorView", **kwargs): + def __init__(self, executor: "FlowExecutorView", current_stage: TStage | None = None, **kwargs): self.executor = executor - current_stage = getattr(self.executor, "current_stage", None) + self.current_stage = current_stage or executor.current_stage self.logger = get_logger().bind( - stage=getattr(current_stage, "name", None), + stage=getattr(self.current_stage, "name", None), stage_view=class_to_path(type(self)), ) super().__init__(**kwargs) @@ -80,7 +82,7 @@ class StageView(View): """Cleanup session""" -class ChallengeStageView(StageView): +class ChallengeStageView[TStage: "Stage"](StageView[TStage]): """Stage view which response with a challenge""" response_class = ChallengeResponse @@ -258,7 +260,7 @@ class RedirectStage(ChallengeStageView): def get_challenge(self, *args, **kwargs) -> RedirectChallenge: destination = getattr( - self.executor.current_stage, "destination", reverse("authentik_core:root-redirect") + self.current_stage, "destination", reverse("authentik_core:root-redirect") ) return RedirectChallenge( data={ diff --git a/authentik/stages/authenticator_duo/stage.py b/authentik/stages/authenticator_duo/stage.py index 06272b720b..e1b06b0c87 100644 --- a/authentik/stages/authenticator_duo/stage.py +++ b/authentik/stages/authenticator_duo/stage.py @@ -32,7 +32,7 @@ class AuthenticatorDuoChallengeResponse(ChallengeResponse): component = CharField(default="ak-stage-authenticator-duo") -class AuthenticatorDuoStageView(ChallengeStageView): +class AuthenticatorDuoStageView(ChallengeStageView[AuthenticatorDuoStage]): """Duo stage""" response_class = AuthenticatorDuoChallengeResponse @@ -40,9 +40,8 @@ class AuthenticatorDuoStageView(ChallengeStageView): def duo_enroll(self): """Enroll User with Duo API and save results""" user = self.get_pending_user() - stage: AuthenticatorDuoStage = self.executor.current_stage try: - enroll = stage.auth_client().enroll(user.username) + enroll = self.current_stage.auth_client().enroll(user.username) except RuntimeError as exc: Event.new( EventAction.CONFIGURATION_ERROR, @@ -54,7 +53,6 @@ class AuthenticatorDuoStageView(ChallengeStageView): return enroll def get_challenge(self, *args, **kwargs) -> Challenge: - stage: AuthenticatorDuoStage = self.executor.current_stage if SESSION_KEY_DUO_ENROLL not in self.request.session: self.duo_enroll() enroll = self.request.session[SESSION_KEY_DUO_ENROLL] @@ -62,15 +60,14 @@ class AuthenticatorDuoStageView(ChallengeStageView): data={ "activation_barcode": enroll["activation_barcode"], "activation_code": enroll["activation_code"], - "stage_uuid": str(stage.stage_uuid), + "stage_uuid": str(self.current_stage.stage_uuid), } ) def challenge_valid(self, response: ChallengeResponse) -> HttpResponse: # Duo Challenge has already been validated - stage: AuthenticatorDuoStage = self.executor.current_stage enroll = self.request.session.get(SESSION_KEY_DUO_ENROLL) - enroll_status = stage.auth_client().enroll_status( + enroll_status = self.current_stage.auth_client().enroll_status( enroll["user_id"], enroll["activation_code"] ) if enroll_status != "success": @@ -82,7 +79,7 @@ class AuthenticatorDuoStageView(ChallengeStageView): name="Duo Authenticator", user=self.get_pending_user(), duo_user_id=enroll["user_id"], - stage=stage, + stage=self.current_stage, last_t=now(), ) else: diff --git a/authentik/stages/authenticator_sms/stage.py b/authentik/stages/authenticator_sms/stage.py index 17ce6ac65e..2da674ee45 100644 --- a/authentik/stages/authenticator_sms/stage.py +++ b/authentik/stages/authenticator_sms/stage.py @@ -57,21 +57,20 @@ class AuthenticatorSMSChallengeResponse(ChallengeResponse): return super().validate(attrs) -class AuthenticatorSMSStageView(ChallengeStageView): +class AuthenticatorSMSStageView(ChallengeStageView[AuthenticatorSMSStage]): """OTP sms Setup stage""" response_class = AuthenticatorSMSChallengeResponse def validate_and_send(self, phone_number: str): """Validate phone number and send message""" - stage: AuthenticatorSMSStage = self.executor.current_stage hashed_number = hash_phone_number(phone_number) query = Q(phone_number=hashed_number) | Q(phone_number=phone_number) - if SMSDevice.objects.filter(query, stage=stage.pk).exists(): + if SMSDevice.objects.filter(query, stage=self.current_stage.pk).exists(): raise ValidationError(_("Invalid phone number")) # No code yet, but we have a phone number, so send a verification message device: SMSDevice = self.request.session[SESSION_KEY_SMS_DEVICE] - stage.send(device.token, device) + self.current_stage.send(device.token, device) def _has_phone_number(self) -> str | None: context = self.executor.plan.context @@ -101,10 +100,10 @@ class AuthenticatorSMSStageView(ChallengeStageView): def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: user = self.get_pending_user() - stage: AuthenticatorSMSStage = self.executor.current_stage - if SESSION_KEY_SMS_DEVICE not in self.request.session: - device = SMSDevice(user=user, confirmed=False, stage=stage, name="SMS Device") + device = SMSDevice( + user=user, confirmed=False, stage=self.current_stage, name="SMS Device" + ) device.generate_token(commit=False) self.request.session[SESSION_KEY_SMS_DEVICE] = device if phone_number := self._has_phone_number(): @@ -130,8 +129,7 @@ class AuthenticatorSMSStageView(ChallengeStageView): device: SMSDevice = self.request.session[SESSION_KEY_SMS_DEVICE] if not device.confirmed: return self.challenge_invalid(response) - stage: AuthenticatorSMSStage = self.executor.current_stage - if stage.verify_only: + if self.current_stage.verify_only: self.logger.debug("Hashing number on device") device.set_hashed_number() device.save() diff --git a/authentik/stages/authenticator_static/stage.py b/authentik/stages/authenticator_static/stage.py index face89bfd8..f4a9fbc67e 100644 --- a/authentik/stages/authenticator_static/stage.py +++ b/authentik/stages/authenticator_static/stage.py @@ -29,7 +29,7 @@ class AuthenticatorStaticChallengeResponse(ChallengeResponse): component = CharField(default="ak-stage-authenticator-static") -class AuthenticatorStaticStageView(ChallengeStageView): +class AuthenticatorStaticStageView(ChallengeStageView[AuthenticatorStaticStage]): """Static OTP Setup stage""" response_class = AuthenticatorStaticChallengeResponse @@ -48,14 +48,14 @@ class AuthenticatorStaticStageView(ChallengeStageView): self.logger.debug("No pending user, continuing") return self.executor.stage_ok() - stage: AuthenticatorStaticStage = self.executor.current_stage - if SESSION_STATIC_DEVICE not in self.request.session: device = StaticDevice(user=user, confirmed=False, name="Static Token") tokens = [] - for _ in range(0, stage.token_count): + for _ in range(0, self.current_stage.token_count): tokens.append( - StaticToken(device=device, token=generate_id(length=stage.token_length)) + StaticToken( + device=device, token=generate_id(length=self.current_stage.token_length) + ) ) self.request.session[SESSION_STATIC_DEVICE] = device self.request.session[SESSION_STATIC_TOKENS] = tokens diff --git a/authentik/stages/authenticator_totp/stage.py b/authentik/stages/authenticator_totp/stage.py index 1275fca174..861db144ae 100644 --- a/authentik/stages/authenticator_totp/stage.py +++ b/authentik/stages/authenticator_totp/stage.py @@ -45,7 +45,7 @@ class AuthenticatorTOTPChallengeResponse(ChallengeResponse): return code -class AuthenticatorTOTPStageView(ChallengeStageView): +class AuthenticatorTOTPStageView(ChallengeStageView[AuthenticatorTOTPStage]): """OTP totp Setup stage""" response_class = AuthenticatorTOTPChallengeResponse @@ -71,11 +71,12 @@ class AuthenticatorTOTPStageView(ChallengeStageView): self.logger.debug("No pending user, continuing") return self.executor.stage_ok() - stage: AuthenticatorTOTPStage = self.executor.current_stage - if SESSION_TOTP_DEVICE not in self.request.session: device = TOTPDevice( - user=user, confirmed=False, digits=stage.digits, name="TOTP Authenticator" + user=user, + confirmed=False, + digits=self.current_stage.digits, + name="TOTP Authenticator", ) self.request.session[SESSION_TOTP_DEVICE] = device diff --git a/authentik/stages/authenticator_validate/stage.py b/authentik/stages/authenticator_validate/stage.py index 96ae7e6215..11181e9042 100644 --- a/authentik/stages/authenticator_validate/stage.py +++ b/authentik/stages/authenticator_validate/stage.py @@ -151,7 +151,7 @@ class AuthenticatorValidationChallengeResponse(ChallengeResponse): return attrs -class AuthenticatorValidateStageView(ChallengeStageView): +class AuthenticatorValidateStageView(ChallengeStageView[AuthenticatorValidateStage]): """Authenticator Validation""" response_class = AuthenticatorValidationChallengeResponse @@ -177,16 +177,14 @@ class AuthenticatorValidateStageView(ChallengeStageView): # since their challenges are device-independent seen_classes = [] - stage: AuthenticatorValidateStage = self.executor.current_stage - - threshold = timedelta_from_string(stage.last_auth_threshold) + threshold = timedelta_from_string(self.current_stage.last_auth_threshold) allowed_devices = [] - has_webauthn_filters_set = stage.webauthn_allowed_device_types.exists() + has_webauthn_filters_set = self.current_stage.webauthn_allowed_device_types.exists() for device in user_devices: device_class = device.__class__.__name__.lower().replace("device", "") - if device_class not in stage.device_classes: + if device_class not in self.current_stage.device_classes: self.logger.debug("device class not allowed", device_class=device_class) continue if isinstance(device, SMSDevice) and device.is_hashed: @@ -199,7 +197,7 @@ class AuthenticatorValidateStageView(ChallengeStageView): and device.device_type and has_webauthn_filters_set ): - if not stage.webauthn_allowed_device_types.filter( + if not self.current_stage.webauthn_allowed_device_types.filter( pk=device.device_type.pk ).exists(): self.logger.debug( @@ -216,7 +214,7 @@ class AuthenticatorValidateStageView(ChallengeStageView): data={ "device_class": device_class, "device_uid": device.pk, - "challenge": get_challenge_for_device(self.request, stage, device), + "challenge": get_challenge_for_device(self.request, self.current_stage, device), } ) challenge.is_valid() @@ -235,7 +233,7 @@ class AuthenticatorValidateStageView(ChallengeStageView): "device_uid": -1, "challenge": get_webauthn_challenge_without_user( self.request, - self.executor.current_stage, + self.current_stage, ), } ) @@ -246,7 +244,6 @@ class AuthenticatorValidateStageView(ChallengeStageView): """Check if a user is set, and check if the user has any devices if not, we can skip this entire stage""" user = self.get_pending_user() - stage: AuthenticatorValidateStage = self.executor.current_stage if user and not user.is_anonymous: try: challenges = self.get_device_challenges() @@ -257,7 +254,7 @@ class AuthenticatorValidateStageView(ChallengeStageView): self.logger.debug("Refusing passwordless flow in non-authentication flow") return self.executor.stage_ok() # Passwordless auth, with just webauthn - if DeviceClasses.WEBAUTHN in stage.device_classes: + if DeviceClasses.WEBAUTHN in self.current_stage.device_classes: self.logger.debug("Flow without user, getting generic webauthn challenge") challenges = self.get_webauthn_challenge_without_user() else: @@ -267,13 +264,13 @@ class AuthenticatorValidateStageView(ChallengeStageView): # No allowed devices if len(challenges) < 1: - if stage.not_configured_action == NotConfiguredAction.SKIP: + if self.current_stage.not_configured_action == NotConfiguredAction.SKIP: self.logger.debug("Authenticator not configured, skipping stage") return self.executor.stage_ok() - if stage.not_configured_action == NotConfiguredAction.DENY: + if self.current_stage.not_configured_action == NotConfiguredAction.DENY: self.logger.debug("Authenticator not configured, denying") return self.executor.stage_invalid(_("No (allowed) MFA authenticator configured.")) - if stage.not_configured_action == NotConfiguredAction.CONFIGURE: + if self.current_stage.not_configured_action == NotConfiguredAction.CONFIGURE: self.logger.debug("Authenticator not configured, forcing configure") return self.prepare_stages(user) return super().get(request, *args, **kwargs) @@ -282,8 +279,7 @@ class AuthenticatorValidateStageView(ChallengeStageView): """Check how the user can configure themselves. If no stages are set, return an error. If a single stage is set, insert that stage directly. If multiple are selected, include them in the challenge.""" - stage: AuthenticatorValidateStage = self.executor.current_stage - if not stage.configuration_stages.exists(): + if not self.current_stage.configuration_stages.exists(): Event.new( EventAction.CONFIGURATION_ERROR, message=( @@ -293,15 +289,19 @@ class AuthenticatorValidateStageView(ChallengeStageView): stage=self, ).from_http(self.request).set_user(user).save() return self.executor.stage_invalid() - if stage.configuration_stages.count() == 1: - next_stage = Stage.objects.get_subclass(pk=stage.configuration_stages.first().pk) + if self.current_stage.configuration_stages.count() == 1: + next_stage = Stage.objects.get_subclass( + pk=self.current_stage.configuration_stages.first().pk + ) self.logger.debug("Single stage configured, auto-selecting", stage=next_stage) self.executor.plan.context[PLAN_CONTEXT_SELECTED_STAGE] = next_stage # Because that normal execution only happens on post, we directly inject it here and # return it self.executor.plan.insert_stage(next_stage) return self.executor.stage_ok() - stages = Stage.objects.filter(pk__in=stage.configuration_stages.all()).select_subclasses() + stages = Stage.objects.filter( + pk__in=self.current_stage.configuration_stages.all() + ).select_subclasses() self.executor.plan.context[PLAN_CONTEXT_STAGES] = stages return super().get(self.request, *args, **kwargs) @@ -309,7 +309,7 @@ class AuthenticatorValidateStageView(ChallengeStageView): res = super().post(request, *args, **kwargs) if ( PLAN_CONTEXT_SELECTED_STAGE in self.executor.plan.context - and self.executor.current_stage.not_configured_action == NotConfiguredAction.CONFIGURE + and self.current_stage.not_configured_action == NotConfiguredAction.CONFIGURE ): self.logger.debug("Got selected stage in context, running that") stage_pk = self.executor.plan.context.get(PLAN_CONTEXT_SELECTED_STAGE) @@ -351,7 +351,7 @@ class AuthenticatorValidateStageView(ChallengeStageView): def cookie_jwt_key(self) -> str: """Signing key for MFA Cookie for this stage""" return sha256( - f"{get_unique_identifier()}:{self.executor.current_stage.pk.hex}".encode("ascii") + f"{get_unique_identifier()}:{self.current_stage.pk.hex}".encode("ascii") ).hexdigest() def check_mfa_cookie(self, allowed_devices: list[Device]): @@ -362,12 +362,11 @@ class AuthenticatorValidateStageView(ChallengeStageView): correct user and with an allowed class""" if COOKIE_NAME_MFA not in self.request.COOKIES: return - stage: AuthenticatorValidateStage = self.executor.current_stage - threshold = timedelta_from_string(stage.last_auth_threshold) + threshold = timedelta_from_string(self.current_stage.last_auth_threshold) latest_allowed = datetime.now() + threshold try: payload = decode(self.request.COOKIES[COOKIE_NAME_MFA], self.cookie_jwt_key, ["HS256"]) - if payload["stage"] != stage.pk.hex: + if payload["stage"] != self.current_stage.pk.hex: self.logger.warning("Invalid stage PK") return if datetime.fromtimestamp(payload["exp"]) > latest_allowed: @@ -385,15 +384,14 @@ class AuthenticatorValidateStageView(ChallengeStageView): """Set an MFA cookie to allow users to skip MFA validation in this context (browser) The cookie is JWT which is signed with a hash of the secret key and the UID of the stage""" - stage: AuthenticatorValidateStage = self.executor.current_stage - delta = timedelta_from_string(stage.last_auth_threshold) + delta = timedelta_from_string(self.current_stage.last_auth_threshold) if delta.total_seconds() < 1: self.logger.info("Not setting MFA cookie since threshold is not set.") return self.executor.stage_ok() expiry = datetime.now() + delta cookie_payload = { "device": device.pk, - "stage": stage.pk.hex, + "stage": self.current_stage.pk.hex, "exp": expiry.timestamp(), } response = self.executor.stage_ok() diff --git a/authentik/stages/authenticator_webauthn/stage.py b/authentik/stages/authenticator_webauthn/stage.py index 7c25f0c449..e679c60e61 100644 --- a/authentik/stages/authenticator_webauthn/stage.py +++ b/authentik/stages/authenticator_webauthn/stage.py @@ -108,7 +108,7 @@ class AuthenticatorWebAuthnChallengeResponse(ChallengeResponse): return registration -class AuthenticatorWebAuthnStageView(ChallengeStageView): +class AuthenticatorWebAuthnStageView(ChallengeStageView[AuthenticatorWebAuthnStage]): """WebAuthn stage""" response_class = AuthenticatorWebAuthnChallengeResponse @@ -116,12 +116,11 @@ class AuthenticatorWebAuthnStageView(ChallengeStageView): def get_challenge(self, *args, **kwargs) -> Challenge: # clear session variables prior to starting a new registration self.request.session.pop(SESSION_KEY_WEBAUTHN_CHALLENGE, None) - stage: AuthenticatorWebAuthnStage = self.executor.current_stage user = self.get_pending_user() # library accepts none so we store null in the database, but if there is a value # set, cast it to string to ensure it's not a django class - authenticator_attachment = stage.authenticator_attachment + authenticator_attachment = self.current_stage.authenticator_attachment if authenticator_attachment: authenticator_attachment = AuthenticatorAttachment(str(authenticator_attachment)) @@ -132,8 +131,12 @@ class AuthenticatorWebAuthnStageView(ChallengeStageView): user_name=user.username, user_display_name=user.name, authenticator_selection=AuthenticatorSelectionCriteria( - resident_key=ResidentKeyRequirement(str(stage.resident_key_requirement)), - user_verification=UserVerificationRequirement(str(stage.user_verification)), + resident_key=ResidentKeyRequirement( + str(self.current_stage.resident_key_requirement) + ), + user_verification=UserVerificationRequirement( + str(self.current_stage.user_verification) + ), authenticator_attachment=authenticator_attachment, ), attestation=AttestationConveyancePreference.DIRECT, diff --git a/authentik/stages/captcha/stage.py b/authentik/stages/captcha/stage.py index 3967e6d3d3..d8517d8096 100644 --- a/authentik/stages/captcha/stage.py +++ b/authentik/stages/captcha/stage.py @@ -70,7 +70,7 @@ class CaptchaChallengeResponse(ChallengeResponse): return data -class CaptchaStageView(ChallengeStageView): +class CaptchaStageView(ChallengeStageView[CaptchaChallenge]): """Simple captcha checker, logic is handled in django-captcha module""" response_class = CaptchaChallengeResponse @@ -78,8 +78,8 @@ class CaptchaStageView(ChallengeStageView): def get_challenge(self, *args, **kwargs) -> Challenge: return CaptchaChallenge( data={ - "js_url": self.executor.current_stage.js_url, - "site_key": self.executor.current_stage.public_key, + "js_url": self.current_stage.js_url, + "site_key": self.current_stage.public_key, } ) @@ -87,6 +87,6 @@ class CaptchaStageView(ChallengeStageView): response = response.validated_data["token"] self.executor.plan.context[PLAN_CONTEXT_CAPTCHA] = { "response": response, - "stage": self.executor.current_stage, + "stage": self.current_stage, } return self.executor.stage_ok() diff --git a/authentik/stages/consent/stage.py b/authentik/stages/consent/stage.py index 36648c899a..156402a74f 100644 --- a/authentik/stages/consent/stage.py +++ b/authentik/stages/consent/stage.py @@ -48,7 +48,7 @@ class ConsentChallengeResponse(ChallengeResponse): token = CharField(required=True) -class ConsentStageView(ChallengeStageView): +class ConsentStageView(ChallengeStageView[ConsentStage]): """Simple consent checker.""" response_class = ConsentChallengeResponse @@ -72,14 +72,13 @@ class ConsentStageView(ChallengeStageView): """Check if the current request should require a prompt for non consent reasons, i.e. this stage injected from another stage, mode is always requireed or no application is set.""" - current_stage: ConsentStage = self.executor.current_stage # Make this StageView work when injected, in which case `current_stage` is an instance # of the base class, and we don't save any consent, as it is assumed to be a one-time # prompt - if not isinstance(current_stage, ConsentStage): + if not isinstance(self.current_stage, ConsentStage): return True # For always require, we always return the challenge - if current_stage.mode == ConsentMode.ALWAYS_REQUIRE: + if self.current_stage.mode == ConsentMode.ALWAYS_REQUIRE: return True # at this point we need to check consent from database if PLAN_CONTEXT_APPLICATION not in self.executor.plan.context: @@ -125,7 +124,6 @@ class ConsentStageView(ChallengeStageView): return self.get(self.request) if self.should_always_prompt(): return self.executor.stage_ok() - current_stage: ConsentStage = self.executor.current_stage application = self.executor.plan.context[PLAN_CONTEXT_APPLICATION] permissions = self.executor.plan.context.get( PLAN_CONTEXT_CONSENT_PERMISSIONS, [] @@ -139,9 +137,9 @@ class ConsentStageView(ChallengeStageView): ) consent: UserConsent = self.executor.plan.context[PLAN_CONTEXT_CONSENT] consent.permissions = permissions_string - if current_stage.mode == ConsentMode.PERMANENT: + if self.current_stage.mode == ConsentMode.PERMANENT: consent.expiring = False - if current_stage.mode == ConsentMode.EXPIRING: - consent.expires = now() + timedelta_from_string(current_stage.consent_expire_in) + if self.current_stage.mode == ConsentMode.EXPIRING: + consent.expires = now() + timedelta_from_string(self.current_stage.consent_expire_in) consent.save() return self.executor.stage_ok() diff --git a/authentik/stages/deny/stage.py b/authentik/stages/deny/stage.py index a5cd134661..4966f3d931 100644 --- a/authentik/stages/deny/stage.py +++ b/authentik/stages/deny/stage.py @@ -6,11 +6,10 @@ from authentik.flows.stage import StageView from authentik.stages.deny.models import DenyStage -class DenyStageView(StageView): +class DenyStageView(StageView[DenyStage]): """Cancels the current flow""" def dispatch(self, request: HttpRequest) -> HttpResponse: """Cancels the current flow""" - stage: DenyStage = self.executor.current_stage - message = self.executor.plan.context.get("deny_message", stage.deny_message) + message = self.executor.plan.context.get("deny_message", self.current_stage.deny_message) return self.executor.stage_invalid(message) diff --git a/authentik/stages/dummy/stage.py b/authentik/stages/dummy/stage.py index a3a32ebd34..3ec4b011ec 100644 --- a/authentik/stages/dummy/stage.py +++ b/authentik/stages/dummy/stage.py @@ -30,11 +30,11 @@ class DummyStageView(ChallengeStageView): return self.executor.stage_ok() def get_challenge(self, *args, **kwargs) -> Challenge: - if self.executor.current_stage.throw_error: + if self.current_stage.throw_error: raise SentryIgnoredException("Test error") return DummyChallenge( data={ - "title": self.executor.current_stage.name, - "name": self.executor.current_stage.name, + "title": self.current_stage.name, + "name": self.current_stage.name, } ) diff --git a/authentik/stages/email/stage.py b/authentik/stages/email/stage.py index 062dae62ef..7bbf260cb0 100644 --- a/authentik/stages/email/stage.py +++ b/authentik/stages/email/stage.py @@ -46,7 +46,7 @@ class EmailChallengeResponse(ChallengeResponse): raise ValidationError(detail="email-sent", code="email-sent") -class EmailStageView(ChallengeStageView): +class EmailStageView(ChallengeStageView[EmailStage]): """Email stage which sends Email for verification""" response_class = EmailChallengeResponse @@ -72,11 +72,10 @@ class EmailStageView(ChallengeStageView): def get_token(self) -> FlowToken: """Get token""" pending_user = self.get_pending_user() - current_stage: EmailStage = self.executor.current_stage valid_delta = timedelta( - minutes=current_stage.token_expiry + 1 + minutes=self.current_stage.token_expiry + 1 ) # + 1 because django timesince always rounds down - identifier = slugify(f"ak-email-stage-{current_stage.name}-{str(uuid4())}") + identifier = slugify(f"ak-email-stage-{self.current_stage.name}-{str(uuid4())}") # Don't check for validity here, we only care if the token exists tokens = FlowToken.objects.filter(identifier=identifier) if not tokens.exists(): @@ -105,15 +104,14 @@ class EmailStageView(ChallengeStageView): email = self.executor.plan.context.get(PLAN_CONTEXT_EMAIL_OVERRIDE, None) if not email: email = pending_user.email - current_stage: EmailStage = self.executor.current_stage token = self.get_token() # Send mail to user try: message = TemplateEmailMessage( - subject=_(current_stage.subject), + subject=_(self.current_stage.subject), to=[(pending_user.name, email)], language=pending_user.locale(self.request), - template_name=current_stage.template, + template_name=self.current_stage.template, template_context={ "url": self.get_full_url(**{QS_KEY_TOKEN: token.key}), "user": pending_user, @@ -121,26 +119,28 @@ class EmailStageView(ChallengeStageView): "token": token.key, }, ) - send_mails(current_stage, message) + send_mails(self.current_stage, message) except TemplateSyntaxError as exc: Event.new( EventAction.CONFIGURATION_ERROR, message=_("Exception occurred while rendering E-mail template"), error=exception_to_string(exc), - template=current_stage.template, + template=self.current_stage.template, ).from_http(self.request) raise StageInvalidException from exc def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: # Check if the user came back from the email link to verify - restore_token: FlowToken = self.executor.plan.context.get(PLAN_CONTEXT_IS_RESTORED, None) + restore_token: FlowToken | None = self.executor.plan.context.get( + PLAN_CONTEXT_IS_RESTORED, None + ) user = self.get_pending_user() if restore_token: if restore_token.user != user: self.logger.warning("Flow token for non-matching user, denying request") return self.executor.stage_invalid() messages.success(request, _("Successfully verified Email.")) - if self.executor.current_stage.activate_user_on_success: + if self.current_stage.activate_user_on_success: user.is_active = True user.save() return self.executor.stage_ok() diff --git a/authentik/stages/identification/stage.py b/authentik/stages/identification/stage.py index 1591949741..b3b0098415 100644 --- a/authentik/stages/identification/stage.py +++ b/authentik/stages/identification/stage.py @@ -164,22 +164,21 @@ class IdentificationChallengeResponse(ChallengeResponse): return attrs -class IdentificationStageView(ChallengeStageView): +class IdentificationStageView(ChallengeStageView[IdentificationStage]): """Form to identify the user""" response_class = IdentificationChallengeResponse def get_user(self, uid_value: str) -> User | None: """Find user instance. Returns None if no user was found.""" - current_stage: IdentificationStage = self.executor.current_stage query = Q() - for search_field in current_stage.user_fields: + for search_field in self.current_stage.user_fields: model_field = { "email": "email", "username": "username", "upn": "attributes__upn", }[search_field] - if current_stage.case_insensitive_matching: + if self.current_stage.case_insensitive_matching: model_field += "__iexact" else: model_field += "__exact" @@ -200,16 +199,15 @@ class IdentificationStageView(ChallengeStageView): return _("Continue") def get_challenge(self) -> Challenge: - current_stage: IdentificationStage = self.executor.current_stage challenge = IdentificationChallenge( data={ "component": "ak-stage-identification", "primary_action": self.get_primary_action(), - "user_fields": current_stage.user_fields, - "password_fields": bool(current_stage.password_stage), - "allow_show_password": bool(current_stage.password_stage) - and current_stage.password_stage.allow_show_password, - "show_source_labels": current_stage.show_source_labels, + "user_fields": self.current_stage.user_fields, + "password_fields": bool(self.current_stage.password_stage), + "allow_show_password": bool(self.current_stage.password_stage) + and self.current_stage.password_stage.allow_show_password, + "show_source_labels": self.current_stage.show_source_labels, "flow_designation": self.executor.flow.designation, } ) @@ -221,27 +219,26 @@ class IdentificationStageView(ChallengeStageView): ).name get_qs = self.request.session.get(SESSION_KEY_GET, self.request.GET) # Check for related enrollment and recovery flow, add URL to view - if current_stage.enrollment_flow: + if self.current_stage.enrollment_flow: challenge.initial_data["enroll_url"] = reverse_with_qs( "authentik_core:if-flow", query=get_qs, - kwargs={"flow_slug": current_stage.enrollment_flow.slug}, + kwargs={"flow_slug": self.current_stage.enrollment_flow.slug}, ) - if current_stage.recovery_flow: + if self.current_stage.recovery_flow: challenge.initial_data["recovery_url"] = reverse_with_qs( "authentik_core:if-flow", query=get_qs, - kwargs={"flow_slug": current_stage.recovery_flow.slug}, + kwargs={"flow_slug": self.current_stage.recovery_flow.slug}, ) - if current_stage.passwordless_flow: + if self.current_stage.passwordless_flow: challenge.initial_data["passwordless_url"] = reverse_with_qs( "authentik_core:if-flow", query=get_qs, - kwargs={"flow_slug": current_stage.passwordless_flow.slug}, + kwargs={"flow_slug": self.current_stage.passwordless_flow.slug}, ) - if current_stage.captcha_stage: - captcha = CaptchaStageView(self.executor) - captcha.stage = current_stage.captcha_stage + if self.current_stage.captcha_stage: + captcha = CaptchaStageView(self.executor, self.current_stage.captcha_stage) captcha_challenge = captcha.get_challenge() captcha_challenge.is_valid() challenge.initial_data["captcha_stage"] = captcha_challenge.data @@ -249,7 +246,7 @@ class IdentificationStageView(ChallengeStageView): # Check all enabled source, add them if they have a UI Login button. ui_sources = [] sources: list[Source] = ( - current_stage.sources.filter(enabled=True).order_by("name").select_subclasses() + self.current_stage.sources.filter(enabled=True).order_by("name").select_subclasses() ) for source in sources: ui_login_button = source.ui_login_button(self.request) @@ -264,8 +261,7 @@ class IdentificationStageView(ChallengeStageView): def challenge_valid(self, response: IdentificationChallengeResponse) -> HttpResponse: self.executor.plan.context[PLAN_CONTEXT_PENDING_USER] = response.pre_user - current_stage: IdentificationStage = self.executor.current_stage - if not current_stage.show_matched_user: + if not self.current_stage.show_matched_user: self.executor.plan.context[PLAN_CONTEXT_PENDING_USER_IDENTIFIER] = ( response.validated_data.get("uid_field") ) diff --git a/authentik/stages/invitation/stage.py b/authentik/stages/invitation/stage.py index 3a9cc5746c..c647e81e50 100644 --- a/authentik/stages/invitation/stage.py +++ b/authentik/stages/invitation/stage.py @@ -17,7 +17,7 @@ INVITATION_IN_EFFECT = "invitation_in_effect" INVITATION = "invitation" -class InvitationStageView(StageView): +class InvitationStageView(StageView[InvitationStage]): """Finalise Authentication flow by logging the user in""" def get_token(self) -> str | None: @@ -52,11 +52,10 @@ class InvitationStageView(StageView): def dispatch(self, request: HttpRequest) -> HttpResponse: """Apply data to the current flow based on a URL""" - stage: InvitationStage = self.executor.current_stage invite = self.get_invite() if not invite: - if stage.continue_flow_without_invitation: + if self.current_stage.continue_flow_without_invitation: return self.executor.stage_ok() return self.executor.stage_invalid(_("Invalid invite/invite not found")) diff --git a/authentik/stages/password/stage.py b/authentik/stages/password/stage.py index 873467cf80..9316696435 100644 --- a/authentik/stages/password/stage.py +++ b/authentik/stages/password/stage.py @@ -130,7 +130,7 @@ class PasswordChallengeResponse(ChallengeResponse): return password -class PasswordStageView(ChallengeStageView): +class PasswordStageView(ChallengeStageView[PasswordStage]): """Authentication stage which authenticates against django's AuthBackend""" response_class = PasswordChallengeResponse @@ -138,7 +138,7 @@ class PasswordStageView(ChallengeStageView): def get_challenge(self) -> Challenge: challenge = PasswordChallenge( data={ - "allow_show_password": self.executor.current_stage.allow_show_password, + "allow_show_password": self.current_stage.allow_show_password, } ) recovery_flow = Flow.objects.filter(designation=FlowDesignation.RECOVERY) @@ -154,10 +154,9 @@ class PasswordStageView(ChallengeStageView): if SESSION_KEY_INVALID_TRIES not in self.request.session: self.request.session[SESSION_KEY_INVALID_TRIES] = 0 self.request.session[SESSION_KEY_INVALID_TRIES] += 1 - current_stage: PasswordStage = self.executor.current_stage if ( self.request.session[SESSION_KEY_INVALID_TRIES] - >= current_stage.failed_attempts_before_cancel + >= self.current_stage.failed_attempts_before_cancel ): self.logger.debug("User has exceeded maximum tries") del self.request.session[SESSION_KEY_INVALID_TRIES] diff --git a/authentik/stages/prompt/stage.py b/authentik/stages/prompt/stage.py index 987a164446..1d57f70620 100644 --- a/authentik/stages/prompt/stage.py +++ b/authentik/stages/prompt/stage.py @@ -222,7 +222,7 @@ class PromptStageView(ChallengeStageView): return serializers def get_challenge(self, *args, **kwargs) -> Challenge: - fields: list[Prompt] = list(self.executor.current_stage.fields.all().order_by("order")) + fields: list[Prompt] = list(self.current_stage.fields.all().order_by("order")) context_prompt = self.executor.plan.context.get(PLAN_CONTEXT_PROMPT, {}) serializers = self.get_prompt_challenge_fields(fields, context_prompt) challenge = PromptChallenge( @@ -239,7 +239,7 @@ class PromptStageView(ChallengeStageView): instance=None, data=data, request=self.request, - stage_instance=self.executor.current_stage, + stage_instance=self.current_stage, stage=self, plan=self.executor.plan, user=self.get_pending_user(), diff --git a/authentik/stages/user_delete/stage.py b/authentik/stages/user_delete/stage.py index 3ea73f1268..da043e958b 100644 --- a/authentik/stages/user_delete/stage.py +++ b/authentik/stages/user_delete/stage.py @@ -7,9 +7,10 @@ from django.utils.translation import gettext as _ from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER from authentik.flows.stage import StageView +from authentik.stages.user_delete.models import UserDeleteStage -class UserDeleteStageView(StageView): +class UserDeleteStageView(StageView[UserDeleteStage]): """Finalise unenrollment flow by deleting the user object.""" def dispatch(self, request: HttpRequest) -> HttpResponse: diff --git a/authentik/stages/user_login/stage.py b/authentik/stages/user_login/stage.py index 5f180a977c..0cf1811da2 100644 --- a/authentik/stages/user_login/stage.py +++ b/authentik/stages/user_login/stage.py @@ -39,7 +39,7 @@ class UserLoginChallengeResponse(ChallengeResponse): remember_me = BooleanField(required=True) -class UserLoginStageView(ChallengeStageView): +class UserLoginStageView(ChallengeStageView[UserLoginStage]): """Finalise Authentication flow by logging the user in""" response_class = UserLoginChallengeResponse @@ -49,8 +49,7 @@ class UserLoginStageView(ChallengeStageView): def dispatch(self, request: HttpRequest) -> HttpResponse: """Check for remember_me, and do login""" - stage: UserLoginStage = self.executor.current_stage - if timedelta_from_string(stage.remember_me_offset).total_seconds() > 0: + if timedelta_from_string(self.current_stage.remember_me_offset).total_seconds() > 0: return super().dispatch(request) return self.do_login(request) @@ -59,9 +58,9 @@ class UserLoginStageView(ChallengeStageView): def set_session_duration(self, remember: bool) -> timedelta: """Update the sessions' expiry""" - delta = timedelta_from_string(self.executor.current_stage.session_duration) + delta = timedelta_from_string(self.current_stage.session_duration) if remember: - offset = timedelta_from_string(self.executor.current_stage.remember_me_offset) + offset = timedelta_from_string(self.current_stage.remember_me_offset) delta = delta + offset if delta.total_seconds() == 0: self.request.session.set_expiry(0) @@ -71,11 +70,9 @@ class UserLoginStageView(ChallengeStageView): def set_session_ip(self): """Set the sessions' last IP and session bindings""" - stage: UserLoginStage = self.executor.current_stage - self.request.session[SESSION_KEY_LAST_IP] = ClientIPMiddleware.get_client_ip(self.request) - self.request.session[SESSION_KEY_BINDING_NET] = stage.network_binding - self.request.session[SESSION_KEY_BINDING_GEO] = stage.geoip_binding + self.request.session[SESSION_KEY_BINDING_NET] = self.current_stage.network_binding + self.request.session[SESSION_KEY_BINDING_GEO] = self.current_stage.geoip_binding def do_login(self, request: HttpRequest, remember: bool = False) -> HttpResponse: """Attach the currently pending user to the current session""" @@ -111,7 +108,7 @@ class UserLoginStageView(ChallengeStageView): # as sources show their own success messages if not self.executor.plan.context.get(PLAN_CONTEXT_SOURCE, None): messages.success(self.request, _("Successfully logged in!")) - if self.executor.current_stage.terminate_other_sessions: + if self.current_stage.terminate_other_sessions: AuthenticatedSession.objects.filter( user=user, ).exclude(session_key=self.request.session.session_key).delete() diff --git a/authentik/stages/user_logout/stage.py b/authentik/stages/user_logout/stage.py index 84bba66dd3..9cf250aa0c 100644 --- a/authentik/stages/user_logout/stage.py +++ b/authentik/stages/user_logout/stage.py @@ -4,9 +4,10 @@ from django.contrib.auth import logout from django.http import HttpRequest, HttpResponse from authentik.flows.stage import StageView +from authentik.stages.user_logout.models import UserLogoutStage -class UserLogoutStageView(StageView): +class UserLogoutStageView(StageView[UserLogoutStage]): """Finalise Authentication flow by logging the user in""" def dispatch(self, request: HttpRequest) -> HttpResponse: diff --git a/authentik/stages/user_write/stage.py b/authentik/stages/user_write/stage.py index 71240171eb..be5b4e2cb0 100644 --- a/authentik/stages/user_write/stage.py +++ b/authentik/stages/user_write/stage.py @@ -55,7 +55,7 @@ class UserWriteStageView(StageView): """Ensure a user exists""" user_created = False path = self.executor.plan.context.get( - PLAN_CONTEXT_USER_PATH, self.executor.current_stage.user_path_template + PLAN_CONTEXT_USER_PATH, self.current_stage.user_path_template ) if path == "": path = User.default_path() @@ -64,11 +64,11 @@ class UserWriteStageView(StageView): user_type = UserTypes( self.executor.plan.context.get( PLAN_CONTEXT_USER_TYPE, - self.executor.current_stage.user_type, + self.current_stage.user_type, ) ) except ValueError: - user_type = self.executor.current_stage.user_type + user_type = self.current_stage.user_type if user_type == UserTypes.INTERNAL_SERVICE_ACCOUNT: user_type = UserTypes.SERVICE_ACCOUNT @@ -76,12 +76,12 @@ class UserWriteStageView(StageView): self.executor.plan.context.setdefault(PLAN_CONTEXT_PENDING_USER, self.request.user) if ( PLAN_CONTEXT_PENDING_USER not in self.executor.plan.context - or self.executor.current_stage.user_creation_mode == UserCreationMode.ALWAYS_CREATE + or self.current_stage.user_creation_mode == UserCreationMode.ALWAYS_CREATE ): - if self.executor.current_stage.user_creation_mode == UserCreationMode.NEVER_CREATE: + if self.current_stage.user_creation_mode == UserCreationMode.NEVER_CREATE: return None, False self.executor.plan.context[PLAN_CONTEXT_PENDING_USER] = User( - is_active=not self.executor.current_stage.create_users_as_inactive, + is_active=not self.current_stage.create_users_as_inactive, path=path, type=user_type, ) @@ -180,8 +180,8 @@ class UserWriteStageView(StageView): try: with transaction.atomic(): user.save() - if self.executor.current_stage.create_users_group: - user.ak_groups.add(self.executor.current_stage.create_users_group) + if self.current_stage.create_users_group: + user.ak_groups.add(self.current_stage.create_users_group) if PLAN_CONTEXT_GROUPS in self.executor.plan.context: user.ak_groups.add(*self.executor.plan.context[PLAN_CONTEXT_GROUPS]) except (IntegrityError, ValueError, TypeError, InternalError) as exc: