154 lines
		
	
	
		
			5.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			154 lines
		
	
	
		
			5.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """OAuth2/OpenID Utils"""
 | |
| import re
 | |
| from base64 import b64decode
 | |
| from binascii import Error
 | |
| from typing import List, Tuple
 | |
| 
 | |
| from django.http import HttpRequest, HttpResponse, JsonResponse
 | |
| from django.utils.cache import patch_vary_headers
 | |
| from jwkest.jwt import JWT
 | |
| from structlog import get_logger
 | |
| 
 | |
| from passbook.providers.oauth2.errors import BearerTokenError
 | |
| from passbook.providers.oauth2.models import RefreshToken
 | |
| 
 | |
| LOGGER = get_logger()
 | |
| 
 | |
| 
 | |
| class TokenResponse(JsonResponse):
 | |
|     """JSON Response with headers that it should never be cached
 | |
| 
 | |
|     https://openid.net/specs/openid-connect-core-1_0.html#TokenResponse"""
 | |
| 
 | |
|     def __init__(self, *args, **kwargs):
 | |
|         super().__init__(*args, **kwargs)
 | |
|         self["Cache-Control"] = "no-store"
 | |
|         self["Pragma"] = "no-cache"
 | |
| 
 | |
| 
 | |
| def cors_allow_any(request, response):
 | |
|     """
 | |
|     Add headers to permit CORS requests from any origin, with or without credentials,
 | |
|     with any headers.
 | |
|     """
 | |
|     origin = request.META.get("HTTP_ORIGIN")
 | |
|     if not origin:
 | |
|         return response
 | |
| 
 | |
|     # From the CORS spec: The string "*" cannot be used for a resource that supports credentials.
 | |
|     response["Access-Control-Allow-Origin"] = origin
 | |
|     patch_vary_headers(response, ["Origin"])
 | |
|     response["Access-Control-Allow-Credentials"] = "true"
 | |
| 
 | |
|     if request.method == "OPTIONS":
 | |
|         if "HTTP_ACCESS_CONTROL_REQUEST_HEADERS" in request.META:
 | |
|             response["Access-Control-Allow-Headers"] = request.META[
 | |
|                 "HTTP_ACCESS_CONTROL_REQUEST_HEADERS"
 | |
|             ]
 | |
|         response["Access-Control-Allow-Methods"] = "GET, POST, OPTIONS"
 | |
| 
 | |
|     return response
 | |
| 
 | |
| 
 | |
| def extract_access_token(request: HttpRequest) -> str:
 | |
|     """
 | |
|     Get the access token using Authorization Request Header Field method.
 | |
|     Or try getting via GET.
 | |
|     See: http://tools.ietf.org/html/rfc6750#section-2.1
 | |
| 
 | |
|     Return a string.
 | |
|     """
 | |
|     auth_header = request.META.get("HTTP_AUTHORIZATION", "")
 | |
| 
 | |
|     if re.compile(r"^[Bb]earer\s{1}.+$").match(auth_header):
 | |
|         return auth_header.split()[1]
 | |
|     if "access_token" in request.POST:
 | |
|         return request.POST.get("access_token")
 | |
|     if "access_token" in request.GET:
 | |
|         return request.GET.get("access_token")
 | |
|     return ""
 | |
| 
 | |
| 
 | |
| def extract_client_auth(request: HttpRequest) -> Tuple[str, str]:
 | |
|     """
 | |
|     Get client credentials using HTTP Basic Authentication method.
 | |
|     Or try getting parameters via POST.
 | |
|     See: http://tools.ietf.org/html/rfc6750#section-2.1
 | |
| 
 | |
|     Return a tuple `(client_id, client_secret)`.
 | |
|     """
 | |
|     auth_header = request.META.get("HTTP_AUTHORIZATION", "")
 | |
| 
 | |
|     if re.compile(r"^Basic\s{1}.+$").match(auth_header):
 | |
|         b64_user_pass = auth_header.split()[1]
 | |
|         try:
 | |
|             user_pass = b64decode(b64_user_pass).decode("utf-8").split(":")
 | |
|             client_id, client_secret = user_pass
 | |
|         except (ValueError, Error):
 | |
|             client_id = client_secret = ""
 | |
|     else:
 | |
|         client_id = request.POST.get("client_id", "")
 | |
|         client_secret = request.POST.get("client_secret", "")
 | |
| 
 | |
|     return (client_id, client_secret)
 | |
| 
 | |
| 
 | |
| def protected_resource_view(scopes: List[str]):
 | |
|     """View decorator. The client accesses protected resources by presenting the
 | |
|     access token to the resource server.
 | |
| 
 | |
|     https://tools.ietf.org/html/rfc6749#section-7
 | |
| 
 | |
|     This decorator also injects the token into `kwargs`"""
 | |
| 
 | |
|     def wrapper(view):
 | |
|         def view_wrapper(request, *args, **kwargs):
 | |
|             access_token = extract_access_token(request)
 | |
| 
 | |
|             try:
 | |
|                 try:
 | |
|                     kwargs["token"] = RefreshToken.objects.get(
 | |
|                         access_token=access_token
 | |
|                     )
 | |
|                 except RefreshToken.DoesNotExist:
 | |
|                     LOGGER.debug("Token does not exist", access_token=access_token)
 | |
|                     raise BearerTokenError("invalid_token")
 | |
| 
 | |
|                 if kwargs["token"].is_expired:
 | |
|                     LOGGER.debug("Token has expired", access_token=access_token)
 | |
|                     raise BearerTokenError("invalid_token")
 | |
| 
 | |
|                 if not set(scopes).issubset(set(kwargs["token"].scope)):
 | |
|                     LOGGER.warning(
 | |
|                         "Scope missmatch.",
 | |
|                         required=set(scopes),
 | |
|                         token_has=set(kwargs["token"].scope),
 | |
|                     )
 | |
|                     raise BearerTokenError("insufficient_scope")
 | |
|             except BearerTokenError as error:
 | |
|                 response = HttpResponse(status=error.status)
 | |
|                 response[
 | |
|                     "WWW-Authenticate"
 | |
|                 ] = f'error="{error.code}", error_description="{error.description}"'
 | |
|                 return response
 | |
| 
 | |
|             return view(request, *args, **kwargs)
 | |
| 
 | |
|         return view_wrapper
 | |
| 
 | |
|     return wrapper
 | |
| 
 | |
| 
 | |
| def client_id_from_id_token(id_token):
 | |
|     """
 | |
|     Extracts the client id from a JSON Web Token (JWT).
 | |
|     Returns a string or None.
 | |
|     """
 | |
|     payload = JWT().unpack(id_token).payload()
 | |
|     aud = payload.get("aud", None)
 | |
|     if aud is None:
 | |
|         return None
 | |
|     if isinstance(aud, list):
 | |
|         return aud[0]
 | |
|     return aud
 | 
