Compare commits

...

2 Commits

Author SHA1 Message Date
7549a6b83d add cause for oauth authorization errors
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2025-06-16 03:04:20 +02:00
bb45b714e2 fix incorrect tests/add more
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2025-06-16 02:54:16 +02:00
3 changed files with 107 additions and 37 deletions

View File

@ -15,6 +15,7 @@ class OAuth2Error(SentryIgnoredException):
error: str
description: str
cause: str | None = None
def create_dict(self):
"""Return error as dict for JSON Rendering"""
@ -34,6 +35,10 @@ class OAuth2Error(SentryIgnoredException):
**kwargs,
)
def with_cause(self, cause: str):
self.cause = cause
return self
class RedirectUriError(OAuth2Error):
"""The request fails due to a missing, invalid, or mismatching

View File

@ -12,7 +12,7 @@ from authentik.core.tests.utils import create_test_admin_user, create_test_flow
from authentik.events.models import Event, EventAction
from authentik.lib.generators import generate_id
from authentik.lib.utils.time import timedelta_from_string
from authentik.providers.oauth2.constants import TOKEN_TYPE
from authentik.providers.oauth2.constants import SCOPE_OFFLINE_ACCESS, SCOPE_OPENID, TOKEN_TYPE
from authentik.providers.oauth2.errors import AuthorizeError, ClientIdError, RedirectUriError
from authentik.providers.oauth2.models import (
AccessToken,
@ -43,7 +43,7 @@ class TestAuthorize(OAuthTestCase):
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")],
)
with self.assertRaises(AuthorizeError):
with self.assertRaises(AuthorizeError) as cm:
request = self.factory.get(
"/",
data={
@ -53,6 +53,7 @@ class TestAuthorize(OAuthTestCase):
},
)
OAuthAuthorizationParams.from_request(request)
self.assertEqual(cm.exception.error, "unsupported_response_type")
def test_invalid_client_id(self):
"""Test invalid client ID"""
@ -68,7 +69,7 @@ class TestAuthorize(OAuthTestCase):
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")],
)
with self.assertRaises(AuthorizeError):
with self.assertRaises(AuthorizeError) as cm:
request = self.factory.get(
"/",
data={
@ -79,19 +80,30 @@ class TestAuthorize(OAuthTestCase):
},
)
OAuthAuthorizationParams.from_request(request)
self.assertEqual(cm.exception.error, "request_not_supported")
def test_invalid_redirect_uri(self):
"""test missing/invalid redirect URI"""
def test_invalid_redirect_uri_missing(self):
"""test missing redirect URI"""
OAuth2Provider.objects.create(
name=generate_id(),
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
)
with self.assertRaises(RedirectUriError):
with self.assertRaises(RedirectUriError) as cm:
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
OAuthAuthorizationParams.from_request(request)
with self.assertRaises(RedirectUriError):
self.assertEqual(cm.exception.cause, "redirect_uri_missing")
def test_invalid_redirect_uri(self):
"""test invalid redirect URI"""
OAuth2Provider.objects.create(
name=generate_id(),
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
)
with self.assertRaises(RedirectUriError) as cm:
request = self.factory.get(
"/",
data={
@ -101,6 +113,7 @@ class TestAuthorize(OAuthTestCase):
},
)
OAuthAuthorizationParams.from_request(request)
self.assertEqual(cm.exception.cause, "redirect_uri_no_match")
def test_blocked_redirect_uri(self):
"""test missing/invalid redirect URI"""
@ -108,9 +121,9 @@ class TestAuthorize(OAuthTestCase):
name=generate_id(),
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "data:local.invalid")],
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "data:localhost")],
)
with self.assertRaises(RedirectUriError):
with self.assertRaises(RedirectUriError) as cm:
request = self.factory.get(
"/",
data={
@ -120,6 +133,7 @@ class TestAuthorize(OAuthTestCase):
},
)
OAuthAuthorizationParams.from_request(request)
self.assertEqual(cm.exception.cause, "redirect_uri_forbidden_scheme")
def test_invalid_redirect_uri_empty(self):
"""test missing/invalid redirect URI"""
@ -129,9 +143,6 @@ class TestAuthorize(OAuthTestCase):
authorization_flow=create_test_flow(),
redirect_uris=[],
)
with self.assertRaises(RedirectUriError):
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
OAuthAuthorizationParams.from_request(request)
request = self.factory.get(
"/",
data={
@ -150,12 +161,9 @@ class TestAuthorize(OAuthTestCase):
name=generate_id(),
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid?")],
redirect_uris=[RedirectURI(RedirectURIMatchingMode.REGEX, "http://local.invalid?")],
)
with self.assertRaises(RedirectUriError):
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
OAuthAuthorizationParams.from_request(request)
with self.assertRaises(RedirectUriError):
with self.assertRaises(RedirectUriError) as cm:
request = self.factory.get(
"/",
data={
@ -165,6 +173,7 @@ class TestAuthorize(OAuthTestCase):
},
)
OAuthAuthorizationParams.from_request(request)
self.assertEqual(cm.exception.cause, "redirect_uri_no_match")
def test_redirect_uri_invalid_regex(self):
"""test missing/invalid redirect URI (invalid regex)"""
@ -172,12 +181,9 @@ class TestAuthorize(OAuthTestCase):
name=generate_id(),
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "+")],
redirect_uris=[RedirectURI(RedirectURIMatchingMode.REGEX, "+")],
)
with self.assertRaises(RedirectUriError):
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
OAuthAuthorizationParams.from_request(request)
with self.assertRaises(RedirectUriError):
with self.assertRaises(RedirectUriError) as cm:
request = self.factory.get(
"/",
data={
@ -187,23 +193,22 @@ class TestAuthorize(OAuthTestCase):
},
)
OAuthAuthorizationParams.from_request(request)
self.assertEqual(cm.exception.cause, "redirect_uri_no_match")
def test_empty_redirect_uri(self):
"""test empty redirect URI (configure in provider)"""
def test_redirect_uri_regex(self):
"""test valid redirect URI (regex)"""
OAuth2Provider.objects.create(
name=generate_id(),
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.REGEX, ".+")],
)
with self.assertRaises(RedirectUriError):
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
OAuthAuthorizationParams.from_request(request)
request = self.factory.get(
"/",
data={
"response_type": "code",
"client_id": "test",
"redirect_uri": "http://localhost",
"redirect_uri": "http://foo.bar.baz",
},
)
OAuthAuthorizationParams.from_request(request)
@ -258,7 +263,7 @@ class TestAuthorize(OAuthTestCase):
GrantTypes.IMPLICIT,
)
# Implicit without openid scope
with self.assertRaises(AuthorizeError):
with self.assertRaises(AuthorizeError) as cm:
request = self.factory.get(
"/",
data={
@ -285,7 +290,7 @@ class TestAuthorize(OAuthTestCase):
self.assertEqual(
OAuthAuthorizationParams.from_request(request).grant_type, GrantTypes.HYBRID
)
with self.assertRaises(AuthorizeError):
with self.assertRaises(AuthorizeError) as cm:
request = self.factory.get(
"/",
data={
@ -295,6 +300,7 @@ class TestAuthorize(OAuthTestCase):
},
)
OAuthAuthorizationParams.from_request(request)
self.assertEqual(cm.exception.error, "unsupported_response_type")
def test_full_code(self):
"""Test full authorization"""
@ -615,3 +621,54 @@ class TestAuthorize(OAuthTestCase):
},
},
)
def test_openid_missing_invalid(self):
"""test request requiring an OpenID scope to be set"""
OAuth2Provider.objects.create(
name=generate_id(),
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")],
)
request = self.factory.get(
"/",
data={
"response_type": "id_token",
"client_id": "test",
"redirect_uri": "http://localhost",
"scope": "",
},
)
with self.assertRaises(AuthorizeError) as cm:
OAuthAuthorizationParams.from_request(request)
self.assertEqual(cm.exception.cause, "scope_openid_missing")
@apply_blueprint("system/providers-oauth2.yaml")
def test_offline_access_invalid(self):
"""test request for offline_access with invalid response type"""
provider = OAuth2Provider.objects.create(
name=generate_id(),
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")],
)
provider.property_mappings.set(
ScopeMapping.objects.filter(
managed__in=[
"goauthentik.io/providers/oauth2/scope-openid",
"goauthentik.io/providers/oauth2/scope-offline_access",
]
)
)
request = self.factory.get(
"/",
data={
"response_type": "id_token",
"client_id": "test",
"redirect_uri": "http://localhost",
"scope": f"{SCOPE_OPENID} {SCOPE_OFFLINE_ACCESS}",
"nonce": generate_id(),
},
)
parsed = OAuthAuthorizationParams.from_request(request)
self.assertNotIn(SCOPE_OFFLINE_ACCESS, parsed.scope)

View File

@ -190,7 +190,7 @@ class OAuthAuthorizationParams:
allowed_redirect_urls = self.provider.redirect_uris
if not self.redirect_uri:
LOGGER.warning("Missing redirect uri.")
raise RedirectUriError("", allowed_redirect_urls)
raise RedirectUriError("", allowed_redirect_urls).with_cause("redirect_uri_missing")
if len(allowed_redirect_urls) < 1:
LOGGER.info("Setting redirect for blank redirect_uris", redirect=self.redirect_uri)
@ -219,10 +219,14 @@ class OAuthAuthorizationParams:
provider=self.provider,
)
if not match_found:
raise RedirectUriError(self.redirect_uri, allowed_redirect_urls)
raise RedirectUriError(self.redirect_uri, allowed_redirect_urls).with_cause(
"redirect_uri_no_match"
)
# Check against forbidden schemes
if urlparse(self.redirect_uri).scheme in FORBIDDEN_URI_SCHEMES:
raise RedirectUriError(self.redirect_uri, allowed_redirect_urls)
raise RedirectUriError(self.redirect_uri, allowed_redirect_urls).with_cause(
"redirect_uri_forbidden_scheme"
)
def check_scope(self, github_compat=False):
"""Ensure openid scope is set in Hybrid flows, or when requesting an id_token"""
@ -251,7 +255,9 @@ class OAuthAuthorizationParams:
or self.response_type in [ResponseTypes.ID_TOKEN, ResponseTypes.ID_TOKEN_TOKEN]
):
LOGGER.warning("Missing 'openid' scope.")
raise AuthorizeError(self.redirect_uri, "invalid_scope", self.grant_type, self.state)
raise AuthorizeError(
self.redirect_uri, "invalid_scope", self.grant_type, self.state
).with_cause("scope_openid_missing")
if SCOPE_OFFLINE_ACCESS in self.scope:
# https://openid.net/specs/openid-connect-core-1_0.html#OfflineAccess
# Don't explicitly request consent with offline_access, as the spec allows for
@ -286,7 +292,9 @@ class OAuthAuthorizationParams:
return
if not self.nonce:
LOGGER.warning("Missing nonce for OpenID Request")
raise AuthorizeError(self.redirect_uri, "invalid_request", self.grant_type, self.state)
raise AuthorizeError(
self.redirect_uri, "invalid_request", self.grant_type, self.state
).with_cause("none_missing")
def check_code_challenge(self):
"""PKCE validation of the transformation method."""
@ -345,10 +353,10 @@ class AuthorizationFlowInitView(PolicyAccessView):
self.request, github_compat=self.github_compat
)
except AuthorizeError as error:
LOGGER.warning(error.description, redirect_uri=error.redirect_uri)
LOGGER.warning(error.description, redirect_uri=error.redirect_uri, cause=error.cause)
raise RequestValidationError(error.get_response(self.request)) from None
except OAuth2Error as error:
LOGGER.warning(error.description)
LOGGER.warning(error.description, cause=error.cause)
raise RequestValidationError(
bad_request_message(self.request, error.description, title=error.error)
) from None