refactor from self.executor.current_stage to make nesting easier

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens Langhammer
2024-08-24 15:59:31 +02:00
parent 2149e81d8f
commit 1e6e4a0bbc
23 changed files with 144 additions and 159 deletions

View File

@ -69,8 +69,8 @@ class MessageStage(StageView):
def dispatch(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: def dispatch(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
"""Show a pre-configured message after the flow is done""" """Show a pre-configured message after the flow is done"""
message = getattr(self.executor.current_stage, "message", "") message = getattr(self.current_stage, "message", "")
level = getattr(self.executor.current_stage, "level", messages.SUCCESS) level = getattr(self.current_stage, "level", messages.SUCCESS)
messages.add_message( messages.add_message(
self.request, self.request,
level, level,
@ -486,9 +486,7 @@ class GroupUpdateStage(StageView):
def handle_groups(self) -> bool: def handle_groups(self) -> bool:
self.source: Source = self.executor.plan.context[PLAN_CONTEXT_SOURCE] self.source: Source = self.executor.plan.context[PLAN_CONTEXT_SOURCE]
self.user: User = self.executor.plan.context[PLAN_CONTEXT_PENDING_USER] self.user: User = self.executor.plan.context[PLAN_CONTEXT_PENDING_USER]
self.group_connection_type: GroupSourceConnection = ( self.group_connection_type: GroupSourceConnection = self.current_stage.group_connection_type
self.executor.current_stage.group_connection_type
)
raw_groups: dict[str, dict[str, Any | dict[str, Any]]] = self.executor.plan.context[ raw_groups: dict[str, dict[str, Any | dict[str, Any]]] = self.executor.plan.context[
PLAN_CONTEXT_SOURCE_GROUPS PLAN_CONTEXT_SOURCE_GROUPS

View File

@ -91,9 +91,9 @@ class RACFinalStage(RedirectStage):
application: Application application: Application
def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse: def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse:
self.endpoint = self.executor.current_stage.endpoint self.endpoint = self.current_stage.endpoint
self.provider = self.executor.current_stage.provider self.provider = self.current_stage.provider
self.application = self.executor.current_stage.application self.application = self.current_stage.application
# Check policies bound to endpoint directly # Check policies bound to endpoint directly
engine = PolicyEngine(self.endpoint, self.request.user, self.request) engine = PolicyEngine(self.endpoint, self.request.user, self.request)
engine.use_cache = False engine.use_cache = False
@ -132,7 +132,7 @@ class RACFinalStage(RedirectStage):
flow=self.executor.plan.flow_pk, flow=self.executor.plan.flow_pk,
endpoint=self.endpoint.name, endpoint=self.endpoint.name,
).from_http(self.request) ).from_http(self.request)
self.executor.current_stage.destination = self.request.build_absolute_uri( self.current_stage.destination = self.request.build_absolute_uri(
reverse("authentik_providers_rac:if-rac", kwargs={"token": str(token.token)}) reverse("authentik_providers_rac:if-rac", kwargs={"token": str(token.token)})
) )
return super().get_challenge(*args, **kwargs) return super().get_challenge(*args, **kwargs)

View File

@ -21,16 +21,15 @@ from authentik.lib.utils.time import timedelta_from_string
PLAN_CONTEXT_RESUME_TOKEN = "resume_token" # nosec PLAN_CONTEXT_RESUME_TOKEN = "resume_token" # nosec
class SourceStageView(ChallengeStageView): class SourceStageView(ChallengeStageView[SourceStage]):
"""Suspend the current flow execution and send the user to a source, """Suspend the current flow execution and send the user to a source,
after which this flow execution is resumed.""" after which this flow execution is resumed."""
login_button: UILoginButton login_button: UILoginButton
def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse: def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse:
current_stage: SourceStage = self.executor.current_stage
source: Source = ( source: Source = (
Source.objects.filter(pk=current_stage.source_id).select_subclasses().first() Source.objects.filter(pk=self.current_stage.source_id).select_subclasses().first()
) )
if not source: if not source:
self.logger.warning("Source does not exist") self.logger.warning("Source does not exist")
@ -56,11 +55,10 @@ class SourceStageView(ChallengeStageView):
pending_user: User = self.get_pending_user() pending_user: User = self.get_pending_user()
if pending_user.is_anonymous or not pending_user.pk: if pending_user.is_anonymous or not pending_user.pk:
pending_user = get_anonymous_user() pending_user = get_anonymous_user()
current_stage: SourceStage = self.executor.current_stage identifier = slugify(f"ak-source-stage-{self.current_stage.name}-{str(uuid4())}")
identifier = slugify(f"ak-source-stage-{current_stage.name}-{str(uuid4())}")
# Don't check for validity here, we only care if the token exists # Don't check for validity here, we only care if the token exists
tokens = FlowToken.objects.filter(identifier=identifier) tokens = FlowToken.objects.filter(identifier=identifier)
valid_delta = timedelta_from_string(current_stage.resume_timeout) valid_delta = timedelta_from_string(self.current_stage.resume_timeout)
if not tokens.exists(): if not tokens.exists():
return FlowToken.objects.create( return FlowToken.objects.create(
expires=now() + valid_delta, expires=now() + valid_delta,

View File

@ -30,6 +30,7 @@ from authentik.lib.avatars import DEFAULT_AVATAR, get_avatar
from authentik.lib.utils.reflection import class_to_path from authentik.lib.utils.reflection import class_to_path
if TYPE_CHECKING: if TYPE_CHECKING:
from authentik.flows.models import Stage
from authentik.flows.views.executor import FlowExecutorView from authentik.flows.views.executor import FlowExecutorView
PLAN_CONTEXT_PENDING_USER_IDENTIFIER = "pending_user_identifier" PLAN_CONTEXT_PENDING_USER_IDENTIFIER = "pending_user_identifier"
@ -40,20 +41,21 @@ HIST_FLOWS_STAGE_TIME = Histogram(
) )
class StageView(View): class StageView[TStage: "Stage"](View):
"""Abstract Stage""" """Abstract Stage"""
executor: "FlowExecutorView" executor: "FlowExecutorView"
current_stage: TStage
request: HttpRequest = None request: HttpRequest = None
logger: BoundLogger logger: BoundLogger
def __init__(self, executor: "FlowExecutorView", **kwargs): def __init__(self, executor: "FlowExecutorView", current_stage: TStage | None = None, **kwargs):
self.executor = executor self.executor = executor
current_stage = getattr(self.executor, "current_stage", None) self.current_stage = current_stage or executor.current_stage
self.logger = get_logger().bind( self.logger = get_logger().bind(
stage=getattr(current_stage, "name", None), stage=getattr(self.current_stage, "name", None),
stage_view=class_to_path(type(self)), stage_view=class_to_path(type(self)),
) )
super().__init__(**kwargs) super().__init__(**kwargs)
@ -80,7 +82,7 @@ class StageView(View):
"""Cleanup session""" """Cleanup session"""
class ChallengeStageView(StageView): class ChallengeStageView[TStage: "Stage"](StageView[TStage]):
"""Stage view which response with a challenge""" """Stage view which response with a challenge"""
response_class = ChallengeResponse response_class = ChallengeResponse
@ -258,7 +260,7 @@ class RedirectStage(ChallengeStageView):
def get_challenge(self, *args, **kwargs) -> RedirectChallenge: def get_challenge(self, *args, **kwargs) -> RedirectChallenge:
destination = getattr( destination = getattr(
self.executor.current_stage, "destination", reverse("authentik_core:root-redirect") self.current_stage, "destination", reverse("authentik_core:root-redirect")
) )
return RedirectChallenge( return RedirectChallenge(
data={ data={

View File

@ -32,7 +32,7 @@ class AuthenticatorDuoChallengeResponse(ChallengeResponse):
component = CharField(default="ak-stage-authenticator-duo") component = CharField(default="ak-stage-authenticator-duo")
class AuthenticatorDuoStageView(ChallengeStageView): class AuthenticatorDuoStageView(ChallengeStageView[AuthenticatorDuoStage]):
"""Duo stage""" """Duo stage"""
response_class = AuthenticatorDuoChallengeResponse response_class = AuthenticatorDuoChallengeResponse
@ -40,9 +40,8 @@ class AuthenticatorDuoStageView(ChallengeStageView):
def duo_enroll(self): def duo_enroll(self):
"""Enroll User with Duo API and save results""" """Enroll User with Duo API and save results"""
user = self.get_pending_user() user = self.get_pending_user()
stage: AuthenticatorDuoStage = self.executor.current_stage
try: try:
enroll = stage.auth_client().enroll(user.username) enroll = self.current_stage.auth_client().enroll(user.username)
except RuntimeError as exc: except RuntimeError as exc:
Event.new( Event.new(
EventAction.CONFIGURATION_ERROR, EventAction.CONFIGURATION_ERROR,
@ -54,7 +53,6 @@ class AuthenticatorDuoStageView(ChallengeStageView):
return enroll return enroll
def get_challenge(self, *args, **kwargs) -> Challenge: def get_challenge(self, *args, **kwargs) -> Challenge:
stage: AuthenticatorDuoStage = self.executor.current_stage
if SESSION_KEY_DUO_ENROLL not in self.request.session: if SESSION_KEY_DUO_ENROLL not in self.request.session:
self.duo_enroll() self.duo_enroll()
enroll = self.request.session[SESSION_KEY_DUO_ENROLL] enroll = self.request.session[SESSION_KEY_DUO_ENROLL]
@ -62,15 +60,14 @@ class AuthenticatorDuoStageView(ChallengeStageView):
data={ data={
"activation_barcode": enroll["activation_barcode"], "activation_barcode": enroll["activation_barcode"],
"activation_code": enroll["activation_code"], "activation_code": enroll["activation_code"],
"stage_uuid": str(stage.stage_uuid), "stage_uuid": str(self.current_stage.stage_uuid),
} }
) )
def challenge_valid(self, response: ChallengeResponse) -> HttpResponse: def challenge_valid(self, response: ChallengeResponse) -> HttpResponse:
# Duo Challenge has already been validated # Duo Challenge has already been validated
stage: AuthenticatorDuoStage = self.executor.current_stage
enroll = self.request.session.get(SESSION_KEY_DUO_ENROLL) enroll = self.request.session.get(SESSION_KEY_DUO_ENROLL)
enroll_status = stage.auth_client().enroll_status( enroll_status = self.current_stage.auth_client().enroll_status(
enroll["user_id"], enroll["activation_code"] enroll["user_id"], enroll["activation_code"]
) )
if enroll_status != "success": if enroll_status != "success":
@ -82,7 +79,7 @@ class AuthenticatorDuoStageView(ChallengeStageView):
name="Duo Authenticator", name="Duo Authenticator",
user=self.get_pending_user(), user=self.get_pending_user(),
duo_user_id=enroll["user_id"], duo_user_id=enroll["user_id"],
stage=stage, stage=self.current_stage,
last_t=now(), last_t=now(),
) )
else: else:

View File

@ -57,21 +57,20 @@ class AuthenticatorSMSChallengeResponse(ChallengeResponse):
return super().validate(attrs) return super().validate(attrs)
class AuthenticatorSMSStageView(ChallengeStageView): class AuthenticatorSMSStageView(ChallengeStageView[AuthenticatorSMSStage]):
"""OTP sms Setup stage""" """OTP sms Setup stage"""
response_class = AuthenticatorSMSChallengeResponse response_class = AuthenticatorSMSChallengeResponse
def validate_and_send(self, phone_number: str): def validate_and_send(self, phone_number: str):
"""Validate phone number and send message""" """Validate phone number and send message"""
stage: AuthenticatorSMSStage = self.executor.current_stage
hashed_number = hash_phone_number(phone_number) hashed_number = hash_phone_number(phone_number)
query = Q(phone_number=hashed_number) | Q(phone_number=phone_number) query = Q(phone_number=hashed_number) | Q(phone_number=phone_number)
if SMSDevice.objects.filter(query, stage=stage.pk).exists(): if SMSDevice.objects.filter(query, stage=self.current_stage.pk).exists():
raise ValidationError(_("Invalid phone number")) raise ValidationError(_("Invalid phone number"))
# No code yet, but we have a phone number, so send a verification message # No code yet, but we have a phone number, so send a verification message
device: SMSDevice = self.request.session[SESSION_KEY_SMS_DEVICE] device: SMSDevice = self.request.session[SESSION_KEY_SMS_DEVICE]
stage.send(device.token, device) self.current_stage.send(device.token, device)
def _has_phone_number(self) -> str | None: def _has_phone_number(self) -> str | None:
context = self.executor.plan.context context = self.executor.plan.context
@ -101,10 +100,10 @@ class AuthenticatorSMSStageView(ChallengeStageView):
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
user = self.get_pending_user() user = self.get_pending_user()
stage: AuthenticatorSMSStage = self.executor.current_stage
if SESSION_KEY_SMS_DEVICE not in self.request.session: if SESSION_KEY_SMS_DEVICE not in self.request.session:
device = SMSDevice(user=user, confirmed=False, stage=stage, name="SMS Device") device = SMSDevice(
user=user, confirmed=False, stage=self.current_stage, name="SMS Device"
)
device.generate_token(commit=False) device.generate_token(commit=False)
self.request.session[SESSION_KEY_SMS_DEVICE] = device self.request.session[SESSION_KEY_SMS_DEVICE] = device
if phone_number := self._has_phone_number(): if phone_number := self._has_phone_number():
@ -130,8 +129,7 @@ class AuthenticatorSMSStageView(ChallengeStageView):
device: SMSDevice = self.request.session[SESSION_KEY_SMS_DEVICE] device: SMSDevice = self.request.session[SESSION_KEY_SMS_DEVICE]
if not device.confirmed: if not device.confirmed:
return self.challenge_invalid(response) return self.challenge_invalid(response)
stage: AuthenticatorSMSStage = self.executor.current_stage if self.current_stage.verify_only:
if stage.verify_only:
self.logger.debug("Hashing number on device") self.logger.debug("Hashing number on device")
device.set_hashed_number() device.set_hashed_number()
device.save() device.save()

View File

@ -29,7 +29,7 @@ class AuthenticatorStaticChallengeResponse(ChallengeResponse):
component = CharField(default="ak-stage-authenticator-static") component = CharField(default="ak-stage-authenticator-static")
class AuthenticatorStaticStageView(ChallengeStageView): class AuthenticatorStaticStageView(ChallengeStageView[AuthenticatorStaticStage]):
"""Static OTP Setup stage""" """Static OTP Setup stage"""
response_class = AuthenticatorStaticChallengeResponse response_class = AuthenticatorStaticChallengeResponse
@ -48,14 +48,14 @@ class AuthenticatorStaticStageView(ChallengeStageView):
self.logger.debug("No pending user, continuing") self.logger.debug("No pending user, continuing")
return self.executor.stage_ok() return self.executor.stage_ok()
stage: AuthenticatorStaticStage = self.executor.current_stage
if SESSION_STATIC_DEVICE not in self.request.session: if SESSION_STATIC_DEVICE not in self.request.session:
device = StaticDevice(user=user, confirmed=False, name="Static Token") device = StaticDevice(user=user, confirmed=False, name="Static Token")
tokens = [] tokens = []
for _ in range(0, stage.token_count): for _ in range(0, self.current_stage.token_count):
tokens.append( tokens.append(
StaticToken(device=device, token=generate_id(length=stage.token_length)) StaticToken(
device=device, token=generate_id(length=self.current_stage.token_length)
)
) )
self.request.session[SESSION_STATIC_DEVICE] = device self.request.session[SESSION_STATIC_DEVICE] = device
self.request.session[SESSION_STATIC_TOKENS] = tokens self.request.session[SESSION_STATIC_TOKENS] = tokens

View File

@ -45,7 +45,7 @@ class AuthenticatorTOTPChallengeResponse(ChallengeResponse):
return code return code
class AuthenticatorTOTPStageView(ChallengeStageView): class AuthenticatorTOTPStageView(ChallengeStageView[AuthenticatorTOTPStage]):
"""OTP totp Setup stage""" """OTP totp Setup stage"""
response_class = AuthenticatorTOTPChallengeResponse response_class = AuthenticatorTOTPChallengeResponse
@ -71,11 +71,12 @@ class AuthenticatorTOTPStageView(ChallengeStageView):
self.logger.debug("No pending user, continuing") self.logger.debug("No pending user, continuing")
return self.executor.stage_ok() return self.executor.stage_ok()
stage: AuthenticatorTOTPStage = self.executor.current_stage
if SESSION_TOTP_DEVICE not in self.request.session: if SESSION_TOTP_DEVICE not in self.request.session:
device = TOTPDevice( device = TOTPDevice(
user=user, confirmed=False, digits=stage.digits, name="TOTP Authenticator" user=user,
confirmed=False,
digits=self.current_stage.digits,
name="TOTP Authenticator",
) )
self.request.session[SESSION_TOTP_DEVICE] = device self.request.session[SESSION_TOTP_DEVICE] = device

View File

@ -151,7 +151,7 @@ class AuthenticatorValidationChallengeResponse(ChallengeResponse):
return attrs return attrs
class AuthenticatorValidateStageView(ChallengeStageView): class AuthenticatorValidateStageView(ChallengeStageView[AuthenticatorValidateStage]):
"""Authenticator Validation""" """Authenticator Validation"""
response_class = AuthenticatorValidationChallengeResponse response_class = AuthenticatorValidationChallengeResponse
@ -177,16 +177,14 @@ class AuthenticatorValidateStageView(ChallengeStageView):
# since their challenges are device-independent # since their challenges are device-independent
seen_classes = [] seen_classes = []
stage: AuthenticatorValidateStage = self.executor.current_stage threshold = timedelta_from_string(self.current_stage.last_auth_threshold)
threshold = timedelta_from_string(stage.last_auth_threshold)
allowed_devices = [] allowed_devices = []
has_webauthn_filters_set = stage.webauthn_allowed_device_types.exists() has_webauthn_filters_set = self.current_stage.webauthn_allowed_device_types.exists()
for device in user_devices: for device in user_devices:
device_class = device.__class__.__name__.lower().replace("device", "") device_class = device.__class__.__name__.lower().replace("device", "")
if device_class not in stage.device_classes: if device_class not in self.current_stage.device_classes:
self.logger.debug("device class not allowed", device_class=device_class) self.logger.debug("device class not allowed", device_class=device_class)
continue continue
if isinstance(device, SMSDevice) and device.is_hashed: if isinstance(device, SMSDevice) and device.is_hashed:
@ -199,7 +197,7 @@ class AuthenticatorValidateStageView(ChallengeStageView):
and device.device_type and device.device_type
and has_webauthn_filters_set and has_webauthn_filters_set
): ):
if not stage.webauthn_allowed_device_types.filter( if not self.current_stage.webauthn_allowed_device_types.filter(
pk=device.device_type.pk pk=device.device_type.pk
).exists(): ).exists():
self.logger.debug( self.logger.debug(
@ -216,7 +214,7 @@ class AuthenticatorValidateStageView(ChallengeStageView):
data={ data={
"device_class": device_class, "device_class": device_class,
"device_uid": device.pk, "device_uid": device.pk,
"challenge": get_challenge_for_device(self.request, stage, device), "challenge": get_challenge_for_device(self.request, self.current_stage, device),
} }
) )
challenge.is_valid() challenge.is_valid()
@ -235,7 +233,7 @@ class AuthenticatorValidateStageView(ChallengeStageView):
"device_uid": -1, "device_uid": -1,
"challenge": get_webauthn_challenge_without_user( "challenge": get_webauthn_challenge_without_user(
self.request, self.request,
self.executor.current_stage, self.current_stage,
), ),
} }
) )
@ -246,7 +244,6 @@ class AuthenticatorValidateStageView(ChallengeStageView):
"""Check if a user is set, and check if the user has any devices """Check if a user is set, and check if the user has any devices
if not, we can skip this entire stage""" if not, we can skip this entire stage"""
user = self.get_pending_user() user = self.get_pending_user()
stage: AuthenticatorValidateStage = self.executor.current_stage
if user and not user.is_anonymous: if user and not user.is_anonymous:
try: try:
challenges = self.get_device_challenges() challenges = self.get_device_challenges()
@ -257,7 +254,7 @@ class AuthenticatorValidateStageView(ChallengeStageView):
self.logger.debug("Refusing passwordless flow in non-authentication flow") self.logger.debug("Refusing passwordless flow in non-authentication flow")
return self.executor.stage_ok() return self.executor.stage_ok()
# Passwordless auth, with just webauthn # Passwordless auth, with just webauthn
if DeviceClasses.WEBAUTHN in stage.device_classes: if DeviceClasses.WEBAUTHN in self.current_stage.device_classes:
self.logger.debug("Flow without user, getting generic webauthn challenge") self.logger.debug("Flow without user, getting generic webauthn challenge")
challenges = self.get_webauthn_challenge_without_user() challenges = self.get_webauthn_challenge_without_user()
else: else:
@ -267,13 +264,13 @@ class AuthenticatorValidateStageView(ChallengeStageView):
# No allowed devices # No allowed devices
if len(challenges) < 1: if len(challenges) < 1:
if stage.not_configured_action == NotConfiguredAction.SKIP: if self.current_stage.not_configured_action == NotConfiguredAction.SKIP:
self.logger.debug("Authenticator not configured, skipping stage") self.logger.debug("Authenticator not configured, skipping stage")
return self.executor.stage_ok() return self.executor.stage_ok()
if stage.not_configured_action == NotConfiguredAction.DENY: if self.current_stage.not_configured_action == NotConfiguredAction.DENY:
self.logger.debug("Authenticator not configured, denying") self.logger.debug("Authenticator not configured, denying")
return self.executor.stage_invalid(_("No (allowed) MFA authenticator configured.")) return self.executor.stage_invalid(_("No (allowed) MFA authenticator configured."))
if stage.not_configured_action == NotConfiguredAction.CONFIGURE: if self.current_stage.not_configured_action == NotConfiguredAction.CONFIGURE:
self.logger.debug("Authenticator not configured, forcing configure") self.logger.debug("Authenticator not configured, forcing configure")
return self.prepare_stages(user) return self.prepare_stages(user)
return super().get(request, *args, **kwargs) return super().get(request, *args, **kwargs)
@ -282,8 +279,7 @@ class AuthenticatorValidateStageView(ChallengeStageView):
"""Check how the user can configure themselves. If no stages are set, return an error. """Check how the user can configure themselves. If no stages are set, return an error.
If a single stage is set, insert that stage directly. If multiple are selected, include If a single stage is set, insert that stage directly. If multiple are selected, include
them in the challenge.""" them in the challenge."""
stage: AuthenticatorValidateStage = self.executor.current_stage if not self.current_stage.configuration_stages.exists():
if not stage.configuration_stages.exists():
Event.new( Event.new(
EventAction.CONFIGURATION_ERROR, EventAction.CONFIGURATION_ERROR,
message=( message=(
@ -293,15 +289,19 @@ class AuthenticatorValidateStageView(ChallengeStageView):
stage=self, stage=self,
).from_http(self.request).set_user(user).save() ).from_http(self.request).set_user(user).save()
return self.executor.stage_invalid() return self.executor.stage_invalid()
if stage.configuration_stages.count() == 1: if self.current_stage.configuration_stages.count() == 1:
next_stage = Stage.objects.get_subclass(pk=stage.configuration_stages.first().pk) next_stage = Stage.objects.get_subclass(
pk=self.current_stage.configuration_stages.first().pk
)
self.logger.debug("Single stage configured, auto-selecting", stage=next_stage) self.logger.debug("Single stage configured, auto-selecting", stage=next_stage)
self.executor.plan.context[PLAN_CONTEXT_SELECTED_STAGE] = next_stage self.executor.plan.context[PLAN_CONTEXT_SELECTED_STAGE] = next_stage
# Because that normal execution only happens on post, we directly inject it here and # Because that normal execution only happens on post, we directly inject it here and
# return it # return it
self.executor.plan.insert_stage(next_stage) self.executor.plan.insert_stage(next_stage)
return self.executor.stage_ok() return self.executor.stage_ok()
stages = Stage.objects.filter(pk__in=stage.configuration_stages.all()).select_subclasses() stages = Stage.objects.filter(
pk__in=self.current_stage.configuration_stages.all()
).select_subclasses()
self.executor.plan.context[PLAN_CONTEXT_STAGES] = stages self.executor.plan.context[PLAN_CONTEXT_STAGES] = stages
return super().get(self.request, *args, **kwargs) return super().get(self.request, *args, **kwargs)
@ -309,7 +309,7 @@ class AuthenticatorValidateStageView(ChallengeStageView):
res = super().post(request, *args, **kwargs) res = super().post(request, *args, **kwargs)
if ( if (
PLAN_CONTEXT_SELECTED_STAGE in self.executor.plan.context PLAN_CONTEXT_SELECTED_STAGE in self.executor.plan.context
and self.executor.current_stage.not_configured_action == NotConfiguredAction.CONFIGURE and self.current_stage.not_configured_action == NotConfiguredAction.CONFIGURE
): ):
self.logger.debug("Got selected stage in context, running that") self.logger.debug("Got selected stage in context, running that")
stage_pk = self.executor.plan.context.get(PLAN_CONTEXT_SELECTED_STAGE) stage_pk = self.executor.plan.context.get(PLAN_CONTEXT_SELECTED_STAGE)
@ -351,7 +351,7 @@ class AuthenticatorValidateStageView(ChallengeStageView):
def cookie_jwt_key(self) -> str: def cookie_jwt_key(self) -> str:
"""Signing key for MFA Cookie for this stage""" """Signing key for MFA Cookie for this stage"""
return sha256( return sha256(
f"{get_unique_identifier()}:{self.executor.current_stage.pk.hex}".encode("ascii") f"{get_unique_identifier()}:{self.current_stage.pk.hex}".encode("ascii")
).hexdigest() ).hexdigest()
def check_mfa_cookie(self, allowed_devices: list[Device]): def check_mfa_cookie(self, allowed_devices: list[Device]):
@ -362,12 +362,11 @@ class AuthenticatorValidateStageView(ChallengeStageView):
correct user and with an allowed class""" correct user and with an allowed class"""
if COOKIE_NAME_MFA not in self.request.COOKIES: if COOKIE_NAME_MFA not in self.request.COOKIES:
return return
stage: AuthenticatorValidateStage = self.executor.current_stage threshold = timedelta_from_string(self.current_stage.last_auth_threshold)
threshold = timedelta_from_string(stage.last_auth_threshold)
latest_allowed = datetime.now() + threshold latest_allowed = datetime.now() + threshold
try: try:
payload = decode(self.request.COOKIES[COOKIE_NAME_MFA], self.cookie_jwt_key, ["HS256"]) payload = decode(self.request.COOKIES[COOKIE_NAME_MFA], self.cookie_jwt_key, ["HS256"])
if payload["stage"] != stage.pk.hex: if payload["stage"] != self.current_stage.pk.hex:
self.logger.warning("Invalid stage PK") self.logger.warning("Invalid stage PK")
return return
if datetime.fromtimestamp(payload["exp"]) > latest_allowed: if datetime.fromtimestamp(payload["exp"]) > latest_allowed:
@ -385,15 +384,14 @@ class AuthenticatorValidateStageView(ChallengeStageView):
"""Set an MFA cookie to allow users to skip MFA validation in this context (browser) """Set an MFA cookie to allow users to skip MFA validation in this context (browser)
The cookie is JWT which is signed with a hash of the secret key and the UID of the stage""" The cookie is JWT which is signed with a hash of the secret key and the UID of the stage"""
stage: AuthenticatorValidateStage = self.executor.current_stage delta = timedelta_from_string(self.current_stage.last_auth_threshold)
delta = timedelta_from_string(stage.last_auth_threshold)
if delta.total_seconds() < 1: if delta.total_seconds() < 1:
self.logger.info("Not setting MFA cookie since threshold is not set.") self.logger.info("Not setting MFA cookie since threshold is not set.")
return self.executor.stage_ok() return self.executor.stage_ok()
expiry = datetime.now() + delta expiry = datetime.now() + delta
cookie_payload = { cookie_payload = {
"device": device.pk, "device": device.pk,
"stage": stage.pk.hex, "stage": self.current_stage.pk.hex,
"exp": expiry.timestamp(), "exp": expiry.timestamp(),
} }
response = self.executor.stage_ok() response = self.executor.stage_ok()

View File

@ -108,7 +108,7 @@ class AuthenticatorWebAuthnChallengeResponse(ChallengeResponse):
return registration return registration
class AuthenticatorWebAuthnStageView(ChallengeStageView): class AuthenticatorWebAuthnStageView(ChallengeStageView[AuthenticatorWebAuthnStage]):
"""WebAuthn stage""" """WebAuthn stage"""
response_class = AuthenticatorWebAuthnChallengeResponse response_class = AuthenticatorWebAuthnChallengeResponse
@ -116,12 +116,11 @@ class AuthenticatorWebAuthnStageView(ChallengeStageView):
def get_challenge(self, *args, **kwargs) -> Challenge: def get_challenge(self, *args, **kwargs) -> Challenge:
# clear session variables prior to starting a new registration # clear session variables prior to starting a new registration
self.request.session.pop(SESSION_KEY_WEBAUTHN_CHALLENGE, None) self.request.session.pop(SESSION_KEY_WEBAUTHN_CHALLENGE, None)
stage: AuthenticatorWebAuthnStage = self.executor.current_stage
user = self.get_pending_user() user = self.get_pending_user()
# library accepts none so we store null in the database, but if there is a value # library accepts none so we store null in the database, but if there is a value
# set, cast it to string to ensure it's not a django class # set, cast it to string to ensure it's not a django class
authenticator_attachment = stage.authenticator_attachment authenticator_attachment = self.current_stage.authenticator_attachment
if authenticator_attachment: if authenticator_attachment:
authenticator_attachment = AuthenticatorAttachment(str(authenticator_attachment)) authenticator_attachment = AuthenticatorAttachment(str(authenticator_attachment))
@ -132,8 +131,12 @@ class AuthenticatorWebAuthnStageView(ChallengeStageView):
user_name=user.username, user_name=user.username,
user_display_name=user.name, user_display_name=user.name,
authenticator_selection=AuthenticatorSelectionCriteria( authenticator_selection=AuthenticatorSelectionCriteria(
resident_key=ResidentKeyRequirement(str(stage.resident_key_requirement)), resident_key=ResidentKeyRequirement(
user_verification=UserVerificationRequirement(str(stage.user_verification)), str(self.current_stage.resident_key_requirement)
),
user_verification=UserVerificationRequirement(
str(self.current_stage.user_verification)
),
authenticator_attachment=authenticator_attachment, authenticator_attachment=authenticator_attachment,
), ),
attestation=AttestationConveyancePreference.DIRECT, attestation=AttestationConveyancePreference.DIRECT,

View File

@ -70,7 +70,7 @@ class CaptchaChallengeResponse(ChallengeResponse):
return data return data
class CaptchaStageView(ChallengeStageView): class CaptchaStageView(ChallengeStageView[CaptchaChallenge]):
"""Simple captcha checker, logic is handled in django-captcha module""" """Simple captcha checker, logic is handled in django-captcha module"""
response_class = CaptchaChallengeResponse response_class = CaptchaChallengeResponse
@ -78,8 +78,8 @@ class CaptchaStageView(ChallengeStageView):
def get_challenge(self, *args, **kwargs) -> Challenge: def get_challenge(self, *args, **kwargs) -> Challenge:
return CaptchaChallenge( return CaptchaChallenge(
data={ data={
"js_url": self.executor.current_stage.js_url, "js_url": self.current_stage.js_url,
"site_key": self.executor.current_stage.public_key, "site_key": self.current_stage.public_key,
} }
) )
@ -87,6 +87,6 @@ class CaptchaStageView(ChallengeStageView):
response = response.validated_data["token"] response = response.validated_data["token"]
self.executor.plan.context[PLAN_CONTEXT_CAPTCHA] = { self.executor.plan.context[PLAN_CONTEXT_CAPTCHA] = {
"response": response, "response": response,
"stage": self.executor.current_stage, "stage": self.current_stage,
} }
return self.executor.stage_ok() return self.executor.stage_ok()

View File

@ -48,7 +48,7 @@ class ConsentChallengeResponse(ChallengeResponse):
token = CharField(required=True) token = CharField(required=True)
class ConsentStageView(ChallengeStageView): class ConsentStageView(ChallengeStageView[ConsentStage]):
"""Simple consent checker.""" """Simple consent checker."""
response_class = ConsentChallengeResponse response_class = ConsentChallengeResponse
@ -72,14 +72,13 @@ class ConsentStageView(ChallengeStageView):
"""Check if the current request should require a prompt for non consent reasons, """Check if the current request should require a prompt for non consent reasons,
i.e. this stage injected from another stage, mode is always requireed or no application i.e. this stage injected from another stage, mode is always requireed or no application
is set.""" is set."""
current_stage: ConsentStage = self.executor.current_stage
# Make this StageView work when injected, in which case `current_stage` is an instance # Make this StageView work when injected, in which case `current_stage` is an instance
# of the base class, and we don't save any consent, as it is assumed to be a one-time # of the base class, and we don't save any consent, as it is assumed to be a one-time
# prompt # prompt
if not isinstance(current_stage, ConsentStage): if not isinstance(self.current_stage, ConsentStage):
return True return True
# For always require, we always return the challenge # For always require, we always return the challenge
if current_stage.mode == ConsentMode.ALWAYS_REQUIRE: if self.current_stage.mode == ConsentMode.ALWAYS_REQUIRE:
return True return True
# at this point we need to check consent from database # at this point we need to check consent from database
if PLAN_CONTEXT_APPLICATION not in self.executor.plan.context: if PLAN_CONTEXT_APPLICATION not in self.executor.plan.context:
@ -125,7 +124,6 @@ class ConsentStageView(ChallengeStageView):
return self.get(self.request) return self.get(self.request)
if self.should_always_prompt(): if self.should_always_prompt():
return self.executor.stage_ok() return self.executor.stage_ok()
current_stage: ConsentStage = self.executor.current_stage
application = self.executor.plan.context[PLAN_CONTEXT_APPLICATION] application = self.executor.plan.context[PLAN_CONTEXT_APPLICATION]
permissions = self.executor.plan.context.get( permissions = self.executor.plan.context.get(
PLAN_CONTEXT_CONSENT_PERMISSIONS, [] PLAN_CONTEXT_CONSENT_PERMISSIONS, []
@ -139,9 +137,9 @@ class ConsentStageView(ChallengeStageView):
) )
consent: UserConsent = self.executor.plan.context[PLAN_CONTEXT_CONSENT] consent: UserConsent = self.executor.plan.context[PLAN_CONTEXT_CONSENT]
consent.permissions = permissions_string consent.permissions = permissions_string
if current_stage.mode == ConsentMode.PERMANENT: if self.current_stage.mode == ConsentMode.PERMANENT:
consent.expiring = False consent.expiring = False
if current_stage.mode == ConsentMode.EXPIRING: if self.current_stage.mode == ConsentMode.EXPIRING:
consent.expires = now() + timedelta_from_string(current_stage.consent_expire_in) consent.expires = now() + timedelta_from_string(self.current_stage.consent_expire_in)
consent.save() consent.save()
return self.executor.stage_ok() return self.executor.stage_ok()

View File

@ -6,11 +6,10 @@ from authentik.flows.stage import StageView
from authentik.stages.deny.models import DenyStage from authentik.stages.deny.models import DenyStage
class DenyStageView(StageView): class DenyStageView(StageView[DenyStage]):
"""Cancels the current flow""" """Cancels the current flow"""
def dispatch(self, request: HttpRequest) -> HttpResponse: def dispatch(self, request: HttpRequest) -> HttpResponse:
"""Cancels the current flow""" """Cancels the current flow"""
stage: DenyStage = self.executor.current_stage message = self.executor.plan.context.get("deny_message", self.current_stage.deny_message)
message = self.executor.plan.context.get("deny_message", stage.deny_message)
return self.executor.stage_invalid(message) return self.executor.stage_invalid(message)

View File

@ -30,11 +30,11 @@ class DummyStageView(ChallengeStageView):
return self.executor.stage_ok() return self.executor.stage_ok()
def get_challenge(self, *args, **kwargs) -> Challenge: def get_challenge(self, *args, **kwargs) -> Challenge:
if self.executor.current_stage.throw_error: if self.current_stage.throw_error:
raise SentryIgnoredException("Test error") raise SentryIgnoredException("Test error")
return DummyChallenge( return DummyChallenge(
data={ data={
"title": self.executor.current_stage.name, "title": self.current_stage.name,
"name": self.executor.current_stage.name, "name": self.current_stage.name,
} }
) )

View File

@ -46,7 +46,7 @@ class EmailChallengeResponse(ChallengeResponse):
raise ValidationError(detail="email-sent", code="email-sent") raise ValidationError(detail="email-sent", code="email-sent")
class EmailStageView(ChallengeStageView): class EmailStageView(ChallengeStageView[EmailStage]):
"""Email stage which sends Email for verification""" """Email stage which sends Email for verification"""
response_class = EmailChallengeResponse response_class = EmailChallengeResponse
@ -72,11 +72,10 @@ class EmailStageView(ChallengeStageView):
def get_token(self) -> FlowToken: def get_token(self) -> FlowToken:
"""Get token""" """Get token"""
pending_user = self.get_pending_user() pending_user = self.get_pending_user()
current_stage: EmailStage = self.executor.current_stage
valid_delta = timedelta( valid_delta = timedelta(
minutes=current_stage.token_expiry + 1 minutes=self.current_stage.token_expiry + 1
) # + 1 because django timesince always rounds down ) # + 1 because django timesince always rounds down
identifier = slugify(f"ak-email-stage-{current_stage.name}-{str(uuid4())}") identifier = slugify(f"ak-email-stage-{self.current_stage.name}-{str(uuid4())}")
# Don't check for validity here, we only care if the token exists # Don't check for validity here, we only care if the token exists
tokens = FlowToken.objects.filter(identifier=identifier) tokens = FlowToken.objects.filter(identifier=identifier)
if not tokens.exists(): if not tokens.exists():
@ -105,15 +104,14 @@ class EmailStageView(ChallengeStageView):
email = self.executor.plan.context.get(PLAN_CONTEXT_EMAIL_OVERRIDE, None) email = self.executor.plan.context.get(PLAN_CONTEXT_EMAIL_OVERRIDE, None)
if not email: if not email:
email = pending_user.email email = pending_user.email
current_stage: EmailStage = self.executor.current_stage
token = self.get_token() token = self.get_token()
# Send mail to user # Send mail to user
try: try:
message = TemplateEmailMessage( message = TemplateEmailMessage(
subject=_(current_stage.subject), subject=_(self.current_stage.subject),
to=[(pending_user.name, email)], to=[(pending_user.name, email)],
language=pending_user.locale(self.request), language=pending_user.locale(self.request),
template_name=current_stage.template, template_name=self.current_stage.template,
template_context={ template_context={
"url": self.get_full_url(**{QS_KEY_TOKEN: token.key}), "url": self.get_full_url(**{QS_KEY_TOKEN: token.key}),
"user": pending_user, "user": pending_user,
@ -121,26 +119,28 @@ class EmailStageView(ChallengeStageView):
"token": token.key, "token": token.key,
}, },
) )
send_mails(current_stage, message) send_mails(self.current_stage, message)
except TemplateSyntaxError as exc: except TemplateSyntaxError as exc:
Event.new( Event.new(
EventAction.CONFIGURATION_ERROR, EventAction.CONFIGURATION_ERROR,
message=_("Exception occurred while rendering E-mail template"), message=_("Exception occurred while rendering E-mail template"),
error=exception_to_string(exc), error=exception_to_string(exc),
template=current_stage.template, template=self.current_stage.template,
).from_http(self.request) ).from_http(self.request)
raise StageInvalidException from exc raise StageInvalidException from exc
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
# Check if the user came back from the email link to verify # Check if the user came back from the email link to verify
restore_token: FlowToken = self.executor.plan.context.get(PLAN_CONTEXT_IS_RESTORED, None) restore_token: FlowToken | None = self.executor.plan.context.get(
PLAN_CONTEXT_IS_RESTORED, None
)
user = self.get_pending_user() user = self.get_pending_user()
if restore_token: if restore_token:
if restore_token.user != user: if restore_token.user != user:
self.logger.warning("Flow token for non-matching user, denying request") self.logger.warning("Flow token for non-matching user, denying request")
return self.executor.stage_invalid() return self.executor.stage_invalid()
messages.success(request, _("Successfully verified Email.")) messages.success(request, _("Successfully verified Email."))
if self.executor.current_stage.activate_user_on_success: if self.current_stage.activate_user_on_success:
user.is_active = True user.is_active = True
user.save() user.save()
return self.executor.stage_ok() return self.executor.stage_ok()

View File

@ -164,22 +164,21 @@ class IdentificationChallengeResponse(ChallengeResponse):
return attrs return attrs
class IdentificationStageView(ChallengeStageView): class IdentificationStageView(ChallengeStageView[IdentificationStage]):
"""Form to identify the user""" """Form to identify the user"""
response_class = IdentificationChallengeResponse response_class = IdentificationChallengeResponse
def get_user(self, uid_value: str) -> User | None: def get_user(self, uid_value: str) -> User | None:
"""Find user instance. Returns None if no user was found.""" """Find user instance. Returns None if no user was found."""
current_stage: IdentificationStage = self.executor.current_stage
query = Q() query = Q()
for search_field in current_stage.user_fields: for search_field in self.current_stage.user_fields:
model_field = { model_field = {
"email": "email", "email": "email",
"username": "username", "username": "username",
"upn": "attributes__upn", "upn": "attributes__upn",
}[search_field] }[search_field]
if current_stage.case_insensitive_matching: if self.current_stage.case_insensitive_matching:
model_field += "__iexact" model_field += "__iexact"
else: else:
model_field += "__exact" model_field += "__exact"
@ -200,16 +199,15 @@ class IdentificationStageView(ChallengeStageView):
return _("Continue") return _("Continue")
def get_challenge(self) -> Challenge: def get_challenge(self) -> Challenge:
current_stage: IdentificationStage = self.executor.current_stage
challenge = IdentificationChallenge( challenge = IdentificationChallenge(
data={ data={
"component": "ak-stage-identification", "component": "ak-stage-identification",
"primary_action": self.get_primary_action(), "primary_action": self.get_primary_action(),
"user_fields": current_stage.user_fields, "user_fields": self.current_stage.user_fields,
"password_fields": bool(current_stage.password_stage), "password_fields": bool(self.current_stage.password_stage),
"allow_show_password": bool(current_stage.password_stage) "allow_show_password": bool(self.current_stage.password_stage)
and current_stage.password_stage.allow_show_password, and self.current_stage.password_stage.allow_show_password,
"show_source_labels": current_stage.show_source_labels, "show_source_labels": self.current_stage.show_source_labels,
"flow_designation": self.executor.flow.designation, "flow_designation": self.executor.flow.designation,
} }
) )
@ -221,27 +219,26 @@ class IdentificationStageView(ChallengeStageView):
).name ).name
get_qs = self.request.session.get(SESSION_KEY_GET, self.request.GET) get_qs = self.request.session.get(SESSION_KEY_GET, self.request.GET)
# Check for related enrollment and recovery flow, add URL to view # Check for related enrollment and recovery flow, add URL to view
if current_stage.enrollment_flow: if self.current_stage.enrollment_flow:
challenge.initial_data["enroll_url"] = reverse_with_qs( challenge.initial_data["enroll_url"] = reverse_with_qs(
"authentik_core:if-flow", "authentik_core:if-flow",
query=get_qs, query=get_qs,
kwargs={"flow_slug": current_stage.enrollment_flow.slug}, kwargs={"flow_slug": self.current_stage.enrollment_flow.slug},
) )
if current_stage.recovery_flow: if self.current_stage.recovery_flow:
challenge.initial_data["recovery_url"] = reverse_with_qs( challenge.initial_data["recovery_url"] = reverse_with_qs(
"authentik_core:if-flow", "authentik_core:if-flow",
query=get_qs, query=get_qs,
kwargs={"flow_slug": current_stage.recovery_flow.slug}, kwargs={"flow_slug": self.current_stage.recovery_flow.slug},
) )
if current_stage.passwordless_flow: if self.current_stage.passwordless_flow:
challenge.initial_data["passwordless_url"] = reverse_with_qs( challenge.initial_data["passwordless_url"] = reverse_with_qs(
"authentik_core:if-flow", "authentik_core:if-flow",
query=get_qs, query=get_qs,
kwargs={"flow_slug": current_stage.passwordless_flow.slug}, kwargs={"flow_slug": self.current_stage.passwordless_flow.slug},
) )
if current_stage.captcha_stage: if self.current_stage.captcha_stage:
captcha = CaptchaStageView(self.executor) captcha = CaptchaStageView(self.executor, self.current_stage.captcha_stage)
captcha.stage = current_stage.captcha_stage
captcha_challenge = captcha.get_challenge() captcha_challenge = captcha.get_challenge()
captcha_challenge.is_valid() captcha_challenge.is_valid()
challenge.initial_data["captcha_stage"] = captcha_challenge.data challenge.initial_data["captcha_stage"] = captcha_challenge.data
@ -249,7 +246,7 @@ class IdentificationStageView(ChallengeStageView):
# Check all enabled source, add them if they have a UI Login button. # Check all enabled source, add them if they have a UI Login button.
ui_sources = [] ui_sources = []
sources: list[Source] = ( sources: list[Source] = (
current_stage.sources.filter(enabled=True).order_by("name").select_subclasses() self.current_stage.sources.filter(enabled=True).order_by("name").select_subclasses()
) )
for source in sources: for source in sources:
ui_login_button = source.ui_login_button(self.request) ui_login_button = source.ui_login_button(self.request)
@ -264,8 +261,7 @@ class IdentificationStageView(ChallengeStageView):
def challenge_valid(self, response: IdentificationChallengeResponse) -> HttpResponse: def challenge_valid(self, response: IdentificationChallengeResponse) -> HttpResponse:
self.executor.plan.context[PLAN_CONTEXT_PENDING_USER] = response.pre_user self.executor.plan.context[PLAN_CONTEXT_PENDING_USER] = response.pre_user
current_stage: IdentificationStage = self.executor.current_stage if not self.current_stage.show_matched_user:
if not current_stage.show_matched_user:
self.executor.plan.context[PLAN_CONTEXT_PENDING_USER_IDENTIFIER] = ( self.executor.plan.context[PLAN_CONTEXT_PENDING_USER_IDENTIFIER] = (
response.validated_data.get("uid_field") response.validated_data.get("uid_field")
) )

View File

@ -17,7 +17,7 @@ INVITATION_IN_EFFECT = "invitation_in_effect"
INVITATION = "invitation" INVITATION = "invitation"
class InvitationStageView(StageView): class InvitationStageView(StageView[InvitationStage]):
"""Finalise Authentication flow by logging the user in""" """Finalise Authentication flow by logging the user in"""
def get_token(self) -> str | None: def get_token(self) -> str | None:
@ -52,11 +52,10 @@ class InvitationStageView(StageView):
def dispatch(self, request: HttpRequest) -> HttpResponse: def dispatch(self, request: HttpRequest) -> HttpResponse:
"""Apply data to the current flow based on a URL""" """Apply data to the current flow based on a URL"""
stage: InvitationStage = self.executor.current_stage
invite = self.get_invite() invite = self.get_invite()
if not invite: if not invite:
if stage.continue_flow_without_invitation: if self.current_stage.continue_flow_without_invitation:
return self.executor.stage_ok() return self.executor.stage_ok()
return self.executor.stage_invalid(_("Invalid invite/invite not found")) return self.executor.stage_invalid(_("Invalid invite/invite not found"))

View File

@ -130,7 +130,7 @@ class PasswordChallengeResponse(ChallengeResponse):
return password return password
class PasswordStageView(ChallengeStageView): class PasswordStageView(ChallengeStageView[PasswordStage]):
"""Authentication stage which authenticates against django's AuthBackend""" """Authentication stage which authenticates against django's AuthBackend"""
response_class = PasswordChallengeResponse response_class = PasswordChallengeResponse
@ -138,7 +138,7 @@ class PasswordStageView(ChallengeStageView):
def get_challenge(self) -> Challenge: def get_challenge(self) -> Challenge:
challenge = PasswordChallenge( challenge = PasswordChallenge(
data={ data={
"allow_show_password": self.executor.current_stage.allow_show_password, "allow_show_password": self.current_stage.allow_show_password,
} }
) )
recovery_flow = Flow.objects.filter(designation=FlowDesignation.RECOVERY) recovery_flow = Flow.objects.filter(designation=FlowDesignation.RECOVERY)
@ -154,10 +154,9 @@ class PasswordStageView(ChallengeStageView):
if SESSION_KEY_INVALID_TRIES not in self.request.session: if SESSION_KEY_INVALID_TRIES not in self.request.session:
self.request.session[SESSION_KEY_INVALID_TRIES] = 0 self.request.session[SESSION_KEY_INVALID_TRIES] = 0
self.request.session[SESSION_KEY_INVALID_TRIES] += 1 self.request.session[SESSION_KEY_INVALID_TRIES] += 1
current_stage: PasswordStage = self.executor.current_stage
if ( if (
self.request.session[SESSION_KEY_INVALID_TRIES] self.request.session[SESSION_KEY_INVALID_TRIES]
>= current_stage.failed_attempts_before_cancel >= self.current_stage.failed_attempts_before_cancel
): ):
self.logger.debug("User has exceeded maximum tries") self.logger.debug("User has exceeded maximum tries")
del self.request.session[SESSION_KEY_INVALID_TRIES] del self.request.session[SESSION_KEY_INVALID_TRIES]

View File

@ -222,7 +222,7 @@ class PromptStageView(ChallengeStageView):
return serializers return serializers
def get_challenge(self, *args, **kwargs) -> Challenge: def get_challenge(self, *args, **kwargs) -> Challenge:
fields: list[Prompt] = list(self.executor.current_stage.fields.all().order_by("order")) fields: list[Prompt] = list(self.current_stage.fields.all().order_by("order"))
context_prompt = self.executor.plan.context.get(PLAN_CONTEXT_PROMPT, {}) context_prompt = self.executor.plan.context.get(PLAN_CONTEXT_PROMPT, {})
serializers = self.get_prompt_challenge_fields(fields, context_prompt) serializers = self.get_prompt_challenge_fields(fields, context_prompt)
challenge = PromptChallenge( challenge = PromptChallenge(
@ -239,7 +239,7 @@ class PromptStageView(ChallengeStageView):
instance=None, instance=None,
data=data, data=data,
request=self.request, request=self.request,
stage_instance=self.executor.current_stage, stage_instance=self.current_stage,
stage=self, stage=self,
plan=self.executor.plan, plan=self.executor.plan,
user=self.get_pending_user(), user=self.get_pending_user(),

View File

@ -7,9 +7,10 @@ from django.utils.translation import gettext as _
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER
from authentik.flows.stage import StageView from authentik.flows.stage import StageView
from authentik.stages.user_delete.models import UserDeleteStage
class UserDeleteStageView(StageView): class UserDeleteStageView(StageView[UserDeleteStage]):
"""Finalise unenrollment flow by deleting the user object.""" """Finalise unenrollment flow by deleting the user object."""
def dispatch(self, request: HttpRequest) -> HttpResponse: def dispatch(self, request: HttpRequest) -> HttpResponse:

View File

@ -39,7 +39,7 @@ class UserLoginChallengeResponse(ChallengeResponse):
remember_me = BooleanField(required=True) remember_me = BooleanField(required=True)
class UserLoginStageView(ChallengeStageView): class UserLoginStageView(ChallengeStageView[UserLoginStage]):
"""Finalise Authentication flow by logging the user in""" """Finalise Authentication flow by logging the user in"""
response_class = UserLoginChallengeResponse response_class = UserLoginChallengeResponse
@ -49,8 +49,7 @@ class UserLoginStageView(ChallengeStageView):
def dispatch(self, request: HttpRequest) -> HttpResponse: def dispatch(self, request: HttpRequest) -> HttpResponse:
"""Check for remember_me, and do login""" """Check for remember_me, and do login"""
stage: UserLoginStage = self.executor.current_stage if timedelta_from_string(self.current_stage.remember_me_offset).total_seconds() > 0:
if timedelta_from_string(stage.remember_me_offset).total_seconds() > 0:
return super().dispatch(request) return super().dispatch(request)
return self.do_login(request) return self.do_login(request)
@ -59,9 +58,9 @@ class UserLoginStageView(ChallengeStageView):
def set_session_duration(self, remember: bool) -> timedelta: def set_session_duration(self, remember: bool) -> timedelta:
"""Update the sessions' expiry""" """Update the sessions' expiry"""
delta = timedelta_from_string(self.executor.current_stage.session_duration) delta = timedelta_from_string(self.current_stage.session_duration)
if remember: if remember:
offset = timedelta_from_string(self.executor.current_stage.remember_me_offset) offset = timedelta_from_string(self.current_stage.remember_me_offset)
delta = delta + offset delta = delta + offset
if delta.total_seconds() == 0: if delta.total_seconds() == 0:
self.request.session.set_expiry(0) self.request.session.set_expiry(0)
@ -71,11 +70,9 @@ class UserLoginStageView(ChallengeStageView):
def set_session_ip(self): def set_session_ip(self):
"""Set the sessions' last IP and session bindings""" """Set the sessions' last IP and session bindings"""
stage: UserLoginStage = self.executor.current_stage
self.request.session[SESSION_KEY_LAST_IP] = ClientIPMiddleware.get_client_ip(self.request) self.request.session[SESSION_KEY_LAST_IP] = ClientIPMiddleware.get_client_ip(self.request)
self.request.session[SESSION_KEY_BINDING_NET] = stage.network_binding self.request.session[SESSION_KEY_BINDING_NET] = self.current_stage.network_binding
self.request.session[SESSION_KEY_BINDING_GEO] = stage.geoip_binding self.request.session[SESSION_KEY_BINDING_GEO] = self.current_stage.geoip_binding
def do_login(self, request: HttpRequest, remember: bool = False) -> HttpResponse: def do_login(self, request: HttpRequest, remember: bool = False) -> HttpResponse:
"""Attach the currently pending user to the current session""" """Attach the currently pending user to the current session"""
@ -111,7 +108,7 @@ class UserLoginStageView(ChallengeStageView):
# as sources show their own success messages # as sources show their own success messages
if not self.executor.plan.context.get(PLAN_CONTEXT_SOURCE, None): if not self.executor.plan.context.get(PLAN_CONTEXT_SOURCE, None):
messages.success(self.request, _("Successfully logged in!")) messages.success(self.request, _("Successfully logged in!"))
if self.executor.current_stage.terminate_other_sessions: if self.current_stage.terminate_other_sessions:
AuthenticatedSession.objects.filter( AuthenticatedSession.objects.filter(
user=user, user=user,
).exclude(session_key=self.request.session.session_key).delete() ).exclude(session_key=self.request.session.session_key).delete()

View File

@ -4,9 +4,10 @@ from django.contrib.auth import logout
from django.http import HttpRequest, HttpResponse from django.http import HttpRequest, HttpResponse
from authentik.flows.stage import StageView from authentik.flows.stage import StageView
from authentik.stages.user_logout.models import UserLogoutStage
class UserLogoutStageView(StageView): class UserLogoutStageView(StageView[UserLogoutStage]):
"""Finalise Authentication flow by logging the user in""" """Finalise Authentication flow by logging the user in"""
def dispatch(self, request: HttpRequest) -> HttpResponse: def dispatch(self, request: HttpRequest) -> HttpResponse:

View File

@ -55,7 +55,7 @@ class UserWriteStageView(StageView):
"""Ensure a user exists""" """Ensure a user exists"""
user_created = False user_created = False
path = self.executor.plan.context.get( path = self.executor.plan.context.get(
PLAN_CONTEXT_USER_PATH, self.executor.current_stage.user_path_template PLAN_CONTEXT_USER_PATH, self.current_stage.user_path_template
) )
if path == "": if path == "":
path = User.default_path() path = User.default_path()
@ -64,11 +64,11 @@ class UserWriteStageView(StageView):
user_type = UserTypes( user_type = UserTypes(
self.executor.plan.context.get( self.executor.plan.context.get(
PLAN_CONTEXT_USER_TYPE, PLAN_CONTEXT_USER_TYPE,
self.executor.current_stage.user_type, self.current_stage.user_type,
) )
) )
except ValueError: except ValueError:
user_type = self.executor.current_stage.user_type user_type = self.current_stage.user_type
if user_type == UserTypes.INTERNAL_SERVICE_ACCOUNT: if user_type == UserTypes.INTERNAL_SERVICE_ACCOUNT:
user_type = UserTypes.SERVICE_ACCOUNT user_type = UserTypes.SERVICE_ACCOUNT
@ -76,12 +76,12 @@ class UserWriteStageView(StageView):
self.executor.plan.context.setdefault(PLAN_CONTEXT_PENDING_USER, self.request.user) self.executor.plan.context.setdefault(PLAN_CONTEXT_PENDING_USER, self.request.user)
if ( if (
PLAN_CONTEXT_PENDING_USER not in self.executor.plan.context PLAN_CONTEXT_PENDING_USER not in self.executor.plan.context
or self.executor.current_stage.user_creation_mode == UserCreationMode.ALWAYS_CREATE or self.current_stage.user_creation_mode == UserCreationMode.ALWAYS_CREATE
): ):
if self.executor.current_stage.user_creation_mode == UserCreationMode.NEVER_CREATE: if self.current_stage.user_creation_mode == UserCreationMode.NEVER_CREATE:
return None, False return None, False
self.executor.plan.context[PLAN_CONTEXT_PENDING_USER] = User( self.executor.plan.context[PLAN_CONTEXT_PENDING_USER] = User(
is_active=not self.executor.current_stage.create_users_as_inactive, is_active=not self.current_stage.create_users_as_inactive,
path=path, path=path,
type=user_type, type=user_type,
) )
@ -180,8 +180,8 @@ class UserWriteStageView(StageView):
try: try:
with transaction.atomic(): with transaction.atomic():
user.save() user.save()
if self.executor.current_stage.create_users_group: if self.current_stage.create_users_group:
user.ak_groups.add(self.executor.current_stage.create_users_group) user.ak_groups.add(self.current_stage.create_users_group)
if PLAN_CONTEXT_GROUPS in self.executor.plan.context: if PLAN_CONTEXT_GROUPS in self.executor.plan.context:
user.ak_groups.add(*self.executor.plan.context[PLAN_CONTEXT_GROUPS]) user.ak_groups.add(*self.executor.plan.context[PLAN_CONTEXT_GROUPS])
except (IntegrityError, ValueError, TypeError, InternalError) as exc: except (IntegrityError, ValueError, TypeError, InternalError) as exc: