Compare commits
2 Commits
imports-fo
...
providers/
Author | SHA1 | Date | |
---|---|---|---|
7549a6b83d | |||
bb45b714e2 |
@ -15,6 +15,7 @@ class OAuth2Error(SentryIgnoredException):
|
|||||||
|
|
||||||
error: str
|
error: str
|
||||||
description: str
|
description: str
|
||||||
|
cause: str | None = None
|
||||||
|
|
||||||
def create_dict(self):
|
def create_dict(self):
|
||||||
"""Return error as dict for JSON Rendering"""
|
"""Return error as dict for JSON Rendering"""
|
||||||
@ -34,6 +35,10 @@ class OAuth2Error(SentryIgnoredException):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def with_cause(self, cause: str):
|
||||||
|
self.cause = cause
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class RedirectUriError(OAuth2Error):
|
class RedirectUriError(OAuth2Error):
|
||||||
"""The request fails due to a missing, invalid, or mismatching
|
"""The request fails due to a missing, invalid, or mismatching
|
||||||
|
@ -12,7 +12,7 @@ from authentik.core.tests.utils import create_test_admin_user, create_test_flow
|
|||||||
from authentik.events.models import Event, EventAction
|
from authentik.events.models import Event, EventAction
|
||||||
from authentik.lib.generators import generate_id
|
from authentik.lib.generators import generate_id
|
||||||
from authentik.lib.utils.time import timedelta_from_string
|
from authentik.lib.utils.time import timedelta_from_string
|
||||||
from authentik.providers.oauth2.constants import TOKEN_TYPE
|
from authentik.providers.oauth2.constants import SCOPE_OFFLINE_ACCESS, SCOPE_OPENID, TOKEN_TYPE
|
||||||
from authentik.providers.oauth2.errors import AuthorizeError, ClientIdError, RedirectUriError
|
from authentik.providers.oauth2.errors import AuthorizeError, ClientIdError, RedirectUriError
|
||||||
from authentik.providers.oauth2.models import (
|
from authentik.providers.oauth2.models import (
|
||||||
AccessToken,
|
AccessToken,
|
||||||
@ -43,7 +43,7 @@ class TestAuthorize(OAuthTestCase):
|
|||||||
authorization_flow=create_test_flow(),
|
authorization_flow=create_test_flow(),
|
||||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")],
|
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")],
|
||||||
)
|
)
|
||||||
with self.assertRaises(AuthorizeError):
|
with self.assertRaises(AuthorizeError) as cm:
|
||||||
request = self.factory.get(
|
request = self.factory.get(
|
||||||
"/",
|
"/",
|
||||||
data={
|
data={
|
||||||
@ -53,6 +53,7 @@ class TestAuthorize(OAuthTestCase):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
OAuthAuthorizationParams.from_request(request)
|
OAuthAuthorizationParams.from_request(request)
|
||||||
|
self.assertEqual(cm.exception.error, "unsupported_response_type")
|
||||||
|
|
||||||
def test_invalid_client_id(self):
|
def test_invalid_client_id(self):
|
||||||
"""Test invalid client ID"""
|
"""Test invalid client ID"""
|
||||||
@ -68,7 +69,7 @@ class TestAuthorize(OAuthTestCase):
|
|||||||
authorization_flow=create_test_flow(),
|
authorization_flow=create_test_flow(),
|
||||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")],
|
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")],
|
||||||
)
|
)
|
||||||
with self.assertRaises(AuthorizeError):
|
with self.assertRaises(AuthorizeError) as cm:
|
||||||
request = self.factory.get(
|
request = self.factory.get(
|
||||||
"/",
|
"/",
|
||||||
data={
|
data={
|
||||||
@ -79,19 +80,30 @@ class TestAuthorize(OAuthTestCase):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
OAuthAuthorizationParams.from_request(request)
|
OAuthAuthorizationParams.from_request(request)
|
||||||
|
self.assertEqual(cm.exception.error, "request_not_supported")
|
||||||
|
|
||||||
def test_invalid_redirect_uri(self):
|
def test_invalid_redirect_uri_missing(self):
|
||||||
"""test missing/invalid redirect URI"""
|
"""test missing redirect URI"""
|
||||||
OAuth2Provider.objects.create(
|
OAuth2Provider.objects.create(
|
||||||
name=generate_id(),
|
name=generate_id(),
|
||||||
client_id="test",
|
client_id="test",
|
||||||
authorization_flow=create_test_flow(),
|
authorization_flow=create_test_flow(),
|
||||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
|
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
|
||||||
)
|
)
|
||||||
with self.assertRaises(RedirectUriError):
|
with self.assertRaises(RedirectUriError) as cm:
|
||||||
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
|
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
|
||||||
OAuthAuthorizationParams.from_request(request)
|
OAuthAuthorizationParams.from_request(request)
|
||||||
with self.assertRaises(RedirectUriError):
|
self.assertEqual(cm.exception.cause, "redirect_uri_missing")
|
||||||
|
|
||||||
|
def test_invalid_redirect_uri(self):
|
||||||
|
"""test invalid redirect URI"""
|
||||||
|
OAuth2Provider.objects.create(
|
||||||
|
name=generate_id(),
|
||||||
|
client_id="test",
|
||||||
|
authorization_flow=create_test_flow(),
|
||||||
|
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
|
||||||
|
)
|
||||||
|
with self.assertRaises(RedirectUriError) as cm:
|
||||||
request = self.factory.get(
|
request = self.factory.get(
|
||||||
"/",
|
"/",
|
||||||
data={
|
data={
|
||||||
@ -101,6 +113,7 @@ class TestAuthorize(OAuthTestCase):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
OAuthAuthorizationParams.from_request(request)
|
OAuthAuthorizationParams.from_request(request)
|
||||||
|
self.assertEqual(cm.exception.cause, "redirect_uri_no_match")
|
||||||
|
|
||||||
def test_blocked_redirect_uri(self):
|
def test_blocked_redirect_uri(self):
|
||||||
"""test missing/invalid redirect URI"""
|
"""test missing/invalid redirect URI"""
|
||||||
@ -108,9 +121,9 @@ class TestAuthorize(OAuthTestCase):
|
|||||||
name=generate_id(),
|
name=generate_id(),
|
||||||
client_id="test",
|
client_id="test",
|
||||||
authorization_flow=create_test_flow(),
|
authorization_flow=create_test_flow(),
|
||||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "data:local.invalid")],
|
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "data:localhost")],
|
||||||
)
|
)
|
||||||
with self.assertRaises(RedirectUriError):
|
with self.assertRaises(RedirectUriError) as cm:
|
||||||
request = self.factory.get(
|
request = self.factory.get(
|
||||||
"/",
|
"/",
|
||||||
data={
|
data={
|
||||||
@ -120,6 +133,7 @@ class TestAuthorize(OAuthTestCase):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
OAuthAuthorizationParams.from_request(request)
|
OAuthAuthorizationParams.from_request(request)
|
||||||
|
self.assertEqual(cm.exception.cause, "redirect_uri_forbidden_scheme")
|
||||||
|
|
||||||
def test_invalid_redirect_uri_empty(self):
|
def test_invalid_redirect_uri_empty(self):
|
||||||
"""test missing/invalid redirect URI"""
|
"""test missing/invalid redirect URI"""
|
||||||
@ -129,9 +143,6 @@ class TestAuthorize(OAuthTestCase):
|
|||||||
authorization_flow=create_test_flow(),
|
authorization_flow=create_test_flow(),
|
||||||
redirect_uris=[],
|
redirect_uris=[],
|
||||||
)
|
)
|
||||||
with self.assertRaises(RedirectUriError):
|
|
||||||
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
|
|
||||||
OAuthAuthorizationParams.from_request(request)
|
|
||||||
request = self.factory.get(
|
request = self.factory.get(
|
||||||
"/",
|
"/",
|
||||||
data={
|
data={
|
||||||
@ -150,12 +161,9 @@ class TestAuthorize(OAuthTestCase):
|
|||||||
name=generate_id(),
|
name=generate_id(),
|
||||||
client_id="test",
|
client_id="test",
|
||||||
authorization_flow=create_test_flow(),
|
authorization_flow=create_test_flow(),
|
||||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid?")],
|
redirect_uris=[RedirectURI(RedirectURIMatchingMode.REGEX, "http://local.invalid?")],
|
||||||
)
|
)
|
||||||
with self.assertRaises(RedirectUriError):
|
with self.assertRaises(RedirectUriError) as cm:
|
||||||
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
|
|
||||||
OAuthAuthorizationParams.from_request(request)
|
|
||||||
with self.assertRaises(RedirectUriError):
|
|
||||||
request = self.factory.get(
|
request = self.factory.get(
|
||||||
"/",
|
"/",
|
||||||
data={
|
data={
|
||||||
@ -165,6 +173,7 @@ class TestAuthorize(OAuthTestCase):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
OAuthAuthorizationParams.from_request(request)
|
OAuthAuthorizationParams.from_request(request)
|
||||||
|
self.assertEqual(cm.exception.cause, "redirect_uri_no_match")
|
||||||
|
|
||||||
def test_redirect_uri_invalid_regex(self):
|
def test_redirect_uri_invalid_regex(self):
|
||||||
"""test missing/invalid redirect URI (invalid regex)"""
|
"""test missing/invalid redirect URI (invalid regex)"""
|
||||||
@ -172,12 +181,9 @@ class TestAuthorize(OAuthTestCase):
|
|||||||
name=generate_id(),
|
name=generate_id(),
|
||||||
client_id="test",
|
client_id="test",
|
||||||
authorization_flow=create_test_flow(),
|
authorization_flow=create_test_flow(),
|
||||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "+")],
|
redirect_uris=[RedirectURI(RedirectURIMatchingMode.REGEX, "+")],
|
||||||
)
|
)
|
||||||
with self.assertRaises(RedirectUriError):
|
with self.assertRaises(RedirectUriError) as cm:
|
||||||
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
|
|
||||||
OAuthAuthorizationParams.from_request(request)
|
|
||||||
with self.assertRaises(RedirectUriError):
|
|
||||||
request = self.factory.get(
|
request = self.factory.get(
|
||||||
"/",
|
"/",
|
||||||
data={
|
data={
|
||||||
@ -187,23 +193,22 @@ class TestAuthorize(OAuthTestCase):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
OAuthAuthorizationParams.from_request(request)
|
OAuthAuthorizationParams.from_request(request)
|
||||||
|
self.assertEqual(cm.exception.cause, "redirect_uri_no_match")
|
||||||
|
|
||||||
def test_empty_redirect_uri(self):
|
def test_redirect_uri_regex(self):
|
||||||
"""test empty redirect URI (configure in provider)"""
|
"""test valid redirect URI (regex)"""
|
||||||
OAuth2Provider.objects.create(
|
OAuth2Provider.objects.create(
|
||||||
name=generate_id(),
|
name=generate_id(),
|
||||||
client_id="test",
|
client_id="test",
|
||||||
authorization_flow=create_test_flow(),
|
authorization_flow=create_test_flow(),
|
||||||
|
redirect_uris=[RedirectURI(RedirectURIMatchingMode.REGEX, ".+")],
|
||||||
)
|
)
|
||||||
with self.assertRaises(RedirectUriError):
|
|
||||||
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
|
|
||||||
OAuthAuthorizationParams.from_request(request)
|
|
||||||
request = self.factory.get(
|
request = self.factory.get(
|
||||||
"/",
|
"/",
|
||||||
data={
|
data={
|
||||||
"response_type": "code",
|
"response_type": "code",
|
||||||
"client_id": "test",
|
"client_id": "test",
|
||||||
"redirect_uri": "http://localhost",
|
"redirect_uri": "http://foo.bar.baz",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
OAuthAuthorizationParams.from_request(request)
|
OAuthAuthorizationParams.from_request(request)
|
||||||
@ -258,7 +263,7 @@ class TestAuthorize(OAuthTestCase):
|
|||||||
GrantTypes.IMPLICIT,
|
GrantTypes.IMPLICIT,
|
||||||
)
|
)
|
||||||
# Implicit without openid scope
|
# Implicit without openid scope
|
||||||
with self.assertRaises(AuthorizeError):
|
with self.assertRaises(AuthorizeError) as cm:
|
||||||
request = self.factory.get(
|
request = self.factory.get(
|
||||||
"/",
|
"/",
|
||||||
data={
|
data={
|
||||||
@ -285,7 +290,7 @@ class TestAuthorize(OAuthTestCase):
|
|||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
OAuthAuthorizationParams.from_request(request).grant_type, GrantTypes.HYBRID
|
OAuthAuthorizationParams.from_request(request).grant_type, GrantTypes.HYBRID
|
||||||
)
|
)
|
||||||
with self.assertRaises(AuthorizeError):
|
with self.assertRaises(AuthorizeError) as cm:
|
||||||
request = self.factory.get(
|
request = self.factory.get(
|
||||||
"/",
|
"/",
|
||||||
data={
|
data={
|
||||||
@ -295,6 +300,7 @@ class TestAuthorize(OAuthTestCase):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
OAuthAuthorizationParams.from_request(request)
|
OAuthAuthorizationParams.from_request(request)
|
||||||
|
self.assertEqual(cm.exception.error, "unsupported_response_type")
|
||||||
|
|
||||||
def test_full_code(self):
|
def test_full_code(self):
|
||||||
"""Test full authorization"""
|
"""Test full authorization"""
|
||||||
@ -615,3 +621,54 @@ class TestAuthorize(OAuthTestCase):
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_openid_missing_invalid(self):
|
||||||
|
"""test request requiring an OpenID scope to be set"""
|
||||||
|
OAuth2Provider.objects.create(
|
||||||
|
name=generate_id(),
|
||||||
|
client_id="test",
|
||||||
|
authorization_flow=create_test_flow(),
|
||||||
|
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")],
|
||||||
|
)
|
||||||
|
request = self.factory.get(
|
||||||
|
"/",
|
||||||
|
data={
|
||||||
|
"response_type": "id_token",
|
||||||
|
"client_id": "test",
|
||||||
|
"redirect_uri": "http://localhost",
|
||||||
|
"scope": "",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
with self.assertRaises(AuthorizeError) as cm:
|
||||||
|
OAuthAuthorizationParams.from_request(request)
|
||||||
|
self.assertEqual(cm.exception.cause, "scope_openid_missing")
|
||||||
|
|
||||||
|
@apply_blueprint("system/providers-oauth2.yaml")
|
||||||
|
def test_offline_access_invalid(self):
|
||||||
|
"""test request for offline_access with invalid response type"""
|
||||||
|
provider = OAuth2Provider.objects.create(
|
||||||
|
name=generate_id(),
|
||||||
|
client_id="test",
|
||||||
|
authorization_flow=create_test_flow(),
|
||||||
|
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")],
|
||||||
|
)
|
||||||
|
provider.property_mappings.set(
|
||||||
|
ScopeMapping.objects.filter(
|
||||||
|
managed__in=[
|
||||||
|
"goauthentik.io/providers/oauth2/scope-openid",
|
||||||
|
"goauthentik.io/providers/oauth2/scope-offline_access",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
request = self.factory.get(
|
||||||
|
"/",
|
||||||
|
data={
|
||||||
|
"response_type": "id_token",
|
||||||
|
"client_id": "test",
|
||||||
|
"redirect_uri": "http://localhost",
|
||||||
|
"scope": f"{SCOPE_OPENID} {SCOPE_OFFLINE_ACCESS}",
|
||||||
|
"nonce": generate_id(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
parsed = OAuthAuthorizationParams.from_request(request)
|
||||||
|
self.assertNotIn(SCOPE_OFFLINE_ACCESS, parsed.scope)
|
||||||
|
@ -190,7 +190,7 @@ class OAuthAuthorizationParams:
|
|||||||
allowed_redirect_urls = self.provider.redirect_uris
|
allowed_redirect_urls = self.provider.redirect_uris
|
||||||
if not self.redirect_uri:
|
if not self.redirect_uri:
|
||||||
LOGGER.warning("Missing redirect uri.")
|
LOGGER.warning("Missing redirect uri.")
|
||||||
raise RedirectUriError("", allowed_redirect_urls)
|
raise RedirectUriError("", allowed_redirect_urls).with_cause("redirect_uri_missing")
|
||||||
|
|
||||||
if len(allowed_redirect_urls) < 1:
|
if len(allowed_redirect_urls) < 1:
|
||||||
LOGGER.info("Setting redirect for blank redirect_uris", redirect=self.redirect_uri)
|
LOGGER.info("Setting redirect for blank redirect_uris", redirect=self.redirect_uri)
|
||||||
@ -219,10 +219,14 @@ class OAuthAuthorizationParams:
|
|||||||
provider=self.provider,
|
provider=self.provider,
|
||||||
)
|
)
|
||||||
if not match_found:
|
if not match_found:
|
||||||
raise RedirectUriError(self.redirect_uri, allowed_redirect_urls)
|
raise RedirectUriError(self.redirect_uri, allowed_redirect_urls).with_cause(
|
||||||
|
"redirect_uri_no_match"
|
||||||
|
)
|
||||||
# Check against forbidden schemes
|
# Check against forbidden schemes
|
||||||
if urlparse(self.redirect_uri).scheme in FORBIDDEN_URI_SCHEMES:
|
if urlparse(self.redirect_uri).scheme in FORBIDDEN_URI_SCHEMES:
|
||||||
raise RedirectUriError(self.redirect_uri, allowed_redirect_urls)
|
raise RedirectUriError(self.redirect_uri, allowed_redirect_urls).with_cause(
|
||||||
|
"redirect_uri_forbidden_scheme"
|
||||||
|
)
|
||||||
|
|
||||||
def check_scope(self, github_compat=False):
|
def check_scope(self, github_compat=False):
|
||||||
"""Ensure openid scope is set in Hybrid flows, or when requesting an id_token"""
|
"""Ensure openid scope is set in Hybrid flows, or when requesting an id_token"""
|
||||||
@ -251,7 +255,9 @@ class OAuthAuthorizationParams:
|
|||||||
or self.response_type in [ResponseTypes.ID_TOKEN, ResponseTypes.ID_TOKEN_TOKEN]
|
or self.response_type in [ResponseTypes.ID_TOKEN, ResponseTypes.ID_TOKEN_TOKEN]
|
||||||
):
|
):
|
||||||
LOGGER.warning("Missing 'openid' scope.")
|
LOGGER.warning("Missing 'openid' scope.")
|
||||||
raise AuthorizeError(self.redirect_uri, "invalid_scope", self.grant_type, self.state)
|
raise AuthorizeError(
|
||||||
|
self.redirect_uri, "invalid_scope", self.grant_type, self.state
|
||||||
|
).with_cause("scope_openid_missing")
|
||||||
if SCOPE_OFFLINE_ACCESS in self.scope:
|
if SCOPE_OFFLINE_ACCESS in self.scope:
|
||||||
# https://openid.net/specs/openid-connect-core-1_0.html#OfflineAccess
|
# https://openid.net/specs/openid-connect-core-1_0.html#OfflineAccess
|
||||||
# Don't explicitly request consent with offline_access, as the spec allows for
|
# Don't explicitly request consent with offline_access, as the spec allows for
|
||||||
@ -286,7 +292,9 @@ class OAuthAuthorizationParams:
|
|||||||
return
|
return
|
||||||
if not self.nonce:
|
if not self.nonce:
|
||||||
LOGGER.warning("Missing nonce for OpenID Request")
|
LOGGER.warning("Missing nonce for OpenID Request")
|
||||||
raise AuthorizeError(self.redirect_uri, "invalid_request", self.grant_type, self.state)
|
raise AuthorizeError(
|
||||||
|
self.redirect_uri, "invalid_request", self.grant_type, self.state
|
||||||
|
).with_cause("none_missing")
|
||||||
|
|
||||||
def check_code_challenge(self):
|
def check_code_challenge(self):
|
||||||
"""PKCE validation of the transformation method."""
|
"""PKCE validation of the transformation method."""
|
||||||
@ -345,10 +353,10 @@ class AuthorizationFlowInitView(PolicyAccessView):
|
|||||||
self.request, github_compat=self.github_compat
|
self.request, github_compat=self.github_compat
|
||||||
)
|
)
|
||||||
except AuthorizeError as error:
|
except AuthorizeError as error:
|
||||||
LOGGER.warning(error.description, redirect_uri=error.redirect_uri)
|
LOGGER.warning(error.description, redirect_uri=error.redirect_uri, cause=error.cause)
|
||||||
raise RequestValidationError(error.get_response(self.request)) from None
|
raise RequestValidationError(error.get_response(self.request)) from None
|
||||||
except OAuth2Error as error:
|
except OAuth2Error as error:
|
||||||
LOGGER.warning(error.description)
|
LOGGER.warning(error.description, cause=error.cause)
|
||||||
raise RequestValidationError(
|
raise RequestValidationError(
|
||||||
bad_request_message(self.request, error.description, title=error.error)
|
bad_request_message(self.request, error.description, title=error.error)
|
||||||
) from None
|
) from None
|
||||||
|
Reference in New Issue
Block a user