stages/prompt: migrate to SPA
This commit is contained in:
		@ -28,9 +28,3 @@ class TestOverviewViews(TestCase):
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            self.client.get(reverse("authentik_core:shell")).status_code, 200
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def test_overview(self):
 | 
			
		||||
        """Test overview"""
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            self.client.get(reverse("authentik_core:overview")).status_code, 200
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
@ -1,20 +1,7 @@
 | 
			
		||||
"""Prompt forms"""
 | 
			
		||||
from email.policy import Policy
 | 
			
		||||
from types import MethodType
 | 
			
		||||
from typing import Any, Callable, Iterator
 | 
			
		||||
 | 
			
		||||
from django import forms
 | 
			
		||||
from django.db.models.query import QuerySet
 | 
			
		||||
from django.http import HttpRequest
 | 
			
		||||
from django.utils.translation import gettext_lazy as _
 | 
			
		||||
from guardian.shortcuts import get_anonymous_user
 | 
			
		||||
 | 
			
		||||
from authentik.core.models import User
 | 
			
		||||
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan
 | 
			
		||||
from authentik.policies.engine import PolicyEngine
 | 
			
		||||
from authentik.policies.models import PolicyBinding, PolicyBindingModel
 | 
			
		||||
from authentik.stages.prompt.models import FieldTypes, Prompt, PromptStage
 | 
			
		||||
from authentik.stages.prompt.signals import password_validate
 | 
			
		||||
from authentik.stages.prompt.models import Prompt, PromptStage
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PromptStageForm(forms.ModelForm):
 | 
			
		||||
@ -47,111 +34,3 @@ class PromptAdminForm(forms.ModelForm):
 | 
			
		||||
            "label": forms.TextInput(),
 | 
			
		||||
            "placeholder": forms.TextInput(),
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ListPolicyEngine(PolicyEngine):
 | 
			
		||||
    """Slightly modified policy engine, which uses a list instead of a PolicyBindingModel"""
 | 
			
		||||
 | 
			
		||||
    __list: list[Policy]
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self, policies: list[Policy], user: User, request: HttpRequest = None
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        super().__init__(PolicyBindingModel(), user, request)
 | 
			
		||||
        self.__list = policies
 | 
			
		||||
        self.use_cache = False
 | 
			
		||||
 | 
			
		||||
    def _iter_bindings(self) -> Iterator[PolicyBinding]:
 | 
			
		||||
        for policy in self.__list:
 | 
			
		||||
            yield PolicyBinding(
 | 
			
		||||
                policy=policy,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PromptForm(forms.Form):
 | 
			
		||||
    """Dynamically created form based on PromptStage"""
 | 
			
		||||
 | 
			
		||||
    stage: PromptStage
 | 
			
		||||
    plan: FlowPlan
 | 
			
		||||
 | 
			
		||||
    def __init__(self, stage: PromptStage, plan: FlowPlan, *args, **kwargs):
 | 
			
		||||
        self.stage = stage
 | 
			
		||||
        self.plan = plan
 | 
			
		||||
        super().__init__(*args, **kwargs)
 | 
			
		||||
        # list() is called so we only load the fields once
 | 
			
		||||
        fields = list(self.stage.fields.all())
 | 
			
		||||
        for field in fields:
 | 
			
		||||
            field: Prompt
 | 
			
		||||
            self.fields[field.field_key] = field.field
 | 
			
		||||
            # Special handling for fields with username type
 | 
			
		||||
            # these check for existing users with the same username
 | 
			
		||||
            if field.type == FieldTypes.USERNAME:
 | 
			
		||||
                setattr(
 | 
			
		||||
                    self,
 | 
			
		||||
                    f"clean_{field.field_key}",
 | 
			
		||||
                    MethodType(username_field_cleaner_factory(field), self),
 | 
			
		||||
                )
 | 
			
		||||
            # Check if we have a password field, add a handler that sends a signal
 | 
			
		||||
            # to validate it
 | 
			
		||||
            if field.type == FieldTypes.PASSWORD:
 | 
			
		||||
                setattr(
 | 
			
		||||
                    self,
 | 
			
		||||
                    f"clean_{field.field_key}",
 | 
			
		||||
                    MethodType(password_single_cleaner_factory(field), self),
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        self.field_order = sorted(fields, key=lambda x: x.order)
 | 
			
		||||
 | 
			
		||||
    def _clean_password_fields(self, *field_names):
 | 
			
		||||
        """Check if the value of all password fields match by merging them into a set
 | 
			
		||||
        and checking the length"""
 | 
			
		||||
        all_passwords = {self.cleaned_data[x] for x in field_names}
 | 
			
		||||
        if len(all_passwords) > 1:
 | 
			
		||||
            raise forms.ValidationError(_("Passwords don't match."))
 | 
			
		||||
 | 
			
		||||
    def clean(self):
 | 
			
		||||
        cleaned_data = super().clean()
 | 
			
		||||
        if cleaned_data == {}:
 | 
			
		||||
            return {}
 | 
			
		||||
        # Check if we have two password fields, and make sure they are the same
 | 
			
		||||
        password_fields: QuerySet[Prompt] = self.stage.fields.filter(
 | 
			
		||||
            type=FieldTypes.PASSWORD
 | 
			
		||||
        )
 | 
			
		||||
        if password_fields.exists() and password_fields.count() == 2:
 | 
			
		||||
            self._clean_password_fields(*[field.field_key for field in password_fields])
 | 
			
		||||
 | 
			
		||||
        user = self.plan.context.get(PLAN_CONTEXT_PENDING_USER, get_anonymous_user())
 | 
			
		||||
        engine = ListPolicyEngine(self.stage.validation_policies.all(), user)
 | 
			
		||||
        engine.request.context = cleaned_data
 | 
			
		||||
        engine.build()
 | 
			
		||||
        result = engine.result
 | 
			
		||||
        if not result.passing:
 | 
			
		||||
            raise forms.ValidationError(list(result.messages))
 | 
			
		||||
        return cleaned_data
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def username_field_cleaner_factory(field: Prompt) -> Callable:
 | 
			
		||||
    """Return a `clean_` method for `field`. Clean method checks if username is taken already."""
 | 
			
		||||
 | 
			
		||||
    def username_field_cleaner(self: PromptForm) -> Any:
 | 
			
		||||
        """Check for duplicate usernames"""
 | 
			
		||||
        username = self.cleaned_data.get(field.field_key)
 | 
			
		||||
        if User.objects.filter(username=username).exists():
 | 
			
		||||
            raise forms.ValidationError("Username is already taken.")
 | 
			
		||||
        return username
 | 
			
		||||
 | 
			
		||||
    return username_field_cleaner
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def password_single_cleaner_factory(field: Prompt) -> Callable[[PromptForm], Any]:
 | 
			
		||||
    """Return a `clean_` method for `field`. Clean method checks if username is taken already."""
 | 
			
		||||
 | 
			
		||||
    def password_single_clean(self: PromptForm) -> Any:
 | 
			
		||||
        """Send password validation signals for e.g. LDAP Source"""
 | 
			
		||||
        password = self.cleaned_data[field.field_key]
 | 
			
		||||
        password_validate.send(
 | 
			
		||||
            sender=self, password=password, plan_context=self.plan.context
 | 
			
		||||
        )
 | 
			
		||||
        return password
 | 
			
		||||
 | 
			
		||||
    return password_single_clean
 | 
			
		||||
 | 
			
		||||
@ -2,17 +2,23 @@
 | 
			
		||||
from typing import Type
 | 
			
		||||
from uuid import uuid4
 | 
			
		||||
 | 
			
		||||
from django import forms
 | 
			
		||||
from django.db import models
 | 
			
		||||
from django.forms import ModelForm
 | 
			
		||||
from django.utils.translation import gettext_lazy as _
 | 
			
		||||
from django.views import View
 | 
			
		||||
from rest_framework.fields import (
 | 
			
		||||
    BooleanField,
 | 
			
		||||
    CharField,
 | 
			
		||||
    DateField,
 | 
			
		||||
    DateTimeField,
 | 
			
		||||
    EmailField,
 | 
			
		||||
    IntegerField,
 | 
			
		||||
)
 | 
			
		||||
from rest_framework.serializers import BaseSerializer
 | 
			
		||||
 | 
			
		||||
from authentik.flows.models import Stage
 | 
			
		||||
from authentik.lib.models import SerializerModel
 | 
			
		||||
from authentik.policies.models import Policy
 | 
			
		||||
from authentik.stages.prompt.widgets import HorizontalRuleWidget, StaticTextWidget
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FieldTypes(models.TextChoices):
 | 
			
		||||
@ -43,8 +49,8 @@ class FieldTypes(models.TextChoices):
 | 
			
		||||
    )
 | 
			
		||||
    NUMBER = "number"
 | 
			
		||||
    CHECKBOX = "checkbox"
 | 
			
		||||
    DATE = "data"
 | 
			
		||||
    DATE_TIME = "data-time"
 | 
			
		||||
    DATE = "date"
 | 
			
		||||
    DATE_TIME = "date-time"
 | 
			
		||||
 | 
			
		||||
    SEPARATOR = "separator", _("Separator: Static Separator Line")
 | 
			
		||||
    HIDDEN = "hidden", _("Hidden: Hidden field, can be used to insert data into form.")
 | 
			
		||||
@ -73,49 +79,34 @@ class Prompt(SerializerModel):
 | 
			
		||||
        return PromptSerializer
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def field(self):
 | 
			
		||||
        """Return instantiated form input field"""
 | 
			
		||||
        attrs = {"placeholder": _(self.placeholder)}
 | 
			
		||||
        field_class = forms.CharField
 | 
			
		||||
        widget = forms.TextInput(attrs=attrs)
 | 
			
		||||
    def field(self) -> CharField:
 | 
			
		||||
        """Get field type for Challenge and response"""
 | 
			
		||||
        field_class = CharField
 | 
			
		||||
        kwargs = {
 | 
			
		||||
            "label": _(self.label),
 | 
			
		||||
            "required": self.required,
 | 
			
		||||
        }
 | 
			
		||||
        if self.type == FieldTypes.EMAIL:
 | 
			
		||||
            field_class = forms.EmailField
 | 
			
		||||
        if self.type == FieldTypes.USERNAME:
 | 
			
		||||
            attrs["autocomplete"] = "username"
 | 
			
		||||
        if self.type == FieldTypes.PASSWORD:
 | 
			
		||||
            widget = forms.PasswordInput(attrs=attrs)
 | 
			
		||||
            attrs["autocomplete"] = "new-password"
 | 
			
		||||
            field_class = EmailField
 | 
			
		||||
        if self.type == FieldTypes.NUMBER:
 | 
			
		||||
            field_class = forms.IntegerField
 | 
			
		||||
            widget = forms.NumberInput(attrs=attrs)
 | 
			
		||||
            field_class = IntegerField
 | 
			
		||||
        # TODO: Hidden?
 | 
			
		||||
        if self.type == FieldTypes.HIDDEN:
 | 
			
		||||
            widget = forms.HiddenInput(attrs=attrs)
 | 
			
		||||
            kwargs["required"] = False
 | 
			
		||||
            kwargs["initial"] = self.placeholder
 | 
			
		||||
        if self.type == FieldTypes.CHECKBOX:
 | 
			
		||||
            field_class = forms.BooleanField
 | 
			
		||||
            field_class = BooleanField
 | 
			
		||||
            kwargs["required"] = False
 | 
			
		||||
        if self.type == FieldTypes.DATE:
 | 
			
		||||
            attrs["type"] = "date"
 | 
			
		||||
            widget = forms.DateInput(attrs=attrs)
 | 
			
		||||
            field_class = DateField
 | 
			
		||||
        if self.type == FieldTypes.DATE_TIME:
 | 
			
		||||
            attrs["type"] = "datetime-local"
 | 
			
		||||
            widget = forms.DateTimeInput(attrs=attrs)
 | 
			
		||||
            field_class = DateTimeField
 | 
			
		||||
        if self.type == FieldTypes.STATIC:
 | 
			
		||||
            widget = StaticTextWidget(attrs=attrs)
 | 
			
		||||
            kwargs["initial"] = self.placeholder
 | 
			
		||||
            kwargs["required"] = False
 | 
			
		||||
            kwargs["label"] = ""
 | 
			
		||||
        if self.type == FieldTypes.SEPARATOR:
 | 
			
		||||
            widget = HorizontalRuleWidget(attrs=attrs)
 | 
			
		||||
            kwargs["required"] = False
 | 
			
		||||
            kwargs["label"] = ""
 | 
			
		||||
 | 
			
		||||
        kwargs["widget"] = widget
 | 
			
		||||
        return field_class(**kwargs)
 | 
			
		||||
 | 
			
		||||
    def save(self, *args, **kwargs):
 | 
			
		||||
 | 
			
		||||
@ -1,36 +1,189 @@
 | 
			
		||||
"""Prompt Stage Logic"""
 | 
			
		||||
from django.http import HttpResponse
 | 
			
		||||
from email.policy import Policy
 | 
			
		||||
from types import MethodType
 | 
			
		||||
from typing import Any, Callable, Iterator
 | 
			
		||||
 | 
			
		||||
from django.db.models.base import Model
 | 
			
		||||
from django.db.models.query import QuerySet
 | 
			
		||||
from django.http import HttpRequest, HttpResponse
 | 
			
		||||
from django.http.request import QueryDict
 | 
			
		||||
from django.utils.translation import gettext_lazy as _
 | 
			
		||||
from django.views.generic import FormView
 | 
			
		||||
from guardian.shortcuts import get_anonymous_user
 | 
			
		||||
from rest_framework.fields import BooleanField, CharField, IntegerField
 | 
			
		||||
from rest_framework.serializers import Serializer, ValidationError
 | 
			
		||||
from structlog.stdlib import get_logger
 | 
			
		||||
 | 
			
		||||
from authentik.flows.stage import StageView
 | 
			
		||||
from authentik.stages.prompt.forms import PromptForm
 | 
			
		||||
from authentik.core.models import User
 | 
			
		||||
from authentik.flows.challenge import Challenge, ChallengeResponse, ChallengeTypes
 | 
			
		||||
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan
 | 
			
		||||
from authentik.flows.stage import ChallengeStageView
 | 
			
		||||
from authentik.policies.engine import PolicyEngine
 | 
			
		||||
from authentik.policies.models import PolicyBinding, PolicyBindingModel
 | 
			
		||||
from authentik.stages.prompt.models import FieldTypes, Prompt, PromptStage
 | 
			
		||||
from authentik.stages.prompt.signals import password_validate
 | 
			
		||||
 | 
			
		||||
LOGGER = get_logger()
 | 
			
		||||
PLAN_CONTEXT_PROMPT = "prompt_data"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PromptStageView(FormView, StageView):
 | 
			
		||||
class PromptSerializer(Serializer):
 | 
			
		||||
    """Serializer for a single Prompt field"""
 | 
			
		||||
 | 
			
		||||
    field_key = CharField()
 | 
			
		||||
    label = CharField()
 | 
			
		||||
    type = CharField()
 | 
			
		||||
    required = BooleanField()
 | 
			
		||||
    placeholder = CharField()
 | 
			
		||||
    order = IntegerField()
 | 
			
		||||
 | 
			
		||||
    def create(self, validated_data: dict) -> Model:
 | 
			
		||||
        return Model()
 | 
			
		||||
 | 
			
		||||
    def update(self, instance: Model, validated_data: dict) -> Model:
 | 
			
		||||
        return Model()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PromptChallenge(Challenge):
 | 
			
		||||
    """Initial challenge being sent, define fields"""
 | 
			
		||||
 | 
			
		||||
    fields = PromptSerializer(many=True)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PromptResponseChallenge(ChallengeResponse):
 | 
			
		||||
    """Validate response, fields are dynamically created based
 | 
			
		||||
    on the stage"""
 | 
			
		||||
 | 
			
		||||
    def __init__(self, *args, stage: PromptStage, plan: FlowPlan, **kwargs):
 | 
			
		||||
        super().__init__(*args, **kwargs)
 | 
			
		||||
        self.stage = stage
 | 
			
		||||
        self.plan = plan
 | 
			
		||||
        # list() is called so we only load the fields once
 | 
			
		||||
        fields = list(self.stage.fields.all())
 | 
			
		||||
        for field in fields:
 | 
			
		||||
            field: Prompt
 | 
			
		||||
            self.fields[field.field_key] = field.field
 | 
			
		||||
            # Special handling for fields with username type
 | 
			
		||||
            # these check for existing users with the same username
 | 
			
		||||
            if field.type == FieldTypes.USERNAME:
 | 
			
		||||
                setattr(
 | 
			
		||||
                    self,
 | 
			
		||||
                    f"validate_{field.field_key}",
 | 
			
		||||
                    MethodType(username_field_validator_factory(), self),
 | 
			
		||||
                )
 | 
			
		||||
            # Check if we have a password field, add a handler that sends a signal
 | 
			
		||||
            # to validate it
 | 
			
		||||
            if field.type == FieldTypes.PASSWORD:
 | 
			
		||||
                setattr(
 | 
			
		||||
                    self,
 | 
			
		||||
                    f"validate_{field.field_key}",
 | 
			
		||||
                    MethodType(password_single_validator_factory(), self),
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        self.field_order = sorted(fields, key=lambda x: x.order)
 | 
			
		||||
 | 
			
		||||
    def _validate_password_fields(self, *field_names):
 | 
			
		||||
        """Check if the value of all password fields match by merging them into a set
 | 
			
		||||
        and checking the length"""
 | 
			
		||||
        all_passwords = {self.initial_data[x] for x in field_names}
 | 
			
		||||
        if len(all_passwords) > 1:
 | 
			
		||||
            raise ValidationError(_("Passwords don't match."))
 | 
			
		||||
 | 
			
		||||
    def validate(self, attrs):
 | 
			
		||||
        if attrs == {}:
 | 
			
		||||
            return {}
 | 
			
		||||
        # Check if we have two password fields, and make sure they are the same
 | 
			
		||||
        password_fields: QuerySet[Prompt] = self.stage.fields.filter(
 | 
			
		||||
            type=FieldTypes.PASSWORD
 | 
			
		||||
        )
 | 
			
		||||
        if password_fields.exists() and password_fields.count() == 2:
 | 
			
		||||
            self._validate_password_fields(
 | 
			
		||||
                *[field.field_key for field in password_fields]
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        user = self.plan.context.get(PLAN_CONTEXT_PENDING_USER, get_anonymous_user())
 | 
			
		||||
        engine = ListPolicyEngine(self.stage.validation_policies.all(), user)
 | 
			
		||||
        engine.request.context = attrs
 | 
			
		||||
        engine.build()
 | 
			
		||||
        result = engine.result
 | 
			
		||||
        if not result.passing:
 | 
			
		||||
            raise ValidationError(list(result.messages))
 | 
			
		||||
        return attrs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def username_field_validator_factory() -> Callable[[PromptChallenge, str], Any]:
 | 
			
		||||
    """Return a `clean_` method for `field`. Clean method checks if username is taken already."""
 | 
			
		||||
 | 
			
		||||
    # pylint: disable=unused-argument
 | 
			
		||||
    def username_field_validator(self: PromptChallenge, value: str) -> Any:
 | 
			
		||||
        """Check for duplicate usernames"""
 | 
			
		||||
        if User.objects.filter(username=value).exists():
 | 
			
		||||
            raise ValidationError("Username is already taken.")
 | 
			
		||||
        return value
 | 
			
		||||
 | 
			
		||||
    return username_field_validator
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def password_single_validator_factory() -> Callable[[PromptChallenge, str], Any]:
 | 
			
		||||
    """Return a `clean_` method for `field`. Clean method checks if username is taken already."""
 | 
			
		||||
 | 
			
		||||
    def password_single_clean(self: PromptChallenge, value: str) -> Any:
 | 
			
		||||
        """Send password validation signals for e.g. LDAP Source"""
 | 
			
		||||
        password_validate.send(
 | 
			
		||||
            sender=self, password=value, plan_context=self.plan.context
 | 
			
		||||
        )
 | 
			
		||||
        return value
 | 
			
		||||
 | 
			
		||||
    return password_single_clean
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ListPolicyEngine(PolicyEngine):
 | 
			
		||||
    """Slightly modified policy engine, which uses a list instead of a PolicyBindingModel"""
 | 
			
		||||
 | 
			
		||||
    __list: list[Policy]
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self, policies: list[Policy], user: User, request: HttpRequest = None
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        super().__init__(PolicyBindingModel(), user, request)
 | 
			
		||||
        self.__list = policies
 | 
			
		||||
        self.use_cache = False
 | 
			
		||||
 | 
			
		||||
    def _iter_bindings(self) -> Iterator[PolicyBinding]:
 | 
			
		||||
        for policy in self.__list:
 | 
			
		||||
            yield PolicyBinding(
 | 
			
		||||
                policy=policy,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PromptStageView(ChallengeStageView):
 | 
			
		||||
    """Prompt Stage, save form data in plan context."""
 | 
			
		||||
 | 
			
		||||
    template_name = "login/form.html"
 | 
			
		||||
    form_class = PromptForm
 | 
			
		||||
    response_class = PromptResponseChallenge
 | 
			
		||||
 | 
			
		||||
    def get_context_data(self, **kwargs):
 | 
			
		||||
        ctx = super().get_context_data(**kwargs)
 | 
			
		||||
        ctx["title"] = _(self.executor.current_stage.name)
 | 
			
		||||
        return ctx
 | 
			
		||||
    def get_challenge(self, *args, **kwargs) -> Challenge:
 | 
			
		||||
        fields = list(self.executor.current_stage.fields.all())
 | 
			
		||||
        challenge = PromptChallenge(
 | 
			
		||||
            data={
 | 
			
		||||
                "type": ChallengeTypes.native,
 | 
			
		||||
                "component": "ak-stage-prompt",
 | 
			
		||||
                "fields": [PromptSerializer(field).data for field in fields],
 | 
			
		||||
            },
 | 
			
		||||
        )
 | 
			
		||||
        return challenge
 | 
			
		||||
 | 
			
		||||
    def get_form_kwargs(self):
 | 
			
		||||
        kwargs = super().get_form_kwargs()
 | 
			
		||||
        kwargs["stage"] = self.executor.current_stage
 | 
			
		||||
        kwargs["plan"] = self.executor.plan
 | 
			
		||||
        return kwargs
 | 
			
		||||
    def get_response_instance(self, data: QueryDict) -> ChallengeResponse:
 | 
			
		||||
        if not self.executor.plan:
 | 
			
		||||
            raise ValueError
 | 
			
		||||
        return PromptResponseChallenge(
 | 
			
		||||
            instance=None,
 | 
			
		||||
            data=data,
 | 
			
		||||
            stage=self.executor.current_stage,
 | 
			
		||||
            plan=self.executor.plan,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def form_valid(self, form: PromptForm) -> HttpResponse:
 | 
			
		||||
        """Form data is valid"""
 | 
			
		||||
    def challenge_valid(self, response: ChallengeResponse) -> HttpResponse:
 | 
			
		||||
        if PLAN_CONTEXT_PROMPT not in self.executor.plan.context:
 | 
			
		||||
            self.executor.plan.context[PLAN_CONTEXT_PROMPT] = {}
 | 
			
		||||
        self.executor.plan.context[PLAN_CONTEXT_PROMPT].update(form.cleaned_data)
 | 
			
		||||
        self.executor.plan.context[PLAN_CONTEXT_PROMPT].update(response.validated_data)
 | 
			
		||||
        print(self.executor.plan.context[PLAN_CONTEXT_PROMPT])
 | 
			
		||||
        return self.executor.stage_ok()
 | 
			
		||||
 | 
			
		||||
@ -11,9 +11,8 @@ from authentik.flows.models import Flow, FlowDesignation, FlowStageBinding
 | 
			
		||||
from authentik.flows.planner import FlowPlan
 | 
			
		||||
from authentik.flows.views import SESSION_KEY_PLAN
 | 
			
		||||
from authentik.policies.expression.models import ExpressionPolicy
 | 
			
		||||
from authentik.stages.prompt.forms import PromptForm
 | 
			
		||||
from authentik.stages.prompt.models import FieldTypes, Prompt, PromptStage
 | 
			
		||||
from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT
 | 
			
		||||
from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT, PromptResponseChallenge
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestPromptStage(TestCase):
 | 
			
		||||
@ -112,8 +111,8 @@ class TestPromptStage(TestCase):
 | 
			
		||||
            self.assertIn(prompt.label, force_str(response.content))
 | 
			
		||||
            self.assertIn(prompt.placeholder, force_str(response.content))
 | 
			
		||||
 | 
			
		||||
    def test_valid_form_with_policy(self) -> PromptForm:
 | 
			
		||||
        """Test form validation"""
 | 
			
		||||
    def test_valid_challenge_with_policy(self) -> PromptResponseChallenge:
 | 
			
		||||
        """Test challenge_response validation"""
 | 
			
		||||
        plan = FlowPlan(
 | 
			
		||||
            flow_pk=self.flow.pk.hex, stages=[self.stage], markers=[StageMarker()]
 | 
			
		||||
        )
 | 
			
		||||
@ -123,12 +122,14 @@ class TestPromptStage(TestCase):
 | 
			
		||||
        )
 | 
			
		||||
        self.stage.validation_policies.set([expr_policy])
 | 
			
		||||
        self.stage.save()
 | 
			
		||||
        form = PromptForm(stage=self.stage, plan=plan, data=self.prompt_data)
 | 
			
		||||
        self.assertEqual(form.is_valid(), True)
 | 
			
		||||
        return form
 | 
			
		||||
        challenge_response = PromptResponseChallenge(
 | 
			
		||||
            None, stage=self.stage, plan=plan, data=self.prompt_data
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(challenge_response.is_valid(), True)
 | 
			
		||||
        return challenge_response
 | 
			
		||||
 | 
			
		||||
    def test_invalid_form(self) -> PromptForm:
 | 
			
		||||
        """Test form validation"""
 | 
			
		||||
    def test_invalid_challenge(self) -> PromptResponseChallenge:
 | 
			
		||||
        """Test challenge_response validation"""
 | 
			
		||||
        plan = FlowPlan(
 | 
			
		||||
            flow_pk=self.flow.pk.hex, stages=[self.stage], markers=[StageMarker()]
 | 
			
		||||
        )
 | 
			
		||||
@ -138,12 +139,14 @@ class TestPromptStage(TestCase):
 | 
			
		||||
        )
 | 
			
		||||
        self.stage.validation_policies.set([expr_policy])
 | 
			
		||||
        self.stage.save()
 | 
			
		||||
        form = PromptForm(stage=self.stage, plan=plan, data=self.prompt_data)
 | 
			
		||||
        self.assertEqual(form.is_valid(), False)
 | 
			
		||||
        return form
 | 
			
		||||
        challenge_response = PromptResponseChallenge(
 | 
			
		||||
            None, stage=self.stage, plan=plan, data=self.prompt_data
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(challenge_response.is_valid(), False)
 | 
			
		||||
        return challenge_response
 | 
			
		||||
 | 
			
		||||
    def test_valid_form_request(self):
 | 
			
		||||
        """Test a request with valid form data"""
 | 
			
		||||
    def test_valid_challenge_request(self):
 | 
			
		||||
        """Test a request with valid challenge_response data"""
 | 
			
		||||
        plan = FlowPlan(
 | 
			
		||||
            flow_pk=self.flow.pk.hex, stages=[self.stage], markers=[StageMarker()]
 | 
			
		||||
        )
 | 
			
		||||
@ -151,7 +154,7 @@ class TestPromptStage(TestCase):
 | 
			
		||||
        session[SESSION_KEY_PLAN] = plan
 | 
			
		||||
        session.save()
 | 
			
		||||
 | 
			
		||||
        form = self.test_valid_form_with_policy()
 | 
			
		||||
        challenge_response = self.test_valid_challenge_with_policy()
 | 
			
		||||
 | 
			
		||||
        with patch("authentik.flows.views.FlowExecutorView.cancel", MagicMock()):
 | 
			
		||||
            response = self.client.post(
 | 
			
		||||
@ -159,7 +162,7 @@ class TestPromptStage(TestCase):
 | 
			
		||||
                    "authentik_api:flow-executor",
 | 
			
		||||
                    kwargs={"flow_slug": self.flow.slug},
 | 
			
		||||
                ),
 | 
			
		||||
                form.cleaned_data,
 | 
			
		||||
                challenge_response.validated_data,
 | 
			
		||||
            )
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
        self.assertJSONEqual(
 | 
			
		||||
 | 
			
		||||
@ -1,17 +0,0 @@
 | 
			
		||||
"""Prompt Widgets"""
 | 
			
		||||
from django import forms
 | 
			
		||||
from django.utils.safestring import mark_safe
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class StaticTextWidget(forms.widgets.Widget):
 | 
			
		||||
    """Widget to render static text"""
 | 
			
		||||
 | 
			
		||||
    def render(self, name, value, attrs=None, renderer=None):
 | 
			
		||||
        return mark_safe(f"<p>{value}</p>")  # nosec
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HorizontalRuleWidget(forms.widgets.Widget):
 | 
			
		||||
    """Widget, which renders an <hr> element"""
 | 
			
		||||
 | 
			
		||||
    def render(self, name, value, attrs=None, renderer=None):
 | 
			
		||||
        return mark_safe("<hr>")  # nosec
 | 
			
		||||
@ -153,7 +153,7 @@ class SeleniumTestCase(StaticLiveServerTestCase):
 | 
			
		||||
        ObjectManager().run()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def retry(max_retires=3, exceptions=None):
 | 
			
		||||
def retry(max_retires=1, exceptions=None):
 | 
			
		||||
    """Retry test multiple times. Default to catching Selenium Timeout Exception"""
 | 
			
		||||
 | 
			
		||||
    if not exceptions:
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										144
									
								
								web/src/elements/stages/prompt/PromptStage.ts
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										144
									
								
								web/src/elements/stages/prompt/PromptStage.ts
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,144 @@
 | 
			
		||||
import { gettext } from "django";
 | 
			
		||||
import { CSSResult, customElement, html, property, TemplateResult } from "lit-element";
 | 
			
		||||
import { Challenge } from "../../../api/Flows";
 | 
			
		||||
import { COMMON_STYLES } from "../../../common/styles";
 | 
			
		||||
import { BaseStage } from "../base";
 | 
			
		||||
 | 
			
		||||
export interface Prompt {
 | 
			
		||||
    field_key: string;
 | 
			
		||||
    label: string;
 | 
			
		||||
    type: string;
 | 
			
		||||
    required: boolean;
 | 
			
		||||
    placeholder: string;
 | 
			
		||||
    order: number;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
export interface PromptChallenge extends Challenge {
 | 
			
		||||
    fields: Prompt[];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@customElement("ak-stage-prompt")
 | 
			
		||||
export class PromptStage extends BaseStage {
 | 
			
		||||
 | 
			
		||||
    @property({attribute: false})
 | 
			
		||||
    challenge?: PromptChallenge;
 | 
			
		||||
 | 
			
		||||
    static get styles(): CSSResult[] {
 | 
			
		||||
        return COMMON_STYLES;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    renderPromptInner(prompt: Prompt): TemplateResult {
 | 
			
		||||
        switch (prompt.type) {
 | 
			
		||||
            case "text":
 | 
			
		||||
                return html`<input
 | 
			
		||||
                    type="text"
 | 
			
		||||
                    name="${prompt.field_key}"
 | 
			
		||||
                    placeholder="${prompt.placeholder}"
 | 
			
		||||
                    autocomplete="off"
 | 
			
		||||
                    class="pf-c-form-control"
 | 
			
		||||
                    ?required=${prompt.required}
 | 
			
		||||
                    value="">`;
 | 
			
		||||
            case "username":
 | 
			
		||||
                return html`<input
 | 
			
		||||
                    type="text"
 | 
			
		||||
                    name="${prompt.field_key}"
 | 
			
		||||
                    placeholder="${prompt.placeholder}"
 | 
			
		||||
                    autocomplete="username"
 | 
			
		||||
                    class="pf-c-form-control"
 | 
			
		||||
                    ?required=${prompt.required}
 | 
			
		||||
                    value="">`;
 | 
			
		||||
            case "email":
 | 
			
		||||
                return html`<input
 | 
			
		||||
                    type="email"
 | 
			
		||||
                    name="${prompt.field_key}"
 | 
			
		||||
                    placeholder="${prompt.placeholder}"
 | 
			
		||||
                    class="pf-c-form-control"
 | 
			
		||||
                    ?required=${prompt.required}
 | 
			
		||||
                    value="">`;
 | 
			
		||||
            case "password":
 | 
			
		||||
                return html`<input
 | 
			
		||||
                    type="password"
 | 
			
		||||
                    name="${prompt.field_key}"
 | 
			
		||||
                    placeholder="${prompt.placeholder}"
 | 
			
		||||
                    autocomplete="new-password"
 | 
			
		||||
                    class="pf-c-form-control"
 | 
			
		||||
                    ?required=${prompt.required}>`;
 | 
			
		||||
            case "number":
 | 
			
		||||
                return html`<input
 | 
			
		||||
                    type="number"
 | 
			
		||||
                    name="${prompt.field_key}"
 | 
			
		||||
                    placeholder="${prompt.placeholder}"
 | 
			
		||||
                    class="pf-c-form-control"
 | 
			
		||||
                    ?required=${prompt.required}>`;
 | 
			
		||||
            case "checkbox":
 | 
			
		||||
                return html`<input
 | 
			
		||||
                    type="checkbox"
 | 
			
		||||
                    name="${prompt.field_key}"
 | 
			
		||||
                    placeholder="${prompt.placeholder}"
 | 
			
		||||
                    class="pf-c-form-control"
 | 
			
		||||
                    ?required=${prompt.required}>`;
 | 
			
		||||
            case "date":
 | 
			
		||||
                return html`<input
 | 
			
		||||
                    type="date"
 | 
			
		||||
                    name="${prompt.field_key}"
 | 
			
		||||
                    placeholder="${prompt.placeholder}"
 | 
			
		||||
                    class="pf-c-form-control"
 | 
			
		||||
                    ?required=${prompt.required}>`;
 | 
			
		||||
            case "date-time":
 | 
			
		||||
                return html`<input
 | 
			
		||||
                    type="datetime"
 | 
			
		||||
                    name="${prompt.field_key}"
 | 
			
		||||
                    placeholder="${prompt.placeholder}"
 | 
			
		||||
                    class="pf-c-form-control"
 | 
			
		||||
                    ?required=${prompt.required}>`;
 | 
			
		||||
            case "separator":
 | 
			
		||||
                return html`<hr>`;
 | 
			
		||||
            case "hidden":
 | 
			
		||||
                return html`<input
 | 
			
		||||
                    type="hidden"
 | 
			
		||||
                    name="${prompt.field_key}"
 | 
			
		||||
                    value="${prompt.placeholder}"
 | 
			
		||||
                    class="pf-c-form-control"
 | 
			
		||||
                    ?required=${prompt.required}>`;
 | 
			
		||||
            case "static":
 | 
			
		||||
                return html`<p
 | 
			
		||||
                    class="pf-c-form-control">${prompt.placeholder}
 | 
			
		||||
                </p>`;
 | 
			
		||||
        }
 | 
			
		||||
        return html``;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    render(): TemplateResult {
 | 
			
		||||
        if (!this.challenge) {
 | 
			
		||||
            return html`<ak-loading-state></ak-loading-state>`;
 | 
			
		||||
        }
 | 
			
		||||
        return html`<header class="pf-c-login__main-header">
 | 
			
		||||
                <h1 class="pf-c-title pf-m-3xl">
 | 
			
		||||
                    ${this.challenge.title}
 | 
			
		||||
                </h1>
 | 
			
		||||
            </header>
 | 
			
		||||
            <div class="pf-c-login__main-body">
 | 
			
		||||
                <form class="pf-c-form" @submit=${(e: Event) => {this.submit(e);}}>
 | 
			
		||||
                    ${this.challenge.fields.map((prompt) => {
 | 
			
		||||
                        return html`<ak-form-element
 | 
			
		||||
                            label="${prompt.label}"
 | 
			
		||||
                            ?required="${prompt.required}"
 | 
			
		||||
                            class="pf-c-form__group"
 | 
			
		||||
                            .errors=${(this.challenge?.response_errors || {})[prompt.field_key]}>
 | 
			
		||||
                            ${this.renderPromptInner(prompt)}
 | 
			
		||||
                        </ak-form-element>`;
 | 
			
		||||
                    })}
 | 
			
		||||
                    <div class="pf-c-form__group pf-m-action">
 | 
			
		||||
                        <button type="submit" class="pf-c-button pf-m-primary pf-m-block">
 | 
			
		||||
                            ${gettext("Continue")}
 | 
			
		||||
                        </button>
 | 
			
		||||
                    </div>
 | 
			
		||||
                </form>
 | 
			
		||||
            </div>
 | 
			
		||||
            <footer class="pf-c-login__main-footer">
 | 
			
		||||
                <ul class="pf-c-login__main-footer-links">
 | 
			
		||||
                </ul>
 | 
			
		||||
            </footer>`;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
@ -7,6 +7,7 @@ import "../../elements/stages/password/PasswordStage";
 | 
			
		||||
import "../../elements/stages/consent/ConsentStage";
 | 
			
		||||
import "../../elements/stages/email/EmailStage";
 | 
			
		||||
import "../../elements/stages/autosubmit/AutosubmitStage";
 | 
			
		||||
import "../../elements/stages/prompt/PromptStage";
 | 
			
		||||
import { ShellChallenge, Challenge, ChallengeTypes, Flow, RedirectChallenge } from "../../api/Flows";
 | 
			
		||||
import { DefaultClient } from "../../api/Client";
 | 
			
		||||
import { IdentificationChallenge } from "../../elements/stages/identification/IdentificationStage";
 | 
			
		||||
@ -14,6 +15,7 @@ import { PasswordChallenge } from "../../elements/stages/password/PasswordStage"
 | 
			
		||||
import { ConsentChallenge } from "../../elements/stages/consent/ConsentStage";
 | 
			
		||||
import { EmailChallenge } from "../../elements/stages/email/EmailStage";
 | 
			
		||||
import { AutosubmitChallenge } from "../../elements/stages/autosubmit/AutosubmitStage";
 | 
			
		||||
import { PromptChallenge } from "../../elements/stages/prompt/PromptStage";
 | 
			
		||||
 | 
			
		||||
@customElement("ak-flow-executor")
 | 
			
		||||
export class FlowExecutor extends LitElement {
 | 
			
		||||
@ -120,6 +122,8 @@ export class FlowExecutor extends LitElement {
 | 
			
		||||
                    return html`<ak-stage-email .host=${this} .challenge=${this.challenge as EmailChallenge}></ak-stage-email>`;
 | 
			
		||||
                case "ak-stage-autosubmit":
 | 
			
		||||
                    return html`<ak-stage-autosubmit .host=${this} .challenge=${this.challenge as AutosubmitChallenge}></ak-stage-autosubmit>`;
 | 
			
		||||
                case "ak-stage-prompt":
 | 
			
		||||
                    return html`<ak-stage-prompt .host=${this} .challenge=${this.challenge as PromptChallenge}></ak-stage-prompt>`;
 | 
			
		||||
                default:
 | 
			
		||||
                    break;
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user