*: migrate ui_* properties to functions to allow context being passed
Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
		@ -104,14 +104,14 @@ class SourceViewSet(
 | 
			
		||||
        )
 | 
			
		||||
        matching_sources: list[UserSettingSerializer] = []
 | 
			
		||||
        for source in _all_sources:
 | 
			
		||||
            user_settings = source.ui_user_settings
 | 
			
		||||
            user_settings = source.ui_user_settings()
 | 
			
		||||
            if not user_settings:
 | 
			
		||||
                continue
 | 
			
		||||
            policy_engine = PolicyEngine(source, request.user, request)
 | 
			
		||||
            policy_engine.build()
 | 
			
		||||
            if not policy_engine.passing:
 | 
			
		||||
                continue
 | 
			
		||||
            source_settings = source.ui_user_settings
 | 
			
		||||
            source_settings = source.ui_user_settings()
 | 
			
		||||
            source_settings.initial_data["object_uid"] = source.slug
 | 
			
		||||
            if not source_settings.is_valid():
 | 
			
		||||
                LOGGER.warning(source_settings.errors)
 | 
			
		||||
 | 
			
		||||
@ -359,13 +359,11 @@ class Source(ManagedModel, SerializerModel, PolicyBindingModel):
 | 
			
		||||
        """Return component used to edit this object"""
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def ui_login_button(self) -> Optional[UILoginButton]:
 | 
			
		||||
    def ui_login_button(self, request: HttpRequest) -> Optional[UILoginButton]:
 | 
			
		||||
        """If source uses a http-based flow, return UI Information about the login
 | 
			
		||||
        button. If source doesn't use http-based flow, return None."""
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def ui_user_settings(self) -> Optional[UserSettingSerializer]:
 | 
			
		||||
        """Entrypoint to integrate with User settings. Can either return None if no
 | 
			
		||||
        user settings are available, or UserSettingSerializer."""
 | 
			
		||||
 | 
			
		||||
@ -2,7 +2,7 @@
 | 
			
		||||
from time import sleep
 | 
			
		||||
from typing import Callable, Type
 | 
			
		||||
 | 
			
		||||
from django.test import TestCase
 | 
			
		||||
from django.test import RequestFactory, TestCase
 | 
			
		||||
from django.utils.timezone import now
 | 
			
		||||
from guardian.shortcuts import get_anonymous_user
 | 
			
		||||
 | 
			
		||||
@ -30,6 +30,9 @@ class TestModels(TestCase):
 | 
			
		||||
def source_tester_factory(test_model: Type[Stage]) -> Callable:
 | 
			
		||||
    """Test source"""
 | 
			
		||||
 | 
			
		||||
    factory = RequestFactory()
 | 
			
		||||
    request = factory.get("/")
 | 
			
		||||
 | 
			
		||||
    def tester(self: TestModels):
 | 
			
		||||
        model_class = None
 | 
			
		||||
        if test_model._meta.abstract:
 | 
			
		||||
@ -38,8 +41,8 @@ def source_tester_factory(test_model: Type[Stage]) -> Callable:
 | 
			
		||||
            model_class = test_model()
 | 
			
		||||
        model_class.slug = "test"
 | 
			
		||||
        self.assertIsNotNone(model_class.component)
 | 
			
		||||
        _ = model_class.ui_login_button
 | 
			
		||||
        _ = model_class.ui_user_settings
 | 
			
		||||
        _ = model_class.ui_login_button(request)
 | 
			
		||||
        _ = model_class.ui_user_settings()
 | 
			
		||||
 | 
			
		||||
    return tester
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -90,7 +90,7 @@ class StageViewSet(
 | 
			
		||||
            stages += list(configurable_stage.objects.all().order_by("name"))
 | 
			
		||||
        matching_stages: list[dict] = []
 | 
			
		||||
        for stage in stages:
 | 
			
		||||
            user_settings = stage.ui_user_settings
 | 
			
		||||
            user_settings = stage.ui_user_settings()
 | 
			
		||||
            if not user_settings:
 | 
			
		||||
                continue
 | 
			
		||||
            user_settings.initial_data["object_uid"] = str(stage.pk)
 | 
			
		||||
 | 
			
		||||
@ -75,7 +75,6 @@ class Stage(SerializerModel):
 | 
			
		||||
        """Return component used to edit this object"""
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def ui_user_settings(self) -> Optional[UserSettingSerializer]:
 | 
			
		||||
        """Entrypoint to integrate with User settings. Can either return None if no
 | 
			
		||||
        user settings are available, or a challenge."""
 | 
			
		||||
 | 
			
		||||
@ -32,7 +32,7 @@ class TestFlowsAPI(APITestCase):
 | 
			
		||||
 | 
			
		||||
    def test_models(self):
 | 
			
		||||
        """Test that ui_user_settings returns none"""
 | 
			
		||||
        self.assertIsNone(Stage().ui_user_settings)
 | 
			
		||||
        self.assertIsNone(Stage().ui_user_settings())
 | 
			
		||||
 | 
			
		||||
    def test_api_serializer(self):
 | 
			
		||||
        """Test that stage serializer returns the correct type"""
 | 
			
		||||
 | 
			
		||||
@ -23,7 +23,7 @@ def model_tester_factory(test_model: Type[Stage]) -> Callable:
 | 
			
		||||
            model_class = test_model()
 | 
			
		||||
        self.assertTrue(issubclass(model_class.type, StageView))
 | 
			
		||||
        self.assertIsNotNone(test_model.component)
 | 
			
		||||
        _ = model_class.ui_user_settings
 | 
			
		||||
        _ = model_class.ui_user_settings()
 | 
			
		||||
 | 
			
		||||
    return tester
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -2,6 +2,7 @@
 | 
			
		||||
from typing import TYPE_CHECKING, Optional, Type
 | 
			
		||||
 | 
			
		||||
from django.db import models
 | 
			
		||||
from django.http.request import HttpRequest
 | 
			
		||||
from django.urls import reverse
 | 
			
		||||
from django.utils.translation import gettext_lazy as _
 | 
			
		||||
from rest_framework.serializers import Serializer
 | 
			
		||||
@ -63,11 +64,15 @@ class OAuthSource(Source):
 | 
			
		||||
 | 
			
		||||
        return OAuthSourceSerializer
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def ui_login_button(self) -> UILoginButton:
 | 
			
		||||
        return self.type().ui_login_button()
 | 
			
		||||
    def ui_login_button(self, request: HttpRequest) -> UILoginButton:
 | 
			
		||||
        provider_type = self.type
 | 
			
		||||
        provider = provider_type()
 | 
			
		||||
        return UILoginButton(
 | 
			
		||||
            name=self.name,
 | 
			
		||||
            icon_url=provider.icon_url(),
 | 
			
		||||
            challenge=provider.login_challenge(self, request),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def ui_user_settings(self) -> Optional[UserSettingSerializer]:
 | 
			
		||||
        return UserSettingSerializer(
 | 
			
		||||
            data={
 | 
			
		||||
 | 
			
		||||
@ -2,12 +2,13 @@
 | 
			
		||||
from enum import Enum
 | 
			
		||||
from typing import Callable, Optional, Type
 | 
			
		||||
 | 
			
		||||
from django.http.request import HttpRequest
 | 
			
		||||
from django.templatetags.static import static
 | 
			
		||||
from django.urls.base import reverse
 | 
			
		||||
from structlog.stdlib import get_logger
 | 
			
		||||
 | 
			
		||||
from authentik.core.types import UILoginButton
 | 
			
		||||
from authentik.flows.challenge import ChallengeTypes, RedirectChallenge
 | 
			
		||||
from authentik.flows.challenge import Challenge, ChallengeTypes, RedirectChallenge
 | 
			
		||||
from authentik.sources.oauth.models import OAuthSource
 | 
			
		||||
from authentik.sources.oauth.views.callback import OAuthCallback
 | 
			
		||||
from authentik.sources.oauth.views.redirect import OAuthRedirect
 | 
			
		||||
 | 
			
		||||
@ -40,20 +41,17 @@ class SourceType:
 | 
			
		||||
        """Get Icon URL for login"""
 | 
			
		||||
        return static(f"authentik/sources/{self.slug}.svg")
 | 
			
		||||
 | 
			
		||||
    def ui_login_button(self) -> UILoginButton:
 | 
			
		||||
    # pylint: disable=unused-argument
 | 
			
		||||
    def login_challenge(self, source: OAuthSource, request: HttpRequest) -> Challenge:
 | 
			
		||||
        """Allow types to return custom challenges"""
 | 
			
		||||
        return UILoginButton(
 | 
			
		||||
            challenge=RedirectChallenge(
 | 
			
		||||
                instance={
 | 
			
		||||
                    "type": ChallengeTypes.REDIRECT.value,
 | 
			
		||||
                    "to": reverse(
 | 
			
		||||
                        "authentik_sources_oauth:oauth-client-login",
 | 
			
		||||
                        kwargs={"source_slug": self.slug},
 | 
			
		||||
                    ),
 | 
			
		||||
                }
 | 
			
		||||
            ),
 | 
			
		||||
            icon_url=self.icon_url(),
 | 
			
		||||
            name=self.name,
 | 
			
		||||
        return RedirectChallenge(
 | 
			
		||||
            instance={
 | 
			
		||||
                "type": ChallengeTypes.REDIRECT.value,
 | 
			
		||||
                "to": reverse(
 | 
			
		||||
                    "authentik_sources_oauth:oauth-client-login",
 | 
			
		||||
                    kwargs={"source_slug": self.slug},
 | 
			
		||||
                ),
 | 
			
		||||
            }
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -3,6 +3,7 @@ from typing import Optional
 | 
			
		||||
 | 
			
		||||
from django.contrib.postgres.fields import ArrayField
 | 
			
		||||
from django.db import models
 | 
			
		||||
from django.http.request import HttpRequest
 | 
			
		||||
from django.templatetags.static import static
 | 
			
		||||
from django.utils.translation import gettext_lazy as _
 | 
			
		||||
from rest_framework.fields import CharField
 | 
			
		||||
@ -62,8 +63,7 @@ class PlexSource(Source):
 | 
			
		||||
 | 
			
		||||
        return PlexSourceSerializer
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def ui_login_button(self) -> UILoginButton:
 | 
			
		||||
    def ui_login_button(self, request: HttpRequest) -> UILoginButton:
 | 
			
		||||
        return UILoginButton(
 | 
			
		||||
            challenge=PlexAuthenticationChallenge(
 | 
			
		||||
                {
 | 
			
		||||
@ -77,7 +77,6 @@ class PlexSource(Source):
 | 
			
		||||
            name=self.name,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def ui_user_settings(self) -> Optional[UserSettingSerializer]:
 | 
			
		||||
        return UserSettingSerializer(
 | 
			
		||||
            data={
 | 
			
		||||
 | 
			
		||||
@ -167,8 +167,7 @@ class SAMLSource(Source):
 | 
			
		||||
            reverse(f"authentik_sources_saml:{view}", kwargs={"source_slug": self.slug})
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def ui_login_button(self) -> UILoginButton:
 | 
			
		||||
    def ui_login_button(self, request: HttpRequest) -> UILoginButton:
 | 
			
		||||
        return UILoginButton(
 | 
			
		||||
            challenge=RedirectChallenge(
 | 
			
		||||
                instance={
 | 
			
		||||
 | 
			
		||||
@ -48,7 +48,6 @@ class AuthenticatorDuoStage(ConfigurableStage, Stage):
 | 
			
		||||
    def component(self) -> str:
 | 
			
		||||
        return "ak-stage-authenticator-duo-form"
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def ui_user_settings(self) -> Optional[UserSettingSerializer]:
 | 
			
		||||
        return UserSettingSerializer(
 | 
			
		||||
            data={
 | 
			
		||||
 | 
			
		||||
@ -141,7 +141,6 @@ class AuthenticatorSMSStage(ConfigurableStage, Stage):
 | 
			
		||||
    def component(self) -> str:
 | 
			
		||||
        return "ak-stage-authenticator-sms-form"
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def ui_user_settings(self) -> Optional[UserSettingSerializer]:
 | 
			
		||||
        return UserSettingSerializer(
 | 
			
		||||
            data={
 | 
			
		||||
 | 
			
		||||
@ -31,7 +31,6 @@ class AuthenticatorStaticStage(ConfigurableStage, Stage):
 | 
			
		||||
    def component(self) -> str:
 | 
			
		||||
        return "ak-stage-authenticator-static-form"
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def ui_user_settings(self) -> Optional[UserSettingSerializer]:
 | 
			
		||||
        return UserSettingSerializer(
 | 
			
		||||
            data={
 | 
			
		||||
 | 
			
		||||
@ -38,7 +38,6 @@ class AuthenticatorTOTPStage(ConfigurableStage, Stage):
 | 
			
		||||
    def component(self) -> str:
 | 
			
		||||
        return "ak-stage-authenticator-totp-form"
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def ui_user_settings(self) -> Optional[UserSettingSerializer]:
 | 
			
		||||
        return UserSettingSerializer(
 | 
			
		||||
            data={
 | 
			
		||||
 | 
			
		||||
@ -34,7 +34,6 @@ class AuthenticateWebAuthnStage(ConfigurableStage, Stage):
 | 
			
		||||
    def component(self) -> str:
 | 
			
		||||
        return "ak-stage-authenticator-webauthn-form"
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def ui_user_settings(self) -> Optional[UserSettingSerializer]:
 | 
			
		||||
        return UserSettingSerializer(
 | 
			
		||||
            data={
 | 
			
		||||
 | 
			
		||||
@ -191,7 +191,7 @@ class IdentificationStageView(ChallengeStageView):
 | 
			
		||||
            current_stage.sources.filter(enabled=True).order_by("name").select_subclasses()
 | 
			
		||||
        )
 | 
			
		||||
        for source in sources:
 | 
			
		||||
            ui_login_button = source.ui_login_button
 | 
			
		||||
            ui_login_button = source.ui_login_button(self.request)
 | 
			
		||||
            if ui_login_button:
 | 
			
		||||
                button = asdict(ui_login_button)
 | 
			
		||||
                button["challenge"] = ui_login_button.challenge.data
 | 
			
		||||
 | 
			
		||||
@ -63,7 +63,6 @@ class PasswordStage(ConfigurableStage, Stage):
 | 
			
		||||
    def component(self) -> str:
 | 
			
		||||
        return "ak-stage-password-form"
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def ui_user_settings(self) -> Optional[UserSettingSerializer]:
 | 
			
		||||
        if not self.configure_flow:
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user