Compare commits
	
		
			70 Commits
		
	
	
		
			enterprise
			...
			version/20
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| c99a33baee | |||
| b17d482e50 | |||
| 524d46ad7c | |||
| f90d6bb3d9 | |||
| 2340bced63 | |||
| 0a51e1b696 | |||
| 13636c0efe | |||
| e7f49d97a8 | |||
| 736240f60d | |||
| e8b5e4c127 | |||
| 81ec98b198 | |||
| c46ab19e79 | |||
| de9fc5de6b | |||
| eab3d9b411 | |||
| 7cb40d786f | |||
| b4fce08bbc | |||
| 8a2ba1c518 | |||
| 25b4306693 | |||
| 1e279950f1 | |||
| 960429355f | |||
| b4f3748353 | |||
| 91d2445c61 | |||
| dd8f809161 | |||
| 57a31b5dd1 | |||
| 09125b6236 | |||
| 832126c6fe | |||
| 25fe489b34 | |||
| 18078fd68f | |||
| 4fa71d995d | |||
| 22cec64234 | |||
| a87cc27366 | |||
| ad7ad1fa78 | |||
| c70e609e50 | |||
| 5f08485fff | |||
| 3a2ed11821 | |||
| ee04f39e28 | |||
| 2c6aa72f3c | |||
| bd0afef790 | |||
| fc11cc0a1a | |||
| fb78303e8f | |||
| 2ea04440db | |||
| 96e1636be3 | |||
| c546451a73 | |||
| 61778053b4 | |||
| f5580d311d | |||
| 99d292bce0 | |||
| b2801641bc | |||
| bfaa1046b2 | |||
| 95c30400cc | |||
| e77480ee1d | |||
| 905800e535 | |||
| fadeaef4c6 | |||
| 437efda649 | |||
| dd75d5f54b | |||
| 392a2e582e | |||
| a1da183721 | |||
| feea2df0b1 | |||
| b47acd8c76 | |||
| 6fd87d9ced | |||
| acbb065808 | |||
| 2fb097061d | |||
| 8962d17e03 | |||
| 8326e1490c | |||
| 091e4d3e4c | |||
| 6ee77edcbb | |||
| 763e2288bf | |||
| 9cdb177ca7 | |||
| 6070508058 | |||
| ec13a5d84d | |||
| 057de82b01 | 
| @ -1,5 +1,5 @@ | ||||
| [bumpversion] | ||||
| current_version = 2024.6.4 | ||||
| current_version = 2024.8.5 | ||||
| tag = True | ||||
| commit = True | ||||
| parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))? | ||||
|  | ||||
| @ -29,9 +29,9 @@ outputs: | ||||
|   imageTags: | ||||
|     description: "Docker image tags" | ||||
|     value: ${{ steps.ev.outputs.imageTags }} | ||||
|   imageNames: | ||||
|     description: "Docker image names" | ||||
|     value: ${{ steps.ev.outputs.imageNames }} | ||||
|   attestImageNames: | ||||
|     description: "Docker image names used for attestation" | ||||
|     value: ${{ steps.ev.outputs.attestImageNames }} | ||||
|   imageMainTag: | ||||
|     description: "Docker image main tag" | ||||
|     value: ${{ steps.ev.outputs.imageMainTag }} | ||||
|  | ||||
| @ -51,15 +51,24 @@ else: | ||||
|         ] | ||||
|  | ||||
| image_main_tag = image_tags[0].split(":")[-1] | ||||
| image_tags_rendered = ",".join(image_tags) | ||||
| image_names_rendered = ",".join(set(name.split(":")[0] for name in image_tags)) | ||||
|  | ||||
|  | ||||
| def get_attest_image_names(image_with_tags: list[str]): | ||||
|     """Attestation only for GHCR""" | ||||
|     image_tags = [] | ||||
|     for image_name in set(name.split(":")[0] for name in image_with_tags): | ||||
|         if not image_name.startswith("ghcr.io"): | ||||
|             continue | ||||
|         image_tags.append(image_name) | ||||
|     return ",".join(set(image_tags)) | ||||
|  | ||||
|  | ||||
| with open(os.environ["GITHUB_OUTPUT"], "a+", encoding="utf-8") as _output: | ||||
|     print(f"shouldBuild={should_build}", file=_output) | ||||
|     print(f"sha={sha}", file=_output) | ||||
|     print(f"version={version}", file=_output) | ||||
|     print(f"prerelease={prerelease}", file=_output) | ||||
|     print(f"imageTags={image_tags_rendered}", file=_output) | ||||
|     print(f"imageNames={image_names_rendered}", file=_output) | ||||
|     print(f"imageTags={','.join(image_tags)}", file=_output) | ||||
|     print(f"attestImageNames={get_attest_image_names(image_tags)}", file=_output) | ||||
|     print(f"imageMainTag={image_main_tag}", file=_output) | ||||
|     print(f"imageMainName={image_tags[0]}", file=_output) | ||||
|  | ||||
							
								
								
									
										2
									
								
								.github/workflows/ci-main.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/ci-main.yml
									
									
									
									
										vendored
									
									
								
							| @ -261,7 +261,7 @@ jobs: | ||||
|         id: attest | ||||
|         if: ${{ steps.ev.outputs.shouldBuild == 'true' }} | ||||
|         with: | ||||
|           subject-name: ${{ steps.ev.outputs.imageNames }} | ||||
|           subject-name: ${{ steps.ev.outputs.attestImageNames }} | ||||
|           subject-digest: ${{ steps.push.outputs.digest }} | ||||
|           push-to-registry: true | ||||
|   pr-comment: | ||||
|  | ||||
							
								
								
									
										2
									
								
								.github/workflows/ci-outpost.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/ci-outpost.yml
									
									
									
									
										vendored
									
									
								
							| @ -115,7 +115,7 @@ jobs: | ||||
|         id: attest | ||||
|         if: ${{ steps.ev.outputs.shouldBuild == 'true' }} | ||||
|         with: | ||||
|           subject-name: ${{ steps.ev.outputs.imageNames }} | ||||
|           subject-name: ${{ steps.ev.outputs.attestImageNames }} | ||||
|           subject-digest: ${{ steps.push.outputs.digest }} | ||||
|           push-to-registry: true | ||||
|   build-binary: | ||||
|  | ||||
							
								
								
									
										4
									
								
								.github/workflows/release-publish.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.github/workflows/release-publish.yml
									
									
									
									
										vendored
									
									
								
							| @ -58,7 +58,7 @@ jobs: | ||||
|       - uses: actions/attest-build-provenance@v1 | ||||
|         id: attest | ||||
|         with: | ||||
|           subject-name: ${{ steps.ev.outputs.imageNames }} | ||||
|           subject-name: ${{ steps.ev.outputs.attestImageNames }} | ||||
|           subject-digest: ${{ steps.push.outputs.digest }} | ||||
|           push-to-registry: true | ||||
|   build-outpost: | ||||
| @ -122,7 +122,7 @@ jobs: | ||||
|       - uses: actions/attest-build-provenance@v1 | ||||
|         id: attest | ||||
|         with: | ||||
|           subject-name: ${{ steps.ev.outputs.imageNames }} | ||||
|           subject-name: ${{ steps.ev.outputs.attestImageNames }} | ||||
|           subject-digest: ${{ steps.push.outputs.digest }} | ||||
|           push-to-registry: true | ||||
|   build-outpost-binary: | ||||
|  | ||||
							
								
								
									
										2
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								Makefile
									
									
									
									
									
								
							| @ -205,7 +205,7 @@ gen: gen-build gen-client-ts | ||||
| web-build: web-install  ## Build the Authentik UI | ||||
| 	cd web && npm run build | ||||
|  | ||||
| web: web-lint-fix web-lint web-check-compile web-test  ## Automatically fix formatting issues in the Authentik UI source code, lint the code, and compile it | ||||
| web: web-lint-fix web-lint web-check-compile  ## Automatically fix formatting issues in the Authentik UI source code, lint the code, and compile it | ||||
|  | ||||
| web-install:  ## Install the necessary libraries to build the Authentik UI | ||||
| 	cd web && npm ci | ||||
|  | ||||
| @ -2,7 +2,7 @@ | ||||
|  | ||||
| from os import environ | ||||
|  | ||||
| __version__ = "2024.6.4" | ||||
| __version__ = "2024.8.5" | ||||
| ENV_GIT_HASH_KEY = "GIT_BUILD_HASH" | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -51,9 +51,11 @@ class BlueprintInstanceSerializer(ModelSerializer): | ||||
|         context = self.instance.context if self.instance else {} | ||||
|         valid, logs = Importer.from_string(content, context).validate() | ||||
|         if not valid: | ||||
|             text_logs = "\n".join([x["event"] for x in logs]) | ||||
|             raise ValidationError( | ||||
|                 _("Failed to validate blueprint: {logs}".format_map({"logs": text_logs})) | ||||
|                 [ | ||||
|                     _("Failed to validate blueprint"), | ||||
|                     *[f"- {x.event}" for x in logs], | ||||
|                 ] | ||||
|             ) | ||||
|         return content | ||||
|  | ||||
|  | ||||
| @ -78,5 +78,5 @@ class TestBlueprintsV1API(APITestCase): | ||||
|         self.assertEqual(res.status_code, 400) | ||||
|         self.assertJSONEqual( | ||||
|             res.content.decode(), | ||||
|             {"content": ["Failed to validate blueprint: Invalid blueprint version"]}, | ||||
|             {"content": ["Failed to validate blueprint", "- Invalid blueprint version"]}, | ||||
|         ) | ||||
|  | ||||
| @ -429,7 +429,7 @@ class Importer: | ||||
|         orig_import = deepcopy(self._import) | ||||
|         if self._import.version != 1: | ||||
|             self.logger.warning("Invalid blueprint version") | ||||
|             return False, [{"event": "Invalid blueprint version"}] | ||||
|             return False, [LogEvent("Invalid blueprint version", log_level="warning", logger=None)] | ||||
|         with ( | ||||
|             transaction_rollback(), | ||||
|             capture_logs() as logs, | ||||
|  | ||||
| @ -30,8 +30,10 @@ from authentik.core.api.utils import ( | ||||
|     PassiveSerializer, | ||||
| ) | ||||
| from authentik.core.expression.evaluator import PropertyMappingEvaluator | ||||
| from authentik.core.expression.exceptions import PropertyMappingExpressionException | ||||
| from authentik.core.models import Group, PropertyMapping, User | ||||
| from authentik.events.utils import sanitize_item | ||||
| from authentik.lib.utils.errors import exception_to_string | ||||
| from authentik.policies.api.exec import PolicyTestSerializer | ||||
| from authentik.rbac.decorators import permission_required | ||||
|  | ||||
| @ -162,12 +164,15 @@ class PropertyMappingViewSet( | ||||
|  | ||||
|         response_data = {"successful": True, "result": ""} | ||||
|         try: | ||||
|             result = mapping.evaluate(**context) | ||||
|             result = mapping.evaluate(dry_run=True, **context) | ||||
|             response_data["result"] = dumps( | ||||
|                 sanitize_item(result), indent=(4 if format_result else None) | ||||
|             ) | ||||
|         except PropertyMappingExpressionException as exc: | ||||
|             response_data["result"] = exception_to_string(exc.exc) | ||||
|             response_data["successful"] = False | ||||
|         except Exception as exc: | ||||
|             response_data["result"] = str(exc) | ||||
|             response_data["result"] = exception_to_string(exc) | ||||
|             response_data["successful"] = False | ||||
|         response = PropertyMappingTestResultSerializer(response_data) | ||||
|         return Response(response.data) | ||||
|  | ||||
| @ -678,10 +678,13 @@ class UserViewSet(UsedByMixin, ModelViewSet): | ||||
|         if not request.tenant.impersonation: | ||||
|             LOGGER.debug("User attempted to impersonate", user=request.user) | ||||
|             return Response(status=401) | ||||
|         if not request.user.has_perm("impersonate"): | ||||
|         user_to_be = self.get_object() | ||||
|         # Check both object-level perms and global perms | ||||
|         if not request.user.has_perm( | ||||
|             "authentik_core.impersonate", user_to_be | ||||
|         ) and not request.user.has_perm("authentik_core.impersonate"): | ||||
|             LOGGER.debug("User attempted to impersonate without permissions", user=request.user) | ||||
|             return Response(status=401) | ||||
|         user_to_be = self.get_object() | ||||
|         if user_to_be.pk == self.request.user.pk: | ||||
|             LOGGER.debug("User attempted to impersonate themselves", user=request.user) | ||||
|             return Response(status=401) | ||||
|  | ||||
| @ -9,10 +9,11 @@ class Command(TenantCommand): | ||||
|  | ||||
|     def add_arguments(self, parser): | ||||
|         parser.add_argument("--type", type=str, required=True) | ||||
|         parser.add_argument("--all", action="store_true") | ||||
|         parser.add_argument("usernames", nargs="+", type=str) | ||||
|         parser.add_argument("--all", action="store_true", default=False) | ||||
|         parser.add_argument("usernames", nargs="*", type=str) | ||||
|  | ||||
|     def handle_per_tenant(self, **options): | ||||
|         print(options) | ||||
|         new_type = UserTypes(options["type"]) | ||||
|         qs = ( | ||||
|             User.objects.exclude_anonymous() | ||||
| @ -22,6 +23,9 @@ class Command(TenantCommand): | ||||
|         if options["usernames"] and options["all"]: | ||||
|             self.stderr.write("--all and usernames specified, only one can be specified") | ||||
|             return | ||||
|         if not options["usernames"] and not options["all"]: | ||||
|             self.stderr.write("--all or usernames must be specified") | ||||
|             return | ||||
|         if options["usernames"] and not options["all"]: | ||||
|             qs = qs.filter(username__in=options["usernames"]) | ||||
|         updated = qs.update(type=new_type) | ||||
|  | ||||
| @ -466,8 +466,6 @@ class ApplicationQuerySet(QuerySet): | ||||
|     def with_provider(self) -> "QuerySet[Application]": | ||||
|         qs = self.select_related("provider") | ||||
|         for subclass in Provider.objects.get_queryset()._get_subclasses_recurse(Provider): | ||||
|             if LOOKUP_SEP in subclass: | ||||
|                 continue | ||||
|             qs = qs.select_related(f"provider__{subclass}") | ||||
|         return qs | ||||
|  | ||||
| @ -545,15 +543,24 @@ class Application(SerializerModel, PolicyBindingModel): | ||||
|         if not self.provider: | ||||
|             return None | ||||
|  | ||||
|         for subclass in Provider.objects.get_queryset()._get_subclasses_recurse(Provider): | ||||
|             # We don't care about recursion, skip nested models | ||||
|             if LOOKUP_SEP in subclass: | ||||
|                 continue | ||||
|         candidates = [] | ||||
|         base_class = Provider | ||||
|         for subclass in base_class.objects.get_queryset()._get_subclasses_recurse(base_class): | ||||
|             parent = self.provider | ||||
|             for level in subclass.split(LOOKUP_SEP): | ||||
|                 try: | ||||
|                 return getattr(self.provider, subclass) | ||||
|                     parent = getattr(parent, level) | ||||
|                 except AttributeError: | ||||
|                 pass | ||||
|                     break | ||||
|             if parent in candidates: | ||||
|                 continue | ||||
|             idx = subclass.count(LOOKUP_SEP) | ||||
|             if type(parent) is not base_class: | ||||
|                 idx += 1 | ||||
|             candidates.insert(idx, parent) | ||||
|         if not candidates: | ||||
|             return None | ||||
|         return candidates[-1] | ||||
|  | ||||
|     def __str__(self): | ||||
|         return str(self.name) | ||||
| @ -901,7 +908,7 @@ class PropertyMapping(SerializerModel, ManagedModel): | ||||
|         except ControlFlowException as exc: | ||||
|             raise exc | ||||
|         except Exception as exc: | ||||
|             raise PropertyMappingExpressionException(self, exc) from exc | ||||
|             raise PropertyMappingExpressionException(exc, self) from exc | ||||
|  | ||||
|     def __str__(self): | ||||
|         return f"Property Mapping {self.name}" | ||||
|  | ||||
| @ -9,9 +9,12 @@ from rest_framework.test import APITestCase | ||||
|  | ||||
| from authentik.core.models import Application | ||||
| from authentik.core.tests.utils import create_test_admin_user, create_test_flow | ||||
| from authentik.lib.generators import generate_id | ||||
| from authentik.policies.dummy.models import DummyPolicy | ||||
| from authentik.policies.models import PolicyBinding | ||||
| from authentik.providers.oauth2.models import OAuth2Provider | ||||
| from authentik.providers.oauth2.models import OAuth2Provider, RedirectURI, RedirectURIMatchingMode | ||||
| from authentik.providers.proxy.models import ProxyProvider | ||||
| from authentik.providers.saml.models import SAMLProvider | ||||
|  | ||||
|  | ||||
| class TestApplicationsAPI(APITestCase): | ||||
| @ -21,7 +24,7 @@ class TestApplicationsAPI(APITestCase): | ||||
|         self.user = create_test_admin_user() | ||||
|         self.provider = OAuth2Provider.objects.create( | ||||
|             name="test", | ||||
|             redirect_uris="http://some-other-domain", | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://some-other-domain")], | ||||
|             authorization_flow=create_test_flow(), | ||||
|         ) | ||||
|         self.allowed: Application = Application.objects.create( | ||||
| @ -222,3 +225,31 @@ class TestApplicationsAPI(APITestCase): | ||||
|                 ], | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|     def test_get_provider(self): | ||||
|         """Ensure that proxy providers (at the time of writing that is the only provider | ||||
|         that inherits from another proxy type (OAuth) instead of inheriting from the root | ||||
|         provider class) is correctly looked up and selected from the database""" | ||||
|         slug = generate_id() | ||||
|         provider = ProxyProvider.objects.create(name=generate_id()) | ||||
|         Application.objects.create( | ||||
|             name=generate_id(), | ||||
|             slug=slug, | ||||
|             provider=provider, | ||||
|         ) | ||||
|         self.assertEqual(Application.objects.get(slug=slug).get_provider(), provider) | ||||
|         self.assertEqual( | ||||
|             Application.objects.with_provider().get(slug=slug).get_provider(), provider | ||||
|         ) | ||||
|  | ||||
|         slug = generate_id() | ||||
|         provider = SAMLProvider.objects.create(name=generate_id()) | ||||
|         Application.objects.create( | ||||
|             name=generate_id(), | ||||
|             slug=slug, | ||||
|             provider=provider, | ||||
|         ) | ||||
|         self.assertEqual(Application.objects.get(slug=slug).get_provider(), provider) | ||||
|         self.assertEqual( | ||||
|             Application.objects.with_provider().get(slug=slug).get_provider(), provider | ||||
|         ) | ||||
|  | ||||
| @ -3,10 +3,10 @@ | ||||
| from json import loads | ||||
|  | ||||
| from django.urls import reverse | ||||
| from guardian.shortcuts import assign_perm | ||||
| from rest_framework.test import APITestCase | ||||
|  | ||||
| from authentik.core.models import User | ||||
| from authentik.core.tests.utils import create_test_admin_user | ||||
| from authentik.core.tests.utils import create_test_admin_user, create_test_user | ||||
| from authentik.tenants.utils import get_current_tenant | ||||
|  | ||||
|  | ||||
| @ -15,7 +15,7 @@ class TestImpersonation(APITestCase): | ||||
|  | ||||
|     def setUp(self) -> None: | ||||
|         super().setUp() | ||||
|         self.other_user = User.objects.create(username="to-impersonate") | ||||
|         self.other_user = create_test_user() | ||||
|         self.user = create_test_admin_user() | ||||
|  | ||||
|     def test_impersonate_simple(self): | ||||
| @ -44,6 +44,46 @@ class TestImpersonation(APITestCase): | ||||
|         self.assertEqual(response_body["user"]["username"], self.user.username) | ||||
|         self.assertNotIn("original", response_body) | ||||
|  | ||||
|     def test_impersonate_global(self): | ||||
|         """Test impersonation with global permissions""" | ||||
|         new_user = create_test_user() | ||||
|         assign_perm("authentik_core.impersonate", new_user) | ||||
|         assign_perm("authentik_core.view_user", new_user) | ||||
|         self.client.force_login(new_user) | ||||
|  | ||||
|         response = self.client.post( | ||||
|             reverse( | ||||
|                 "authentik_api:user-impersonate", | ||||
|                 kwargs={"pk": self.other_user.pk}, | ||||
|             ) | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 201) | ||||
|  | ||||
|         response = self.client.get(reverse("authentik_api:user-me")) | ||||
|         response_body = loads(response.content.decode()) | ||||
|         self.assertEqual(response_body["user"]["username"], self.other_user.username) | ||||
|         self.assertEqual(response_body["original"]["username"], new_user.username) | ||||
|  | ||||
|     def test_impersonate_scoped(self): | ||||
|         """Test impersonation with scoped permissions""" | ||||
|         new_user = create_test_user() | ||||
|         assign_perm("authentik_core.impersonate", new_user, self.other_user) | ||||
|         assign_perm("authentik_core.view_user", new_user, self.other_user) | ||||
|         self.client.force_login(new_user) | ||||
|  | ||||
|         response = self.client.post( | ||||
|             reverse( | ||||
|                 "authentik_api:user-impersonate", | ||||
|                 kwargs={"pk": self.other_user.pk}, | ||||
|             ) | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 201) | ||||
|  | ||||
|         response = self.client.get(reverse("authentik_api:user-me")) | ||||
|         response_body = loads(response.content.decode()) | ||||
|         self.assertEqual(response_body["user"]["username"], self.other_user.username) | ||||
|         self.assertEqual(response_body["original"]["username"], new_user.username) | ||||
|  | ||||
|     def test_impersonate_denied(self): | ||||
|         """test impersonation without permissions""" | ||||
|         self.client.force_login(self.other_user) | ||||
|  | ||||
| @ -31,6 +31,7 @@ class TestTransactionalApplicationsAPI(APITestCase): | ||||
|                 "provider": { | ||||
|                     "name": uid, | ||||
|                     "authorization_flow": str(authorization_flow.pk), | ||||
|                     "redirect_uris": [], | ||||
|                 }, | ||||
|             }, | ||||
|         ) | ||||
| @ -56,6 +57,7 @@ class TestTransactionalApplicationsAPI(APITestCase): | ||||
|                 "provider": { | ||||
|                     "name": uid, | ||||
|                     "authorization_flow": "", | ||||
|                     "redirect_uris": [], | ||||
|                 }, | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
| @ -18,7 +18,7 @@ from authentik.crypto.models import CertificateKeyPair | ||||
| from authentik.crypto.tasks import MANAGED_DISCOVERED, certificate_discovery | ||||
| from authentik.lib.config import CONFIG | ||||
| from authentik.lib.generators import generate_id, generate_key | ||||
| from authentik.providers.oauth2.models import OAuth2Provider | ||||
| from authentik.providers.oauth2.models import OAuth2Provider, RedirectURI, RedirectURIMatchingMode | ||||
|  | ||||
|  | ||||
| class TestCrypto(APITestCase): | ||||
| @ -263,7 +263,7 @@ class TestCrypto(APITestCase): | ||||
|             client_id="test", | ||||
|             client_secret=generate_key(), | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris="http://localhost", | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")], | ||||
|             signing_key=keypair, | ||||
|         ) | ||||
|         response = self.client.get( | ||||
| @ -295,7 +295,7 @@ class TestCrypto(APITestCase): | ||||
|             client_id="test", | ||||
|             client_secret=generate_key(), | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris="http://localhost", | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")], | ||||
|             signing_key=keypair, | ||||
|         ) | ||||
|         response = self.client.get( | ||||
|  | ||||
| @ -18,7 +18,7 @@ from authentik.core.api.used_by import UsedByMixin | ||||
| from authentik.core.api.utils import ModelSerializer, PassiveSerializer | ||||
| from authentik.core.models import User, UserTypes | ||||
| from authentik.enterprise.license import LicenseKey, LicenseSummarySerializer | ||||
| from authentik.enterprise.models import License, LicenseUsageStatus | ||||
| from authentik.enterprise.models import License | ||||
| from authentik.rbac.decorators import permission_required | ||||
| from authentik.tenants.utils import get_unique_identifier | ||||
|  | ||||
| @ -29,7 +29,7 @@ class EnterpriseRequiredMixin: | ||||
|  | ||||
|     def validate(self, attrs: dict) -> dict: | ||||
|         """Check that a valid license exists""" | ||||
|         if LicenseKey.cached_summary().status != LicenseUsageStatus.UNLICENSED: | ||||
|         if not LicenseKey.cached_summary().status.is_valid: | ||||
|             raise ValidationError(_("Enterprise is required to create/update this object.")) | ||||
|         return super().validate(attrs) | ||||
|  | ||||
|  | ||||
| @ -25,4 +25,4 @@ class AuthentikEnterpriseConfig(EnterpriseConfig): | ||||
|         """Actual enterprise check, cached""" | ||||
|         from authentik.enterprise.license import LicenseKey | ||||
|  | ||||
|         return LicenseKey.cached_summary().status | ||||
|         return LicenseKey.cached_summary().status.is_valid | ||||
|  | ||||
| @ -117,10 +117,13 @@ class LicenseKey: | ||||
|                     our_cert.public_key(), | ||||
|                     algorithms=["ES512"], | ||||
|                     audience=get_license_aud(), | ||||
|                     options={"verify_exp": check_expiry}, | ||||
|                     options={"verify_exp": check_expiry, "verify_signature": check_expiry}, | ||||
|                 ), | ||||
|             ) | ||||
|         except PyJWTError: | ||||
|             unverified = decode(jwt, options={"verify_signature": False}) | ||||
|             if unverified["aud"] != get_license_aud(): | ||||
|                 raise ValidationError("Invalid Install ID in license") from None | ||||
|             raise ValidationError("Unable to verify license") from None | ||||
|         return body | ||||
|  | ||||
| @ -134,7 +137,7 @@ class LicenseKey: | ||||
|             exp_ts = int(mktime(lic.expiry.timetuple())) | ||||
|             if total.exp == 0: | ||||
|                 total.exp = exp_ts | ||||
|             total.exp = min(total.exp, exp_ts) | ||||
|             total.exp = max(total.exp, exp_ts) | ||||
|             total.license_flags.extend(lic.status.license_flags) | ||||
|         return total | ||||
|  | ||||
|  | ||||
| @ -3,7 +3,7 @@ | ||||
| from datetime import datetime | ||||
|  | ||||
| from django.core.cache import cache | ||||
| from django.db.models.signals import post_save, pre_save | ||||
| from django.db.models.signals import post_delete, post_save, pre_save | ||||
| from django.dispatch import receiver | ||||
| from django.utils.timezone import get_current_timezone | ||||
|  | ||||
| @ -27,3 +27,9 @@ def post_save_license(sender: type[License], instance: License, **_): | ||||
|     """Trigger license usage calculation when license is saved""" | ||||
|     cache.delete(CACHE_KEY_ENTERPRISE_LICENSE) | ||||
|     enterprise_update_usage.delay() | ||||
|  | ||||
|  | ||||
| @receiver(post_delete, sender=License) | ||||
| def post_delete_license(sender: type[License], instance: License, **_): | ||||
|     """Clear license cache when license is deleted""" | ||||
|     cache.delete(CACHE_KEY_ENTERPRISE_LICENSE) | ||||
|  | ||||
| @ -69,8 +69,5 @@ class NotificationViewSet( | ||||
|     @action(detail=False, methods=["post"]) | ||||
|     def mark_all_seen(self, request: Request) -> Response: | ||||
|         """Mark all the user's notifications as seen""" | ||||
|         notifications = Notification.objects.filter(user=request.user) | ||||
|         for notification in notifications: | ||||
|             notification.seen = True | ||||
|         Notification.objects.bulk_update(notifications, ["seen"]) | ||||
|         Notification.objects.filter(user=request.user, seen=False).update(seen=True) | ||||
|         return Response({}, status=204) | ||||
|  | ||||
| @ -49,6 +49,7 @@ from authentik.policies.models import PolicyBindingModel | ||||
| from authentik.root.middleware import ClientIPMiddleware | ||||
| from authentik.stages.email.utils import TemplateEmailMessage | ||||
| from authentik.tenants.models import Tenant | ||||
| from authentik.tenants.utils import get_current_tenant | ||||
|  | ||||
| LOGGER = get_logger() | ||||
| DISCORD_FIELD_LIMIT = 25 | ||||
| @ -58,6 +59,10 @@ NOTIFICATION_SUMMARY_LENGTH = 75 | ||||
| def default_event_duration(): | ||||
|     """Default duration an Event is saved. | ||||
|     This is used as a fallback when no brand is available""" | ||||
|     try: | ||||
|         tenant = get_current_tenant() | ||||
|         return now() + timedelta_from_string(tenant.event_retention) | ||||
|     except Tenant.DoesNotExist: | ||||
|         return now() + timedelta(days=365) | ||||
|  | ||||
|  | ||||
| @ -245,12 +250,6 @@ class Event(SerializerModel, ExpiringModel): | ||||
|             if QS_QUERY in self.context["http_request"]["args"]: | ||||
|                 wrapped = self.context["http_request"]["args"][QS_QUERY] | ||||
|                 self.context["http_request"]["args"] = cleanse_dict(QueryDict(wrapped)) | ||||
|         if hasattr(request, "tenant"): | ||||
|             tenant: Tenant = request.tenant | ||||
|             # Because self.created only gets set on save, we can't use it's value here | ||||
|             # hence we set self.created to now and then use it | ||||
|             self.created = now() | ||||
|             self.expires = self.created + timedelta_from_string(tenant.event_retention) | ||||
|         if hasattr(request, "brand"): | ||||
|             brand: Brand = request.brand | ||||
|             self.brand = sanitize_dict(model_to_dict(brand)) | ||||
|  | ||||
| @ -1,13 +1,16 @@ | ||||
| """authentik events signal listener""" | ||||
|  | ||||
| from importlib import import_module | ||||
| from typing import Any | ||||
|  | ||||
| from django.conf import settings | ||||
| from django.contrib.auth.signals import user_logged_in, user_logged_out | ||||
| from django.db.models.signals import post_save, pre_delete | ||||
| from django.dispatch import receiver | ||||
| from django.http import HttpRequest | ||||
| from rest_framework.request import Request | ||||
|  | ||||
| from authentik.core.models import User | ||||
| from authentik.core.models import AuthenticatedSession, User | ||||
| from authentik.core.signals import login_failed, password_changed | ||||
| from authentik.events.apps import SYSTEM_TASK_STATUS | ||||
| from authentik.events.models import Event, EventAction, SystemTask | ||||
| @ -23,6 +26,7 @@ from authentik.stages.user_write.signals import user_write | ||||
| from authentik.tenants.utils import get_current_tenant | ||||
|  | ||||
| SESSION_LOGIN_EVENT = "login_event" | ||||
| _session_engine = import_module(settings.SESSION_ENGINE) | ||||
|  | ||||
|  | ||||
| @receiver(user_logged_in) | ||||
| @ -40,11 +44,20 @@ def on_user_logged_in(sender, request: HttpRequest, user: User, **_): | ||||
|             kwargs[PLAN_CONTEXT_METHOD_ARGS] = flow_plan.context.get(PLAN_CONTEXT_METHOD_ARGS, {}) | ||||
|     event = Event.new(EventAction.LOGIN, **kwargs).from_http(request, user=user) | ||||
|     request.session[SESSION_LOGIN_EVENT] = event | ||||
|     request.session.save() | ||||
|  | ||||
|  | ||||
| def get_login_event(request: HttpRequest) -> Event | None: | ||||
| def get_login_event(request_or_session: HttpRequest | AuthenticatedSession | None) -> Event | None: | ||||
|     """Wrapper to get login event that can be mocked in tests""" | ||||
|     return request.session.get(SESSION_LOGIN_EVENT, None) | ||||
|     session = None | ||||
|     if not request_or_session: | ||||
|         return None | ||||
|     if isinstance(request_or_session, HttpRequest | Request): | ||||
|         session = request_or_session.session | ||||
|     if isinstance(request_or_session, AuthenticatedSession): | ||||
|         SessionStore = _session_engine.SessionStore | ||||
|         session = SessionStore(request_or_session.session_key) | ||||
|     return session.get(SESSION_LOGIN_EVENT, None) | ||||
|  | ||||
|  | ||||
| @receiver(user_logged_out) | ||||
|  | ||||
| @ -6,6 +6,7 @@ from django.db.models import Model | ||||
| from django.test import TestCase | ||||
|  | ||||
| from authentik.core.models import default_token_key | ||||
| from authentik.events.models import default_event_duration | ||||
| from authentik.lib.utils.reflection import get_apps | ||||
|  | ||||
|  | ||||
| @ -20,7 +21,7 @@ def model_tester_factory(test_model: type[Model]) -> Callable: | ||||
|         allowed = 0 | ||||
|         # Token-like objects need to lookup the current tenant to get the default token length | ||||
|         for field in test_model._meta.fields: | ||||
|             if field.default == default_token_key: | ||||
|             if field.default in [default_token_key, default_event_duration]: | ||||
|                 allowed += 1 | ||||
|         with self.assertNumQueries(allowed): | ||||
|             str(test_model()) | ||||
|  | ||||
| @ -2,7 +2,8 @@ | ||||
|  | ||||
| from unittest.mock import MagicMock, patch | ||||
|  | ||||
| from django.test import TestCase | ||||
| from django.urls import reverse | ||||
| from rest_framework.test import APITestCase | ||||
|  | ||||
| from authentik.core.models import Group, User | ||||
| from authentik.events.models import ( | ||||
| @ -10,6 +11,7 @@ from authentik.events.models import ( | ||||
|     EventAction, | ||||
|     Notification, | ||||
|     NotificationRule, | ||||
|     NotificationSeverity, | ||||
|     NotificationTransport, | ||||
|     NotificationWebhookMapping, | ||||
|     TransportMode, | ||||
| @ -20,7 +22,7 @@ from authentik.policies.exceptions import PolicyException | ||||
| from authentik.policies.models import PolicyBinding | ||||
|  | ||||
|  | ||||
| class TestEventsNotifications(TestCase): | ||||
| class TestEventsNotifications(APITestCase): | ||||
|     """Test Event Notifications""" | ||||
|  | ||||
|     def setUp(self) -> None: | ||||
| @ -131,3 +133,15 @@ class TestEventsNotifications(TestCase): | ||||
|         Notification.objects.all().delete() | ||||
|         Event.new(EventAction.CUSTOM_PREFIX).save() | ||||
|         self.assertEqual(Notification.objects.first().body, "foo") | ||||
|  | ||||
|     def test_api_mark_all_seen(self): | ||||
|         """Test mark_all_seen""" | ||||
|         self.client.force_login(self.user) | ||||
|  | ||||
|         Notification.objects.create( | ||||
|             severity=NotificationSeverity.NOTICE, body="foo", user=self.user, seen=False | ||||
|         ) | ||||
|  | ||||
|         response = self.client.post(reverse("authentik_api:notification-mark-all-seen")) | ||||
|         self.assertEqual(response.status_code, 204) | ||||
|         self.assertFalse(Notification.objects.filter(body="foo", seen=False).exists()) | ||||
|  | ||||
| @ -2,7 +2,6 @@ | ||||
|  | ||||
| import re | ||||
| import socket | ||||
| from collections.abc import Iterable | ||||
| from ipaddress import ip_address, ip_network | ||||
| from textwrap import indent | ||||
| from types import CodeType | ||||
| @ -28,6 +27,12 @@ from authentik.stages.authenticator import devices_for_user | ||||
|  | ||||
| LOGGER = get_logger() | ||||
|  | ||||
| ARG_SANITIZE = re.compile(r"[:.-]") | ||||
|  | ||||
|  | ||||
| def sanitize_arg(arg_name: str) -> str: | ||||
|     return re.sub(ARG_SANITIZE, "_", arg_name) | ||||
|  | ||||
|  | ||||
| class BaseEvaluator: | ||||
|     """Validate and evaluate python-based expressions""" | ||||
| @ -177,9 +182,9 @@ class BaseEvaluator: | ||||
|         proc = PolicyProcess(PolicyBinding(policy=policy), request=req, connection=None) | ||||
|         return proc.profiling_wrapper() | ||||
|  | ||||
|     def wrap_expression(self, expression: str, params: Iterable[str]) -> str: | ||||
|     def wrap_expression(self, expression: str) -> str: | ||||
|         """Wrap expression in a function, call it, and save the result as `result`""" | ||||
|         handler_signature = ",".join(params) | ||||
|         handler_signature = ",".join(sanitize_arg(x) for x in self._context.keys()) | ||||
|         full_expression = "" | ||||
|         full_expression += f"def handler({handler_signature}):\n" | ||||
|         full_expression += indent(expression, "    ") | ||||
| @ -188,8 +193,8 @@ class BaseEvaluator: | ||||
|  | ||||
|     def compile(self, expression: str) -> CodeType: | ||||
|         """Parse expression. Raises SyntaxError or ValueError if the syntax is incorrect.""" | ||||
|         param_keys = self._context.keys() | ||||
|         return compile(self.wrap_expression(expression, param_keys), self._filename, "exec") | ||||
|         expression = self.wrap_expression(expression) | ||||
|         return compile(expression, self._filename, "exec") | ||||
|  | ||||
|     def evaluate(self, expression_source: str) -> Any: | ||||
|         """Parse and evaluate expression. If the syntax is incorrect, a SyntaxError is raised. | ||||
| @ -205,7 +210,7 @@ class BaseEvaluator: | ||||
|                 self.handle_error(exc, expression_source) | ||||
|                 raise exc | ||||
|             try: | ||||
|                 _locals = self._context | ||||
|                 _locals = {sanitize_arg(x): y for x, y in self._context.items()} | ||||
|                 # Yes this is an exec, yes it is potentially bad. Since we limit what variables are | ||||
|                 # available here, and these policies can only be edited by admins, this is a risk | ||||
|                 # we're willing to take. | ||||
|  | ||||
| @ -30,6 +30,11 @@ class TestHTTP(TestCase): | ||||
|         request = self.factory.get("/", HTTP_X_FORWARDED_FOR="127.0.0.2") | ||||
|         self.assertEqual(ClientIPMiddleware.get_client_ip(request), "127.0.0.2") | ||||
|  | ||||
|     def test_forward_for_invalid(self): | ||||
|         """Test invalid forward for""" | ||||
|         request = self.factory.get("/", HTTP_X_FORWARDED_FOR="foobar") | ||||
|         self.assertEqual(ClientIPMiddleware.get_client_ip(request), ClientIPMiddleware.default_ip) | ||||
|  | ||||
|     def test_fake_outpost(self): | ||||
|         """Test faked IP which is overridden by an outpost""" | ||||
|         token = Token.objects.create( | ||||
| @ -53,6 +58,17 @@ class TestHTTP(TestCase): | ||||
|             }, | ||||
|         ) | ||||
|         self.assertEqual(ClientIPMiddleware.get_client_ip(request), "127.0.0.1") | ||||
|         # Invalid, not a real IP | ||||
|         self.user.type = UserTypes.INTERNAL_SERVICE_ACCOUNT | ||||
|         self.user.save() | ||||
|         request = self.factory.get( | ||||
|             "/", | ||||
|             **{ | ||||
|                 ClientIPMiddleware.outpost_remote_ip_header: "foobar", | ||||
|                 ClientIPMiddleware.outpost_token_header: token.key, | ||||
|             }, | ||||
|         ) | ||||
|         self.assertEqual(ClientIPMiddleware.get_client_ip(request), "127.0.0.1") | ||||
|         # Valid | ||||
|         self.user.type = UserTypes.INTERNAL_SERVICE_ACCOUNT | ||||
|         self.user.save() | ||||
|  | ||||
| @ -21,7 +21,14 @@ class DebugSession(Session): | ||||
|  | ||||
|     def send(self, req: PreparedRequest, *args, **kwargs): | ||||
|         request_id = str(uuid4()) | ||||
|         LOGGER.debug("HTTP request sent", uid=request_id, path=req.path_url, headers=req.headers) | ||||
|         LOGGER.debug( | ||||
|             "HTTP request sent", | ||||
|             uid=request_id, | ||||
|             url=req.url, | ||||
|             method=req.method, | ||||
|             headers=req.headers, | ||||
|             body=req.body, | ||||
|         ) | ||||
|         resp = super().send(req, *args, **kwargs) | ||||
|         LOGGER.debug( | ||||
|             "HTTP response received", | ||||
|  | ||||
| @ -108,7 +108,7 @@ class EventMatcherPolicy(Policy): | ||||
|                 result=result, | ||||
|             ) | ||||
|             matches.append(result) | ||||
|         passing = any(x.passing for x in matches) | ||||
|         passing = all(x.passing for x in matches) | ||||
|         messages = chain(*[x.messages for x in matches]) | ||||
|         result = PolicyResult(passing, *messages) | ||||
|         result.source_results = matches | ||||
|  | ||||
| @ -77,11 +77,24 @@ class TestEventMatcherPolicy(TestCase): | ||||
|         request = PolicyRequest(get_anonymous_user()) | ||||
|         request.context["event"] = event | ||||
|         policy: EventMatcherPolicy = EventMatcherPolicy.objects.create( | ||||
|             client_ip="1.2.3.5", app="bar" | ||||
|             client_ip="1.2.3.5", app="foo" | ||||
|         ) | ||||
|         response = policy.passes(request) | ||||
|         self.assertFalse(response.passing) | ||||
|  | ||||
|     def test_multiple(self): | ||||
|         """Test multiple""" | ||||
|         event = Event.new(EventAction.LOGIN) | ||||
|         event.app = "foo" | ||||
|         event.client_ip = "1.2.3.4" | ||||
|         request = PolicyRequest(get_anonymous_user()) | ||||
|         request.context["event"] = event | ||||
|         policy: EventMatcherPolicy = EventMatcherPolicy.objects.create( | ||||
|             client_ip="1.2.3.4", app="foo" | ||||
|         ) | ||||
|         response = policy.passes(request) | ||||
|         self.assertTrue(response.passing) | ||||
|  | ||||
|     def test_invalid(self): | ||||
|         """Test passing event""" | ||||
|         request = PolicyRequest(get_anonymous_user()) | ||||
|  | ||||
| @ -4,13 +4,13 @@ from django.apps.registry import Apps | ||||
| from django.db.backends.base.schema import BaseDatabaseSchemaEditor | ||||
|  | ||||
| from django.db import migrations | ||||
| from django.contrib.auth.management import create_permissions | ||||
|  | ||||
|  | ||||
| def migrate_search_group(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): | ||||
|     from guardian.shortcuts import assign_perm | ||||
|     from authentik.core.models import User | ||||
|     from django.apps import apps as real_apps | ||||
|     from django.contrib.auth.management import create_permissions | ||||
|     from guardian.shortcuts import UserObjectPermission | ||||
|  | ||||
|     db_alias = schema_editor.connection.alias | ||||
|  | ||||
| @ -20,14 +20,25 @@ def migrate_search_group(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): | ||||
|     create_permissions(real_apps.get_app_config("authentik_providers_ldap"), using=db_alias) | ||||
|  | ||||
|     LDAPProvider = apps.get_model("authentik_providers_ldap", "ldapprovider") | ||||
|     Permission = apps.get_model("auth", "Permission") | ||||
|     UserObjectPermission = apps.get_model("guardian", "UserObjectPermission") | ||||
|     ContentType = apps.get_model("contenttypes", "ContentType") | ||||
|  | ||||
|     new_prem = Permission.objects.using(db_alias).get(codename="search_full_directory") | ||||
|     ct = ContentType.objects.using(db_alias).get( | ||||
|         app_label="authentik_providers_ldap", | ||||
|         model="ldapprovider", | ||||
|     ) | ||||
|  | ||||
|     for provider in LDAPProvider.objects.using(db_alias).all(): | ||||
|         for user_pk in ( | ||||
|             provider.search_group.users.using(db_alias).all().values_list("pk", flat=True) | ||||
|         ): | ||||
|             # We need the correct user model instance to assign the permission | ||||
|             assign_perm( | ||||
|                 "search_full_directory", User.objects.using(db_alias).get(pk=user_pk), provider | ||||
|         if not provider.search_group: | ||||
|             continue | ||||
|         for user in provider.search_group.users.using(db_alias).all(): | ||||
|             UserObjectPermission.objects.using(db_alias).create( | ||||
|                 user=user, | ||||
|                 permission=new_prem, | ||||
|                 object_pk=provider.pk, | ||||
|                 content_type=ct, | ||||
|             ) | ||||
|  | ||||
|  | ||||
| @ -35,6 +46,7 @@ class Migration(migrations.Migration): | ||||
|  | ||||
|     dependencies = [ | ||||
|         ("authentik_providers_ldap", "0003_ldapprovider_mfa_support_and_more"), | ||||
|         ("guardian", "0002_generic_permissions_index"), | ||||
|     ] | ||||
|  | ||||
|     operations = [ | ||||
|  | ||||
| @ -1,15 +1,18 @@ | ||||
| """OAuth2Provider API Views""" | ||||
|  | ||||
| from copy import copy | ||||
| from re import compile | ||||
| from re import error as RegexError | ||||
|  | ||||
| from django.urls import reverse | ||||
| from django.utils import timezone | ||||
| from django.utils.translation import gettext_lazy as _ | ||||
| from drf_spectacular.types import OpenApiTypes | ||||
| from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema | ||||
| from guardian.shortcuts import get_objects_for_user | ||||
| from rest_framework.decorators import action | ||||
| from rest_framework.exceptions import ValidationError | ||||
| from rest_framework.fields import CharField | ||||
| from rest_framework.fields import CharField, ChoiceField | ||||
| from rest_framework.generics import get_object_or_404 | ||||
| from rest_framework.request import Request | ||||
| from rest_framework.response import Response | ||||
| @ -20,13 +23,39 @@ from authentik.core.api.used_by import UsedByMixin | ||||
| from authentik.core.api.utils import PassiveSerializer, PropertyMappingPreviewSerializer | ||||
| from authentik.core.models import Provider | ||||
| from authentik.providers.oauth2.id_token import IDToken | ||||
| from authentik.providers.oauth2.models import AccessToken, OAuth2Provider, ScopeMapping | ||||
| from authentik.providers.oauth2.models import ( | ||||
|     AccessToken, | ||||
|     OAuth2Provider, | ||||
|     RedirectURIMatchingMode, | ||||
|     ScopeMapping, | ||||
| ) | ||||
| from authentik.rbac.decorators import permission_required | ||||
|  | ||||
|  | ||||
| class RedirectURISerializer(PassiveSerializer): | ||||
|     """A single allowed redirect URI entry""" | ||||
|  | ||||
|     matching_mode = ChoiceField(choices=RedirectURIMatchingMode.choices) | ||||
|     url = CharField() | ||||
|  | ||||
|  | ||||
| class OAuth2ProviderSerializer(ProviderSerializer): | ||||
|     """OAuth2Provider Serializer""" | ||||
|  | ||||
|     redirect_uris = RedirectURISerializer(many=True, source="_redirect_uris") | ||||
|  | ||||
|     def validate_redirect_uris(self, data: list) -> list: | ||||
|         for entry in data: | ||||
|             if entry.get("matching_mode") == RedirectURIMatchingMode.REGEX: | ||||
|                 url = entry.get("url") | ||||
|                 try: | ||||
|                     compile(url) | ||||
|                 except RegexError: | ||||
|                     raise ValidationError( | ||||
|                         _("Invalid Regex Pattern: {url}".format(url=url)) | ||||
|                     ) from None | ||||
|         return data | ||||
|  | ||||
|     class Meta: | ||||
|         model = OAuth2Provider | ||||
|         fields = ProviderSerializer.Meta.fields + [ | ||||
| @ -78,7 +107,6 @@ class OAuth2ProviderViewSet(UsedByMixin, ModelViewSet): | ||||
|         "refresh_token_validity", | ||||
|         "include_claims_in_id_token", | ||||
|         "signing_key", | ||||
|         "redirect_uris", | ||||
|         "sub_mode", | ||||
|         "property_mappings", | ||||
|         "issuer_mode", | ||||
|  | ||||
| @ -7,7 +7,7 @@ from django.http import HttpRequest, HttpResponse, HttpResponseRedirect | ||||
| from authentik.events.models import Event, EventAction | ||||
| from authentik.lib.sentry import SentryIgnoredException | ||||
| from authentik.lib.views import bad_request_message | ||||
| from authentik.providers.oauth2.models import GrantTypes | ||||
| from authentik.providers.oauth2.models import GrantTypes, RedirectURI | ||||
|  | ||||
|  | ||||
| class OAuth2Error(SentryIgnoredException): | ||||
| @ -46,9 +46,9 @@ class RedirectUriError(OAuth2Error): | ||||
|     ) | ||||
|  | ||||
|     provided_uri: str | ||||
|     allowed_uris: list[str] | ||||
|     allowed_uris: list[RedirectURI] | ||||
|  | ||||
|     def __init__(self, provided_uri: str, allowed_uris: list[str]) -> None: | ||||
|     def __init__(self, provided_uri: str, allowed_uris: list[RedirectURI]) -> None: | ||||
|         super().__init__() | ||||
|         self.provided_uri = provided_uri | ||||
|         self.allowed_uris = allowed_uris | ||||
|  | ||||
| @ -1,6 +1,7 @@ | ||||
| """id_token utils""" | ||||
|  | ||||
| from dataclasses import asdict, dataclass, field | ||||
| from hashlib import sha256 | ||||
| from typing import TYPE_CHECKING, Any | ||||
|  | ||||
| from django.db import models | ||||
| @ -23,8 +24,13 @@ if TYPE_CHECKING: | ||||
|     from authentik.providers.oauth2.models import BaseGrantModel, OAuth2Provider | ||||
|  | ||||
|  | ||||
| def hash_session_key(session_key: str) -> str: | ||||
|     """Hash the session key for inclusion in JWTs as `sid`""" | ||||
|     return sha256(session_key.encode("ascii")).hexdigest() | ||||
|  | ||||
|  | ||||
| class SubModes(models.TextChoices): | ||||
|     """Mode after which 'sub' attribute is generateed, for compatibility reasons""" | ||||
|     """Mode after which 'sub' attribute is generated, for compatibility reasons""" | ||||
|  | ||||
|     HASHED_USER_ID = "hashed_user_id", _("Based on the Hashed User ID") | ||||
|     USER_ID = "user_id", _("Based on user ID") | ||||
| @ -51,7 +57,8 @@ class IDToken: | ||||
|     and potentially other requested Claims. The ID Token is represented as a | ||||
|     JSON Web Token (JWT) [JWT]. | ||||
|  | ||||
|     https://openid.net/specs/openid-connect-core-1_0.html#IDToken""" | ||||
|     https://openid.net/specs/openid-connect-core-1_0.html#IDToken | ||||
|     https://www.iana.org/assignments/jwt/jwt.xhtml""" | ||||
|  | ||||
|     # Issuer, https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.1 | ||||
|     iss: str | None = None | ||||
| @ -79,6 +86,8 @@ class IDToken: | ||||
|     nonce: str | None = None | ||||
|     # Access Token hash value, http://openid.net/specs/openid-connect-core-1_0.html | ||||
|     at_hash: str | None = None | ||||
|     # Session ID, https://openid.net/specs/openid-connect-frontchannel-1_0.html#ClaimsContents | ||||
|     sid: str | None = None | ||||
|  | ||||
|     claims: dict[str, Any] = field(default_factory=dict) | ||||
|  | ||||
| @ -116,9 +125,11 @@ class IDToken: | ||||
|         now = timezone.now() | ||||
|         id_token.iat = int(now.timestamp()) | ||||
|         id_token.auth_time = int(token.auth_time.timestamp()) | ||||
|         if token.session: | ||||
|             id_token.sid = hash_session_key(token.session.session_key) | ||||
|  | ||||
|         # We use the timestamp of the user's last successful login (EventAction.LOGIN) for auth_time | ||||
|         auth_event = get_login_event(request) | ||||
|         auth_event = get_login_event(token.session) | ||||
|         if auth_event: | ||||
|             # Also check which method was used for authentication | ||||
|             method = auth_event.context.get(PLAN_CONTEXT_METHOD, "") | ||||
|  | ||||
| @ -3,6 +3,7 @@ | ||||
| import django.db.models.deletion | ||||
| from django.apps.registry import Apps | ||||
| from django.db import migrations, models | ||||
| from django.db.backends.base.schema import BaseDatabaseSchemaEditor | ||||
|  | ||||
| import authentik.lib.utils.time | ||||
|  | ||||
| @ -14,7 +15,7 @@ scope_uid_map = { | ||||
| } | ||||
|  | ||||
|  | ||||
| def set_managed_flag(apps: Apps, schema_editor): | ||||
| def set_managed_flag(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): | ||||
|     ScopeMapping = apps.get_model("authentik_providers_oauth2", "ScopeMapping") | ||||
|     db_alias = schema_editor.connection.alias | ||||
|     for mapping in ScopeMapping.objects.using(db_alias).filter(name__startswith="Autogenerated "): | ||||
|  | ||||
| @ -0,0 +1,26 @@ | ||||
| # Generated by Django 5.0.9 on 2024-09-26 16:25 | ||||
|  | ||||
| from django.conf import settings | ||||
| from django.db import migrations, models | ||||
|  | ||||
|  | ||||
| class Migration(migrations.Migration): | ||||
|  | ||||
|     dependencies = [ | ||||
|         ("authentik_providers_oauth2", "0018_alter_accesstoken_expires_and_more"), | ||||
|         migrations.swappable_dependency(settings.AUTH_USER_MODEL), | ||||
|     ] | ||||
|  | ||||
|     # Original preserved | ||||
|     # See https://github.com/goauthentik/authentik/issues/11874 | ||||
|     # operations = [ | ||||
|     #     migrations.AddIndex( | ||||
|     #         model_name="accesstoken", | ||||
|     #         index=models.Index(fields=["token"], name="authentik_p_token_4bc870_idx"), | ||||
|     #     ), | ||||
|     #     migrations.AddIndex( | ||||
|     #         model_name="refreshtoken", | ||||
|     #         index=models.Index(fields=["token"], name="authentik_p_token_1a841f_idx"), | ||||
|     #     ), | ||||
|     # ] | ||||
|     operations = [] | ||||
| @ -0,0 +1,34 @@ | ||||
| # Generated by Django 5.0.9 on 2024-09-27 14:50 | ||||
|  | ||||
| from django.conf import settings | ||||
| from django.db import migrations, models | ||||
|  | ||||
|  | ||||
| class Migration(migrations.Migration): | ||||
|  | ||||
|     dependencies = [ | ||||
|         ("authentik_providers_oauth2", "0019_accesstoken_authentik_p_token_4bc870_idx_and_more"), | ||||
|         migrations.swappable_dependency(settings.AUTH_USER_MODEL), | ||||
|     ] | ||||
|  | ||||
|     # Original preserved | ||||
|     # See https://github.com/goauthentik/authentik/issues/11874 | ||||
|     # operations = [ | ||||
|     #     migrations.RemoveIndex( | ||||
|     #         model_name="accesstoken", | ||||
|     #         name="authentik_p_token_4bc870_idx", | ||||
|     #     ), | ||||
|     #     migrations.RemoveIndex( | ||||
|     #         model_name="refreshtoken", | ||||
|     #         name="authentik_p_token_1a841f_idx", | ||||
|     #     ), | ||||
|     #     migrations.AddIndex( | ||||
|     #         model_name="accesstoken", | ||||
|     #         index=models.Index(fields=["token", "provider"], name="authentik_p_token_f99422_idx"), | ||||
|     #     ), | ||||
|     #     migrations.AddIndex( | ||||
|     #         model_name="refreshtoken", | ||||
|     #         index=models.Index(fields=["token", "provider"], name="authentik_p_token_a1d921_idx"), | ||||
|     #     ), | ||||
|     # ] | ||||
|     operations = [] | ||||
| @ -0,0 +1,42 @@ | ||||
| # Generated by Django 5.0.9 on 2024-10-16 14:53 | ||||
|  | ||||
| import django.db.models.deletion | ||||
| from django.db import migrations, models | ||||
|  | ||||
|  | ||||
| class Migration(migrations.Migration): | ||||
|  | ||||
|     dependencies = [ | ||||
|         ("authentik_crypto", "0004_alter_certificatekeypair_name"), | ||||
|         ( | ||||
|             "authentik_providers_oauth2", | ||||
|             "0020_remove_accesstoken_authentik_p_token_4bc870_idx_and_more", | ||||
|         ), | ||||
|     ] | ||||
|  | ||||
|     operations = [ | ||||
|         migrations.AddField( | ||||
|             model_name="oauth2provider", | ||||
|             name="encryption_key", | ||||
|             field=models.ForeignKey( | ||||
|                 help_text="Key used to encrypt the tokens. When set, tokens will be encrypted and returned as JWEs.", | ||||
|                 null=True, | ||||
|                 on_delete=django.db.models.deletion.SET_NULL, | ||||
|                 related_name="oauth2provider_encryption_key_set", | ||||
|                 to="authentik_crypto.certificatekeypair", | ||||
|                 verbose_name="Encryption Key", | ||||
|             ), | ||||
|         ), | ||||
|         migrations.AlterField( | ||||
|             model_name="oauth2provider", | ||||
|             name="signing_key", | ||||
|             field=models.ForeignKey( | ||||
|                 help_text="Key used to sign the tokens.", | ||||
|                 null=True, | ||||
|                 on_delete=django.db.models.deletion.SET_NULL, | ||||
|                 related_name="oauth2provider_signing_key_set", | ||||
|                 to="authentik_crypto.certificatekeypair", | ||||
|                 verbose_name="Signing Key", | ||||
|             ), | ||||
|         ), | ||||
|     ] | ||||
| @ -0,0 +1,113 @@ | ||||
| # Generated by Django 5.0.9 on 2024-10-23 13:38 | ||||
|  | ||||
| from hashlib import sha256 | ||||
| import django.db.models.deletion | ||||
| from django.db import migrations, models | ||||
| from django.apps.registry import Apps | ||||
| from django.db.backends.base.schema import BaseDatabaseSchemaEditor | ||||
| from authentik.lib.migrations import progress_bar | ||||
|  | ||||
|  | ||||
| def migrate_session(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): | ||||
|     AuthenticatedSession = apps.get_model("authentik_core", "authenticatedsession") | ||||
|     AuthorizationCode = apps.get_model("authentik_providers_oauth2", "authorizationcode") | ||||
|     AccessToken = apps.get_model("authentik_providers_oauth2", "accesstoken") | ||||
|     RefreshToken = apps.get_model("authentik_providers_oauth2", "refreshtoken") | ||||
|     db_alias = schema_editor.connection.alias | ||||
|  | ||||
|     print(f"\nFetching session keys, this might take a couple of minutes...") | ||||
|     session_ids = {} | ||||
|     for session in progress_bar(AuthenticatedSession.objects.using(db_alias).all()): | ||||
|         session_ids[sha256(session.session_key.encode("ascii")).hexdigest()] = session.session_key | ||||
|     for model in [AuthorizationCode, AccessToken, RefreshToken]: | ||||
|         print( | ||||
|             f"\nAdding session to {model._meta.verbose_name}, this might take a couple of minutes..." | ||||
|         ) | ||||
|         for code in progress_bar(model.objects.using(db_alias).all()): | ||||
|             if code.session_id_old not in session_ids: | ||||
|                 continue | ||||
|             code.session = ( | ||||
|                 AuthenticatedSession.objects.using(db_alias) | ||||
|                 .filter(session_key=session_ids[code.session_id_old]) | ||||
|                 .first() | ||||
|             ) | ||||
|             code.save() | ||||
|  | ||||
|  | ||||
| class Migration(migrations.Migration): | ||||
|  | ||||
|     dependencies = [ | ||||
|         ("authentik_core", "0039_source_group_matching_mode_alter_group_name_and_more"), | ||||
|         ("authentik_providers_oauth2", "0021_oauth2provider_encryption_key_and_more"), | ||||
|     ] | ||||
|  | ||||
|     operations = [ | ||||
|         migrations.RenameField( | ||||
|             model_name="accesstoken", | ||||
|             old_name="session_id", | ||||
|             new_name="session_id_old", | ||||
|         ), | ||||
|         migrations.RenameField( | ||||
|             model_name="authorizationcode", | ||||
|             old_name="session_id", | ||||
|             new_name="session_id_old", | ||||
|         ), | ||||
|         migrations.RenameField( | ||||
|             model_name="refreshtoken", | ||||
|             old_name="session_id", | ||||
|             new_name="session_id_old", | ||||
|         ), | ||||
|         migrations.AddField( | ||||
|             model_name="accesstoken", | ||||
|             name="session", | ||||
|             field=models.ForeignKey( | ||||
|                 default=None, | ||||
|                 null=True, | ||||
|                 on_delete=django.db.models.deletion.SET_DEFAULT, | ||||
|                 to="authentik_core.authenticatedsession", | ||||
|             ), | ||||
|         ), | ||||
|         migrations.AddField( | ||||
|             model_name="authorizationcode", | ||||
|             name="session", | ||||
|             field=models.ForeignKey( | ||||
|                 default=None, | ||||
|                 null=True, | ||||
|                 on_delete=django.db.models.deletion.SET_DEFAULT, | ||||
|                 to="authentik_core.authenticatedsession", | ||||
|             ), | ||||
|         ), | ||||
|         migrations.AddField( | ||||
|             model_name="devicetoken", | ||||
|             name="session", | ||||
|             field=models.ForeignKey( | ||||
|                 default=None, | ||||
|                 null=True, | ||||
|                 on_delete=django.db.models.deletion.SET_DEFAULT, | ||||
|                 to="authentik_core.authenticatedsession", | ||||
|             ), | ||||
|         ), | ||||
|         migrations.AddField( | ||||
|             model_name="refreshtoken", | ||||
|             name="session", | ||||
|             field=models.ForeignKey( | ||||
|                 default=None, | ||||
|                 null=True, | ||||
|                 on_delete=django.db.models.deletion.SET_DEFAULT, | ||||
|                 to="authentik_core.authenticatedsession", | ||||
|             ), | ||||
|         ), | ||||
|         migrations.RunPython(migrate_session), | ||||
|         migrations.RemoveField( | ||||
|             model_name="accesstoken", | ||||
|             name="session_id_old", | ||||
|         ), | ||||
|         migrations.RemoveField( | ||||
|             model_name="authorizationcode", | ||||
|             name="session_id_old", | ||||
|         ), | ||||
|         migrations.RemoveField( | ||||
|             model_name="refreshtoken", | ||||
|             name="session_id_old", | ||||
|         ), | ||||
|     ] | ||||
| @ -0,0 +1,31 @@ | ||||
| # Generated by Django 5.0.9 on 2024-10-31 14:28 | ||||
|  | ||||
| import django.contrib.postgres.indexes | ||||
| from django.conf import settings | ||||
| from django.db import migrations | ||||
|  | ||||
|  | ||||
| class Migration(migrations.Migration): | ||||
|  | ||||
|     dependencies = [ | ||||
|         ("authentik_core", "0039_source_group_matching_mode_alter_group_name_and_more"), | ||||
|         ("authentik_providers_oauth2", "0022_remove_accesstoken_session_id_and_more"), | ||||
|         migrations.swappable_dependency(settings.AUTH_USER_MODEL), | ||||
|     ] | ||||
|  | ||||
|     operations = [ | ||||
|         migrations.RunSQL("DROP INDEX IF EXISTS authentik_p_token_f99422_idx;"), | ||||
|         migrations.RunSQL("DROP INDEX IF EXISTS authentik_p_token_a1d921_idx;"), | ||||
|         migrations.AddIndex( | ||||
|             model_name="accesstoken", | ||||
|             index=django.contrib.postgres.indexes.HashIndex( | ||||
|                 fields=["token"], name="authentik_p_token_e00883_hash" | ||||
|             ), | ||||
|         ), | ||||
|         migrations.AddIndex( | ||||
|             model_name="refreshtoken", | ||||
|             index=django.contrib.postgres.indexes.HashIndex( | ||||
|                 fields=["token"], name="authentik_p_token_32e2b7_hash" | ||||
|             ), | ||||
|         ), | ||||
|     ] | ||||
| @ -0,0 +1,48 @@ | ||||
| # Generated by Django 5.0.9 on 2024-11-04 12:56 | ||||
| from django.apps.registry import Apps | ||||
|  | ||||
| from django.db.backends.base.schema import BaseDatabaseSchemaEditor | ||||
|  | ||||
| from django.db import migrations, models | ||||
|  | ||||
|  | ||||
| def migrate_redirect_uris(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): | ||||
|     from authentik.providers.oauth2.models import RedirectURI, RedirectURIMatchingMode | ||||
|  | ||||
|     OAuth2Provider = apps.get_model("authentik_providers_oauth2", "oauth2provider") | ||||
|  | ||||
|     db_alias = schema_editor.connection.alias | ||||
|     for provider in OAuth2Provider.objects.using(db_alias).all(): | ||||
|         uris = [] | ||||
|         for old in provider.old_redirect_uris.split("\n"): | ||||
|             mode = RedirectURIMatchingMode.STRICT | ||||
|             if old == "*" or old == ".*": | ||||
|                 mode = RedirectURIMatchingMode.REGEX | ||||
|             uris.append(RedirectURI(mode, url=old)) | ||||
|         provider.redirect_uris = uris | ||||
|         provider.save() | ||||
|  | ||||
|  | ||||
| class Migration(migrations.Migration): | ||||
|  | ||||
|     dependencies = [ | ||||
|         ("authentik_providers_oauth2", "0023_alter_accesstoken_refreshtoken_use_hash_index"), | ||||
|     ] | ||||
|  | ||||
|     operations = [ | ||||
|         migrations.RenameField( | ||||
|             model_name="oauth2provider", | ||||
|             old_name="redirect_uris", | ||||
|             new_name="old_redirect_uris", | ||||
|         ), | ||||
|         migrations.AddField( | ||||
|             model_name="oauth2provider", | ||||
|             name="_redirect_uris", | ||||
|             field=models.JSONField(default=dict, verbose_name="Redirect URIs"), | ||||
|         ), | ||||
|         migrations.RunPython(migrate_redirect_uris, lambda *args: ...), | ||||
|         migrations.RemoveField( | ||||
|             model_name="oauth2provider", | ||||
|             name="old_redirect_uris", | ||||
|         ), | ||||
|     ] | ||||
| @ -3,7 +3,7 @@ | ||||
| import base64 | ||||
| import binascii | ||||
| import json | ||||
| from dataclasses import asdict | ||||
| from dataclasses import asdict, dataclass | ||||
| from functools import cached_property | ||||
| from hashlib import sha256 | ||||
| from typing import Any | ||||
| @ -12,6 +12,7 @@ from urllib.parse import urlparse, urlunparse | ||||
| from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey | ||||
| from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey | ||||
| from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes | ||||
| from dacite import Config | ||||
| from dacite.core import from_dict | ||||
| from django.db import models | ||||
| from django.http import HttpRequest | ||||
| @ -23,7 +24,13 @@ from rest_framework.serializers import Serializer | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.brands.models import WebfingerProvider | ||||
| from authentik.core.models import ExpiringModel, PropertyMapping, Provider, User | ||||
| from authentik.core.models import ( | ||||
|     AuthenticatedSession, | ||||
|     ExpiringModel, | ||||
|     PropertyMapping, | ||||
|     Provider, | ||||
|     User, | ||||
| ) | ||||
| from authentik.crypto.models import CertificateKeyPair | ||||
| from authentik.lib.generators import generate_code_fixed_length, generate_id, generate_key | ||||
| from authentik.lib.models import SerializerModel | ||||
| @ -67,11 +74,25 @@ class IssuerMode(models.TextChoices): | ||||
|     """Configure how the `iss` field is created.""" | ||||
|  | ||||
|     GLOBAL = "global", _("Same identifier is used for all providers") | ||||
|     PER_PROVIDER = "per_provider", _( | ||||
|         "Each provider has a different issuer, based on the application slug." | ||||
|     PER_PROVIDER = ( | ||||
|         "per_provider", | ||||
|         _("Each provider has a different issuer, based on the application slug."), | ||||
|     ) | ||||
|  | ||||
|  | ||||
| class RedirectURIMatchingMode(models.TextChoices): | ||||
|     STRICT = "strict", _("Strict URL comparison") | ||||
|     REGEX = "regex", _("Regular Expression URL matching") | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class RedirectURI: | ||||
|     """A single redirect URI entry""" | ||||
|  | ||||
|     matching_mode: RedirectURIMatchingMode | ||||
|     url: str | ||||
|  | ||||
|  | ||||
| class ResponseTypes(models.TextChoices): | ||||
|     """Response Type required by the client.""" | ||||
|  | ||||
| @ -146,11 +167,9 @@ class OAuth2Provider(WebfingerProvider, Provider): | ||||
|         verbose_name=_("Client Secret"), | ||||
|         default=generate_client_secret, | ||||
|     ) | ||||
|     redirect_uris = models.TextField( | ||||
|         default="", | ||||
|         blank=True, | ||||
|     _redirect_uris = models.JSONField( | ||||
|         default=dict, | ||||
|         verbose_name=_("Redirect URIs"), | ||||
|         help_text=_("Enter each URI on a new line."), | ||||
|     ) | ||||
|  | ||||
|     include_claims_in_id_token = models.BooleanField( | ||||
| @ -251,12 +270,33 @@ class OAuth2Provider(WebfingerProvider, Provider): | ||||
|         except Provider.application.RelatedObjectDoesNotExist: | ||||
|             return None | ||||
|  | ||||
|     @property | ||||
|     def redirect_uris(self) -> list[RedirectURI]: | ||||
|         uris = [] | ||||
|         for entry in self._redirect_uris: | ||||
|             uris.append( | ||||
|                 from_dict( | ||||
|                     RedirectURI, | ||||
|                     entry, | ||||
|                     config=Config(type_hooks={RedirectURIMatchingMode: RedirectURIMatchingMode}), | ||||
|                 ) | ||||
|             ) | ||||
|         return uris | ||||
|  | ||||
|     @redirect_uris.setter | ||||
|     def redirect_uris(self, value: list[RedirectURI]): | ||||
|         cleansed = [] | ||||
|         for entry in value: | ||||
|             cleansed.append(asdict(entry)) | ||||
|         self._redirect_uris = cleansed | ||||
|  | ||||
|     @property | ||||
|     def launch_url(self) -> str | None: | ||||
|         """Guess launch_url based on first redirect_uri""" | ||||
|         if self.redirect_uris == "": | ||||
|         redirects = self.redirect_uris | ||||
|         if len(redirects) < 1: | ||||
|             return None | ||||
|         main_url = self.redirect_uris.split("\n", maxsplit=1)[0] | ||||
|         main_url = redirects[0].url | ||||
|         try: | ||||
|             launch_url = urlparse(main_url)._replace(path="") | ||||
|             return urlunparse(launch_url) | ||||
| @ -320,7 +360,9 @@ class BaseGrantModel(models.Model): | ||||
|     revoked = models.BooleanField(default=False) | ||||
|     _scope = models.TextField(default="", verbose_name=_("Scopes")) | ||||
|     auth_time = models.DateTimeField(verbose_name="Authentication time") | ||||
|     session_id = models.CharField(default="", blank=True) | ||||
|     session = models.ForeignKey( | ||||
|         AuthenticatedSession, null=True, on_delete=models.SET_DEFAULT, default=None | ||||
|     ) | ||||
|  | ||||
|     class Meta: | ||||
|         abstract = True | ||||
| @ -452,6 +494,9 @@ class DeviceToken(ExpiringModel): | ||||
|     device_code = models.TextField(default=generate_key) | ||||
|     user_code = models.TextField(default=generate_code_fixed_length) | ||||
|     _scope = models.TextField(default="", verbose_name=_("Scopes")) | ||||
|     session = models.ForeignKey( | ||||
|         AuthenticatedSession, null=True, on_delete=models.SET_DEFAULT, default=None | ||||
|     ) | ||||
|  | ||||
|     @property | ||||
|     def scope(self) -> list[str]: | ||||
|  | ||||
| @ -1,5 +1,3 @@ | ||||
| from hashlib import sha256 | ||||
|  | ||||
| from django.contrib.auth.signals import user_logged_out | ||||
| from django.dispatch import receiver | ||||
| from django.http import HttpRequest | ||||
| @ -13,5 +11,4 @@ def user_logged_out_oauth_access_token(sender, request: HttpRequest, user: User, | ||||
|     """Revoke access tokens upon user logout""" | ||||
|     if not request.session or not request.session.session_key: | ||||
|         return | ||||
|     hashed_session_key = sha256(request.session.session_key.encode("ascii")).hexdigest() | ||||
|     AccessToken.objects.filter(user=user, session_id=hashed_session_key).delete() | ||||
|     AccessToken.objects.filter(user=user, session__session_key=request.session.session_key).delete() | ||||
|  | ||||
| @ -10,7 +10,13 @@ from rest_framework.test import APITestCase | ||||
| from authentik.blueprints.tests import apply_blueprint | ||||
| from authentik.core.models import Application | ||||
| from authentik.core.tests.utils import create_test_admin_user, create_test_flow | ||||
| from authentik.providers.oauth2.models import OAuth2Provider, ScopeMapping | ||||
| from authentik.lib.generators import generate_id | ||||
| from authentik.providers.oauth2.models import ( | ||||
|     OAuth2Provider, | ||||
|     RedirectURI, | ||||
|     RedirectURIMatchingMode, | ||||
|     ScopeMapping, | ||||
| ) | ||||
|  | ||||
|  | ||||
| class TestAPI(APITestCase): | ||||
| @ -21,7 +27,7 @@ class TestAPI(APITestCase): | ||||
|         self.provider: OAuth2Provider = OAuth2Provider.objects.create( | ||||
|             name="test", | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris="http://testserver", | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")], | ||||
|         ) | ||||
|         self.provider.property_mappings.set(ScopeMapping.objects.all()) | ||||
|         self.app = Application.objects.create(name="test", slug="test", provider=self.provider) | ||||
| @ -50,9 +56,29 @@ class TestAPI(APITestCase): | ||||
|     @skipUnless(version_info >= (3, 11, 4), "This behaviour is only Python 3.11.4 and up") | ||||
|     def test_launch_url(self): | ||||
|         """Test launch_url""" | ||||
|         self.provider.redirect_uris = ( | ||||
|             "https://[\\d\\w]+.pr.test.goauthentik.io/source/oauth/callback/authentik/\n" | ||||
|         ) | ||||
|         self.provider.redirect_uris = [ | ||||
|             RedirectURI( | ||||
|                 RedirectURIMatchingMode.REGEX, | ||||
|                 "https://[\\d\\w]+.pr.test.goauthentik.io/source/oauth/callback/authentik/", | ||||
|             ), | ||||
|         ] | ||||
|         self.provider.save() | ||||
|         self.provider.refresh_from_db() | ||||
|         self.assertIsNone(self.provider.launch_url) | ||||
|  | ||||
|     def test_validate_redirect_uris(self): | ||||
|         """Test redirect_uris API""" | ||||
|         response = self.client.post( | ||||
|             reverse("authentik_api:oauth2provider-list"), | ||||
|             data={ | ||||
|                 "name": generate_id(), | ||||
|                 "authorization_flow": create_test_flow().pk, | ||||
|                 "invalidation_flow": create_test_flow().pk, | ||||
|                 "redirect_uris": [ | ||||
|                     {"matching_mode": "strict", "url": "http://goauthentik.io"}, | ||||
|                     {"matching_mode": "regex", "url": "**"}, | ||||
|                 ], | ||||
|             }, | ||||
|         ) | ||||
|         self.assertJSONEqual(response.content, {"redirect_uris": ["Invalid Regex Pattern: **"]}) | ||||
|         self.assertEqual(response.status_code, 400) | ||||
|  | ||||
| @ -19,6 +19,8 @@ from authentik.providers.oauth2.models import ( | ||||
|     AuthorizationCode, | ||||
|     GrantTypes, | ||||
|     OAuth2Provider, | ||||
|     RedirectURI, | ||||
|     RedirectURIMatchingMode, | ||||
|     ScopeMapping, | ||||
| ) | ||||
| from authentik.providers.oauth2.tests.utils import OAuthTestCase | ||||
| @ -39,7 +41,7 @@ class TestAuthorize(OAuthTestCase): | ||||
|             name=generate_id(), | ||||
|             client_id="test", | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris="http://local.invalid/Foo", | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")], | ||||
|         ) | ||||
|         with self.assertRaises(AuthorizeError): | ||||
|             request = self.factory.get( | ||||
| @ -64,7 +66,7 @@ class TestAuthorize(OAuthTestCase): | ||||
|             name=generate_id(), | ||||
|             client_id="test", | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris="http://local.invalid/Foo", | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")], | ||||
|         ) | ||||
|         with self.assertRaises(AuthorizeError): | ||||
|             request = self.factory.get( | ||||
| @ -84,7 +86,7 @@ class TestAuthorize(OAuthTestCase): | ||||
|             name=generate_id(), | ||||
|             client_id="test", | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris="http://local.invalid", | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")], | ||||
|         ) | ||||
|         with self.assertRaises(RedirectUriError): | ||||
|             request = self.factory.get("/", data={"response_type": "code", "client_id": "test"}) | ||||
| @ -106,7 +108,7 @@ class TestAuthorize(OAuthTestCase): | ||||
|             name=generate_id(), | ||||
|             client_id="test", | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris="data:local.invalid", | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "data:local.invalid")], | ||||
|         ) | ||||
|         with self.assertRaises(RedirectUriError): | ||||
|             request = self.factory.get( | ||||
| @ -125,7 +127,7 @@ class TestAuthorize(OAuthTestCase): | ||||
|             name=generate_id(), | ||||
|             client_id="test", | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris="", | ||||
|             redirect_uris=[], | ||||
|         ) | ||||
|         with self.assertRaises(RedirectUriError): | ||||
|             request = self.factory.get("/", data={"response_type": "code", "client_id": "test"}) | ||||
| @ -140,7 +142,7 @@ class TestAuthorize(OAuthTestCase): | ||||
|         ) | ||||
|         OAuthAuthorizationParams.from_request(request) | ||||
|         provider.refresh_from_db() | ||||
|         self.assertEqual(provider.redirect_uris, "+") | ||||
|         self.assertEqual(provider.redirect_uris, [RedirectURI(RedirectURIMatchingMode.STRICT, "+")]) | ||||
|  | ||||
|     def test_invalid_redirect_uri_regex(self): | ||||
|         """test missing/invalid redirect URI""" | ||||
| @ -148,7 +150,7 @@ class TestAuthorize(OAuthTestCase): | ||||
|             name=generate_id(), | ||||
|             client_id="test", | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris="http://local.invalid?", | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid?")], | ||||
|         ) | ||||
|         with self.assertRaises(RedirectUriError): | ||||
|             request = self.factory.get("/", data={"response_type": "code", "client_id": "test"}) | ||||
| @ -170,7 +172,7 @@ class TestAuthorize(OAuthTestCase): | ||||
|             name=generate_id(), | ||||
|             client_id="test", | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris="+", | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "+")], | ||||
|         ) | ||||
|         with self.assertRaises(RedirectUriError): | ||||
|             request = self.factory.get("/", data={"response_type": "code", "client_id": "test"}) | ||||
| @ -213,7 +215,7 @@ class TestAuthorize(OAuthTestCase): | ||||
|             name=generate_id(), | ||||
|             client_id="test", | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris="http://local.invalid/Foo", | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")], | ||||
|         ) | ||||
|         provider.property_mappings.set( | ||||
|             ScopeMapping.objects.filter( | ||||
| @ -301,7 +303,7 @@ class TestAuthorize(OAuthTestCase): | ||||
|             name=generate_id(), | ||||
|             client_id="test", | ||||
|             authorization_flow=flow, | ||||
|             redirect_uris="foo://localhost", | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "foo://localhost")], | ||||
|             access_code_validity="seconds=100", | ||||
|         ) | ||||
|         Application.objects.create(name="app", slug="app", provider=provider) | ||||
| @ -343,7 +345,7 @@ class TestAuthorize(OAuthTestCase): | ||||
|             name=generate_id(), | ||||
|             client_id="test", | ||||
|             authorization_flow=flow, | ||||
|             redirect_uris="http://localhost", | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")], | ||||
|             signing_key=self.keypair, | ||||
|         ) | ||||
|         provider.property_mappings.set( | ||||
| @ -419,7 +421,7 @@ class TestAuthorize(OAuthTestCase): | ||||
|             name=generate_id(), | ||||
|             client_id="test", | ||||
|             authorization_flow=flow, | ||||
|             redirect_uris="http://localhost", | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")], | ||||
|             signing_key=self.keypair, | ||||
|         ) | ||||
|         Application.objects.create(name="app", slug="app", provider=provider) | ||||
| @ -474,7 +476,7 @@ class TestAuthorize(OAuthTestCase): | ||||
|             name=generate_id(), | ||||
|             client_id=generate_id(), | ||||
|             authorization_flow=flow, | ||||
|             redirect_uris="http://localhost", | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")], | ||||
|             signing_key=self.keypair, | ||||
|         ) | ||||
|         provider.property_mappings.set( | ||||
| @ -532,7 +534,7 @@ class TestAuthorize(OAuthTestCase): | ||||
|             name=generate_id(), | ||||
|             client_id=generate_id(), | ||||
|             authorization_flow=flow, | ||||
|             redirect_uris="http://localhost", | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")], | ||||
|             signing_key=self.keypair, | ||||
|         ) | ||||
|         app = Application.objects.create(name=generate_id(), slug=generate_id(), provider=provider) | ||||
|  | ||||
| @ -11,7 +11,14 @@ from authentik.core.models import Application | ||||
| from authentik.core.tests.utils import create_test_admin_user, create_test_cert, create_test_flow | ||||
| from authentik.lib.generators import generate_id | ||||
| from authentik.providers.oauth2.constants import ACR_AUTHENTIK_DEFAULT | ||||
| from authentik.providers.oauth2.models import AccessToken, IDToken, OAuth2Provider, RefreshToken | ||||
| from authentik.providers.oauth2.models import ( | ||||
|     AccessToken, | ||||
|     IDToken, | ||||
|     OAuth2Provider, | ||||
|     RedirectURI, | ||||
|     RedirectURIMatchingMode, | ||||
|     RefreshToken, | ||||
| ) | ||||
| from authentik.providers.oauth2.tests.utils import OAuthTestCase | ||||
|  | ||||
|  | ||||
| @ -23,13 +30,12 @@ class TesOAuth2Introspection(OAuthTestCase): | ||||
|         self.provider: OAuth2Provider = OAuth2Provider.objects.create( | ||||
|             name=generate_id(), | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris="", | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "")], | ||||
|             signing_key=create_test_cert(), | ||||
|         ) | ||||
|         self.app = Application.objects.create( | ||||
|             name=generate_id(), slug=generate_id(), provider=self.provider | ||||
|         ) | ||||
|         self.app.save() | ||||
|         self.user = create_test_admin_user() | ||||
|         self.auth = b64encode( | ||||
|             f"{self.provider.client_id}:{self.provider.client_secret}".encode() | ||||
| @ -114,6 +120,41 @@ class TesOAuth2Introspection(OAuthTestCase): | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|     def test_introspect_invalid_provider(self): | ||||
|         """Test introspection (mismatched provider and token)""" | ||||
|         provider: OAuth2Provider = OAuth2Provider.objects.create( | ||||
|             name=generate_id(), | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "")], | ||||
|             signing_key=create_test_cert(), | ||||
|         ) | ||||
|         auth = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode() | ||||
|  | ||||
|         token: AccessToken = AccessToken.objects.create( | ||||
|             provider=self.provider, | ||||
|             user=self.user, | ||||
|             token=generate_id(), | ||||
|             auth_time=timezone.now(), | ||||
|             _scope="openid user profile", | ||||
|             _id_token=json.dumps( | ||||
|                 asdict( | ||||
|                     IDToken("foo", "bar"), | ||||
|                 ) | ||||
|             ), | ||||
|         ) | ||||
|         res = self.client.post( | ||||
|             reverse("authentik_providers_oauth2:token-introspection"), | ||||
|             HTTP_AUTHORIZATION=f"Basic {auth}", | ||||
|             data={"token": token.token}, | ||||
|         ) | ||||
|         self.assertEqual(res.status_code, 200) | ||||
|         self.assertJSONEqual( | ||||
|             res.content.decode(), | ||||
|             { | ||||
|                 "active": False, | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|     def test_introspect_invalid_auth(self): | ||||
|         """Test introspect (invalid auth)""" | ||||
|         res = self.client.post( | ||||
|  | ||||
| @ -13,7 +13,7 @@ from authentik.core.tests.utils import create_test_cert, create_test_flow | ||||
| from authentik.crypto.builder import PrivateKeyAlg | ||||
| from authentik.crypto.models import CertificateKeyPair | ||||
| from authentik.lib.generators import generate_id | ||||
| from authentik.providers.oauth2.models import OAuth2Provider | ||||
| from authentik.providers.oauth2.models import OAuth2Provider, RedirectURI, RedirectURIMatchingMode | ||||
| from authentik.providers.oauth2.tests.utils import OAuthTestCase | ||||
|  | ||||
| TEST_CORDS_CERT = """ | ||||
| @ -49,7 +49,7 @@ class TestJWKS(OAuthTestCase): | ||||
|             name="test", | ||||
|             client_id="test", | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris="http://local.invalid", | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")], | ||||
|             signing_key=create_test_cert(), | ||||
|         ) | ||||
|         app = Application.objects.create(name="test", slug="test", provider=provider) | ||||
| @ -68,7 +68,7 @@ class TestJWKS(OAuthTestCase): | ||||
|             name="test", | ||||
|             client_id="test", | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris="http://local.invalid", | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")], | ||||
|         ) | ||||
|         app = Application.objects.create(name="test", slug="test", provider=provider) | ||||
|         response = self.client.get( | ||||
| @ -82,7 +82,7 @@ class TestJWKS(OAuthTestCase): | ||||
|             name="test", | ||||
|             client_id="test", | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris="http://local.invalid", | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")], | ||||
|             signing_key=create_test_cert(PrivateKeyAlg.ECDSA), | ||||
|         ) | ||||
|         app = Application.objects.create(name="test", slug="test", provider=provider) | ||||
| @ -104,7 +104,7 @@ class TestJWKS(OAuthTestCase): | ||||
|             name="test", | ||||
|             client_id="test", | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris="http://local.invalid", | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")], | ||||
|             signing_key=cert, | ||||
|         ) | ||||
|         app = Application.objects.create(name="test", slug="test", provider=provider) | ||||
|  | ||||
| @ -10,7 +10,14 @@ from django.utils import timezone | ||||
| from authentik.core.models import Application | ||||
| from authentik.core.tests.utils import create_test_admin_user, create_test_cert, create_test_flow | ||||
| from authentik.lib.generators import generate_id | ||||
| from authentik.providers.oauth2.models import AccessToken, IDToken, OAuth2Provider, RefreshToken | ||||
| from authentik.providers.oauth2.models import ( | ||||
|     AccessToken, | ||||
|     IDToken, | ||||
|     OAuth2Provider, | ||||
|     RedirectURI, | ||||
|     RedirectURIMatchingMode, | ||||
|     RefreshToken, | ||||
| ) | ||||
| from authentik.providers.oauth2.tests.utils import OAuthTestCase | ||||
|  | ||||
|  | ||||
| @ -22,7 +29,7 @@ class TesOAuth2Revoke(OAuthTestCase): | ||||
|         self.provider: OAuth2Provider = OAuth2Provider.objects.create( | ||||
|             name=generate_id(), | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris="", | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "")], | ||||
|             signing_key=create_test_cert(), | ||||
|         ) | ||||
|         self.app = Application.objects.create( | ||||
|  | ||||
| @ -22,6 +22,8 @@ from authentik.providers.oauth2.models import ( | ||||
|     AccessToken, | ||||
|     AuthorizationCode, | ||||
|     OAuth2Provider, | ||||
|     RedirectURI, | ||||
|     RedirectURIMatchingMode, | ||||
|     RefreshToken, | ||||
|     ScopeMapping, | ||||
| ) | ||||
| @ -42,7 +44,7 @@ class TestToken(OAuthTestCase): | ||||
|         provider = OAuth2Provider.objects.create( | ||||
|             name=generate_id(), | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris="http://TestServer", | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://TestServer")], | ||||
|             signing_key=self.keypair, | ||||
|         ) | ||||
|         header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode() | ||||
| @ -69,7 +71,7 @@ class TestToken(OAuthTestCase): | ||||
|         provider = OAuth2Provider.objects.create( | ||||
|             name=generate_id(), | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris="http://testserver", | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")], | ||||
|             signing_key=self.keypair, | ||||
|         ) | ||||
|         header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode() | ||||
| @ -90,7 +92,7 @@ class TestToken(OAuthTestCase): | ||||
|         provider = OAuth2Provider.objects.create( | ||||
|             name=generate_id(), | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris="http://local.invalid", | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")], | ||||
|             signing_key=self.keypair, | ||||
|         ) | ||||
|         header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode() | ||||
| @ -118,7 +120,7 @@ class TestToken(OAuthTestCase): | ||||
|         provider = OAuth2Provider.objects.create( | ||||
|             name=generate_id(), | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris="http://local.invalid", | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")], | ||||
|             signing_key=self.keypair, | ||||
|         ) | ||||
|         # Needs to be assigned to an application for iss to be set | ||||
| @ -158,7 +160,7 @@ class TestToken(OAuthTestCase): | ||||
|         provider = OAuth2Provider.objects.create( | ||||
|             name=generate_id(), | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris="http://local.invalid", | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")], | ||||
|             signing_key=self.keypair, | ||||
|         ) | ||||
|         provider.property_mappings.set( | ||||
| @ -220,7 +222,7 @@ class TestToken(OAuthTestCase): | ||||
|         provider = OAuth2Provider.objects.create( | ||||
|             name=generate_id(), | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris="http://local.invalid", | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")], | ||||
|             signing_key=self.keypair, | ||||
|         ) | ||||
|         provider.property_mappings.set( | ||||
| @ -278,7 +280,7 @@ class TestToken(OAuthTestCase): | ||||
|         provider = OAuth2Provider.objects.create( | ||||
|             name=generate_id(), | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris="http://testserver", | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")], | ||||
|             signing_key=self.keypair, | ||||
|         ) | ||||
|         provider.property_mappings.set( | ||||
|  | ||||
| @ -19,7 +19,12 @@ from authentik.providers.oauth2.constants import ( | ||||
|     SCOPE_OPENID_PROFILE, | ||||
|     TOKEN_TYPE, | ||||
| ) | ||||
| from authentik.providers.oauth2.models import OAuth2Provider, ScopeMapping | ||||
| from authentik.providers.oauth2.models import ( | ||||
|     OAuth2Provider, | ||||
|     RedirectURI, | ||||
|     RedirectURIMatchingMode, | ||||
|     ScopeMapping, | ||||
| ) | ||||
| from authentik.providers.oauth2.tests.utils import OAuthTestCase | ||||
| from authentik.providers.oauth2.views.jwks import JWKSView | ||||
| from authentik.sources.oauth.models import OAuthSource | ||||
| @ -54,7 +59,7 @@ class TestTokenClientCredentialsJWTSource(OAuthTestCase): | ||||
|         self.provider: OAuth2Provider = OAuth2Provider.objects.create( | ||||
|             name="test", | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris="http://testserver", | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")], | ||||
|             signing_key=self.cert, | ||||
|         ) | ||||
|         self.provider.jwks_sources.add(self.source) | ||||
|  | ||||
| @ -19,7 +19,13 @@ from authentik.providers.oauth2.constants import ( | ||||
|     TOKEN_TYPE, | ||||
| ) | ||||
| from authentik.providers.oauth2.errors import TokenError | ||||
| from authentik.providers.oauth2.models import OAuth2Provider, ScopeMapping | ||||
| from authentik.providers.oauth2.models import ( | ||||
|     AccessToken, | ||||
|     OAuth2Provider, | ||||
|     RedirectURI, | ||||
|     RedirectURIMatchingMode, | ||||
|     ScopeMapping, | ||||
| ) | ||||
| from authentik.providers.oauth2.tests.utils import OAuthTestCase | ||||
|  | ||||
|  | ||||
| @ -33,7 +39,7 @@ class TestTokenClientCredentialsStandard(OAuthTestCase): | ||||
|         self.provider = OAuth2Provider.objects.create( | ||||
|             name="test", | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris="http://testserver", | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")], | ||||
|             signing_key=create_test_cert(), | ||||
|         ) | ||||
|         self.provider.property_mappings.set(ScopeMapping.objects.all()) | ||||
| @ -107,6 +113,48 @@ class TestTokenClientCredentialsStandard(OAuthTestCase): | ||||
|             {"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]}, | ||||
|         ) | ||||
|  | ||||
|     def test_incorrect_scopes(self): | ||||
|         """test scope that isn't configured""" | ||||
|         response = self.client.post( | ||||
|             reverse("authentik_providers_oauth2:token"), | ||||
|             { | ||||
|                 "grant_type": GRANT_TYPE_CLIENT_CREDENTIALS, | ||||
|                 "scope": f"{SCOPE_OPENID} {SCOPE_OPENID_EMAIL} {SCOPE_OPENID_PROFILE} extra_scope", | ||||
|                 "client_id": self.provider.client_id, | ||||
|                 "client_secret": self.provider.client_secret, | ||||
|             }, | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|         body = loads(response.content.decode()) | ||||
|         self.assertEqual(body["token_type"], TOKEN_TYPE) | ||||
|         token = AccessToken.objects.filter( | ||||
|             provider=self.provider, token=body["access_token"] | ||||
|         ).first() | ||||
|         self.assertSetEqual( | ||||
|             set(token.scope), {SCOPE_OPENID, SCOPE_OPENID_EMAIL, SCOPE_OPENID_PROFILE} | ||||
|         ) | ||||
|         _, alg = self.provider.jwt_key | ||||
|         jwt = decode( | ||||
|             body["access_token"], | ||||
|             key=self.provider.signing_key.public_key, | ||||
|             algorithms=[alg], | ||||
|             audience=self.provider.client_id, | ||||
|         ) | ||||
|         self.assertEqual( | ||||
|             jwt["given_name"], "Autogenerated user from application test (client credentials)" | ||||
|         ) | ||||
|         self.assertEqual(jwt["preferred_username"], "ak-test-client_credentials") | ||||
|         jwt = decode( | ||||
|             body["id_token"], | ||||
|             key=self.provider.signing_key.public_key, | ||||
|             algorithms=[alg], | ||||
|             audience=self.provider.client_id, | ||||
|         ) | ||||
|         self.assertEqual( | ||||
|             jwt["given_name"], "Autogenerated user from application test (client credentials)" | ||||
|         ) | ||||
|         self.assertEqual(jwt["preferred_username"], "ak-test-client_credentials") | ||||
|  | ||||
|     def test_successful(self): | ||||
|         """test successful""" | ||||
|         response = self.client.post( | ||||
|  | ||||
| @ -20,7 +20,12 @@ from authentik.providers.oauth2.constants import ( | ||||
|     TOKEN_TYPE, | ||||
| ) | ||||
| from authentik.providers.oauth2.errors import TokenError | ||||
| from authentik.providers.oauth2.models import OAuth2Provider, ScopeMapping | ||||
| from authentik.providers.oauth2.models import ( | ||||
|     OAuth2Provider, | ||||
|     RedirectURI, | ||||
|     RedirectURIMatchingMode, | ||||
|     ScopeMapping, | ||||
| ) | ||||
| from authentik.providers.oauth2.tests.utils import OAuthTestCase | ||||
|  | ||||
|  | ||||
| @ -34,7 +39,7 @@ class TestTokenClientCredentialsStandardCompat(OAuthTestCase): | ||||
|         self.provider = OAuth2Provider.objects.create( | ||||
|             name="test", | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris="http://testserver", | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")], | ||||
|             signing_key=create_test_cert(), | ||||
|         ) | ||||
|         self.provider.property_mappings.set(ScopeMapping.objects.all()) | ||||
|  | ||||
| @ -19,7 +19,12 @@ from authentik.providers.oauth2.constants import ( | ||||
|     TOKEN_TYPE, | ||||
| ) | ||||
| from authentik.providers.oauth2.errors import TokenError | ||||
| from authentik.providers.oauth2.models import OAuth2Provider, ScopeMapping | ||||
| from authentik.providers.oauth2.models import ( | ||||
|     OAuth2Provider, | ||||
|     RedirectURI, | ||||
|     RedirectURIMatchingMode, | ||||
|     ScopeMapping, | ||||
| ) | ||||
| from authentik.providers.oauth2.tests.utils import OAuthTestCase | ||||
|  | ||||
|  | ||||
| @ -33,7 +38,7 @@ class TestTokenClientCredentialsUserNamePassword(OAuthTestCase): | ||||
|         self.provider = OAuth2Provider.objects.create( | ||||
|             name="test", | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris="http://testserver", | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")], | ||||
|             signing_key=create_test_cert(), | ||||
|         ) | ||||
|         self.provider.property_mappings.set(ScopeMapping.objects.all()) | ||||
|  | ||||
| @ -9,8 +9,19 @@ from authentik.blueprints.tests import apply_blueprint | ||||
| from authentik.core.models import Application | ||||
| from authentik.core.tests.utils import create_test_admin_user, create_test_cert, create_test_flow | ||||
| from authentik.lib.generators import generate_code_fixed_length, generate_id | ||||
| from authentik.providers.oauth2.constants import GRANT_TYPE_DEVICE_CODE | ||||
| from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider, ScopeMapping | ||||
| from authentik.providers.oauth2.constants import ( | ||||
|     GRANT_TYPE_DEVICE_CODE, | ||||
|     SCOPE_OPENID, | ||||
|     SCOPE_OPENID_EMAIL, | ||||
| ) | ||||
| from authentik.providers.oauth2.models import ( | ||||
|     AccessToken, | ||||
|     DeviceToken, | ||||
|     OAuth2Provider, | ||||
|     RedirectURI, | ||||
|     RedirectURIMatchingMode, | ||||
|     ScopeMapping, | ||||
| ) | ||||
| from authentik.providers.oauth2.tests.utils import OAuthTestCase | ||||
|  | ||||
|  | ||||
| @ -24,7 +35,7 @@ class TestTokenDeviceCode(OAuthTestCase): | ||||
|         self.provider = OAuth2Provider.objects.create( | ||||
|             name="test", | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris="http://testserver", | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")], | ||||
|             signing_key=create_test_cert(), | ||||
|         ) | ||||
|         self.provider.property_mappings.set(ScopeMapping.objects.all()) | ||||
| @ -80,3 +91,28 @@ class TestTokenDeviceCode(OAuthTestCase): | ||||
|             }, | ||||
|         ) | ||||
|         self.assertEqual(res.status_code, 200) | ||||
|  | ||||
|     def test_code_mismatched_scope(self): | ||||
|         """Test code with user (mismatched scopes)""" | ||||
|         device_token = DeviceToken.objects.create( | ||||
|             provider=self.provider, | ||||
|             user_code=generate_code_fixed_length(), | ||||
|             device_code=generate_id(), | ||||
|             user=self.user, | ||||
|             scope=[SCOPE_OPENID, SCOPE_OPENID_EMAIL], | ||||
|         ) | ||||
|         res = self.client.post( | ||||
|             reverse("authentik_providers_oauth2:token"), | ||||
|             data={ | ||||
|                 "client_id": self.provider.client_id, | ||||
|                 "grant_type": GRANT_TYPE_DEVICE_CODE, | ||||
|                 "device_code": device_token.device_code, | ||||
|                 "scope": f"{SCOPE_OPENID} {SCOPE_OPENID_EMAIL} invalid", | ||||
|             }, | ||||
|         ) | ||||
|         self.assertEqual(res.status_code, 200) | ||||
|         body = loads(res.content) | ||||
|         token = AccessToken.objects.filter( | ||||
|             provider=self.provider, token=body["access_token"] | ||||
|         ).first() | ||||
|         self.assertSetEqual(set(token.scope), {SCOPE_OPENID, SCOPE_OPENID_EMAIL}) | ||||
|  | ||||
| @ -10,7 +10,12 @@ from authentik.core.models import Application | ||||
| from authentik.core.tests.utils import create_test_admin_user, create_test_flow | ||||
| from authentik.lib.generators import generate_id | ||||
| from authentik.providers.oauth2.constants import GRANT_TYPE_AUTHORIZATION_CODE | ||||
| from authentik.providers.oauth2.models import AuthorizationCode, OAuth2Provider | ||||
| from authentik.providers.oauth2.models import ( | ||||
|     AuthorizationCode, | ||||
|     OAuth2Provider, | ||||
|     RedirectURI, | ||||
|     RedirectURIMatchingMode, | ||||
| ) | ||||
| from authentik.providers.oauth2.tests.utils import OAuthTestCase | ||||
|  | ||||
|  | ||||
| @ -30,7 +35,7 @@ class TestTokenPKCE(OAuthTestCase): | ||||
|             name=generate_id(), | ||||
|             client_id="test", | ||||
|             authorization_flow=flow, | ||||
|             redirect_uris="foo://localhost", | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "foo://localhost")], | ||||
|             access_code_validity="seconds=100", | ||||
|         ) | ||||
|         Application.objects.create(name="app", slug="app", provider=provider) | ||||
| @ -93,7 +98,7 @@ class TestTokenPKCE(OAuthTestCase): | ||||
|             name=generate_id(), | ||||
|             client_id="test", | ||||
|             authorization_flow=flow, | ||||
|             redirect_uris="foo://localhost", | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "foo://localhost")], | ||||
|             access_code_validity="seconds=100", | ||||
|         ) | ||||
|         Application.objects.create(name="app", slug="app", provider=provider) | ||||
| @ -154,7 +159,7 @@ class TestTokenPKCE(OAuthTestCase): | ||||
|             name=generate_id(), | ||||
|             client_id="test", | ||||
|             authorization_flow=flow, | ||||
|             redirect_uris="foo://localhost", | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "foo://localhost")], | ||||
|             access_code_validity="seconds=100", | ||||
|         ) | ||||
|         Application.objects.create(name="app", slug="app", provider=provider) | ||||
| @ -210,7 +215,7 @@ class TestTokenPKCE(OAuthTestCase): | ||||
|             name=generate_id(), | ||||
|             client_id="test", | ||||
|             authorization_flow=flow, | ||||
|             redirect_uris="foo://localhost", | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "foo://localhost")], | ||||
|             access_code_validity="seconds=100", | ||||
|         ) | ||||
|         Application.objects.create(name="app", slug="app", provider=provider) | ||||
|  | ||||
| @ -11,7 +11,14 @@ from authentik.core.models import Application | ||||
| from authentik.core.tests.utils import create_test_admin_user, create_test_cert, create_test_flow | ||||
| from authentik.events.models import Event, EventAction | ||||
| from authentik.lib.generators import generate_id | ||||
| from authentik.providers.oauth2.models import AccessToken, IDToken, OAuth2Provider, ScopeMapping | ||||
| from authentik.providers.oauth2.models import ( | ||||
|     AccessToken, | ||||
|     IDToken, | ||||
|     OAuth2Provider, | ||||
|     RedirectURI, | ||||
|     RedirectURIMatchingMode, | ||||
|     ScopeMapping, | ||||
| ) | ||||
| from authentik.providers.oauth2.tests.utils import OAuthTestCase | ||||
|  | ||||
|  | ||||
| @ -25,7 +32,7 @@ class TestUserinfo(OAuthTestCase): | ||||
|         self.provider: OAuth2Provider = OAuth2Provider.objects.create( | ||||
|             name=generate_id(), | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris="", | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "")], | ||||
|             signing_key=create_test_cert(), | ||||
|         ) | ||||
|         self.provider.property_mappings.set(ScopeMapping.objects.all()) | ||||
|  | ||||
| @ -2,7 +2,6 @@ | ||||
|  | ||||
| from dataclasses import InitVar, dataclass, field | ||||
| from datetime import timedelta | ||||
| from hashlib import sha256 | ||||
| from json import dumps | ||||
| from re import error as RegexError | ||||
| from re import fullmatch | ||||
| @ -16,7 +15,7 @@ from django.utils import timezone | ||||
| from django.utils.translation import gettext as _ | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.core.models import Application | ||||
| from authentik.core.models import Application, AuthenticatedSession | ||||
| from authentik.events.models import Event, EventAction | ||||
| from authentik.events.signals import get_login_event | ||||
| from authentik.flows.challenge import ( | ||||
| @ -57,6 +56,8 @@ from authentik.providers.oauth2.models import ( | ||||
|     AuthorizationCode, | ||||
|     GrantTypes, | ||||
|     OAuth2Provider, | ||||
|     RedirectURI, | ||||
|     RedirectURIMatchingMode, | ||||
|     ResponseMode, | ||||
|     ResponseTypes, | ||||
|     ScopeMapping, | ||||
| @ -188,40 +189,39 @@ class OAuthAuthorizationParams: | ||||
|  | ||||
|     def check_redirect_uri(self): | ||||
|         """Redirect URI validation.""" | ||||
|         allowed_redirect_urls = self.provider.redirect_uris.split() | ||||
|         allowed_redirect_urls = self.provider.redirect_uris | ||||
|         if not self.redirect_uri: | ||||
|             LOGGER.warning("Missing redirect uri.") | ||||
|             raise RedirectUriError("", allowed_redirect_urls) | ||||
|  | ||||
|         if self.provider.redirect_uris == "": | ||||
|         if len(allowed_redirect_urls) < 1: | ||||
|             LOGGER.info("Setting redirect for blank redirect_uris", redirect=self.redirect_uri) | ||||
|             self.provider.redirect_uris = self.redirect_uri | ||||
|             self.provider.redirect_uris = [ | ||||
|                 RedirectURI(RedirectURIMatchingMode.STRICT, self.redirect_uri) | ||||
|             ] | ||||
|             self.provider.save() | ||||
|             allowed_redirect_urls = self.provider.redirect_uris.split() | ||||
|  | ||||
|         if self.provider.redirect_uris == "*": | ||||
|             LOGGER.info("Converting redirect_uris to regex", redirect=self.redirect_uri) | ||||
|             self.provider.redirect_uris = ".*" | ||||
|             self.provider.save() | ||||
|             allowed_redirect_urls = self.provider.redirect_uris.split() | ||||
|             allowed_redirect_urls = self.provider.redirect_uris | ||||
|  | ||||
|         match_found = False | ||||
|         for allowed in allowed_redirect_urls: | ||||
|             if allowed.matching_mode == RedirectURIMatchingMode.STRICT: | ||||
|                 if self.redirect_uri == allowed.url: | ||||
|                     match_found = True | ||||
|                     break | ||||
|             if allowed.matching_mode == RedirectURIMatchingMode.REGEX: | ||||
|                 try: | ||||
|             if not any(fullmatch(x, self.redirect_uri) for x in allowed_redirect_urls): | ||||
|                 LOGGER.warning( | ||||
|                     "Invalid redirect uri (regex comparison)", | ||||
|                     redirect_uri_given=self.redirect_uri, | ||||
|                     redirect_uri_expected=allowed_redirect_urls, | ||||
|                 ) | ||||
|                 raise RedirectUriError(self.redirect_uri, allowed_redirect_urls) | ||||
|                     if fullmatch(allowed.url, self.redirect_uri): | ||||
|                         match_found = True | ||||
|                         break | ||||
|                 except RegexError as exc: | ||||
|             LOGGER.info("Failed to parse regular expression, checking directly", exc=exc) | ||||
|             if not any(x == self.redirect_uri for x in allowed_redirect_urls): | ||||
|                     LOGGER.warning( | ||||
|                     "Invalid redirect uri (strict comparison)", | ||||
|                     redirect_uri_given=self.redirect_uri, | ||||
|                     redirect_uri_expected=allowed_redirect_urls, | ||||
|                         "Failed to parse regular expression", | ||||
|                         exc=exc, | ||||
|                         url=allowed.url, | ||||
|                         provider=self.provider, | ||||
|                     ) | ||||
|                 raise RedirectUriError(self.redirect_uri, allowed_redirect_urls) from None | ||||
|         if not match_found: | ||||
|             raise RedirectUriError(self.redirect_uri, allowed_redirect_urls) | ||||
|         # Check against forbidden schemes | ||||
|         if urlparse(self.redirect_uri).scheme in FORBIDDEN_URI_SCHEMES: | ||||
|             raise RedirectUriError(self.redirect_uri, allowed_redirect_urls) | ||||
| @ -318,7 +318,9 @@ class OAuthAuthorizationParams: | ||||
|             expires=now + timedelta_from_string(self.provider.access_code_validity), | ||||
|             scope=self.scope, | ||||
|             nonce=self.nonce, | ||||
|             session_id=sha256(request.session.session_key.encode("ascii")).hexdigest(), | ||||
|             session=AuthenticatedSession.objects.filter( | ||||
|                 session_key=request.session.session_key | ||||
|             ).first(), | ||||
|         ) | ||||
|  | ||||
|         if self.code_challenge and self.code_challenge_method: | ||||
| @ -610,7 +612,9 @@ class OAuthFulfillmentStage(StageView): | ||||
|             expires=access_token_expiry, | ||||
|             provider=self.provider, | ||||
|             auth_time=auth_event.created if auth_event else now, | ||||
|             session_id=sha256(self.request.session.session_key.encode("ascii")).hexdigest(), | ||||
|             session=AuthenticatedSession.objects.filter( | ||||
|                 session_key=self.request.session.session_key | ||||
|             ).first(), | ||||
|         ) | ||||
|  | ||||
|         id_token = IDToken.new(self.provider, token, self.request) | ||||
|  | ||||
| @ -46,10 +46,10 @@ class TokenIntrospectionParams: | ||||
|         if not provider: | ||||
|             raise TokenIntrospectionError | ||||
|  | ||||
|         access_token = AccessToken.objects.filter(token=raw_token).first() | ||||
|         access_token = AccessToken.objects.filter(token=raw_token, provider=provider).first() | ||||
|         if access_token: | ||||
|             return TokenIntrospectionParams(access_token, provider) | ||||
|         refresh_token = RefreshToken.objects.filter(token=raw_token).first() | ||||
|         refresh_token = RefreshToken.objects.filter(token=raw_token, provider=provider).first() | ||||
|         if refresh_token: | ||||
|             return TokenIntrospectionParams(refresh_token, provider) | ||||
|         LOGGER.debug("Token does not exist", token=raw_token) | ||||
|  | ||||
| @ -158,5 +158,5 @@ class ProviderInfoView(View): | ||||
|             OAuth2Provider, pk=application.provider_id | ||||
|         ) | ||||
|         response = super().dispatch(request, *args, **kwargs) | ||||
|         cors_allow(request, response, *self.provider.redirect_uris.split("\n")) | ||||
|         cors_allow(request, response, *[x.url for x in self.provider.redirect_uris]) | ||||
|         return response | ||||
|  | ||||
| @ -58,7 +58,9 @@ from authentik.providers.oauth2.models import ( | ||||
|     ClientTypes, | ||||
|     DeviceToken, | ||||
|     OAuth2Provider, | ||||
|     RedirectURIMatchingMode, | ||||
|     RefreshToken, | ||||
|     ScopeMapping, | ||||
| ) | ||||
| from authentik.providers.oauth2.utils import TokenResponse, cors_allow, extract_client_auth | ||||
| from authentik.providers.oauth2.views.authorize import FORBIDDEN_URI_SCHEMES | ||||
| @ -77,7 +79,7 @@ class TokenParams: | ||||
|     redirect_uri: str | ||||
|     grant_type: str | ||||
|     state: str | ||||
|     scope: list[str] | ||||
|     scope: set[str] | ||||
|  | ||||
|     provider: OAuth2Provider | ||||
|  | ||||
| @ -112,11 +114,26 @@ class TokenParams: | ||||
|             redirect_uri=request.POST.get("redirect_uri", ""), | ||||
|             grant_type=request.POST.get("grant_type", ""), | ||||
|             state=request.POST.get("state", ""), | ||||
|             scope=request.POST.get("scope", "").split(), | ||||
|             scope=set(request.POST.get("scope", "").split()), | ||||
|             # PKCE parameter. | ||||
|             code_verifier=request.POST.get("code_verifier"), | ||||
|         ) | ||||
|  | ||||
|     def __check_scopes(self): | ||||
|         allowed_scope_names = set( | ||||
|             ScopeMapping.objects.filter(provider__in=[self.provider]).values_list( | ||||
|                 "scope_name", flat=True | ||||
|             ) | ||||
|         ) | ||||
|         scopes_to_check = self.scope | ||||
|         if not scopes_to_check.issubset(allowed_scope_names): | ||||
|             LOGGER.info( | ||||
|                 "Application requested scopes not configured, setting to overlap", | ||||
|                 scope_allowed=allowed_scope_names, | ||||
|                 scope_given=self.scope, | ||||
|             ) | ||||
|             self.scope = self.scope.intersection(allowed_scope_names) | ||||
|  | ||||
|     def __check_policy_access(self, app: Application, request: HttpRequest, **kwargs): | ||||
|         with start_span( | ||||
|             op="authentik.providers.oauth2.token.policy", | ||||
| @ -149,7 +166,7 @@ class TokenParams: | ||||
|                     client_id=self.provider.client_id, | ||||
|                 ) | ||||
|                 raise TokenError("invalid_client") | ||||
|  | ||||
|         self.__check_scopes() | ||||
|         if self.grant_type == GRANT_TYPE_AUTHORIZATION_CODE: | ||||
|             with start_span( | ||||
|                 op="authentik.providers.oauth2.post.parse.code", | ||||
| @ -179,42 +196,7 @@ class TokenParams: | ||||
|             LOGGER.warning("Missing authorization code") | ||||
|             raise TokenError("invalid_grant") | ||||
|  | ||||
|         allowed_redirect_urls = self.provider.redirect_uris.split() | ||||
|         # At this point, no provider should have a blank redirect_uri, in case they do | ||||
|         # this will check an empty array and raise an error | ||||
|         try: | ||||
|             if not any(fullmatch(x, self.redirect_uri) for x in allowed_redirect_urls): | ||||
|                 LOGGER.warning( | ||||
|                     "Invalid redirect uri (regex comparison)", | ||||
|                     redirect_uri=self.redirect_uri, | ||||
|                     expected=allowed_redirect_urls, | ||||
|                 ) | ||||
|                 Event.new( | ||||
|                     EventAction.CONFIGURATION_ERROR, | ||||
|                     message="Invalid redirect URI used by provider", | ||||
|                     provider=self.provider, | ||||
|                     redirect_uri=self.redirect_uri, | ||||
|                     expected=allowed_redirect_urls, | ||||
|                 ).from_http(request) | ||||
|                 raise TokenError("invalid_client") | ||||
|         except RegexError as exc: | ||||
|             LOGGER.info("Failed to parse regular expression, checking directly", exc=exc) | ||||
|             if not any(x == self.redirect_uri for x in allowed_redirect_urls): | ||||
|                 LOGGER.warning( | ||||
|                     "Invalid redirect uri (strict comparison)", | ||||
|                     redirect_uri=self.redirect_uri, | ||||
|                     expected=allowed_redirect_urls, | ||||
|                 ) | ||||
|                 Event.new( | ||||
|                     EventAction.CONFIGURATION_ERROR, | ||||
|                     message="Invalid redirect_uri configured", | ||||
|                     provider=self.provider, | ||||
|                 ).from_http(request) | ||||
|                 raise TokenError("invalid_client") from None | ||||
|  | ||||
|         # Check against forbidden schemes | ||||
|         if urlparse(self.redirect_uri).scheme in FORBIDDEN_URI_SCHEMES: | ||||
|             raise TokenError("invalid_request") | ||||
|         self.__check_redirect_uri(request) | ||||
|  | ||||
|         self.authorization_code = AuthorizationCode.objects.filter(code=raw_code).first() | ||||
|         if not self.authorization_code: | ||||
| @ -254,6 +236,48 @@ class TokenParams: | ||||
|         if not self.authorization_code.code_challenge and self.code_verifier: | ||||
|             raise TokenError("invalid_grant") | ||||
|  | ||||
|     def __check_redirect_uri(self, request: HttpRequest): | ||||
|         allowed_redirect_urls = self.provider.redirect_uris | ||||
|         # At this point, no provider should have a blank redirect_uri, in case they do | ||||
|         # this will check an empty array and raise an error | ||||
|  | ||||
|         match_found = False | ||||
|         for allowed in allowed_redirect_urls: | ||||
|             if allowed.matching_mode == RedirectURIMatchingMode.STRICT: | ||||
|                 if self.redirect_uri == allowed.url: | ||||
|                     match_found = True | ||||
|                     break | ||||
|             if allowed.matching_mode == RedirectURIMatchingMode.REGEX: | ||||
|                 try: | ||||
|                     if fullmatch(allowed.url, self.redirect_uri): | ||||
|                         match_found = True | ||||
|                         break | ||||
|                 except RegexError as exc: | ||||
|                     LOGGER.warning( | ||||
|                         "Failed to parse regular expression", | ||||
|                         exc=exc, | ||||
|                         url=allowed.url, | ||||
|                         provider=self.provider, | ||||
|                     ) | ||||
|                     Event.new( | ||||
|                         EventAction.CONFIGURATION_ERROR, | ||||
|                         message="Invalid redirect_uri configured", | ||||
|                         provider=self.provider, | ||||
|                     ).from_http(request) | ||||
|         if not match_found: | ||||
|             Event.new( | ||||
|                 EventAction.CONFIGURATION_ERROR, | ||||
|                 message="Invalid redirect URI used by provider", | ||||
|                 provider=self.provider, | ||||
|                 redirect_uri=self.redirect_uri, | ||||
|                 expected=allowed_redirect_urls, | ||||
|             ).from_http(request) | ||||
|             raise TokenError("invalid_client") | ||||
|  | ||||
|         # Check against forbidden schemes | ||||
|         if urlparse(self.redirect_uri).scheme in FORBIDDEN_URI_SCHEMES: | ||||
|             raise TokenError("invalid_request") | ||||
|  | ||||
|     def __post_init_refresh(self, raw_token: str, request: HttpRequest): | ||||
|         if not raw_token: | ||||
|             LOGGER.warning("Missing refresh token") | ||||
| @ -433,20 +457,20 @@ class TokenParams: | ||||
|         app = Application.objects.filter(provider=self.provider).first() | ||||
|         if not app or not app.provider: | ||||
|             raise TokenError("invalid_grant") | ||||
|         with audit_ignore(): | ||||
|             self.user, _ = User.objects.update_or_create( | ||||
|                 # trim username to ensure the entire username is max 150 chars | ||||
|                 # (22 chars being the length of the "template") | ||||
|                 username=f"ak-{self.provider.name[:150-22]}-client_credentials", | ||||
|                 defaults={ | ||||
|                 "attributes": { | ||||
|                     USER_ATTRIBUTE_GENERATED: True, | ||||
|                 }, | ||||
|                     "last_login": timezone.now(), | ||||
|                     "name": f"Autogenerated user from application {app.name} (client credentials)", | ||||
|                     "path": f"{USER_PATH_SYSTEM_PREFIX}/apps/{app.slug}", | ||||
|                     "type": UserTypes.SERVICE_ACCOUNT, | ||||
|                 }, | ||||
|             ) | ||||
|             self.user.attributes[USER_ATTRIBUTE_GENERATED] = True | ||||
|             self.user.save() | ||||
|         self.__check_policy_access(app, request) | ||||
|  | ||||
|         Event.new( | ||||
| @ -470,9 +494,6 @@ class TokenParams: | ||||
|             self.user, created = User.objects.update_or_create( | ||||
|                 username=f"{self.provider.name}-{token.get('sub')}", | ||||
|                 defaults={ | ||||
|                     "attributes": { | ||||
|                         USER_ATTRIBUTE_GENERATED: True, | ||||
|                     }, | ||||
|                     "last_login": timezone.now(), | ||||
|                     "name": ( | ||||
|                         f"Autogenerated user from application {app.name} (client credentials JWT)" | ||||
| @ -481,6 +502,8 @@ class TokenParams: | ||||
|                     "type": UserTypes.SERVICE_ACCOUNT, | ||||
|                 }, | ||||
|             ) | ||||
|             self.user.attributes[USER_ATTRIBUTE_GENERATED] = True | ||||
|             self.user.save() | ||||
|             exp = token.get("exp") | ||||
|             if created and exp: | ||||
|                 self.user.attributes[USER_ATTRIBUTE_EXPIRES] = exp | ||||
| @ -498,7 +521,7 @@ class TokenView(View): | ||||
|         response = super().dispatch(request, *args, **kwargs) | ||||
|         allowed_origins = [] | ||||
|         if self.provider: | ||||
|             allowed_origins = self.provider.redirect_uris.split("\n") | ||||
|             allowed_origins = [x.url for x in self.provider.redirect_uris] | ||||
|         cors_allow(self.request, response, *allowed_origins) | ||||
|         return response | ||||
|  | ||||
| @ -551,7 +574,7 @@ class TokenView(View): | ||||
|             # Keep same scopes as previous token | ||||
|             scope=self.params.authorization_code.scope, | ||||
|             auth_time=self.params.authorization_code.auth_time, | ||||
|             session_id=self.params.authorization_code.session_id, | ||||
|             session=self.params.authorization_code.session, | ||||
|         ) | ||||
|         access_id_token = IDToken.new( | ||||
|             self.provider, | ||||
| @ -579,7 +602,7 @@ class TokenView(View): | ||||
|                 expires=refresh_token_expiry, | ||||
|                 provider=self.provider, | ||||
|                 auth_time=self.params.authorization_code.auth_time, | ||||
|                 session_id=self.params.authorization_code.session_id, | ||||
|                 session=self.params.authorization_code.session, | ||||
|             ) | ||||
|             id_token = IDToken.new( | ||||
|                 self.provider, | ||||
| @ -612,7 +635,7 @@ class TokenView(View): | ||||
|             # Keep same scopes as previous token | ||||
|             scope=self.params.refresh_token.scope, | ||||
|             auth_time=self.params.refresh_token.auth_time, | ||||
|             session_id=self.params.refresh_token.session_id, | ||||
|             session=self.params.refresh_token.session, | ||||
|         ) | ||||
|         access_token.id_token = IDToken.new( | ||||
|             self.provider, | ||||
| @ -628,7 +651,7 @@ class TokenView(View): | ||||
|             expires=refresh_token_expiry, | ||||
|             provider=self.provider, | ||||
|             auth_time=self.params.refresh_token.auth_time, | ||||
|             session_id=self.params.refresh_token.session_id, | ||||
|             session=self.params.refresh_token.session, | ||||
|         ) | ||||
|         id_token = IDToken.new( | ||||
|             self.provider, | ||||
| @ -686,13 +709,14 @@ class TokenView(View): | ||||
|             raise DeviceCodeError("authorization_pending") | ||||
|         now = timezone.now() | ||||
|         access_token_expiry = now + timedelta_from_string(self.provider.access_token_validity) | ||||
|         auth_event = get_login_event(self.request) | ||||
|         auth_event = get_login_event(self.params.device_code.session) | ||||
|         access_token = AccessToken( | ||||
|             provider=self.provider, | ||||
|             user=self.params.device_code.user, | ||||
|             expires=access_token_expiry, | ||||
|             scope=self.params.device_code.scope, | ||||
|             auth_time=auth_event.created if auth_event else now, | ||||
|             session=self.params.device_code.session, | ||||
|         ) | ||||
|         access_token.id_token = IDToken.new( | ||||
|             self.provider, | ||||
| @ -710,7 +734,7 @@ class TokenView(View): | ||||
|             "id_token": access_token.id_token.to_jwt(self.provider), | ||||
|         } | ||||
|  | ||||
|         if SCOPE_OFFLINE_ACCESS in self.params.scope: | ||||
|         if SCOPE_OFFLINE_ACCESS in self.params.device_code.scope: | ||||
|             refresh_token_expiry = now + timedelta_from_string(self.provider.refresh_token_validity) | ||||
|             refresh_token = RefreshToken( | ||||
|                 user=self.params.device_code.user, | ||||
|  | ||||
| @ -108,7 +108,7 @@ class UserInfoView(View): | ||||
|         response = super().dispatch(request, *args, **kwargs) | ||||
|         allowed_origins = [] | ||||
|         if self.token: | ||||
|             allowed_origins = self.token.provider.redirect_uris.split("\n") | ||||
|             allowed_origins = [x.url for x in self.token.provider.redirect_uris] | ||||
|         cors_allow(self.request, response, *allowed_origins) | ||||
|         return response | ||||
|  | ||||
|  | ||||
| @ -121,7 +121,6 @@ class ProxyProviderViewSet(UsedByMixin, ModelViewSet): | ||||
|         "basic_auth_password_attribute": ["iexact"], | ||||
|         "basic_auth_user_attribute": ["iexact"], | ||||
|         "mode": ["iexact"], | ||||
|         "redirect_uris": ["iexact"], | ||||
|         "cookie_domain": ["iexact"], | ||||
|     } | ||||
|     search_fields = ["name"] | ||||
|  | ||||
| @ -28,7 +28,7 @@ class ProxyDockerController(DockerController): | ||||
|         labels = super()._get_labels() | ||||
|         labels["traefik.enable"] = "true" | ||||
|         labels[f"traefik.http.routers.{traefik_name}-router.rule"] = ( | ||||
|             f"({' || '.join([f'Host(`{host}`)' for host in hosts])})" | ||||
|             f"({' || '.join([f'Host({host})' for host in hosts])})" | ||||
|             f" && PathPrefix(`/outpost.goauthentik.io`)" | ||||
|         ) | ||||
|         labels[f"traefik.http.routers.{traefik_name}-router.tls"] = "true" | ||||
|  | ||||
| @ -13,7 +13,13 @@ from rest_framework.serializers import Serializer | ||||
| from authentik.crypto.models import CertificateKeyPair | ||||
| from authentik.lib.models import DomainlessURLValidator | ||||
| from authentik.outposts.models import OutpostModel | ||||
| from authentik.providers.oauth2.models import ClientTypes, OAuth2Provider, ScopeMapping | ||||
| from authentik.providers.oauth2.models import ( | ||||
|     ClientTypes, | ||||
|     OAuth2Provider, | ||||
|     RedirectURI, | ||||
|     RedirectURIMatchingMode, | ||||
|     ScopeMapping, | ||||
| ) | ||||
|  | ||||
| SCOPE_AK_PROXY = "ak_proxy" | ||||
| OUTPOST_CALLBACK_SIGNATURE = "X-authentik-auth-callback" | ||||
| @ -24,14 +30,15 @@ def get_cookie_secret(): | ||||
|     return "".join(SystemRandom().choice(string.ascii_uppercase + string.digits) for _ in range(32)) | ||||
|  | ||||
|  | ||||
| def _get_callback_url(uri: str) -> str: | ||||
|     return "\n".join( | ||||
|         [ | ||||
| def _get_callback_url(uri: str) -> list[RedirectURI]: | ||||
|     return [ | ||||
|         RedirectURI( | ||||
|             RedirectURIMatchingMode.STRICT, | ||||
|             urljoin(uri, "outpost.goauthentik.io/callback") | ||||
|             + f"\\?{OUTPOST_CALLBACK_SIGNATURE}=true", | ||||
|             uri + f"\\?{OUTPOST_CALLBACK_SIGNATURE}=true", | ||||
|         ), | ||||
|         RedirectURI(RedirectURIMatchingMode.STRICT, uri + f"\\?{OUTPOST_CALLBACK_SIGNATURE}=true"), | ||||
|     ] | ||||
|     ) | ||||
|  | ||||
|  | ||||
| class ProxyMode(models.TextChoices): | ||||
|  | ||||
| @ -1,13 +1,12 @@ | ||||
| """proxy provider tasks""" | ||||
|  | ||||
| from hashlib import sha256 | ||||
|  | ||||
| from asgiref.sync import async_to_sync | ||||
| from channels.layers import get_channel_layer | ||||
| from django.db import DatabaseError, InternalError, ProgrammingError | ||||
|  | ||||
| from authentik.outposts.consumer import OUTPOST_GROUP | ||||
| from authentik.outposts.models import Outpost, OutpostType | ||||
| from authentik.providers.oauth2.id_token import hash_session_key | ||||
| from authentik.providers.proxy.models import ProxyProvider | ||||
| from authentik.root.celery import CELERY_APP | ||||
|  | ||||
| @ -26,7 +25,7 @@ def proxy_set_defaults(): | ||||
| def proxy_on_logout(session_id: str): | ||||
|     """Update outpost instances connected to a single outpost""" | ||||
|     layer = get_channel_layer() | ||||
|     hashed_session_id = sha256(session_id.encode("ascii")).hexdigest() | ||||
|     hashed_session_id = hash_session_key(session_id) | ||||
|     for outpost in Outpost.objects.filter(type=OutpostType.PROXY): | ||||
|         group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)} | ||||
|         async_to_sync(layer.group_send)( | ||||
|  | ||||
| @ -164,7 +164,7 @@ class SAMLProvider(Provider): | ||||
|     ) | ||||
|  | ||||
|     sign_assertion = models.BooleanField(default=True) | ||||
|     sign_response = models.BooleanField(default=True) | ||||
|     sign_response = models.BooleanField(default=False) | ||||
|  | ||||
|     @property | ||||
|     def launch_url(self) -> str | None: | ||||
|  | ||||
| @ -50,6 +50,7 @@ class AssertionProcessor: | ||||
|  | ||||
|     _issue_instant: str | ||||
|     _assertion_id: str | ||||
|     _response_id: str | ||||
|  | ||||
|     _valid_not_before: str | ||||
|     _session_not_on_or_after: str | ||||
| @ -62,6 +63,7 @@ class AssertionProcessor: | ||||
|  | ||||
|         self._issue_instant = get_time_string() | ||||
|         self._assertion_id = get_random_id() | ||||
|         self._response_id = get_random_id() | ||||
|  | ||||
|         self._valid_not_before = get_time_string( | ||||
|             timedelta_from_string(self.provider.assertion_valid_not_before) | ||||
| @ -130,7 +132,9 @@ class AssertionProcessor: | ||||
|         """Generate AuthnStatement with AuthnContext and ContextClassRef Elements.""" | ||||
|         auth_n_statement = Element(f"{{{NS_SAML_ASSERTION}}}AuthnStatement") | ||||
|         auth_n_statement.attrib["AuthnInstant"] = self._valid_not_before | ||||
|         auth_n_statement.attrib["SessionIndex"] = self._assertion_id | ||||
|         auth_n_statement.attrib["SessionIndex"] = sha256( | ||||
|             self.http_request.session.session_key.encode("ascii") | ||||
|         ).hexdigest() | ||||
|         auth_n_statement.attrib["SessionNotOnOrAfter"] = self._session_not_on_or_after | ||||
|  | ||||
|         auth_n_context = SubElement(auth_n_statement, f"{{{NS_SAML_ASSERTION}}}AuthnContext") | ||||
| @ -285,7 +289,7 @@ class AssertionProcessor: | ||||
|         response.attrib["Version"] = "2.0" | ||||
|         response.attrib["IssueInstant"] = self._issue_instant | ||||
|         response.attrib["Destination"] = self.provider.acs_url | ||||
|         response.attrib["ID"] = get_random_id() | ||||
|         response.attrib["ID"] = self._response_id | ||||
|         if self.auth_n_request.id: | ||||
|             response.attrib["InResponseTo"] = self.auth_n_request.id | ||||
|  | ||||
| @ -308,7 +312,7 @@ class AssertionProcessor: | ||||
|         ref = xmlsec.template.add_reference( | ||||
|             signature_node, | ||||
|             digest_algorithm_transform, | ||||
|             uri="#" + self._assertion_id, | ||||
|             uri="#" + element.attrib["ID"], | ||||
|         ) | ||||
|         xmlsec.template.add_transform(ref, xmlsec.constants.TransformEnveloped) | ||||
|         xmlsec.template.add_transform(ref, xmlsec.constants.TransformExclC14N) | ||||
|  | ||||
| @ -180,6 +180,10 @@ class TestAuthNRequest(TestCase): | ||||
|         # Now create a response and convert it to string (provider) | ||||
|         response_proc = AssertionProcessor(self.provider, http_request, parsed_request) | ||||
|         response = response_proc.build_response() | ||||
|         # Ensure both response and assertion ID are in the response twice (once as ID attribute, | ||||
|         # once as ds:Reference URI) | ||||
|         self.assertEqual(response.count(response_proc._assertion_id), 2) | ||||
|         self.assertEqual(response.count(response_proc._response_id), 2) | ||||
|  | ||||
|         # Now parse the response (source) | ||||
|         http_request.POST = QueryDict(mutable=True) | ||||
|  | ||||
| @ -54,7 +54,11 @@ class TestServiceProviderMetadataParser(TestCase): | ||||
|         request = self.factory.get("/") | ||||
|         metadata = lxml_from_string(MetadataProcessor(provider, request).build_entity_descriptor()) | ||||
|  | ||||
|         schema = etree.XMLSchema(etree.parse("schemas/saml-schema-metadata-2.0.xsd"))  # nosec | ||||
|         schema = etree.XMLSchema( | ||||
|             etree.parse( | ||||
|                 source="schemas/saml-schema-metadata-2.0.xsd", parser=etree.XMLParser() | ||||
|             )  # nosec | ||||
|         ) | ||||
|         self.assertTrue(schema.validate(metadata)) | ||||
|  | ||||
|     def test_schema_want_authn_requests_signed(self): | ||||
|  | ||||
| @ -47,7 +47,9 @@ class TestSchema(TestCase): | ||||
|  | ||||
|         metadata = lxml_from_string(request) | ||||
|  | ||||
|         schema = etree.XMLSchema(etree.parse("schemas/saml-schema-protocol-2.0.xsd"))  # nosec | ||||
|         schema = etree.XMLSchema( | ||||
|             etree.parse("schemas/saml-schema-protocol-2.0.xsd", parser=etree.XMLParser())  # nosec | ||||
|         ) | ||||
|         self.assertTrue(schema.validate(metadata)) | ||||
|  | ||||
|     def test_response_schema(self): | ||||
| @ -68,5 +70,7 @@ class TestSchema(TestCase): | ||||
|  | ||||
|         metadata = lxml_from_string(response) | ||||
|  | ||||
|         schema = etree.XMLSchema(etree.parse("schemas/saml-schema-protocol-2.0.xsd"))  # nosec | ||||
|         schema = etree.XMLSchema( | ||||
|             etree.parse("schemas/saml-schema-protocol-2.0.xsd", parser=etree.XMLParser())  # nosec | ||||
|         ) | ||||
|         self.assertTrue(schema.validate(metadata)) | ||||
|  | ||||
| @ -2,9 +2,10 @@ | ||||
|  | ||||
| from itertools import batched | ||||
|  | ||||
| from django.db import transaction | ||||
| from pydantic import ValidationError | ||||
| from pydanticscim.group import GroupMember | ||||
| from pydanticscim.responses import PatchOp, PatchOperation | ||||
| from pydanticscim.responses import PatchOp | ||||
|  | ||||
| from authentik.core.models import Group | ||||
| from authentik.lib.sync.mapper import PropertyMappingManager | ||||
| @ -19,7 +20,7 @@ from authentik.providers.scim.clients.base import SCIMClient | ||||
| from authentik.providers.scim.clients.exceptions import ( | ||||
|     SCIMRequestException, | ||||
| ) | ||||
| from authentik.providers.scim.clients.schema import SCIM_GROUP_SCHEMA, PatchRequest | ||||
| from authentik.providers.scim.clients.schema import SCIM_GROUP_SCHEMA, PatchOperation, PatchRequest | ||||
| from authentik.providers.scim.clients.schema import Group as SCIMGroupSchema | ||||
| from authentik.providers.scim.models import ( | ||||
|     SCIMMapping, | ||||
| @ -104,13 +105,47 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]): | ||||
|             provider=self.provider, group=group, scim_id=scim_id | ||||
|         ) | ||||
|         users = list(group.users.order_by("id").values_list("id", flat=True)) | ||||
|         self._patch_add_users(group, users) | ||||
|         self._patch_add_users(connection, users) | ||||
|         return connection | ||||
|  | ||||
|     def update(self, group: Group, connection: SCIMProviderGroup): | ||||
|         """Update existing group""" | ||||
|         scim_group = self.to_schema(group, connection) | ||||
|         scim_group.id = connection.scim_id | ||||
|         try: | ||||
|             if self._config.patch.supported: | ||||
|                 return self._update_patch(group, scim_group, connection) | ||||
|             return self._update_put(group, scim_group, connection) | ||||
|         except NotFoundSyncException: | ||||
|             # Resource missing is handled by self.write, which will re-create the group | ||||
|             raise | ||||
|  | ||||
|     def _update_patch( | ||||
|         self, group: Group, scim_group: SCIMGroupSchema, connection: SCIMProviderGroup | ||||
|     ): | ||||
|         """Update a group via PATCH request""" | ||||
|         # Patch group's attributes instead of replacing it and re-adding users if we can | ||||
|         self._request( | ||||
|             "PATCH", | ||||
|             f"/Groups/{connection.scim_id}", | ||||
|             json=PatchRequest( | ||||
|                 Operations=[ | ||||
|                     PatchOperation( | ||||
|                         op=PatchOp.replace, | ||||
|                         path=None, | ||||
|                         value=scim_group.model_dump(mode="json", exclude_unset=True), | ||||
|                     ) | ||||
|                 ] | ||||
|             ).model_dump( | ||||
|                 mode="json", | ||||
|                 exclude_unset=True, | ||||
|                 exclude_none=True, | ||||
|             ), | ||||
|         ) | ||||
|         return self.patch_compare_users(group) | ||||
|  | ||||
|     def _update_put(self, group: Group, scim_group: SCIMGroupSchema, connection: SCIMProviderGroup): | ||||
|         """Update a group via PUT request""" | ||||
|         try: | ||||
|             self._request( | ||||
|                 "PUT", | ||||
| @ -120,33 +155,25 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]): | ||||
|                     exclude_unset=True, | ||||
|                 ), | ||||
|             ) | ||||
|             users = list(group.users.order_by("id").values_list("id", flat=True)) | ||||
|             return self._patch_add_users(group, users) | ||||
|         except NotFoundSyncException: | ||||
|             # Resource missing is handled by self.write, which will re-create the group | ||||
|             raise | ||||
|             return self.patch_compare_users(group) | ||||
|         except (SCIMRequestException, ObjectExistsSyncException): | ||||
|             # Some providers don't support PUT on groups, so this is mainly a fix for the initial | ||||
|             # sync, send patch add requests for all the users the group currently has | ||||
|             users = list(group.users.order_by("id").values_list("id", flat=True)) | ||||
|             self._patch_add_users(group, users) | ||||
|             # Also update the group name | ||||
|             return self._patch( | ||||
|                 scim_group.id, | ||||
|                 PatchOperation( | ||||
|                     op=PatchOp.replace, | ||||
|                     path="displayName", | ||||
|                     value=scim_group.displayName, | ||||
|                 ), | ||||
|             ) | ||||
|             return self._update_patch(group, scim_group, connection) | ||||
|  | ||||
|     def update_group(self, group: Group, action: Direction, users_set: set[int]): | ||||
|         """Update a group, either using PUT to replace it or PATCH if supported""" | ||||
|         scim_group = SCIMProviderGroup.objects.filter(provider=self.provider, group=group).first() | ||||
|         if not scim_group: | ||||
|             self.logger.warning( | ||||
|                 "could not sync group membership, group does not exist", group=group | ||||
|             ) | ||||
|             return | ||||
|         if self._config.patch.supported: | ||||
|             if action == Direction.add: | ||||
|                 return self._patch_add_users(group, users_set) | ||||
|                 return self._patch_add_users(scim_group, users_set) | ||||
|             if action == Direction.remove: | ||||
|                 return self._patch_remove_users(group, users_set) | ||||
|                 return self._patch_remove_users(scim_group, users_set) | ||||
|         try: | ||||
|             return self.write(group) | ||||
|         except SCIMRequestException as exc: | ||||
| @ -154,19 +181,24 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]): | ||||
|                 # Assume that provider does not support PUT and also doesn't support | ||||
|                 # ServiceProviderConfig, so try PATCH as a fallback | ||||
|                 if action == Direction.add: | ||||
|                     return self._patch_add_users(group, users_set) | ||||
|                     return self._patch_add_users(scim_group, users_set) | ||||
|                 if action == Direction.remove: | ||||
|                     return self._patch_remove_users(group, users_set) | ||||
|                     return self._patch_remove_users(scim_group, users_set) | ||||
|             raise exc | ||||
|  | ||||
|     def _patch( | ||||
|     def _patch_chunked( | ||||
|         self, | ||||
|         group_id: str, | ||||
|         *ops: PatchOperation, | ||||
|     ): | ||||
|         """Helper function that chunks patch requests based on the maxOperations attribute. | ||||
|         This is not strictly according to specs but there's nothing in the schema that allows the | ||||
|         us to know what the maximum patch operations per request should be.""" | ||||
|         chunk_size = self._config.bulk.maxOperations | ||||
|         if chunk_size < 1: | ||||
|             chunk_size = len(ops) | ||||
|         if len(ops) < 1: | ||||
|             return | ||||
|         for chunk in batched(ops, chunk_size): | ||||
|             req = PatchRequest(Operations=list(chunk)) | ||||
|             self._request( | ||||
| @ -177,16 +209,70 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]): | ||||
|                 ), | ||||
|             ) | ||||
|  | ||||
|     def _patch_add_users(self, group: Group, users_set: set[int]): | ||||
|         """Add users in users_set to group""" | ||||
|         if len(users_set) < 1: | ||||
|             return | ||||
|     @transaction.atomic | ||||
|     def patch_compare_users(self, group: Group): | ||||
|         """Compare users with a SCIM group and add/remove any differences""" | ||||
|         # Get scim group first | ||||
|         scim_group = SCIMProviderGroup.objects.filter(provider=self.provider, group=group).first() | ||||
|         if not scim_group: | ||||
|             self.logger.warning( | ||||
|                 "could not sync group membership, group does not exist", group=group | ||||
|             ) | ||||
|             return | ||||
|         # Get a list of all users in the authentik group | ||||
|         raw_users_should = list(group.users.order_by("id").values_list("id", flat=True)) | ||||
|         # Lookup the SCIM IDs of the users | ||||
|         users_should: list[str] = list( | ||||
|             SCIMProviderUser.objects.filter( | ||||
|                 user__pk__in=raw_users_should, provider=self.provider | ||||
|             ).values_list("scim_id", flat=True) | ||||
|         ) | ||||
|         if len(raw_users_should) != len(users_should): | ||||
|             self.logger.warning( | ||||
|                 "User count mismatch, not all users in the group are synced to SCIM yet.", | ||||
|                 group=group, | ||||
|             ) | ||||
|         # Get current group status | ||||
|         current_group = SCIMGroupSchema.model_validate( | ||||
|             self._request("GET", f"/Groups/{scim_group.scim_id}") | ||||
|         ) | ||||
|         users_to_add = [] | ||||
|         users_to_remove = [] | ||||
|         # Check users currently in group and if they shouldn't be in the group and remove them | ||||
|         for user in current_group.members or []: | ||||
|             if user.value not in users_should: | ||||
|                 users_to_remove.append(user.value) | ||||
|         # Check users that should be in the group and add them | ||||
|         for user in users_should: | ||||
|             if len([x for x in current_group.members if x.value == user]) < 1: | ||||
|                 users_to_add.append(user) | ||||
|         # Only send request if we need to make changes | ||||
|         if len(users_to_add) < 1 and len(users_to_remove) < 1: | ||||
|             return | ||||
|         return self._patch_chunked( | ||||
|             scim_group.scim_id, | ||||
|             *[ | ||||
|                 PatchOperation( | ||||
|                     op=PatchOp.add, | ||||
|                     path="members", | ||||
|                     value=[{"value": x}], | ||||
|                 ) | ||||
|                 for x in users_to_add | ||||
|             ], | ||||
|             *[ | ||||
|                 PatchOperation( | ||||
|                     op=PatchOp.remove, | ||||
|                     path="members", | ||||
|                     value=[{"value": x}], | ||||
|                 ) | ||||
|                 for x in users_to_remove | ||||
|             ], | ||||
|         ) | ||||
|  | ||||
|     def _patch_add_users(self, scim_group: SCIMProviderGroup, users_set: set[int]): | ||||
|         """Add users in users_set to group""" | ||||
|         if len(users_set) < 1: | ||||
|             return | ||||
|         user_ids = list( | ||||
|             SCIMProviderUser.objects.filter( | ||||
|                 user__pk__in=users_set, provider=self.provider | ||||
| @ -194,7 +280,7 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]): | ||||
|         ) | ||||
|         if len(user_ids) < 1: | ||||
|             return | ||||
|         self._patch( | ||||
|         self._patch_chunked( | ||||
|             scim_group.scim_id, | ||||
|             *[ | ||||
|                 PatchOperation( | ||||
| @ -206,16 +292,10 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]): | ||||
|             ], | ||||
|         ) | ||||
|  | ||||
|     def _patch_remove_users(self, group: Group, users_set: set[int]): | ||||
|     def _patch_remove_users(self, scim_group: SCIMProviderGroup, users_set: set[int]): | ||||
|         """Remove users in users_set from group""" | ||||
|         if len(users_set) < 1: | ||||
|             return | ||||
|         scim_group = SCIMProviderGroup.objects.filter(provider=self.provider, group=group).first() | ||||
|         if not scim_group: | ||||
|             self.logger.warning( | ||||
|                 "could not sync group membership, group does not exist", group=group | ||||
|             ) | ||||
|             return | ||||
|         user_ids = list( | ||||
|             SCIMProviderUser.objects.filter( | ||||
|                 user__pk__in=users_set, provider=self.provider | ||||
| @ -223,7 +303,7 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]): | ||||
|         ) | ||||
|         if len(user_ids) < 1: | ||||
|             return | ||||
|         self._patch( | ||||
|         self._patch_chunked( | ||||
|             scim_group.scim_id, | ||||
|             *[ | ||||
|                 PatchOperation( | ||||
|  | ||||
| @ -2,6 +2,7 @@ | ||||
|  | ||||
| from pydantic import Field | ||||
| from pydanticscim.group import Group as BaseGroup | ||||
| from pydanticscim.responses import PatchOperation as BasePatchOperation | ||||
| from pydanticscim.responses import PatchRequest as BasePatchRequest | ||||
| from pydanticscim.responses import SCIMError as BaseSCIMError | ||||
| from pydanticscim.service_provider import Bulk as BaseBulk | ||||
| @ -68,6 +69,12 @@ class PatchRequest(BasePatchRequest): | ||||
|     schemas: tuple[str] = ("urn:ietf:params:scim:api:messages:2.0:PatchOp",) | ||||
|  | ||||
|  | ||||
| class PatchOperation(BasePatchOperation): | ||||
|     """PatchOperation with optional path""" | ||||
|  | ||||
|     path: str | None | ||||
|  | ||||
|  | ||||
| class SCIMError(BaseSCIMError): | ||||
|     """SCIM error with optional status code""" | ||||
|  | ||||
|  | ||||
| @ -252,3 +252,118 @@ class SCIMMembershipTests(TestCase): | ||||
|                     ], | ||||
|                 }, | ||||
|             ) | ||||
|  | ||||
|     def test_member_add_save(self): | ||||
|         """Test member add + save""" | ||||
|         config = ServiceProviderConfiguration.default() | ||||
|  | ||||
|         config.patch.supported = True | ||||
|         user_scim_id = generate_id() | ||||
|         group_scim_id = generate_id() | ||||
|         uid = generate_id() | ||||
|         group = Group.objects.create( | ||||
|             name=uid, | ||||
|         ) | ||||
|  | ||||
|         user = User.objects.create(username=generate_id()) | ||||
|  | ||||
|         # Test initial sync of group creation | ||||
|         with Mocker() as mocker: | ||||
|             mocker.get( | ||||
|                 "https://localhost/ServiceProviderConfig", | ||||
|                 json=config.model_dump(), | ||||
|             ) | ||||
|             mocker.post( | ||||
|                 "https://localhost/Users", | ||||
|                 json={ | ||||
|                     "id": user_scim_id, | ||||
|                 }, | ||||
|             ) | ||||
|             mocker.post( | ||||
|                 "https://localhost/Groups", | ||||
|                 json={ | ||||
|                     "id": group_scim_id, | ||||
|                 }, | ||||
|             ) | ||||
|  | ||||
|             self.configure() | ||||
|             sync_tasks.trigger_single_task(self.provider, scim_sync).get() | ||||
|  | ||||
|             self.assertEqual(mocker.call_count, 6) | ||||
|             self.assertEqual(mocker.request_history[0].method, "GET") | ||||
|             self.assertEqual(mocker.request_history[1].method, "GET") | ||||
|             self.assertEqual(mocker.request_history[2].method, "GET") | ||||
|             self.assertEqual(mocker.request_history[3].method, "POST") | ||||
|             self.assertEqual(mocker.request_history[4].method, "GET") | ||||
|             self.assertEqual(mocker.request_history[5].method, "POST") | ||||
|             self.assertJSONEqual( | ||||
|                 mocker.request_history[3].body, | ||||
|                 { | ||||
|                     "schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"], | ||||
|                     "emails": [], | ||||
|                     "active": True, | ||||
|                     "externalId": user.uid, | ||||
|                     "name": {"familyName": " ", "formatted": " ", "givenName": ""}, | ||||
|                     "displayName": "", | ||||
|                     "userName": user.username, | ||||
|                 }, | ||||
|             ) | ||||
|             self.assertJSONEqual( | ||||
|                 mocker.request_history[5].body, | ||||
|                 { | ||||
|                     "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"], | ||||
|                     "externalId": str(group.pk), | ||||
|                     "displayName": group.name, | ||||
|                 }, | ||||
|             ) | ||||
|  | ||||
|         with Mocker() as mocker: | ||||
|             mocker.get( | ||||
|                 "https://localhost/ServiceProviderConfig", | ||||
|                 json=config.model_dump(), | ||||
|             ) | ||||
|             mocker.get( | ||||
|                 f"https://localhost/Groups/{group_scim_id}", | ||||
|                 json={}, | ||||
|             ) | ||||
|             mocker.patch( | ||||
|                 f"https://localhost/Groups/{group_scim_id}", | ||||
|                 json={}, | ||||
|             ) | ||||
|             group.users.add(user) | ||||
|             group.save() | ||||
|             self.assertEqual(mocker.call_count, 5) | ||||
|             self.assertEqual(mocker.request_history[0].method, "GET") | ||||
|             self.assertEqual(mocker.request_history[1].method, "PATCH") | ||||
|             self.assertEqual(mocker.request_history[2].method, "GET") | ||||
|             self.assertEqual(mocker.request_history[3].method, "PATCH") | ||||
|             self.assertEqual(mocker.request_history[4].method, "GET") | ||||
|             self.assertJSONEqual( | ||||
|                 mocker.request_history[1].body, | ||||
|                 { | ||||
|                     "schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"], | ||||
|                     "Operations": [ | ||||
|                         { | ||||
|                             "op": "add", | ||||
|                             "path": "members", | ||||
|                             "value": [{"value": user_scim_id}], | ||||
|                         } | ||||
|                     ], | ||||
|                 }, | ||||
|             ) | ||||
|             self.assertJSONEqual( | ||||
|                 mocker.request_history[3].body, | ||||
|                 { | ||||
|                     "Operations": [ | ||||
|                         { | ||||
|                             "op": "replace", | ||||
|                             "value": { | ||||
|                                 "id": group_scim_id, | ||||
|                                 "displayName": group.name, | ||||
|                                 "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"], | ||||
|                                 "externalId": str(group.pk), | ||||
|                             }, | ||||
|                         } | ||||
|                     ] | ||||
|                 }, | ||||
|             ) | ||||
|  | ||||
| @ -87,7 +87,11 @@ def task_error_hook(task_id: str, exception: Exception, traceback, *args, **kwar | ||||
|  | ||||
| def _get_startup_tasks_default_tenant() -> list[Callable]: | ||||
|     """Get all tasks to be run on startup for the default tenant""" | ||||
|     return [] | ||||
|     from authentik.outposts.tasks import outpost_connection_discovery | ||||
|  | ||||
|     return [ | ||||
|         outpost_connection_discovery, | ||||
|     ] | ||||
|  | ||||
|  | ||||
| def _get_startup_tasks_all_tenants() -> list[Callable]: | ||||
|  | ||||
| @ -2,6 +2,7 @@ | ||||
|  | ||||
| from collections.abc import Callable | ||||
| from hashlib import sha512 | ||||
| from ipaddress import ip_address | ||||
| from time import perf_counter, time | ||||
| from typing import Any | ||||
|  | ||||
| @ -174,6 +175,7 @@ class ClientIPMiddleware: | ||||
|  | ||||
|     def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]): | ||||
|         self.get_response = get_response | ||||
|         self.logger = get_logger().bind() | ||||
|  | ||||
|     def _get_client_ip_from_meta(self, meta: dict[str, Any]) -> str: | ||||
|         """Attempt to get the client's IP by checking common HTTP Headers. | ||||
| @ -185,10 +187,15 @@ class ClientIPMiddleware: | ||||
|             "HTTP_X_FORWARDED_FOR", | ||||
|             "REMOTE_ADDR", | ||||
|         ) | ||||
|         try: | ||||
|             for _header in headers: | ||||
|                 if _header in meta: | ||||
|                     ips: list[str] = meta.get(_header).split(",") | ||||
|                 return ips[0].strip() | ||||
|                     # Ensure the IP parses as a valid IP | ||||
|                     return str(ip_address(ips[0].strip())) | ||||
|             return self.default_ip | ||||
|         except ValueError as exc: | ||||
|             self.logger.debug("Invalid remote IP", exc=exc) | ||||
|             return self.default_ip | ||||
|  | ||||
|     # FIXME: this should probably not be in `root` but rather in a middleware in `outposts` | ||||
| @ -226,7 +233,11 @@ class ClientIPMiddleware: | ||||
|         Scope.get_isolation_scope().set_user(user) | ||||
|         # Set the outpost service account on the request | ||||
|         setattr(request, self.request_attr_outpost_user, user) | ||||
|         return delegated_ip | ||||
|         try: | ||||
|             return str(ip_address(delegated_ip)) | ||||
|         except ValueError as exc: | ||||
|             self.logger.debug("Invalid remote IP from Outpost", exc=exc) | ||||
|             return None | ||||
|  | ||||
|     def _get_client_ip(self, request: HttpRequest | None) -> str: | ||||
|         """Attempt to get the client's IP by checking common HTTP Headers. | ||||
|  | ||||
| @ -1,6 +1,8 @@ | ||||
| """Metrics view""" | ||||
|  | ||||
| from base64 import b64encode | ||||
| from hmac import compare_digest | ||||
| from pathlib import Path | ||||
| from tempfile import gettempdir | ||||
|  | ||||
| from django.conf import settings | ||||
| from django.db import connections | ||||
| @ -16,22 +18,21 @@ monitoring_set = Signal() | ||||
|  | ||||
|  | ||||
| class MetricsView(View): | ||||
|     """Wrapper around ExportToDjangoView, using http-basic auth""" | ||||
|     """Wrapper around ExportToDjangoView with authentication, accessed by the authentik router""" | ||||
|  | ||||
|     def __init__(self, **kwargs): | ||||
|         _tmp = Path(gettempdir()) | ||||
|         with open(_tmp / "authentik-core-metrics.key") as _f: | ||||
|             self.monitoring_key = _f.read() | ||||
|  | ||||
|     def get(self, request: HttpRequest) -> HttpResponse: | ||||
|         """Check for HTTP-Basic auth""" | ||||
|         auth_header = request.META.get("HTTP_AUTHORIZATION", "") | ||||
|         auth_type, _, given_credentials = auth_header.partition(" ") | ||||
|         credentials = f"monitor:{settings.SECRET_KEY}" | ||||
|         expected = b64encode(str.encode(credentials)).decode() | ||||
|         authed = auth_type == "Basic" and given_credentials == expected | ||||
|         authed = auth_type == "Bearer" and compare_digest(given_credentials, self.monitoring_key) | ||||
|         if not authed and not settings.DEBUG: | ||||
|             response = HttpResponse(status=401) | ||||
|             response["WWW-Authenticate"] = 'Basic realm="authentik-monitoring"' | ||||
|             return response | ||||
|  | ||||
|             return HttpResponse(status=401) | ||||
|         monitoring_set.send_robust(self) | ||||
|  | ||||
|         return ExportToDjangoView(request) | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -1,6 +1,7 @@ | ||||
| """authentik storage backends""" | ||||
|  | ||||
| import os | ||||
| from urllib.parse import parse_qsl, urlsplit | ||||
|  | ||||
| from django.conf import settings | ||||
| from django.core.exceptions import SuspiciousOperation | ||||
| @ -110,3 +111,34 @@ class S3Storage(BaseS3Storage): | ||||
|         if self.querystring_auth: | ||||
|             return url | ||||
|         return self._strip_signing_parameters(url) | ||||
|  | ||||
|     def _strip_signing_parameters(self, url): | ||||
|         # Boto3 does not currently support generating URLs that are unsigned. Instead | ||||
|         # we take the signed URLs and strip any querystring params related to signing | ||||
|         # and expiration. | ||||
|         # Note that this may end up with URLs that are still invalid, especially if | ||||
|         # params are passed in that only work with signed URLs, e.g. response header | ||||
|         # params. | ||||
|         # The code attempts to strip all query parameters that match names of known | ||||
|         # parameters from v2 and v4 signatures, regardless of the actual signature | ||||
|         # version used. | ||||
|         split_url = urlsplit(url) | ||||
|         qs = parse_qsl(split_url.query, keep_blank_values=True) | ||||
|         blacklist = { | ||||
|             "x-amz-algorithm", | ||||
|             "x-amz-credential", | ||||
|             "x-amz-date", | ||||
|             "x-amz-expires", | ||||
|             "x-amz-signedheaders", | ||||
|             "x-amz-signature", | ||||
|             "x-amz-security-token", | ||||
|             "awsaccesskeyid", | ||||
|             "expires", | ||||
|             "signature", | ||||
|         } | ||||
|         filtered_qs = ((key, val) for key, val in qs if key.lower() not in blacklist) | ||||
|         # Note: Parameters that did not have a value in the original query string will | ||||
|         # have an '=' sign appended to it, e.g ?foo&bar becomes ?foo=&bar= | ||||
|         joined_qs = ("=".join(keyval) for keyval in filtered_qs) | ||||
|         split_url = split_url._replace(query="&".join(joined_qs)) | ||||
|         return split_url.geturl() | ||||
|  | ||||
| @ -1,8 +1,9 @@ | ||||
| """root tests""" | ||||
|  | ||||
| from base64 import b64encode | ||||
| from pathlib import Path | ||||
| from secrets import token_urlsafe | ||||
| from tempfile import gettempdir | ||||
|  | ||||
| from django.conf import settings | ||||
| from django.test import TestCase | ||||
| from django.urls import reverse | ||||
|  | ||||
| @ -10,6 +11,16 @@ from django.urls import reverse | ||||
| class TestRoot(TestCase): | ||||
|     """Test root application""" | ||||
|  | ||||
|     def setUp(self): | ||||
|         _tmp = Path(gettempdir()) | ||||
|         self.token = token_urlsafe(32) | ||||
|         with open(_tmp / "authentik-core-metrics.key", "w") as _f: | ||||
|             _f.write(self.token) | ||||
|  | ||||
|     def tearDown(self): | ||||
|         _tmp = Path(gettempdir()) | ||||
|         (_tmp / "authentik-core-metrics.key").unlink() | ||||
|  | ||||
|     def test_monitoring_error(self): | ||||
|         """Test monitoring without any credentials""" | ||||
|         response = self.client.get(reverse("metrics")) | ||||
| @ -17,8 +28,7 @@ class TestRoot(TestCase): | ||||
|  | ||||
|     def test_monitoring_ok(self): | ||||
|         """Test monitoring with credentials""" | ||||
|         creds = "Basic " + b64encode(f"monitor:{settings.SECRET_KEY}".encode()).decode("utf-8") | ||||
|         auth_headers = {"HTTP_AUTHORIZATION": creds} | ||||
|         auth_headers = {"HTTP_AUTHORIZATION": f"Bearer {self.token}"} | ||||
|         response = self.client.get(reverse("metrics"), **auth_headers) | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|  | ||||
|  | ||||
| @ -3,6 +3,7 @@ | ||||
| from typing import Any | ||||
|  | ||||
| from django.core.cache import cache | ||||
| from django.utils.translation import gettext_lazy as _ | ||||
| from drf_spectacular.utils import extend_schema, inline_serializer | ||||
| from guardian.shortcuts import get_objects_for_user | ||||
| from rest_framework.decorators import action | ||||
| @ -39,9 +40,8 @@ class LDAPSourceSerializer(SourceSerializer): | ||||
|         """Get cached source connectivity""" | ||||
|         return cache.get(CACHE_KEY_STATUS + source.slug, None) | ||||
|  | ||||
|     def validate(self, attrs: dict[str, Any]) -> dict[str, Any]: | ||||
|     def validate_sync_users_password(self, sync_users_password: bool) -> bool: | ||||
|         """Check that only a single source has password_sync on""" | ||||
|         sync_users_password = attrs.get("sync_users_password", True) | ||||
|         if sync_users_password: | ||||
|             sources = LDAPSource.objects.filter(sync_users_password=True) | ||||
|             if self.instance: | ||||
| @ -49,11 +49,31 @@ class LDAPSourceSerializer(SourceSerializer): | ||||
|             if sources.exists(): | ||||
|                 raise ValidationError( | ||||
|                     { | ||||
|                         "sync_users_password": ( | ||||
|                         "sync_users_password": _( | ||||
|                             "Only a single LDAP Source with password synchronization is allowed" | ||||
|                         ) | ||||
|                     } | ||||
|                 ) | ||||
|         return sync_users_password | ||||
|  | ||||
|     def validate(self, attrs: dict[str, Any]) -> dict[str, Any]: | ||||
|         """Validate property mappings with sync_ flags""" | ||||
|         types = ["user", "group"] | ||||
|         for type in types: | ||||
|             toggle_value = attrs.get(f"sync_{type}s", False) | ||||
|             mappings_field = f"{type}_property_mappings" | ||||
|             mappings_value = attrs.get(mappings_field, []) | ||||
|             if toggle_value and len(mappings_value) == 0: | ||||
|                 raise ValidationError( | ||||
|                     { | ||||
|                         mappings_field: _( | ||||
|                             ( | ||||
|                                 "When 'Sync {type}s' is enabled, '{type}s property " | ||||
|                                 "mappings' cannot be empty." | ||||
|                             ).format(type=type) | ||||
|                         ) | ||||
|                     } | ||||
|                 ) | ||||
|         return super().validate(attrs) | ||||
|  | ||||
|     class Meta: | ||||
| @ -166,7 +186,8 @@ class LDAPSourceViewSet(UsedByMixin, ModelViewSet): | ||||
|         for sync_class in SYNC_CLASSES: | ||||
|             class_name = sync_class.name() | ||||
|             all_objects.setdefault(class_name, []) | ||||
|             for obj in sync_class(source).get_objects(size_limit=10): | ||||
|             for page in sync_class(source).get_objects(size_limit=10): | ||||
|                 for obj in page: | ||||
|                     obj: dict | ||||
|                     obj.pop("raw_attributes", None) | ||||
|                     obj.pop("raw_dn", None) | ||||
|  | ||||
| @ -26,17 +26,16 @@ def sync_ldap_source_on_save(sender, instance: LDAPSource, **_): | ||||
|     """Ensure that source is synced on save (if enabled)""" | ||||
|     if not instance.enabled: | ||||
|         return | ||||
|     ldap_connectivity_check.delay(instance.pk) | ||||
|     # Don't sync sources when they don't have any property mappings. This will only happen if: | ||||
|     # - the user forgets to set them or | ||||
|     # - the source is newly created, this is the first save event | ||||
|     #   and the mappings are created with an m2m event | ||||
|     if ( | ||||
|         not instance.user_property_mappings.exists() | ||||
|         or not instance.group_property_mappings.exists() | ||||
|     ): | ||||
|     if instance.sync_users and not instance.user_property_mappings.exists(): | ||||
|         return | ||||
|     if instance.sync_groups and not instance.group_property_mappings.exists(): | ||||
|         return | ||||
|     ldap_sync_single.delay(instance.pk) | ||||
|     ldap_connectivity_check.delay(instance.pk) | ||||
|  | ||||
|  | ||||
| @receiver(password_validate) | ||||
|  | ||||
| @ -38,7 +38,11 @@ class GroupLDAPSynchronizer(BaseLDAPSynchronizer): | ||||
|             search_base=self.base_dn_groups, | ||||
|             search_filter=self._source.group_object_filter, | ||||
|             search_scope=SUBTREE, | ||||
|             attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES], | ||||
|             attributes=[ | ||||
|                 ALL_ATTRIBUTES, | ||||
|                 ALL_OPERATIONAL_ATTRIBUTES, | ||||
|                 self._source.object_uniqueness_field, | ||||
|             ], | ||||
|             **kwargs, | ||||
|         ) | ||||
|  | ||||
| @ -53,9 +57,9 @@ class GroupLDAPSynchronizer(BaseLDAPSynchronizer): | ||||
|                 continue | ||||
|             attributes = group.get("attributes", {}) | ||||
|             group_dn = flatten(flatten(group.get("entryDN", group.get("dn")))) | ||||
|             if self._source.object_uniqueness_field not in attributes: | ||||
|             if not attributes.get(self._source.object_uniqueness_field): | ||||
|                 self.message( | ||||
|                     f"Cannot find uniqueness field in attributes: '{group_dn}'", | ||||
|                     f"Uniqueness field not found/not set in attributes: '{group_dn}'", | ||||
|                     attributes=attributes.keys(), | ||||
|                     dn=group_dn, | ||||
|                 ) | ||||
|  | ||||
| @ -40,7 +40,11 @@ class UserLDAPSynchronizer(BaseLDAPSynchronizer): | ||||
|             search_base=self.base_dn_users, | ||||
|             search_filter=self._source.user_object_filter, | ||||
|             search_scope=SUBTREE, | ||||
|             attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES], | ||||
|             attributes=[ | ||||
|                 ALL_ATTRIBUTES, | ||||
|                 ALL_OPERATIONAL_ATTRIBUTES, | ||||
|                 self._source.object_uniqueness_field, | ||||
|             ], | ||||
|             **kwargs, | ||||
|         ) | ||||
|  | ||||
| @ -55,9 +59,9 @@ class UserLDAPSynchronizer(BaseLDAPSynchronizer): | ||||
|                 continue | ||||
|             attributes = user.get("attributes", {}) | ||||
|             user_dn = flatten(user.get("entryDN", user.get("dn"))) | ||||
|             if self._source.object_uniqueness_field not in attributes: | ||||
|             if not attributes.get(self._source.object_uniqueness_field): | ||||
|                 self.message( | ||||
|                     f"Cannot find uniqueness field in attributes: '{user_dn}'", | ||||
|                     f"Uniqueness field not found/not set in attributes: '{user_dn}'", | ||||
|                     attributes=attributes.keys(), | ||||
|                     dn=user_dn, | ||||
|                 ) | ||||
|  | ||||
							
								
								
									
										4
									
								
								authentik/sources/ldap/sync/vendor/ms_ad.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								authentik/sources/ldap/sync/vendor/ms_ad.py
									
									
									
									
										vendored
									
									
								
							| @ -78,7 +78,9 @@ class MicrosoftActiveDirectory(BaseLDAPSynchronizer): | ||||
|         #   /useraccountcontrol-manipulate-account-properties | ||||
|         uac_bit = attributes.get("userAccountControl", 512) | ||||
|         uac = UserAccountControl(uac_bit) | ||||
|         is_active = UserAccountControl.ACCOUNTDISABLE not in uac | ||||
|         is_active = ( | ||||
|             UserAccountControl.ACCOUNTDISABLE not in uac and UserAccountControl.LOCKOUT not in uac | ||||
|         ) | ||||
|         if is_active != user.is_active: | ||||
|             user.is_active = is_active | ||||
|             user.save() | ||||
|  | ||||
| @ -50,3 +50,35 @@ class LDAPAPITests(APITestCase): | ||||
|             } | ||||
|         ) | ||||
|         self.assertFalse(serializer.is_valid()) | ||||
|  | ||||
|     def test_sync_users_mapping_empty(self): | ||||
|         """Check that when sync_users is enabled, property mappings must be set""" | ||||
|         serializer = LDAPSourceSerializer( | ||||
|             data={ | ||||
|                 "name": "foo", | ||||
|                 "slug": " foo", | ||||
|                 "server_uri": "ldaps://1.2.3.4", | ||||
|                 "bind_cn": "", | ||||
|                 "bind_password": LDAP_PASSWORD, | ||||
|                 "base_dn": "dc=foo", | ||||
|                 "sync_users": True, | ||||
|                 "user_property_mappings": [], | ||||
|             } | ||||
|         ) | ||||
|         self.assertFalse(serializer.is_valid()) | ||||
|  | ||||
|     def test_sync_groups_mapping_empty(self): | ||||
|         """Check that when sync_groups is enabled, property mappings must be set""" | ||||
|         serializer = LDAPSourceSerializer( | ||||
|             data={ | ||||
|                 "name": "foo", | ||||
|                 "slug": " foo", | ||||
|                 "server_uri": "ldaps://1.2.3.4", | ||||
|                 "bind_cn": "", | ||||
|                 "bind_password": LDAP_PASSWORD, | ||||
|                 "base_dn": "dc=foo", | ||||
|                 "sync_groups": True, | ||||
|                 "group_property_mappings": [], | ||||
|             } | ||||
|         ) | ||||
|         self.assertFalse(serializer.is_valid()) | ||||
|  | ||||
| @ -30,7 +30,9 @@ class TestMetadataProcessor(TestCase): | ||||
|         xml = MetadataProcessor(self.source, request).build_entity_descriptor() | ||||
|         metadata = lxml_from_string(xml) | ||||
|  | ||||
|         schema = etree.XMLSchema(etree.parse("schemas/saml-schema-metadata-2.0.xsd"))  # nosec | ||||
|         schema = etree.XMLSchema( | ||||
|             etree.parse("schemas/saml-schema-metadata-2.0.xsd", parser=etree.XMLParser())  # nosec | ||||
|         ) | ||||
|         self.assertTrue(schema.validate(metadata)) | ||||
|  | ||||
|     def test_metadata_consistent(self): | ||||
|  | ||||
| @ -82,3 +82,5 @@ entries: | ||||
|     order: 10 | ||||
|     target: !KeyOf default-authentication-flow-password-binding | ||||
|     policy: !KeyOf default-authentication-flow-password-optional | ||||
|   attrs: | ||||
|     failure_result: true | ||||
|  | ||||
| @ -2,7 +2,7 @@ | ||||
|     "$schema": "http://json-schema.org/draft-07/schema", | ||||
|     "$id": "https://goauthentik.io/blueprints/schema.json", | ||||
|     "type": "object", | ||||
|     "title": "authentik 2024.6.4 Blueprint schema", | ||||
|     "title": "authentik 2024.8.5 Blueprint schema", | ||||
|     "required": [ | ||||
|         "version", | ||||
|         "entries" | ||||
| @ -5345,9 +5345,30 @@ | ||||
|                     "description": "Key used to sign the tokens. Only required when JWT Algorithm is set to RS256." | ||||
|                 }, | ||||
|                 "redirect_uris": { | ||||
|                     "type": "array", | ||||
|                     "items": { | ||||
|                         "type": "object", | ||||
|                         "properties": { | ||||
|                             "matching_mode": { | ||||
|                                 "type": "string", | ||||
|                     "title": "Redirect URIs", | ||||
|                     "description": "Enter each URI on a new line." | ||||
|                                 "enum": [ | ||||
|                                     "strict", | ||||
|                                     "regex" | ||||
|                                 ], | ||||
|                                 "title": "Matching mode" | ||||
|                             }, | ||||
|                             "url": { | ||||
|                                 "type": "string", | ||||
|                                 "minLength": 1, | ||||
|                                 "title": "Url" | ||||
|                             } | ||||
|                         }, | ||||
|                         "required": [ | ||||
|                             "matching_mode", | ||||
|                             "url" | ||||
|                         ] | ||||
|                     }, | ||||
|                     "title": "Redirect uris" | ||||
|                 }, | ||||
|                 "sub_mode": { | ||||
|                     "type": "string", | ||||
|  | ||||
| @ -14,11 +14,7 @@ entries: | ||||
|       expression: | | ||||
|         # This mapping is used by the authentik proxy. It passes extra user attributes, | ||||
|         # which are used for example for the HTTP-Basic Authentication mapping. | ||||
|         session_id = None | ||||
|         if "token" in request.context: | ||||
|             session_id = request.context.get("token").session_id | ||||
|         return { | ||||
|             "sid": session_id, | ||||
|             "ak_proxy": { | ||||
|                 "user_attributes": request.user.group_attributes(request), | ||||
|                 "is_superuser": request.user.is_superuser, | ||||
|  | ||||
| @ -31,7 +31,7 @@ services: | ||||
|     volumes: | ||||
|       - redis:/data | ||||
|   server: | ||||
|     image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2024.6.4} | ||||
|     image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2024.8.5} | ||||
|     restart: unless-stopped | ||||
|     command: server | ||||
|     environment: | ||||
| @ -52,7 +52,7 @@ services: | ||||
|       - postgresql | ||||
|       - redis | ||||
|   worker: | ||||
|     image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2024.6.4} | ||||
|     image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2024.8.5} | ||||
|     restart: unless-stopped | ||||
|     command: worker | ||||
|     environment: | ||||
|  | ||||
| @ -29,4 +29,4 @@ func UserAgent() string { | ||||
| 	return fmt.Sprintf("authentik@%s", FullVersion()) | ||||
| } | ||||
|  | ||||
| const VERSION = "2024.6.4" | ||||
| const VERSION = "2024.8.5" | ||||
|  | ||||
| @ -35,10 +35,11 @@ func Paginator[Tobj any, Treq any, Tres PaginatorResponse[Tobj]]( | ||||
| 	req PaginatorRequest[Treq, Tres], | ||||
| 	opts PaginatorOptions, | ||||
| ) ([]Tobj, error) { | ||||
| 	var bfreq, cfreq interface{} | ||||
| 	fetchOffset := func(page int32) (Tres, error) { | ||||
| 		req.Page(page) | ||||
| 		req.PageSize(int32(opts.PageSize)) | ||||
| 		res, _, err := req.Execute() | ||||
| 		bfreq = req.Page(page) | ||||
| 		cfreq = bfreq.(PaginatorRequest[Treq, Tres]).PageSize(int32(opts.PageSize)) | ||||
| 		res, _, err := cfreq.(PaginatorRequest[Treq, Tres]).Execute() | ||||
| 		if err != nil { | ||||
| 			opts.Logger.WithError(err).WithField("page", page).Warning("failed to fetch page") | ||||
| 		} | ||||
|  | ||||
							
								
								
									
										26
									
								
								internal/outpost/ak/api_utils_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								internal/outpost/ak/api_utils_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,26 @@ | ||||
| package ak | ||||
|  | ||||
| // func Test_PaginatorCompile(t *testing.T) { | ||||
| // 	req := api.ApiCoreUsersListRequest{} | ||||
| // 	Paginator(req, PaginatorOptions{ | ||||
| // 		PageSize: 100, | ||||
| // 	}) | ||||
| // } | ||||
|  | ||||
| // func Test_PaginatorCompileExplicit(t *testing.T) { | ||||
| // 	req := api.ApiCoreUsersListRequest{} | ||||
| // 	Paginator[ | ||||
| // 		api.User, | ||||
| // 		api.ApiCoreUsersListRequest, | ||||
| // 		*api.PaginatedUserList, | ||||
| // 	](req, PaginatorOptions{ | ||||
| // 		PageSize: 100, | ||||
| // 	}) | ||||
| // } | ||||
|  | ||||
| // func Test_PaginatorCompileOther(t *testing.T) { | ||||
| // 	req := api.ApiOutpostsProxyListRequest{} | ||||
| // 	Paginator(req, PaginatorOptions{ | ||||
| // 		PageSize: 100, | ||||
| // 	}) | ||||
| // } | ||||
| @ -96,7 +96,7 @@ func (db *DirectBinder) Bind(username string, req *bind.Request) (ldap.LDAPResul | ||||
| 		return ldap.LDAPResultOperationsError, nil | ||||
| 	} | ||||
| 	flags.UserPk = userInfo.User.Pk | ||||
| 	flags.CanSearch = access.HasSearchPermission != nil | ||||
| 	flags.CanSearch = access.GetHasSearchPermission() | ||||
| 	db.si.SetFlags(req.BindDN, &flags) | ||||
| 	if flags.CanSearch { | ||||
| 		req.Log().Debug("Allowed access to search") | ||||
|  | ||||
| @ -193,7 +193,17 @@ func NewApplication(p api.ProxyOutpostConfig, c *http.Client, server Server) (*A | ||||
| 	}) | ||||
|  | ||||
| 	mux.HandleFunc("/outpost.goauthentik.io/start", func(w http.ResponseWriter, r *http.Request) { | ||||
| 		a.handleAuthStart(w, r, "") | ||||
| 		fwd := "" | ||||
| 		// This should only really be hit for nginx forward_auth | ||||
| 		// as for that the auth start redirect URL is generated by the | ||||
| 		// reverse proxy, and as such we won't have a request we just | ||||
| 		// denied to reference for final URL | ||||
| 		rd, ok := a.checkRedirectParam(r) | ||||
| 		if ok { | ||||
| 			a.log.WithField("rd", rd).Trace("Setting redirect") | ||||
| 			fwd = rd | ||||
| 		} | ||||
| 		a.handleAuthStart(w, r, fwd) | ||||
| 	}) | ||||
| 	mux.HandleFunc("/outpost.goauthentik.io/callback", a.handleAuthCallback) | ||||
| 	mux.HandleFunc("/outpost.goauthentik.io/sign_out", a.handleSignOut) | ||||
|  | ||||
| @ -15,36 +15,6 @@ const ( | ||||
| 	LogoutSignature   = "X-authentik-logout" | ||||
| ) | ||||
|  | ||||
| func (a *Application) checkRedirectParam(r *http.Request) (string, bool) { | ||||
| 	rd := r.URL.Query().Get(redirectParam) | ||||
| 	if rd == "" { | ||||
| 		return "", false | ||||
| 	} | ||||
| 	u, err := url.Parse(rd) | ||||
| 	if err != nil { | ||||
| 		a.log.WithError(err).Warning("Failed to parse redirect URL") | ||||
| 		return "", false | ||||
| 	} | ||||
| 	// Check to make sure we only redirect to allowed places | ||||
| 	if a.Mode() == api.PROXYMODE_PROXY || a.Mode() == api.PROXYMODE_FORWARD_SINGLE { | ||||
| 		ext, err := url.Parse(a.proxyConfig.ExternalHost) | ||||
| 		if err != nil { | ||||
| 			return "", false | ||||
| 		} | ||||
| 		ext.Scheme = "" | ||||
| 		if !strings.Contains(u.String(), ext.String()) { | ||||
| 			a.log.WithField("url", u.String()).WithField("ext", ext.String()).Warning("redirect URI did not contain external host") | ||||
| 			return "", false | ||||
| 		} | ||||
| 	} else { | ||||
| 		if !strings.HasSuffix(u.Host, *a.proxyConfig.CookieDomain) { | ||||
| 			a.log.WithField("host", u.Host).WithField("dom", *a.proxyConfig.CookieDomain).Warning("redirect URI Host was not included in cookie domain") | ||||
| 			return "", false | ||||
| 		} | ||||
| 	} | ||||
| 	return u.String(), true | ||||
| } | ||||
|  | ||||
| func (a *Application) handleAuthStart(rw http.ResponseWriter, r *http.Request, fwd string) { | ||||
| 	state, err := a.createState(r, fwd) | ||||
| 	if err != nil { | ||||
|  | ||||
| @ -5,10 +5,13 @@ import ( | ||||
| 	"encoding/base64" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/golang-jwt/jwt/v5" | ||||
| 	"github.com/gorilla/securecookie" | ||||
| 	"github.com/mitchellh/mapstructure" | ||||
| 	"goauthentik.io/api/v3" | ||||
| ) | ||||
|  | ||||
| type OAuthState struct { | ||||
| @ -27,6 +30,44 @@ func (oas *OAuthState) GetAudience() (jwt.ClaimStrings, error)       { return ni | ||||
|  | ||||
| var base32RawStdEncoding = base32.StdEncoding.WithPadding(base32.NoPadding) | ||||
|  | ||||
| // Validate that the given redirect parameter (?rd=...) is valid and can be used | ||||
| // For proxy/forward_single this checks that if the `rd` param has a Hostname (and is a full URL) | ||||
| // the hostname matches what's configured, or no hostname must be given | ||||
| // For forward_domain this checks if the domain of the URL in `rd` ends with the configured domain | ||||
| func (a *Application) checkRedirectParam(r *http.Request) (string, bool) { | ||||
| 	rd := r.URL.Query().Get(redirectParam) | ||||
| 	if rd == "" { | ||||
| 		return "", false | ||||
| 	} | ||||
| 	u, err := url.Parse(rd) | ||||
| 	if err != nil { | ||||
| 		a.log.WithError(err).Warning("Failed to parse redirect URL") | ||||
| 		return "", false | ||||
| 	} | ||||
| 	// Check to make sure we only redirect to allowed places | ||||
| 	if a.Mode() == api.PROXYMODE_PROXY || a.Mode() == api.PROXYMODE_FORWARD_SINGLE { | ||||
| 		ext, err := url.Parse(a.proxyConfig.ExternalHost) | ||||
| 		if err != nil { | ||||
| 			return "", false | ||||
| 		} | ||||
| 		// Either hostname needs to match the configured domain, or host name must be empty for just a path | ||||
| 		if u.Host == "" { | ||||
| 			u.Host = ext.Host | ||||
| 			u.Scheme = ext.Scheme | ||||
| 		} | ||||
| 		if u.Host != ext.Host { | ||||
| 			a.log.WithField("url", u.String()).WithField("ext", ext.String()).Warning("redirect URI did not contain external host") | ||||
| 			return "", false | ||||
| 		} | ||||
| 	} else { | ||||
| 		if !strings.HasSuffix(u.Host, *a.proxyConfig.CookieDomain) { | ||||
| 			a.log.WithField("host", u.Host).WithField("dom", *a.proxyConfig.CookieDomain).Warning("redirect URI Host was not included in cookie domain") | ||||
| 			return "", false | ||||
| 		} | ||||
| 	} | ||||
| 	return u.String(), true | ||||
| } | ||||
|  | ||||
| func (a *Application) createState(r *http.Request, fwd string) (string, error) { | ||||
| 	s, _ := a.sessions.Get(r, a.SessionName()) | ||||
| 	if s.ID == "" { | ||||
| @ -39,17 +80,6 @@ func (a *Application) createState(r *http.Request, fwd string) (string, error) { | ||||
| 		SessionID: s.ID, | ||||
| 		Redirect:  fwd, | ||||
| 	} | ||||
| 	if fwd == "" { | ||||
| 		// This should only really be hit for nginx forward_auth | ||||
| 		// as for that the auth start redirect URL is generated by the | ||||
| 		// reverse proxy, and as such we won't have a request we just | ||||
| 		// denied to reference for final URL | ||||
| 		rd, ok := a.checkRedirectParam(r) | ||||
| 		if ok { | ||||
| 			a.log.WithField("rd", rd).Trace("Setting redirect") | ||||
| 			st.Redirect = rd | ||||
| 		} | ||||
| 	} | ||||
| 	token := jwt.NewWithClaims(jwt.SigningMethodHS256, st) | ||||
| 	tokenString, err := token.SignedString([]byte(a.proxyConfig.GetCookieSecret())) | ||||
| 	if err != nil { | ||||
|  | ||||
| @ -8,25 +8,45 @@ import ( | ||||
| 	"goauthentik.io/api/v3" | ||||
| ) | ||||
|  | ||||
| func TestCheckRedirectParam(t *testing.T) { | ||||
| func TestCheckRedirectParam_None(t *testing.T) { | ||||
| 	a := newTestApplication() | ||||
| 	// Test no rd param | ||||
| 	req, _ := http.NewRequest("GET", "/outpost.goauthentik.io/auth/start", nil) | ||||
|  | ||||
| 	rd, ok := a.checkRedirectParam(req) | ||||
|  | ||||
| 	assert.Equal(t, false, ok) | ||||
| 	assert.Equal(t, "", rd) | ||||
| } | ||||
|  | ||||
| 	req, _ = http.NewRequest("GET", "/outpost.goauthentik.io/auth/start?rd=https://google.com", nil) | ||||
| func TestCheckRedirectParam_Invalid(t *testing.T) { | ||||
| 	a := newTestApplication() | ||||
| 	// Test invalid rd param | ||||
| 	req, _ := http.NewRequest("GET", "/outpost.goauthentik.io/auth/start?rd=https://google.com", nil) | ||||
|  | ||||
| 	rd, ok = a.checkRedirectParam(req) | ||||
| 	rd, ok := a.checkRedirectParam(req) | ||||
|  | ||||
| 	assert.Equal(t, false, ok) | ||||
| 	assert.Equal(t, "", rd) | ||||
| } | ||||
|  | ||||
| 	req, _ = http.NewRequest("GET", "/outpost.goauthentik.io/auth/start?rd=https://ext.t.goauthentik.io/test?foo", nil) | ||||
| func TestCheckRedirectParam_ValidFull(t *testing.T) { | ||||
| 	a := newTestApplication() | ||||
| 	// Test valid full rd param | ||||
| 	req, _ := http.NewRequest("GET", "/outpost.goauthentik.io/auth/start?rd=https://ext.t.goauthentik.io/test?foo", nil) | ||||
|  | ||||
| 	rd, ok = a.checkRedirectParam(req) | ||||
| 	rd, ok := a.checkRedirectParam(req) | ||||
|  | ||||
| 	assert.Equal(t, true, ok) | ||||
| 	assert.Equal(t, "https://ext.t.goauthentik.io/test?foo", rd) | ||||
| } | ||||
|  | ||||
| func TestCheckRedirectParam_ValidPartial(t *testing.T) { | ||||
| 	a := newTestApplication() | ||||
| 	// Test valid partial rd param | ||||
| 	req, _ := http.NewRequest("GET", "/outpost.goauthentik.io/auth/start?rd=/test?foo", nil) | ||||
|  | ||||
| 	rd, ok := a.checkRedirectParam(req) | ||||
|  | ||||
| 	assert.Equal(t, true, ok) | ||||
| 	assert.Equal(t, "https://ext.t.goauthentik.io/test?foo", rd) | ||||
|  | ||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user
	