providers/oauth2: rewrite introspection endpoint to allow basic or bearer auth
This commit is contained in:
		| @ -4,14 +4,14 @@ from time import sleep | ||||
| from typing import Any, Dict, Optional | ||||
| from unittest.case import skipUnless | ||||
|  | ||||
| from channels.testing import ChannelsLiveServerTestCase | ||||
| from docker.client import DockerClient, from_env | ||||
| from docker.models.containers import Container | ||||
| from selenium.webdriver.common.by import By | ||||
| from selenium.webdriver.common.keys import Keys | ||||
| from channels.testing import ChannelsLiveServerTestCase | ||||
|  | ||||
| from passbook import __version__ | ||||
| from e2e.utils import USER, SeleniumTestCase | ||||
| from passbook import __version__ | ||||
| from passbook.core.models import Application | ||||
| from passbook.flows.models import Flow | ||||
| from passbook.outposts.models import Outpost, OutpostDeploymentType, OutpostType | ||||
| @ -124,6 +124,7 @@ class TestProviderProxyConnect(ChannelsLiveServerTestCase): | ||||
|         return container | ||||
|  | ||||
|     def test_proxy_connectivity(self): | ||||
|         """Test proxy connectivity over websocket""" | ||||
|         SeleniumTestCase().apply_default_data() | ||||
|         proxy: ProxyProvider = ProxyProvider.objects.create( | ||||
|             name="proxy_provider", | ||||
|  | ||||
| @ -7,7 +7,6 @@ PROMPT_CONSNET = "consent" | ||||
| SCOPE_OPENID = "openid" | ||||
| SCOPE_OPENID_PROFILE = "profile" | ||||
| SCOPE_OPENID_EMAIL = "email" | ||||
| SCOPE_OPENID_INTROSPECTION = "token_introspection" | ||||
|  | ||||
| # Read/write full user (including email) | ||||
| SCOPE_GITHUB_USER = "user" | ||||
|  | ||||
| @ -202,11 +202,6 @@ class OAuth2Provider(Provider): | ||||
|         ), | ||||
|     ) | ||||
|  | ||||
|     @property | ||||
|     def scope_names(self) -> List[str]: | ||||
|         """Return list of assigned scopes seperated with a space""" | ||||
|         return [pm.scope_name for pm in self.property_mappings.all()] | ||||
|  | ||||
|     def create_refresh_token( | ||||
|         self, user: User, scope: List[str], id_token: Optional["IDToken"] = None | ||||
|     ) -> "RefreshToken": | ||||
|  | ||||
| @ -2,7 +2,7 @@ | ||||
| import re | ||||
| from base64 import b64decode | ||||
| from binascii import Error | ||||
| from typing import List, Tuple | ||||
| from typing import List, Optional, Tuple | ||||
|  | ||||
| from django.http import HttpRequest, HttpResponse, JsonResponse | ||||
| from django.utils.cache import patch_vary_headers | ||||
| @ -50,7 +50,7 @@ def cors_allow_any(request, response): | ||||
|     return response | ||||
|  | ||||
|  | ||||
| def extract_access_token(request: HttpRequest) -> str: | ||||
| def extract_access_token(request: HttpRequest) -> Optional[str]: | ||||
|     """ | ||||
|     Get the access token using Authorization Request Header Field method. | ||||
|     Or try getting via GET. | ||||
| @ -66,7 +66,7 @@ def extract_access_token(request: HttpRequest) -> str: | ||||
|         return request.POST.get("access_token") | ||||
|     if "access_token" in request.GET: | ||||
|         return request.GET.get("access_token") | ||||
|     return "" | ||||
|     return None | ||||
|  | ||||
|  | ||||
| def extract_client_auth(request: HttpRequest) -> Tuple[str, str]: | ||||
| @ -103,9 +103,12 @@ def protected_resource_view(scopes: List[str]): | ||||
|  | ||||
|     def wrapper(view): | ||||
|         def view_wrapper(request, *args, **kwargs): | ||||
|             access_token = extract_access_token(request) | ||||
|  | ||||
|             try: | ||||
|                 access_token = extract_access_token(request) | ||||
|                 if not access_token: | ||||
|                     LOGGER.debug("No token passed") | ||||
|                     raise BearerTokenError("invalid_token") | ||||
|  | ||||
|                 try: | ||||
|                     kwargs["token"] = RefreshToken.objects.get( | ||||
|                         access_token=access_token | ||||
|  | ||||
| @ -1,15 +1,17 @@ | ||||
| """passbook OAuth2 Token Introspection Views""" | ||||
| from dataclasses import InitVar, dataclass | ||||
| from typing import Optional | ||||
| from dataclasses import dataclass, field | ||||
|  | ||||
| from django.http import HttpRequest, HttpResponse | ||||
| from django.views import View | ||||
| from structlog import get_logger | ||||
|  | ||||
| from passbook.providers.oauth2.constants import SCOPE_OPENID_INTROSPECTION | ||||
| from passbook.providers.oauth2.errors import TokenIntrospectionError | ||||
| from passbook.providers.oauth2.models import IDToken, OAuth2Provider, RefreshToken | ||||
| from passbook.providers.oauth2.utils import TokenResponse, extract_client_auth | ||||
| from passbook.providers.oauth2.utils import ( | ||||
|     TokenResponse, | ||||
|     extract_access_token, | ||||
|     extract_client_auth, | ||||
| ) | ||||
|  | ||||
| LOGGER = get_logger() | ||||
|  | ||||
| @ -18,39 +20,17 @@ LOGGER = get_logger() | ||||
| class TokenIntrospectionParams: | ||||
|     """Parameters for Token Introspection""" | ||||
|  | ||||
|     client_id: str | ||||
|     client_secret: str | ||||
|     token: RefreshToken | ||||
|  | ||||
|     raw_token: InitVar[str] | ||||
|     provider: OAuth2Provider = field(init=False) | ||||
|     id_token: IDToken = field(init=False) | ||||
|  | ||||
|     token: Optional[RefreshToken] = None | ||||
|  | ||||
|     provider: Optional[OAuth2Provider] = None | ||||
|     id_token: Optional[IDToken] = None | ||||
|  | ||||
|     def __post_init__(self, raw_token: str): | ||||
|         try: | ||||
|             self.token = RefreshToken.objects.get(access_token=raw_token) | ||||
|         except RefreshToken.DoesNotExist: | ||||
|             LOGGER.debug("Token does not exist", token=raw_token) | ||||
|             raise TokenIntrospectionError() | ||||
|     def __post_init__(self): | ||||
|         if self.token.is_expired: | ||||
|             LOGGER.debug("Token is not valid", token=raw_token) | ||||
|             raise TokenIntrospectionError() | ||||
|         try: | ||||
|             self.provider = OAuth2Provider.objects.get( | ||||
|                 client_id=self.client_id, client_secret=self.client_secret, | ||||
|             ) | ||||
|         except OAuth2Provider.DoesNotExist: | ||||
|             LOGGER.debug("provider for ID not found", client_id=self.client_id) | ||||
|             raise TokenIntrospectionError() | ||||
|         if SCOPE_OPENID_INTROSPECTION not in self.provider.scope_names: | ||||
|             LOGGER.debug( | ||||
|                 "OAuth2Provider does not have introspection scope", | ||||
|                 client_id=self.client_id, | ||||
|             ) | ||||
|             LOGGER.debug("Token is not valid") | ||||
|             raise TokenIntrospectionError() | ||||
|  | ||||
|         self.provider = self.token.provider | ||||
|         self.id_token = self.token.id_token | ||||
|  | ||||
|         if not self.token.id_token: | ||||
| @ -59,31 +39,61 @@ class TokenIntrospectionParams: | ||||
|             ) | ||||
|             raise TokenIntrospectionError() | ||||
|  | ||||
|         audience = self.token.id_token.aud | ||||
|         if not audience: | ||||
|             LOGGER.debug( | ||||
|                 "No audience found for token", token=self.token, | ||||
|             ) | ||||
|     def authenticate_basic(self, request: HttpRequest) -> bool: | ||||
|         """Attempt to authenticate via Basic auth of client_id:client_secret""" | ||||
|         client_id, client_secret = extract_client_auth(request) | ||||
|         if client_id == client_secret == "": | ||||
|             return False | ||||
|         if ( | ||||
|             client_id != self.provider.client_id | ||||
|             or client_secret != self.provider.client_secret | ||||
|         ): | ||||
|             LOGGER.debug("(basic) Provider for basic auth does not exist") | ||||
|             raise TokenIntrospectionError() | ||||
|         return True | ||||
|  | ||||
|         if audience not in self.provider.scope_names: | ||||
|             LOGGER.debug( | ||||
|                 "provider does not audience scope", | ||||
|                 client_id=self.client_id, | ||||
|                 audience=audience, | ||||
|             ) | ||||
|     def authenticate_bearer(self, request: HttpRequest) -> bool: | ||||
|         """Attempt to authenticate via token sent as bearer header""" | ||||
|         body_token = extract_access_token(request) | ||||
|         if not body_token: | ||||
|             return False | ||||
|         tokens = RefreshToken.objects.filter(access_token=body_token).select_related( | ||||
|             "provider" | ||||
|         ) | ||||
|         if not tokens.exists(): | ||||
|             LOGGER.debug("(bearer) Token does not exist") | ||||
|             raise TokenIntrospectionError() | ||||
|         if tokens.first().provider != self.provider: | ||||
|             LOGGER.debug("(bearer) Token providers don't match") | ||||
|             raise TokenIntrospectionError() | ||||
|         return True | ||||
|  | ||||
|     @staticmethod | ||||
|     def from_request(request: HttpRequest) -> "TokenIntrospectionParams": | ||||
|         """Extract required Parameters from HTTP Request""" | ||||
|         # Introspection only supports POST requests | ||||
|         client_id, client_secret = extract_client_auth(request) | ||||
|         return TokenIntrospectionParams( | ||||
|             raw_token=request.POST.get("token"), | ||||
|             client_id=client_id, | ||||
|             client_secret=client_secret, | ||||
|         ) | ||||
|         raw_token = request.POST.get("token") | ||||
|         token_type_hint = request.POST.get("token_type_hint", "access_token") | ||||
|         token_filter = {token_type_hint: raw_token} | ||||
|  | ||||
|         if token_type_hint not in ["access_token", "refresh_token"]: | ||||
|             LOGGER.debug("token_type_hint has invalid value", value=token_type_hint) | ||||
|             raise TokenIntrospectionError() | ||||
|  | ||||
|         try: | ||||
|             token: RefreshToken = RefreshToken.objects.select_related("provider").get( | ||||
|                 **token_filter | ||||
|             ) | ||||
|         except RefreshToken.DoesNotExist: | ||||
|             LOGGER.debug("Token does not exist", token=raw_token) | ||||
|             raise TokenIntrospectionError() | ||||
|  | ||||
|         params = TokenIntrospectionParams(token=token) | ||||
|         if not any( | ||||
|             [params.authenticate_basic(request), params.authenticate_bearer(request)] | ||||
|         ): | ||||
|             LOGGER.debug("Not authenticated") | ||||
|             raise TokenIntrospectionError() | ||||
|         return params | ||||
|  | ||||
|  | ||||
| class TokenIntrospectionView(View): | ||||
| @ -101,12 +111,12 @@ class TokenIntrospectionView(View): | ||||
|             self.params = TokenIntrospectionParams.from_request(request) | ||||
|  | ||||
|             response_dic = {} | ||||
|             if self.id_token: | ||||
|                 token_dict = self.id_token.to_dict() | ||||
|             if self.params.id_token: | ||||
|                 token_dict = self.params.id_token.to_dict() | ||||
|                 for k in ("aud", "sub", "exp", "iat", "iss"): | ||||
|                     response_dic[k] = token_dict[k] | ||||
|             response_dic["active"] = True | ||||
|             response_dic["client_id"] = self.token.provider.client_id | ||||
|             response_dic["client_id"] = self.params.token.provider.client_id | ||||
|  | ||||
|             return TokenResponse(response_dic) | ||||
|         except TokenIntrospectionError: | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	 Jens Langhammer
					Jens Langhammer