providers/oauth2: offline access (#8026)

* improve scope check (log when application requests non-configured scopes)

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add offline_access special scope

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* ensure scope is set

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* update tests for refresh tokens

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* special handling of scopes for github compat

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix spec

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* attempt to fix oidc tests

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* remove hardcoded slug

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* check scope from authorization code instead of request

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix injection for consent stage checking incorrectly

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens L
2024-01-04 19:57:11 +01:00
committed by GitHub
parent 1b36cb8331
commit 509b502d3c
15 changed files with 369 additions and 171 deletions

View File

@ -7,8 +7,8 @@ GRANT_TYPE_CLIENT_CREDENTIALS = "client_credentials"
GRANT_TYPE_PASSWORD = "password" # nosec
GRANT_TYPE_DEVICE_CODE = "urn:ietf:params:oauth:grant-type:device_code"
CLIENT_ASSERTION_TYPE = "client_assertion_type"
CLIENT_ASSERTION = "client_assertion"
CLIENT_ASSERTION_TYPE = "client_assertion_type"
CLIENT_ASSERTION_TYPE_JWT = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
PROMPT_NONE = "none"
@ -18,9 +18,9 @@ PROMPT_LOGIN = "login"
SCOPE_OPENID = "openid"
SCOPE_OPENID_PROFILE = "profile"
SCOPE_OPENID_EMAIL = "email"
SCOPE_OFFLINE_ACCESS = "offline_access"
# https://www.iana.org/assignments/oauth-parameters/\
# oauth-parameters.xhtml#pkce-code-challenge-method
# https://www.iana.org/assignments/oauth-parameters/auth-parameters.xhtml#pkce-code-challenge-method
PKCE_METHOD_PLAIN = "plain"
PKCE_METHOD_S256 = "S256"
@ -36,6 +36,12 @@ SCOPE_GITHUB_USER_READ = "read:user"
SCOPE_GITHUB_USER_EMAIL = "user:email"
# Read info about teams
SCOPE_GITHUB_ORG_READ = "read:org"
SCOPE_GITHUB = {
SCOPE_GITHUB_USER,
SCOPE_GITHUB_USER_READ,
SCOPE_GITHUB_USER_EMAIL,
SCOPE_GITHUB_ORG_READ,
}
ACR_AUTHENTIK_DEFAULT = "goauthentik.io/providers/oauth2/default"

View File

@ -127,7 +127,7 @@ class AuthorizeError(OAuth2Error):
"account_selection_required": (
"The End-User is required to select a session at the Authorization Server"
),
"consent_required": "The Authorization Server requires End-Userconsent",
"consent_required": "The Authorization Server requires End-User consent",
"invalid_request_uri": (
"The request_uri in the Authorization Request returns an error or contains invalid data"
),

View File

@ -5,6 +5,7 @@ from django.test import RequestFactory
from django.urls import reverse
from django.utils.timezone import now
from authentik.blueprints.tests import apply_blueprint
from authentik.core.models import Application
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
from authentik.events.models import Event, EventAction
@ -18,6 +19,7 @@ from authentik.providers.oauth2.models import (
AuthorizationCode,
GrantTypes,
OAuth2Provider,
ScopeMapping,
)
from authentik.providers.oauth2.tests.utils import OAuthTestCase
from authentik.providers.oauth2.views.authorize import OAuthAuthorizationParams
@ -172,14 +174,24 @@ class TestAuthorize(OAuthTestCase):
)
OAuthAuthorizationParams.from_request(request)
@apply_blueprint("system/providers-oauth2.yaml")
def test_response_type(self):
"""test response_type"""
OAuth2Provider.objects.create(
provider = OAuth2Provider.objects.create(
name=generate_id(),
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris="http://local.invalid/Foo",
)
provider.property_mappings.set(
ScopeMapping.objects.filter(
managed__in=[
"goauthentik.io/providers/oauth2/scope-openid",
"goauthentik.io/providers/oauth2/scope-email",
"goauthentik.io/providers/oauth2/scope-profile",
]
)
)
request = self.factory.get(
"/",
data={
@ -292,6 +304,7 @@ class TestAuthorize(OAuthTestCase):
delta=5,
)
@apply_blueprint("system/providers-oauth2.yaml")
def test_full_implicit(self):
"""Test full authorization"""
flow = create_test_flow()
@ -302,6 +315,15 @@ class TestAuthorize(OAuthTestCase):
redirect_uris="http://localhost",
signing_key=self.keypair,
)
provider.property_mappings.set(
ScopeMapping.objects.filter(
managed__in=[
"goauthentik.io/providers/oauth2/scope-openid",
"goauthentik.io/providers/oauth2/scope-email",
"goauthentik.io/providers/oauth2/scope-profile",
]
)
)
Application.objects.create(name="app", slug="app", provider=provider)
state = generate_id()
user = create_test_admin_user()
@ -409,6 +431,7 @@ class TestAuthorize(OAuthTestCase):
delta=5,
)
@apply_blueprint("system/providers-oauth2.yaml")
def test_full_form_post_id_token(self):
"""Test full authorization (form_post response)"""
flow = create_test_flow()
@ -419,6 +442,15 @@ class TestAuthorize(OAuthTestCase):
redirect_uris="http://localhost",
signing_key=self.keypair,
)
provider.property_mappings.set(
ScopeMapping.objects.filter(
managed__in=[
"goauthentik.io/providers/oauth2/scope-openid",
"goauthentik.io/providers/oauth2/scope-email",
"goauthentik.io/providers/oauth2/scope-profile",
]
)
)
app = Application.objects.create(name=generate_id(), slug=generate_id(), provider=provider)
state = generate_id()
user = create_test_admin_user()
@ -440,6 +472,7 @@ class TestAuthorize(OAuthTestCase):
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
)
token: AccessToken = AccessToken.objects.filter(user=user).first()
self.assertIsNotNone(token)
self.assertJSONEqual(
response.content.decode(),
{

View File

@ -6,6 +6,7 @@ from django.test import RequestFactory
from django.urls import reverse
from django.utils import timezone
from authentik.blueprints.tests import apply_blueprint
from authentik.core.models import Application
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
from authentik.events.models import Event, EventAction
@ -21,6 +22,7 @@ from authentik.providers.oauth2.models import (
AuthorizationCode,
OAuth2Provider,
RefreshToken,
ScopeMapping,
)
from authentik.providers.oauth2.tests.utils import OAuthTestCase
from authentik.providers.oauth2.views.token import TokenParams
@ -136,21 +138,20 @@ class TestToken(OAuthTestCase):
HTTP_AUTHORIZATION=f"Basic {header}",
)
access: AccessToken = AccessToken.objects.filter(user=user, provider=provider).first()
refresh: RefreshToken = RefreshToken.objects.filter(user=user, provider=provider).first()
self.assertJSONEqual(
response.content.decode(),
{
"access_token": access.token,
"refresh_token": refresh.token,
"token_type": TOKEN_TYPE,
"expires_in": 3600,
"id_token": provider.encode(
refresh.id_token.to_dict(),
access.id_token.to_dict(),
),
},
)
self.validate_jwt(access, provider)
@apply_blueprint("system/providers-oauth2.yaml")
def test_refresh_token_view(self):
"""test request param"""
provider = OAuth2Provider.objects.create(
@ -159,6 +160,16 @@ class TestToken(OAuthTestCase):
redirect_uris="http://local.invalid",
signing_key=self.keypair,
)
provider.property_mappings.set(
ScopeMapping.objects.filter(
managed__in=[
"goauthentik.io/providers/oauth2/scope-openid",
"goauthentik.io/providers/oauth2/scope-email",
"goauthentik.io/providers/oauth2/scope-profile",
"goauthentik.io/providers/oauth2/scope-offline_access",
]
)
)
# Needs to be assigned to an application for iss to be set
self.app.provider = provider
self.app.save()
@ -170,6 +181,7 @@ class TestToken(OAuthTestCase):
token=generate_id(),
_id_token=dumps({}),
auth_time=timezone.now(),
_scope="offline_access",
)
response = self.client.post(
reverse("authentik_providers_oauth2:token"),
@ -201,6 +213,7 @@ class TestToken(OAuthTestCase):
)
self.validate_jwt(access, provider)
@apply_blueprint("system/providers-oauth2.yaml")
def test_refresh_token_view_invalid_origin(self):
"""test request param"""
provider = OAuth2Provider.objects.create(
@ -209,6 +222,16 @@ class TestToken(OAuthTestCase):
redirect_uris="http://local.invalid",
signing_key=self.keypair,
)
provider.property_mappings.set(
ScopeMapping.objects.filter(
managed__in=[
"goauthentik.io/providers/oauth2/scope-openid",
"goauthentik.io/providers/oauth2/scope-email",
"goauthentik.io/providers/oauth2/scope-profile",
"goauthentik.io/providers/oauth2/scope-offline_access",
]
)
)
header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
user = create_test_admin_user()
token: RefreshToken = RefreshToken.objects.create(
@ -217,6 +240,7 @@ class TestToken(OAuthTestCase):
token=generate_id(),
_id_token=dumps({}),
auth_time=timezone.now(),
_scope="offline_access",
)
response = self.client.post(
reverse("authentik_providers_oauth2:token"),
@ -247,6 +271,7 @@ class TestToken(OAuthTestCase):
},
)
@apply_blueprint("system/providers-oauth2.yaml")
def test_refresh_token_revoke(self):
"""test request param"""
provider = OAuth2Provider.objects.create(
@ -255,6 +280,16 @@ class TestToken(OAuthTestCase):
redirect_uris="http://testserver",
signing_key=self.keypair,
)
provider.property_mappings.set(
ScopeMapping.objects.filter(
managed__in=[
"goauthentik.io/providers/oauth2/scope-openid",
"goauthentik.io/providers/oauth2/scope-email",
"goauthentik.io/providers/oauth2/scope-profile",
"goauthentik.io/providers/oauth2/scope-offline_access",
]
)
)
# Needs to be assigned to an application for iss to be set
self.app.provider = provider
self.app.save()
@ -266,6 +301,7 @@ class TestToken(OAuthTestCase):
token=generate_id(),
_id_token=dumps({}),
auth_time=timezone.now(),
_scope="offline_access",
)
# Create initial refresh token
response = self.client.post(

View File

@ -10,7 +10,7 @@ from authentik.providers.oauth2.views.token import TokenView
github_urlpatterns = [
path(
"login/oauth/authorize",
AuthorizationFlowInitView.as_view(),
AuthorizationFlowInitView.as_view(github_compat=True),
name="github-authorize",
),
path(

View File

@ -1,5 +1,5 @@
"""authentik OAuth2 Authorization views"""
from dataclasses import dataclass, field
from dataclasses import InitVar, dataclass, field
from datetime import timedelta
from hashlib import sha256
from json import dumps
@ -41,6 +41,8 @@ from authentik.providers.oauth2.constants import (
PROMPT_CONSENT,
PROMPT_LOGIN,
PROMPT_NONE,
SCOPE_GITHUB,
SCOPE_OFFLINE_ACCESS,
SCOPE_OPENID,
TOKEN_TYPE,
)
@ -66,7 +68,6 @@ from authentik.stages.consent.models import ConsentMode, ConsentStage
from authentik.stages.consent.stage import (
PLAN_CONTEXT_CONSENT_HEADER,
PLAN_CONTEXT_CONSENT_PERMISSIONS,
ConsentStageView,
)
LOGGER = get_logger()
@ -86,7 +87,7 @@ class OAuthAuthorizationParams:
redirect_uri: str
response_type: str
response_mode: Optional[str]
scope: list[str]
scope: set[str]
state: str
nonce: Optional[str]
prompt: set[str]
@ -101,8 +102,10 @@ class OAuthAuthorizationParams:
code_challenge: Optional[str] = None
code_challenge_method: Optional[str] = None
github_compat: InitVar[bool] = False
@staticmethod
def from_request(request: HttpRequest) -> "OAuthAuthorizationParams":
def from_request(request: HttpRequest, github_compat=False) -> "OAuthAuthorizationParams":
"""
Get all the params used by the Authorization Code Flow
(and also for the Implicit and Hybrid).
@ -154,7 +157,7 @@ class OAuthAuthorizationParams:
response_type=response_type,
response_mode=response_mode,
grant_type=grant_type,
scope=query_dict.get("scope", "").split(),
scope=set(query_dict.get("scope", "").split()),
state=state,
nonce=query_dict.get("nonce"),
prompt=ALLOWED_PROMPT_PARAMS.intersection(set(query_dict.get("prompt", "").split())),
@ -162,9 +165,10 @@ class OAuthAuthorizationParams:
max_age=int(max_age) if max_age else None,
code_challenge=query_dict.get("code_challenge"),
code_challenge_method=query_dict.get("code_challenge_method", "plain"),
github_compat=github_compat,
)
def __post_init__(self):
def __post_init__(self, github_compat=False):
self.provider: OAuth2Provider = OAuth2Provider.objects.filter(
client_id=self.client_id
).first()
@ -172,7 +176,7 @@ class OAuthAuthorizationParams:
LOGGER.warning("Invalid client identifier", client_id=self.client_id)
raise ClientIdError(client_id=self.client_id)
self.check_redirect_uri()
self.check_scope()
self.check_scope(github_compat)
self.check_nonce()
self.check_code_challenge()
@ -199,8 +203,8 @@ class OAuthAuthorizationParams:
if not any(fullmatch(x, self.redirect_uri) for x in allowed_redirect_urls):
LOGGER.warning(
"Invalid redirect uri (regex comparison)",
redirect_uri=self.redirect_uri,
expected=allowed_redirect_urls,
redirect_uri_given=self.redirect_uri,
redirect_uri_expected=allowed_redirect_urls,
)
raise RedirectUriError(self.redirect_uri, allowed_redirect_urls)
except RegexError as exc:
@ -208,8 +212,8 @@ class OAuthAuthorizationParams:
if not any(x == self.redirect_uri for x in allowed_redirect_urls):
LOGGER.warning(
"Invalid redirect uri (strict comparison)",
redirect_uri=self.redirect_uri,
expected=allowed_redirect_urls,
redirect_uri_given=self.redirect_uri,
redirect_uri_expected=allowed_redirect_urls,
)
raise RedirectUriError(self.redirect_uri, allowed_redirect_urls)
if self.request:
@ -217,24 +221,50 @@ class OAuthAuthorizationParams:
self.redirect_uri, "request_not_supported", self.grant_type, self.state
)
def check_scope(self):
def check_scope(self, github_compat=False):
"""Ensure openid scope is set in Hybrid flows, or when requesting an id_token"""
if len(self.scope) == 0:
default_scope_names = set(
ScopeMapping.objects.filter(provider__in=[self.provider]).values_list(
"scope_name", flat=True
)
default_scope_names = set(
ScopeMapping.objects.filter(provider__in=[self.provider]).values_list(
"scope_name", flat=True
)
)
if len(self.scope) == 0:
self.scope = default_scope_names
LOGGER.info(
"No scopes requested, defaulting to all configured scopes", scopes=self.scope
)
scopes_to_check = self.scope
if github_compat:
scopes_to_check = self.scope - SCOPE_GITHUB
if not scopes_to_check.issubset(default_scope_names):
LOGGER.info(
"Application requested scopes not configured, setting to overlap",
scope_allowed=default_scope_names,
scope_given=self.scope,
)
self.scope = self.scope.intersection(default_scope_names)
if SCOPE_OPENID not in self.scope and (
self.grant_type == GrantTypes.HYBRID
or self.response_type in [ResponseTypes.ID_TOKEN, ResponseTypes.ID_TOKEN_TOKEN]
):
LOGGER.warning("Missing 'openid' scope.")
raise AuthorizeError(self.redirect_uri, "invalid_scope", self.grant_type, self.state)
if SCOPE_OFFLINE_ACCESS in self.scope:
# https://openid.net/specs/openid-connect-core-1_0.html#OfflineAccess
if PROMPT_CONSENT not in self.prompt:
raise AuthorizeError(
self.redirect_uri, "consent_required", self.grant_type, self.state
)
if self.response_type not in [
ResponseTypes.CODE,
ResponseTypes.CODE_TOKEN,
ResponseTypes.CODE_ID_TOKEN,
ResponseTypes.CODE_ID_TOKEN_TOKEN,
]:
# offline_access requires a response type that has some sort of token
# Spec says to ignore the scope when the response_type wouldn't result
# in an authorization code being generated
self.scope.remove(SCOPE_OFFLINE_ACCESS)
def check_nonce(self):
"""Nonce parameter validation."""
@ -297,6 +327,9 @@ class AuthorizationFlowInitView(PolicyAccessView):
"""OAuth2 Flow initializer, checks access to application and starts flow"""
params: OAuthAuthorizationParams
# Enable GitHub compatibility (only allow for scopes which are handled
# differently for github compat)
github_compat = False
def pre_permission_check(self):
"""Check prompt parameter before checking permission/authentication,
@ -305,7 +338,9 @@ class AuthorizationFlowInitView(PolicyAccessView):
if len(self.request.GET) < 1:
raise Http404
try:
self.params = OAuthAuthorizationParams.from_request(self.request)
self.params = OAuthAuthorizationParams.from_request(
self.request, github_compat=self.github_compat
)
except AuthorizeError as error:
LOGGER.warning(error.description, redirect_uri=error.redirect_uri)
raise RequestValidationError(error.get_response(self.request))
@ -402,7 +437,7 @@ class AuthorizationFlowInitView(PolicyAccessView):
# OpenID clients can specify a `prompt` parameter, and if its set to consent we
# need to inject a consent stage
if PROMPT_CONSENT in self.params.prompt:
if not any(isinstance(x.stage, ConsentStageView) for x in plan.bindings):
if not any(isinstance(x.stage, ConsentStage) for x in plan.bindings):
# Plan does not have any consent stage, so we add an in-memory one
stage = ConsentStage(
name="OAuth2 Provider In-memory consent stage",

View File

@ -41,6 +41,7 @@ from authentik.providers.oauth2.constants import (
GRANT_TYPE_PASSWORD,
GRANT_TYPE_REFRESH_TOKEN,
PKCE_METHOD_S256,
SCOPE_OFFLINE_ACCESS,
TOKEN_TYPE,
)
from authentik.providers.oauth2.errors import DeviceCodeError, TokenError, UserAuthError
@ -459,7 +460,7 @@ class TokenView(View):
op="authentik.providers.oauth2.post.response",
):
if self.params.grant_type == GRANT_TYPE_AUTHORIZATION_CODE:
LOGGER.debug("Converting authorization code to refresh token")
LOGGER.debug("Converting authorization code to access token")
return TokenResponse(self.create_code_response())
if self.params.grant_type == GRANT_TYPE_REFRESH_TOKEN:
LOGGER.debug("Refreshing refresh token")
@ -496,42 +497,47 @@ class TokenView(View):
)
access_token.save()
refresh_token_expiry = now + timedelta_from_string(self.provider.refresh_token_validity)
refresh_token = RefreshToken(
user=self.params.authorization_code.user,
scope=self.params.authorization_code.scope,
expires=refresh_token_expiry,
provider=self.provider,
auth_time=self.params.authorization_code.auth_time,
session_id=self.params.authorization_code.session_id,
)
id_token = IDToken.new(
self.provider,
refresh_token,
self.request,
)
id_token.nonce = self.params.authorization_code.nonce
id_token.at_hash = access_token.at_hash
refresh_token.id_token = id_token
refresh_token.save()
# Delete old code
self.params.authorization_code.delete()
return {
response = {
"access_token": access_token.token,
"refresh_token": refresh_token.token,
"token_type": TOKEN_TYPE,
"expires_in": int(
timedelta_from_string(self.provider.access_token_validity).total_seconds()
),
"id_token": id_token.to_jwt(self.provider),
"id_token": access_token.id_token.to_jwt(self.provider),
}
if SCOPE_OFFLINE_ACCESS in self.params.authorization_code.scope:
refresh_token_expiry = now + timedelta_from_string(self.provider.refresh_token_validity)
refresh_token = RefreshToken(
user=self.params.authorization_code.user,
scope=self.params.authorization_code.scope,
expires=refresh_token_expiry,
provider=self.provider,
auth_time=self.params.authorization_code.auth_time,
session_id=self.params.authorization_code.session_id,
)
id_token = IDToken.new(
self.provider,
refresh_token,
self.request,
)
id_token.nonce = self.params.authorization_code.nonce
id_token.at_hash = access_token.at_hash
refresh_token.id_token = id_token
refresh_token.save()
response["refresh_token"] = refresh_token.token
# Delete old code
self.params.authorization_code.delete()
return response
def create_refresh_response(self) -> dict[str, Any]:
"""See https://datatracker.ietf.org/doc/html/rfc6749#section-6"""
unauthorized_scopes = set(self.params.scope) - set(self.params.refresh_token.scope)
if unauthorized_scopes:
raise TokenError("invalid_scope")
if SCOPE_OFFLINE_ACCESS not in self.params.scope:
raise TokenError("invalid_scope")
now = timezone.now()
access_token_expiry = now + timedelta_from_string(self.provider.access_token_validity)
access_token = AccessToken(
@ -630,31 +636,34 @@ class TokenView(View):
)
access_token.save()
refresh_token_expiry = now + timedelta_from_string(self.provider.refresh_token_validity)
refresh_token = RefreshToken(
user=self.params.device_code.user,
scope=self.params.device_code.scope,
expires=refresh_token_expiry,
provider=self.provider,
auth_time=auth_event.created if auth_event else now,
)
id_token = IDToken.new(
self.provider,
refresh_token,
self.request,
)
id_token.at_hash = access_token.at_hash
refresh_token.id_token = id_token
refresh_token.save()
# Delete device code
self.params.device_code.delete()
return {
response = {
"access_token": access_token.token,
"refresh_token": refresh_token.token,
"token_type": TOKEN_TYPE,
"expires_in": int(
timedelta_from_string(self.provider.access_token_validity).total_seconds()
),
"id_token": id_token.to_jwt(self.provider),
"id_token": access_token.id_token.to_jwt(self.provider),
}
if SCOPE_OFFLINE_ACCESS in self.params.scope:
refresh_token_expiry = now + timedelta_from_string(self.provider.refresh_token_validity)
refresh_token = RefreshToken(
user=self.params.device_code.user,
scope=self.params.device_code.scope,
expires=refresh_token_expiry,
provider=self.provider,
auth_time=auth_event.created if auth_event else now,
)
id_token = IDToken.new(
self.provider,
refresh_token,
self.request,
)
id_token.at_hash = access_token.at_hash
refresh_token.id_token = id_token
refresh_token.save()
response["refresh_token"] = refresh_token.token
# Delete device code
self.params.device_code.delete()
return response