Files
authentik/authentik/providers/oauth2/views/jwks.py
Jens Langhammer a302a72379 crypto: fallback when no SAN values are given
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2023-01-18 19:40:24 +01:00

128 lines
4.4 KiB
Python

"""authentik OAuth2 JWKS Views"""
from base64 import b64encode, urlsafe_b64encode
from typing import Optional
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric.ec import (
SECP256R1,
SECP384R1,
SECP521R1,
EllipticCurvePrivateKey,
EllipticCurvePublicKey,
)
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPublicKey
from cryptography.hazmat.primitives.serialization import Encoding
from django.http import HttpRequest, HttpResponse, JsonResponse
from django.shortcuts import get_object_or_404
from django.views import View
from jwt.utils import base64url_encode
from authentik.core.models import Application
from authentik.crypto.models import CertificateKeyPair
from authentik.providers.oauth2.models import JWTAlgorithms, OAuth2Provider
# See https://notes.salrahman.com/generate-es256-es384-es512-private-keys/
# and _CURVE_TYPES in the same file as the below curve files
ec_crv_map = {
SECP256R1: "P-256",
SECP384R1: "P-384",
SECP521R1: "P-521",
}
min_length_map = {
SECP256R1: 32,
SECP384R1: 48,
SECP521R1: 66,
}
# https://github.com/jpadilla/pyjwt/issues/709
def bytes_from_int(val: int, min_length: int = 0) -> bytes:
"""Custom bytes_from_int that accepts a minimum length"""
remaining = val
byte_length = 0
while remaining != 0:
remaining >>= 8
byte_length += 1
length = max([byte_length, min_length])
return val.to_bytes(length, "big", signed=False)
def to_base64url_uint(val: int, min_length: int = 0) -> bytes:
"""Custom to_base64url_uint that accepts a minimum length"""
if val < 0:
raise ValueError("Must be a positive integer")
int_bytes = bytes_from_int(val, min_length)
if len(int_bytes) == 0:
int_bytes = b"\x00"
return base64url_encode(int_bytes)
class JWKSView(View):
"""Show RSA Key data for Provider"""
def get_jwk_for_key(self, key: CertificateKeyPair) -> Optional[dict]:
"""Convert a certificate-key pair into JWK"""
private_key = key.private_key
key_data = None
if not private_key:
return key_data
if isinstance(private_key, RSAPrivateKey):
public_key: RSAPublicKey = private_key.public_key()
public_numbers = public_key.public_numbers()
key_data = {
"kid": key.kid,
"kty": "RSA",
"alg": JWTAlgorithms.RS256,
"use": "sig",
"n": to_base64url_uint(public_numbers.n).decode(),
"e": to_base64url_uint(public_numbers.e).decode(),
}
elif isinstance(private_key, EllipticCurvePrivateKey):
public_key: EllipticCurvePublicKey = private_key.public_key()
public_numbers = public_key.public_numbers()
curve_type = type(public_key.curve)
key_data = {
"kid": key.kid,
"kty": "EC",
"alg": JWTAlgorithms.ES256,
"use": "sig",
"x": to_base64url_uint(public_numbers.x, min_length_map[curve_type]).decode(),
"y": to_base64url_uint(public_numbers.y, min_length_map[curve_type]).decode(),
"crv": ec_crv_map.get(curve_type, public_key.curve.name),
}
else:
return key_data
key_data["x5c"] = [b64encode(key.certificate.public_bytes(Encoding.DER)).decode("utf-8")]
key_data["x5t"] = (
urlsafe_b64encode(key.certificate.fingerprint(hashes.SHA1())) # nosec
.decode("utf-8")
.rstrip("=")
)
key_data["x5t#S256"] = (
urlsafe_b64encode(key.certificate.fingerprint(hashes.SHA256()))
.decode("utf-8")
.rstrip("=")
)
return key_data
def get(self, request: HttpRequest, application_slug: str) -> HttpResponse:
"""Show JWK Key data for Provider"""
application = get_object_or_404(Application, slug=application_slug)
provider: OAuth2Provider = get_object_or_404(OAuth2Provider, pk=application.provider_id)
signing_key: CertificateKeyPair = provider.signing_key
response_data = {}
if signing_key:
jwk = self.get_jwk_for_key(signing_key)
if jwk:
response_data["keys"] = [jwk]
response = JsonResponse(response_data)
response["Access-Control-Allow-Origin"] = "*"
return response