providers/oauth2: add missing kid header to JWT Tokens
Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
		| @ -6,11 +6,10 @@ import time | ||||
| from dataclasses import asdict, dataclass, field | ||||
| from datetime import datetime | ||||
| from hashlib import sha256 | ||||
| from typing import Any, Optional, Type, Union | ||||
| from typing import Any, Optional, Type | ||||
| from urllib.parse import urlparse | ||||
| from uuid import uuid4 | ||||
|  | ||||
| from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey | ||||
| from dacite import from_dict | ||||
| from django.db import models | ||||
| from django.http import HttpRequest | ||||
| @ -238,7 +237,7 @@ class OAuth2Provider(Provider): | ||||
|         token.access_token = token.create_access_token(user, request) | ||||
|         return token | ||||
|  | ||||
|     def get_jwt_keys(self) -> Union[RSAPrivateKey, str]: | ||||
|     def get_jwt_key(self) -> str: | ||||
|         """ | ||||
|         Takes a provider and returns the set of keys associated with it. | ||||
|         Returns a list of keys. | ||||
| @ -255,7 +254,7 @@ class OAuth2Provider(Provider): | ||||
|                 self.jwt_alg = JWTAlgorithms.HS256 | ||||
|                 self.save() | ||||
|             else: | ||||
|                 return self.rsa_key.private_key | ||||
|                 return self.rsa_key.key_data | ||||
|  | ||||
|         if self.jwt_alg == JWTAlgorithms.HS256: | ||||
|             return self.client_secret | ||||
| @ -299,11 +298,14 @@ class OAuth2Provider(Provider): | ||||
|  | ||||
|     def encode(self, payload: dict[str, Any]) -> str: | ||||
|         """Represent the ID Token as a JSON Web Token (JWT).""" | ||||
|         key = self.get_jwt_keys() | ||||
|         headers = {} | ||||
|         if self.rsa_key: | ||||
|             headers["kid"] = self.rsa_key.kid | ||||
|         key = self.get_jwt_key() | ||||
|         # If the provider does not have an RSA Key assigned, it was switched to Symmetric | ||||
|         self.refresh_from_db() | ||||
|         # pyright: reportGeneralTypeIssues=false | ||||
|         return encode(payload, key, algorithm=self.jwt_alg) | ||||
|         return encode(payload, key, algorithm=self.jwt_alg, headers=headers) | ||||
|  | ||||
|     class Meta: | ||||
|  | ||||
|  | ||||
| @ -4,6 +4,7 @@ from django.urls import reverse | ||||
| from django.utils.encoding import force_str | ||||
|  | ||||
| from authentik.core.models import Application, User | ||||
| from authentik.crypto.models import CertificateKeyPair | ||||
| from authentik.flows.challenge import ChallengeTypes | ||||
| from authentik.flows.models import Flow | ||||
| from authentik.providers.oauth2.errors import ( | ||||
| @ -207,6 +208,7 @@ class TestAuthorize(OAuthTestCase): | ||||
|             client_secret=generate_client_secret(), | ||||
|             authorization_flow=flow, | ||||
|             redirect_uris="http://localhost", | ||||
|             rsa_key=CertificateKeyPair.objects.first(), | ||||
|         ) | ||||
|         Application.objects.create(name="app", slug="app", provider=provider) | ||||
|         state = generate_client_id() | ||||
|  | ||||
| @ -2,7 +2,11 @@ | ||||
| from django.test import TestCase | ||||
| from jwt import decode | ||||
|  | ||||
| from authentik.providers.oauth2.models import OAuth2Provider, RefreshToken | ||||
| from authentik.providers.oauth2.models import ( | ||||
|     JWTAlgorithms, | ||||
|     OAuth2Provider, | ||||
|     RefreshToken, | ||||
| ) | ||||
|  | ||||
|  | ||||
| class OAuthTestCase(TestCase): | ||||
| @ -19,9 +23,12 @@ class OAuthTestCase(TestCase): | ||||
|  | ||||
|     def validate_jwt(self, token: RefreshToken, provider: OAuth2Provider): | ||||
|         """Validate that all required fields are set""" | ||||
|         key = provider.client_secret | ||||
|         if provider.jwt_alg == JWTAlgorithms.RS256: | ||||
|             key = provider.rsa_key.public_key | ||||
|         jwt = decode( | ||||
|             token.access_token, | ||||
|             provider.client_secret, | ||||
|             key, | ||||
|             algorithms=[provider.jwt_alg], | ||||
|             audience=provider.client_id, | ||||
|         ) | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	 Jens Langhammer
					Jens Langhammer