providers/oauth2: always test JWT keys in tests
Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
		| @ -1,8 +1,7 @@ | |||||||
| """Test authorize view""" | """Test authorize view""" | ||||||
| from django.test import RequestFactory, TestCase | from django.test import RequestFactory | ||||||
| from django.urls import reverse | from django.urls import reverse | ||||||
| from django.utils.encoding import force_str | from django.utils.encoding import force_str | ||||||
| from jwt import decode |  | ||||||
|  |  | ||||||
| from authentik.core.models import Application, User | from authentik.core.models import Application, User | ||||||
| from authentik.flows.challenge import ChallengeTypes | from authentik.flows.challenge import ChallengeTypes | ||||||
| @ -22,10 +21,11 @@ from authentik.providers.oauth2.models import ( | |||||||
|     OAuth2Provider, |     OAuth2Provider, | ||||||
|     RefreshToken, |     RefreshToken, | ||||||
| ) | ) | ||||||
|  | from authentik.providers.oauth2.tests.utils import OAuthTestCase | ||||||
| from authentik.providers.oauth2.views.authorize import OAuthAuthorizationParams | from authentik.providers.oauth2.views.authorize import OAuthAuthorizationParams | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestAuthorize(TestCase): | class TestAuthorize(OAuthTestCase): | ||||||
|     """Test authorize view""" |     """Test authorize view""" | ||||||
|  |  | ||||||
|     def setUp(self) -> None: |     def setUp(self) -> None: | ||||||
| @ -238,23 +238,4 @@ class TestAuthorize(TestCase): | |||||||
|                 ), |                 ), | ||||||
|             }, |             }, | ||||||
|         ) |         ) | ||||||
|         jwt = decode( |         self.validate_jwt(token, provider) | ||||||
|             token.access_token, |  | ||||||
|             provider.client_secret, |  | ||||||
|             algorithms=[provider.jwt_alg], |  | ||||||
|             audience=provider.client_id, |  | ||||||
|         ) |  | ||||||
|         self.assertIsNotNone(jwt["exp"]) |  | ||||||
|         self.assertIsNotNone(jwt["iat"]) |  | ||||||
|         self.assertIsNotNone(jwt["auth_time"]) |  | ||||||
|         self.assertIsNotNone(jwt["acr"]) |  | ||||||
|         self.assertIsNotNone(jwt["sub"]) |  | ||||||
|         self.assertIsNotNone(jwt["iss"]) |  | ||||||
|         # Check id_token |  | ||||||
|         id_token = token.id_token.to_dict() |  | ||||||
|         self.assertIsNotNone(id_token["exp"]) |  | ||||||
|         self.assertIsNotNone(id_token["iat"]) |  | ||||||
|         self.assertIsNotNone(id_token["auth_time"]) |  | ||||||
|         self.assertIsNotNone(id_token["acr"]) |  | ||||||
|         self.assertIsNotNone(id_token["sub"]) |  | ||||||
|         self.assertIsNotNone(id_token["iss"]) |  | ||||||
|  | |||||||
| @ -1,11 +1,11 @@ | |||||||
| """Test token view""" | """Test token view""" | ||||||
| from base64 import b64encode | from base64 import b64encode | ||||||
|  |  | ||||||
| from django.test import RequestFactory, TestCase | from django.test import RequestFactory | ||||||
| from django.urls import reverse | from django.urls import reverse | ||||||
| from django.utils.encoding import force_str | from django.utils.encoding import force_str | ||||||
|  |  | ||||||
| from authentik.core.models import User | from authentik.core.models import Application, User | ||||||
| from authentik.flows.models import Flow | from authentik.flows.models import Flow | ||||||
| from authentik.providers.oauth2.constants import ( | from authentik.providers.oauth2.constants import ( | ||||||
|     GRANT_TYPE_AUTHORIZATION_CODE, |     GRANT_TYPE_AUTHORIZATION_CODE, | ||||||
| @ -20,15 +20,17 @@ from authentik.providers.oauth2.models import ( | |||||||
|     OAuth2Provider, |     OAuth2Provider, | ||||||
|     RefreshToken, |     RefreshToken, | ||||||
| ) | ) | ||||||
|  | from authentik.providers.oauth2.tests.utils import OAuthTestCase | ||||||
| from authentik.providers.oauth2.views.token import TokenParams | from authentik.providers.oauth2.views.token import TokenParams | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestToken(TestCase): | class TestToken(OAuthTestCase): | ||||||
|     """Test token view""" |     """Test token view""" | ||||||
|  |  | ||||||
|     def setUp(self) -> None: |     def setUp(self) -> None: | ||||||
|         super().setUp() |         super().setUp() | ||||||
|         self.factory = RequestFactory() |         self.factory = RequestFactory() | ||||||
|  |         self.app = Application.objects.create(name="test", slug="test") | ||||||
|  |  | ||||||
|     def test_request_auth_code(self): |     def test_request_auth_code(self): | ||||||
|         """test request param""" |         """test request param""" | ||||||
| @ -97,12 +99,15 @@ class TestToken(TestCase): | |||||||
|             authorization_flow=Flow.objects.first(), |             authorization_flow=Flow.objects.first(), | ||||||
|             redirect_uris="http://local.invalid", |             redirect_uris="http://local.invalid", | ||||||
|         ) |         ) | ||||||
|  |         # Needs to be assigned to an application for iss to be set | ||||||
|  |         self.app.provider = provider | ||||||
|  |         self.app.save() | ||||||
|         header = b64encode( |         header = b64encode( | ||||||
|             f"{provider.client_id}:{provider.client_secret}".encode() |             f"{provider.client_id}:{provider.client_secret}".encode() | ||||||
|         ).decode() |         ).decode() | ||||||
|         user = User.objects.get(username="akadmin") |         user = User.objects.get(username="akadmin") | ||||||
|         code = AuthorizationCode.objects.create( |         code = AuthorizationCode.objects.create( | ||||||
|             code="foobar", provider=provider, user=user |             code="foobar", provider=provider, user=user, is_open_id=True | ||||||
|         ) |         ) | ||||||
|         response = self.client.post( |         response = self.client.post( | ||||||
|             reverse("authentik_providers_oauth2:token"), |             reverse("authentik_providers_oauth2:token"), | ||||||
| @ -126,6 +131,7 @@ class TestToken(TestCase): | |||||||
|                 ), |                 ), | ||||||
|             }, |             }, | ||||||
|         ) |         ) | ||||||
|  |         self.validate_jwt(new_token, provider) | ||||||
|  |  | ||||||
|     def test_refresh_token_view(self): |     def test_refresh_token_view(self): | ||||||
|         """test request param""" |         """test request param""" | ||||||
| @ -136,6 +142,9 @@ class TestToken(TestCase): | |||||||
|             authorization_flow=Flow.objects.first(), |             authorization_flow=Flow.objects.first(), | ||||||
|             redirect_uris="http://local.invalid", |             redirect_uris="http://local.invalid", | ||||||
|         ) |         ) | ||||||
|  |         # Needs to be assigned to an application for iss to be set | ||||||
|  |         self.app.provider = provider | ||||||
|  |         self.app.save() | ||||||
|         header = b64encode( |         header = b64encode( | ||||||
|             f"{provider.client_id}:{provider.client_secret}".encode() |             f"{provider.client_id}:{provider.client_secret}".encode() | ||||||
|         ).decode() |         ).decode() | ||||||
| @ -174,6 +183,7 @@ class TestToken(TestCase): | |||||||
|                 ), |                 ), | ||||||
|             }, |             }, | ||||||
|         ) |         ) | ||||||
|  |         self.validate_jwt(new_token, provider) | ||||||
|  |  | ||||||
|     def test_refresh_token_view_invalid_origin(self): |     def test_refresh_token_view_invalid_origin(self): | ||||||
|         """test request param""" |         """test request param""" | ||||||
|  | |||||||
							
								
								
									
										31
									
								
								authentik/providers/oauth2/tests/utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								authentik/providers/oauth2/tests/utils.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,31 @@ | |||||||
|  | """OAuth test helpers""" | ||||||
|  | from django.test import TestCase | ||||||
|  | from jwt import decode | ||||||
|  |  | ||||||
|  | from authentik.providers.oauth2.models import OAuth2Provider, RefreshToken | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class OAuthTestCase(TestCase): | ||||||
|  |     """OAuth test helpers""" | ||||||
|  |  | ||||||
|  |     required_jwt_keys = [ | ||||||
|  |         "exp", | ||||||
|  |         "iat", | ||||||
|  |         "auth_time", | ||||||
|  |         "acr", | ||||||
|  |         "sub", | ||||||
|  |         "iss", | ||||||
|  |     ] | ||||||
|  |  | ||||||
|  |     def validate_jwt(self, token: RefreshToken, provider: OAuth2Provider): | ||||||
|  |         """Validate that all required fields are set""" | ||||||
|  |         jwt = decode( | ||||||
|  |             token.access_token, | ||||||
|  |             provider.client_secret, | ||||||
|  |             algorithms=[provider.jwt_alg], | ||||||
|  |             audience=provider.client_id, | ||||||
|  |         ) | ||||||
|  |         id_token = token.id_token.to_dict() | ||||||
|  |         for key in self.required_jwt_keys: | ||||||
|  |             self.assertIsNotNone(jwt[key], f"Key {key} is missing in access_token") | ||||||
|  |             self.assertIsNotNone(id_token[key], f"Key {key} is missing in id_token") | ||||||
| @ -16,6 +16,7 @@ from authentik.providers.oauth2.constants import ( | |||||||
| from authentik.providers.oauth2.errors import TokenError, UserAuthError | from authentik.providers.oauth2.errors import TokenError, UserAuthError | ||||||
| from authentik.providers.oauth2.models import ( | from authentik.providers.oauth2.models import ( | ||||||
|     AuthorizationCode, |     AuthorizationCode, | ||||||
|  |     ClientTypes, | ||||||
|     OAuth2Provider, |     OAuth2Provider, | ||||||
|     RefreshToken, |     RefreshToken, | ||||||
| ) | ) | ||||||
| @ -75,7 +76,7 @@ class TokenParams: | |||||||
|             LOGGER.warning("OAuth2Provider does not exist", client_id=self.client_id) |             LOGGER.warning("OAuth2Provider does not exist", client_id=self.client_id) | ||||||
|             raise TokenError("invalid_client") |             raise TokenError("invalid_client") | ||||||
|  |  | ||||||
|         if self.provider.client_type == "confidential": |         if self.provider.client_type == ClientTypes.CONFIDENTIAL: | ||||||
|             if self.provider.client_secret != self.client_secret: |             if self.provider.client_secret != self.client_secret: | ||||||
|                 LOGGER.warning( |                 LOGGER.warning( | ||||||
|                     "Invalid client secret: client does not have secret", |                     "Invalid client secret: client does not have secret", | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user
	 Jens Langhammer
					Jens Langhammer