providers/oauth2: if a redirect_uri cannot be parsed as regex, compare strict (#3070)
Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
		| @ -47,11 +47,11 @@ def create_test_tenant() -> Tenant: | ||||
|  | ||||
| def create_test_cert() -> CertificateKeyPair: | ||||
|     """Generate a certificate for testing""" | ||||
|     CertificateKeyPair.objects.filter(name="goauthentik.io").delete() | ||||
|     builder = CertificateBuilder() | ||||
|     builder.common_name = "goauthentik.io" | ||||
|     builder.build( | ||||
|         subject_alt_names=["goauthentik.io"], | ||||
|         validity_days=360, | ||||
|     ) | ||||
|     builder.name = generate_id() | ||||
|     return builder.save() | ||||
|  | ||||
| @ -53,10 +53,7 @@ class CertificateBuilder: | ||||
|             .subject_name( | ||||
|                 x509.Name( | ||||
|                     [ | ||||
|                         x509.NameAttribute( | ||||
|                             NameOID.COMMON_NAME, | ||||
|                             self.common_name, | ||||
|                         ), | ||||
|                         x509.NameAttribute(NameOID.COMMON_NAME, self.common_name), | ||||
|                         x509.NameAttribute(NameOID.ORGANIZATION_NAME, "authentik"), | ||||
|                         x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, "Self-signed"), | ||||
|                     ] | ||||
| @ -65,10 +62,7 @@ class CertificateBuilder: | ||||
|             .issuer_name( | ||||
|                 x509.Name( | ||||
|                     [ | ||||
|                         x509.NameAttribute( | ||||
|                             NameOID.COMMON_NAME, | ||||
|                             f"authentik {__version__}", | ||||
|                         ), | ||||
|                         x509.NameAttribute(NameOID.COMMON_NAME, f"authentik {__version__}"), | ||||
|                     ] | ||||
|                 ) | ||||
|             ) | ||||
|  | ||||
| @ -3,7 +3,7 @@ from django.test import RequestFactory | ||||
| from django.urls import reverse | ||||
|  | ||||
| from authentik.core.models import Application | ||||
| from authentik.core.tests.utils import create_test_admin_user, create_test_cert, create_test_flow | ||||
| from authentik.core.tests.utils import create_test_admin_user, create_test_flow | ||||
| from authentik.flows.challenge import ChallengeTypes | ||||
| from authentik.lib.generators import generate_id, generate_key | ||||
| from authentik.providers.oauth2.errors import AuthorizeError, ClientIdError, RedirectUriError | ||||
| @ -39,7 +39,7 @@ class TestAuthorize(OAuthTestCase): | ||||
|     def test_request(self): | ||||
|         """test request param""" | ||||
|         OAuth2Provider.objects.create( | ||||
|             name="test", | ||||
|             name=generate_id(), | ||||
|             client_id="test", | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris="http://local.invalid/Foo", | ||||
| @ -59,7 +59,7 @@ class TestAuthorize(OAuthTestCase): | ||||
|     def test_invalid_redirect_uri(self): | ||||
|         """test missing/invalid redirect URI""" | ||||
|         OAuth2Provider.objects.create( | ||||
|             name="test", | ||||
|             name=generate_id(), | ||||
|             client_id="test", | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris="http://local.invalid", | ||||
| @ -78,10 +78,55 @@ class TestAuthorize(OAuthTestCase): | ||||
|             ) | ||||
|             OAuthAuthorizationParams.from_request(request) | ||||
|  | ||||
|     def test_invalid_redirect_uri_empty(self): | ||||
|         """test missing/invalid redirect URI""" | ||||
|         provider = OAuth2Provider.objects.create( | ||||
|             name=generate_id(), | ||||
|             client_id="test", | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris="", | ||||
|         ) | ||||
|         with self.assertRaises(RedirectUriError): | ||||
|             request = self.factory.get("/", data={"response_type": "code", "client_id": "test"}) | ||||
|             OAuthAuthorizationParams.from_request(request) | ||||
|         request = self.factory.get( | ||||
|             "/", | ||||
|             data={ | ||||
|                 "response_type": "code", | ||||
|                 "client_id": "test", | ||||
|                 "redirect_uri": "+", | ||||
|             }, | ||||
|         ) | ||||
|         OAuthAuthorizationParams.from_request(request) | ||||
|         provider.refresh_from_db() | ||||
|         self.assertEqual(provider.redirect_uris, "+") | ||||
|  | ||||
|     def test_invalid_redirect_uri_regex(self): | ||||
|         """test missing/invalid redirect URI""" | ||||
|         OAuth2Provider.objects.create( | ||||
|             name="test", | ||||
|             name=generate_id(), | ||||
|             client_id="test", | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris="http://local.invalid?", | ||||
|         ) | ||||
|         with self.assertRaises(RedirectUriError): | ||||
|             request = self.factory.get("/", data={"response_type": "code", "client_id": "test"}) | ||||
|             OAuthAuthorizationParams.from_request(request) | ||||
|         with self.assertRaises(RedirectUriError): | ||||
|             request = self.factory.get( | ||||
|                 "/", | ||||
|                 data={ | ||||
|                     "response_type": "code", | ||||
|                     "client_id": "test", | ||||
|                     "redirect_uri": "http://localhost", | ||||
|                 }, | ||||
|             ) | ||||
|             OAuthAuthorizationParams.from_request(request) | ||||
|  | ||||
|     def test_redirect_uri_invalid_regex(self): | ||||
|         """test missing/invalid redirect URI (invalid regex)""" | ||||
|         OAuth2Provider.objects.create( | ||||
|             name=generate_id(), | ||||
|             client_id="test", | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris="+", | ||||
| @ -103,7 +148,7 @@ class TestAuthorize(OAuthTestCase): | ||||
|     def test_empty_redirect_uri(self): | ||||
|         """test empty redirect URI (configure in provider)""" | ||||
|         OAuth2Provider.objects.create( | ||||
|             name="test", | ||||
|             name=generate_id(), | ||||
|             client_id="test", | ||||
|             authorization_flow=create_test_flow(), | ||||
|         ) | ||||
| @ -123,7 +168,7 @@ class TestAuthorize(OAuthTestCase): | ||||
|     def test_response_type(self): | ||||
|         """test response_type""" | ||||
|         OAuth2Provider.objects.create( | ||||
|             name="test", | ||||
|             name=generate_id(), | ||||
|             client_id="test", | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris="http://local.invalid/Foo", | ||||
| @ -201,7 +246,7 @@ class TestAuthorize(OAuthTestCase): | ||||
|         """Test full authorization""" | ||||
|         flow = create_test_flow() | ||||
|         provider = OAuth2Provider.objects.create( | ||||
|             name="test", | ||||
|             name=generate_id(), | ||||
|             client_id="test", | ||||
|             authorization_flow=flow, | ||||
|             redirect_uris="foo://localhost", | ||||
| @ -237,12 +282,12 @@ class TestAuthorize(OAuthTestCase): | ||||
|         """Test full authorization""" | ||||
|         flow = create_test_flow() | ||||
|         provider = OAuth2Provider.objects.create( | ||||
|             name="test", | ||||
|             name=generate_id(), | ||||
|             client_id="test", | ||||
|             client_secret=generate_key(), | ||||
|             authorization_flow=flow, | ||||
|             redirect_uris="http://localhost", | ||||
|             signing_key=create_test_cert(), | ||||
|             signing_key=self.keypair, | ||||
|         ) | ||||
|         Application.objects.create(name="app", slug="app", provider=provider) | ||||
|         state = generate_id() | ||||
| @ -281,12 +326,12 @@ class TestAuthorize(OAuthTestCase): | ||||
|         """Test full authorization (form_post response)""" | ||||
|         flow = create_test_flow() | ||||
|         provider = OAuth2Provider.objects.create( | ||||
|             name="test", | ||||
|             name=generate_id(), | ||||
|             client_id="test", | ||||
|             client_secret=generate_key(), | ||||
|             authorization_flow=flow, | ||||
|             redirect_uris="http://localhost", | ||||
|             signing_key=create_test_cert(), | ||||
|             signing_key=self.keypair, | ||||
|         ) | ||||
|         Application.objects.create(name="app", slug="app", provider=provider) | ||||
|         state = generate_id() | ||||
|  | ||||
| @ -5,7 +5,7 @@ from django.test import RequestFactory | ||||
| from django.urls import reverse | ||||
|  | ||||
| from authentik.core.models import Application | ||||
| from authentik.core.tests.utils import create_test_admin_user, create_test_cert, create_test_flow | ||||
| from authentik.core.tests.utils import create_test_admin_user, create_test_flow | ||||
| from authentik.events.models import Event, EventAction | ||||
| from authentik.lib.generators import generate_id, generate_key | ||||
| from authentik.providers.oauth2.constants import ( | ||||
| @ -24,17 +24,17 @@ class TestToken(OAuthTestCase): | ||||
|     def setUp(self) -> None: | ||||
|         super().setUp() | ||||
|         self.factory = RequestFactory() | ||||
|         self.app = Application.objects.create(name="test", slug="test") | ||||
|         self.app = Application.objects.create(name=generate_id(), slug="test") | ||||
|  | ||||
|     def test_request_auth_code(self): | ||||
|         """test request param""" | ||||
|         provider = OAuth2Provider.objects.create( | ||||
|             name="test", | ||||
|             name=generate_id(), | ||||
|             client_id=generate_id(), | ||||
|             client_secret=generate_key(), | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris="http://testserver", | ||||
|             signing_key=create_test_cert(), | ||||
|             signing_key=self.keypair, | ||||
|         ) | ||||
|         header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode() | ||||
|         user = create_test_admin_user() | ||||
| @ -56,12 +56,12 @@ class TestToken(OAuthTestCase): | ||||
|     def test_request_auth_code_invalid(self): | ||||
|         """test request param""" | ||||
|         provider = OAuth2Provider.objects.create( | ||||
|             name="test", | ||||
|             name=generate_id(), | ||||
|             client_id=generate_id(), | ||||
|             client_secret=generate_key(), | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris="http://testserver", | ||||
|             signing_key=create_test_cert(), | ||||
|             signing_key=self.keypair, | ||||
|         ) | ||||
|         header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode() | ||||
|         request = self.factory.post( | ||||
| @ -79,12 +79,12 @@ class TestToken(OAuthTestCase): | ||||
|     def test_request_refresh_token(self): | ||||
|         """test request param""" | ||||
|         provider = OAuth2Provider.objects.create( | ||||
|             name="test", | ||||
|             name=generate_id(), | ||||
|             client_id=generate_id(), | ||||
|             client_secret=generate_key(), | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris="http://local.invalid", | ||||
|             signing_key=create_test_cert(), | ||||
|             signing_key=self.keypair, | ||||
|         ) | ||||
|         header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode() | ||||
|         user = create_test_admin_user() | ||||
| @ -108,12 +108,12 @@ class TestToken(OAuthTestCase): | ||||
|     def test_auth_code_view(self): | ||||
|         """test request param""" | ||||
|         provider = OAuth2Provider.objects.create( | ||||
|             name="test", | ||||
|             name=generate_id(), | ||||
|             client_id=generate_id(), | ||||
|             client_secret=generate_key(), | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris="http://local.invalid", | ||||
|             signing_key=create_test_cert(), | ||||
|             signing_key=self.keypair, | ||||
|         ) | ||||
|         # Needs to be assigned to an application for iss to be set | ||||
|         self.app.provider = provider | ||||
| @ -150,12 +150,12 @@ class TestToken(OAuthTestCase): | ||||
|     def test_refresh_token_view(self): | ||||
|         """test request param""" | ||||
|         provider = OAuth2Provider.objects.create( | ||||
|             name="test", | ||||
|             name=generate_id(), | ||||
|             client_id=generate_id(), | ||||
|             client_secret=generate_key(), | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris="http://local.invalid", | ||||
|             signing_key=create_test_cert(), | ||||
|             signing_key=self.keypair, | ||||
|         ) | ||||
|         # Needs to be assigned to an application for iss to be set | ||||
|         self.app.provider = provider | ||||
| @ -199,12 +199,12 @@ class TestToken(OAuthTestCase): | ||||
|     def test_refresh_token_view_invalid_origin(self): | ||||
|         """test request param""" | ||||
|         provider = OAuth2Provider.objects.create( | ||||
|             name="test", | ||||
|             name=generate_id(), | ||||
|             client_id=generate_id(), | ||||
|             client_secret=generate_key(), | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris="http://local.invalid", | ||||
|             signing_key=create_test_cert(), | ||||
|             signing_key=self.keypair, | ||||
|         ) | ||||
|         header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode() | ||||
|         user = create_test_admin_user() | ||||
| @ -244,12 +244,12 @@ class TestToken(OAuthTestCase): | ||||
|     def test_refresh_token_revoke(self): | ||||
|         """test request param""" | ||||
|         provider = OAuth2Provider.objects.create( | ||||
|             name="test", | ||||
|             name=generate_id(), | ||||
|             client_id=generate_id(), | ||||
|             client_secret=generate_key(), | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris="http://testserver", | ||||
|             signing_key=create_test_cert(), | ||||
|             signing_key=self.keypair, | ||||
|         ) | ||||
|         # Needs to be assigned to an application for iss to be set | ||||
|         self.app.provider = provider | ||||
|  | ||||
| @ -2,12 +2,15 @@ | ||||
| from django.test import TestCase | ||||
| from jwt import decode | ||||
|  | ||||
| from authentik.core.tests.utils import create_test_cert | ||||
| from authentik.crypto.models import CertificateKeyPair | ||||
| from authentik.providers.oauth2.models import JWTAlgorithms, OAuth2Provider, RefreshToken | ||||
|  | ||||
|  | ||||
| class OAuthTestCase(TestCase): | ||||
|     """OAuth test helpers""" | ||||
|  | ||||
|     keypair: CertificateKeyPair | ||||
|     required_jwt_keys = [ | ||||
|         "exp", | ||||
|         "iat", | ||||
| @ -17,6 +20,11 @@ class OAuthTestCase(TestCase): | ||||
|         "iss", | ||||
|     ] | ||||
|  | ||||
|     @classmethod | ||||
|     def setUpClass(cls) -> None: | ||||
|         cls.keypair = create_test_cert() | ||||
|         super().setUpClass() | ||||
|  | ||||
|     def validate_jwt(self, token: RefreshToken, provider: OAuth2Provider): | ||||
|         """Validate that all required fields are set""" | ||||
|         key, alg = provider.get_jwt_key() | ||||
|  | ||||
| @ -2,7 +2,7 @@ | ||||
| from dataclasses import dataclass, field | ||||
| from datetime import timedelta | ||||
| from re import error as RegexError | ||||
| from re import escape, fullmatch | ||||
| from re import fullmatch | ||||
| from typing import Optional | ||||
| from urllib.parse import parse_qs, urlencode, urlparse, urlsplit, urlunsplit | ||||
| from uuid import uuid4 | ||||
| @ -181,7 +181,7 @@ class OAuthAuthorizationParams: | ||||
|  | ||||
|         if self.provider.redirect_uris == "": | ||||
|             LOGGER.info("Setting redirect for blank redirect_uris", redirect=self.redirect_uri) | ||||
|             self.provider.redirect_uris = escape(self.redirect_uri) | ||||
|             self.provider.redirect_uris = self.redirect_uri | ||||
|             self.provider.save() | ||||
|             allowed_redirect_urls = self.provider.redirect_uris.split() | ||||
|  | ||||
| @ -194,14 +194,20 @@ class OAuthAuthorizationParams: | ||||
|         try: | ||||
|             if not any(fullmatch(x, self.redirect_uri) for x in allowed_redirect_urls): | ||||
|                 LOGGER.warning( | ||||
|                     "Invalid redirect uri", | ||||
|                     "Invalid redirect uri (regex comparison)", | ||||
|                     redirect_uri=self.redirect_uri, | ||||
|                     expected=allowed_redirect_urls, | ||||
|                 ) | ||||
|                 raise RedirectUriError(self.redirect_uri, allowed_redirect_urls) | ||||
|         except RegexError as exc: | ||||
|             LOGGER.warning("Invalid regular expression configured", exc=exc) | ||||
|             raise RedirectUriError(self.redirect_uri, allowed_redirect_urls) | ||||
|             LOGGER.info("Failed to parse regular expression, checking directly", exc=exc) | ||||
|             if not any(x == self.redirect_uri for x in allowed_redirect_urls): | ||||
|                 LOGGER.warning( | ||||
|                     "Invalid redirect uri (strict comparison)", | ||||
|                     redirect_uri=self.redirect_uri, | ||||
|                     expected=allowed_redirect_urls, | ||||
|                 ) | ||||
|                 raise RedirectUriError(self.redirect_uri, allowed_redirect_urls) | ||||
|         if self.request: | ||||
|             raise AuthorizeError( | ||||
|                 self.redirect_uri, "request_not_supported", self.grant_type, self.state | ||||
|  | ||||
| @ -154,7 +154,7 @@ class TokenParams: | ||||
|         try: | ||||
|             if not any(fullmatch(x, self.redirect_uri) for x in allowed_redirect_urls): | ||||
|                 LOGGER.warning( | ||||
|                     "Invalid redirect uri", | ||||
|                     "Invalid redirect uri (regex comparison)", | ||||
|                     redirect_uri=self.redirect_uri, | ||||
|                     expected=allowed_redirect_urls, | ||||
|                 ) | ||||
| @ -167,13 +167,19 @@ class TokenParams: | ||||
|                 ).from_http(request) | ||||
|                 raise TokenError("invalid_client") | ||||
|         except RegexError as exc: | ||||
|             LOGGER.warning("Invalid regular expression configured", exc=exc) | ||||
|             Event.new( | ||||
|                 EventAction.CONFIGURATION_ERROR, | ||||
|                 message="Invalid redirect_uri RegEx configured", | ||||
|                 provider=self.provider, | ||||
|             ).from_http(request) | ||||
|             raise TokenError("invalid_client") | ||||
|             LOGGER.info("Failed to parse regular expression, checking directly", exc=exc) | ||||
|             if not any(x == self.redirect_uri for x in allowed_redirect_urls): | ||||
|                 LOGGER.warning( | ||||
|                     "Invalid redirect uri (strict comparison)", | ||||
|                     redirect_uri=self.redirect_uri, | ||||
|                     expected=allowed_redirect_urls, | ||||
|                 ) | ||||
|                 Event.new( | ||||
|                     EventAction.CONFIGURATION_ERROR, | ||||
|                     message="Invalid redirect_uri configured", | ||||
|                     provider=self.provider, | ||||
|                 ).from_http(request) | ||||
|                 raise TokenError("invalid_client") | ||||
|  | ||||
|         try: | ||||
|             self.authorization_code = AuthorizationCode.objects.get(code=raw_code) | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	 Jens L
					Jens L