OAuth Provider Rewrite (#182)
This commit is contained in:
445
passbook/providers/oauth2/models.py
Normal file
445
passbook/providers/oauth2/models.py
Normal file
@ -0,0 +1,445 @@
|
||||
"""OAuth Provider Models"""
|
||||
import base64
|
||||
import binascii
|
||||
import json
|
||||
import time
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from hashlib import sha256
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
from uuid import uuid4
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import models
|
||||
from django.forms import ModelForm
|
||||
from django.http import HttpRequest
|
||||
from django.shortcuts import reverse
|
||||
from django.utils import dateformat, timezone
|
||||
from django.utils.translation import ugettext_lazy as _
|
||||
from jwkest.jwk import Key, RSAKey, SYMKey, import_rsa_key
|
||||
from jwkest.jws import JWS
|
||||
|
||||
from passbook.core.models import ExpiringModel, PropertyMapping, Provider, User
|
||||
from passbook.crypto.models import CertificateKeyPair
|
||||
from passbook.lib.utils.template import render_to_string
|
||||
from passbook.lib.utils.time import timedelta_from_string, timedelta_string_validator
|
||||
from passbook.providers.oauth2.apps import PassbookProviderOAuth2Config
|
||||
from passbook.providers.oauth2.generators import (
|
||||
generate_client_id,
|
||||
generate_client_secret,
|
||||
)
|
||||
|
||||
|
||||
class ClientTypes(models.TextChoices):
|
||||
"""<b>Confidential</b> clients are capable of maintaining the confidentiality
|
||||
of their credentials. <b>Public</b> clients are incapable."""
|
||||
|
||||
CONFIDENTIAL = "confidential", _("Confidential")
|
||||
PUBLIC = "public", _("Public")
|
||||
|
||||
|
||||
class GrantTypes(models.TextChoices):
|
||||
"""OAuth2 Grant types we support"""
|
||||
|
||||
AUTHORIZATION_CODE = "authorization_code"
|
||||
IMPLICIT = "implicit"
|
||||
HYBRID = "hybrid"
|
||||
|
||||
|
||||
class ResponseTypes(models.TextChoices):
|
||||
"""Response Type required by the client."""
|
||||
|
||||
CODE = "code", _("code (Authorization Code Flow)")
|
||||
ID_TOKEN = "id_token", _("id_token (Implicit Flow)")
|
||||
ID_TOKEN_TOKEN = "id_token token", _("id_token token (Implicit Flow)")
|
||||
CODE_TOKEN = "code token", _("code token (Hybrid Flow)")
|
||||
CODE_ID_TOKEN = "code id_token", _("code id_token (Hybrid Flow)")
|
||||
CODE_ID_TOKEN_TOKEN = "code id_token token", _("code id_token token (Hybrid Flow)")
|
||||
|
||||
|
||||
class JWTAlgorithms(models.TextChoices):
|
||||
"""Algorithm used to sign the JWT Token"""
|
||||
|
||||
HS256 = "HS256", _("HS256 (Symmetric Encryption)")
|
||||
RS256 = "RS256", _("RS256 (Asymmetric Encryption)")
|
||||
|
||||
|
||||
class ScopeMapping(PropertyMapping):
|
||||
"""Map an OAuth Scope to users properties"""
|
||||
|
||||
scope_name = models.TextField(help_text=_("Scope used by the client"))
|
||||
description = models.TextField(
|
||||
blank=True,
|
||||
help_text=_(
|
||||
(
|
||||
"Description shown to the user when consenting. "
|
||||
"If left empty, the user won't be informed."
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
def form(self) -> Type[ModelForm]:
|
||||
from passbook.providers.oauth2.forms import ScopeMappingForm
|
||||
|
||||
return ScopeMappingForm
|
||||
|
||||
def __str__(self):
|
||||
return f"Scope Mapping '{self.scope_name}'"
|
||||
|
||||
class Meta:
|
||||
|
||||
verbose_name = _("Scope Mapping")
|
||||
verbose_name_plural = _("Scope Mappings")
|
||||
|
||||
|
||||
class OAuth2Provider(Provider):
|
||||
"""OAuth2 Provider for generic OAuth and OpenID Connect Applications."""
|
||||
|
||||
name = models.TextField()
|
||||
|
||||
client_type = models.CharField(
|
||||
max_length=30,
|
||||
choices=ClientTypes.choices,
|
||||
default=ClientTypes.CONFIDENTIAL,
|
||||
verbose_name=_("Client Type"),
|
||||
help_text=_(ClientTypes.__doc__),
|
||||
)
|
||||
client_id = models.CharField(
|
||||
max_length=255,
|
||||
unique=True,
|
||||
verbose_name=_("Client ID"),
|
||||
default=generate_client_id,
|
||||
)
|
||||
client_secret = models.CharField(
|
||||
max_length=255,
|
||||
blank=True,
|
||||
verbose_name=_("Client Secret"),
|
||||
default=generate_client_secret,
|
||||
)
|
||||
response_type = models.TextField(
|
||||
choices=ResponseTypes.choices,
|
||||
default=ResponseTypes.CODE,
|
||||
help_text=_(ResponseTypes.__doc__),
|
||||
)
|
||||
jwt_alg = models.CharField(
|
||||
max_length=10,
|
||||
choices=JWTAlgorithms.choices,
|
||||
default=JWTAlgorithms.RS256,
|
||||
verbose_name=_("JWT Algorithm"),
|
||||
help_text=_(JWTAlgorithms.__doc__),
|
||||
)
|
||||
redirect_uris = models.TextField(
|
||||
default="",
|
||||
verbose_name=_("Redirect URIs"),
|
||||
help_text=_("Enter each URI on a new line."),
|
||||
)
|
||||
post_logout_redirect_uris = models.TextField(
|
||||
blank=True,
|
||||
default="",
|
||||
verbose_name=_("Post Logout Redirect URIs"),
|
||||
help_text=_("Enter each URI on a new line."),
|
||||
)
|
||||
|
||||
include_claims_in_id_token = models.BooleanField(
|
||||
default=True,
|
||||
verbose_name=_("Include claims in id_token"),
|
||||
help_text=_(
|
||||
(
|
||||
"Include User claims from scopes in the id_token, for applications "
|
||||
"that don't access the userinfo endpoint."
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
token_validity = models.TextField(
|
||||
default="minutes=10",
|
||||
validators=[timedelta_string_validator],
|
||||
help_text=_(
|
||||
(
|
||||
"Tokens not valid on or after current time + this value "
|
||||
"(Format: hours=1;minutes=2;seconds=3)."
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
rsa_key = models.ForeignKey(
|
||||
CertificateKeyPair,
|
||||
verbose_name=_("RSA Key"),
|
||||
on_delete=models.CASCADE,
|
||||
blank=True,
|
||||
null=True,
|
||||
help_text=_(
|
||||
"Key used to sign the tokens. Only required when JWT Algorithm is set to RS256."
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def scope_names(self) -> List[str]:
|
||||
"""Return list of assigned scopes seperated with a space"""
|
||||
return [pm.scope_name for pm in self.property_mappings.all()]
|
||||
|
||||
def create_refresh_token(
|
||||
self, user: User, scope: List[str], id_token: Optional["IDToken"] = None
|
||||
) -> "RefreshToken":
|
||||
"""Create and populate a RefreshToken object."""
|
||||
token = RefreshToken(
|
||||
user=user,
|
||||
provider=self,
|
||||
access_token=uuid4().hex,
|
||||
refresh_token=uuid4().hex,
|
||||
expires=timezone.now() + timedelta_from_string(self.token_validity),
|
||||
scope=scope,
|
||||
)
|
||||
if id_token:
|
||||
token.id_token = id_token
|
||||
return token
|
||||
|
||||
def get_jwt_keys(self) -> List[Key]:
|
||||
"""
|
||||
Takes a provider and returns the set of keys associated with it.
|
||||
Returns a list of keys.
|
||||
"""
|
||||
if self.jwt_alg == JWTAlgorithms.RS256:
|
||||
# if the user selected RS256 but didn't select a
|
||||
# CertificateKeyPair, we fall back to HS256
|
||||
if not self.rsa_key:
|
||||
self.jwt_alg = JWTAlgorithms.HS256
|
||||
self.save()
|
||||
else:
|
||||
# Because the JWT Library uses python cryptodome,
|
||||
# we can't directly pass the RSAPublicKey
|
||||
# object, but have to load it ourselves
|
||||
key = import_rsa_key(self.rsa_key.key_data)
|
||||
keys = [RSAKey(key=key, kid=self.rsa_key.kid)]
|
||||
if not keys:
|
||||
raise Exception("You must add at least one RSA Key.")
|
||||
return keys
|
||||
|
||||
if self.jwt_alg == JWTAlgorithms.HS256:
|
||||
return [SYMKey(key=self.client_secret, alg=self.jwt_alg)]
|
||||
|
||||
raise Exception("Unsupported key algorithm.")
|
||||
|
||||
def get_issuer(self, request: HttpRequest) -> Optional[str]:
|
||||
"""Get issuer, based on request"""
|
||||
try:
|
||||
mountpoint = PassbookProviderOAuth2Config.mountpoints[
|
||||
"passbook.providers.oauth2.urls"
|
||||
]
|
||||
# pylint: disable=no-member
|
||||
return request.build_absolute_uri(f"/{mountpoint}{self.application.slug}/")
|
||||
except Provider.application.RelatedObjectDoesNotExist:
|
||||
return None
|
||||
|
||||
def form(self) -> Type[ModelForm]:
|
||||
from passbook.providers.oauth2.forms import OAuth2ProviderForm
|
||||
|
||||
return OAuth2ProviderForm
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
def html_setup_urls(self, request: HttpRequest) -> Optional[str]:
|
||||
"""return template and context modal with URLs for authorize, token, openid-config, etc"""
|
||||
try:
|
||||
# pylint: disable=no-member
|
||||
return render_to_string(
|
||||
"providers/oauth2/setup_url_modal.html",
|
||||
{
|
||||
"provider": self,
|
||||
"authorize": request.build_absolute_uri(
|
||||
reverse("passbook_providers_oauth2:authorize",)
|
||||
),
|
||||
"token": request.build_absolute_uri(
|
||||
reverse("passbook_providers_oauth2:token",)
|
||||
),
|
||||
"userinfo": request.build_absolute_uri(
|
||||
reverse("passbook_providers_oauth2:userinfo",)
|
||||
),
|
||||
"provider_info": request.build_absolute_uri(
|
||||
reverse(
|
||||
"passbook_providers_oauth2:provider-info",
|
||||
kwargs={"application_slug": self.application.slug},
|
||||
)
|
||||
),
|
||||
},
|
||||
)
|
||||
except Provider.application.RelatedObjectDoesNotExist:
|
||||
return None
|
||||
|
||||
class Meta:
|
||||
|
||||
verbose_name = _("OAuth2/OpenID Provider")
|
||||
verbose_name_plural = _("OAuth2/OpenID Providers")
|
||||
|
||||
|
||||
class BaseGrantModel(models.Model):
|
||||
"""Base Model for all grants"""
|
||||
|
||||
provider = models.ForeignKey(OAuth2Provider, on_delete=models.CASCADE)
|
||||
user = models.ForeignKey(User, verbose_name=_("User"), on_delete=models.CASCADE)
|
||||
_scope = models.TextField(default="", verbose_name=_("Scopes"))
|
||||
|
||||
@property
|
||||
def scope(self) -> List[str]:
|
||||
"""Return scopes as list of strings"""
|
||||
return self._scope.split()
|
||||
|
||||
@scope.setter
|
||||
def scope(self, value):
|
||||
self._scope = " ".join(value)
|
||||
|
||||
class Meta:
|
||||
abstract = True
|
||||
|
||||
|
||||
# pylint: disable=too-many-instance-attributes
|
||||
class AuthorizationCode(ExpiringModel, BaseGrantModel):
|
||||
"""OAuth2 Authorization Code"""
|
||||
|
||||
code = models.CharField(max_length=255, unique=True, verbose_name=_("Code"))
|
||||
nonce = models.CharField(
|
||||
max_length=255, blank=True, default="", verbose_name=_("Nonce")
|
||||
)
|
||||
is_open_id = models.BooleanField(
|
||||
default=False, verbose_name=_("Is Authentication?")
|
||||
)
|
||||
code_challenge = models.CharField(
|
||||
max_length=255, null=True, verbose_name=_("Code Challenge")
|
||||
)
|
||||
code_challenge_method = models.CharField(
|
||||
max_length=255, null=True, verbose_name=_("Code Challenge Method")
|
||||
)
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("Authorization Code")
|
||||
verbose_name_plural = _("Authorization Codes")
|
||||
|
||||
def __str__(self):
|
||||
return "{0} - {1}".format(self.provider, self.code)
|
||||
|
||||
|
||||
@dataclass
|
||||
# plyint: disable=too-many-instance-attributes
|
||||
class IDToken:
|
||||
"""The primary extension that OpenID Connect makes to OAuth 2.0 to enable End-Users to be
|
||||
Authenticated is the ID Token data structure. The ID Token is a security token that contains
|
||||
Claims about the Authentication of an End-User by an Authorization Server when using a Client,
|
||||
and potentially other requested Claims. The ID Token is represented as a
|
||||
JSON Web Token (JWT) [JWT].
|
||||
|
||||
https://openid.net/specs/openid-connect-core-1_0.html#IDToken"""
|
||||
|
||||
# All these fields need to optional so we can save an empty IDToken for non-OpenID flows.
|
||||
iss: Optional[str] = None
|
||||
sub: Optional[str] = None
|
||||
aud: Optional[str] = None
|
||||
exp: Optional[int] = None
|
||||
iat: Optional[int] = None
|
||||
auth_time: Optional[int] = None
|
||||
|
||||
nonce: Optional[str] = None
|
||||
at_hash: Optional[str] = None
|
||||
|
||||
claims: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@staticmethod
|
||||
def from_dict(data: Dict[str, Any]) -> "IDToken":
|
||||
"""Reconstruct ID Token from json dictionary"""
|
||||
token = IDToken()
|
||||
for key, value in data.items():
|
||||
setattr(token, key, value)
|
||||
return token
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert dataclass to dict, and update with keys from `claims`"""
|
||||
dic = asdict(self)
|
||||
dic.pop("claims")
|
||||
dic.update(self.claims)
|
||||
return dic
|
||||
|
||||
def encode(self, provider: OAuth2Provider) -> str:
|
||||
"""Represent the ID Token as a JSON Web Token (JWT)."""
|
||||
keys = provider.get_jwt_keys()
|
||||
# If the provider does not have an RSA Key assigned, it was switched to Symmetric
|
||||
provider.refresh_from_db()
|
||||
jws = JWS(self.to_dict(), alg=provider.jwt_alg)
|
||||
return jws.sign_compact(keys)
|
||||
|
||||
|
||||
class RefreshToken(ExpiringModel, BaseGrantModel):
|
||||
"""OAuth2 Refresh Token"""
|
||||
|
||||
access_token = models.CharField(
|
||||
max_length=255, unique=True, verbose_name=_("Access Token")
|
||||
)
|
||||
refresh_token = models.CharField(
|
||||
max_length=255, unique=True, verbose_name=_("Refresh Token")
|
||||
)
|
||||
_id_token = models.TextField(verbose_name=_("ID Token"))
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("Token")
|
||||
verbose_name_plural = _("Tokens")
|
||||
|
||||
@property
|
||||
def id_token(self) -> IDToken:
|
||||
"""Load ID Token from json"""
|
||||
if self._id_token:
|
||||
raw_token = json.loads(self._id_token)
|
||||
return IDToken.from_dict(raw_token)
|
||||
return IDToken()
|
||||
|
||||
@id_token.setter
|
||||
def id_token(self, value: IDToken):
|
||||
self._id_token = json.dumps(asdict(value))
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.provider} - {self.access_token}"
|
||||
|
||||
@property
|
||||
def at_hash(self):
|
||||
"""Get hashed access_token"""
|
||||
hashed_access_token = (
|
||||
sha256(self.access_token.encode("ascii")).hexdigest().encode("ascii")
|
||||
)
|
||||
return (
|
||||
base64.urlsafe_b64encode(
|
||||
binascii.unhexlify(hashed_access_token[: len(hashed_access_token) // 2])
|
||||
)
|
||||
.rstrip(b"=")
|
||||
.decode("ascii")
|
||||
)
|
||||
|
||||
def create_id_token(self, user: User, request: HttpRequest) -> IDToken:
|
||||
"""Creates the id_token.
|
||||
See: http://openid.net/specs/openid-connect-core-1_0.html#IDToken"""
|
||||
sub = sha256(f"{user.id}-{settings.SECRET_KEY}".encode("ascii")).hexdigest()
|
||||
|
||||
# Convert datetimes into timestamps.
|
||||
now = int(time.time())
|
||||
iat_time = now
|
||||
exp_time = int(
|
||||
now + timedelta_from_string(self.provider.token_validity).seconds
|
||||
)
|
||||
user_auth_time = user.last_login or user.date_joined
|
||||
auth_time = int(dateformat.format(user_auth_time, "U"))
|
||||
|
||||
token = IDToken(
|
||||
iss=self.provider.get_issuer(request),
|
||||
sub=sub,
|
||||
aud=self.provider.client_id,
|
||||
exp=exp_time,
|
||||
iat=iat_time,
|
||||
auth_time=auth_time,
|
||||
)
|
||||
|
||||
# Include (or not) user standard claims in the id_token.
|
||||
if self.provider.include_claims_in_id_token:
|
||||
from passbook.providers.oauth2.views.userinfo import UserInfoView
|
||||
|
||||
user_info = UserInfoView()
|
||||
user_info.request = request
|
||||
claims = user_info.get_claims(self)
|
||||
token.claims = claims
|
||||
|
||||
return token
|
Reference in New Issue
Block a user