providers/oauth2: launch url: if URL parsing fails, return no launch URL (#5918)
* providers/oauth2: launch url: if URL parsing fails, return no launch URL Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space> * add test Signed-off-by: Jens Langhammer <jens@goauthentik.io> * only get provider launch URL when no url is set Signed-off-by: Jens Langhammer <jens@goauthentik.io> * only catch value error Signed-off-by: Jens Langhammer <jens@goauthentik.io> * format Signed-off-by: Jens Langhammer <jens@goauthentik.io> --------- Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space> Signed-off-by: Jens Langhammer <jens@goauthentik.io> Co-authored-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
		@ -17,6 +17,7 @@ from django.urls import reverse
 | 
			
		||||
from django.utils.translation import gettext_lazy as _
 | 
			
		||||
from jwt import encode
 | 
			
		||||
from rest_framework.serializers import Serializer
 | 
			
		||||
from structlog.stdlib import get_logger
 | 
			
		||||
 | 
			
		||||
from authentik.core.models import ExpiringModel, PropertyMapping, Provider, User
 | 
			
		||||
from authentik.crypto.models import CertificateKeyPair
 | 
			
		||||
@ -26,6 +27,8 @@ from authentik.lib.utils.time import timedelta_string_validator
 | 
			
		||||
from authentik.providers.oauth2.id_token import IDToken, SubModes
 | 
			
		||||
from authentik.sources.oauth.models import OAuthSource
 | 
			
		||||
 | 
			
		||||
LOGGER = get_logger()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def generate_client_secret() -> str:
 | 
			
		||||
    """Generate client secret with adequate length"""
 | 
			
		||||
@ -251,8 +254,12 @@ class OAuth2Provider(Provider):
 | 
			
		||||
        if self.redirect_uris == "":
 | 
			
		||||
            return None
 | 
			
		||||
        main_url = self.redirect_uris.split("\n", maxsplit=1)[0]
 | 
			
		||||
        launch_url = urlparse(main_url)._replace(path="")
 | 
			
		||||
        return urlunparse(launch_url)
 | 
			
		||||
        try:
 | 
			
		||||
            launch_url = urlparse(main_url)._replace(path="")
 | 
			
		||||
            return urlunparse(launch_url)
 | 
			
		||||
        except ValueError as exc:
 | 
			
		||||
            LOGGER.warning("Failed to format launch url", exc=exc)
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def component(self) -> str:
 | 
			
		||||
 | 
			
		||||
@ -1,5 +1,7 @@
 | 
			
		||||
"""Test OAuth2 API"""
 | 
			
		||||
from json import loads
 | 
			
		||||
from sys import version_info
 | 
			
		||||
from unittest import skipUnless
 | 
			
		||||
 | 
			
		||||
from django.urls import reverse
 | 
			
		||||
from rest_framework.test import APITestCase
 | 
			
		||||
@ -42,3 +44,14 @@ class TestAPI(APITestCase):
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
        body = loads(response.content.decode())
 | 
			
		||||
        self.assertEqual(body["issuer"], "http://testserver/application/o/test/")
 | 
			
		||||
 | 
			
		||||
    # https://github.com/goauthentik/authentik/pull/5918
 | 
			
		||||
    @skipUnless(version_info >= (3, 11, 4), "This behaviour is only Python 3.11.4 and up")
 | 
			
		||||
    def test_launch_url(self):
 | 
			
		||||
        """Test launch_url"""
 | 
			
		||||
        self.provider.redirect_uris = (
 | 
			
		||||
            "https://[\\d\\w]+.pr.test.goauthentik.io/source/oauth/callback/authentik/\n"
 | 
			
		||||
        )
 | 
			
		||||
        self.provider.save()
 | 
			
		||||
        self.provider.refresh_from_db()
 | 
			
		||||
        self.assertIsNone(self.provider.launch_url)
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user