Compare commits
2 Commits
sources/oa
...
providers/
Author | SHA1 | Date | |
---|---|---|---|
7549a6b83d | |||
bb45b714e2 |
@ -15,6 +15,7 @@ class OAuth2Error(SentryIgnoredException):
|
||||
|
||||
error: str
|
||||
description: str
|
||||
cause: str | None = None
|
||||
|
||||
def create_dict(self):
|
||||
"""Return error as dict for JSON Rendering"""
|
||||
@ -34,6 +35,10 @@ class OAuth2Error(SentryIgnoredException):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def with_cause(self, cause: str):
|
||||
self.cause = cause
|
||||
return self
|
||||
|
||||
|
||||
class RedirectUriError(OAuth2Error):
|
||||
"""The request fails due to a missing, invalid, or mismatching
|
||||
|
@ -12,7 +12,7 @@ from authentik.core.tests.utils import create_test_admin_user, create_test_flow
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.lib.utils.time import timedelta_from_string
|
||||
from authentik.providers.oauth2.constants import TOKEN_TYPE
|
||||
from authentik.providers.oauth2.constants import SCOPE_OFFLINE_ACCESS, SCOPE_OPENID, TOKEN_TYPE
|
||||
from authentik.providers.oauth2.errors import AuthorizeError, ClientIdError, RedirectUriError
|
||||
from authentik.providers.oauth2.models import (
|
||||
AccessToken,
|
||||
@ -43,7 +43,7 @@ class TestAuthorize(OAuthTestCase):
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")],
|
||||
)
|
||||
with self.assertRaises(AuthorizeError):
|
||||
with self.assertRaises(AuthorizeError) as cm:
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
data={
|
||||
@ -53,6 +53,7 @@ class TestAuthorize(OAuthTestCase):
|
||||
},
|
||||
)
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
self.assertEqual(cm.exception.error, "unsupported_response_type")
|
||||
|
||||
def test_invalid_client_id(self):
|
||||
"""Test invalid client ID"""
|
||||
@ -68,7 +69,7 @@ class TestAuthorize(OAuthTestCase):
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")],
|
||||
)
|
||||
with self.assertRaises(AuthorizeError):
|
||||
with self.assertRaises(AuthorizeError) as cm:
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
data={
|
||||
@ -79,19 +80,30 @@ class TestAuthorize(OAuthTestCase):
|
||||
},
|
||||
)
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
self.assertEqual(cm.exception.error, "request_not_supported")
|
||||
|
||||
def test_invalid_redirect_uri(self):
|
||||
"""test missing/invalid redirect URI"""
|
||||
def test_invalid_redirect_uri_missing(self):
|
||||
"""test missing redirect URI"""
|
||||
OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
client_id="test",
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
|
||||
)
|
||||
with self.assertRaises(RedirectUriError):
|
||||
with self.assertRaises(RedirectUriError) as cm:
|
||||
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
with self.assertRaises(RedirectUriError):
|
||||
self.assertEqual(cm.exception.cause, "redirect_uri_missing")
|
||||
|
||||
def test_invalid_redirect_uri(self):
|
||||
"""test invalid redirect URI"""
|
||||
OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
client_id="test",
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
|
||||
)
|
||||
with self.assertRaises(RedirectUriError) as cm:
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
data={
|
||||
@ -101,6 +113,7 @@ class TestAuthorize(OAuthTestCase):
|
||||
},
|
||||
)
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
self.assertEqual(cm.exception.cause, "redirect_uri_no_match")
|
||||
|
||||
def test_blocked_redirect_uri(self):
|
||||
"""test missing/invalid redirect URI"""
|
||||
@ -108,9 +121,9 @@ class TestAuthorize(OAuthTestCase):
|
||||
name=generate_id(),
|
||||
client_id="test",
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "data:local.invalid")],
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "data:localhost")],
|
||||
)
|
||||
with self.assertRaises(RedirectUriError):
|
||||
with self.assertRaises(RedirectUriError) as cm:
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
data={
|
||||
@ -120,6 +133,7 @@ class TestAuthorize(OAuthTestCase):
|
||||
},
|
||||
)
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
self.assertEqual(cm.exception.cause, "redirect_uri_forbidden_scheme")
|
||||
|
||||
def test_invalid_redirect_uri_empty(self):
|
||||
"""test missing/invalid redirect URI"""
|
||||
@ -129,9 +143,6 @@ class TestAuthorize(OAuthTestCase):
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[],
|
||||
)
|
||||
with self.assertRaises(RedirectUriError):
|
||||
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
data={
|
||||
@ -150,12 +161,9 @@ class TestAuthorize(OAuthTestCase):
|
||||
name=generate_id(),
|
||||
client_id="test",
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid?")],
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.REGEX, "http://local.invalid?")],
|
||||
)
|
||||
with self.assertRaises(RedirectUriError):
|
||||
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
with self.assertRaises(RedirectUriError):
|
||||
with self.assertRaises(RedirectUriError) as cm:
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
data={
|
||||
@ -165,6 +173,7 @@ class TestAuthorize(OAuthTestCase):
|
||||
},
|
||||
)
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
self.assertEqual(cm.exception.cause, "redirect_uri_no_match")
|
||||
|
||||
def test_redirect_uri_invalid_regex(self):
|
||||
"""test missing/invalid redirect URI (invalid regex)"""
|
||||
@ -172,12 +181,9 @@ class TestAuthorize(OAuthTestCase):
|
||||
name=generate_id(),
|
||||
client_id="test",
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "+")],
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.REGEX, "+")],
|
||||
)
|
||||
with self.assertRaises(RedirectUriError):
|
||||
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
with self.assertRaises(RedirectUriError):
|
||||
with self.assertRaises(RedirectUriError) as cm:
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
data={
|
||||
@ -187,23 +193,22 @@ class TestAuthorize(OAuthTestCase):
|
||||
},
|
||||
)
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
self.assertEqual(cm.exception.cause, "redirect_uri_no_match")
|
||||
|
||||
def test_empty_redirect_uri(self):
|
||||
"""test empty redirect URI (configure in provider)"""
|
||||
def test_redirect_uri_regex(self):
|
||||
"""test valid redirect URI (regex)"""
|
||||
OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
client_id="test",
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.REGEX, ".+")],
|
||||
)
|
||||
with self.assertRaises(RedirectUriError):
|
||||
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
data={
|
||||
"response_type": "code",
|
||||
"client_id": "test",
|
||||
"redirect_uri": "http://localhost",
|
||||
"redirect_uri": "http://foo.bar.baz",
|
||||
},
|
||||
)
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
@ -258,7 +263,7 @@ class TestAuthorize(OAuthTestCase):
|
||||
GrantTypes.IMPLICIT,
|
||||
)
|
||||
# Implicit without openid scope
|
||||
with self.assertRaises(AuthorizeError):
|
||||
with self.assertRaises(AuthorizeError) as cm:
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
data={
|
||||
@ -285,7 +290,7 @@ class TestAuthorize(OAuthTestCase):
|
||||
self.assertEqual(
|
||||
OAuthAuthorizationParams.from_request(request).grant_type, GrantTypes.HYBRID
|
||||
)
|
||||
with self.assertRaises(AuthorizeError):
|
||||
with self.assertRaises(AuthorizeError) as cm:
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
data={
|
||||
@ -295,6 +300,7 @@ class TestAuthorize(OAuthTestCase):
|
||||
},
|
||||
)
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
self.assertEqual(cm.exception.error, "unsupported_response_type")
|
||||
|
||||
def test_full_code(self):
|
||||
"""Test full authorization"""
|
||||
@ -615,3 +621,54 @@ class TestAuthorize(OAuthTestCase):
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
def test_openid_missing_invalid(self):
|
||||
"""test request requiring an OpenID scope to be set"""
|
||||
OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
client_id="test",
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")],
|
||||
)
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
data={
|
||||
"response_type": "id_token",
|
||||
"client_id": "test",
|
||||
"redirect_uri": "http://localhost",
|
||||
"scope": "",
|
||||
},
|
||||
)
|
||||
with self.assertRaises(AuthorizeError) as cm:
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
self.assertEqual(cm.exception.cause, "scope_openid_missing")
|
||||
|
||||
@apply_blueprint("system/providers-oauth2.yaml")
|
||||
def test_offline_access_invalid(self):
|
||||
"""test request for offline_access with invalid response type"""
|
||||
provider = OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
client_id="test",
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")],
|
||||
)
|
||||
provider.property_mappings.set(
|
||||
ScopeMapping.objects.filter(
|
||||
managed__in=[
|
||||
"goauthentik.io/providers/oauth2/scope-openid",
|
||||
"goauthentik.io/providers/oauth2/scope-offline_access",
|
||||
]
|
||||
)
|
||||
)
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
data={
|
||||
"response_type": "id_token",
|
||||
"client_id": "test",
|
||||
"redirect_uri": "http://localhost",
|
||||
"scope": f"{SCOPE_OPENID} {SCOPE_OFFLINE_ACCESS}",
|
||||
"nonce": generate_id(),
|
||||
},
|
||||
)
|
||||
parsed = OAuthAuthorizationParams.from_request(request)
|
||||
self.assertNotIn(SCOPE_OFFLINE_ACCESS, parsed.scope)
|
||||
|
@ -190,7 +190,7 @@ class OAuthAuthorizationParams:
|
||||
allowed_redirect_urls = self.provider.redirect_uris
|
||||
if not self.redirect_uri:
|
||||
LOGGER.warning("Missing redirect uri.")
|
||||
raise RedirectUriError("", allowed_redirect_urls)
|
||||
raise RedirectUriError("", allowed_redirect_urls).with_cause("redirect_uri_missing")
|
||||
|
||||
if len(allowed_redirect_urls) < 1:
|
||||
LOGGER.info("Setting redirect for blank redirect_uris", redirect=self.redirect_uri)
|
||||
@ -219,10 +219,14 @@ class OAuthAuthorizationParams:
|
||||
provider=self.provider,
|
||||
)
|
||||
if not match_found:
|
||||
raise RedirectUriError(self.redirect_uri, allowed_redirect_urls)
|
||||
raise RedirectUriError(self.redirect_uri, allowed_redirect_urls).with_cause(
|
||||
"redirect_uri_no_match"
|
||||
)
|
||||
# Check against forbidden schemes
|
||||
if urlparse(self.redirect_uri).scheme in FORBIDDEN_URI_SCHEMES:
|
||||
raise RedirectUriError(self.redirect_uri, allowed_redirect_urls)
|
||||
raise RedirectUriError(self.redirect_uri, allowed_redirect_urls).with_cause(
|
||||
"redirect_uri_forbidden_scheme"
|
||||
)
|
||||
|
||||
def check_scope(self, github_compat=False):
|
||||
"""Ensure openid scope is set in Hybrid flows, or when requesting an id_token"""
|
||||
@ -251,7 +255,9 @@ class OAuthAuthorizationParams:
|
||||
or self.response_type in [ResponseTypes.ID_TOKEN, ResponseTypes.ID_TOKEN_TOKEN]
|
||||
):
|
||||
LOGGER.warning("Missing 'openid' scope.")
|
||||
raise AuthorizeError(self.redirect_uri, "invalid_scope", self.grant_type, self.state)
|
||||
raise AuthorizeError(
|
||||
self.redirect_uri, "invalid_scope", self.grant_type, self.state
|
||||
).with_cause("scope_openid_missing")
|
||||
if SCOPE_OFFLINE_ACCESS in self.scope:
|
||||
# https://openid.net/specs/openid-connect-core-1_0.html#OfflineAccess
|
||||
# Don't explicitly request consent with offline_access, as the spec allows for
|
||||
@ -286,7 +292,9 @@ class OAuthAuthorizationParams:
|
||||
return
|
||||
if not self.nonce:
|
||||
LOGGER.warning("Missing nonce for OpenID Request")
|
||||
raise AuthorizeError(self.redirect_uri, "invalid_request", self.grant_type, self.state)
|
||||
raise AuthorizeError(
|
||||
self.redirect_uri, "invalid_request", self.grant_type, self.state
|
||||
).with_cause("none_missing")
|
||||
|
||||
def check_code_challenge(self):
|
||||
"""PKCE validation of the transformation method."""
|
||||
@ -345,10 +353,10 @@ class AuthorizationFlowInitView(PolicyAccessView):
|
||||
self.request, github_compat=self.github_compat
|
||||
)
|
||||
except AuthorizeError as error:
|
||||
LOGGER.warning(error.description, redirect_uri=error.redirect_uri)
|
||||
LOGGER.warning(error.description, redirect_uri=error.redirect_uri, cause=error.cause)
|
||||
raise RequestValidationError(error.get_response(self.request)) from None
|
||||
except OAuth2Error as error:
|
||||
LOGGER.warning(error.description)
|
||||
LOGGER.warning(error.description, cause=error.cause)
|
||||
raise RequestValidationError(
|
||||
bad_request_message(self.request, error.description, title=error.error)
|
||||
) from None
|
||||
|
Reference in New Issue
Block a user