Compare commits
	
		
			70 Commits
		
	
	
		
			enterprise
			...
			version/20
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| c99a33baee | |||
| b17d482e50 | |||
| 524d46ad7c | |||
| f90d6bb3d9 | |||
| 2340bced63 | |||
| 0a51e1b696 | |||
| 13636c0efe | |||
| e7f49d97a8 | |||
| 736240f60d | |||
| e8b5e4c127 | |||
| 81ec98b198 | |||
| c46ab19e79 | |||
| de9fc5de6b | |||
| eab3d9b411 | |||
| 7cb40d786f | |||
| b4fce08bbc | |||
| 8a2ba1c518 | |||
| 25b4306693 | |||
| 1e279950f1 | |||
| 960429355f | |||
| b4f3748353 | |||
| 91d2445c61 | |||
| dd8f809161 | |||
| 57a31b5dd1 | |||
| 09125b6236 | |||
| 832126c6fe | |||
| 25fe489b34 | |||
| 18078fd68f | |||
| 4fa71d995d | |||
| 22cec64234 | |||
| a87cc27366 | |||
| ad7ad1fa78 | |||
| c70e609e50 | |||
| 5f08485fff | |||
| 3a2ed11821 | |||
| ee04f39e28 | |||
| 2c6aa72f3c | |||
| bd0afef790 | |||
| fc11cc0a1a | |||
| fb78303e8f | |||
| 2ea04440db | |||
| 96e1636be3 | |||
| c546451a73 | |||
| 61778053b4 | |||
| f5580d311d | |||
| 99d292bce0 | |||
| b2801641bc | |||
| bfaa1046b2 | |||
| 95c30400cc | |||
| e77480ee1d | |||
| 905800e535 | |||
| fadeaef4c6 | |||
| 437efda649 | |||
| dd75d5f54b | |||
| 392a2e582e | |||
| a1da183721 | |||
| feea2df0b1 | |||
| b47acd8c76 | |||
| 6fd87d9ced | |||
| acbb065808 | |||
| 2fb097061d | |||
| 8962d17e03 | |||
| 8326e1490c | |||
| 091e4d3e4c | |||
| 6ee77edcbb | |||
| 763e2288bf | |||
| 9cdb177ca7 | |||
| 6070508058 | |||
| ec13a5d84d | |||
| 057de82b01 | 
| @ -1,5 +1,5 @@ | |||||||
| [bumpversion] | [bumpversion] | ||||||
| current_version = 2024.6.4 | current_version = 2024.8.5 | ||||||
| tag = True | tag = True | ||||||
| commit = True | commit = True | ||||||
| parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))? | parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))? | ||||||
|  | |||||||
| @ -29,9 +29,9 @@ outputs: | |||||||
|   imageTags: |   imageTags: | ||||||
|     description: "Docker image tags" |     description: "Docker image tags" | ||||||
|     value: ${{ steps.ev.outputs.imageTags }} |     value: ${{ steps.ev.outputs.imageTags }} | ||||||
|   imageNames: |   attestImageNames: | ||||||
|     description: "Docker image names" |     description: "Docker image names used for attestation" | ||||||
|     value: ${{ steps.ev.outputs.imageNames }} |     value: ${{ steps.ev.outputs.attestImageNames }} | ||||||
|   imageMainTag: |   imageMainTag: | ||||||
|     description: "Docker image main tag" |     description: "Docker image main tag" | ||||||
|     value: ${{ steps.ev.outputs.imageMainTag }} |     value: ${{ steps.ev.outputs.imageMainTag }} | ||||||
|  | |||||||
| @ -51,15 +51,24 @@ else: | |||||||
|         ] |         ] | ||||||
|  |  | ||||||
| image_main_tag = image_tags[0].split(":")[-1] | image_main_tag = image_tags[0].split(":")[-1] | ||||||
| image_tags_rendered = ",".join(image_tags) |  | ||||||
| image_names_rendered = ",".join(set(name.split(":")[0] for name in image_tags)) |  | ||||||
|  | def get_attest_image_names(image_with_tags: list[str]): | ||||||
|  |     """Attestation only for GHCR""" | ||||||
|  |     image_tags = [] | ||||||
|  |     for image_name in set(name.split(":")[0] for name in image_with_tags): | ||||||
|  |         if not image_name.startswith("ghcr.io"): | ||||||
|  |             continue | ||||||
|  |         image_tags.append(image_name) | ||||||
|  |     return ",".join(set(image_tags)) | ||||||
|  |  | ||||||
|  |  | ||||||
| with open(os.environ["GITHUB_OUTPUT"], "a+", encoding="utf-8") as _output: | with open(os.environ["GITHUB_OUTPUT"], "a+", encoding="utf-8") as _output: | ||||||
|     print(f"shouldBuild={should_build}", file=_output) |     print(f"shouldBuild={should_build}", file=_output) | ||||||
|     print(f"sha={sha}", file=_output) |     print(f"sha={sha}", file=_output) | ||||||
|     print(f"version={version}", file=_output) |     print(f"version={version}", file=_output) | ||||||
|     print(f"prerelease={prerelease}", file=_output) |     print(f"prerelease={prerelease}", file=_output) | ||||||
|     print(f"imageTags={image_tags_rendered}", file=_output) |     print(f"imageTags={','.join(image_tags)}", file=_output) | ||||||
|     print(f"imageNames={image_names_rendered}", file=_output) |     print(f"attestImageNames={get_attest_image_names(image_tags)}", file=_output) | ||||||
|     print(f"imageMainTag={image_main_tag}", file=_output) |     print(f"imageMainTag={image_main_tag}", file=_output) | ||||||
|     print(f"imageMainName={image_tags[0]}", file=_output) |     print(f"imageMainName={image_tags[0]}", file=_output) | ||||||
|  | |||||||
							
								
								
									
										2
									
								
								.github/workflows/ci-main.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/ci-main.yml
									
									
									
									
										vendored
									
									
								
							| @ -261,7 +261,7 @@ jobs: | |||||||
|         id: attest |         id: attest | ||||||
|         if: ${{ steps.ev.outputs.shouldBuild == 'true' }} |         if: ${{ steps.ev.outputs.shouldBuild == 'true' }} | ||||||
|         with: |         with: | ||||||
|           subject-name: ${{ steps.ev.outputs.imageNames }} |           subject-name: ${{ steps.ev.outputs.attestImageNames }} | ||||||
|           subject-digest: ${{ steps.push.outputs.digest }} |           subject-digest: ${{ steps.push.outputs.digest }} | ||||||
|           push-to-registry: true |           push-to-registry: true | ||||||
|   pr-comment: |   pr-comment: | ||||||
|  | |||||||
							
								
								
									
										2
									
								
								.github/workflows/ci-outpost.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/ci-outpost.yml
									
									
									
									
										vendored
									
									
								
							| @ -115,7 +115,7 @@ jobs: | |||||||
|         id: attest |         id: attest | ||||||
|         if: ${{ steps.ev.outputs.shouldBuild == 'true' }} |         if: ${{ steps.ev.outputs.shouldBuild == 'true' }} | ||||||
|         with: |         with: | ||||||
|           subject-name: ${{ steps.ev.outputs.imageNames }} |           subject-name: ${{ steps.ev.outputs.attestImageNames }} | ||||||
|           subject-digest: ${{ steps.push.outputs.digest }} |           subject-digest: ${{ steps.push.outputs.digest }} | ||||||
|           push-to-registry: true |           push-to-registry: true | ||||||
|   build-binary: |   build-binary: | ||||||
|  | |||||||
							
								
								
									
										4
									
								
								.github/workflows/release-publish.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.github/workflows/release-publish.yml
									
									
									
									
										vendored
									
									
								
							| @ -58,7 +58,7 @@ jobs: | |||||||
|       - uses: actions/attest-build-provenance@v1 |       - uses: actions/attest-build-provenance@v1 | ||||||
|         id: attest |         id: attest | ||||||
|         with: |         with: | ||||||
|           subject-name: ${{ steps.ev.outputs.imageNames }} |           subject-name: ${{ steps.ev.outputs.attestImageNames }} | ||||||
|           subject-digest: ${{ steps.push.outputs.digest }} |           subject-digest: ${{ steps.push.outputs.digest }} | ||||||
|           push-to-registry: true |           push-to-registry: true | ||||||
|   build-outpost: |   build-outpost: | ||||||
| @ -122,7 +122,7 @@ jobs: | |||||||
|       - uses: actions/attest-build-provenance@v1 |       - uses: actions/attest-build-provenance@v1 | ||||||
|         id: attest |         id: attest | ||||||
|         with: |         with: | ||||||
|           subject-name: ${{ steps.ev.outputs.imageNames }} |           subject-name: ${{ steps.ev.outputs.attestImageNames }} | ||||||
|           subject-digest: ${{ steps.push.outputs.digest }} |           subject-digest: ${{ steps.push.outputs.digest }} | ||||||
|           push-to-registry: true |           push-to-registry: true | ||||||
|   build-outpost-binary: |   build-outpost-binary: | ||||||
|  | |||||||
							
								
								
									
										2
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								Makefile
									
									
									
									
									
								
							| @ -205,7 +205,7 @@ gen: gen-build gen-client-ts | |||||||
| web-build: web-install  ## Build the Authentik UI | web-build: web-install  ## Build the Authentik UI | ||||||
| 	cd web && npm run build | 	cd web && npm run build | ||||||
|  |  | ||||||
| web: web-lint-fix web-lint web-check-compile web-test  ## Automatically fix formatting issues in the Authentik UI source code, lint the code, and compile it | web: web-lint-fix web-lint web-check-compile  ## Automatically fix formatting issues in the Authentik UI source code, lint the code, and compile it | ||||||
|  |  | ||||||
| web-install:  ## Install the necessary libraries to build the Authentik UI | web-install:  ## Install the necessary libraries to build the Authentik UI | ||||||
| 	cd web && npm ci | 	cd web && npm ci | ||||||
|  | |||||||
| @ -2,7 +2,7 @@ | |||||||
|  |  | ||||||
| from os import environ | from os import environ | ||||||
|  |  | ||||||
| __version__ = "2024.6.4" | __version__ = "2024.8.5" | ||||||
| ENV_GIT_HASH_KEY = "GIT_BUILD_HASH" | ENV_GIT_HASH_KEY = "GIT_BUILD_HASH" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| @ -51,9 +51,11 @@ class BlueprintInstanceSerializer(ModelSerializer): | |||||||
|         context = self.instance.context if self.instance else {} |         context = self.instance.context if self.instance else {} | ||||||
|         valid, logs = Importer.from_string(content, context).validate() |         valid, logs = Importer.from_string(content, context).validate() | ||||||
|         if not valid: |         if not valid: | ||||||
|             text_logs = "\n".join([x["event"] for x in logs]) |  | ||||||
|             raise ValidationError( |             raise ValidationError( | ||||||
|                 _("Failed to validate blueprint: {logs}".format_map({"logs": text_logs})) |                 [ | ||||||
|  |                     _("Failed to validate blueprint"), | ||||||
|  |                     *[f"- {x.event}" for x in logs], | ||||||
|  |                 ] | ||||||
|             ) |             ) | ||||||
|         return content |         return content | ||||||
|  |  | ||||||
|  | |||||||
| @ -78,5 +78,5 @@ class TestBlueprintsV1API(APITestCase): | |||||||
|         self.assertEqual(res.status_code, 400) |         self.assertEqual(res.status_code, 400) | ||||||
|         self.assertJSONEqual( |         self.assertJSONEqual( | ||||||
|             res.content.decode(), |             res.content.decode(), | ||||||
|             {"content": ["Failed to validate blueprint: Invalid blueprint version"]}, |             {"content": ["Failed to validate blueprint", "- Invalid blueprint version"]}, | ||||||
|         ) |         ) | ||||||
|  | |||||||
| @ -429,7 +429,7 @@ class Importer: | |||||||
|         orig_import = deepcopy(self._import) |         orig_import = deepcopy(self._import) | ||||||
|         if self._import.version != 1: |         if self._import.version != 1: | ||||||
|             self.logger.warning("Invalid blueprint version") |             self.logger.warning("Invalid blueprint version") | ||||||
|             return False, [{"event": "Invalid blueprint version"}] |             return False, [LogEvent("Invalid blueprint version", log_level="warning", logger=None)] | ||||||
|         with ( |         with ( | ||||||
|             transaction_rollback(), |             transaction_rollback(), | ||||||
|             capture_logs() as logs, |             capture_logs() as logs, | ||||||
|  | |||||||
| @ -30,8 +30,10 @@ from authentik.core.api.utils import ( | |||||||
|     PassiveSerializer, |     PassiveSerializer, | ||||||
| ) | ) | ||||||
| from authentik.core.expression.evaluator import PropertyMappingEvaluator | from authentik.core.expression.evaluator import PropertyMappingEvaluator | ||||||
|  | from authentik.core.expression.exceptions import PropertyMappingExpressionException | ||||||
| from authentik.core.models import Group, PropertyMapping, User | from authentik.core.models import Group, PropertyMapping, User | ||||||
| from authentik.events.utils import sanitize_item | from authentik.events.utils import sanitize_item | ||||||
|  | from authentik.lib.utils.errors import exception_to_string | ||||||
| from authentik.policies.api.exec import PolicyTestSerializer | from authentik.policies.api.exec import PolicyTestSerializer | ||||||
| from authentik.rbac.decorators import permission_required | from authentik.rbac.decorators import permission_required | ||||||
|  |  | ||||||
| @ -162,12 +164,15 @@ class PropertyMappingViewSet( | |||||||
|  |  | ||||||
|         response_data = {"successful": True, "result": ""} |         response_data = {"successful": True, "result": ""} | ||||||
|         try: |         try: | ||||||
|             result = mapping.evaluate(**context) |             result = mapping.evaluate(dry_run=True, **context) | ||||||
|             response_data["result"] = dumps( |             response_data["result"] = dumps( | ||||||
|                 sanitize_item(result), indent=(4 if format_result else None) |                 sanitize_item(result), indent=(4 if format_result else None) | ||||||
|             ) |             ) | ||||||
|  |         except PropertyMappingExpressionException as exc: | ||||||
|  |             response_data["result"] = exception_to_string(exc.exc) | ||||||
|  |             response_data["successful"] = False | ||||||
|         except Exception as exc: |         except Exception as exc: | ||||||
|             response_data["result"] = str(exc) |             response_data["result"] = exception_to_string(exc) | ||||||
|             response_data["successful"] = False |             response_data["successful"] = False | ||||||
|         response = PropertyMappingTestResultSerializer(response_data) |         response = PropertyMappingTestResultSerializer(response_data) | ||||||
|         return Response(response.data) |         return Response(response.data) | ||||||
|  | |||||||
| @ -678,10 +678,13 @@ class UserViewSet(UsedByMixin, ModelViewSet): | |||||||
|         if not request.tenant.impersonation: |         if not request.tenant.impersonation: | ||||||
|             LOGGER.debug("User attempted to impersonate", user=request.user) |             LOGGER.debug("User attempted to impersonate", user=request.user) | ||||||
|             return Response(status=401) |             return Response(status=401) | ||||||
|         if not request.user.has_perm("impersonate"): |         user_to_be = self.get_object() | ||||||
|  |         # Check both object-level perms and global perms | ||||||
|  |         if not request.user.has_perm( | ||||||
|  |             "authentik_core.impersonate", user_to_be | ||||||
|  |         ) and not request.user.has_perm("authentik_core.impersonate"): | ||||||
|             LOGGER.debug("User attempted to impersonate without permissions", user=request.user) |             LOGGER.debug("User attempted to impersonate without permissions", user=request.user) | ||||||
|             return Response(status=401) |             return Response(status=401) | ||||||
|         user_to_be = self.get_object() |  | ||||||
|         if user_to_be.pk == self.request.user.pk: |         if user_to_be.pk == self.request.user.pk: | ||||||
|             LOGGER.debug("User attempted to impersonate themselves", user=request.user) |             LOGGER.debug("User attempted to impersonate themselves", user=request.user) | ||||||
|             return Response(status=401) |             return Response(status=401) | ||||||
|  | |||||||
| @ -9,10 +9,11 @@ class Command(TenantCommand): | |||||||
|  |  | ||||||
|     def add_arguments(self, parser): |     def add_arguments(self, parser): | ||||||
|         parser.add_argument("--type", type=str, required=True) |         parser.add_argument("--type", type=str, required=True) | ||||||
|         parser.add_argument("--all", action="store_true") |         parser.add_argument("--all", action="store_true", default=False) | ||||||
|         parser.add_argument("usernames", nargs="+", type=str) |         parser.add_argument("usernames", nargs="*", type=str) | ||||||
|  |  | ||||||
|     def handle_per_tenant(self, **options): |     def handle_per_tenant(self, **options): | ||||||
|  |         print(options) | ||||||
|         new_type = UserTypes(options["type"]) |         new_type = UserTypes(options["type"]) | ||||||
|         qs = ( |         qs = ( | ||||||
|             User.objects.exclude_anonymous() |             User.objects.exclude_anonymous() | ||||||
| @ -22,6 +23,9 @@ class Command(TenantCommand): | |||||||
|         if options["usernames"] and options["all"]: |         if options["usernames"] and options["all"]: | ||||||
|             self.stderr.write("--all and usernames specified, only one can be specified") |             self.stderr.write("--all and usernames specified, only one can be specified") | ||||||
|             return |             return | ||||||
|  |         if not options["usernames"] and not options["all"]: | ||||||
|  |             self.stderr.write("--all or usernames must be specified") | ||||||
|  |             return | ||||||
|         if options["usernames"] and not options["all"]: |         if options["usernames"] and not options["all"]: | ||||||
|             qs = qs.filter(username__in=options["usernames"]) |             qs = qs.filter(username__in=options["usernames"]) | ||||||
|         updated = qs.update(type=new_type) |         updated = qs.update(type=new_type) | ||||||
|  | |||||||
| @ -466,8 +466,6 @@ class ApplicationQuerySet(QuerySet): | |||||||
|     def with_provider(self) -> "QuerySet[Application]": |     def with_provider(self) -> "QuerySet[Application]": | ||||||
|         qs = self.select_related("provider") |         qs = self.select_related("provider") | ||||||
|         for subclass in Provider.objects.get_queryset()._get_subclasses_recurse(Provider): |         for subclass in Provider.objects.get_queryset()._get_subclasses_recurse(Provider): | ||||||
|             if LOOKUP_SEP in subclass: |  | ||||||
|                 continue |  | ||||||
|             qs = qs.select_related(f"provider__{subclass}") |             qs = qs.select_related(f"provider__{subclass}") | ||||||
|         return qs |         return qs | ||||||
|  |  | ||||||
| @ -545,15 +543,24 @@ class Application(SerializerModel, PolicyBindingModel): | |||||||
|         if not self.provider: |         if not self.provider: | ||||||
|             return None |             return None | ||||||
|  |  | ||||||
|         for subclass in Provider.objects.get_queryset()._get_subclasses_recurse(Provider): |         candidates = [] | ||||||
|             # We don't care about recursion, skip nested models |         base_class = Provider | ||||||
|             if LOOKUP_SEP in subclass: |         for subclass in base_class.objects.get_queryset()._get_subclasses_recurse(base_class): | ||||||
|                 continue |             parent = self.provider | ||||||
|  |             for level in subclass.split(LOOKUP_SEP): | ||||||
|                 try: |                 try: | ||||||
|                 return getattr(self.provider, subclass) |                     parent = getattr(parent, level) | ||||||
|                 except AttributeError: |                 except AttributeError: | ||||||
|                 pass |                     break | ||||||
|  |             if parent in candidates: | ||||||
|  |                 continue | ||||||
|  |             idx = subclass.count(LOOKUP_SEP) | ||||||
|  |             if type(parent) is not base_class: | ||||||
|  |                 idx += 1 | ||||||
|  |             candidates.insert(idx, parent) | ||||||
|  |         if not candidates: | ||||||
|             return None |             return None | ||||||
|  |         return candidates[-1] | ||||||
|  |  | ||||||
|     def __str__(self): |     def __str__(self): | ||||||
|         return str(self.name) |         return str(self.name) | ||||||
| @ -901,7 +908,7 @@ class PropertyMapping(SerializerModel, ManagedModel): | |||||||
|         except ControlFlowException as exc: |         except ControlFlowException as exc: | ||||||
|             raise exc |             raise exc | ||||||
|         except Exception as exc: |         except Exception as exc: | ||||||
|             raise PropertyMappingExpressionException(self, exc) from exc |             raise PropertyMappingExpressionException(exc, self) from exc | ||||||
|  |  | ||||||
|     def __str__(self): |     def __str__(self): | ||||||
|         return f"Property Mapping {self.name}" |         return f"Property Mapping {self.name}" | ||||||
|  | |||||||
| @ -9,9 +9,12 @@ from rest_framework.test import APITestCase | |||||||
|  |  | ||||||
| from authentik.core.models import Application | from authentik.core.models import Application | ||||||
| from authentik.core.tests.utils import create_test_admin_user, create_test_flow | from authentik.core.tests.utils import create_test_admin_user, create_test_flow | ||||||
|  | from authentik.lib.generators import generate_id | ||||||
| from authentik.policies.dummy.models import DummyPolicy | from authentik.policies.dummy.models import DummyPolicy | ||||||
| from authentik.policies.models import PolicyBinding | from authentik.policies.models import PolicyBinding | ||||||
| from authentik.providers.oauth2.models import OAuth2Provider | from authentik.providers.oauth2.models import OAuth2Provider, RedirectURI, RedirectURIMatchingMode | ||||||
|  | from authentik.providers.proxy.models import ProxyProvider | ||||||
|  | from authentik.providers.saml.models import SAMLProvider | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestApplicationsAPI(APITestCase): | class TestApplicationsAPI(APITestCase): | ||||||
| @ -21,7 +24,7 @@ class TestApplicationsAPI(APITestCase): | |||||||
|         self.user = create_test_admin_user() |         self.user = create_test_admin_user() | ||||||
|         self.provider = OAuth2Provider.objects.create( |         self.provider = OAuth2Provider.objects.create( | ||||||
|             name="test", |             name="test", | ||||||
|             redirect_uris="http://some-other-domain", |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://some-other-domain")], | ||||||
|             authorization_flow=create_test_flow(), |             authorization_flow=create_test_flow(), | ||||||
|         ) |         ) | ||||||
|         self.allowed: Application = Application.objects.create( |         self.allowed: Application = Application.objects.create( | ||||||
| @ -222,3 +225,31 @@ class TestApplicationsAPI(APITestCase): | |||||||
|                 ], |                 ], | ||||||
|             }, |             }, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |     def test_get_provider(self): | ||||||
|  |         """Ensure that proxy providers (at the time of writing that is the only provider | ||||||
|  |         that inherits from another proxy type (OAuth) instead of inheriting from the root | ||||||
|  |         provider class) is correctly looked up and selected from the database""" | ||||||
|  |         slug = generate_id() | ||||||
|  |         provider = ProxyProvider.objects.create(name=generate_id()) | ||||||
|  |         Application.objects.create( | ||||||
|  |             name=generate_id(), | ||||||
|  |             slug=slug, | ||||||
|  |             provider=provider, | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(Application.objects.get(slug=slug).get_provider(), provider) | ||||||
|  |         self.assertEqual( | ||||||
|  |             Application.objects.with_provider().get(slug=slug).get_provider(), provider | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         slug = generate_id() | ||||||
|  |         provider = SAMLProvider.objects.create(name=generate_id()) | ||||||
|  |         Application.objects.create( | ||||||
|  |             name=generate_id(), | ||||||
|  |             slug=slug, | ||||||
|  |             provider=provider, | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(Application.objects.get(slug=slug).get_provider(), provider) | ||||||
|  |         self.assertEqual( | ||||||
|  |             Application.objects.with_provider().get(slug=slug).get_provider(), provider | ||||||
|  |         ) | ||||||
|  | |||||||
| @ -3,10 +3,10 @@ | |||||||
| from json import loads | from json import loads | ||||||
|  |  | ||||||
| from django.urls import reverse | from django.urls import reverse | ||||||
|  | from guardian.shortcuts import assign_perm | ||||||
| from rest_framework.test import APITestCase | from rest_framework.test import APITestCase | ||||||
|  |  | ||||||
| from authentik.core.models import User | from authentik.core.tests.utils import create_test_admin_user, create_test_user | ||||||
| from authentik.core.tests.utils import create_test_admin_user |  | ||||||
| from authentik.tenants.utils import get_current_tenant | from authentik.tenants.utils import get_current_tenant | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -15,7 +15,7 @@ class TestImpersonation(APITestCase): | |||||||
|  |  | ||||||
|     def setUp(self) -> None: |     def setUp(self) -> None: | ||||||
|         super().setUp() |         super().setUp() | ||||||
|         self.other_user = User.objects.create(username="to-impersonate") |         self.other_user = create_test_user() | ||||||
|         self.user = create_test_admin_user() |         self.user = create_test_admin_user() | ||||||
|  |  | ||||||
|     def test_impersonate_simple(self): |     def test_impersonate_simple(self): | ||||||
| @ -44,6 +44,46 @@ class TestImpersonation(APITestCase): | |||||||
|         self.assertEqual(response_body["user"]["username"], self.user.username) |         self.assertEqual(response_body["user"]["username"], self.user.username) | ||||||
|         self.assertNotIn("original", response_body) |         self.assertNotIn("original", response_body) | ||||||
|  |  | ||||||
|  |     def test_impersonate_global(self): | ||||||
|  |         """Test impersonation with global permissions""" | ||||||
|  |         new_user = create_test_user() | ||||||
|  |         assign_perm("authentik_core.impersonate", new_user) | ||||||
|  |         assign_perm("authentik_core.view_user", new_user) | ||||||
|  |         self.client.force_login(new_user) | ||||||
|  |  | ||||||
|  |         response = self.client.post( | ||||||
|  |             reverse( | ||||||
|  |                 "authentik_api:user-impersonate", | ||||||
|  |                 kwargs={"pk": self.other_user.pk}, | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(response.status_code, 201) | ||||||
|  |  | ||||||
|  |         response = self.client.get(reverse("authentik_api:user-me")) | ||||||
|  |         response_body = loads(response.content.decode()) | ||||||
|  |         self.assertEqual(response_body["user"]["username"], self.other_user.username) | ||||||
|  |         self.assertEqual(response_body["original"]["username"], new_user.username) | ||||||
|  |  | ||||||
|  |     def test_impersonate_scoped(self): | ||||||
|  |         """Test impersonation with scoped permissions""" | ||||||
|  |         new_user = create_test_user() | ||||||
|  |         assign_perm("authentik_core.impersonate", new_user, self.other_user) | ||||||
|  |         assign_perm("authentik_core.view_user", new_user, self.other_user) | ||||||
|  |         self.client.force_login(new_user) | ||||||
|  |  | ||||||
|  |         response = self.client.post( | ||||||
|  |             reverse( | ||||||
|  |                 "authentik_api:user-impersonate", | ||||||
|  |                 kwargs={"pk": self.other_user.pk}, | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(response.status_code, 201) | ||||||
|  |  | ||||||
|  |         response = self.client.get(reverse("authentik_api:user-me")) | ||||||
|  |         response_body = loads(response.content.decode()) | ||||||
|  |         self.assertEqual(response_body["user"]["username"], self.other_user.username) | ||||||
|  |         self.assertEqual(response_body["original"]["username"], new_user.username) | ||||||
|  |  | ||||||
|     def test_impersonate_denied(self): |     def test_impersonate_denied(self): | ||||||
|         """test impersonation without permissions""" |         """test impersonation without permissions""" | ||||||
|         self.client.force_login(self.other_user) |         self.client.force_login(self.other_user) | ||||||
|  | |||||||
| @ -31,6 +31,7 @@ class TestTransactionalApplicationsAPI(APITestCase): | |||||||
|                 "provider": { |                 "provider": { | ||||||
|                     "name": uid, |                     "name": uid, | ||||||
|                     "authorization_flow": str(authorization_flow.pk), |                     "authorization_flow": str(authorization_flow.pk), | ||||||
|  |                     "redirect_uris": [], | ||||||
|                 }, |                 }, | ||||||
|             }, |             }, | ||||||
|         ) |         ) | ||||||
| @ -56,6 +57,7 @@ class TestTransactionalApplicationsAPI(APITestCase): | |||||||
|                 "provider": { |                 "provider": { | ||||||
|                     "name": uid, |                     "name": uid, | ||||||
|                     "authorization_flow": "", |                     "authorization_flow": "", | ||||||
|  |                     "redirect_uris": [], | ||||||
|                 }, |                 }, | ||||||
|             }, |             }, | ||||||
|         ) |         ) | ||||||
|  | |||||||
| @ -18,7 +18,7 @@ from authentik.crypto.models import CertificateKeyPair | |||||||
| from authentik.crypto.tasks import MANAGED_DISCOVERED, certificate_discovery | from authentik.crypto.tasks import MANAGED_DISCOVERED, certificate_discovery | ||||||
| from authentik.lib.config import CONFIG | from authentik.lib.config import CONFIG | ||||||
| from authentik.lib.generators import generate_id, generate_key | from authentik.lib.generators import generate_id, generate_key | ||||||
| from authentik.providers.oauth2.models import OAuth2Provider | from authentik.providers.oauth2.models import OAuth2Provider, RedirectURI, RedirectURIMatchingMode | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestCrypto(APITestCase): | class TestCrypto(APITestCase): | ||||||
| @ -263,7 +263,7 @@ class TestCrypto(APITestCase): | |||||||
|             client_id="test", |             client_id="test", | ||||||
|             client_secret=generate_key(), |             client_secret=generate_key(), | ||||||
|             authorization_flow=create_test_flow(), |             authorization_flow=create_test_flow(), | ||||||
|             redirect_uris="http://localhost", |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")], | ||||||
|             signing_key=keypair, |             signing_key=keypair, | ||||||
|         ) |         ) | ||||||
|         response = self.client.get( |         response = self.client.get( | ||||||
| @ -295,7 +295,7 @@ class TestCrypto(APITestCase): | |||||||
|             client_id="test", |             client_id="test", | ||||||
|             client_secret=generate_key(), |             client_secret=generate_key(), | ||||||
|             authorization_flow=create_test_flow(), |             authorization_flow=create_test_flow(), | ||||||
|             redirect_uris="http://localhost", |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")], | ||||||
|             signing_key=keypair, |             signing_key=keypair, | ||||||
|         ) |         ) | ||||||
|         response = self.client.get( |         response = self.client.get( | ||||||
|  | |||||||
| @ -18,7 +18,7 @@ from authentik.core.api.used_by import UsedByMixin | |||||||
| from authentik.core.api.utils import ModelSerializer, PassiveSerializer | from authentik.core.api.utils import ModelSerializer, PassiveSerializer | ||||||
| from authentik.core.models import User, UserTypes | from authentik.core.models import User, UserTypes | ||||||
| from authentik.enterprise.license import LicenseKey, LicenseSummarySerializer | from authentik.enterprise.license import LicenseKey, LicenseSummarySerializer | ||||||
| from authentik.enterprise.models import License, LicenseUsageStatus | from authentik.enterprise.models import License | ||||||
| from authentik.rbac.decorators import permission_required | from authentik.rbac.decorators import permission_required | ||||||
| from authentik.tenants.utils import get_unique_identifier | from authentik.tenants.utils import get_unique_identifier | ||||||
|  |  | ||||||
| @ -29,7 +29,7 @@ class EnterpriseRequiredMixin: | |||||||
|  |  | ||||||
|     def validate(self, attrs: dict) -> dict: |     def validate(self, attrs: dict) -> dict: | ||||||
|         """Check that a valid license exists""" |         """Check that a valid license exists""" | ||||||
|         if LicenseKey.cached_summary().status != LicenseUsageStatus.UNLICENSED: |         if not LicenseKey.cached_summary().status.is_valid: | ||||||
|             raise ValidationError(_("Enterprise is required to create/update this object.")) |             raise ValidationError(_("Enterprise is required to create/update this object.")) | ||||||
|         return super().validate(attrs) |         return super().validate(attrs) | ||||||
|  |  | ||||||
|  | |||||||
| @ -25,4 +25,4 @@ class AuthentikEnterpriseConfig(EnterpriseConfig): | |||||||
|         """Actual enterprise check, cached""" |         """Actual enterprise check, cached""" | ||||||
|         from authentik.enterprise.license import LicenseKey |         from authentik.enterprise.license import LicenseKey | ||||||
|  |  | ||||||
|         return LicenseKey.cached_summary().status |         return LicenseKey.cached_summary().status.is_valid | ||||||
|  | |||||||
| @ -117,10 +117,13 @@ class LicenseKey: | |||||||
|                     our_cert.public_key(), |                     our_cert.public_key(), | ||||||
|                     algorithms=["ES512"], |                     algorithms=["ES512"], | ||||||
|                     audience=get_license_aud(), |                     audience=get_license_aud(), | ||||||
|                     options={"verify_exp": check_expiry}, |                     options={"verify_exp": check_expiry, "verify_signature": check_expiry}, | ||||||
|                 ), |                 ), | ||||||
|             ) |             ) | ||||||
|         except PyJWTError: |         except PyJWTError: | ||||||
|  |             unverified = decode(jwt, options={"verify_signature": False}) | ||||||
|  |             if unverified["aud"] != get_license_aud(): | ||||||
|  |                 raise ValidationError("Invalid Install ID in license") from None | ||||||
|             raise ValidationError("Unable to verify license") from None |             raise ValidationError("Unable to verify license") from None | ||||||
|         return body |         return body | ||||||
|  |  | ||||||
| @ -134,7 +137,7 @@ class LicenseKey: | |||||||
|             exp_ts = int(mktime(lic.expiry.timetuple())) |             exp_ts = int(mktime(lic.expiry.timetuple())) | ||||||
|             if total.exp == 0: |             if total.exp == 0: | ||||||
|                 total.exp = exp_ts |                 total.exp = exp_ts | ||||||
|             total.exp = min(total.exp, exp_ts) |             total.exp = max(total.exp, exp_ts) | ||||||
|             total.license_flags.extend(lic.status.license_flags) |             total.license_flags.extend(lic.status.license_flags) | ||||||
|         return total |         return total | ||||||
|  |  | ||||||
|  | |||||||
| @ -3,7 +3,7 @@ | |||||||
| from datetime import datetime | from datetime import datetime | ||||||
|  |  | ||||||
| from django.core.cache import cache | from django.core.cache import cache | ||||||
| from django.db.models.signals import post_save, pre_save | from django.db.models.signals import post_delete, post_save, pre_save | ||||||
| from django.dispatch import receiver | from django.dispatch import receiver | ||||||
| from django.utils.timezone import get_current_timezone | from django.utils.timezone import get_current_timezone | ||||||
|  |  | ||||||
| @ -27,3 +27,9 @@ def post_save_license(sender: type[License], instance: License, **_): | |||||||
|     """Trigger license usage calculation when license is saved""" |     """Trigger license usage calculation when license is saved""" | ||||||
|     cache.delete(CACHE_KEY_ENTERPRISE_LICENSE) |     cache.delete(CACHE_KEY_ENTERPRISE_LICENSE) | ||||||
|     enterprise_update_usage.delay() |     enterprise_update_usage.delay() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @receiver(post_delete, sender=License) | ||||||
|  | def post_delete_license(sender: type[License], instance: License, **_): | ||||||
|  |     """Clear license cache when license is deleted""" | ||||||
|  |     cache.delete(CACHE_KEY_ENTERPRISE_LICENSE) | ||||||
|  | |||||||
| @ -69,8 +69,5 @@ class NotificationViewSet( | |||||||
|     @action(detail=False, methods=["post"]) |     @action(detail=False, methods=["post"]) | ||||||
|     def mark_all_seen(self, request: Request) -> Response: |     def mark_all_seen(self, request: Request) -> Response: | ||||||
|         """Mark all the user's notifications as seen""" |         """Mark all the user's notifications as seen""" | ||||||
|         notifications = Notification.objects.filter(user=request.user) |         Notification.objects.filter(user=request.user, seen=False).update(seen=True) | ||||||
|         for notification in notifications: |  | ||||||
|             notification.seen = True |  | ||||||
|         Notification.objects.bulk_update(notifications, ["seen"]) |  | ||||||
|         return Response({}, status=204) |         return Response({}, status=204) | ||||||
|  | |||||||
| @ -49,6 +49,7 @@ from authentik.policies.models import PolicyBindingModel | |||||||
| from authentik.root.middleware import ClientIPMiddleware | from authentik.root.middleware import ClientIPMiddleware | ||||||
| from authentik.stages.email.utils import TemplateEmailMessage | from authentik.stages.email.utils import TemplateEmailMessage | ||||||
| from authentik.tenants.models import Tenant | from authentik.tenants.models import Tenant | ||||||
|  | from authentik.tenants.utils import get_current_tenant | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
| DISCORD_FIELD_LIMIT = 25 | DISCORD_FIELD_LIMIT = 25 | ||||||
| @ -58,6 +59,10 @@ NOTIFICATION_SUMMARY_LENGTH = 75 | |||||||
| def default_event_duration(): | def default_event_duration(): | ||||||
|     """Default duration an Event is saved. |     """Default duration an Event is saved. | ||||||
|     This is used as a fallback when no brand is available""" |     This is used as a fallback when no brand is available""" | ||||||
|  |     try: | ||||||
|  |         tenant = get_current_tenant() | ||||||
|  |         return now() + timedelta_from_string(tenant.event_retention) | ||||||
|  |     except Tenant.DoesNotExist: | ||||||
|         return now() + timedelta(days=365) |         return now() + timedelta(days=365) | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -245,12 +250,6 @@ class Event(SerializerModel, ExpiringModel): | |||||||
|             if QS_QUERY in self.context["http_request"]["args"]: |             if QS_QUERY in self.context["http_request"]["args"]: | ||||||
|                 wrapped = self.context["http_request"]["args"][QS_QUERY] |                 wrapped = self.context["http_request"]["args"][QS_QUERY] | ||||||
|                 self.context["http_request"]["args"] = cleanse_dict(QueryDict(wrapped)) |                 self.context["http_request"]["args"] = cleanse_dict(QueryDict(wrapped)) | ||||||
|         if hasattr(request, "tenant"): |  | ||||||
|             tenant: Tenant = request.tenant |  | ||||||
|             # Because self.created only gets set on save, we can't use it's value here |  | ||||||
|             # hence we set self.created to now and then use it |  | ||||||
|             self.created = now() |  | ||||||
|             self.expires = self.created + timedelta_from_string(tenant.event_retention) |  | ||||||
|         if hasattr(request, "brand"): |         if hasattr(request, "brand"): | ||||||
|             brand: Brand = request.brand |             brand: Brand = request.brand | ||||||
|             self.brand = sanitize_dict(model_to_dict(brand)) |             self.brand = sanitize_dict(model_to_dict(brand)) | ||||||
|  | |||||||
| @ -1,13 +1,16 @@ | |||||||
| """authentik events signal listener""" | """authentik events signal listener""" | ||||||
|  |  | ||||||
|  | from importlib import import_module | ||||||
| from typing import Any | from typing import Any | ||||||
|  |  | ||||||
|  | from django.conf import settings | ||||||
| from django.contrib.auth.signals import user_logged_in, user_logged_out | from django.contrib.auth.signals import user_logged_in, user_logged_out | ||||||
| from django.db.models.signals import post_save, pre_delete | from django.db.models.signals import post_save, pre_delete | ||||||
| from django.dispatch import receiver | from django.dispatch import receiver | ||||||
| from django.http import HttpRequest | from django.http import HttpRequest | ||||||
|  | from rest_framework.request import Request | ||||||
|  |  | ||||||
| from authentik.core.models import User | from authentik.core.models import AuthenticatedSession, User | ||||||
| from authentik.core.signals import login_failed, password_changed | from authentik.core.signals import login_failed, password_changed | ||||||
| from authentik.events.apps import SYSTEM_TASK_STATUS | from authentik.events.apps import SYSTEM_TASK_STATUS | ||||||
| from authentik.events.models import Event, EventAction, SystemTask | from authentik.events.models import Event, EventAction, SystemTask | ||||||
| @ -23,6 +26,7 @@ from authentik.stages.user_write.signals import user_write | |||||||
| from authentik.tenants.utils import get_current_tenant | from authentik.tenants.utils import get_current_tenant | ||||||
|  |  | ||||||
| SESSION_LOGIN_EVENT = "login_event" | SESSION_LOGIN_EVENT = "login_event" | ||||||
|  | _session_engine = import_module(settings.SESSION_ENGINE) | ||||||
|  |  | ||||||
|  |  | ||||||
| @receiver(user_logged_in) | @receiver(user_logged_in) | ||||||
| @ -40,11 +44,20 @@ def on_user_logged_in(sender, request: HttpRequest, user: User, **_): | |||||||
|             kwargs[PLAN_CONTEXT_METHOD_ARGS] = flow_plan.context.get(PLAN_CONTEXT_METHOD_ARGS, {}) |             kwargs[PLAN_CONTEXT_METHOD_ARGS] = flow_plan.context.get(PLAN_CONTEXT_METHOD_ARGS, {}) | ||||||
|     event = Event.new(EventAction.LOGIN, **kwargs).from_http(request, user=user) |     event = Event.new(EventAction.LOGIN, **kwargs).from_http(request, user=user) | ||||||
|     request.session[SESSION_LOGIN_EVENT] = event |     request.session[SESSION_LOGIN_EVENT] = event | ||||||
|  |     request.session.save() | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_login_event(request: HttpRequest) -> Event | None: | def get_login_event(request_or_session: HttpRequest | AuthenticatedSession | None) -> Event | None: | ||||||
|     """Wrapper to get login event that can be mocked in tests""" |     """Wrapper to get login event that can be mocked in tests""" | ||||||
|     return request.session.get(SESSION_LOGIN_EVENT, None) |     session = None | ||||||
|  |     if not request_or_session: | ||||||
|  |         return None | ||||||
|  |     if isinstance(request_or_session, HttpRequest | Request): | ||||||
|  |         session = request_or_session.session | ||||||
|  |     if isinstance(request_or_session, AuthenticatedSession): | ||||||
|  |         SessionStore = _session_engine.SessionStore | ||||||
|  |         session = SessionStore(request_or_session.session_key) | ||||||
|  |     return session.get(SESSION_LOGIN_EVENT, None) | ||||||
|  |  | ||||||
|  |  | ||||||
| @receiver(user_logged_out) | @receiver(user_logged_out) | ||||||
|  | |||||||
| @ -6,6 +6,7 @@ from django.db.models import Model | |||||||
| from django.test import TestCase | from django.test import TestCase | ||||||
|  |  | ||||||
| from authentik.core.models import default_token_key | from authentik.core.models import default_token_key | ||||||
|  | from authentik.events.models import default_event_duration | ||||||
| from authentik.lib.utils.reflection import get_apps | from authentik.lib.utils.reflection import get_apps | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -20,7 +21,7 @@ def model_tester_factory(test_model: type[Model]) -> Callable: | |||||||
|         allowed = 0 |         allowed = 0 | ||||||
|         # Token-like objects need to lookup the current tenant to get the default token length |         # Token-like objects need to lookup the current tenant to get the default token length | ||||||
|         for field in test_model._meta.fields: |         for field in test_model._meta.fields: | ||||||
|             if field.default == default_token_key: |             if field.default in [default_token_key, default_event_duration]: | ||||||
|                 allowed += 1 |                 allowed += 1 | ||||||
|         with self.assertNumQueries(allowed): |         with self.assertNumQueries(allowed): | ||||||
|             str(test_model()) |             str(test_model()) | ||||||
|  | |||||||
| @ -2,7 +2,8 @@ | |||||||
|  |  | ||||||
| from unittest.mock import MagicMock, patch | from unittest.mock import MagicMock, patch | ||||||
|  |  | ||||||
| from django.test import TestCase | from django.urls import reverse | ||||||
|  | from rest_framework.test import APITestCase | ||||||
|  |  | ||||||
| from authentik.core.models import Group, User | from authentik.core.models import Group, User | ||||||
| from authentik.events.models import ( | from authentik.events.models import ( | ||||||
| @ -10,6 +11,7 @@ from authentik.events.models import ( | |||||||
|     EventAction, |     EventAction, | ||||||
|     Notification, |     Notification, | ||||||
|     NotificationRule, |     NotificationRule, | ||||||
|  |     NotificationSeverity, | ||||||
|     NotificationTransport, |     NotificationTransport, | ||||||
|     NotificationWebhookMapping, |     NotificationWebhookMapping, | ||||||
|     TransportMode, |     TransportMode, | ||||||
| @ -20,7 +22,7 @@ from authentik.policies.exceptions import PolicyException | |||||||
| from authentik.policies.models import PolicyBinding | from authentik.policies.models import PolicyBinding | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestEventsNotifications(TestCase): | class TestEventsNotifications(APITestCase): | ||||||
|     """Test Event Notifications""" |     """Test Event Notifications""" | ||||||
|  |  | ||||||
|     def setUp(self) -> None: |     def setUp(self) -> None: | ||||||
| @ -131,3 +133,15 @@ class TestEventsNotifications(TestCase): | |||||||
|         Notification.objects.all().delete() |         Notification.objects.all().delete() | ||||||
|         Event.new(EventAction.CUSTOM_PREFIX).save() |         Event.new(EventAction.CUSTOM_PREFIX).save() | ||||||
|         self.assertEqual(Notification.objects.first().body, "foo") |         self.assertEqual(Notification.objects.first().body, "foo") | ||||||
|  |  | ||||||
|  |     def test_api_mark_all_seen(self): | ||||||
|  |         """Test mark_all_seen""" | ||||||
|  |         self.client.force_login(self.user) | ||||||
|  |  | ||||||
|  |         Notification.objects.create( | ||||||
|  |             severity=NotificationSeverity.NOTICE, body="foo", user=self.user, seen=False | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         response = self.client.post(reverse("authentik_api:notification-mark-all-seen")) | ||||||
|  |         self.assertEqual(response.status_code, 204) | ||||||
|  |         self.assertFalse(Notification.objects.filter(body="foo", seen=False).exists()) | ||||||
|  | |||||||
| @ -2,7 +2,6 @@ | |||||||
|  |  | ||||||
| import re | import re | ||||||
| import socket | import socket | ||||||
| from collections.abc import Iterable |  | ||||||
| from ipaddress import ip_address, ip_network | from ipaddress import ip_address, ip_network | ||||||
| from textwrap import indent | from textwrap import indent | ||||||
| from types import CodeType | from types import CodeType | ||||||
| @ -28,6 +27,12 @@ from authentik.stages.authenticator import devices_for_user | |||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  |  | ||||||
|  | ARG_SANITIZE = re.compile(r"[:.-]") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def sanitize_arg(arg_name: str) -> str: | ||||||
|  |     return re.sub(ARG_SANITIZE, "_", arg_name) | ||||||
|  |  | ||||||
|  |  | ||||||
| class BaseEvaluator: | class BaseEvaluator: | ||||||
|     """Validate and evaluate python-based expressions""" |     """Validate and evaluate python-based expressions""" | ||||||
| @ -177,9 +182,9 @@ class BaseEvaluator: | |||||||
|         proc = PolicyProcess(PolicyBinding(policy=policy), request=req, connection=None) |         proc = PolicyProcess(PolicyBinding(policy=policy), request=req, connection=None) | ||||||
|         return proc.profiling_wrapper() |         return proc.profiling_wrapper() | ||||||
|  |  | ||||||
|     def wrap_expression(self, expression: str, params: Iterable[str]) -> str: |     def wrap_expression(self, expression: str) -> str: | ||||||
|         """Wrap expression in a function, call it, and save the result as `result`""" |         """Wrap expression in a function, call it, and save the result as `result`""" | ||||||
|         handler_signature = ",".join(params) |         handler_signature = ",".join(sanitize_arg(x) for x in self._context.keys()) | ||||||
|         full_expression = "" |         full_expression = "" | ||||||
|         full_expression += f"def handler({handler_signature}):\n" |         full_expression += f"def handler({handler_signature}):\n" | ||||||
|         full_expression += indent(expression, "    ") |         full_expression += indent(expression, "    ") | ||||||
| @ -188,8 +193,8 @@ class BaseEvaluator: | |||||||
|  |  | ||||||
|     def compile(self, expression: str) -> CodeType: |     def compile(self, expression: str) -> CodeType: | ||||||
|         """Parse expression. Raises SyntaxError or ValueError if the syntax is incorrect.""" |         """Parse expression. Raises SyntaxError or ValueError if the syntax is incorrect.""" | ||||||
|         param_keys = self._context.keys() |         expression = self.wrap_expression(expression) | ||||||
|         return compile(self.wrap_expression(expression, param_keys), self._filename, "exec") |         return compile(expression, self._filename, "exec") | ||||||
|  |  | ||||||
|     def evaluate(self, expression_source: str) -> Any: |     def evaluate(self, expression_source: str) -> Any: | ||||||
|         """Parse and evaluate expression. If the syntax is incorrect, a SyntaxError is raised. |         """Parse and evaluate expression. If the syntax is incorrect, a SyntaxError is raised. | ||||||
| @ -205,7 +210,7 @@ class BaseEvaluator: | |||||||
|                 self.handle_error(exc, expression_source) |                 self.handle_error(exc, expression_source) | ||||||
|                 raise exc |                 raise exc | ||||||
|             try: |             try: | ||||||
|                 _locals = self._context |                 _locals = {sanitize_arg(x): y for x, y in self._context.items()} | ||||||
|                 # Yes this is an exec, yes it is potentially bad. Since we limit what variables are |                 # Yes this is an exec, yes it is potentially bad. Since we limit what variables are | ||||||
|                 # available here, and these policies can only be edited by admins, this is a risk |                 # available here, and these policies can only be edited by admins, this is a risk | ||||||
|                 # we're willing to take. |                 # we're willing to take. | ||||||
|  | |||||||
| @ -30,6 +30,11 @@ class TestHTTP(TestCase): | |||||||
|         request = self.factory.get("/", HTTP_X_FORWARDED_FOR="127.0.0.2") |         request = self.factory.get("/", HTTP_X_FORWARDED_FOR="127.0.0.2") | ||||||
|         self.assertEqual(ClientIPMiddleware.get_client_ip(request), "127.0.0.2") |         self.assertEqual(ClientIPMiddleware.get_client_ip(request), "127.0.0.2") | ||||||
|  |  | ||||||
|  |     def test_forward_for_invalid(self): | ||||||
|  |         """Test invalid forward for""" | ||||||
|  |         request = self.factory.get("/", HTTP_X_FORWARDED_FOR="foobar") | ||||||
|  |         self.assertEqual(ClientIPMiddleware.get_client_ip(request), ClientIPMiddleware.default_ip) | ||||||
|  |  | ||||||
|     def test_fake_outpost(self): |     def test_fake_outpost(self): | ||||||
|         """Test faked IP which is overridden by an outpost""" |         """Test faked IP which is overridden by an outpost""" | ||||||
|         token = Token.objects.create( |         token = Token.objects.create( | ||||||
| @ -53,6 +58,17 @@ class TestHTTP(TestCase): | |||||||
|             }, |             }, | ||||||
|         ) |         ) | ||||||
|         self.assertEqual(ClientIPMiddleware.get_client_ip(request), "127.0.0.1") |         self.assertEqual(ClientIPMiddleware.get_client_ip(request), "127.0.0.1") | ||||||
|  |         # Invalid, not a real IP | ||||||
|  |         self.user.type = UserTypes.INTERNAL_SERVICE_ACCOUNT | ||||||
|  |         self.user.save() | ||||||
|  |         request = self.factory.get( | ||||||
|  |             "/", | ||||||
|  |             **{ | ||||||
|  |                 ClientIPMiddleware.outpost_remote_ip_header: "foobar", | ||||||
|  |                 ClientIPMiddleware.outpost_token_header: token.key, | ||||||
|  |             }, | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(ClientIPMiddleware.get_client_ip(request), "127.0.0.1") | ||||||
|         # Valid |         # Valid | ||||||
|         self.user.type = UserTypes.INTERNAL_SERVICE_ACCOUNT |         self.user.type = UserTypes.INTERNAL_SERVICE_ACCOUNT | ||||||
|         self.user.save() |         self.user.save() | ||||||
|  | |||||||
| @ -21,7 +21,14 @@ class DebugSession(Session): | |||||||
|  |  | ||||||
|     def send(self, req: PreparedRequest, *args, **kwargs): |     def send(self, req: PreparedRequest, *args, **kwargs): | ||||||
|         request_id = str(uuid4()) |         request_id = str(uuid4()) | ||||||
|         LOGGER.debug("HTTP request sent", uid=request_id, path=req.path_url, headers=req.headers) |         LOGGER.debug( | ||||||
|  |             "HTTP request sent", | ||||||
|  |             uid=request_id, | ||||||
|  |             url=req.url, | ||||||
|  |             method=req.method, | ||||||
|  |             headers=req.headers, | ||||||
|  |             body=req.body, | ||||||
|  |         ) | ||||||
|         resp = super().send(req, *args, **kwargs) |         resp = super().send(req, *args, **kwargs) | ||||||
|         LOGGER.debug( |         LOGGER.debug( | ||||||
|             "HTTP response received", |             "HTTP response received", | ||||||
|  | |||||||
| @ -108,7 +108,7 @@ class EventMatcherPolicy(Policy): | |||||||
|                 result=result, |                 result=result, | ||||||
|             ) |             ) | ||||||
|             matches.append(result) |             matches.append(result) | ||||||
|         passing = any(x.passing for x in matches) |         passing = all(x.passing for x in matches) | ||||||
|         messages = chain(*[x.messages for x in matches]) |         messages = chain(*[x.messages for x in matches]) | ||||||
|         result = PolicyResult(passing, *messages) |         result = PolicyResult(passing, *messages) | ||||||
|         result.source_results = matches |         result.source_results = matches | ||||||
|  | |||||||
| @ -77,11 +77,24 @@ class TestEventMatcherPolicy(TestCase): | |||||||
|         request = PolicyRequest(get_anonymous_user()) |         request = PolicyRequest(get_anonymous_user()) | ||||||
|         request.context["event"] = event |         request.context["event"] = event | ||||||
|         policy: EventMatcherPolicy = EventMatcherPolicy.objects.create( |         policy: EventMatcherPolicy = EventMatcherPolicy.objects.create( | ||||||
|             client_ip="1.2.3.5", app="bar" |             client_ip="1.2.3.5", app="foo" | ||||||
|         ) |         ) | ||||||
|         response = policy.passes(request) |         response = policy.passes(request) | ||||||
|         self.assertFalse(response.passing) |         self.assertFalse(response.passing) | ||||||
|  |  | ||||||
|  |     def test_multiple(self): | ||||||
|  |         """Test multiple""" | ||||||
|  |         event = Event.new(EventAction.LOGIN) | ||||||
|  |         event.app = "foo" | ||||||
|  |         event.client_ip = "1.2.3.4" | ||||||
|  |         request = PolicyRequest(get_anonymous_user()) | ||||||
|  |         request.context["event"] = event | ||||||
|  |         policy: EventMatcherPolicy = EventMatcherPolicy.objects.create( | ||||||
|  |             client_ip="1.2.3.4", app="foo" | ||||||
|  |         ) | ||||||
|  |         response = policy.passes(request) | ||||||
|  |         self.assertTrue(response.passing) | ||||||
|  |  | ||||||
|     def test_invalid(self): |     def test_invalid(self): | ||||||
|         """Test passing event""" |         """Test passing event""" | ||||||
|         request = PolicyRequest(get_anonymous_user()) |         request = PolicyRequest(get_anonymous_user()) | ||||||
|  | |||||||
| @ -4,13 +4,13 @@ from django.apps.registry import Apps | |||||||
| from django.db.backends.base.schema import BaseDatabaseSchemaEditor | from django.db.backends.base.schema import BaseDatabaseSchemaEditor | ||||||
|  |  | ||||||
| from django.db import migrations | from django.db import migrations | ||||||
| from django.contrib.auth.management import create_permissions |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def migrate_search_group(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): | def migrate_search_group(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): | ||||||
|     from guardian.shortcuts import assign_perm |  | ||||||
|     from authentik.core.models import User |     from authentik.core.models import User | ||||||
|     from django.apps import apps as real_apps |     from django.apps import apps as real_apps | ||||||
|  |     from django.contrib.auth.management import create_permissions | ||||||
|  |     from guardian.shortcuts import UserObjectPermission | ||||||
|  |  | ||||||
|     db_alias = schema_editor.connection.alias |     db_alias = schema_editor.connection.alias | ||||||
|  |  | ||||||
| @ -20,14 +20,25 @@ def migrate_search_group(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): | |||||||
|     create_permissions(real_apps.get_app_config("authentik_providers_ldap"), using=db_alias) |     create_permissions(real_apps.get_app_config("authentik_providers_ldap"), using=db_alias) | ||||||
|  |  | ||||||
|     LDAPProvider = apps.get_model("authentik_providers_ldap", "ldapprovider") |     LDAPProvider = apps.get_model("authentik_providers_ldap", "ldapprovider") | ||||||
|  |     Permission = apps.get_model("auth", "Permission") | ||||||
|  |     UserObjectPermission = apps.get_model("guardian", "UserObjectPermission") | ||||||
|  |     ContentType = apps.get_model("contenttypes", "ContentType") | ||||||
|  |  | ||||||
|  |     new_prem = Permission.objects.using(db_alias).get(codename="search_full_directory") | ||||||
|  |     ct = ContentType.objects.using(db_alias).get( | ||||||
|  |         app_label="authentik_providers_ldap", | ||||||
|  |         model="ldapprovider", | ||||||
|  |     ) | ||||||
|  |  | ||||||
|     for provider in LDAPProvider.objects.using(db_alias).all(): |     for provider in LDAPProvider.objects.using(db_alias).all(): | ||||||
|         for user_pk in ( |         if not provider.search_group: | ||||||
|             provider.search_group.users.using(db_alias).all().values_list("pk", flat=True) |             continue | ||||||
|         ): |         for user in provider.search_group.users.using(db_alias).all(): | ||||||
|             # We need the correct user model instance to assign the permission |             UserObjectPermission.objects.using(db_alias).create( | ||||||
|             assign_perm( |                 user=user, | ||||||
|                 "search_full_directory", User.objects.using(db_alias).get(pk=user_pk), provider |                 permission=new_prem, | ||||||
|  |                 object_pk=provider.pk, | ||||||
|  |                 content_type=ct, | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -35,6 +46,7 @@ class Migration(migrations.Migration): | |||||||
|  |  | ||||||
|     dependencies = [ |     dependencies = [ | ||||||
|         ("authentik_providers_ldap", "0003_ldapprovider_mfa_support_and_more"), |         ("authentik_providers_ldap", "0003_ldapprovider_mfa_support_and_more"), | ||||||
|  |         ("guardian", "0002_generic_permissions_index"), | ||||||
|     ] |     ] | ||||||
|  |  | ||||||
|     operations = [ |     operations = [ | ||||||
|  | |||||||
| @ -1,15 +1,18 @@ | |||||||
| """OAuth2Provider API Views""" | """OAuth2Provider API Views""" | ||||||
|  |  | ||||||
| from copy import copy | from copy import copy | ||||||
|  | from re import compile | ||||||
|  | from re import error as RegexError | ||||||
|  |  | ||||||
| from django.urls import reverse | from django.urls import reverse | ||||||
| from django.utils import timezone | from django.utils import timezone | ||||||
|  | from django.utils.translation import gettext_lazy as _ | ||||||
| from drf_spectacular.types import OpenApiTypes | from drf_spectacular.types import OpenApiTypes | ||||||
| from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema | from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema | ||||||
| from guardian.shortcuts import get_objects_for_user | from guardian.shortcuts import get_objects_for_user | ||||||
| from rest_framework.decorators import action | from rest_framework.decorators import action | ||||||
| from rest_framework.exceptions import ValidationError | from rest_framework.exceptions import ValidationError | ||||||
| from rest_framework.fields import CharField | from rest_framework.fields import CharField, ChoiceField | ||||||
| from rest_framework.generics import get_object_or_404 | from rest_framework.generics import get_object_or_404 | ||||||
| from rest_framework.request import Request | from rest_framework.request import Request | ||||||
| from rest_framework.response import Response | from rest_framework.response import Response | ||||||
| @ -20,13 +23,39 @@ from authentik.core.api.used_by import UsedByMixin | |||||||
| from authentik.core.api.utils import PassiveSerializer, PropertyMappingPreviewSerializer | from authentik.core.api.utils import PassiveSerializer, PropertyMappingPreviewSerializer | ||||||
| from authentik.core.models import Provider | from authentik.core.models import Provider | ||||||
| from authentik.providers.oauth2.id_token import IDToken | from authentik.providers.oauth2.id_token import IDToken | ||||||
| from authentik.providers.oauth2.models import AccessToken, OAuth2Provider, ScopeMapping | from authentik.providers.oauth2.models import ( | ||||||
|  |     AccessToken, | ||||||
|  |     OAuth2Provider, | ||||||
|  |     RedirectURIMatchingMode, | ||||||
|  |     ScopeMapping, | ||||||
|  | ) | ||||||
| from authentik.rbac.decorators import permission_required | from authentik.rbac.decorators import permission_required | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class RedirectURISerializer(PassiveSerializer): | ||||||
|  |     """A single allowed redirect URI entry""" | ||||||
|  |  | ||||||
|  |     matching_mode = ChoiceField(choices=RedirectURIMatchingMode.choices) | ||||||
|  |     url = CharField() | ||||||
|  |  | ||||||
|  |  | ||||||
| class OAuth2ProviderSerializer(ProviderSerializer): | class OAuth2ProviderSerializer(ProviderSerializer): | ||||||
|     """OAuth2Provider Serializer""" |     """OAuth2Provider Serializer""" | ||||||
|  |  | ||||||
|  |     redirect_uris = RedirectURISerializer(many=True, source="_redirect_uris") | ||||||
|  |  | ||||||
|  |     def validate_redirect_uris(self, data: list) -> list: | ||||||
|  |         for entry in data: | ||||||
|  |             if entry.get("matching_mode") == RedirectURIMatchingMode.REGEX: | ||||||
|  |                 url = entry.get("url") | ||||||
|  |                 try: | ||||||
|  |                     compile(url) | ||||||
|  |                 except RegexError: | ||||||
|  |                     raise ValidationError( | ||||||
|  |                         _("Invalid Regex Pattern: {url}".format(url=url)) | ||||||
|  |                     ) from None | ||||||
|  |         return data | ||||||
|  |  | ||||||
|     class Meta: |     class Meta: | ||||||
|         model = OAuth2Provider |         model = OAuth2Provider | ||||||
|         fields = ProviderSerializer.Meta.fields + [ |         fields = ProviderSerializer.Meta.fields + [ | ||||||
| @ -78,7 +107,6 @@ class OAuth2ProviderViewSet(UsedByMixin, ModelViewSet): | |||||||
|         "refresh_token_validity", |         "refresh_token_validity", | ||||||
|         "include_claims_in_id_token", |         "include_claims_in_id_token", | ||||||
|         "signing_key", |         "signing_key", | ||||||
|         "redirect_uris", |  | ||||||
|         "sub_mode", |         "sub_mode", | ||||||
|         "property_mappings", |         "property_mappings", | ||||||
|         "issuer_mode", |         "issuer_mode", | ||||||
|  | |||||||
| @ -7,7 +7,7 @@ from django.http import HttpRequest, HttpResponse, HttpResponseRedirect | |||||||
| from authentik.events.models import Event, EventAction | from authentik.events.models import Event, EventAction | ||||||
| from authentik.lib.sentry import SentryIgnoredException | from authentik.lib.sentry import SentryIgnoredException | ||||||
| from authentik.lib.views import bad_request_message | from authentik.lib.views import bad_request_message | ||||||
| from authentik.providers.oauth2.models import GrantTypes | from authentik.providers.oauth2.models import GrantTypes, RedirectURI | ||||||
|  |  | ||||||
|  |  | ||||||
| class OAuth2Error(SentryIgnoredException): | class OAuth2Error(SentryIgnoredException): | ||||||
| @ -46,9 +46,9 @@ class RedirectUriError(OAuth2Error): | |||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     provided_uri: str |     provided_uri: str | ||||||
|     allowed_uris: list[str] |     allowed_uris: list[RedirectURI] | ||||||
|  |  | ||||||
|     def __init__(self, provided_uri: str, allowed_uris: list[str]) -> None: |     def __init__(self, provided_uri: str, allowed_uris: list[RedirectURI]) -> None: | ||||||
|         super().__init__() |         super().__init__() | ||||||
|         self.provided_uri = provided_uri |         self.provided_uri = provided_uri | ||||||
|         self.allowed_uris = allowed_uris |         self.allowed_uris = allowed_uris | ||||||
|  | |||||||
| @ -1,6 +1,7 @@ | |||||||
| """id_token utils""" | """id_token utils""" | ||||||
|  |  | ||||||
| from dataclasses import asdict, dataclass, field | from dataclasses import asdict, dataclass, field | ||||||
|  | from hashlib import sha256 | ||||||
| from typing import TYPE_CHECKING, Any | from typing import TYPE_CHECKING, Any | ||||||
|  |  | ||||||
| from django.db import models | from django.db import models | ||||||
| @ -23,8 +24,13 @@ if TYPE_CHECKING: | |||||||
|     from authentik.providers.oauth2.models import BaseGrantModel, OAuth2Provider |     from authentik.providers.oauth2.models import BaseGrantModel, OAuth2Provider | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def hash_session_key(session_key: str) -> str: | ||||||
|  |     """Hash the session key for inclusion in JWTs as `sid`""" | ||||||
|  |     return sha256(session_key.encode("ascii")).hexdigest() | ||||||
|  |  | ||||||
|  |  | ||||||
| class SubModes(models.TextChoices): | class SubModes(models.TextChoices): | ||||||
|     """Mode after which 'sub' attribute is generateed, for compatibility reasons""" |     """Mode after which 'sub' attribute is generated, for compatibility reasons""" | ||||||
|  |  | ||||||
|     HASHED_USER_ID = "hashed_user_id", _("Based on the Hashed User ID") |     HASHED_USER_ID = "hashed_user_id", _("Based on the Hashed User ID") | ||||||
|     USER_ID = "user_id", _("Based on user ID") |     USER_ID = "user_id", _("Based on user ID") | ||||||
| @ -51,7 +57,8 @@ class IDToken: | |||||||
|     and potentially other requested Claims. The ID Token is represented as a |     and potentially other requested Claims. The ID Token is represented as a | ||||||
|     JSON Web Token (JWT) [JWT]. |     JSON Web Token (JWT) [JWT]. | ||||||
|  |  | ||||||
|     https://openid.net/specs/openid-connect-core-1_0.html#IDToken""" |     https://openid.net/specs/openid-connect-core-1_0.html#IDToken | ||||||
|  |     https://www.iana.org/assignments/jwt/jwt.xhtml""" | ||||||
|  |  | ||||||
|     # Issuer, https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.1 |     # Issuer, https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.1 | ||||||
|     iss: str | None = None |     iss: str | None = None | ||||||
| @ -79,6 +86,8 @@ class IDToken: | |||||||
|     nonce: str | None = None |     nonce: str | None = None | ||||||
|     # Access Token hash value, http://openid.net/specs/openid-connect-core-1_0.html |     # Access Token hash value, http://openid.net/specs/openid-connect-core-1_0.html | ||||||
|     at_hash: str | None = None |     at_hash: str | None = None | ||||||
|  |     # Session ID, https://openid.net/specs/openid-connect-frontchannel-1_0.html#ClaimsContents | ||||||
|  |     sid: str | None = None | ||||||
|  |  | ||||||
|     claims: dict[str, Any] = field(default_factory=dict) |     claims: dict[str, Any] = field(default_factory=dict) | ||||||
|  |  | ||||||
| @ -116,9 +125,11 @@ class IDToken: | |||||||
|         now = timezone.now() |         now = timezone.now() | ||||||
|         id_token.iat = int(now.timestamp()) |         id_token.iat = int(now.timestamp()) | ||||||
|         id_token.auth_time = int(token.auth_time.timestamp()) |         id_token.auth_time = int(token.auth_time.timestamp()) | ||||||
|  |         if token.session: | ||||||
|  |             id_token.sid = hash_session_key(token.session.session_key) | ||||||
|  |  | ||||||
|         # We use the timestamp of the user's last successful login (EventAction.LOGIN) for auth_time |         # We use the timestamp of the user's last successful login (EventAction.LOGIN) for auth_time | ||||||
|         auth_event = get_login_event(request) |         auth_event = get_login_event(token.session) | ||||||
|         if auth_event: |         if auth_event: | ||||||
|             # Also check which method was used for authentication |             # Also check which method was used for authentication | ||||||
|             method = auth_event.context.get(PLAN_CONTEXT_METHOD, "") |             method = auth_event.context.get(PLAN_CONTEXT_METHOD, "") | ||||||
|  | |||||||
| @ -3,6 +3,7 @@ | |||||||
| import django.db.models.deletion | import django.db.models.deletion | ||||||
| from django.apps.registry import Apps | from django.apps.registry import Apps | ||||||
| from django.db import migrations, models | from django.db import migrations, models | ||||||
|  | from django.db.backends.base.schema import BaseDatabaseSchemaEditor | ||||||
|  |  | ||||||
| import authentik.lib.utils.time | import authentik.lib.utils.time | ||||||
|  |  | ||||||
| @ -14,7 +15,7 @@ scope_uid_map = { | |||||||
| } | } | ||||||
|  |  | ||||||
|  |  | ||||||
| def set_managed_flag(apps: Apps, schema_editor): | def set_managed_flag(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): | ||||||
|     ScopeMapping = apps.get_model("authentik_providers_oauth2", "ScopeMapping") |     ScopeMapping = apps.get_model("authentik_providers_oauth2", "ScopeMapping") | ||||||
|     db_alias = schema_editor.connection.alias |     db_alias = schema_editor.connection.alias | ||||||
|     for mapping in ScopeMapping.objects.using(db_alias).filter(name__startswith="Autogenerated "): |     for mapping in ScopeMapping.objects.using(db_alias).filter(name__startswith="Autogenerated "): | ||||||
|  | |||||||
| @ -0,0 +1,26 @@ | |||||||
|  | # Generated by Django 5.0.9 on 2024-09-26 16:25 | ||||||
|  |  | ||||||
|  | from django.conf import settings | ||||||
|  | from django.db import migrations, models | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Migration(migrations.Migration): | ||||||
|  |  | ||||||
|  |     dependencies = [ | ||||||
|  |         ("authentik_providers_oauth2", "0018_alter_accesstoken_expires_and_more"), | ||||||
|  |         migrations.swappable_dependency(settings.AUTH_USER_MODEL), | ||||||
|  |     ] | ||||||
|  |  | ||||||
|  |     # Original preserved | ||||||
|  |     # See https://github.com/goauthentik/authentik/issues/11874 | ||||||
|  |     # operations = [ | ||||||
|  |     #     migrations.AddIndex( | ||||||
|  |     #         model_name="accesstoken", | ||||||
|  |     #         index=models.Index(fields=["token"], name="authentik_p_token_4bc870_idx"), | ||||||
|  |     #     ), | ||||||
|  |     #     migrations.AddIndex( | ||||||
|  |     #         model_name="refreshtoken", | ||||||
|  |     #         index=models.Index(fields=["token"], name="authentik_p_token_1a841f_idx"), | ||||||
|  |     #     ), | ||||||
|  |     # ] | ||||||
|  |     operations = [] | ||||||
| @ -0,0 +1,34 @@ | |||||||
|  | # Generated by Django 5.0.9 on 2024-09-27 14:50 | ||||||
|  |  | ||||||
|  | from django.conf import settings | ||||||
|  | from django.db import migrations, models | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Migration(migrations.Migration): | ||||||
|  |  | ||||||
|  |     dependencies = [ | ||||||
|  |         ("authentik_providers_oauth2", "0019_accesstoken_authentik_p_token_4bc870_idx_and_more"), | ||||||
|  |         migrations.swappable_dependency(settings.AUTH_USER_MODEL), | ||||||
|  |     ] | ||||||
|  |  | ||||||
|  |     # Original preserved | ||||||
|  |     # See https://github.com/goauthentik/authentik/issues/11874 | ||||||
|  |     # operations = [ | ||||||
|  |     #     migrations.RemoveIndex( | ||||||
|  |     #         model_name="accesstoken", | ||||||
|  |     #         name="authentik_p_token_4bc870_idx", | ||||||
|  |     #     ), | ||||||
|  |     #     migrations.RemoveIndex( | ||||||
|  |     #         model_name="refreshtoken", | ||||||
|  |     #         name="authentik_p_token_1a841f_idx", | ||||||
|  |     #     ), | ||||||
|  |     #     migrations.AddIndex( | ||||||
|  |     #         model_name="accesstoken", | ||||||
|  |     #         index=models.Index(fields=["token", "provider"], name="authentik_p_token_f99422_idx"), | ||||||
|  |     #     ), | ||||||
|  |     #     migrations.AddIndex( | ||||||
|  |     #         model_name="refreshtoken", | ||||||
|  |     #         index=models.Index(fields=["token", "provider"], name="authentik_p_token_a1d921_idx"), | ||||||
|  |     #     ), | ||||||
|  |     # ] | ||||||
|  |     operations = [] | ||||||
| @ -0,0 +1,42 @@ | |||||||
|  | # Generated by Django 5.0.9 on 2024-10-16 14:53 | ||||||
|  |  | ||||||
|  | import django.db.models.deletion | ||||||
|  | from django.db import migrations, models | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Migration(migrations.Migration): | ||||||
|  |  | ||||||
|  |     dependencies = [ | ||||||
|  |         ("authentik_crypto", "0004_alter_certificatekeypair_name"), | ||||||
|  |         ( | ||||||
|  |             "authentik_providers_oauth2", | ||||||
|  |             "0020_remove_accesstoken_authentik_p_token_4bc870_idx_and_more", | ||||||
|  |         ), | ||||||
|  |     ] | ||||||
|  |  | ||||||
|  |     operations = [ | ||||||
|  |         migrations.AddField( | ||||||
|  |             model_name="oauth2provider", | ||||||
|  |             name="encryption_key", | ||||||
|  |             field=models.ForeignKey( | ||||||
|  |                 help_text="Key used to encrypt the tokens. When set, tokens will be encrypted and returned as JWEs.", | ||||||
|  |                 null=True, | ||||||
|  |                 on_delete=django.db.models.deletion.SET_NULL, | ||||||
|  |                 related_name="oauth2provider_encryption_key_set", | ||||||
|  |                 to="authentik_crypto.certificatekeypair", | ||||||
|  |                 verbose_name="Encryption Key", | ||||||
|  |             ), | ||||||
|  |         ), | ||||||
|  |         migrations.AlterField( | ||||||
|  |             model_name="oauth2provider", | ||||||
|  |             name="signing_key", | ||||||
|  |             field=models.ForeignKey( | ||||||
|  |                 help_text="Key used to sign the tokens.", | ||||||
|  |                 null=True, | ||||||
|  |                 on_delete=django.db.models.deletion.SET_NULL, | ||||||
|  |                 related_name="oauth2provider_signing_key_set", | ||||||
|  |                 to="authentik_crypto.certificatekeypair", | ||||||
|  |                 verbose_name="Signing Key", | ||||||
|  |             ), | ||||||
|  |         ), | ||||||
|  |     ] | ||||||
| @ -0,0 +1,113 @@ | |||||||
|  | # Generated by Django 5.0.9 on 2024-10-23 13:38 | ||||||
|  |  | ||||||
|  | from hashlib import sha256 | ||||||
|  | import django.db.models.deletion | ||||||
|  | from django.db import migrations, models | ||||||
|  | from django.apps.registry import Apps | ||||||
|  | from django.db.backends.base.schema import BaseDatabaseSchemaEditor | ||||||
|  | from authentik.lib.migrations import progress_bar | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def migrate_session(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): | ||||||
|  |     AuthenticatedSession = apps.get_model("authentik_core", "authenticatedsession") | ||||||
|  |     AuthorizationCode = apps.get_model("authentik_providers_oauth2", "authorizationcode") | ||||||
|  |     AccessToken = apps.get_model("authentik_providers_oauth2", "accesstoken") | ||||||
|  |     RefreshToken = apps.get_model("authentik_providers_oauth2", "refreshtoken") | ||||||
|  |     db_alias = schema_editor.connection.alias | ||||||
|  |  | ||||||
|  |     print(f"\nFetching session keys, this might take a couple of minutes...") | ||||||
|  |     session_ids = {} | ||||||
|  |     for session in progress_bar(AuthenticatedSession.objects.using(db_alias).all()): | ||||||
|  |         session_ids[sha256(session.session_key.encode("ascii")).hexdigest()] = session.session_key | ||||||
|  |     for model in [AuthorizationCode, AccessToken, RefreshToken]: | ||||||
|  |         print( | ||||||
|  |             f"\nAdding session to {model._meta.verbose_name}, this might take a couple of minutes..." | ||||||
|  |         ) | ||||||
|  |         for code in progress_bar(model.objects.using(db_alias).all()): | ||||||
|  |             if code.session_id_old not in session_ids: | ||||||
|  |                 continue | ||||||
|  |             code.session = ( | ||||||
|  |                 AuthenticatedSession.objects.using(db_alias) | ||||||
|  |                 .filter(session_key=session_ids[code.session_id_old]) | ||||||
|  |                 .first() | ||||||
|  |             ) | ||||||
|  |             code.save() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Migration(migrations.Migration): | ||||||
|  |  | ||||||
|  |     dependencies = [ | ||||||
|  |         ("authentik_core", "0039_source_group_matching_mode_alter_group_name_and_more"), | ||||||
|  |         ("authentik_providers_oauth2", "0021_oauth2provider_encryption_key_and_more"), | ||||||
|  |     ] | ||||||
|  |  | ||||||
|  |     operations = [ | ||||||
|  |         migrations.RenameField( | ||||||
|  |             model_name="accesstoken", | ||||||
|  |             old_name="session_id", | ||||||
|  |             new_name="session_id_old", | ||||||
|  |         ), | ||||||
|  |         migrations.RenameField( | ||||||
|  |             model_name="authorizationcode", | ||||||
|  |             old_name="session_id", | ||||||
|  |             new_name="session_id_old", | ||||||
|  |         ), | ||||||
|  |         migrations.RenameField( | ||||||
|  |             model_name="refreshtoken", | ||||||
|  |             old_name="session_id", | ||||||
|  |             new_name="session_id_old", | ||||||
|  |         ), | ||||||
|  |         migrations.AddField( | ||||||
|  |             model_name="accesstoken", | ||||||
|  |             name="session", | ||||||
|  |             field=models.ForeignKey( | ||||||
|  |                 default=None, | ||||||
|  |                 null=True, | ||||||
|  |                 on_delete=django.db.models.deletion.SET_DEFAULT, | ||||||
|  |                 to="authentik_core.authenticatedsession", | ||||||
|  |             ), | ||||||
|  |         ), | ||||||
|  |         migrations.AddField( | ||||||
|  |             model_name="authorizationcode", | ||||||
|  |             name="session", | ||||||
|  |             field=models.ForeignKey( | ||||||
|  |                 default=None, | ||||||
|  |                 null=True, | ||||||
|  |                 on_delete=django.db.models.deletion.SET_DEFAULT, | ||||||
|  |                 to="authentik_core.authenticatedsession", | ||||||
|  |             ), | ||||||
|  |         ), | ||||||
|  |         migrations.AddField( | ||||||
|  |             model_name="devicetoken", | ||||||
|  |             name="session", | ||||||
|  |             field=models.ForeignKey( | ||||||
|  |                 default=None, | ||||||
|  |                 null=True, | ||||||
|  |                 on_delete=django.db.models.deletion.SET_DEFAULT, | ||||||
|  |                 to="authentik_core.authenticatedsession", | ||||||
|  |             ), | ||||||
|  |         ), | ||||||
|  |         migrations.AddField( | ||||||
|  |             model_name="refreshtoken", | ||||||
|  |             name="session", | ||||||
|  |             field=models.ForeignKey( | ||||||
|  |                 default=None, | ||||||
|  |                 null=True, | ||||||
|  |                 on_delete=django.db.models.deletion.SET_DEFAULT, | ||||||
|  |                 to="authentik_core.authenticatedsession", | ||||||
|  |             ), | ||||||
|  |         ), | ||||||
|  |         migrations.RunPython(migrate_session), | ||||||
|  |         migrations.RemoveField( | ||||||
|  |             model_name="accesstoken", | ||||||
|  |             name="session_id_old", | ||||||
|  |         ), | ||||||
|  |         migrations.RemoveField( | ||||||
|  |             model_name="authorizationcode", | ||||||
|  |             name="session_id_old", | ||||||
|  |         ), | ||||||
|  |         migrations.RemoveField( | ||||||
|  |             model_name="refreshtoken", | ||||||
|  |             name="session_id_old", | ||||||
|  |         ), | ||||||
|  |     ] | ||||||
| @ -0,0 +1,31 @@ | |||||||
|  | # Generated by Django 5.0.9 on 2024-10-31 14:28 | ||||||
|  |  | ||||||
|  | import django.contrib.postgres.indexes | ||||||
|  | from django.conf import settings | ||||||
|  | from django.db import migrations | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Migration(migrations.Migration): | ||||||
|  |  | ||||||
|  |     dependencies = [ | ||||||
|  |         ("authentik_core", "0039_source_group_matching_mode_alter_group_name_and_more"), | ||||||
|  |         ("authentik_providers_oauth2", "0022_remove_accesstoken_session_id_and_more"), | ||||||
|  |         migrations.swappable_dependency(settings.AUTH_USER_MODEL), | ||||||
|  |     ] | ||||||
|  |  | ||||||
|  |     operations = [ | ||||||
|  |         migrations.RunSQL("DROP INDEX IF EXISTS authentik_p_token_f99422_idx;"), | ||||||
|  |         migrations.RunSQL("DROP INDEX IF EXISTS authentik_p_token_a1d921_idx;"), | ||||||
|  |         migrations.AddIndex( | ||||||
|  |             model_name="accesstoken", | ||||||
|  |             index=django.contrib.postgres.indexes.HashIndex( | ||||||
|  |                 fields=["token"], name="authentik_p_token_e00883_hash" | ||||||
|  |             ), | ||||||
|  |         ), | ||||||
|  |         migrations.AddIndex( | ||||||
|  |             model_name="refreshtoken", | ||||||
|  |             index=django.contrib.postgres.indexes.HashIndex( | ||||||
|  |                 fields=["token"], name="authentik_p_token_32e2b7_hash" | ||||||
|  |             ), | ||||||
|  |         ), | ||||||
|  |     ] | ||||||
| @ -0,0 +1,48 @@ | |||||||
|  | # Generated by Django 5.0.9 on 2024-11-04 12:56 | ||||||
|  | from django.apps.registry import Apps | ||||||
|  |  | ||||||
|  | from django.db.backends.base.schema import BaseDatabaseSchemaEditor | ||||||
|  |  | ||||||
|  | from django.db import migrations, models | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def migrate_redirect_uris(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): | ||||||
|  |     from authentik.providers.oauth2.models import RedirectURI, RedirectURIMatchingMode | ||||||
|  |  | ||||||
|  |     OAuth2Provider = apps.get_model("authentik_providers_oauth2", "oauth2provider") | ||||||
|  |  | ||||||
|  |     db_alias = schema_editor.connection.alias | ||||||
|  |     for provider in OAuth2Provider.objects.using(db_alias).all(): | ||||||
|  |         uris = [] | ||||||
|  |         for old in provider.old_redirect_uris.split("\n"): | ||||||
|  |             mode = RedirectURIMatchingMode.STRICT | ||||||
|  |             if old == "*" or old == ".*": | ||||||
|  |                 mode = RedirectURIMatchingMode.REGEX | ||||||
|  |             uris.append(RedirectURI(mode, url=old)) | ||||||
|  |         provider.redirect_uris = uris | ||||||
|  |         provider.save() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Migration(migrations.Migration): | ||||||
|  |  | ||||||
|  |     dependencies = [ | ||||||
|  |         ("authentik_providers_oauth2", "0023_alter_accesstoken_refreshtoken_use_hash_index"), | ||||||
|  |     ] | ||||||
|  |  | ||||||
|  |     operations = [ | ||||||
|  |         migrations.RenameField( | ||||||
|  |             model_name="oauth2provider", | ||||||
|  |             old_name="redirect_uris", | ||||||
|  |             new_name="old_redirect_uris", | ||||||
|  |         ), | ||||||
|  |         migrations.AddField( | ||||||
|  |             model_name="oauth2provider", | ||||||
|  |             name="_redirect_uris", | ||||||
|  |             field=models.JSONField(default=dict, verbose_name="Redirect URIs"), | ||||||
|  |         ), | ||||||
|  |         migrations.RunPython(migrate_redirect_uris, lambda *args: ...), | ||||||
|  |         migrations.RemoveField( | ||||||
|  |             model_name="oauth2provider", | ||||||
|  |             name="old_redirect_uris", | ||||||
|  |         ), | ||||||
|  |     ] | ||||||
| @ -3,7 +3,7 @@ | |||||||
| import base64 | import base64 | ||||||
| import binascii | import binascii | ||||||
| import json | import json | ||||||
| from dataclasses import asdict | from dataclasses import asdict, dataclass | ||||||
| from functools import cached_property | from functools import cached_property | ||||||
| from hashlib import sha256 | from hashlib import sha256 | ||||||
| from typing import Any | from typing import Any | ||||||
| @ -12,6 +12,7 @@ from urllib.parse import urlparse, urlunparse | |||||||
| from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey | from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey | ||||||
| from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey | from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey | ||||||
| from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes | from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes | ||||||
|  | from dacite import Config | ||||||
| from dacite.core import from_dict | from dacite.core import from_dict | ||||||
| from django.db import models | from django.db import models | ||||||
| from django.http import HttpRequest | from django.http import HttpRequest | ||||||
| @ -23,7 +24,13 @@ from rest_framework.serializers import Serializer | |||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik.brands.models import WebfingerProvider | from authentik.brands.models import WebfingerProvider | ||||||
| from authentik.core.models import ExpiringModel, PropertyMapping, Provider, User | from authentik.core.models import ( | ||||||
|  |     AuthenticatedSession, | ||||||
|  |     ExpiringModel, | ||||||
|  |     PropertyMapping, | ||||||
|  |     Provider, | ||||||
|  |     User, | ||||||
|  | ) | ||||||
| from authentik.crypto.models import CertificateKeyPair | from authentik.crypto.models import CertificateKeyPair | ||||||
| from authentik.lib.generators import generate_code_fixed_length, generate_id, generate_key | from authentik.lib.generators import generate_code_fixed_length, generate_id, generate_key | ||||||
| from authentik.lib.models import SerializerModel | from authentik.lib.models import SerializerModel | ||||||
| @ -67,11 +74,25 @@ class IssuerMode(models.TextChoices): | |||||||
|     """Configure how the `iss` field is created.""" |     """Configure how the `iss` field is created.""" | ||||||
|  |  | ||||||
|     GLOBAL = "global", _("Same identifier is used for all providers") |     GLOBAL = "global", _("Same identifier is used for all providers") | ||||||
|     PER_PROVIDER = "per_provider", _( |     PER_PROVIDER = ( | ||||||
|         "Each provider has a different issuer, based on the application slug." |         "per_provider", | ||||||
|  |         _("Each provider has a different issuer, based on the application slug."), | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class RedirectURIMatchingMode(models.TextChoices): | ||||||
|  |     STRICT = "strict", _("Strict URL comparison") | ||||||
|  |     REGEX = "regex", _("Regular Expression URL matching") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @dataclass | ||||||
|  | class RedirectURI: | ||||||
|  |     """A single redirect URI entry""" | ||||||
|  |  | ||||||
|  |     matching_mode: RedirectURIMatchingMode | ||||||
|  |     url: str | ||||||
|  |  | ||||||
|  |  | ||||||
| class ResponseTypes(models.TextChoices): | class ResponseTypes(models.TextChoices): | ||||||
|     """Response Type required by the client.""" |     """Response Type required by the client.""" | ||||||
|  |  | ||||||
| @ -146,11 +167,9 @@ class OAuth2Provider(WebfingerProvider, Provider): | |||||||
|         verbose_name=_("Client Secret"), |         verbose_name=_("Client Secret"), | ||||||
|         default=generate_client_secret, |         default=generate_client_secret, | ||||||
|     ) |     ) | ||||||
|     redirect_uris = models.TextField( |     _redirect_uris = models.JSONField( | ||||||
|         default="", |         default=dict, | ||||||
|         blank=True, |  | ||||||
|         verbose_name=_("Redirect URIs"), |         verbose_name=_("Redirect URIs"), | ||||||
|         help_text=_("Enter each URI on a new line."), |  | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     include_claims_in_id_token = models.BooleanField( |     include_claims_in_id_token = models.BooleanField( | ||||||
| @ -251,12 +270,33 @@ class OAuth2Provider(WebfingerProvider, Provider): | |||||||
|         except Provider.application.RelatedObjectDoesNotExist: |         except Provider.application.RelatedObjectDoesNotExist: | ||||||
|             return None |             return None | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def redirect_uris(self) -> list[RedirectURI]: | ||||||
|  |         uris = [] | ||||||
|  |         for entry in self._redirect_uris: | ||||||
|  |             uris.append( | ||||||
|  |                 from_dict( | ||||||
|  |                     RedirectURI, | ||||||
|  |                     entry, | ||||||
|  |                     config=Config(type_hooks={RedirectURIMatchingMode: RedirectURIMatchingMode}), | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|  |         return uris | ||||||
|  |  | ||||||
|  |     @redirect_uris.setter | ||||||
|  |     def redirect_uris(self, value: list[RedirectURI]): | ||||||
|  |         cleansed = [] | ||||||
|  |         for entry in value: | ||||||
|  |             cleansed.append(asdict(entry)) | ||||||
|  |         self._redirect_uris = cleansed | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def launch_url(self) -> str | None: |     def launch_url(self) -> str | None: | ||||||
|         """Guess launch_url based on first redirect_uri""" |         """Guess launch_url based on first redirect_uri""" | ||||||
|         if self.redirect_uris == "": |         redirects = self.redirect_uris | ||||||
|  |         if len(redirects) < 1: | ||||||
|             return None |             return None | ||||||
|         main_url = self.redirect_uris.split("\n", maxsplit=1)[0] |         main_url = redirects[0].url | ||||||
|         try: |         try: | ||||||
|             launch_url = urlparse(main_url)._replace(path="") |             launch_url = urlparse(main_url)._replace(path="") | ||||||
|             return urlunparse(launch_url) |             return urlunparse(launch_url) | ||||||
| @ -320,7 +360,9 @@ class BaseGrantModel(models.Model): | |||||||
|     revoked = models.BooleanField(default=False) |     revoked = models.BooleanField(default=False) | ||||||
|     _scope = models.TextField(default="", verbose_name=_("Scopes")) |     _scope = models.TextField(default="", verbose_name=_("Scopes")) | ||||||
|     auth_time = models.DateTimeField(verbose_name="Authentication time") |     auth_time = models.DateTimeField(verbose_name="Authentication time") | ||||||
|     session_id = models.CharField(default="", blank=True) |     session = models.ForeignKey( | ||||||
|  |         AuthenticatedSession, null=True, on_delete=models.SET_DEFAULT, default=None | ||||||
|  |     ) | ||||||
|  |  | ||||||
|     class Meta: |     class Meta: | ||||||
|         abstract = True |         abstract = True | ||||||
| @ -452,6 +494,9 @@ class DeviceToken(ExpiringModel): | |||||||
|     device_code = models.TextField(default=generate_key) |     device_code = models.TextField(default=generate_key) | ||||||
|     user_code = models.TextField(default=generate_code_fixed_length) |     user_code = models.TextField(default=generate_code_fixed_length) | ||||||
|     _scope = models.TextField(default="", verbose_name=_("Scopes")) |     _scope = models.TextField(default="", verbose_name=_("Scopes")) | ||||||
|  |     session = models.ForeignKey( | ||||||
|  |         AuthenticatedSession, null=True, on_delete=models.SET_DEFAULT, default=None | ||||||
|  |     ) | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def scope(self) -> list[str]: |     def scope(self) -> list[str]: | ||||||
|  | |||||||
| @ -1,5 +1,3 @@ | |||||||
| from hashlib import sha256 |  | ||||||
|  |  | ||||||
| from django.contrib.auth.signals import user_logged_out | from django.contrib.auth.signals import user_logged_out | ||||||
| from django.dispatch import receiver | from django.dispatch import receiver | ||||||
| from django.http import HttpRequest | from django.http import HttpRequest | ||||||
| @ -13,5 +11,4 @@ def user_logged_out_oauth_access_token(sender, request: HttpRequest, user: User, | |||||||
|     """Revoke access tokens upon user logout""" |     """Revoke access tokens upon user logout""" | ||||||
|     if not request.session or not request.session.session_key: |     if not request.session or not request.session.session_key: | ||||||
|         return |         return | ||||||
|     hashed_session_key = sha256(request.session.session_key.encode("ascii")).hexdigest() |     AccessToken.objects.filter(user=user, session__session_key=request.session.session_key).delete() | ||||||
|     AccessToken.objects.filter(user=user, session_id=hashed_session_key).delete() |  | ||||||
|  | |||||||
| @ -10,7 +10,13 @@ from rest_framework.test import APITestCase | |||||||
| from authentik.blueprints.tests import apply_blueprint | from authentik.blueprints.tests import apply_blueprint | ||||||
| from authentik.core.models import Application | from authentik.core.models import Application | ||||||
| from authentik.core.tests.utils import create_test_admin_user, create_test_flow | from authentik.core.tests.utils import create_test_admin_user, create_test_flow | ||||||
| from authentik.providers.oauth2.models import OAuth2Provider, ScopeMapping | from authentik.lib.generators import generate_id | ||||||
|  | from authentik.providers.oauth2.models import ( | ||||||
|  |     OAuth2Provider, | ||||||
|  |     RedirectURI, | ||||||
|  |     RedirectURIMatchingMode, | ||||||
|  |     ScopeMapping, | ||||||
|  | ) | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestAPI(APITestCase): | class TestAPI(APITestCase): | ||||||
| @ -21,7 +27,7 @@ class TestAPI(APITestCase): | |||||||
|         self.provider: OAuth2Provider = OAuth2Provider.objects.create( |         self.provider: OAuth2Provider = OAuth2Provider.objects.create( | ||||||
|             name="test", |             name="test", | ||||||
|             authorization_flow=create_test_flow(), |             authorization_flow=create_test_flow(), | ||||||
|             redirect_uris="http://testserver", |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")], | ||||||
|         ) |         ) | ||||||
|         self.provider.property_mappings.set(ScopeMapping.objects.all()) |         self.provider.property_mappings.set(ScopeMapping.objects.all()) | ||||||
|         self.app = Application.objects.create(name="test", slug="test", provider=self.provider) |         self.app = Application.objects.create(name="test", slug="test", provider=self.provider) | ||||||
| @ -50,9 +56,29 @@ class TestAPI(APITestCase): | |||||||
|     @skipUnless(version_info >= (3, 11, 4), "This behaviour is only Python 3.11.4 and up") |     @skipUnless(version_info >= (3, 11, 4), "This behaviour is only Python 3.11.4 and up") | ||||||
|     def test_launch_url(self): |     def test_launch_url(self): | ||||||
|         """Test launch_url""" |         """Test launch_url""" | ||||||
|         self.provider.redirect_uris = ( |         self.provider.redirect_uris = [ | ||||||
|             "https://[\\d\\w]+.pr.test.goauthentik.io/source/oauth/callback/authentik/\n" |             RedirectURI( | ||||||
|         ) |                 RedirectURIMatchingMode.REGEX, | ||||||
|  |                 "https://[\\d\\w]+.pr.test.goauthentik.io/source/oauth/callback/authentik/", | ||||||
|  |             ), | ||||||
|  |         ] | ||||||
|         self.provider.save() |         self.provider.save() | ||||||
|         self.provider.refresh_from_db() |         self.provider.refresh_from_db() | ||||||
|         self.assertIsNone(self.provider.launch_url) |         self.assertIsNone(self.provider.launch_url) | ||||||
|  |  | ||||||
|  |     def test_validate_redirect_uris(self): | ||||||
|  |         """Test redirect_uris API""" | ||||||
|  |         response = self.client.post( | ||||||
|  |             reverse("authentik_api:oauth2provider-list"), | ||||||
|  |             data={ | ||||||
|  |                 "name": generate_id(), | ||||||
|  |                 "authorization_flow": create_test_flow().pk, | ||||||
|  |                 "invalidation_flow": create_test_flow().pk, | ||||||
|  |                 "redirect_uris": [ | ||||||
|  |                     {"matching_mode": "strict", "url": "http://goauthentik.io"}, | ||||||
|  |                     {"matching_mode": "regex", "url": "**"}, | ||||||
|  |                 ], | ||||||
|  |             }, | ||||||
|  |         ) | ||||||
|  |         self.assertJSONEqual(response.content, {"redirect_uris": ["Invalid Regex Pattern: **"]}) | ||||||
|  |         self.assertEqual(response.status_code, 400) | ||||||
|  | |||||||
| @ -19,6 +19,8 @@ from authentik.providers.oauth2.models import ( | |||||||
|     AuthorizationCode, |     AuthorizationCode, | ||||||
|     GrantTypes, |     GrantTypes, | ||||||
|     OAuth2Provider, |     OAuth2Provider, | ||||||
|  |     RedirectURI, | ||||||
|  |     RedirectURIMatchingMode, | ||||||
|     ScopeMapping, |     ScopeMapping, | ||||||
| ) | ) | ||||||
| from authentik.providers.oauth2.tests.utils import OAuthTestCase | from authentik.providers.oauth2.tests.utils import OAuthTestCase | ||||||
| @ -39,7 +41,7 @@ class TestAuthorize(OAuthTestCase): | |||||||
|             name=generate_id(), |             name=generate_id(), | ||||||
|             client_id="test", |             client_id="test", | ||||||
|             authorization_flow=create_test_flow(), |             authorization_flow=create_test_flow(), | ||||||
|             redirect_uris="http://local.invalid/Foo", |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")], | ||||||
|         ) |         ) | ||||||
|         with self.assertRaises(AuthorizeError): |         with self.assertRaises(AuthorizeError): | ||||||
|             request = self.factory.get( |             request = self.factory.get( | ||||||
| @ -64,7 +66,7 @@ class TestAuthorize(OAuthTestCase): | |||||||
|             name=generate_id(), |             name=generate_id(), | ||||||
|             client_id="test", |             client_id="test", | ||||||
|             authorization_flow=create_test_flow(), |             authorization_flow=create_test_flow(), | ||||||
|             redirect_uris="http://local.invalid/Foo", |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")], | ||||||
|         ) |         ) | ||||||
|         with self.assertRaises(AuthorizeError): |         with self.assertRaises(AuthorizeError): | ||||||
|             request = self.factory.get( |             request = self.factory.get( | ||||||
| @ -84,7 +86,7 @@ class TestAuthorize(OAuthTestCase): | |||||||
|             name=generate_id(), |             name=generate_id(), | ||||||
|             client_id="test", |             client_id="test", | ||||||
|             authorization_flow=create_test_flow(), |             authorization_flow=create_test_flow(), | ||||||
|             redirect_uris="http://local.invalid", |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")], | ||||||
|         ) |         ) | ||||||
|         with self.assertRaises(RedirectUriError): |         with self.assertRaises(RedirectUriError): | ||||||
|             request = self.factory.get("/", data={"response_type": "code", "client_id": "test"}) |             request = self.factory.get("/", data={"response_type": "code", "client_id": "test"}) | ||||||
| @ -106,7 +108,7 @@ class TestAuthorize(OAuthTestCase): | |||||||
|             name=generate_id(), |             name=generate_id(), | ||||||
|             client_id="test", |             client_id="test", | ||||||
|             authorization_flow=create_test_flow(), |             authorization_flow=create_test_flow(), | ||||||
|             redirect_uris="data:local.invalid", |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "data:local.invalid")], | ||||||
|         ) |         ) | ||||||
|         with self.assertRaises(RedirectUriError): |         with self.assertRaises(RedirectUriError): | ||||||
|             request = self.factory.get( |             request = self.factory.get( | ||||||
| @ -125,7 +127,7 @@ class TestAuthorize(OAuthTestCase): | |||||||
|             name=generate_id(), |             name=generate_id(), | ||||||
|             client_id="test", |             client_id="test", | ||||||
|             authorization_flow=create_test_flow(), |             authorization_flow=create_test_flow(), | ||||||
|             redirect_uris="", |             redirect_uris=[], | ||||||
|         ) |         ) | ||||||
|         with self.assertRaises(RedirectUriError): |         with self.assertRaises(RedirectUriError): | ||||||
|             request = self.factory.get("/", data={"response_type": "code", "client_id": "test"}) |             request = self.factory.get("/", data={"response_type": "code", "client_id": "test"}) | ||||||
| @ -140,7 +142,7 @@ class TestAuthorize(OAuthTestCase): | |||||||
|         ) |         ) | ||||||
|         OAuthAuthorizationParams.from_request(request) |         OAuthAuthorizationParams.from_request(request) | ||||||
|         provider.refresh_from_db() |         provider.refresh_from_db() | ||||||
|         self.assertEqual(provider.redirect_uris, "+") |         self.assertEqual(provider.redirect_uris, [RedirectURI(RedirectURIMatchingMode.STRICT, "+")]) | ||||||
|  |  | ||||||
|     def test_invalid_redirect_uri_regex(self): |     def test_invalid_redirect_uri_regex(self): | ||||||
|         """test missing/invalid redirect URI""" |         """test missing/invalid redirect URI""" | ||||||
| @ -148,7 +150,7 @@ class TestAuthorize(OAuthTestCase): | |||||||
|             name=generate_id(), |             name=generate_id(), | ||||||
|             client_id="test", |             client_id="test", | ||||||
|             authorization_flow=create_test_flow(), |             authorization_flow=create_test_flow(), | ||||||
|             redirect_uris="http://local.invalid?", |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid?")], | ||||||
|         ) |         ) | ||||||
|         with self.assertRaises(RedirectUriError): |         with self.assertRaises(RedirectUriError): | ||||||
|             request = self.factory.get("/", data={"response_type": "code", "client_id": "test"}) |             request = self.factory.get("/", data={"response_type": "code", "client_id": "test"}) | ||||||
| @ -170,7 +172,7 @@ class TestAuthorize(OAuthTestCase): | |||||||
|             name=generate_id(), |             name=generate_id(), | ||||||
|             client_id="test", |             client_id="test", | ||||||
|             authorization_flow=create_test_flow(), |             authorization_flow=create_test_flow(), | ||||||
|             redirect_uris="+", |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "+")], | ||||||
|         ) |         ) | ||||||
|         with self.assertRaises(RedirectUriError): |         with self.assertRaises(RedirectUriError): | ||||||
|             request = self.factory.get("/", data={"response_type": "code", "client_id": "test"}) |             request = self.factory.get("/", data={"response_type": "code", "client_id": "test"}) | ||||||
| @ -213,7 +215,7 @@ class TestAuthorize(OAuthTestCase): | |||||||
|             name=generate_id(), |             name=generate_id(), | ||||||
|             client_id="test", |             client_id="test", | ||||||
|             authorization_flow=create_test_flow(), |             authorization_flow=create_test_flow(), | ||||||
|             redirect_uris="http://local.invalid/Foo", |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")], | ||||||
|         ) |         ) | ||||||
|         provider.property_mappings.set( |         provider.property_mappings.set( | ||||||
|             ScopeMapping.objects.filter( |             ScopeMapping.objects.filter( | ||||||
| @ -301,7 +303,7 @@ class TestAuthorize(OAuthTestCase): | |||||||
|             name=generate_id(), |             name=generate_id(), | ||||||
|             client_id="test", |             client_id="test", | ||||||
|             authorization_flow=flow, |             authorization_flow=flow, | ||||||
|             redirect_uris="foo://localhost", |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "foo://localhost")], | ||||||
|             access_code_validity="seconds=100", |             access_code_validity="seconds=100", | ||||||
|         ) |         ) | ||||||
|         Application.objects.create(name="app", slug="app", provider=provider) |         Application.objects.create(name="app", slug="app", provider=provider) | ||||||
| @ -343,7 +345,7 @@ class TestAuthorize(OAuthTestCase): | |||||||
|             name=generate_id(), |             name=generate_id(), | ||||||
|             client_id="test", |             client_id="test", | ||||||
|             authorization_flow=flow, |             authorization_flow=flow, | ||||||
|             redirect_uris="http://localhost", |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")], | ||||||
|             signing_key=self.keypair, |             signing_key=self.keypair, | ||||||
|         ) |         ) | ||||||
|         provider.property_mappings.set( |         provider.property_mappings.set( | ||||||
| @ -419,7 +421,7 @@ class TestAuthorize(OAuthTestCase): | |||||||
|             name=generate_id(), |             name=generate_id(), | ||||||
|             client_id="test", |             client_id="test", | ||||||
|             authorization_flow=flow, |             authorization_flow=flow, | ||||||
|             redirect_uris="http://localhost", |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")], | ||||||
|             signing_key=self.keypair, |             signing_key=self.keypair, | ||||||
|         ) |         ) | ||||||
|         Application.objects.create(name="app", slug="app", provider=provider) |         Application.objects.create(name="app", slug="app", provider=provider) | ||||||
| @ -474,7 +476,7 @@ class TestAuthorize(OAuthTestCase): | |||||||
|             name=generate_id(), |             name=generate_id(), | ||||||
|             client_id=generate_id(), |             client_id=generate_id(), | ||||||
|             authorization_flow=flow, |             authorization_flow=flow, | ||||||
|             redirect_uris="http://localhost", |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")], | ||||||
|             signing_key=self.keypair, |             signing_key=self.keypair, | ||||||
|         ) |         ) | ||||||
|         provider.property_mappings.set( |         provider.property_mappings.set( | ||||||
| @ -532,7 +534,7 @@ class TestAuthorize(OAuthTestCase): | |||||||
|             name=generate_id(), |             name=generate_id(), | ||||||
|             client_id=generate_id(), |             client_id=generate_id(), | ||||||
|             authorization_flow=flow, |             authorization_flow=flow, | ||||||
|             redirect_uris="http://localhost", |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")], | ||||||
|             signing_key=self.keypair, |             signing_key=self.keypair, | ||||||
|         ) |         ) | ||||||
|         app = Application.objects.create(name=generate_id(), slug=generate_id(), provider=provider) |         app = Application.objects.create(name=generate_id(), slug=generate_id(), provider=provider) | ||||||
|  | |||||||
| @ -11,7 +11,14 @@ from authentik.core.models import Application | |||||||
| from authentik.core.tests.utils import create_test_admin_user, create_test_cert, create_test_flow | from authentik.core.tests.utils import create_test_admin_user, create_test_cert, create_test_flow | ||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
| from authentik.providers.oauth2.constants import ACR_AUTHENTIK_DEFAULT | from authentik.providers.oauth2.constants import ACR_AUTHENTIK_DEFAULT | ||||||
| from authentik.providers.oauth2.models import AccessToken, IDToken, OAuth2Provider, RefreshToken | from authentik.providers.oauth2.models import ( | ||||||
|  |     AccessToken, | ||||||
|  |     IDToken, | ||||||
|  |     OAuth2Provider, | ||||||
|  |     RedirectURI, | ||||||
|  |     RedirectURIMatchingMode, | ||||||
|  |     RefreshToken, | ||||||
|  | ) | ||||||
| from authentik.providers.oauth2.tests.utils import OAuthTestCase | from authentik.providers.oauth2.tests.utils import OAuthTestCase | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -23,13 +30,12 @@ class TesOAuth2Introspection(OAuthTestCase): | |||||||
|         self.provider: OAuth2Provider = OAuth2Provider.objects.create( |         self.provider: OAuth2Provider = OAuth2Provider.objects.create( | ||||||
|             name=generate_id(), |             name=generate_id(), | ||||||
|             authorization_flow=create_test_flow(), |             authorization_flow=create_test_flow(), | ||||||
|             redirect_uris="", |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "")], | ||||||
|             signing_key=create_test_cert(), |             signing_key=create_test_cert(), | ||||||
|         ) |         ) | ||||||
|         self.app = Application.objects.create( |         self.app = Application.objects.create( | ||||||
|             name=generate_id(), slug=generate_id(), provider=self.provider |             name=generate_id(), slug=generate_id(), provider=self.provider | ||||||
|         ) |         ) | ||||||
|         self.app.save() |  | ||||||
|         self.user = create_test_admin_user() |         self.user = create_test_admin_user() | ||||||
|         self.auth = b64encode( |         self.auth = b64encode( | ||||||
|             f"{self.provider.client_id}:{self.provider.client_secret}".encode() |             f"{self.provider.client_id}:{self.provider.client_secret}".encode() | ||||||
| @ -114,6 +120,41 @@ class TesOAuth2Introspection(OAuthTestCase): | |||||||
|             }, |             }, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |     def test_introspect_invalid_provider(self): | ||||||
|  |         """Test introspection (mismatched provider and token)""" | ||||||
|  |         provider: OAuth2Provider = OAuth2Provider.objects.create( | ||||||
|  |             name=generate_id(), | ||||||
|  |             authorization_flow=create_test_flow(), | ||||||
|  |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "")], | ||||||
|  |             signing_key=create_test_cert(), | ||||||
|  |         ) | ||||||
|  |         auth = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode() | ||||||
|  |  | ||||||
|  |         token: AccessToken = AccessToken.objects.create( | ||||||
|  |             provider=self.provider, | ||||||
|  |             user=self.user, | ||||||
|  |             token=generate_id(), | ||||||
|  |             auth_time=timezone.now(), | ||||||
|  |             _scope="openid user profile", | ||||||
|  |             _id_token=json.dumps( | ||||||
|  |                 asdict( | ||||||
|  |                     IDToken("foo", "bar"), | ||||||
|  |                 ) | ||||||
|  |             ), | ||||||
|  |         ) | ||||||
|  |         res = self.client.post( | ||||||
|  |             reverse("authentik_providers_oauth2:token-introspection"), | ||||||
|  |             HTTP_AUTHORIZATION=f"Basic {auth}", | ||||||
|  |             data={"token": token.token}, | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(res.status_code, 200) | ||||||
|  |         self.assertJSONEqual( | ||||||
|  |             res.content.decode(), | ||||||
|  |             { | ||||||
|  |                 "active": False, | ||||||
|  |             }, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     def test_introspect_invalid_auth(self): |     def test_introspect_invalid_auth(self): | ||||||
|         """Test introspect (invalid auth)""" |         """Test introspect (invalid auth)""" | ||||||
|         res = self.client.post( |         res = self.client.post( | ||||||
|  | |||||||
| @ -13,7 +13,7 @@ from authentik.core.tests.utils import create_test_cert, create_test_flow | |||||||
| from authentik.crypto.builder import PrivateKeyAlg | from authentik.crypto.builder import PrivateKeyAlg | ||||||
| from authentik.crypto.models import CertificateKeyPair | from authentik.crypto.models import CertificateKeyPair | ||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
| from authentik.providers.oauth2.models import OAuth2Provider | from authentik.providers.oauth2.models import OAuth2Provider, RedirectURI, RedirectURIMatchingMode | ||||||
| from authentik.providers.oauth2.tests.utils import OAuthTestCase | from authentik.providers.oauth2.tests.utils import OAuthTestCase | ||||||
|  |  | ||||||
| TEST_CORDS_CERT = """ | TEST_CORDS_CERT = """ | ||||||
| @ -49,7 +49,7 @@ class TestJWKS(OAuthTestCase): | |||||||
|             name="test", |             name="test", | ||||||
|             client_id="test", |             client_id="test", | ||||||
|             authorization_flow=create_test_flow(), |             authorization_flow=create_test_flow(), | ||||||
|             redirect_uris="http://local.invalid", |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")], | ||||||
|             signing_key=create_test_cert(), |             signing_key=create_test_cert(), | ||||||
|         ) |         ) | ||||||
|         app = Application.objects.create(name="test", slug="test", provider=provider) |         app = Application.objects.create(name="test", slug="test", provider=provider) | ||||||
| @ -68,7 +68,7 @@ class TestJWKS(OAuthTestCase): | |||||||
|             name="test", |             name="test", | ||||||
|             client_id="test", |             client_id="test", | ||||||
|             authorization_flow=create_test_flow(), |             authorization_flow=create_test_flow(), | ||||||
|             redirect_uris="http://local.invalid", |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")], | ||||||
|         ) |         ) | ||||||
|         app = Application.objects.create(name="test", slug="test", provider=provider) |         app = Application.objects.create(name="test", slug="test", provider=provider) | ||||||
|         response = self.client.get( |         response = self.client.get( | ||||||
| @ -82,7 +82,7 @@ class TestJWKS(OAuthTestCase): | |||||||
|             name="test", |             name="test", | ||||||
|             client_id="test", |             client_id="test", | ||||||
|             authorization_flow=create_test_flow(), |             authorization_flow=create_test_flow(), | ||||||
|             redirect_uris="http://local.invalid", |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")], | ||||||
|             signing_key=create_test_cert(PrivateKeyAlg.ECDSA), |             signing_key=create_test_cert(PrivateKeyAlg.ECDSA), | ||||||
|         ) |         ) | ||||||
|         app = Application.objects.create(name="test", slug="test", provider=provider) |         app = Application.objects.create(name="test", slug="test", provider=provider) | ||||||
| @ -104,7 +104,7 @@ class TestJWKS(OAuthTestCase): | |||||||
|             name="test", |             name="test", | ||||||
|             client_id="test", |             client_id="test", | ||||||
|             authorization_flow=create_test_flow(), |             authorization_flow=create_test_flow(), | ||||||
|             redirect_uris="http://local.invalid", |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")], | ||||||
|             signing_key=cert, |             signing_key=cert, | ||||||
|         ) |         ) | ||||||
|         app = Application.objects.create(name="test", slug="test", provider=provider) |         app = Application.objects.create(name="test", slug="test", provider=provider) | ||||||
|  | |||||||
| @ -10,7 +10,14 @@ from django.utils import timezone | |||||||
| from authentik.core.models import Application | from authentik.core.models import Application | ||||||
| from authentik.core.tests.utils import create_test_admin_user, create_test_cert, create_test_flow | from authentik.core.tests.utils import create_test_admin_user, create_test_cert, create_test_flow | ||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
| from authentik.providers.oauth2.models import AccessToken, IDToken, OAuth2Provider, RefreshToken | from authentik.providers.oauth2.models import ( | ||||||
|  |     AccessToken, | ||||||
|  |     IDToken, | ||||||
|  |     OAuth2Provider, | ||||||
|  |     RedirectURI, | ||||||
|  |     RedirectURIMatchingMode, | ||||||
|  |     RefreshToken, | ||||||
|  | ) | ||||||
| from authentik.providers.oauth2.tests.utils import OAuthTestCase | from authentik.providers.oauth2.tests.utils import OAuthTestCase | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -22,7 +29,7 @@ class TesOAuth2Revoke(OAuthTestCase): | |||||||
|         self.provider: OAuth2Provider = OAuth2Provider.objects.create( |         self.provider: OAuth2Provider = OAuth2Provider.objects.create( | ||||||
|             name=generate_id(), |             name=generate_id(), | ||||||
|             authorization_flow=create_test_flow(), |             authorization_flow=create_test_flow(), | ||||||
|             redirect_uris="", |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "")], | ||||||
|             signing_key=create_test_cert(), |             signing_key=create_test_cert(), | ||||||
|         ) |         ) | ||||||
|         self.app = Application.objects.create( |         self.app = Application.objects.create( | ||||||
|  | |||||||
| @ -22,6 +22,8 @@ from authentik.providers.oauth2.models import ( | |||||||
|     AccessToken, |     AccessToken, | ||||||
|     AuthorizationCode, |     AuthorizationCode, | ||||||
|     OAuth2Provider, |     OAuth2Provider, | ||||||
|  |     RedirectURI, | ||||||
|  |     RedirectURIMatchingMode, | ||||||
|     RefreshToken, |     RefreshToken, | ||||||
|     ScopeMapping, |     ScopeMapping, | ||||||
| ) | ) | ||||||
| @ -42,7 +44,7 @@ class TestToken(OAuthTestCase): | |||||||
|         provider = OAuth2Provider.objects.create( |         provider = OAuth2Provider.objects.create( | ||||||
|             name=generate_id(), |             name=generate_id(), | ||||||
|             authorization_flow=create_test_flow(), |             authorization_flow=create_test_flow(), | ||||||
|             redirect_uris="http://TestServer", |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://TestServer")], | ||||||
|             signing_key=self.keypair, |             signing_key=self.keypair, | ||||||
|         ) |         ) | ||||||
|         header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode() |         header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode() | ||||||
| @ -69,7 +71,7 @@ class TestToken(OAuthTestCase): | |||||||
|         provider = OAuth2Provider.objects.create( |         provider = OAuth2Provider.objects.create( | ||||||
|             name=generate_id(), |             name=generate_id(), | ||||||
|             authorization_flow=create_test_flow(), |             authorization_flow=create_test_flow(), | ||||||
|             redirect_uris="http://testserver", |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")], | ||||||
|             signing_key=self.keypair, |             signing_key=self.keypair, | ||||||
|         ) |         ) | ||||||
|         header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode() |         header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode() | ||||||
| @ -90,7 +92,7 @@ class TestToken(OAuthTestCase): | |||||||
|         provider = OAuth2Provider.objects.create( |         provider = OAuth2Provider.objects.create( | ||||||
|             name=generate_id(), |             name=generate_id(), | ||||||
|             authorization_flow=create_test_flow(), |             authorization_flow=create_test_flow(), | ||||||
|             redirect_uris="http://local.invalid", |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")], | ||||||
|             signing_key=self.keypair, |             signing_key=self.keypair, | ||||||
|         ) |         ) | ||||||
|         header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode() |         header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode() | ||||||
| @ -118,7 +120,7 @@ class TestToken(OAuthTestCase): | |||||||
|         provider = OAuth2Provider.objects.create( |         provider = OAuth2Provider.objects.create( | ||||||
|             name=generate_id(), |             name=generate_id(), | ||||||
|             authorization_flow=create_test_flow(), |             authorization_flow=create_test_flow(), | ||||||
|             redirect_uris="http://local.invalid", |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")], | ||||||
|             signing_key=self.keypair, |             signing_key=self.keypair, | ||||||
|         ) |         ) | ||||||
|         # Needs to be assigned to an application for iss to be set |         # Needs to be assigned to an application for iss to be set | ||||||
| @ -158,7 +160,7 @@ class TestToken(OAuthTestCase): | |||||||
|         provider = OAuth2Provider.objects.create( |         provider = OAuth2Provider.objects.create( | ||||||
|             name=generate_id(), |             name=generate_id(), | ||||||
|             authorization_flow=create_test_flow(), |             authorization_flow=create_test_flow(), | ||||||
|             redirect_uris="http://local.invalid", |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")], | ||||||
|             signing_key=self.keypair, |             signing_key=self.keypair, | ||||||
|         ) |         ) | ||||||
|         provider.property_mappings.set( |         provider.property_mappings.set( | ||||||
| @ -220,7 +222,7 @@ class TestToken(OAuthTestCase): | |||||||
|         provider = OAuth2Provider.objects.create( |         provider = OAuth2Provider.objects.create( | ||||||
|             name=generate_id(), |             name=generate_id(), | ||||||
|             authorization_flow=create_test_flow(), |             authorization_flow=create_test_flow(), | ||||||
|             redirect_uris="http://local.invalid", |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")], | ||||||
|             signing_key=self.keypair, |             signing_key=self.keypair, | ||||||
|         ) |         ) | ||||||
|         provider.property_mappings.set( |         provider.property_mappings.set( | ||||||
| @ -278,7 +280,7 @@ class TestToken(OAuthTestCase): | |||||||
|         provider = OAuth2Provider.objects.create( |         provider = OAuth2Provider.objects.create( | ||||||
|             name=generate_id(), |             name=generate_id(), | ||||||
|             authorization_flow=create_test_flow(), |             authorization_flow=create_test_flow(), | ||||||
|             redirect_uris="http://testserver", |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")], | ||||||
|             signing_key=self.keypair, |             signing_key=self.keypair, | ||||||
|         ) |         ) | ||||||
|         provider.property_mappings.set( |         provider.property_mappings.set( | ||||||
|  | |||||||
| @ -19,7 +19,12 @@ from authentik.providers.oauth2.constants import ( | |||||||
|     SCOPE_OPENID_PROFILE, |     SCOPE_OPENID_PROFILE, | ||||||
|     TOKEN_TYPE, |     TOKEN_TYPE, | ||||||
| ) | ) | ||||||
| from authentik.providers.oauth2.models import OAuth2Provider, ScopeMapping | from authentik.providers.oauth2.models import ( | ||||||
|  |     OAuth2Provider, | ||||||
|  |     RedirectURI, | ||||||
|  |     RedirectURIMatchingMode, | ||||||
|  |     ScopeMapping, | ||||||
|  | ) | ||||||
| from authentik.providers.oauth2.tests.utils import OAuthTestCase | from authentik.providers.oauth2.tests.utils import OAuthTestCase | ||||||
| from authentik.providers.oauth2.views.jwks import JWKSView | from authentik.providers.oauth2.views.jwks import JWKSView | ||||||
| from authentik.sources.oauth.models import OAuthSource | from authentik.sources.oauth.models import OAuthSource | ||||||
| @ -54,7 +59,7 @@ class TestTokenClientCredentialsJWTSource(OAuthTestCase): | |||||||
|         self.provider: OAuth2Provider = OAuth2Provider.objects.create( |         self.provider: OAuth2Provider = OAuth2Provider.objects.create( | ||||||
|             name="test", |             name="test", | ||||||
|             authorization_flow=create_test_flow(), |             authorization_flow=create_test_flow(), | ||||||
|             redirect_uris="http://testserver", |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")], | ||||||
|             signing_key=self.cert, |             signing_key=self.cert, | ||||||
|         ) |         ) | ||||||
|         self.provider.jwks_sources.add(self.source) |         self.provider.jwks_sources.add(self.source) | ||||||
|  | |||||||
| @ -19,7 +19,13 @@ from authentik.providers.oauth2.constants import ( | |||||||
|     TOKEN_TYPE, |     TOKEN_TYPE, | ||||||
| ) | ) | ||||||
| from authentik.providers.oauth2.errors import TokenError | from authentik.providers.oauth2.errors import TokenError | ||||||
| from authentik.providers.oauth2.models import OAuth2Provider, ScopeMapping | from authentik.providers.oauth2.models import ( | ||||||
|  |     AccessToken, | ||||||
|  |     OAuth2Provider, | ||||||
|  |     RedirectURI, | ||||||
|  |     RedirectURIMatchingMode, | ||||||
|  |     ScopeMapping, | ||||||
|  | ) | ||||||
| from authentik.providers.oauth2.tests.utils import OAuthTestCase | from authentik.providers.oauth2.tests.utils import OAuthTestCase | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -33,7 +39,7 @@ class TestTokenClientCredentialsStandard(OAuthTestCase): | |||||||
|         self.provider = OAuth2Provider.objects.create( |         self.provider = OAuth2Provider.objects.create( | ||||||
|             name="test", |             name="test", | ||||||
|             authorization_flow=create_test_flow(), |             authorization_flow=create_test_flow(), | ||||||
|             redirect_uris="http://testserver", |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")], | ||||||
|             signing_key=create_test_cert(), |             signing_key=create_test_cert(), | ||||||
|         ) |         ) | ||||||
|         self.provider.property_mappings.set(ScopeMapping.objects.all()) |         self.provider.property_mappings.set(ScopeMapping.objects.all()) | ||||||
| @ -107,6 +113,48 @@ class TestTokenClientCredentialsStandard(OAuthTestCase): | |||||||
|             {"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]}, |             {"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]}, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |     def test_incorrect_scopes(self): | ||||||
|  |         """test scope that isn't configured""" | ||||||
|  |         response = self.client.post( | ||||||
|  |             reverse("authentik_providers_oauth2:token"), | ||||||
|  |             { | ||||||
|  |                 "grant_type": GRANT_TYPE_CLIENT_CREDENTIALS, | ||||||
|  |                 "scope": f"{SCOPE_OPENID} {SCOPE_OPENID_EMAIL} {SCOPE_OPENID_PROFILE} extra_scope", | ||||||
|  |                 "client_id": self.provider.client_id, | ||||||
|  |                 "client_secret": self.provider.client_secret, | ||||||
|  |             }, | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(response.status_code, 200) | ||||||
|  |         body = loads(response.content.decode()) | ||||||
|  |         self.assertEqual(body["token_type"], TOKEN_TYPE) | ||||||
|  |         token = AccessToken.objects.filter( | ||||||
|  |             provider=self.provider, token=body["access_token"] | ||||||
|  |         ).first() | ||||||
|  |         self.assertSetEqual( | ||||||
|  |             set(token.scope), {SCOPE_OPENID, SCOPE_OPENID_EMAIL, SCOPE_OPENID_PROFILE} | ||||||
|  |         ) | ||||||
|  |         _, alg = self.provider.jwt_key | ||||||
|  |         jwt = decode( | ||||||
|  |             body["access_token"], | ||||||
|  |             key=self.provider.signing_key.public_key, | ||||||
|  |             algorithms=[alg], | ||||||
|  |             audience=self.provider.client_id, | ||||||
|  |         ) | ||||||
|  |         self.assertEqual( | ||||||
|  |             jwt["given_name"], "Autogenerated user from application test (client credentials)" | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(jwt["preferred_username"], "ak-test-client_credentials") | ||||||
|  |         jwt = decode( | ||||||
|  |             body["id_token"], | ||||||
|  |             key=self.provider.signing_key.public_key, | ||||||
|  |             algorithms=[alg], | ||||||
|  |             audience=self.provider.client_id, | ||||||
|  |         ) | ||||||
|  |         self.assertEqual( | ||||||
|  |             jwt["given_name"], "Autogenerated user from application test (client credentials)" | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(jwt["preferred_username"], "ak-test-client_credentials") | ||||||
|  |  | ||||||
|     def test_successful(self): |     def test_successful(self): | ||||||
|         """test successful""" |         """test successful""" | ||||||
|         response = self.client.post( |         response = self.client.post( | ||||||
|  | |||||||
| @ -20,7 +20,12 @@ from authentik.providers.oauth2.constants import ( | |||||||
|     TOKEN_TYPE, |     TOKEN_TYPE, | ||||||
| ) | ) | ||||||
| from authentik.providers.oauth2.errors import TokenError | from authentik.providers.oauth2.errors import TokenError | ||||||
| from authentik.providers.oauth2.models import OAuth2Provider, ScopeMapping | from authentik.providers.oauth2.models import ( | ||||||
|  |     OAuth2Provider, | ||||||
|  |     RedirectURI, | ||||||
|  |     RedirectURIMatchingMode, | ||||||
|  |     ScopeMapping, | ||||||
|  | ) | ||||||
| from authentik.providers.oauth2.tests.utils import OAuthTestCase | from authentik.providers.oauth2.tests.utils import OAuthTestCase | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -34,7 +39,7 @@ class TestTokenClientCredentialsStandardCompat(OAuthTestCase): | |||||||
|         self.provider = OAuth2Provider.objects.create( |         self.provider = OAuth2Provider.objects.create( | ||||||
|             name="test", |             name="test", | ||||||
|             authorization_flow=create_test_flow(), |             authorization_flow=create_test_flow(), | ||||||
|             redirect_uris="http://testserver", |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")], | ||||||
|             signing_key=create_test_cert(), |             signing_key=create_test_cert(), | ||||||
|         ) |         ) | ||||||
|         self.provider.property_mappings.set(ScopeMapping.objects.all()) |         self.provider.property_mappings.set(ScopeMapping.objects.all()) | ||||||
|  | |||||||
| @ -19,7 +19,12 @@ from authentik.providers.oauth2.constants import ( | |||||||
|     TOKEN_TYPE, |     TOKEN_TYPE, | ||||||
| ) | ) | ||||||
| from authentik.providers.oauth2.errors import TokenError | from authentik.providers.oauth2.errors import TokenError | ||||||
| from authentik.providers.oauth2.models import OAuth2Provider, ScopeMapping | from authentik.providers.oauth2.models import ( | ||||||
|  |     OAuth2Provider, | ||||||
|  |     RedirectURI, | ||||||
|  |     RedirectURIMatchingMode, | ||||||
|  |     ScopeMapping, | ||||||
|  | ) | ||||||
| from authentik.providers.oauth2.tests.utils import OAuthTestCase | from authentik.providers.oauth2.tests.utils import OAuthTestCase | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -33,7 +38,7 @@ class TestTokenClientCredentialsUserNamePassword(OAuthTestCase): | |||||||
|         self.provider = OAuth2Provider.objects.create( |         self.provider = OAuth2Provider.objects.create( | ||||||
|             name="test", |             name="test", | ||||||
|             authorization_flow=create_test_flow(), |             authorization_flow=create_test_flow(), | ||||||
|             redirect_uris="http://testserver", |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")], | ||||||
|             signing_key=create_test_cert(), |             signing_key=create_test_cert(), | ||||||
|         ) |         ) | ||||||
|         self.provider.property_mappings.set(ScopeMapping.objects.all()) |         self.provider.property_mappings.set(ScopeMapping.objects.all()) | ||||||
|  | |||||||
| @ -9,8 +9,19 @@ from authentik.blueprints.tests import apply_blueprint | |||||||
| from authentik.core.models import Application | from authentik.core.models import Application | ||||||
| from authentik.core.tests.utils import create_test_admin_user, create_test_cert, create_test_flow | from authentik.core.tests.utils import create_test_admin_user, create_test_cert, create_test_flow | ||||||
| from authentik.lib.generators import generate_code_fixed_length, generate_id | from authentik.lib.generators import generate_code_fixed_length, generate_id | ||||||
| from authentik.providers.oauth2.constants import GRANT_TYPE_DEVICE_CODE | from authentik.providers.oauth2.constants import ( | ||||||
| from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider, ScopeMapping |     GRANT_TYPE_DEVICE_CODE, | ||||||
|  |     SCOPE_OPENID, | ||||||
|  |     SCOPE_OPENID_EMAIL, | ||||||
|  | ) | ||||||
|  | from authentik.providers.oauth2.models import ( | ||||||
|  |     AccessToken, | ||||||
|  |     DeviceToken, | ||||||
|  |     OAuth2Provider, | ||||||
|  |     RedirectURI, | ||||||
|  |     RedirectURIMatchingMode, | ||||||
|  |     ScopeMapping, | ||||||
|  | ) | ||||||
| from authentik.providers.oauth2.tests.utils import OAuthTestCase | from authentik.providers.oauth2.tests.utils import OAuthTestCase | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -24,7 +35,7 @@ class TestTokenDeviceCode(OAuthTestCase): | |||||||
|         self.provider = OAuth2Provider.objects.create( |         self.provider = OAuth2Provider.objects.create( | ||||||
|             name="test", |             name="test", | ||||||
|             authorization_flow=create_test_flow(), |             authorization_flow=create_test_flow(), | ||||||
|             redirect_uris="http://testserver", |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")], | ||||||
|             signing_key=create_test_cert(), |             signing_key=create_test_cert(), | ||||||
|         ) |         ) | ||||||
|         self.provider.property_mappings.set(ScopeMapping.objects.all()) |         self.provider.property_mappings.set(ScopeMapping.objects.all()) | ||||||
| @ -80,3 +91,28 @@ class TestTokenDeviceCode(OAuthTestCase): | |||||||
|             }, |             }, | ||||||
|         ) |         ) | ||||||
|         self.assertEqual(res.status_code, 200) |         self.assertEqual(res.status_code, 200) | ||||||
|  |  | ||||||
|  |     def test_code_mismatched_scope(self): | ||||||
|  |         """Test code with user (mismatched scopes)""" | ||||||
|  |         device_token = DeviceToken.objects.create( | ||||||
|  |             provider=self.provider, | ||||||
|  |             user_code=generate_code_fixed_length(), | ||||||
|  |             device_code=generate_id(), | ||||||
|  |             user=self.user, | ||||||
|  |             scope=[SCOPE_OPENID, SCOPE_OPENID_EMAIL], | ||||||
|  |         ) | ||||||
|  |         res = self.client.post( | ||||||
|  |             reverse("authentik_providers_oauth2:token"), | ||||||
|  |             data={ | ||||||
|  |                 "client_id": self.provider.client_id, | ||||||
|  |                 "grant_type": GRANT_TYPE_DEVICE_CODE, | ||||||
|  |                 "device_code": device_token.device_code, | ||||||
|  |                 "scope": f"{SCOPE_OPENID} {SCOPE_OPENID_EMAIL} invalid", | ||||||
|  |             }, | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(res.status_code, 200) | ||||||
|  |         body = loads(res.content) | ||||||
|  |         token = AccessToken.objects.filter( | ||||||
|  |             provider=self.provider, token=body["access_token"] | ||||||
|  |         ).first() | ||||||
|  |         self.assertSetEqual(set(token.scope), {SCOPE_OPENID, SCOPE_OPENID_EMAIL}) | ||||||
|  | |||||||
| @ -10,7 +10,12 @@ from authentik.core.models import Application | |||||||
| from authentik.core.tests.utils import create_test_admin_user, create_test_flow | from authentik.core.tests.utils import create_test_admin_user, create_test_flow | ||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
| from authentik.providers.oauth2.constants import GRANT_TYPE_AUTHORIZATION_CODE | from authentik.providers.oauth2.constants import GRANT_TYPE_AUTHORIZATION_CODE | ||||||
| from authentik.providers.oauth2.models import AuthorizationCode, OAuth2Provider | from authentik.providers.oauth2.models import ( | ||||||
|  |     AuthorizationCode, | ||||||
|  |     OAuth2Provider, | ||||||
|  |     RedirectURI, | ||||||
|  |     RedirectURIMatchingMode, | ||||||
|  | ) | ||||||
| from authentik.providers.oauth2.tests.utils import OAuthTestCase | from authentik.providers.oauth2.tests.utils import OAuthTestCase | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -30,7 +35,7 @@ class TestTokenPKCE(OAuthTestCase): | |||||||
|             name=generate_id(), |             name=generate_id(), | ||||||
|             client_id="test", |             client_id="test", | ||||||
|             authorization_flow=flow, |             authorization_flow=flow, | ||||||
|             redirect_uris="foo://localhost", |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "foo://localhost")], | ||||||
|             access_code_validity="seconds=100", |             access_code_validity="seconds=100", | ||||||
|         ) |         ) | ||||||
|         Application.objects.create(name="app", slug="app", provider=provider) |         Application.objects.create(name="app", slug="app", provider=provider) | ||||||
| @ -93,7 +98,7 @@ class TestTokenPKCE(OAuthTestCase): | |||||||
|             name=generate_id(), |             name=generate_id(), | ||||||
|             client_id="test", |             client_id="test", | ||||||
|             authorization_flow=flow, |             authorization_flow=flow, | ||||||
|             redirect_uris="foo://localhost", |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "foo://localhost")], | ||||||
|             access_code_validity="seconds=100", |             access_code_validity="seconds=100", | ||||||
|         ) |         ) | ||||||
|         Application.objects.create(name="app", slug="app", provider=provider) |         Application.objects.create(name="app", slug="app", provider=provider) | ||||||
| @ -154,7 +159,7 @@ class TestTokenPKCE(OAuthTestCase): | |||||||
|             name=generate_id(), |             name=generate_id(), | ||||||
|             client_id="test", |             client_id="test", | ||||||
|             authorization_flow=flow, |             authorization_flow=flow, | ||||||
|             redirect_uris="foo://localhost", |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "foo://localhost")], | ||||||
|             access_code_validity="seconds=100", |             access_code_validity="seconds=100", | ||||||
|         ) |         ) | ||||||
|         Application.objects.create(name="app", slug="app", provider=provider) |         Application.objects.create(name="app", slug="app", provider=provider) | ||||||
| @ -210,7 +215,7 @@ class TestTokenPKCE(OAuthTestCase): | |||||||
|             name=generate_id(), |             name=generate_id(), | ||||||
|             client_id="test", |             client_id="test", | ||||||
|             authorization_flow=flow, |             authorization_flow=flow, | ||||||
|             redirect_uris="foo://localhost", |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "foo://localhost")], | ||||||
|             access_code_validity="seconds=100", |             access_code_validity="seconds=100", | ||||||
|         ) |         ) | ||||||
|         Application.objects.create(name="app", slug="app", provider=provider) |         Application.objects.create(name="app", slug="app", provider=provider) | ||||||
|  | |||||||
| @ -11,7 +11,14 @@ from authentik.core.models import Application | |||||||
| from authentik.core.tests.utils import create_test_admin_user, create_test_cert, create_test_flow | from authentik.core.tests.utils import create_test_admin_user, create_test_cert, create_test_flow | ||||||
| from authentik.events.models import Event, EventAction | from authentik.events.models import Event, EventAction | ||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
| from authentik.providers.oauth2.models import AccessToken, IDToken, OAuth2Provider, ScopeMapping | from authentik.providers.oauth2.models import ( | ||||||
|  |     AccessToken, | ||||||
|  |     IDToken, | ||||||
|  |     OAuth2Provider, | ||||||
|  |     RedirectURI, | ||||||
|  |     RedirectURIMatchingMode, | ||||||
|  |     ScopeMapping, | ||||||
|  | ) | ||||||
| from authentik.providers.oauth2.tests.utils import OAuthTestCase | from authentik.providers.oauth2.tests.utils import OAuthTestCase | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -25,7 +32,7 @@ class TestUserinfo(OAuthTestCase): | |||||||
|         self.provider: OAuth2Provider = OAuth2Provider.objects.create( |         self.provider: OAuth2Provider = OAuth2Provider.objects.create( | ||||||
|             name=generate_id(), |             name=generate_id(), | ||||||
|             authorization_flow=create_test_flow(), |             authorization_flow=create_test_flow(), | ||||||
|             redirect_uris="", |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "")], | ||||||
|             signing_key=create_test_cert(), |             signing_key=create_test_cert(), | ||||||
|         ) |         ) | ||||||
|         self.provider.property_mappings.set(ScopeMapping.objects.all()) |         self.provider.property_mappings.set(ScopeMapping.objects.all()) | ||||||
|  | |||||||
| @ -2,7 +2,6 @@ | |||||||
|  |  | ||||||
| from dataclasses import InitVar, dataclass, field | from dataclasses import InitVar, dataclass, field | ||||||
| from datetime import timedelta | from datetime import timedelta | ||||||
| from hashlib import sha256 |  | ||||||
| from json import dumps | from json import dumps | ||||||
| from re import error as RegexError | from re import error as RegexError | ||||||
| from re import fullmatch | from re import fullmatch | ||||||
| @ -16,7 +15,7 @@ from django.utils import timezone | |||||||
| from django.utils.translation import gettext as _ | from django.utils.translation import gettext as _ | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik.core.models import Application | from authentik.core.models import Application, AuthenticatedSession | ||||||
| from authentik.events.models import Event, EventAction | from authentik.events.models import Event, EventAction | ||||||
| from authentik.events.signals import get_login_event | from authentik.events.signals import get_login_event | ||||||
| from authentik.flows.challenge import ( | from authentik.flows.challenge import ( | ||||||
| @ -57,6 +56,8 @@ from authentik.providers.oauth2.models import ( | |||||||
|     AuthorizationCode, |     AuthorizationCode, | ||||||
|     GrantTypes, |     GrantTypes, | ||||||
|     OAuth2Provider, |     OAuth2Provider, | ||||||
|  |     RedirectURI, | ||||||
|  |     RedirectURIMatchingMode, | ||||||
|     ResponseMode, |     ResponseMode, | ||||||
|     ResponseTypes, |     ResponseTypes, | ||||||
|     ScopeMapping, |     ScopeMapping, | ||||||
| @ -188,40 +189,39 @@ class OAuthAuthorizationParams: | |||||||
|  |  | ||||||
|     def check_redirect_uri(self): |     def check_redirect_uri(self): | ||||||
|         """Redirect URI validation.""" |         """Redirect URI validation.""" | ||||||
|         allowed_redirect_urls = self.provider.redirect_uris.split() |         allowed_redirect_urls = self.provider.redirect_uris | ||||||
|         if not self.redirect_uri: |         if not self.redirect_uri: | ||||||
|             LOGGER.warning("Missing redirect uri.") |             LOGGER.warning("Missing redirect uri.") | ||||||
|             raise RedirectUriError("", allowed_redirect_urls) |             raise RedirectUriError("", allowed_redirect_urls) | ||||||
|  |  | ||||||
|         if self.provider.redirect_uris == "": |         if len(allowed_redirect_urls) < 1: | ||||||
|             LOGGER.info("Setting redirect for blank redirect_uris", redirect=self.redirect_uri) |             LOGGER.info("Setting redirect for blank redirect_uris", redirect=self.redirect_uri) | ||||||
|             self.provider.redirect_uris = self.redirect_uri |             self.provider.redirect_uris = [ | ||||||
|  |                 RedirectURI(RedirectURIMatchingMode.STRICT, self.redirect_uri) | ||||||
|  |             ] | ||||||
|             self.provider.save() |             self.provider.save() | ||||||
|             allowed_redirect_urls = self.provider.redirect_uris.split() |             allowed_redirect_urls = self.provider.redirect_uris | ||||||
|  |  | ||||||
|         if self.provider.redirect_uris == "*": |  | ||||||
|             LOGGER.info("Converting redirect_uris to regex", redirect=self.redirect_uri) |  | ||||||
|             self.provider.redirect_uris = ".*" |  | ||||||
|             self.provider.save() |  | ||||||
|             allowed_redirect_urls = self.provider.redirect_uris.split() |  | ||||||
|  |  | ||||||
|  |         match_found = False | ||||||
|  |         for allowed in allowed_redirect_urls: | ||||||
|  |             if allowed.matching_mode == RedirectURIMatchingMode.STRICT: | ||||||
|  |                 if self.redirect_uri == allowed.url: | ||||||
|  |                     match_found = True | ||||||
|  |                     break | ||||||
|  |             if allowed.matching_mode == RedirectURIMatchingMode.REGEX: | ||||||
|                 try: |                 try: | ||||||
|             if not any(fullmatch(x, self.redirect_uri) for x in allowed_redirect_urls): |                     if fullmatch(allowed.url, self.redirect_uri): | ||||||
|                 LOGGER.warning( |                         match_found = True | ||||||
|                     "Invalid redirect uri (regex comparison)", |                         break | ||||||
|                     redirect_uri_given=self.redirect_uri, |  | ||||||
|                     redirect_uri_expected=allowed_redirect_urls, |  | ||||||
|                 ) |  | ||||||
|                 raise RedirectUriError(self.redirect_uri, allowed_redirect_urls) |  | ||||||
|                 except RegexError as exc: |                 except RegexError as exc: | ||||||
|             LOGGER.info("Failed to parse regular expression, checking directly", exc=exc) |  | ||||||
|             if not any(x == self.redirect_uri for x in allowed_redirect_urls): |  | ||||||
|                     LOGGER.warning( |                     LOGGER.warning( | ||||||
|                     "Invalid redirect uri (strict comparison)", |                         "Failed to parse regular expression", | ||||||
|                     redirect_uri_given=self.redirect_uri, |                         exc=exc, | ||||||
|                     redirect_uri_expected=allowed_redirect_urls, |                         url=allowed.url, | ||||||
|  |                         provider=self.provider, | ||||||
|                     ) |                     ) | ||||||
|                 raise RedirectUriError(self.redirect_uri, allowed_redirect_urls) from None |         if not match_found: | ||||||
|  |             raise RedirectUriError(self.redirect_uri, allowed_redirect_urls) | ||||||
|         # Check against forbidden schemes |         # Check against forbidden schemes | ||||||
|         if urlparse(self.redirect_uri).scheme in FORBIDDEN_URI_SCHEMES: |         if urlparse(self.redirect_uri).scheme in FORBIDDEN_URI_SCHEMES: | ||||||
|             raise RedirectUriError(self.redirect_uri, allowed_redirect_urls) |             raise RedirectUriError(self.redirect_uri, allowed_redirect_urls) | ||||||
| @ -318,7 +318,9 @@ class OAuthAuthorizationParams: | |||||||
|             expires=now + timedelta_from_string(self.provider.access_code_validity), |             expires=now + timedelta_from_string(self.provider.access_code_validity), | ||||||
|             scope=self.scope, |             scope=self.scope, | ||||||
|             nonce=self.nonce, |             nonce=self.nonce, | ||||||
|             session_id=sha256(request.session.session_key.encode("ascii")).hexdigest(), |             session=AuthenticatedSession.objects.filter( | ||||||
|  |                 session_key=request.session.session_key | ||||||
|  |             ).first(), | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|         if self.code_challenge and self.code_challenge_method: |         if self.code_challenge and self.code_challenge_method: | ||||||
| @ -610,7 +612,9 @@ class OAuthFulfillmentStage(StageView): | |||||||
|             expires=access_token_expiry, |             expires=access_token_expiry, | ||||||
|             provider=self.provider, |             provider=self.provider, | ||||||
|             auth_time=auth_event.created if auth_event else now, |             auth_time=auth_event.created if auth_event else now, | ||||||
|             session_id=sha256(self.request.session.session_key.encode("ascii")).hexdigest(), |             session=AuthenticatedSession.objects.filter( | ||||||
|  |                 session_key=self.request.session.session_key | ||||||
|  |             ).first(), | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|         id_token = IDToken.new(self.provider, token, self.request) |         id_token = IDToken.new(self.provider, token, self.request) | ||||||
|  | |||||||
| @ -46,10 +46,10 @@ class TokenIntrospectionParams: | |||||||
|         if not provider: |         if not provider: | ||||||
|             raise TokenIntrospectionError |             raise TokenIntrospectionError | ||||||
|  |  | ||||||
|         access_token = AccessToken.objects.filter(token=raw_token).first() |         access_token = AccessToken.objects.filter(token=raw_token, provider=provider).first() | ||||||
|         if access_token: |         if access_token: | ||||||
|             return TokenIntrospectionParams(access_token, provider) |             return TokenIntrospectionParams(access_token, provider) | ||||||
|         refresh_token = RefreshToken.objects.filter(token=raw_token).first() |         refresh_token = RefreshToken.objects.filter(token=raw_token, provider=provider).first() | ||||||
|         if refresh_token: |         if refresh_token: | ||||||
|             return TokenIntrospectionParams(refresh_token, provider) |             return TokenIntrospectionParams(refresh_token, provider) | ||||||
|         LOGGER.debug("Token does not exist", token=raw_token) |         LOGGER.debug("Token does not exist", token=raw_token) | ||||||
|  | |||||||
| @ -158,5 +158,5 @@ class ProviderInfoView(View): | |||||||
|             OAuth2Provider, pk=application.provider_id |             OAuth2Provider, pk=application.provider_id | ||||||
|         ) |         ) | ||||||
|         response = super().dispatch(request, *args, **kwargs) |         response = super().dispatch(request, *args, **kwargs) | ||||||
|         cors_allow(request, response, *self.provider.redirect_uris.split("\n")) |         cors_allow(request, response, *[x.url for x in self.provider.redirect_uris]) | ||||||
|         return response |         return response | ||||||
|  | |||||||
| @ -58,7 +58,9 @@ from authentik.providers.oauth2.models import ( | |||||||
|     ClientTypes, |     ClientTypes, | ||||||
|     DeviceToken, |     DeviceToken, | ||||||
|     OAuth2Provider, |     OAuth2Provider, | ||||||
|  |     RedirectURIMatchingMode, | ||||||
|     RefreshToken, |     RefreshToken, | ||||||
|  |     ScopeMapping, | ||||||
| ) | ) | ||||||
| from authentik.providers.oauth2.utils import TokenResponse, cors_allow, extract_client_auth | from authentik.providers.oauth2.utils import TokenResponse, cors_allow, extract_client_auth | ||||||
| from authentik.providers.oauth2.views.authorize import FORBIDDEN_URI_SCHEMES | from authentik.providers.oauth2.views.authorize import FORBIDDEN_URI_SCHEMES | ||||||
| @ -77,7 +79,7 @@ class TokenParams: | |||||||
|     redirect_uri: str |     redirect_uri: str | ||||||
|     grant_type: str |     grant_type: str | ||||||
|     state: str |     state: str | ||||||
|     scope: list[str] |     scope: set[str] | ||||||
|  |  | ||||||
|     provider: OAuth2Provider |     provider: OAuth2Provider | ||||||
|  |  | ||||||
| @ -112,11 +114,26 @@ class TokenParams: | |||||||
|             redirect_uri=request.POST.get("redirect_uri", ""), |             redirect_uri=request.POST.get("redirect_uri", ""), | ||||||
|             grant_type=request.POST.get("grant_type", ""), |             grant_type=request.POST.get("grant_type", ""), | ||||||
|             state=request.POST.get("state", ""), |             state=request.POST.get("state", ""), | ||||||
|             scope=request.POST.get("scope", "").split(), |             scope=set(request.POST.get("scope", "").split()), | ||||||
|             # PKCE parameter. |             # PKCE parameter. | ||||||
|             code_verifier=request.POST.get("code_verifier"), |             code_verifier=request.POST.get("code_verifier"), | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |     def __check_scopes(self): | ||||||
|  |         allowed_scope_names = set( | ||||||
|  |             ScopeMapping.objects.filter(provider__in=[self.provider]).values_list( | ||||||
|  |                 "scope_name", flat=True | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |         scopes_to_check = self.scope | ||||||
|  |         if not scopes_to_check.issubset(allowed_scope_names): | ||||||
|  |             LOGGER.info( | ||||||
|  |                 "Application requested scopes not configured, setting to overlap", | ||||||
|  |                 scope_allowed=allowed_scope_names, | ||||||
|  |                 scope_given=self.scope, | ||||||
|  |             ) | ||||||
|  |             self.scope = self.scope.intersection(allowed_scope_names) | ||||||
|  |  | ||||||
|     def __check_policy_access(self, app: Application, request: HttpRequest, **kwargs): |     def __check_policy_access(self, app: Application, request: HttpRequest, **kwargs): | ||||||
|         with start_span( |         with start_span( | ||||||
|             op="authentik.providers.oauth2.token.policy", |             op="authentik.providers.oauth2.token.policy", | ||||||
| @ -149,7 +166,7 @@ class TokenParams: | |||||||
|                     client_id=self.provider.client_id, |                     client_id=self.provider.client_id, | ||||||
|                 ) |                 ) | ||||||
|                 raise TokenError("invalid_client") |                 raise TokenError("invalid_client") | ||||||
|  |         self.__check_scopes() | ||||||
|         if self.grant_type == GRANT_TYPE_AUTHORIZATION_CODE: |         if self.grant_type == GRANT_TYPE_AUTHORIZATION_CODE: | ||||||
|             with start_span( |             with start_span( | ||||||
|                 op="authentik.providers.oauth2.post.parse.code", |                 op="authentik.providers.oauth2.post.parse.code", | ||||||
| @ -179,42 +196,7 @@ class TokenParams: | |||||||
|             LOGGER.warning("Missing authorization code") |             LOGGER.warning("Missing authorization code") | ||||||
|             raise TokenError("invalid_grant") |             raise TokenError("invalid_grant") | ||||||
|  |  | ||||||
|         allowed_redirect_urls = self.provider.redirect_uris.split() |         self.__check_redirect_uri(request) | ||||||
|         # At this point, no provider should have a blank redirect_uri, in case they do |  | ||||||
|         # this will check an empty array and raise an error |  | ||||||
|         try: |  | ||||||
|             if not any(fullmatch(x, self.redirect_uri) for x in allowed_redirect_urls): |  | ||||||
|                 LOGGER.warning( |  | ||||||
|                     "Invalid redirect uri (regex comparison)", |  | ||||||
|                     redirect_uri=self.redirect_uri, |  | ||||||
|                     expected=allowed_redirect_urls, |  | ||||||
|                 ) |  | ||||||
|                 Event.new( |  | ||||||
|                     EventAction.CONFIGURATION_ERROR, |  | ||||||
|                     message="Invalid redirect URI used by provider", |  | ||||||
|                     provider=self.provider, |  | ||||||
|                     redirect_uri=self.redirect_uri, |  | ||||||
|                     expected=allowed_redirect_urls, |  | ||||||
|                 ).from_http(request) |  | ||||||
|                 raise TokenError("invalid_client") |  | ||||||
|         except RegexError as exc: |  | ||||||
|             LOGGER.info("Failed to parse regular expression, checking directly", exc=exc) |  | ||||||
|             if not any(x == self.redirect_uri for x in allowed_redirect_urls): |  | ||||||
|                 LOGGER.warning( |  | ||||||
|                     "Invalid redirect uri (strict comparison)", |  | ||||||
|                     redirect_uri=self.redirect_uri, |  | ||||||
|                     expected=allowed_redirect_urls, |  | ||||||
|                 ) |  | ||||||
|                 Event.new( |  | ||||||
|                     EventAction.CONFIGURATION_ERROR, |  | ||||||
|                     message="Invalid redirect_uri configured", |  | ||||||
|                     provider=self.provider, |  | ||||||
|                 ).from_http(request) |  | ||||||
|                 raise TokenError("invalid_client") from None |  | ||||||
|  |  | ||||||
|         # Check against forbidden schemes |  | ||||||
|         if urlparse(self.redirect_uri).scheme in FORBIDDEN_URI_SCHEMES: |  | ||||||
|             raise TokenError("invalid_request") |  | ||||||
|  |  | ||||||
|         self.authorization_code = AuthorizationCode.objects.filter(code=raw_code).first() |         self.authorization_code = AuthorizationCode.objects.filter(code=raw_code).first() | ||||||
|         if not self.authorization_code: |         if not self.authorization_code: | ||||||
| @ -254,6 +236,48 @@ class TokenParams: | |||||||
|         if not self.authorization_code.code_challenge and self.code_verifier: |         if not self.authorization_code.code_challenge and self.code_verifier: | ||||||
|             raise TokenError("invalid_grant") |             raise TokenError("invalid_grant") | ||||||
|  |  | ||||||
|  |     def __check_redirect_uri(self, request: HttpRequest): | ||||||
|  |         allowed_redirect_urls = self.provider.redirect_uris | ||||||
|  |         # At this point, no provider should have a blank redirect_uri, in case they do | ||||||
|  |         # this will check an empty array and raise an error | ||||||
|  |  | ||||||
|  |         match_found = False | ||||||
|  |         for allowed in allowed_redirect_urls: | ||||||
|  |             if allowed.matching_mode == RedirectURIMatchingMode.STRICT: | ||||||
|  |                 if self.redirect_uri == allowed.url: | ||||||
|  |                     match_found = True | ||||||
|  |                     break | ||||||
|  |             if allowed.matching_mode == RedirectURIMatchingMode.REGEX: | ||||||
|  |                 try: | ||||||
|  |                     if fullmatch(allowed.url, self.redirect_uri): | ||||||
|  |                         match_found = True | ||||||
|  |                         break | ||||||
|  |                 except RegexError as exc: | ||||||
|  |                     LOGGER.warning( | ||||||
|  |                         "Failed to parse regular expression", | ||||||
|  |                         exc=exc, | ||||||
|  |                         url=allowed.url, | ||||||
|  |                         provider=self.provider, | ||||||
|  |                     ) | ||||||
|  |                     Event.new( | ||||||
|  |                         EventAction.CONFIGURATION_ERROR, | ||||||
|  |                         message="Invalid redirect_uri configured", | ||||||
|  |                         provider=self.provider, | ||||||
|  |                     ).from_http(request) | ||||||
|  |         if not match_found: | ||||||
|  |             Event.new( | ||||||
|  |                 EventAction.CONFIGURATION_ERROR, | ||||||
|  |                 message="Invalid redirect URI used by provider", | ||||||
|  |                 provider=self.provider, | ||||||
|  |                 redirect_uri=self.redirect_uri, | ||||||
|  |                 expected=allowed_redirect_urls, | ||||||
|  |             ).from_http(request) | ||||||
|  |             raise TokenError("invalid_client") | ||||||
|  |  | ||||||
|  |         # Check against forbidden schemes | ||||||
|  |         if urlparse(self.redirect_uri).scheme in FORBIDDEN_URI_SCHEMES: | ||||||
|  |             raise TokenError("invalid_request") | ||||||
|  |  | ||||||
|     def __post_init_refresh(self, raw_token: str, request: HttpRequest): |     def __post_init_refresh(self, raw_token: str, request: HttpRequest): | ||||||
|         if not raw_token: |         if not raw_token: | ||||||
|             LOGGER.warning("Missing refresh token") |             LOGGER.warning("Missing refresh token") | ||||||
| @ -433,20 +457,20 @@ class TokenParams: | |||||||
|         app = Application.objects.filter(provider=self.provider).first() |         app = Application.objects.filter(provider=self.provider).first() | ||||||
|         if not app or not app.provider: |         if not app or not app.provider: | ||||||
|             raise TokenError("invalid_grant") |             raise TokenError("invalid_grant") | ||||||
|  |         with audit_ignore(): | ||||||
|             self.user, _ = User.objects.update_or_create( |             self.user, _ = User.objects.update_or_create( | ||||||
|                 # trim username to ensure the entire username is max 150 chars |                 # trim username to ensure the entire username is max 150 chars | ||||||
|                 # (22 chars being the length of the "template") |                 # (22 chars being the length of the "template") | ||||||
|                 username=f"ak-{self.provider.name[:150-22]}-client_credentials", |                 username=f"ak-{self.provider.name[:150-22]}-client_credentials", | ||||||
|                 defaults={ |                 defaults={ | ||||||
|                 "attributes": { |  | ||||||
|                     USER_ATTRIBUTE_GENERATED: True, |  | ||||||
|                 }, |  | ||||||
|                     "last_login": timezone.now(), |                     "last_login": timezone.now(), | ||||||
|                     "name": f"Autogenerated user from application {app.name} (client credentials)", |                     "name": f"Autogenerated user from application {app.name} (client credentials)", | ||||||
|                     "path": f"{USER_PATH_SYSTEM_PREFIX}/apps/{app.slug}", |                     "path": f"{USER_PATH_SYSTEM_PREFIX}/apps/{app.slug}", | ||||||
|                     "type": UserTypes.SERVICE_ACCOUNT, |                     "type": UserTypes.SERVICE_ACCOUNT, | ||||||
|                 }, |                 }, | ||||||
|             ) |             ) | ||||||
|  |             self.user.attributes[USER_ATTRIBUTE_GENERATED] = True | ||||||
|  |             self.user.save() | ||||||
|         self.__check_policy_access(app, request) |         self.__check_policy_access(app, request) | ||||||
|  |  | ||||||
|         Event.new( |         Event.new( | ||||||
| @ -470,9 +494,6 @@ class TokenParams: | |||||||
|             self.user, created = User.objects.update_or_create( |             self.user, created = User.objects.update_or_create( | ||||||
|                 username=f"{self.provider.name}-{token.get('sub')}", |                 username=f"{self.provider.name}-{token.get('sub')}", | ||||||
|                 defaults={ |                 defaults={ | ||||||
|                     "attributes": { |  | ||||||
|                         USER_ATTRIBUTE_GENERATED: True, |  | ||||||
|                     }, |  | ||||||
|                     "last_login": timezone.now(), |                     "last_login": timezone.now(), | ||||||
|                     "name": ( |                     "name": ( | ||||||
|                         f"Autogenerated user from application {app.name} (client credentials JWT)" |                         f"Autogenerated user from application {app.name} (client credentials JWT)" | ||||||
| @ -481,6 +502,8 @@ class TokenParams: | |||||||
|                     "type": UserTypes.SERVICE_ACCOUNT, |                     "type": UserTypes.SERVICE_ACCOUNT, | ||||||
|                 }, |                 }, | ||||||
|             ) |             ) | ||||||
|  |             self.user.attributes[USER_ATTRIBUTE_GENERATED] = True | ||||||
|  |             self.user.save() | ||||||
|             exp = token.get("exp") |             exp = token.get("exp") | ||||||
|             if created and exp: |             if created and exp: | ||||||
|                 self.user.attributes[USER_ATTRIBUTE_EXPIRES] = exp |                 self.user.attributes[USER_ATTRIBUTE_EXPIRES] = exp | ||||||
| @ -498,7 +521,7 @@ class TokenView(View): | |||||||
|         response = super().dispatch(request, *args, **kwargs) |         response = super().dispatch(request, *args, **kwargs) | ||||||
|         allowed_origins = [] |         allowed_origins = [] | ||||||
|         if self.provider: |         if self.provider: | ||||||
|             allowed_origins = self.provider.redirect_uris.split("\n") |             allowed_origins = [x.url for x in self.provider.redirect_uris] | ||||||
|         cors_allow(self.request, response, *allowed_origins) |         cors_allow(self.request, response, *allowed_origins) | ||||||
|         return response |         return response | ||||||
|  |  | ||||||
| @ -551,7 +574,7 @@ class TokenView(View): | |||||||
|             # Keep same scopes as previous token |             # Keep same scopes as previous token | ||||||
|             scope=self.params.authorization_code.scope, |             scope=self.params.authorization_code.scope, | ||||||
|             auth_time=self.params.authorization_code.auth_time, |             auth_time=self.params.authorization_code.auth_time, | ||||||
|             session_id=self.params.authorization_code.session_id, |             session=self.params.authorization_code.session, | ||||||
|         ) |         ) | ||||||
|         access_id_token = IDToken.new( |         access_id_token = IDToken.new( | ||||||
|             self.provider, |             self.provider, | ||||||
| @ -579,7 +602,7 @@ class TokenView(View): | |||||||
|                 expires=refresh_token_expiry, |                 expires=refresh_token_expiry, | ||||||
|                 provider=self.provider, |                 provider=self.provider, | ||||||
|                 auth_time=self.params.authorization_code.auth_time, |                 auth_time=self.params.authorization_code.auth_time, | ||||||
|                 session_id=self.params.authorization_code.session_id, |                 session=self.params.authorization_code.session, | ||||||
|             ) |             ) | ||||||
|             id_token = IDToken.new( |             id_token = IDToken.new( | ||||||
|                 self.provider, |                 self.provider, | ||||||
| @ -612,7 +635,7 @@ class TokenView(View): | |||||||
|             # Keep same scopes as previous token |             # Keep same scopes as previous token | ||||||
|             scope=self.params.refresh_token.scope, |             scope=self.params.refresh_token.scope, | ||||||
|             auth_time=self.params.refresh_token.auth_time, |             auth_time=self.params.refresh_token.auth_time, | ||||||
|             session_id=self.params.refresh_token.session_id, |             session=self.params.refresh_token.session, | ||||||
|         ) |         ) | ||||||
|         access_token.id_token = IDToken.new( |         access_token.id_token = IDToken.new( | ||||||
|             self.provider, |             self.provider, | ||||||
| @ -628,7 +651,7 @@ class TokenView(View): | |||||||
|             expires=refresh_token_expiry, |             expires=refresh_token_expiry, | ||||||
|             provider=self.provider, |             provider=self.provider, | ||||||
|             auth_time=self.params.refresh_token.auth_time, |             auth_time=self.params.refresh_token.auth_time, | ||||||
|             session_id=self.params.refresh_token.session_id, |             session=self.params.refresh_token.session, | ||||||
|         ) |         ) | ||||||
|         id_token = IDToken.new( |         id_token = IDToken.new( | ||||||
|             self.provider, |             self.provider, | ||||||
| @ -686,13 +709,14 @@ class TokenView(View): | |||||||
|             raise DeviceCodeError("authorization_pending") |             raise DeviceCodeError("authorization_pending") | ||||||
|         now = timezone.now() |         now = timezone.now() | ||||||
|         access_token_expiry = now + timedelta_from_string(self.provider.access_token_validity) |         access_token_expiry = now + timedelta_from_string(self.provider.access_token_validity) | ||||||
|         auth_event = get_login_event(self.request) |         auth_event = get_login_event(self.params.device_code.session) | ||||||
|         access_token = AccessToken( |         access_token = AccessToken( | ||||||
|             provider=self.provider, |             provider=self.provider, | ||||||
|             user=self.params.device_code.user, |             user=self.params.device_code.user, | ||||||
|             expires=access_token_expiry, |             expires=access_token_expiry, | ||||||
|             scope=self.params.device_code.scope, |             scope=self.params.device_code.scope, | ||||||
|             auth_time=auth_event.created if auth_event else now, |             auth_time=auth_event.created if auth_event else now, | ||||||
|  |             session=self.params.device_code.session, | ||||||
|         ) |         ) | ||||||
|         access_token.id_token = IDToken.new( |         access_token.id_token = IDToken.new( | ||||||
|             self.provider, |             self.provider, | ||||||
| @ -710,7 +734,7 @@ class TokenView(View): | |||||||
|             "id_token": access_token.id_token.to_jwt(self.provider), |             "id_token": access_token.id_token.to_jwt(self.provider), | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         if SCOPE_OFFLINE_ACCESS in self.params.scope: |         if SCOPE_OFFLINE_ACCESS in self.params.device_code.scope: | ||||||
|             refresh_token_expiry = now + timedelta_from_string(self.provider.refresh_token_validity) |             refresh_token_expiry = now + timedelta_from_string(self.provider.refresh_token_validity) | ||||||
|             refresh_token = RefreshToken( |             refresh_token = RefreshToken( | ||||||
|                 user=self.params.device_code.user, |                 user=self.params.device_code.user, | ||||||
|  | |||||||
| @ -108,7 +108,7 @@ class UserInfoView(View): | |||||||
|         response = super().dispatch(request, *args, **kwargs) |         response = super().dispatch(request, *args, **kwargs) | ||||||
|         allowed_origins = [] |         allowed_origins = [] | ||||||
|         if self.token: |         if self.token: | ||||||
|             allowed_origins = self.token.provider.redirect_uris.split("\n") |             allowed_origins = [x.url for x in self.token.provider.redirect_uris] | ||||||
|         cors_allow(self.request, response, *allowed_origins) |         cors_allow(self.request, response, *allowed_origins) | ||||||
|         return response |         return response | ||||||
|  |  | ||||||
|  | |||||||
| @ -121,7 +121,6 @@ class ProxyProviderViewSet(UsedByMixin, ModelViewSet): | |||||||
|         "basic_auth_password_attribute": ["iexact"], |         "basic_auth_password_attribute": ["iexact"], | ||||||
|         "basic_auth_user_attribute": ["iexact"], |         "basic_auth_user_attribute": ["iexact"], | ||||||
|         "mode": ["iexact"], |         "mode": ["iexact"], | ||||||
|         "redirect_uris": ["iexact"], |  | ||||||
|         "cookie_domain": ["iexact"], |         "cookie_domain": ["iexact"], | ||||||
|     } |     } | ||||||
|     search_fields = ["name"] |     search_fields = ["name"] | ||||||
|  | |||||||
| @ -28,7 +28,7 @@ class ProxyDockerController(DockerController): | |||||||
|         labels = super()._get_labels() |         labels = super()._get_labels() | ||||||
|         labels["traefik.enable"] = "true" |         labels["traefik.enable"] = "true" | ||||||
|         labels[f"traefik.http.routers.{traefik_name}-router.rule"] = ( |         labels[f"traefik.http.routers.{traefik_name}-router.rule"] = ( | ||||||
|             f"({' || '.join([f'Host(`{host}`)' for host in hosts])})" |             f"({' || '.join([f'Host({host})' for host in hosts])})" | ||||||
|             f" && PathPrefix(`/outpost.goauthentik.io`)" |             f" && PathPrefix(`/outpost.goauthentik.io`)" | ||||||
|         ) |         ) | ||||||
|         labels[f"traefik.http.routers.{traefik_name}-router.tls"] = "true" |         labels[f"traefik.http.routers.{traefik_name}-router.tls"] = "true" | ||||||
|  | |||||||
| @ -13,7 +13,13 @@ from rest_framework.serializers import Serializer | |||||||
| from authentik.crypto.models import CertificateKeyPair | from authentik.crypto.models import CertificateKeyPair | ||||||
| from authentik.lib.models import DomainlessURLValidator | from authentik.lib.models import DomainlessURLValidator | ||||||
| from authentik.outposts.models import OutpostModel | from authentik.outposts.models import OutpostModel | ||||||
| from authentik.providers.oauth2.models import ClientTypes, OAuth2Provider, ScopeMapping | from authentik.providers.oauth2.models import ( | ||||||
|  |     ClientTypes, | ||||||
|  |     OAuth2Provider, | ||||||
|  |     RedirectURI, | ||||||
|  |     RedirectURIMatchingMode, | ||||||
|  |     ScopeMapping, | ||||||
|  | ) | ||||||
|  |  | ||||||
| SCOPE_AK_PROXY = "ak_proxy" | SCOPE_AK_PROXY = "ak_proxy" | ||||||
| OUTPOST_CALLBACK_SIGNATURE = "X-authentik-auth-callback" | OUTPOST_CALLBACK_SIGNATURE = "X-authentik-auth-callback" | ||||||
| @ -24,14 +30,15 @@ def get_cookie_secret(): | |||||||
|     return "".join(SystemRandom().choice(string.ascii_uppercase + string.digits) for _ in range(32)) |     return "".join(SystemRandom().choice(string.ascii_uppercase + string.digits) for _ in range(32)) | ||||||
|  |  | ||||||
|  |  | ||||||
| def _get_callback_url(uri: str) -> str: | def _get_callback_url(uri: str) -> list[RedirectURI]: | ||||||
|     return "\n".join( |     return [ | ||||||
|         [ |         RedirectURI( | ||||||
|  |             RedirectURIMatchingMode.STRICT, | ||||||
|             urljoin(uri, "outpost.goauthentik.io/callback") |             urljoin(uri, "outpost.goauthentik.io/callback") | ||||||
|             + f"\\?{OUTPOST_CALLBACK_SIGNATURE}=true", |             + f"\\?{OUTPOST_CALLBACK_SIGNATURE}=true", | ||||||
|             uri + f"\\?{OUTPOST_CALLBACK_SIGNATURE}=true", |         ), | ||||||
|  |         RedirectURI(RedirectURIMatchingMode.STRICT, uri + f"\\?{OUTPOST_CALLBACK_SIGNATURE}=true"), | ||||||
|     ] |     ] | ||||||
|     ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class ProxyMode(models.TextChoices): | class ProxyMode(models.TextChoices): | ||||||
|  | |||||||
| @ -1,13 +1,12 @@ | |||||||
| """proxy provider tasks""" | """proxy provider tasks""" | ||||||
|  |  | ||||||
| from hashlib import sha256 |  | ||||||
|  |  | ||||||
| from asgiref.sync import async_to_sync | from asgiref.sync import async_to_sync | ||||||
| from channels.layers import get_channel_layer | from channels.layers import get_channel_layer | ||||||
| from django.db import DatabaseError, InternalError, ProgrammingError | from django.db import DatabaseError, InternalError, ProgrammingError | ||||||
|  |  | ||||||
| from authentik.outposts.consumer import OUTPOST_GROUP | from authentik.outposts.consumer import OUTPOST_GROUP | ||||||
| from authentik.outposts.models import Outpost, OutpostType | from authentik.outposts.models import Outpost, OutpostType | ||||||
|  | from authentik.providers.oauth2.id_token import hash_session_key | ||||||
| from authentik.providers.proxy.models import ProxyProvider | from authentik.providers.proxy.models import ProxyProvider | ||||||
| from authentik.root.celery import CELERY_APP | from authentik.root.celery import CELERY_APP | ||||||
|  |  | ||||||
| @ -26,7 +25,7 @@ def proxy_set_defaults(): | |||||||
| def proxy_on_logout(session_id: str): | def proxy_on_logout(session_id: str): | ||||||
|     """Update outpost instances connected to a single outpost""" |     """Update outpost instances connected to a single outpost""" | ||||||
|     layer = get_channel_layer() |     layer = get_channel_layer() | ||||||
|     hashed_session_id = sha256(session_id.encode("ascii")).hexdigest() |     hashed_session_id = hash_session_key(session_id) | ||||||
|     for outpost in Outpost.objects.filter(type=OutpostType.PROXY): |     for outpost in Outpost.objects.filter(type=OutpostType.PROXY): | ||||||
|         group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)} |         group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)} | ||||||
|         async_to_sync(layer.group_send)( |         async_to_sync(layer.group_send)( | ||||||
|  | |||||||
| @ -164,7 +164,7 @@ class SAMLProvider(Provider): | |||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     sign_assertion = models.BooleanField(default=True) |     sign_assertion = models.BooleanField(default=True) | ||||||
|     sign_response = models.BooleanField(default=True) |     sign_response = models.BooleanField(default=False) | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def launch_url(self) -> str | None: |     def launch_url(self) -> str | None: | ||||||
|  | |||||||
| @ -50,6 +50,7 @@ class AssertionProcessor: | |||||||
|  |  | ||||||
|     _issue_instant: str |     _issue_instant: str | ||||||
|     _assertion_id: str |     _assertion_id: str | ||||||
|  |     _response_id: str | ||||||
|  |  | ||||||
|     _valid_not_before: str |     _valid_not_before: str | ||||||
|     _session_not_on_or_after: str |     _session_not_on_or_after: str | ||||||
| @ -62,6 +63,7 @@ class AssertionProcessor: | |||||||
|  |  | ||||||
|         self._issue_instant = get_time_string() |         self._issue_instant = get_time_string() | ||||||
|         self._assertion_id = get_random_id() |         self._assertion_id = get_random_id() | ||||||
|  |         self._response_id = get_random_id() | ||||||
|  |  | ||||||
|         self._valid_not_before = get_time_string( |         self._valid_not_before = get_time_string( | ||||||
|             timedelta_from_string(self.provider.assertion_valid_not_before) |             timedelta_from_string(self.provider.assertion_valid_not_before) | ||||||
| @ -130,7 +132,9 @@ class AssertionProcessor: | |||||||
|         """Generate AuthnStatement with AuthnContext and ContextClassRef Elements.""" |         """Generate AuthnStatement with AuthnContext and ContextClassRef Elements.""" | ||||||
|         auth_n_statement = Element(f"{{{NS_SAML_ASSERTION}}}AuthnStatement") |         auth_n_statement = Element(f"{{{NS_SAML_ASSERTION}}}AuthnStatement") | ||||||
|         auth_n_statement.attrib["AuthnInstant"] = self._valid_not_before |         auth_n_statement.attrib["AuthnInstant"] = self._valid_not_before | ||||||
|         auth_n_statement.attrib["SessionIndex"] = self._assertion_id |         auth_n_statement.attrib["SessionIndex"] = sha256( | ||||||
|  |             self.http_request.session.session_key.encode("ascii") | ||||||
|  |         ).hexdigest() | ||||||
|         auth_n_statement.attrib["SessionNotOnOrAfter"] = self._session_not_on_or_after |         auth_n_statement.attrib["SessionNotOnOrAfter"] = self._session_not_on_or_after | ||||||
|  |  | ||||||
|         auth_n_context = SubElement(auth_n_statement, f"{{{NS_SAML_ASSERTION}}}AuthnContext") |         auth_n_context = SubElement(auth_n_statement, f"{{{NS_SAML_ASSERTION}}}AuthnContext") | ||||||
| @ -285,7 +289,7 @@ class AssertionProcessor: | |||||||
|         response.attrib["Version"] = "2.0" |         response.attrib["Version"] = "2.0" | ||||||
|         response.attrib["IssueInstant"] = self._issue_instant |         response.attrib["IssueInstant"] = self._issue_instant | ||||||
|         response.attrib["Destination"] = self.provider.acs_url |         response.attrib["Destination"] = self.provider.acs_url | ||||||
|         response.attrib["ID"] = get_random_id() |         response.attrib["ID"] = self._response_id | ||||||
|         if self.auth_n_request.id: |         if self.auth_n_request.id: | ||||||
|             response.attrib["InResponseTo"] = self.auth_n_request.id |             response.attrib["InResponseTo"] = self.auth_n_request.id | ||||||
|  |  | ||||||
| @ -308,7 +312,7 @@ class AssertionProcessor: | |||||||
|         ref = xmlsec.template.add_reference( |         ref = xmlsec.template.add_reference( | ||||||
|             signature_node, |             signature_node, | ||||||
|             digest_algorithm_transform, |             digest_algorithm_transform, | ||||||
|             uri="#" + self._assertion_id, |             uri="#" + element.attrib["ID"], | ||||||
|         ) |         ) | ||||||
|         xmlsec.template.add_transform(ref, xmlsec.constants.TransformEnveloped) |         xmlsec.template.add_transform(ref, xmlsec.constants.TransformEnveloped) | ||||||
|         xmlsec.template.add_transform(ref, xmlsec.constants.TransformExclC14N) |         xmlsec.template.add_transform(ref, xmlsec.constants.TransformExclC14N) | ||||||
|  | |||||||
| @ -180,6 +180,10 @@ class TestAuthNRequest(TestCase): | |||||||
|         # Now create a response and convert it to string (provider) |         # Now create a response and convert it to string (provider) | ||||||
|         response_proc = AssertionProcessor(self.provider, http_request, parsed_request) |         response_proc = AssertionProcessor(self.provider, http_request, parsed_request) | ||||||
|         response = response_proc.build_response() |         response = response_proc.build_response() | ||||||
|  |         # Ensure both response and assertion ID are in the response twice (once as ID attribute, | ||||||
|  |         # once as ds:Reference URI) | ||||||
|  |         self.assertEqual(response.count(response_proc._assertion_id), 2) | ||||||
|  |         self.assertEqual(response.count(response_proc._response_id), 2) | ||||||
|  |  | ||||||
|         # Now parse the response (source) |         # Now parse the response (source) | ||||||
|         http_request.POST = QueryDict(mutable=True) |         http_request.POST = QueryDict(mutable=True) | ||||||
|  | |||||||
| @ -54,7 +54,11 @@ class TestServiceProviderMetadataParser(TestCase): | |||||||
|         request = self.factory.get("/") |         request = self.factory.get("/") | ||||||
|         metadata = lxml_from_string(MetadataProcessor(provider, request).build_entity_descriptor()) |         metadata = lxml_from_string(MetadataProcessor(provider, request).build_entity_descriptor()) | ||||||
|  |  | ||||||
|         schema = etree.XMLSchema(etree.parse("schemas/saml-schema-metadata-2.0.xsd"))  # nosec |         schema = etree.XMLSchema( | ||||||
|  |             etree.parse( | ||||||
|  |                 source="schemas/saml-schema-metadata-2.0.xsd", parser=etree.XMLParser() | ||||||
|  |             )  # nosec | ||||||
|  |         ) | ||||||
|         self.assertTrue(schema.validate(metadata)) |         self.assertTrue(schema.validate(metadata)) | ||||||
|  |  | ||||||
|     def test_schema_want_authn_requests_signed(self): |     def test_schema_want_authn_requests_signed(self): | ||||||
|  | |||||||
| @ -47,7 +47,9 @@ class TestSchema(TestCase): | |||||||
|  |  | ||||||
|         metadata = lxml_from_string(request) |         metadata = lxml_from_string(request) | ||||||
|  |  | ||||||
|         schema = etree.XMLSchema(etree.parse("schemas/saml-schema-protocol-2.0.xsd"))  # nosec |         schema = etree.XMLSchema( | ||||||
|  |             etree.parse("schemas/saml-schema-protocol-2.0.xsd", parser=etree.XMLParser())  # nosec | ||||||
|  |         ) | ||||||
|         self.assertTrue(schema.validate(metadata)) |         self.assertTrue(schema.validate(metadata)) | ||||||
|  |  | ||||||
|     def test_response_schema(self): |     def test_response_schema(self): | ||||||
| @ -68,5 +70,7 @@ class TestSchema(TestCase): | |||||||
|  |  | ||||||
|         metadata = lxml_from_string(response) |         metadata = lxml_from_string(response) | ||||||
|  |  | ||||||
|         schema = etree.XMLSchema(etree.parse("schemas/saml-schema-protocol-2.0.xsd"))  # nosec |         schema = etree.XMLSchema( | ||||||
|  |             etree.parse("schemas/saml-schema-protocol-2.0.xsd", parser=etree.XMLParser())  # nosec | ||||||
|  |         ) | ||||||
|         self.assertTrue(schema.validate(metadata)) |         self.assertTrue(schema.validate(metadata)) | ||||||
|  | |||||||
| @ -2,9 +2,10 @@ | |||||||
|  |  | ||||||
| from itertools import batched | from itertools import batched | ||||||
|  |  | ||||||
|  | from django.db import transaction | ||||||
| from pydantic import ValidationError | from pydantic import ValidationError | ||||||
| from pydanticscim.group import GroupMember | from pydanticscim.group import GroupMember | ||||||
| from pydanticscim.responses import PatchOp, PatchOperation | from pydanticscim.responses import PatchOp | ||||||
|  |  | ||||||
| from authentik.core.models import Group | from authentik.core.models import Group | ||||||
| from authentik.lib.sync.mapper import PropertyMappingManager | from authentik.lib.sync.mapper import PropertyMappingManager | ||||||
| @ -19,7 +20,7 @@ from authentik.providers.scim.clients.base import SCIMClient | |||||||
| from authentik.providers.scim.clients.exceptions import ( | from authentik.providers.scim.clients.exceptions import ( | ||||||
|     SCIMRequestException, |     SCIMRequestException, | ||||||
| ) | ) | ||||||
| from authentik.providers.scim.clients.schema import SCIM_GROUP_SCHEMA, PatchRequest | from authentik.providers.scim.clients.schema import SCIM_GROUP_SCHEMA, PatchOperation, PatchRequest | ||||||
| from authentik.providers.scim.clients.schema import Group as SCIMGroupSchema | from authentik.providers.scim.clients.schema import Group as SCIMGroupSchema | ||||||
| from authentik.providers.scim.models import ( | from authentik.providers.scim.models import ( | ||||||
|     SCIMMapping, |     SCIMMapping, | ||||||
| @ -104,13 +105,47 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]): | |||||||
|             provider=self.provider, group=group, scim_id=scim_id |             provider=self.provider, group=group, scim_id=scim_id | ||||||
|         ) |         ) | ||||||
|         users = list(group.users.order_by("id").values_list("id", flat=True)) |         users = list(group.users.order_by("id").values_list("id", flat=True)) | ||||||
|         self._patch_add_users(group, users) |         self._patch_add_users(connection, users) | ||||||
|         return connection |         return connection | ||||||
|  |  | ||||||
|     def update(self, group: Group, connection: SCIMProviderGroup): |     def update(self, group: Group, connection: SCIMProviderGroup): | ||||||
|         """Update existing group""" |         """Update existing group""" | ||||||
|         scim_group = self.to_schema(group, connection) |         scim_group = self.to_schema(group, connection) | ||||||
|         scim_group.id = connection.scim_id |         scim_group.id = connection.scim_id | ||||||
|  |         try: | ||||||
|  |             if self._config.patch.supported: | ||||||
|  |                 return self._update_patch(group, scim_group, connection) | ||||||
|  |             return self._update_put(group, scim_group, connection) | ||||||
|  |         except NotFoundSyncException: | ||||||
|  |             # Resource missing is handled by self.write, which will re-create the group | ||||||
|  |             raise | ||||||
|  |  | ||||||
|  |     def _update_patch( | ||||||
|  |         self, group: Group, scim_group: SCIMGroupSchema, connection: SCIMProviderGroup | ||||||
|  |     ): | ||||||
|  |         """Update a group via PATCH request""" | ||||||
|  |         # Patch group's attributes instead of replacing it and re-adding users if we can | ||||||
|  |         self._request( | ||||||
|  |             "PATCH", | ||||||
|  |             f"/Groups/{connection.scim_id}", | ||||||
|  |             json=PatchRequest( | ||||||
|  |                 Operations=[ | ||||||
|  |                     PatchOperation( | ||||||
|  |                         op=PatchOp.replace, | ||||||
|  |                         path=None, | ||||||
|  |                         value=scim_group.model_dump(mode="json", exclude_unset=True), | ||||||
|  |                     ) | ||||||
|  |                 ] | ||||||
|  |             ).model_dump( | ||||||
|  |                 mode="json", | ||||||
|  |                 exclude_unset=True, | ||||||
|  |                 exclude_none=True, | ||||||
|  |             ), | ||||||
|  |         ) | ||||||
|  |         return self.patch_compare_users(group) | ||||||
|  |  | ||||||
|  |     def _update_put(self, group: Group, scim_group: SCIMGroupSchema, connection: SCIMProviderGroup): | ||||||
|  |         """Update a group via PUT request""" | ||||||
|         try: |         try: | ||||||
|             self._request( |             self._request( | ||||||
|                 "PUT", |                 "PUT", | ||||||
| @ -120,33 +155,25 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]): | |||||||
|                     exclude_unset=True, |                     exclude_unset=True, | ||||||
|                 ), |                 ), | ||||||
|             ) |             ) | ||||||
|             users = list(group.users.order_by("id").values_list("id", flat=True)) |             return self.patch_compare_users(group) | ||||||
|             return self._patch_add_users(group, users) |  | ||||||
|         except NotFoundSyncException: |  | ||||||
|             # Resource missing is handled by self.write, which will re-create the group |  | ||||||
|             raise |  | ||||||
|         except (SCIMRequestException, ObjectExistsSyncException): |         except (SCIMRequestException, ObjectExistsSyncException): | ||||||
|             # Some providers don't support PUT on groups, so this is mainly a fix for the initial |             # Some providers don't support PUT on groups, so this is mainly a fix for the initial | ||||||
|             # sync, send patch add requests for all the users the group currently has |             # sync, send patch add requests for all the users the group currently has | ||||||
|             users = list(group.users.order_by("id").values_list("id", flat=True)) |             return self._update_patch(group, scim_group, connection) | ||||||
|             self._patch_add_users(group, users) |  | ||||||
|             # Also update the group name |  | ||||||
|             return self._patch( |  | ||||||
|                 scim_group.id, |  | ||||||
|                 PatchOperation( |  | ||||||
|                     op=PatchOp.replace, |  | ||||||
|                     path="displayName", |  | ||||||
|                     value=scim_group.displayName, |  | ||||||
|                 ), |  | ||||||
|             ) |  | ||||||
|  |  | ||||||
|     def update_group(self, group: Group, action: Direction, users_set: set[int]): |     def update_group(self, group: Group, action: Direction, users_set: set[int]): | ||||||
|         """Update a group, either using PUT to replace it or PATCH if supported""" |         """Update a group, either using PUT to replace it or PATCH if supported""" | ||||||
|  |         scim_group = SCIMProviderGroup.objects.filter(provider=self.provider, group=group).first() | ||||||
|  |         if not scim_group: | ||||||
|  |             self.logger.warning( | ||||||
|  |                 "could not sync group membership, group does not exist", group=group | ||||||
|  |             ) | ||||||
|  |             return | ||||||
|         if self._config.patch.supported: |         if self._config.patch.supported: | ||||||
|             if action == Direction.add: |             if action == Direction.add: | ||||||
|                 return self._patch_add_users(group, users_set) |                 return self._patch_add_users(scim_group, users_set) | ||||||
|             if action == Direction.remove: |             if action == Direction.remove: | ||||||
|                 return self._patch_remove_users(group, users_set) |                 return self._patch_remove_users(scim_group, users_set) | ||||||
|         try: |         try: | ||||||
|             return self.write(group) |             return self.write(group) | ||||||
|         except SCIMRequestException as exc: |         except SCIMRequestException as exc: | ||||||
| @ -154,19 +181,24 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]): | |||||||
|                 # Assume that provider does not support PUT and also doesn't support |                 # Assume that provider does not support PUT and also doesn't support | ||||||
|                 # ServiceProviderConfig, so try PATCH as a fallback |                 # ServiceProviderConfig, so try PATCH as a fallback | ||||||
|                 if action == Direction.add: |                 if action == Direction.add: | ||||||
|                     return self._patch_add_users(group, users_set) |                     return self._patch_add_users(scim_group, users_set) | ||||||
|                 if action == Direction.remove: |                 if action == Direction.remove: | ||||||
|                     return self._patch_remove_users(group, users_set) |                     return self._patch_remove_users(scim_group, users_set) | ||||||
|             raise exc |             raise exc | ||||||
|  |  | ||||||
|     def _patch( |     def _patch_chunked( | ||||||
|         self, |         self, | ||||||
|         group_id: str, |         group_id: str, | ||||||
|         *ops: PatchOperation, |         *ops: PatchOperation, | ||||||
|     ): |     ): | ||||||
|  |         """Helper function that chunks patch requests based on the maxOperations attribute. | ||||||
|  |         This is not strictly according to specs but there's nothing in the schema that allows the | ||||||
|  |         us to know what the maximum patch operations per request should be.""" | ||||||
|         chunk_size = self._config.bulk.maxOperations |         chunk_size = self._config.bulk.maxOperations | ||||||
|         if chunk_size < 1: |         if chunk_size < 1: | ||||||
|             chunk_size = len(ops) |             chunk_size = len(ops) | ||||||
|  |         if len(ops) < 1: | ||||||
|  |             return | ||||||
|         for chunk in batched(ops, chunk_size): |         for chunk in batched(ops, chunk_size): | ||||||
|             req = PatchRequest(Operations=list(chunk)) |             req = PatchRequest(Operations=list(chunk)) | ||||||
|             self._request( |             self._request( | ||||||
| @ -177,16 +209,70 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]): | |||||||
|                 ), |                 ), | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|     def _patch_add_users(self, group: Group, users_set: set[int]): |     @transaction.atomic | ||||||
|         """Add users in users_set to group""" |     def patch_compare_users(self, group: Group): | ||||||
|         if len(users_set) < 1: |         """Compare users with a SCIM group and add/remove any differences""" | ||||||
|             return |         # Get scim group first | ||||||
|         scim_group = SCIMProviderGroup.objects.filter(provider=self.provider, group=group).first() |         scim_group = SCIMProviderGroup.objects.filter(provider=self.provider, group=group).first() | ||||||
|         if not scim_group: |         if not scim_group: | ||||||
|             self.logger.warning( |             self.logger.warning( | ||||||
|                 "could not sync group membership, group does not exist", group=group |                 "could not sync group membership, group does not exist", group=group | ||||||
|             ) |             ) | ||||||
|             return |             return | ||||||
|  |         # Get a list of all users in the authentik group | ||||||
|  |         raw_users_should = list(group.users.order_by("id").values_list("id", flat=True)) | ||||||
|  |         # Lookup the SCIM IDs of the users | ||||||
|  |         users_should: list[str] = list( | ||||||
|  |             SCIMProviderUser.objects.filter( | ||||||
|  |                 user__pk__in=raw_users_should, provider=self.provider | ||||||
|  |             ).values_list("scim_id", flat=True) | ||||||
|  |         ) | ||||||
|  |         if len(raw_users_should) != len(users_should): | ||||||
|  |             self.logger.warning( | ||||||
|  |                 "User count mismatch, not all users in the group are synced to SCIM yet.", | ||||||
|  |                 group=group, | ||||||
|  |             ) | ||||||
|  |         # Get current group status | ||||||
|  |         current_group = SCIMGroupSchema.model_validate( | ||||||
|  |             self._request("GET", f"/Groups/{scim_group.scim_id}") | ||||||
|  |         ) | ||||||
|  |         users_to_add = [] | ||||||
|  |         users_to_remove = [] | ||||||
|  |         # Check users currently in group and if they shouldn't be in the group and remove them | ||||||
|  |         for user in current_group.members or []: | ||||||
|  |             if user.value not in users_should: | ||||||
|  |                 users_to_remove.append(user.value) | ||||||
|  |         # Check users that should be in the group and add them | ||||||
|  |         for user in users_should: | ||||||
|  |             if len([x for x in current_group.members if x.value == user]) < 1: | ||||||
|  |                 users_to_add.append(user) | ||||||
|  |         # Only send request if we need to make changes | ||||||
|  |         if len(users_to_add) < 1 and len(users_to_remove) < 1: | ||||||
|  |             return | ||||||
|  |         return self._patch_chunked( | ||||||
|  |             scim_group.scim_id, | ||||||
|  |             *[ | ||||||
|  |                 PatchOperation( | ||||||
|  |                     op=PatchOp.add, | ||||||
|  |                     path="members", | ||||||
|  |                     value=[{"value": x}], | ||||||
|  |                 ) | ||||||
|  |                 for x in users_to_add | ||||||
|  |             ], | ||||||
|  |             *[ | ||||||
|  |                 PatchOperation( | ||||||
|  |                     op=PatchOp.remove, | ||||||
|  |                     path="members", | ||||||
|  |                     value=[{"value": x}], | ||||||
|  |                 ) | ||||||
|  |                 for x in users_to_remove | ||||||
|  |             ], | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def _patch_add_users(self, scim_group: SCIMProviderGroup, users_set: set[int]): | ||||||
|  |         """Add users in users_set to group""" | ||||||
|  |         if len(users_set) < 1: | ||||||
|  |             return | ||||||
|         user_ids = list( |         user_ids = list( | ||||||
|             SCIMProviderUser.objects.filter( |             SCIMProviderUser.objects.filter( | ||||||
|                 user__pk__in=users_set, provider=self.provider |                 user__pk__in=users_set, provider=self.provider | ||||||
| @ -194,7 +280,7 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]): | |||||||
|         ) |         ) | ||||||
|         if len(user_ids) < 1: |         if len(user_ids) < 1: | ||||||
|             return |             return | ||||||
|         self._patch( |         self._patch_chunked( | ||||||
|             scim_group.scim_id, |             scim_group.scim_id, | ||||||
|             *[ |             *[ | ||||||
|                 PatchOperation( |                 PatchOperation( | ||||||
| @ -206,16 +292,10 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]): | |||||||
|             ], |             ], | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def _patch_remove_users(self, group: Group, users_set: set[int]): |     def _patch_remove_users(self, scim_group: SCIMProviderGroup, users_set: set[int]): | ||||||
|         """Remove users in users_set from group""" |         """Remove users in users_set from group""" | ||||||
|         if len(users_set) < 1: |         if len(users_set) < 1: | ||||||
|             return |             return | ||||||
|         scim_group = SCIMProviderGroup.objects.filter(provider=self.provider, group=group).first() |  | ||||||
|         if not scim_group: |  | ||||||
|             self.logger.warning( |  | ||||||
|                 "could not sync group membership, group does not exist", group=group |  | ||||||
|             ) |  | ||||||
|             return |  | ||||||
|         user_ids = list( |         user_ids = list( | ||||||
|             SCIMProviderUser.objects.filter( |             SCIMProviderUser.objects.filter( | ||||||
|                 user__pk__in=users_set, provider=self.provider |                 user__pk__in=users_set, provider=self.provider | ||||||
| @ -223,7 +303,7 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]): | |||||||
|         ) |         ) | ||||||
|         if len(user_ids) < 1: |         if len(user_ids) < 1: | ||||||
|             return |             return | ||||||
|         self._patch( |         self._patch_chunked( | ||||||
|             scim_group.scim_id, |             scim_group.scim_id, | ||||||
|             *[ |             *[ | ||||||
|                 PatchOperation( |                 PatchOperation( | ||||||
|  | |||||||
| @ -2,6 +2,7 @@ | |||||||
|  |  | ||||||
| from pydantic import Field | from pydantic import Field | ||||||
| from pydanticscim.group import Group as BaseGroup | from pydanticscim.group import Group as BaseGroup | ||||||
|  | from pydanticscim.responses import PatchOperation as BasePatchOperation | ||||||
| from pydanticscim.responses import PatchRequest as BasePatchRequest | from pydanticscim.responses import PatchRequest as BasePatchRequest | ||||||
| from pydanticscim.responses import SCIMError as BaseSCIMError | from pydanticscim.responses import SCIMError as BaseSCIMError | ||||||
| from pydanticscim.service_provider import Bulk as BaseBulk | from pydanticscim.service_provider import Bulk as BaseBulk | ||||||
| @ -68,6 +69,12 @@ class PatchRequest(BasePatchRequest): | |||||||
|     schemas: tuple[str] = ("urn:ietf:params:scim:api:messages:2.0:PatchOp",) |     schemas: tuple[str] = ("urn:ietf:params:scim:api:messages:2.0:PatchOp",) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class PatchOperation(BasePatchOperation): | ||||||
|  |     """PatchOperation with optional path""" | ||||||
|  |  | ||||||
|  |     path: str | None | ||||||
|  |  | ||||||
|  |  | ||||||
| class SCIMError(BaseSCIMError): | class SCIMError(BaseSCIMError): | ||||||
|     """SCIM error with optional status code""" |     """SCIM error with optional status code""" | ||||||
|  |  | ||||||
|  | |||||||
| @ -252,3 +252,118 @@ class SCIMMembershipTests(TestCase): | |||||||
|                     ], |                     ], | ||||||
|                 }, |                 }, | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|  |     def test_member_add_save(self): | ||||||
|  |         """Test member add + save""" | ||||||
|  |         config = ServiceProviderConfiguration.default() | ||||||
|  |  | ||||||
|  |         config.patch.supported = True | ||||||
|  |         user_scim_id = generate_id() | ||||||
|  |         group_scim_id = generate_id() | ||||||
|  |         uid = generate_id() | ||||||
|  |         group = Group.objects.create( | ||||||
|  |             name=uid, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         user = User.objects.create(username=generate_id()) | ||||||
|  |  | ||||||
|  |         # Test initial sync of group creation | ||||||
|  |         with Mocker() as mocker: | ||||||
|  |             mocker.get( | ||||||
|  |                 "https://localhost/ServiceProviderConfig", | ||||||
|  |                 json=config.model_dump(), | ||||||
|  |             ) | ||||||
|  |             mocker.post( | ||||||
|  |                 "https://localhost/Users", | ||||||
|  |                 json={ | ||||||
|  |                     "id": user_scim_id, | ||||||
|  |                 }, | ||||||
|  |             ) | ||||||
|  |             mocker.post( | ||||||
|  |                 "https://localhost/Groups", | ||||||
|  |                 json={ | ||||||
|  |                     "id": group_scim_id, | ||||||
|  |                 }, | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |             self.configure() | ||||||
|  |             sync_tasks.trigger_single_task(self.provider, scim_sync).get() | ||||||
|  |  | ||||||
|  |             self.assertEqual(mocker.call_count, 6) | ||||||
|  |             self.assertEqual(mocker.request_history[0].method, "GET") | ||||||
|  |             self.assertEqual(mocker.request_history[1].method, "GET") | ||||||
|  |             self.assertEqual(mocker.request_history[2].method, "GET") | ||||||
|  |             self.assertEqual(mocker.request_history[3].method, "POST") | ||||||
|  |             self.assertEqual(mocker.request_history[4].method, "GET") | ||||||
|  |             self.assertEqual(mocker.request_history[5].method, "POST") | ||||||
|  |             self.assertJSONEqual( | ||||||
|  |                 mocker.request_history[3].body, | ||||||
|  |                 { | ||||||
|  |                     "schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"], | ||||||
|  |                     "emails": [], | ||||||
|  |                     "active": True, | ||||||
|  |                     "externalId": user.uid, | ||||||
|  |                     "name": {"familyName": " ", "formatted": " ", "givenName": ""}, | ||||||
|  |                     "displayName": "", | ||||||
|  |                     "userName": user.username, | ||||||
|  |                 }, | ||||||
|  |             ) | ||||||
|  |             self.assertJSONEqual( | ||||||
|  |                 mocker.request_history[5].body, | ||||||
|  |                 { | ||||||
|  |                     "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"], | ||||||
|  |                     "externalId": str(group.pk), | ||||||
|  |                     "displayName": group.name, | ||||||
|  |                 }, | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         with Mocker() as mocker: | ||||||
|  |             mocker.get( | ||||||
|  |                 "https://localhost/ServiceProviderConfig", | ||||||
|  |                 json=config.model_dump(), | ||||||
|  |             ) | ||||||
|  |             mocker.get( | ||||||
|  |                 f"https://localhost/Groups/{group_scim_id}", | ||||||
|  |                 json={}, | ||||||
|  |             ) | ||||||
|  |             mocker.patch( | ||||||
|  |                 f"https://localhost/Groups/{group_scim_id}", | ||||||
|  |                 json={}, | ||||||
|  |             ) | ||||||
|  |             group.users.add(user) | ||||||
|  |             group.save() | ||||||
|  |             self.assertEqual(mocker.call_count, 5) | ||||||
|  |             self.assertEqual(mocker.request_history[0].method, "GET") | ||||||
|  |             self.assertEqual(mocker.request_history[1].method, "PATCH") | ||||||
|  |             self.assertEqual(mocker.request_history[2].method, "GET") | ||||||
|  |             self.assertEqual(mocker.request_history[3].method, "PATCH") | ||||||
|  |             self.assertEqual(mocker.request_history[4].method, "GET") | ||||||
|  |             self.assertJSONEqual( | ||||||
|  |                 mocker.request_history[1].body, | ||||||
|  |                 { | ||||||
|  |                     "schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"], | ||||||
|  |                     "Operations": [ | ||||||
|  |                         { | ||||||
|  |                             "op": "add", | ||||||
|  |                             "path": "members", | ||||||
|  |                             "value": [{"value": user_scim_id}], | ||||||
|  |                         } | ||||||
|  |                     ], | ||||||
|  |                 }, | ||||||
|  |             ) | ||||||
|  |             self.assertJSONEqual( | ||||||
|  |                 mocker.request_history[3].body, | ||||||
|  |                 { | ||||||
|  |                     "Operations": [ | ||||||
|  |                         { | ||||||
|  |                             "op": "replace", | ||||||
|  |                             "value": { | ||||||
|  |                                 "id": group_scim_id, | ||||||
|  |                                 "displayName": group.name, | ||||||
|  |                                 "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"], | ||||||
|  |                                 "externalId": str(group.pk), | ||||||
|  |                             }, | ||||||
|  |                         } | ||||||
|  |                     ] | ||||||
|  |                 }, | ||||||
|  |             ) | ||||||
|  | |||||||
| @ -87,7 +87,11 @@ def task_error_hook(task_id: str, exception: Exception, traceback, *args, **kwar | |||||||
|  |  | ||||||
| def _get_startup_tasks_default_tenant() -> list[Callable]: | def _get_startup_tasks_default_tenant() -> list[Callable]: | ||||||
|     """Get all tasks to be run on startup for the default tenant""" |     """Get all tasks to be run on startup for the default tenant""" | ||||||
|     return [] |     from authentik.outposts.tasks import outpost_connection_discovery | ||||||
|  |  | ||||||
|  |     return [ | ||||||
|  |         outpost_connection_discovery, | ||||||
|  |     ] | ||||||
|  |  | ||||||
|  |  | ||||||
| def _get_startup_tasks_all_tenants() -> list[Callable]: | def _get_startup_tasks_all_tenants() -> list[Callable]: | ||||||
|  | |||||||
| @ -2,6 +2,7 @@ | |||||||
|  |  | ||||||
| from collections.abc import Callable | from collections.abc import Callable | ||||||
| from hashlib import sha512 | from hashlib import sha512 | ||||||
|  | from ipaddress import ip_address | ||||||
| from time import perf_counter, time | from time import perf_counter, time | ||||||
| from typing import Any | from typing import Any | ||||||
|  |  | ||||||
| @ -174,6 +175,7 @@ class ClientIPMiddleware: | |||||||
|  |  | ||||||
|     def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]): |     def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]): | ||||||
|         self.get_response = get_response |         self.get_response = get_response | ||||||
|  |         self.logger = get_logger().bind() | ||||||
|  |  | ||||||
|     def _get_client_ip_from_meta(self, meta: dict[str, Any]) -> str: |     def _get_client_ip_from_meta(self, meta: dict[str, Any]) -> str: | ||||||
|         """Attempt to get the client's IP by checking common HTTP Headers. |         """Attempt to get the client's IP by checking common HTTP Headers. | ||||||
| @ -185,10 +187,15 @@ class ClientIPMiddleware: | |||||||
|             "HTTP_X_FORWARDED_FOR", |             "HTTP_X_FORWARDED_FOR", | ||||||
|             "REMOTE_ADDR", |             "REMOTE_ADDR", | ||||||
|         ) |         ) | ||||||
|  |         try: | ||||||
|             for _header in headers: |             for _header in headers: | ||||||
|                 if _header in meta: |                 if _header in meta: | ||||||
|                     ips: list[str] = meta.get(_header).split(",") |                     ips: list[str] = meta.get(_header).split(",") | ||||||
|                 return ips[0].strip() |                     # Ensure the IP parses as a valid IP | ||||||
|  |                     return str(ip_address(ips[0].strip())) | ||||||
|  |             return self.default_ip | ||||||
|  |         except ValueError as exc: | ||||||
|  |             self.logger.debug("Invalid remote IP", exc=exc) | ||||||
|             return self.default_ip |             return self.default_ip | ||||||
|  |  | ||||||
|     # FIXME: this should probably not be in `root` but rather in a middleware in `outposts` |     # FIXME: this should probably not be in `root` but rather in a middleware in `outposts` | ||||||
| @ -226,7 +233,11 @@ class ClientIPMiddleware: | |||||||
|         Scope.get_isolation_scope().set_user(user) |         Scope.get_isolation_scope().set_user(user) | ||||||
|         # Set the outpost service account on the request |         # Set the outpost service account on the request | ||||||
|         setattr(request, self.request_attr_outpost_user, user) |         setattr(request, self.request_attr_outpost_user, user) | ||||||
|         return delegated_ip |         try: | ||||||
|  |             return str(ip_address(delegated_ip)) | ||||||
|  |         except ValueError as exc: | ||||||
|  |             self.logger.debug("Invalid remote IP from Outpost", exc=exc) | ||||||
|  |             return None | ||||||
|  |  | ||||||
|     def _get_client_ip(self, request: HttpRequest | None) -> str: |     def _get_client_ip(self, request: HttpRequest | None) -> str: | ||||||
|         """Attempt to get the client's IP by checking common HTTP Headers. |         """Attempt to get the client's IP by checking common HTTP Headers. | ||||||
|  | |||||||
| @ -1,6 +1,8 @@ | |||||||
| """Metrics view""" | """Metrics view""" | ||||||
|  |  | ||||||
| from base64 import b64encode | from hmac import compare_digest | ||||||
|  | from pathlib import Path | ||||||
|  | from tempfile import gettempdir | ||||||
|  |  | ||||||
| from django.conf import settings | from django.conf import settings | ||||||
| from django.db import connections | from django.db import connections | ||||||
| @ -16,22 +18,21 @@ monitoring_set = Signal() | |||||||
|  |  | ||||||
|  |  | ||||||
| class MetricsView(View): | class MetricsView(View): | ||||||
|     """Wrapper around ExportToDjangoView, using http-basic auth""" |     """Wrapper around ExportToDjangoView with authentication, accessed by the authentik router""" | ||||||
|  |  | ||||||
|  |     def __init__(self, **kwargs): | ||||||
|  |         _tmp = Path(gettempdir()) | ||||||
|  |         with open(_tmp / "authentik-core-metrics.key") as _f: | ||||||
|  |             self.monitoring_key = _f.read() | ||||||
|  |  | ||||||
|     def get(self, request: HttpRequest) -> HttpResponse: |     def get(self, request: HttpRequest) -> HttpResponse: | ||||||
|         """Check for HTTP-Basic auth""" |         """Check for HTTP-Basic auth""" | ||||||
|         auth_header = request.META.get("HTTP_AUTHORIZATION", "") |         auth_header = request.META.get("HTTP_AUTHORIZATION", "") | ||||||
|         auth_type, _, given_credentials = auth_header.partition(" ") |         auth_type, _, given_credentials = auth_header.partition(" ") | ||||||
|         credentials = f"monitor:{settings.SECRET_KEY}" |         authed = auth_type == "Bearer" and compare_digest(given_credentials, self.monitoring_key) | ||||||
|         expected = b64encode(str.encode(credentials)).decode() |  | ||||||
|         authed = auth_type == "Basic" and given_credentials == expected |  | ||||||
|         if not authed and not settings.DEBUG: |         if not authed and not settings.DEBUG: | ||||||
|             response = HttpResponse(status=401) |             return HttpResponse(status=401) | ||||||
|             response["WWW-Authenticate"] = 'Basic realm="authentik-monitoring"' |  | ||||||
|             return response |  | ||||||
|  |  | ||||||
|         monitoring_set.send_robust(self) |         monitoring_set.send_robust(self) | ||||||
|  |  | ||||||
|         return ExportToDjangoView(request) |         return ExportToDjangoView(request) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| @ -1,6 +1,7 @@ | |||||||
| """authentik storage backends""" | """authentik storage backends""" | ||||||
|  |  | ||||||
| import os | import os | ||||||
|  | from urllib.parse import parse_qsl, urlsplit | ||||||
|  |  | ||||||
| from django.conf import settings | from django.conf import settings | ||||||
| from django.core.exceptions import SuspiciousOperation | from django.core.exceptions import SuspiciousOperation | ||||||
| @ -110,3 +111,34 @@ class S3Storage(BaseS3Storage): | |||||||
|         if self.querystring_auth: |         if self.querystring_auth: | ||||||
|             return url |             return url | ||||||
|         return self._strip_signing_parameters(url) |         return self._strip_signing_parameters(url) | ||||||
|  |  | ||||||
|  |     def _strip_signing_parameters(self, url): | ||||||
|  |         # Boto3 does not currently support generating URLs that are unsigned. Instead | ||||||
|  |         # we take the signed URLs and strip any querystring params related to signing | ||||||
|  |         # and expiration. | ||||||
|  |         # Note that this may end up with URLs that are still invalid, especially if | ||||||
|  |         # params are passed in that only work with signed URLs, e.g. response header | ||||||
|  |         # params. | ||||||
|  |         # The code attempts to strip all query parameters that match names of known | ||||||
|  |         # parameters from v2 and v4 signatures, regardless of the actual signature | ||||||
|  |         # version used. | ||||||
|  |         split_url = urlsplit(url) | ||||||
|  |         qs = parse_qsl(split_url.query, keep_blank_values=True) | ||||||
|  |         blacklist = { | ||||||
|  |             "x-amz-algorithm", | ||||||
|  |             "x-amz-credential", | ||||||
|  |             "x-amz-date", | ||||||
|  |             "x-amz-expires", | ||||||
|  |             "x-amz-signedheaders", | ||||||
|  |             "x-amz-signature", | ||||||
|  |             "x-amz-security-token", | ||||||
|  |             "awsaccesskeyid", | ||||||
|  |             "expires", | ||||||
|  |             "signature", | ||||||
|  |         } | ||||||
|  |         filtered_qs = ((key, val) for key, val in qs if key.lower() not in blacklist) | ||||||
|  |         # Note: Parameters that did not have a value in the original query string will | ||||||
|  |         # have an '=' sign appended to it, e.g ?foo&bar becomes ?foo=&bar= | ||||||
|  |         joined_qs = ("=".join(keyval) for keyval in filtered_qs) | ||||||
|  |         split_url = split_url._replace(query="&".join(joined_qs)) | ||||||
|  |         return split_url.geturl() | ||||||
|  | |||||||
| @ -1,8 +1,9 @@ | |||||||
| """root tests""" | """root tests""" | ||||||
|  |  | ||||||
| from base64 import b64encode | from pathlib import Path | ||||||
|  | from secrets import token_urlsafe | ||||||
|  | from tempfile import gettempdir | ||||||
|  |  | ||||||
| from django.conf import settings |  | ||||||
| from django.test import TestCase | from django.test import TestCase | ||||||
| from django.urls import reverse | from django.urls import reverse | ||||||
|  |  | ||||||
| @ -10,6 +11,16 @@ from django.urls import reverse | |||||||
| class TestRoot(TestCase): | class TestRoot(TestCase): | ||||||
|     """Test root application""" |     """Test root application""" | ||||||
|  |  | ||||||
|  |     def setUp(self): | ||||||
|  |         _tmp = Path(gettempdir()) | ||||||
|  |         self.token = token_urlsafe(32) | ||||||
|  |         with open(_tmp / "authentik-core-metrics.key", "w") as _f: | ||||||
|  |             _f.write(self.token) | ||||||
|  |  | ||||||
|  |     def tearDown(self): | ||||||
|  |         _tmp = Path(gettempdir()) | ||||||
|  |         (_tmp / "authentik-core-metrics.key").unlink() | ||||||
|  |  | ||||||
|     def test_monitoring_error(self): |     def test_monitoring_error(self): | ||||||
|         """Test monitoring without any credentials""" |         """Test monitoring without any credentials""" | ||||||
|         response = self.client.get(reverse("metrics")) |         response = self.client.get(reverse("metrics")) | ||||||
| @ -17,8 +28,7 @@ class TestRoot(TestCase): | |||||||
|  |  | ||||||
|     def test_monitoring_ok(self): |     def test_monitoring_ok(self): | ||||||
|         """Test monitoring with credentials""" |         """Test monitoring with credentials""" | ||||||
|         creds = "Basic " + b64encode(f"monitor:{settings.SECRET_KEY}".encode()).decode("utf-8") |         auth_headers = {"HTTP_AUTHORIZATION": f"Bearer {self.token}"} | ||||||
|         auth_headers = {"HTTP_AUTHORIZATION": creds} |  | ||||||
|         response = self.client.get(reverse("metrics"), **auth_headers) |         response = self.client.get(reverse("metrics"), **auth_headers) | ||||||
|         self.assertEqual(response.status_code, 200) |         self.assertEqual(response.status_code, 200) | ||||||
|  |  | ||||||
|  | |||||||
| @ -3,6 +3,7 @@ | |||||||
| from typing import Any | from typing import Any | ||||||
|  |  | ||||||
| from django.core.cache import cache | from django.core.cache import cache | ||||||
|  | from django.utils.translation import gettext_lazy as _ | ||||||
| from drf_spectacular.utils import extend_schema, inline_serializer | from drf_spectacular.utils import extend_schema, inline_serializer | ||||||
| from guardian.shortcuts import get_objects_for_user | from guardian.shortcuts import get_objects_for_user | ||||||
| from rest_framework.decorators import action | from rest_framework.decorators import action | ||||||
| @ -39,9 +40,8 @@ class LDAPSourceSerializer(SourceSerializer): | |||||||
|         """Get cached source connectivity""" |         """Get cached source connectivity""" | ||||||
|         return cache.get(CACHE_KEY_STATUS + source.slug, None) |         return cache.get(CACHE_KEY_STATUS + source.slug, None) | ||||||
|  |  | ||||||
|     def validate(self, attrs: dict[str, Any]) -> dict[str, Any]: |     def validate_sync_users_password(self, sync_users_password: bool) -> bool: | ||||||
|         """Check that only a single source has password_sync on""" |         """Check that only a single source has password_sync on""" | ||||||
|         sync_users_password = attrs.get("sync_users_password", True) |  | ||||||
|         if sync_users_password: |         if sync_users_password: | ||||||
|             sources = LDAPSource.objects.filter(sync_users_password=True) |             sources = LDAPSource.objects.filter(sync_users_password=True) | ||||||
|             if self.instance: |             if self.instance: | ||||||
| @ -49,11 +49,31 @@ class LDAPSourceSerializer(SourceSerializer): | |||||||
|             if sources.exists(): |             if sources.exists(): | ||||||
|                 raise ValidationError( |                 raise ValidationError( | ||||||
|                     { |                     { | ||||||
|                         "sync_users_password": ( |                         "sync_users_password": _( | ||||||
|                             "Only a single LDAP Source with password synchronization is allowed" |                             "Only a single LDAP Source with password synchronization is allowed" | ||||||
|                         ) |                         ) | ||||||
|                     } |                     } | ||||||
|                 ) |                 ) | ||||||
|  |         return sync_users_password | ||||||
|  |  | ||||||
|  |     def validate(self, attrs: dict[str, Any]) -> dict[str, Any]: | ||||||
|  |         """Validate property mappings with sync_ flags""" | ||||||
|  |         types = ["user", "group"] | ||||||
|  |         for type in types: | ||||||
|  |             toggle_value = attrs.get(f"sync_{type}s", False) | ||||||
|  |             mappings_field = f"{type}_property_mappings" | ||||||
|  |             mappings_value = attrs.get(mappings_field, []) | ||||||
|  |             if toggle_value and len(mappings_value) == 0: | ||||||
|  |                 raise ValidationError( | ||||||
|  |                     { | ||||||
|  |                         mappings_field: _( | ||||||
|  |                             ( | ||||||
|  |                                 "When 'Sync {type}s' is enabled, '{type}s property " | ||||||
|  |                                 "mappings' cannot be empty." | ||||||
|  |                             ).format(type=type) | ||||||
|  |                         ) | ||||||
|  |                     } | ||||||
|  |                 ) | ||||||
|         return super().validate(attrs) |         return super().validate(attrs) | ||||||
|  |  | ||||||
|     class Meta: |     class Meta: | ||||||
| @ -166,7 +186,8 @@ class LDAPSourceViewSet(UsedByMixin, ModelViewSet): | |||||||
|         for sync_class in SYNC_CLASSES: |         for sync_class in SYNC_CLASSES: | ||||||
|             class_name = sync_class.name() |             class_name = sync_class.name() | ||||||
|             all_objects.setdefault(class_name, []) |             all_objects.setdefault(class_name, []) | ||||||
|             for obj in sync_class(source).get_objects(size_limit=10): |             for page in sync_class(source).get_objects(size_limit=10): | ||||||
|  |                 for obj in page: | ||||||
|                     obj: dict |                     obj: dict | ||||||
|                     obj.pop("raw_attributes", None) |                     obj.pop("raw_attributes", None) | ||||||
|                     obj.pop("raw_dn", None) |                     obj.pop("raw_dn", None) | ||||||
|  | |||||||
| @ -26,17 +26,16 @@ def sync_ldap_source_on_save(sender, instance: LDAPSource, **_): | |||||||
|     """Ensure that source is synced on save (if enabled)""" |     """Ensure that source is synced on save (if enabled)""" | ||||||
|     if not instance.enabled: |     if not instance.enabled: | ||||||
|         return |         return | ||||||
|  |     ldap_connectivity_check.delay(instance.pk) | ||||||
|     # Don't sync sources when they don't have any property mappings. This will only happen if: |     # Don't sync sources when they don't have any property mappings. This will only happen if: | ||||||
|     # - the user forgets to set them or |     # - the user forgets to set them or | ||||||
|     # - the source is newly created, this is the first save event |     # - the source is newly created, this is the first save event | ||||||
|     #   and the mappings are created with an m2m event |     #   and the mappings are created with an m2m event | ||||||
|     if ( |     if instance.sync_users and not instance.user_property_mappings.exists(): | ||||||
|         not instance.user_property_mappings.exists() |         return | ||||||
|         or not instance.group_property_mappings.exists() |     if instance.sync_groups and not instance.group_property_mappings.exists(): | ||||||
|     ): |  | ||||||
|         return |         return | ||||||
|     ldap_sync_single.delay(instance.pk) |     ldap_sync_single.delay(instance.pk) | ||||||
|     ldap_connectivity_check.delay(instance.pk) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @receiver(password_validate) | @receiver(password_validate) | ||||||
|  | |||||||
| @ -38,7 +38,11 @@ class GroupLDAPSynchronizer(BaseLDAPSynchronizer): | |||||||
|             search_base=self.base_dn_groups, |             search_base=self.base_dn_groups, | ||||||
|             search_filter=self._source.group_object_filter, |             search_filter=self._source.group_object_filter, | ||||||
|             search_scope=SUBTREE, |             search_scope=SUBTREE, | ||||||
|             attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES], |             attributes=[ | ||||||
|  |                 ALL_ATTRIBUTES, | ||||||
|  |                 ALL_OPERATIONAL_ATTRIBUTES, | ||||||
|  |                 self._source.object_uniqueness_field, | ||||||
|  |             ], | ||||||
|             **kwargs, |             **kwargs, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
| @ -53,9 +57,9 @@ class GroupLDAPSynchronizer(BaseLDAPSynchronizer): | |||||||
|                 continue |                 continue | ||||||
|             attributes = group.get("attributes", {}) |             attributes = group.get("attributes", {}) | ||||||
|             group_dn = flatten(flatten(group.get("entryDN", group.get("dn")))) |             group_dn = flatten(flatten(group.get("entryDN", group.get("dn")))) | ||||||
|             if self._source.object_uniqueness_field not in attributes: |             if not attributes.get(self._source.object_uniqueness_field): | ||||||
|                 self.message( |                 self.message( | ||||||
|                     f"Cannot find uniqueness field in attributes: '{group_dn}'", |                     f"Uniqueness field not found/not set in attributes: '{group_dn}'", | ||||||
|                     attributes=attributes.keys(), |                     attributes=attributes.keys(), | ||||||
|                     dn=group_dn, |                     dn=group_dn, | ||||||
|                 ) |                 ) | ||||||
|  | |||||||
| @ -40,7 +40,11 @@ class UserLDAPSynchronizer(BaseLDAPSynchronizer): | |||||||
|             search_base=self.base_dn_users, |             search_base=self.base_dn_users, | ||||||
|             search_filter=self._source.user_object_filter, |             search_filter=self._source.user_object_filter, | ||||||
|             search_scope=SUBTREE, |             search_scope=SUBTREE, | ||||||
|             attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES], |             attributes=[ | ||||||
|  |                 ALL_ATTRIBUTES, | ||||||
|  |                 ALL_OPERATIONAL_ATTRIBUTES, | ||||||
|  |                 self._source.object_uniqueness_field, | ||||||
|  |             ], | ||||||
|             **kwargs, |             **kwargs, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
| @ -55,9 +59,9 @@ class UserLDAPSynchronizer(BaseLDAPSynchronizer): | |||||||
|                 continue |                 continue | ||||||
|             attributes = user.get("attributes", {}) |             attributes = user.get("attributes", {}) | ||||||
|             user_dn = flatten(user.get("entryDN", user.get("dn"))) |             user_dn = flatten(user.get("entryDN", user.get("dn"))) | ||||||
|             if self._source.object_uniqueness_field not in attributes: |             if not attributes.get(self._source.object_uniqueness_field): | ||||||
|                 self.message( |                 self.message( | ||||||
|                     f"Cannot find uniqueness field in attributes: '{user_dn}'", |                     f"Uniqueness field not found/not set in attributes: '{user_dn}'", | ||||||
|                     attributes=attributes.keys(), |                     attributes=attributes.keys(), | ||||||
|                     dn=user_dn, |                     dn=user_dn, | ||||||
|                 ) |                 ) | ||||||
|  | |||||||
							
								
								
									
										4
									
								
								authentik/sources/ldap/sync/vendor/ms_ad.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								authentik/sources/ldap/sync/vendor/ms_ad.py
									
									
									
									
										vendored
									
									
								
							| @ -78,7 +78,9 @@ class MicrosoftActiveDirectory(BaseLDAPSynchronizer): | |||||||
|         #   /useraccountcontrol-manipulate-account-properties |         #   /useraccountcontrol-manipulate-account-properties | ||||||
|         uac_bit = attributes.get("userAccountControl", 512) |         uac_bit = attributes.get("userAccountControl", 512) | ||||||
|         uac = UserAccountControl(uac_bit) |         uac = UserAccountControl(uac_bit) | ||||||
|         is_active = UserAccountControl.ACCOUNTDISABLE not in uac |         is_active = ( | ||||||
|  |             UserAccountControl.ACCOUNTDISABLE not in uac and UserAccountControl.LOCKOUT not in uac | ||||||
|  |         ) | ||||||
|         if is_active != user.is_active: |         if is_active != user.is_active: | ||||||
|             user.is_active = is_active |             user.is_active = is_active | ||||||
|             user.save() |             user.save() | ||||||
|  | |||||||
| @ -50,3 +50,35 @@ class LDAPAPITests(APITestCase): | |||||||
|             } |             } | ||||||
|         ) |         ) | ||||||
|         self.assertFalse(serializer.is_valid()) |         self.assertFalse(serializer.is_valid()) | ||||||
|  |  | ||||||
|  |     def test_sync_users_mapping_empty(self): | ||||||
|  |         """Check that when sync_users is enabled, property mappings must be set""" | ||||||
|  |         serializer = LDAPSourceSerializer( | ||||||
|  |             data={ | ||||||
|  |                 "name": "foo", | ||||||
|  |                 "slug": " foo", | ||||||
|  |                 "server_uri": "ldaps://1.2.3.4", | ||||||
|  |                 "bind_cn": "", | ||||||
|  |                 "bind_password": LDAP_PASSWORD, | ||||||
|  |                 "base_dn": "dc=foo", | ||||||
|  |                 "sync_users": True, | ||||||
|  |                 "user_property_mappings": [], | ||||||
|  |             } | ||||||
|  |         ) | ||||||
|  |         self.assertFalse(serializer.is_valid()) | ||||||
|  |  | ||||||
|  |     def test_sync_groups_mapping_empty(self): | ||||||
|  |         """Check that when sync_groups is enabled, property mappings must be set""" | ||||||
|  |         serializer = LDAPSourceSerializer( | ||||||
|  |             data={ | ||||||
|  |                 "name": "foo", | ||||||
|  |                 "slug": " foo", | ||||||
|  |                 "server_uri": "ldaps://1.2.3.4", | ||||||
|  |                 "bind_cn": "", | ||||||
|  |                 "bind_password": LDAP_PASSWORD, | ||||||
|  |                 "base_dn": "dc=foo", | ||||||
|  |                 "sync_groups": True, | ||||||
|  |                 "group_property_mappings": [], | ||||||
|  |             } | ||||||
|  |         ) | ||||||
|  |         self.assertFalse(serializer.is_valid()) | ||||||
|  | |||||||
| @ -30,7 +30,9 @@ class TestMetadataProcessor(TestCase): | |||||||
|         xml = MetadataProcessor(self.source, request).build_entity_descriptor() |         xml = MetadataProcessor(self.source, request).build_entity_descriptor() | ||||||
|         metadata = lxml_from_string(xml) |         metadata = lxml_from_string(xml) | ||||||
|  |  | ||||||
|         schema = etree.XMLSchema(etree.parse("schemas/saml-schema-metadata-2.0.xsd"))  # nosec |         schema = etree.XMLSchema( | ||||||
|  |             etree.parse("schemas/saml-schema-metadata-2.0.xsd", parser=etree.XMLParser())  # nosec | ||||||
|  |         ) | ||||||
|         self.assertTrue(schema.validate(metadata)) |         self.assertTrue(schema.validate(metadata)) | ||||||
|  |  | ||||||
|     def test_metadata_consistent(self): |     def test_metadata_consistent(self): | ||||||
|  | |||||||
| @ -82,3 +82,5 @@ entries: | |||||||
|     order: 10 |     order: 10 | ||||||
|     target: !KeyOf default-authentication-flow-password-binding |     target: !KeyOf default-authentication-flow-password-binding | ||||||
|     policy: !KeyOf default-authentication-flow-password-optional |     policy: !KeyOf default-authentication-flow-password-optional | ||||||
|  |   attrs: | ||||||
|  |     failure_result: true | ||||||
|  | |||||||
| @ -2,7 +2,7 @@ | |||||||
|     "$schema": "http://json-schema.org/draft-07/schema", |     "$schema": "http://json-schema.org/draft-07/schema", | ||||||
|     "$id": "https://goauthentik.io/blueprints/schema.json", |     "$id": "https://goauthentik.io/blueprints/schema.json", | ||||||
|     "type": "object", |     "type": "object", | ||||||
|     "title": "authentik 2024.6.4 Blueprint schema", |     "title": "authentik 2024.8.5 Blueprint schema", | ||||||
|     "required": [ |     "required": [ | ||||||
|         "version", |         "version", | ||||||
|         "entries" |         "entries" | ||||||
| @ -5345,9 +5345,30 @@ | |||||||
|                     "description": "Key used to sign the tokens. Only required when JWT Algorithm is set to RS256." |                     "description": "Key used to sign the tokens. Only required when JWT Algorithm is set to RS256." | ||||||
|                 }, |                 }, | ||||||
|                 "redirect_uris": { |                 "redirect_uris": { | ||||||
|  |                     "type": "array", | ||||||
|  |                     "items": { | ||||||
|  |                         "type": "object", | ||||||
|  |                         "properties": { | ||||||
|  |                             "matching_mode": { | ||||||
|                                 "type": "string", |                                 "type": "string", | ||||||
|                     "title": "Redirect URIs", |                                 "enum": [ | ||||||
|                     "description": "Enter each URI on a new line." |                                     "strict", | ||||||
|  |                                     "regex" | ||||||
|  |                                 ], | ||||||
|  |                                 "title": "Matching mode" | ||||||
|  |                             }, | ||||||
|  |                             "url": { | ||||||
|  |                                 "type": "string", | ||||||
|  |                                 "minLength": 1, | ||||||
|  |                                 "title": "Url" | ||||||
|  |                             } | ||||||
|  |                         }, | ||||||
|  |                         "required": [ | ||||||
|  |                             "matching_mode", | ||||||
|  |                             "url" | ||||||
|  |                         ] | ||||||
|  |                     }, | ||||||
|  |                     "title": "Redirect uris" | ||||||
|                 }, |                 }, | ||||||
|                 "sub_mode": { |                 "sub_mode": { | ||||||
|                     "type": "string", |                     "type": "string", | ||||||
|  | |||||||
| @ -14,11 +14,7 @@ entries: | |||||||
|       expression: | |       expression: | | ||||||
|         # This mapping is used by the authentik proxy. It passes extra user attributes, |         # This mapping is used by the authentik proxy. It passes extra user attributes, | ||||||
|         # which are used for example for the HTTP-Basic Authentication mapping. |         # which are used for example for the HTTP-Basic Authentication mapping. | ||||||
|         session_id = None |  | ||||||
|         if "token" in request.context: |  | ||||||
|             session_id = request.context.get("token").session_id |  | ||||||
|         return { |         return { | ||||||
|             "sid": session_id, |  | ||||||
|             "ak_proxy": { |             "ak_proxy": { | ||||||
|                 "user_attributes": request.user.group_attributes(request), |                 "user_attributes": request.user.group_attributes(request), | ||||||
|                 "is_superuser": request.user.is_superuser, |                 "is_superuser": request.user.is_superuser, | ||||||
|  | |||||||
| @ -31,7 +31,7 @@ services: | |||||||
|     volumes: |     volumes: | ||||||
|       - redis:/data |       - redis:/data | ||||||
|   server: |   server: | ||||||
|     image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2024.6.4} |     image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2024.8.5} | ||||||
|     restart: unless-stopped |     restart: unless-stopped | ||||||
|     command: server |     command: server | ||||||
|     environment: |     environment: | ||||||
| @ -52,7 +52,7 @@ services: | |||||||
|       - postgresql |       - postgresql | ||||||
|       - redis |       - redis | ||||||
|   worker: |   worker: | ||||||
|     image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2024.6.4} |     image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2024.8.5} | ||||||
|     restart: unless-stopped |     restart: unless-stopped | ||||||
|     command: worker |     command: worker | ||||||
|     environment: |     environment: | ||||||
|  | |||||||
| @ -29,4 +29,4 @@ func UserAgent() string { | |||||||
| 	return fmt.Sprintf("authentik@%s", FullVersion()) | 	return fmt.Sprintf("authentik@%s", FullVersion()) | ||||||
| } | } | ||||||
|  |  | ||||||
| const VERSION = "2024.6.4" | const VERSION = "2024.8.5" | ||||||
|  | |||||||
| @ -35,10 +35,11 @@ func Paginator[Tobj any, Treq any, Tres PaginatorResponse[Tobj]]( | |||||||
| 	req PaginatorRequest[Treq, Tres], | 	req PaginatorRequest[Treq, Tres], | ||||||
| 	opts PaginatorOptions, | 	opts PaginatorOptions, | ||||||
| ) ([]Tobj, error) { | ) ([]Tobj, error) { | ||||||
|  | 	var bfreq, cfreq interface{} | ||||||
| 	fetchOffset := func(page int32) (Tres, error) { | 	fetchOffset := func(page int32) (Tres, error) { | ||||||
| 		req.Page(page) | 		bfreq = req.Page(page) | ||||||
| 		req.PageSize(int32(opts.PageSize)) | 		cfreq = bfreq.(PaginatorRequest[Treq, Tres]).PageSize(int32(opts.PageSize)) | ||||||
| 		res, _, err := req.Execute() | 		res, _, err := cfreq.(PaginatorRequest[Treq, Tres]).Execute() | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			opts.Logger.WithError(err).WithField("page", page).Warning("failed to fetch page") | 			opts.Logger.WithError(err).WithField("page", page).Warning("failed to fetch page") | ||||||
| 		} | 		} | ||||||
|  | |||||||
							
								
								
									
										26
									
								
								internal/outpost/ak/api_utils_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								internal/outpost/ak/api_utils_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,26 @@ | |||||||
|  | package ak | ||||||
|  |  | ||||||
|  | // func Test_PaginatorCompile(t *testing.T) { | ||||||
|  | // 	req := api.ApiCoreUsersListRequest{} | ||||||
|  | // 	Paginator(req, PaginatorOptions{ | ||||||
|  | // 		PageSize: 100, | ||||||
|  | // 	}) | ||||||
|  | // } | ||||||
|  |  | ||||||
|  | // func Test_PaginatorCompileExplicit(t *testing.T) { | ||||||
|  | // 	req := api.ApiCoreUsersListRequest{} | ||||||
|  | // 	Paginator[ | ||||||
|  | // 		api.User, | ||||||
|  | // 		api.ApiCoreUsersListRequest, | ||||||
|  | // 		*api.PaginatedUserList, | ||||||
|  | // 	](req, PaginatorOptions{ | ||||||
|  | // 		PageSize: 100, | ||||||
|  | // 	}) | ||||||
|  | // } | ||||||
|  |  | ||||||
|  | // func Test_PaginatorCompileOther(t *testing.T) { | ||||||
|  | // 	req := api.ApiOutpostsProxyListRequest{} | ||||||
|  | // 	Paginator(req, PaginatorOptions{ | ||||||
|  | // 		PageSize: 100, | ||||||
|  | // 	}) | ||||||
|  | // } | ||||||
| @ -96,7 +96,7 @@ func (db *DirectBinder) Bind(username string, req *bind.Request) (ldap.LDAPResul | |||||||
| 		return ldap.LDAPResultOperationsError, nil | 		return ldap.LDAPResultOperationsError, nil | ||||||
| 	} | 	} | ||||||
| 	flags.UserPk = userInfo.User.Pk | 	flags.UserPk = userInfo.User.Pk | ||||||
| 	flags.CanSearch = access.HasSearchPermission != nil | 	flags.CanSearch = access.GetHasSearchPermission() | ||||||
| 	db.si.SetFlags(req.BindDN, &flags) | 	db.si.SetFlags(req.BindDN, &flags) | ||||||
| 	if flags.CanSearch { | 	if flags.CanSearch { | ||||||
| 		req.Log().Debug("Allowed access to search") | 		req.Log().Debug("Allowed access to search") | ||||||
|  | |||||||
| @ -193,7 +193,17 @@ func NewApplication(p api.ProxyOutpostConfig, c *http.Client, server Server) (*A | |||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
| 	mux.HandleFunc("/outpost.goauthentik.io/start", func(w http.ResponseWriter, r *http.Request) { | 	mux.HandleFunc("/outpost.goauthentik.io/start", func(w http.ResponseWriter, r *http.Request) { | ||||||
| 		a.handleAuthStart(w, r, "") | 		fwd := "" | ||||||
|  | 		// This should only really be hit for nginx forward_auth | ||||||
|  | 		// as for that the auth start redirect URL is generated by the | ||||||
|  | 		// reverse proxy, and as such we won't have a request we just | ||||||
|  | 		// denied to reference for final URL | ||||||
|  | 		rd, ok := a.checkRedirectParam(r) | ||||||
|  | 		if ok { | ||||||
|  | 			a.log.WithField("rd", rd).Trace("Setting redirect") | ||||||
|  | 			fwd = rd | ||||||
|  | 		} | ||||||
|  | 		a.handleAuthStart(w, r, fwd) | ||||||
| 	}) | 	}) | ||||||
| 	mux.HandleFunc("/outpost.goauthentik.io/callback", a.handleAuthCallback) | 	mux.HandleFunc("/outpost.goauthentik.io/callback", a.handleAuthCallback) | ||||||
| 	mux.HandleFunc("/outpost.goauthentik.io/sign_out", a.handleSignOut) | 	mux.HandleFunc("/outpost.goauthentik.io/sign_out", a.handleSignOut) | ||||||
|  | |||||||
| @ -15,36 +15,6 @@ const ( | |||||||
| 	LogoutSignature   = "X-authentik-logout" | 	LogoutSignature   = "X-authentik-logout" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func (a *Application) checkRedirectParam(r *http.Request) (string, bool) { |  | ||||||
| 	rd := r.URL.Query().Get(redirectParam) |  | ||||||
| 	if rd == "" { |  | ||||||
| 		return "", false |  | ||||||
| 	} |  | ||||||
| 	u, err := url.Parse(rd) |  | ||||||
| 	if err != nil { |  | ||||||
| 		a.log.WithError(err).Warning("Failed to parse redirect URL") |  | ||||||
| 		return "", false |  | ||||||
| 	} |  | ||||||
| 	// Check to make sure we only redirect to allowed places |  | ||||||
| 	if a.Mode() == api.PROXYMODE_PROXY || a.Mode() == api.PROXYMODE_FORWARD_SINGLE { |  | ||||||
| 		ext, err := url.Parse(a.proxyConfig.ExternalHost) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return "", false |  | ||||||
| 		} |  | ||||||
| 		ext.Scheme = "" |  | ||||||
| 		if !strings.Contains(u.String(), ext.String()) { |  | ||||||
| 			a.log.WithField("url", u.String()).WithField("ext", ext.String()).Warning("redirect URI did not contain external host") |  | ||||||
| 			return "", false |  | ||||||
| 		} |  | ||||||
| 	} else { |  | ||||||
| 		if !strings.HasSuffix(u.Host, *a.proxyConfig.CookieDomain) { |  | ||||||
| 			a.log.WithField("host", u.Host).WithField("dom", *a.proxyConfig.CookieDomain).Warning("redirect URI Host was not included in cookie domain") |  | ||||||
| 			return "", false |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	return u.String(), true |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (a *Application) handleAuthStart(rw http.ResponseWriter, r *http.Request, fwd string) { | func (a *Application) handleAuthStart(rw http.ResponseWriter, r *http.Request, fwd string) { | ||||||
| 	state, err := a.createState(r, fwd) | 	state, err := a.createState(r, fwd) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | |||||||
| @ -5,10 +5,13 @@ import ( | |||||||
| 	"encoding/base64" | 	"encoding/base64" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net/http" | 	"net/http" | ||||||
|  | 	"net/url" | ||||||
|  | 	"strings" | ||||||
|  |  | ||||||
| 	"github.com/golang-jwt/jwt/v5" | 	"github.com/golang-jwt/jwt/v5" | ||||||
| 	"github.com/gorilla/securecookie" | 	"github.com/gorilla/securecookie" | ||||||
| 	"github.com/mitchellh/mapstructure" | 	"github.com/mitchellh/mapstructure" | ||||||
|  | 	"goauthentik.io/api/v3" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type OAuthState struct { | type OAuthState struct { | ||||||
| @ -27,6 +30,44 @@ func (oas *OAuthState) GetAudience() (jwt.ClaimStrings, error)       { return ni | |||||||
|  |  | ||||||
| var base32RawStdEncoding = base32.StdEncoding.WithPadding(base32.NoPadding) | var base32RawStdEncoding = base32.StdEncoding.WithPadding(base32.NoPadding) | ||||||
|  |  | ||||||
|  | // Validate that the given redirect parameter (?rd=...) is valid and can be used | ||||||
|  | // For proxy/forward_single this checks that if the `rd` param has a Hostname (and is a full URL) | ||||||
|  | // the hostname matches what's configured, or no hostname must be given | ||||||
|  | // For forward_domain this checks if the domain of the URL in `rd` ends with the configured domain | ||||||
|  | func (a *Application) checkRedirectParam(r *http.Request) (string, bool) { | ||||||
|  | 	rd := r.URL.Query().Get(redirectParam) | ||||||
|  | 	if rd == "" { | ||||||
|  | 		return "", false | ||||||
|  | 	} | ||||||
|  | 	u, err := url.Parse(rd) | ||||||
|  | 	if err != nil { | ||||||
|  | 		a.log.WithError(err).Warning("Failed to parse redirect URL") | ||||||
|  | 		return "", false | ||||||
|  | 	} | ||||||
|  | 	// Check to make sure we only redirect to allowed places | ||||||
|  | 	if a.Mode() == api.PROXYMODE_PROXY || a.Mode() == api.PROXYMODE_FORWARD_SINGLE { | ||||||
|  | 		ext, err := url.Parse(a.proxyConfig.ExternalHost) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return "", false | ||||||
|  | 		} | ||||||
|  | 		// Either hostname needs to match the configured domain, or host name must be empty for just a path | ||||||
|  | 		if u.Host == "" { | ||||||
|  | 			u.Host = ext.Host | ||||||
|  | 			u.Scheme = ext.Scheme | ||||||
|  | 		} | ||||||
|  | 		if u.Host != ext.Host { | ||||||
|  | 			a.log.WithField("url", u.String()).WithField("ext", ext.String()).Warning("redirect URI did not contain external host") | ||||||
|  | 			return "", false | ||||||
|  | 		} | ||||||
|  | 	} else { | ||||||
|  | 		if !strings.HasSuffix(u.Host, *a.proxyConfig.CookieDomain) { | ||||||
|  | 			a.log.WithField("host", u.Host).WithField("dom", *a.proxyConfig.CookieDomain).Warning("redirect URI Host was not included in cookie domain") | ||||||
|  | 			return "", false | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return u.String(), true | ||||||
|  | } | ||||||
|  |  | ||||||
| func (a *Application) createState(r *http.Request, fwd string) (string, error) { | func (a *Application) createState(r *http.Request, fwd string) (string, error) { | ||||||
| 	s, _ := a.sessions.Get(r, a.SessionName()) | 	s, _ := a.sessions.Get(r, a.SessionName()) | ||||||
| 	if s.ID == "" { | 	if s.ID == "" { | ||||||
| @ -39,17 +80,6 @@ func (a *Application) createState(r *http.Request, fwd string) (string, error) { | |||||||
| 		SessionID: s.ID, | 		SessionID: s.ID, | ||||||
| 		Redirect:  fwd, | 		Redirect:  fwd, | ||||||
| 	} | 	} | ||||||
| 	if fwd == "" { |  | ||||||
| 		// This should only really be hit for nginx forward_auth |  | ||||||
| 		// as for that the auth start redirect URL is generated by the |  | ||||||
| 		// reverse proxy, and as such we won't have a request we just |  | ||||||
| 		// denied to reference for final URL |  | ||||||
| 		rd, ok := a.checkRedirectParam(r) |  | ||||||
| 		if ok { |  | ||||||
| 			a.log.WithField("rd", rd).Trace("Setting redirect") |  | ||||||
| 			st.Redirect = rd |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	token := jwt.NewWithClaims(jwt.SigningMethodHS256, st) | 	token := jwt.NewWithClaims(jwt.SigningMethodHS256, st) | ||||||
| 	tokenString, err := token.SignedString([]byte(a.proxyConfig.GetCookieSecret())) | 	tokenString, err := token.SignedString([]byte(a.proxyConfig.GetCookieSecret())) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | |||||||
| @ -8,25 +8,45 @@ import ( | |||||||
| 	"goauthentik.io/api/v3" | 	"goauthentik.io/api/v3" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func TestCheckRedirectParam(t *testing.T) { | func TestCheckRedirectParam_None(t *testing.T) { | ||||||
| 	a := newTestApplication() | 	a := newTestApplication() | ||||||
|  | 	// Test no rd param | ||||||
| 	req, _ := http.NewRequest("GET", "/outpost.goauthentik.io/auth/start", nil) | 	req, _ := http.NewRequest("GET", "/outpost.goauthentik.io/auth/start", nil) | ||||||
|  |  | ||||||
| 	rd, ok := a.checkRedirectParam(req) | 	rd, ok := a.checkRedirectParam(req) | ||||||
|  |  | ||||||
| 	assert.Equal(t, false, ok) | 	assert.Equal(t, false, ok) | ||||||
| 	assert.Equal(t, "", rd) | 	assert.Equal(t, "", rd) | ||||||
|  | } | ||||||
|  |  | ||||||
| 	req, _ = http.NewRequest("GET", "/outpost.goauthentik.io/auth/start?rd=https://google.com", nil) | func TestCheckRedirectParam_Invalid(t *testing.T) { | ||||||
|  | 	a := newTestApplication() | ||||||
|  | 	// Test invalid rd param | ||||||
|  | 	req, _ := http.NewRequest("GET", "/outpost.goauthentik.io/auth/start?rd=https://google.com", nil) | ||||||
|  |  | ||||||
| 	rd, ok = a.checkRedirectParam(req) | 	rd, ok := a.checkRedirectParam(req) | ||||||
|  |  | ||||||
| 	assert.Equal(t, false, ok) | 	assert.Equal(t, false, ok) | ||||||
| 	assert.Equal(t, "", rd) | 	assert.Equal(t, "", rd) | ||||||
|  | } | ||||||
|  |  | ||||||
| 	req, _ = http.NewRequest("GET", "/outpost.goauthentik.io/auth/start?rd=https://ext.t.goauthentik.io/test?foo", nil) | func TestCheckRedirectParam_ValidFull(t *testing.T) { | ||||||
|  | 	a := newTestApplication() | ||||||
|  | 	// Test valid full rd param | ||||||
|  | 	req, _ := http.NewRequest("GET", "/outpost.goauthentik.io/auth/start?rd=https://ext.t.goauthentik.io/test?foo", nil) | ||||||
|  |  | ||||||
| 	rd, ok = a.checkRedirectParam(req) | 	rd, ok := a.checkRedirectParam(req) | ||||||
|  |  | ||||||
|  | 	assert.Equal(t, true, ok) | ||||||
|  | 	assert.Equal(t, "https://ext.t.goauthentik.io/test?foo", rd) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestCheckRedirectParam_ValidPartial(t *testing.T) { | ||||||
|  | 	a := newTestApplication() | ||||||
|  | 	// Test valid partial rd param | ||||||
|  | 	req, _ := http.NewRequest("GET", "/outpost.goauthentik.io/auth/start?rd=/test?foo", nil) | ||||||
|  |  | ||||||
|  | 	rd, ok := a.checkRedirectParam(req) | ||||||
|  |  | ||||||
| 	assert.Equal(t, true, ok) | 	assert.Equal(t, true, ok) | ||||||
| 	assert.Equal(t, "https://ext.t.goauthentik.io/test?foo", rd) | 	assert.Equal(t, "https://ext.t.goauthentik.io/test?foo", rd) | ||||||
|  | |||||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user
	