257 lines
		
	
	
		
			9.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			257 lines
		
	
	
		
			9.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
"""passbook OAuth2 Token views"""
 | 
						|
from base64 import urlsafe_b64encode
 | 
						|
from dataclasses import InitVar, dataclass
 | 
						|
from hashlib import sha256
 | 
						|
from typing import Any, Dict, List, Optional
 | 
						|
 | 
						|
from django.http import HttpRequest, HttpResponse
 | 
						|
from django.views import View
 | 
						|
from structlog import get_logger
 | 
						|
 | 
						|
from passbook.lib.utils.time import timedelta_from_string
 | 
						|
from passbook.providers.oauth2.constants import (
 | 
						|
    GRANT_TYPE_AUTHORIZATION_CODE,
 | 
						|
    GRANT_TYPE_REFRESH_TOKEN,
 | 
						|
)
 | 
						|
from passbook.providers.oauth2.errors import TokenError, UserAuthError
 | 
						|
from passbook.providers.oauth2.models import (
 | 
						|
    AuthorizationCode,
 | 
						|
    OAuth2Provider,
 | 
						|
    RefreshToken,
 | 
						|
    ResponseTypes,
 | 
						|
)
 | 
						|
from passbook.providers.oauth2.utils import TokenResponse, extract_client_auth
 | 
						|
 | 
						|
LOGGER = get_logger()
 | 
						|
 | 
						|
 | 
						|
@dataclass
 | 
						|
class TokenParams:
 | 
						|
    """Token params"""
 | 
						|
 | 
						|
    client_id: str
 | 
						|
    client_secret: str
 | 
						|
    redirect_uri: str
 | 
						|
    grant_type: str
 | 
						|
    state: str
 | 
						|
    scope: List[str]
 | 
						|
 | 
						|
    authorization_code: Optional[AuthorizationCode] = None
 | 
						|
    refresh_token: Optional[RefreshToken] = None
 | 
						|
 | 
						|
    code_verifier: Optional[str] = None
 | 
						|
 | 
						|
    raw_code: InitVar[str] = ""
 | 
						|
    raw_token: InitVar[str] = ""
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def from_request(request: HttpRequest) -> "TokenParams":
 | 
						|
        """Extract Token Parameters from http request"""
 | 
						|
        client_id, client_secret = extract_client_auth(request)
 | 
						|
 | 
						|
        return TokenParams(
 | 
						|
            client_id=client_id,
 | 
						|
            client_secret=client_secret,
 | 
						|
            redirect_uri=request.POST.get("redirect_uri", ""),
 | 
						|
            grant_type=request.POST.get("grant_type", ""),
 | 
						|
            raw_code=request.POST.get("code", ""),
 | 
						|
            raw_token=request.POST.get("refresh_token", ""),
 | 
						|
            state=request.POST.get("state", ""),
 | 
						|
            scope=request.POST.get("scope", "").split(),
 | 
						|
            # PKCE parameter.
 | 
						|
            code_verifier=request.POST.get("code_verifier"),
 | 
						|
        )
 | 
						|
 | 
						|
    def __post_init__(self, raw_code, raw_token):
 | 
						|
        try:
 | 
						|
            provider: OAuth2Provider = OAuth2Provider.objects.get(
 | 
						|
                client_id=self.client_id
 | 
						|
            )
 | 
						|
            self.provider = provider
 | 
						|
        except OAuth2Provider.DoesNotExist:
 | 
						|
            LOGGER.warning("OAuth2Provider does not exist", client_id=self.client_id)
 | 
						|
            raise TokenError("invalid_client")
 | 
						|
 | 
						|
        if self.provider.client_type == "confidential":
 | 
						|
            if self.provider.client_secret != self.client_secret:
 | 
						|
                LOGGER.warning(
 | 
						|
                    "Invalid client secret: client does not have secret",
 | 
						|
                    client_id=self.provider.client_id,
 | 
						|
                    secret=self.provider.client_secret,
 | 
						|
                )
 | 
						|
                raise TokenError("invalid_client")
 | 
						|
 | 
						|
        if self.grant_type == GRANT_TYPE_AUTHORIZATION_CODE:
 | 
						|
            self.__post_init_code(raw_code)
 | 
						|
 | 
						|
        elif self.grant_type == GRANT_TYPE_REFRESH_TOKEN:
 | 
						|
            if not raw_token:
 | 
						|
                LOGGER.warning("Missing refresh token")
 | 
						|
                raise TokenError("invalid_grant")
 | 
						|
 | 
						|
            try:
 | 
						|
                self.refresh_token = RefreshToken.objects.get(
 | 
						|
                    refresh_token=raw_token, provider=self.provider
 | 
						|
                )
 | 
						|
 | 
						|
            except RefreshToken.DoesNotExist:
 | 
						|
                LOGGER.warning(
 | 
						|
                    "Refresh token does not exist",
 | 
						|
                    token=raw_token,
 | 
						|
                )
 | 
						|
                raise TokenError("invalid_grant")
 | 
						|
 | 
						|
        else:
 | 
						|
            LOGGER.warning("Invalid grant type", grant_type=self.grant_type)
 | 
						|
            raise TokenError("unsupported_grant_type")
 | 
						|
 | 
						|
    def __post_init_code(self, raw_code):
 | 
						|
        if not raw_code:
 | 
						|
            LOGGER.warning("Missing authorization code")
 | 
						|
            raise TokenError("invalid_grant")
 | 
						|
 | 
						|
        if self.redirect_uri not in self.provider.redirect_uris.split():
 | 
						|
            LOGGER.warning(
 | 
						|
                "Invalid redirect uri",
 | 
						|
                uri=self.redirect_uri,
 | 
						|
                expected=self.provider.redirect_uris.split(),
 | 
						|
            )
 | 
						|
            raise TokenError("invalid_client")
 | 
						|
 | 
						|
        try:
 | 
						|
            self.authorization_code = AuthorizationCode.objects.get(code=raw_code)
 | 
						|
        except AuthorizationCode.DoesNotExist:
 | 
						|
            LOGGER.warning("Code does not exist", code=raw_code)
 | 
						|
            raise TokenError("invalid_grant")
 | 
						|
 | 
						|
        if (
 | 
						|
            self.authorization_code.provider != self.provider
 | 
						|
            or self.authorization_code.is_expired
 | 
						|
        ):
 | 
						|
            LOGGER.warning("Invalid code: invalid client or code has expired")
 | 
						|
            raise TokenError("invalid_grant")
 | 
						|
 | 
						|
        # Validate PKCE parameters.
 | 
						|
        if self.code_verifier:
 | 
						|
            if self.authorization_code.code_challenge_method == "S256":
 | 
						|
                new_code_challenge = (
 | 
						|
                    urlsafe_b64encode(
 | 
						|
                        sha256(self.code_verifier.encode("ascii")).digest()
 | 
						|
                    )
 | 
						|
                    .decode("utf-8")
 | 
						|
                    .replace("=", "")
 | 
						|
                )
 | 
						|
            else:
 | 
						|
                new_code_challenge = self.code_verifier
 | 
						|
 | 
						|
            if new_code_challenge != self.authorization_code.code_challenge:
 | 
						|
                LOGGER.warning("Code challenge not matching")
 | 
						|
                raise TokenError("invalid_grant")
 | 
						|
 | 
						|
 | 
						|
class TokenView(View):
 | 
						|
    """Generate tokens for clients"""
 | 
						|
 | 
						|
    params: TokenParams
 | 
						|
 | 
						|
    def post(self, request: HttpRequest) -> HttpResponse:
 | 
						|
        """Generate tokens for clients"""
 | 
						|
        try:
 | 
						|
            self.params = TokenParams.from_request(request)
 | 
						|
 | 
						|
            if self.params.grant_type == GRANT_TYPE_AUTHORIZATION_CODE:
 | 
						|
                return TokenResponse(self.create_code_response_dic())
 | 
						|
            if self.params.grant_type == GRANT_TYPE_REFRESH_TOKEN:
 | 
						|
                return TokenResponse(self.create_refresh_response_dic())
 | 
						|
            raise ValueError(f"Invalid grant_type: {self.params.grant_type}")
 | 
						|
        except TokenError as error:
 | 
						|
            return TokenResponse(error.create_dict(), status=400)
 | 
						|
        except UserAuthError as error:
 | 
						|
            return TokenResponse(error.create_dict(), status=403)
 | 
						|
 | 
						|
    def create_code_response_dic(self) -> Dict[str, Any]:
 | 
						|
        """See https://tools.ietf.org/html/rfc6749#section-4.1"""
 | 
						|
 | 
						|
        refresh_token = self.params.authorization_code.provider.create_refresh_token(
 | 
						|
            user=self.params.authorization_code.user,
 | 
						|
            scope=self.params.authorization_code.scope,
 | 
						|
        )
 | 
						|
 | 
						|
        if self.params.authorization_code.is_open_id:
 | 
						|
            id_token = refresh_token.create_id_token(
 | 
						|
                user=self.params.authorization_code.user,
 | 
						|
                request=self.request,
 | 
						|
            )
 | 
						|
            id_token.nonce = self.params.authorization_code.nonce
 | 
						|
            id_token.at_hash = refresh_token.at_hash
 | 
						|
            refresh_token.id_token = id_token
 | 
						|
 | 
						|
        # Store the token.
 | 
						|
        refresh_token.save()
 | 
						|
 | 
						|
        # We don't need to store the code anymore.
 | 
						|
        self.params.authorization_code.delete()
 | 
						|
 | 
						|
        response_dict = {
 | 
						|
            "access_token": refresh_token.access_token,
 | 
						|
            "refresh_token": refresh_token.refresh_token,
 | 
						|
            "token_type": "Bearer",
 | 
						|
            "expires_in": timedelta_from_string(
 | 
						|
                self.params.provider.token_validity
 | 
						|
            ).seconds,
 | 
						|
            "id_token": refresh_token.provider.encode(refresh_token.id_token.to_dict()),
 | 
						|
        }
 | 
						|
 | 
						|
        if self.params.provider.response_type == ResponseTypes.CODE_ADFS:
 | 
						|
            # This seems to be expected by some OIDC Clients
 | 
						|
            # namely VMware vCenter. This is not documented in any OpenID or OAuth2 Standard.
 | 
						|
            # Maybe this should be a setting
 | 
						|
            # in the future?
 | 
						|
            response_dict["access_token"] = response_dict["id_token"]
 | 
						|
 | 
						|
        return response_dict
 | 
						|
 | 
						|
    def create_refresh_response_dic(self) -> Dict[str, Any]:
 | 
						|
        """See https://tools.ietf.org/html/rfc6749#section-6"""
 | 
						|
 | 
						|
        unauthorized_scopes = set(self.params.scope) - set(
 | 
						|
            self.params.refresh_token.scope
 | 
						|
        )
 | 
						|
        if unauthorized_scopes:
 | 
						|
            raise TokenError("invalid_scope")
 | 
						|
 | 
						|
        provider: OAuth2Provider = self.params.refresh_token.provider
 | 
						|
 | 
						|
        refresh_token: RefreshToken = provider.create_refresh_token(
 | 
						|
            user=self.params.refresh_token.user,
 | 
						|
            scope=self.params.scope,
 | 
						|
        )
 | 
						|
 | 
						|
        # If the Token has an id_token it's an Authentication request.
 | 
						|
        if self.params.refresh_token.id_token:
 | 
						|
            refresh_token.id_token = refresh_token.create_id_token(
 | 
						|
                user=self.params.refresh_token.user,
 | 
						|
                request=self.request,
 | 
						|
            )
 | 
						|
            refresh_token.id_token.at_hash = refresh_token.at_hash
 | 
						|
 | 
						|
            # Store the refresh_token.
 | 
						|
            refresh_token.save()
 | 
						|
 | 
						|
        # Forget the old token.
 | 
						|
        self.params.refresh_token.delete()
 | 
						|
 | 
						|
        dic = {
 | 
						|
            "access_token": refresh_token.access_token,
 | 
						|
            "refresh_token": refresh_token.refresh_token,
 | 
						|
            "token_type": "bearer",
 | 
						|
            "expires_in": timedelta_from_string(
 | 
						|
                refresh_token.provider.token_validity
 | 
						|
            ).seconds,
 | 
						|
            "id_token": self.params.provider.encode(
 | 
						|
                self.params.refresh_token.id_token.to_dict()
 | 
						|
            ),
 | 
						|
        }
 | 
						|
 | 
						|
        return dic
 |