500 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			500 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
"""OAuth Provider Models"""
 | 
						|
import base64
 | 
						|
import binascii
 | 
						|
import json
 | 
						|
import time
 | 
						|
from dataclasses import asdict, dataclass, field
 | 
						|
from hashlib import sha256
 | 
						|
from typing import Any, Dict, List, Optional, Type
 | 
						|
from urllib.parse import urlparse
 | 
						|
from uuid import uuid4
 | 
						|
 | 
						|
from django.conf import settings
 | 
						|
from django.db import models
 | 
						|
from django.forms import ModelForm
 | 
						|
from django.http import HttpRequest
 | 
						|
from django.shortcuts import reverse
 | 
						|
from django.utils import dateformat, timezone
 | 
						|
from django.utils.translation import gettext_lazy as _
 | 
						|
from jwkest.jwk import Key, RSAKey, SYMKey, import_rsa_key
 | 
						|
from jwkest.jws import JWS
 | 
						|
 | 
						|
from passbook.core.models import ExpiringModel, PropertyMapping, Provider, User
 | 
						|
from passbook.crypto.models import CertificateKeyPair
 | 
						|
from passbook.lib.utils.template import render_to_string
 | 
						|
from passbook.lib.utils.time import timedelta_from_string, timedelta_string_validator
 | 
						|
from passbook.providers.oauth2.apps import PassbookProviderOAuth2Config
 | 
						|
from passbook.providers.oauth2.generators import (
 | 
						|
    generate_client_id,
 | 
						|
    generate_client_secret,
 | 
						|
)
 | 
						|
 | 
						|
 | 
						|
class ClientTypes(models.TextChoices):
 | 
						|
    """Confidential clients are capable of maintaining the confidentiality
 | 
						|
    of their credentials. Public clients are incapable."""
 | 
						|
 | 
						|
    CONFIDENTIAL = "confidential", _("Confidential")
 | 
						|
    PUBLIC = "public", _("Public")
 | 
						|
 | 
						|
 | 
						|
class GrantTypes(models.TextChoices):
 | 
						|
    """OAuth2 Grant types we support"""
 | 
						|
 | 
						|
    AUTHORIZATION_CODE = "authorization_code"
 | 
						|
    IMPLICIT = "implicit"
 | 
						|
    HYBRID = "hybrid"
 | 
						|
 | 
						|
 | 
						|
class SubModes(models.TextChoices):
 | 
						|
    """Mode after which 'sub' attribute is generateed, for compatibility reasons"""
 | 
						|
 | 
						|
    HASHED_USER_ID = "hashed_user_id", _("Based on the Hashed User ID")
 | 
						|
    USER_USERNAME = "user_username", _("Based on the username")
 | 
						|
    USER_EMAIL = (
 | 
						|
        "user_email",
 | 
						|
        _("Based on the User's Email. This is recommended over the UPN method."),
 | 
						|
    )
 | 
						|
    USER_UPN = (
 | 
						|
        "user_upn",
 | 
						|
        _(
 | 
						|
            (
 | 
						|
                "Based on the User's UPN, only works if user has a 'upn' attribute set. "
 | 
						|
                "Use this method only if you have different UPN and Mail domains."
 | 
						|
            )
 | 
						|
        ),
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
class ResponseTypes(models.TextChoices):
 | 
						|
    """Response Type required by the client."""
 | 
						|
 | 
						|
    CODE = "code", _("code (Authorization Code Flow)")
 | 
						|
    CODE_ADFS = (
 | 
						|
        "code#adfs",
 | 
						|
        _("code (ADFS Compatibility Mode, sends id_token as access_token)"),
 | 
						|
    )
 | 
						|
    ID_TOKEN = "id_token", _("id_token (Implicit Flow)")
 | 
						|
    ID_TOKEN_TOKEN = "id_token token", _("id_token token (Implicit Flow)")
 | 
						|
    CODE_TOKEN = "code token", _("code token (Hybrid Flow)")
 | 
						|
    CODE_ID_TOKEN = "code id_token", _("code id_token (Hybrid Flow)")
 | 
						|
    CODE_ID_TOKEN_TOKEN = "code id_token token", _("code id_token token (Hybrid Flow)")
 | 
						|
 | 
						|
 | 
						|
class JWTAlgorithms(models.TextChoices):
 | 
						|
    """Algorithm used to sign the JWT Token"""
 | 
						|
 | 
						|
    HS256 = "HS256", _("HS256 (Symmetric Encryption)")
 | 
						|
    RS256 = "RS256", _("RS256 (Asymmetric Encryption)")
 | 
						|
 | 
						|
 | 
						|
class ScopeMapping(PropertyMapping):
 | 
						|
    """Map an OAuth Scope to users properties"""
 | 
						|
 | 
						|
    scope_name = models.TextField(help_text=_("Scope used by the client"))
 | 
						|
    description = models.TextField(
 | 
						|
        blank=True,
 | 
						|
        help_text=_(
 | 
						|
            (
 | 
						|
                "Description shown to the user when consenting. "
 | 
						|
                "If left empty, the user won't be informed."
 | 
						|
            )
 | 
						|
        ),
 | 
						|
    )
 | 
						|
 | 
						|
    @property
 | 
						|
    def form(self) -> Type[ModelForm]:
 | 
						|
        from passbook.providers.oauth2.forms import ScopeMappingForm
 | 
						|
 | 
						|
        return ScopeMappingForm
 | 
						|
 | 
						|
    def __str__(self):
 | 
						|
        return f"Scope Mapping {self.name} ({self.scope_name})"
 | 
						|
 | 
						|
    class Meta:
 | 
						|
 | 
						|
        verbose_name = _("Scope Mapping")
 | 
						|
        verbose_name_plural = _("Scope Mappings")
 | 
						|
 | 
						|
 | 
						|
class OAuth2Provider(Provider):
 | 
						|
    """OAuth2 Provider for generic OAuth and OpenID Connect Applications."""
 | 
						|
 | 
						|
    client_type = models.CharField(
 | 
						|
        max_length=30,
 | 
						|
        choices=ClientTypes.choices,
 | 
						|
        default=ClientTypes.CONFIDENTIAL,
 | 
						|
        verbose_name=_("Client Type"),
 | 
						|
        help_text=_(ClientTypes.__doc__),
 | 
						|
    )
 | 
						|
    client_id = models.CharField(
 | 
						|
        max_length=255,
 | 
						|
        unique=True,
 | 
						|
        verbose_name=_("Client ID"),
 | 
						|
        default=generate_client_id,
 | 
						|
    )
 | 
						|
    client_secret = models.CharField(
 | 
						|
        max_length=255,
 | 
						|
        blank=True,
 | 
						|
        verbose_name=_("Client Secret"),
 | 
						|
        default=generate_client_secret,
 | 
						|
    )
 | 
						|
    response_type = models.TextField(
 | 
						|
        choices=ResponseTypes.choices,
 | 
						|
        default=ResponseTypes.CODE,
 | 
						|
        help_text=_(ResponseTypes.__doc__),
 | 
						|
    )
 | 
						|
    jwt_alg = models.CharField(
 | 
						|
        max_length=10,
 | 
						|
        choices=JWTAlgorithms.choices,
 | 
						|
        default=JWTAlgorithms.RS256,
 | 
						|
        verbose_name=_("JWT Algorithm"),
 | 
						|
        help_text=_(JWTAlgorithms.__doc__),
 | 
						|
    )
 | 
						|
    redirect_uris = models.TextField(
 | 
						|
        default="",
 | 
						|
        verbose_name=_("Redirect URIs"),
 | 
						|
        help_text=_("Enter each URI on a new line."),
 | 
						|
    )
 | 
						|
 | 
						|
    include_claims_in_id_token = models.BooleanField(
 | 
						|
        default=True,
 | 
						|
        verbose_name=_("Include claims in id_token"),
 | 
						|
        help_text=_(
 | 
						|
            (
 | 
						|
                "Include User claims from scopes in the id_token, for applications "
 | 
						|
                "that don't access the userinfo endpoint."
 | 
						|
            )
 | 
						|
        ),
 | 
						|
    )
 | 
						|
 | 
						|
    token_validity = models.TextField(
 | 
						|
        default="minutes=10",
 | 
						|
        validators=[timedelta_string_validator],
 | 
						|
        help_text=_(
 | 
						|
            (
 | 
						|
                "Tokens not valid on or after current time + this value "
 | 
						|
                "(Format: hours=1;minutes=2;seconds=3)."
 | 
						|
            )
 | 
						|
        ),
 | 
						|
    )
 | 
						|
 | 
						|
    sub_mode = models.TextField(
 | 
						|
        choices=SubModes.choices,
 | 
						|
        default=SubModes.HASHED_USER_ID,
 | 
						|
        help_text=_(
 | 
						|
            (
 | 
						|
                "Configure what data should be used as unique User Identifier. For most cases, "
 | 
						|
                "the default should be fine."
 | 
						|
            )
 | 
						|
        ),
 | 
						|
    )
 | 
						|
 | 
						|
    rsa_key = models.ForeignKey(
 | 
						|
        CertificateKeyPair,
 | 
						|
        verbose_name=_("RSA Key"),
 | 
						|
        on_delete=models.CASCADE,
 | 
						|
        blank=True,
 | 
						|
        null=True,
 | 
						|
        help_text=_(
 | 
						|
            "Key used to sign the tokens. Only required when JWT Algorithm is set to RS256."
 | 
						|
        ),
 | 
						|
    )
 | 
						|
 | 
						|
    def create_refresh_token(
 | 
						|
        self, user: User, scope: List[str], id_token: Optional["IDToken"] = None
 | 
						|
    ) -> "RefreshToken":
 | 
						|
        """Create and populate a RefreshToken object."""
 | 
						|
        token = RefreshToken(
 | 
						|
            user=user,
 | 
						|
            provider=self,
 | 
						|
            access_token=uuid4().hex,
 | 
						|
            refresh_token=uuid4().hex,
 | 
						|
            expires=timezone.now() + timedelta_from_string(self.token_validity),
 | 
						|
            scope=scope,
 | 
						|
        )
 | 
						|
        if id_token:
 | 
						|
            token.id_token = id_token
 | 
						|
        return token
 | 
						|
 | 
						|
    def get_jwt_keys(self) -> List[Key]:
 | 
						|
        """
 | 
						|
        Takes a provider and returns the set of keys associated with it.
 | 
						|
        Returns a list of keys.
 | 
						|
        """
 | 
						|
        if self.jwt_alg == JWTAlgorithms.RS256:
 | 
						|
            # if the user selected RS256 but didn't select a
 | 
						|
            # CertificateKeyPair, we fall back to HS256
 | 
						|
            if not self.rsa_key:
 | 
						|
                self.jwt_alg = JWTAlgorithms.HS256
 | 
						|
                self.save()
 | 
						|
            else:
 | 
						|
                # Because the JWT Library uses python cryptodome,
 | 
						|
                # we can't directly pass the RSAPublicKey
 | 
						|
                # object, but have to load it ourselves
 | 
						|
                key = import_rsa_key(self.rsa_key.key_data)
 | 
						|
                keys = [RSAKey(key=key, kid=self.rsa_key.kid)]
 | 
						|
                if not keys:
 | 
						|
                    raise Exception("You must add at least one RSA Key.")
 | 
						|
                return keys
 | 
						|
 | 
						|
        if self.jwt_alg == JWTAlgorithms.HS256:
 | 
						|
            return [SYMKey(key=self.client_secret, alg=self.jwt_alg)]
 | 
						|
 | 
						|
        raise Exception("Unsupported key algorithm.")
 | 
						|
 | 
						|
    def get_issuer(self, request: HttpRequest) -> Optional[str]:
 | 
						|
        """Get issuer, based on request"""
 | 
						|
        try:
 | 
						|
            mountpoint = PassbookProviderOAuth2Config.mountpoints[
 | 
						|
                "passbook.providers.oauth2.urls"
 | 
						|
            ]
 | 
						|
            # pylint: disable=no-member
 | 
						|
            return request.build_absolute_uri(f"/{mountpoint}{self.application.slug}/")
 | 
						|
        except Provider.application.RelatedObjectDoesNotExist:
 | 
						|
            return None
 | 
						|
 | 
						|
    @property
 | 
						|
    def launch_url(self) -> Optional[str]:
 | 
						|
        """Guess launch_url based on first redirect_uri"""
 | 
						|
        if self.redirect_uris == "":
 | 
						|
            return None
 | 
						|
        main_url = self.redirect_uris.split("\n")[0]
 | 
						|
        launch_url = urlparse(main_url)
 | 
						|
        return main_url.replace(launch_url.path, "")
 | 
						|
 | 
						|
    @property
 | 
						|
    def form(self) -> Type[ModelForm]:
 | 
						|
        from passbook.providers.oauth2.forms import OAuth2ProviderForm
 | 
						|
 | 
						|
        return OAuth2ProviderForm
 | 
						|
 | 
						|
    def __str__(self):
 | 
						|
        return f"OAuth2 Provider {self.name}"
 | 
						|
 | 
						|
    def encode(self, payload: Dict[str, Any]) -> str:
 | 
						|
        """Represent the ID Token as a JSON Web Token (JWT)."""
 | 
						|
        keys = self.get_jwt_keys()
 | 
						|
        # If the provider does not have an RSA Key assigned, it was switched to Symmetric
 | 
						|
        self.refresh_from_db()
 | 
						|
        jws = JWS(payload, alg=self.jwt_alg)
 | 
						|
        return jws.sign_compact(keys)
 | 
						|
 | 
						|
    def html_setup_urls(self, request: HttpRequest) -> Optional[str]:
 | 
						|
        """return template and context modal with URLs for authorize, token, openid-config, etc"""
 | 
						|
        try:
 | 
						|
            # pylint: disable=no-member
 | 
						|
            return render_to_string(
 | 
						|
                "providers/oauth2/setup_url_modal.html",
 | 
						|
                {
 | 
						|
                    "provider": self,
 | 
						|
                    "issuer": self.get_issuer(request),
 | 
						|
                    "authorize": request.build_absolute_uri(
 | 
						|
                        reverse(
 | 
						|
                            "passbook_providers_oauth2:authorize",
 | 
						|
                        )
 | 
						|
                    ),
 | 
						|
                    "token": request.build_absolute_uri(
 | 
						|
                        reverse(
 | 
						|
                            "passbook_providers_oauth2:token",
 | 
						|
                        )
 | 
						|
                    ),
 | 
						|
                    "userinfo": request.build_absolute_uri(
 | 
						|
                        reverse(
 | 
						|
                            "passbook_providers_oauth2:userinfo",
 | 
						|
                        )
 | 
						|
                    ),
 | 
						|
                    "provider_info": request.build_absolute_uri(
 | 
						|
                        reverse(
 | 
						|
                            "passbook_providers_oauth2:provider-info",
 | 
						|
                            kwargs={"application_slug": self.application.slug},
 | 
						|
                        )
 | 
						|
                    ),
 | 
						|
                },
 | 
						|
            )
 | 
						|
        except Provider.application.RelatedObjectDoesNotExist:
 | 
						|
            return None
 | 
						|
 | 
						|
    class Meta:
 | 
						|
 | 
						|
        verbose_name = _("OAuth2/OpenID Provider")
 | 
						|
        verbose_name_plural = _("OAuth2/OpenID Providers")
 | 
						|
 | 
						|
 | 
						|
class BaseGrantModel(models.Model):
 | 
						|
    """Base Model for all grants"""
 | 
						|
 | 
						|
    provider = models.ForeignKey(OAuth2Provider, on_delete=models.CASCADE)
 | 
						|
    user = models.ForeignKey(User, verbose_name=_("User"), on_delete=models.CASCADE)
 | 
						|
    _scope = models.TextField(default="", verbose_name=_("Scopes"))
 | 
						|
 | 
						|
    @property
 | 
						|
    def scope(self) -> List[str]:
 | 
						|
        """Return scopes as list of strings"""
 | 
						|
        return self._scope.split()
 | 
						|
 | 
						|
    @scope.setter
 | 
						|
    def scope(self, value):
 | 
						|
        self._scope = " ".join(value)
 | 
						|
 | 
						|
    class Meta:
 | 
						|
        abstract = True
 | 
						|
 | 
						|
 | 
						|
class AuthorizationCode(ExpiringModel, BaseGrantModel):
 | 
						|
    """OAuth2 Authorization Code"""
 | 
						|
 | 
						|
    code = models.CharField(max_length=255, unique=True, verbose_name=_("Code"))
 | 
						|
    nonce = models.CharField(
 | 
						|
        max_length=255, blank=True, default="", verbose_name=_("Nonce")
 | 
						|
    )
 | 
						|
    is_open_id = models.BooleanField(
 | 
						|
        default=False, verbose_name=_("Is Authentication?")
 | 
						|
    )
 | 
						|
    code_challenge = models.CharField(
 | 
						|
        max_length=255, null=True, verbose_name=_("Code Challenge")
 | 
						|
    )
 | 
						|
    code_challenge_method = models.CharField(
 | 
						|
        max_length=255, null=True, verbose_name=_("Code Challenge Method")
 | 
						|
    )
 | 
						|
 | 
						|
    class Meta:
 | 
						|
        verbose_name = _("Authorization Code")
 | 
						|
        verbose_name_plural = _("Authorization Codes")
 | 
						|
 | 
						|
    def __str__(self):
 | 
						|
        return "{0} - {1}".format(self.provider, self.code)
 | 
						|
 | 
						|
 | 
						|
@dataclass
 | 
						|
class IDToken:
 | 
						|
    """The primary extension that OpenID Connect makes to OAuth 2.0 to enable End-Users to be
 | 
						|
    Authenticated is the ID Token data structure. The ID Token is a security token that contains
 | 
						|
    Claims about the Authentication of an End-User by an Authorization Server when using a Client,
 | 
						|
    and potentially other requested Claims. The ID Token is represented as a
 | 
						|
    JSON Web Token (JWT) [JWT].
 | 
						|
 | 
						|
    https://openid.net/specs/openid-connect-core-1_0.html#IDToken"""
 | 
						|
 | 
						|
    # All these fields need to optional so we can save an empty IDToken for non-OpenID flows.
 | 
						|
    iss: Optional[str] = None
 | 
						|
    sub: Optional[str] = None
 | 
						|
    aud: Optional[str] = None
 | 
						|
    exp: Optional[int] = None
 | 
						|
    iat: Optional[int] = None
 | 
						|
    auth_time: Optional[int] = None
 | 
						|
 | 
						|
    nonce: Optional[str] = None
 | 
						|
    at_hash: Optional[str] = None
 | 
						|
 | 
						|
    claims: Dict[str, Any] = field(default_factory=dict)
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def from_dict(data: Dict[str, Any]) -> "IDToken":
 | 
						|
        """Reconstruct ID Token from json dictionary"""
 | 
						|
        token = IDToken()
 | 
						|
        for key, value in data.items():
 | 
						|
            setattr(token, key, value)
 | 
						|
        return token
 | 
						|
 | 
						|
    def to_dict(self) -> Dict[str, Any]:
 | 
						|
        """Convert dataclass to dict, and update with keys from `claims`"""
 | 
						|
        dic = asdict(self)
 | 
						|
        dic.pop("claims")
 | 
						|
        dic.update(self.claims)
 | 
						|
        return dic
 | 
						|
 | 
						|
 | 
						|
class RefreshToken(ExpiringModel, BaseGrantModel):
 | 
						|
    """OAuth2 Refresh Token"""
 | 
						|
 | 
						|
    access_token = models.CharField(
 | 
						|
        max_length=255, unique=True, verbose_name=_("Access Token")
 | 
						|
    )
 | 
						|
    refresh_token = models.CharField(
 | 
						|
        max_length=255, unique=True, verbose_name=_("Refresh Token")
 | 
						|
    )
 | 
						|
    _id_token = models.TextField(verbose_name=_("ID Token"))
 | 
						|
 | 
						|
    class Meta:
 | 
						|
        verbose_name = _("OAuth2 Token")
 | 
						|
        verbose_name_plural = _("OAuth2 Tokens")
 | 
						|
 | 
						|
    @property
 | 
						|
    def id_token(self) -> IDToken:
 | 
						|
        """Load ID Token from json"""
 | 
						|
        if self._id_token:
 | 
						|
            raw_token = json.loads(self._id_token)
 | 
						|
            return IDToken.from_dict(raw_token)
 | 
						|
        return IDToken()
 | 
						|
 | 
						|
    @id_token.setter
 | 
						|
    def id_token(self, value: IDToken):
 | 
						|
        self._id_token = json.dumps(asdict(value))
 | 
						|
 | 
						|
    def __str__(self):
 | 
						|
        return f"{self.provider} - {self.access_token}"
 | 
						|
 | 
						|
    @property
 | 
						|
    def at_hash(self):
 | 
						|
        """Get hashed access_token"""
 | 
						|
        hashed_access_token = (
 | 
						|
            sha256(self.access_token.encode("ascii")).hexdigest().encode("ascii")
 | 
						|
        )
 | 
						|
        return (
 | 
						|
            base64.urlsafe_b64encode(
 | 
						|
                binascii.unhexlify(hashed_access_token[: len(hashed_access_token) // 2])
 | 
						|
            )
 | 
						|
            .rstrip(b"=")
 | 
						|
            .decode("ascii")
 | 
						|
        )
 | 
						|
 | 
						|
    def create_id_token(self, user: User, request: HttpRequest) -> IDToken:
 | 
						|
        """Creates the id_token.
 | 
						|
        See: http://openid.net/specs/openid-connect-core-1_0.html#IDToken"""
 | 
						|
        sub = ""
 | 
						|
        if self.provider.sub_mode == SubModes.HASHED_USER_ID:
 | 
						|
            sub = sha256(f"{user.id}-{settings.SECRET_KEY}".encode("ascii")).hexdigest()
 | 
						|
        elif self.provider.sub_mode == SubModes.USER_EMAIL:
 | 
						|
            sub = user.email
 | 
						|
        elif self.provider.sub_mode == SubModes.USER_USERNAME:
 | 
						|
            sub = user.username
 | 
						|
        elif self.provider.sub_mode == SubModes.USER_UPN:
 | 
						|
            sub = user.attributes["upn"]
 | 
						|
        else:
 | 
						|
            raise ValueError(
 | 
						|
                (
 | 
						|
                    f"Provider {self.provider} has invalid sub_mode "
 | 
						|
                    f"selected: {self.provider.sub_mode}"
 | 
						|
                )
 | 
						|
            )
 | 
						|
 | 
						|
        # Convert datetimes into timestamps.
 | 
						|
        now = int(time.time())
 | 
						|
        iat_time = now
 | 
						|
        exp_time = int(
 | 
						|
            now + timedelta_from_string(self.provider.token_validity).seconds
 | 
						|
        )
 | 
						|
        user_auth_time = user.last_login or user.date_joined
 | 
						|
        auth_time = int(dateformat.format(user_auth_time, "U"))
 | 
						|
 | 
						|
        token = IDToken(
 | 
						|
            iss=self.provider.get_issuer(request),
 | 
						|
            sub=sub,
 | 
						|
            aud=self.provider.client_id,
 | 
						|
            exp=exp_time,
 | 
						|
            iat=iat_time,
 | 
						|
            auth_time=auth_time,
 | 
						|
        )
 | 
						|
 | 
						|
        # Include (or not) user standard claims in the id_token.
 | 
						|
        if self.provider.include_claims_in_id_token:
 | 
						|
            from passbook.providers.oauth2.views.userinfo import UserInfoView
 | 
						|
 | 
						|
            user_info = UserInfoView()
 | 
						|
            user_info.request = request
 | 
						|
            claims = user_info.get_claims(self)
 | 
						|
            token.claims = claims
 | 
						|
 | 
						|
        return token
 |