Compare commits
	
		
			75 Commits
		
	
	
		
			version/20
			...
			version/20
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 8ab8090b60 | |||
| f563a0eb36 | |||
| 5dab92f0d1 | |||
| 180578559d | |||
| 48854fa72a | |||
| c99a33baee | |||
| b17d482e50 | |||
| 524d46ad7c | |||
| f90d6bb3d9 | |||
| 2340bced63 | |||
| 0a51e1b696 | |||
| 13636c0efe | |||
| e7f49d97a8 | |||
| 736240f60d | |||
| e8b5e4c127 | |||
| 81ec98b198 | |||
| c46ab19e79 | |||
| de9fc5de6b | |||
| eab3d9b411 | |||
| 7cb40d786f | |||
| b4fce08bbc | |||
| 8a2ba1c518 | |||
| 25b4306693 | |||
| 1e279950f1 | |||
| 960429355f | |||
| b4f3748353 | |||
| 91d2445c61 | |||
| dd8f809161 | |||
| 57a31b5dd1 | |||
| 09125b6236 | |||
| 832126c6fe | |||
| 25fe489b34 | |||
| 18078fd68f | |||
| 4fa71d995d | |||
| 22cec64234 | |||
| a87cc27366 | |||
| ad7ad1fa78 | |||
| c70e609e50 | |||
| 5f08485fff | |||
| 3a2ed11821 | |||
| ee04f39e28 | |||
| 2c6aa72f3c | |||
| bd0afef790 | |||
| fc11cc0a1a | |||
| fb78303e8f | |||
| 2ea04440db | |||
| 96e1636be3 | |||
| c546451a73 | |||
| 61778053b4 | |||
| f5580d311d | |||
| 99d292bce0 | |||
| b2801641bc | |||
| bfaa1046b2 | |||
| 95c30400cc | |||
| e77480ee1d | |||
| 905800e535 | |||
| fadeaef4c6 | |||
| 437efda649 | |||
| dd75d5f54b | |||
| 392a2e582e | |||
| a1da183721 | |||
| feea2df0b1 | |||
| b47acd8c76 | |||
| 6fd87d9ced | |||
| acbb065808 | |||
| 2fb097061d | |||
| 8962d17e03 | |||
| 8326e1490c | |||
| 091e4d3e4c | |||
| 6ee77edcbb | |||
| 763e2288bf | |||
| 9cdb177ca7 | |||
| 6070508058 | |||
| ec13a5d84d | |||
| 057de82b01 | 
@ -1,5 +1,5 @@
 | 
				
			|||||||
[bumpversion]
 | 
					[bumpversion]
 | 
				
			||||||
current_version = 2024.6.4
 | 
					current_version = 2024.8.6
 | 
				
			||||||
tag = True
 | 
					tag = True
 | 
				
			||||||
commit = True
 | 
					commit = True
 | 
				
			||||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))?
 | 
					parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))?
 | 
				
			||||||
 | 
				
			|||||||
@ -29,9 +29,9 @@ outputs:
 | 
				
			|||||||
  imageTags:
 | 
					  imageTags:
 | 
				
			||||||
    description: "Docker image tags"
 | 
					    description: "Docker image tags"
 | 
				
			||||||
    value: ${{ steps.ev.outputs.imageTags }}
 | 
					    value: ${{ steps.ev.outputs.imageTags }}
 | 
				
			||||||
  imageNames:
 | 
					  attestImageNames:
 | 
				
			||||||
    description: "Docker image names"
 | 
					    description: "Docker image names used for attestation"
 | 
				
			||||||
    value: ${{ steps.ev.outputs.imageNames }}
 | 
					    value: ${{ steps.ev.outputs.attestImageNames }}
 | 
				
			||||||
  imageMainTag:
 | 
					  imageMainTag:
 | 
				
			||||||
    description: "Docker image main tag"
 | 
					    description: "Docker image main tag"
 | 
				
			||||||
    value: ${{ steps.ev.outputs.imageMainTag }}
 | 
					    value: ${{ steps.ev.outputs.imageMainTag }}
 | 
				
			||||||
 | 
				
			|||||||
@ -51,15 +51,24 @@ else:
 | 
				
			|||||||
        ]
 | 
					        ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
image_main_tag = image_tags[0].split(":")[-1]
 | 
					image_main_tag = image_tags[0].split(":")[-1]
 | 
				
			||||||
image_tags_rendered = ",".join(image_tags)
 | 
					
 | 
				
			||||||
image_names_rendered = ",".join(set(name.split(":")[0] for name in image_tags))
 | 
					
 | 
				
			||||||
 | 
					def get_attest_image_names(image_with_tags: list[str]):
 | 
				
			||||||
 | 
					    """Attestation only for GHCR"""
 | 
				
			||||||
 | 
					    image_tags = []
 | 
				
			||||||
 | 
					    for image_name in set(name.split(":")[0] for name in image_with_tags):
 | 
				
			||||||
 | 
					        if not image_name.startswith("ghcr.io"):
 | 
				
			||||||
 | 
					            continue
 | 
				
			||||||
 | 
					        image_tags.append(image_name)
 | 
				
			||||||
 | 
					    return ",".join(set(image_tags))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
with open(os.environ["GITHUB_OUTPUT"], "a+", encoding="utf-8") as _output:
 | 
					with open(os.environ["GITHUB_OUTPUT"], "a+", encoding="utf-8") as _output:
 | 
				
			||||||
    print(f"shouldBuild={should_build}", file=_output)
 | 
					    print(f"shouldBuild={should_build}", file=_output)
 | 
				
			||||||
    print(f"sha={sha}", file=_output)
 | 
					    print(f"sha={sha}", file=_output)
 | 
				
			||||||
    print(f"version={version}", file=_output)
 | 
					    print(f"version={version}", file=_output)
 | 
				
			||||||
    print(f"prerelease={prerelease}", file=_output)
 | 
					    print(f"prerelease={prerelease}", file=_output)
 | 
				
			||||||
    print(f"imageTags={image_tags_rendered}", file=_output)
 | 
					    print(f"imageTags={','.join(image_tags)}", file=_output)
 | 
				
			||||||
    print(f"imageNames={image_names_rendered}", file=_output)
 | 
					    print(f"attestImageNames={get_attest_image_names(image_tags)}", file=_output)
 | 
				
			||||||
    print(f"imageMainTag={image_main_tag}", file=_output)
 | 
					    print(f"imageMainTag={image_main_tag}", file=_output)
 | 
				
			||||||
    print(f"imageMainName={image_tags[0]}", file=_output)
 | 
					    print(f"imageMainName={image_tags[0]}", file=_output)
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										2
									
								
								.github/workflows/ci-main.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/ci-main.yml
									
									
									
									
										vendored
									
									
								
							@ -261,7 +261,7 @@ jobs:
 | 
				
			|||||||
        id: attest
 | 
					        id: attest
 | 
				
			||||||
        if: ${{ steps.ev.outputs.shouldBuild == 'true' }}
 | 
					        if: ${{ steps.ev.outputs.shouldBuild == 'true' }}
 | 
				
			||||||
        with:
 | 
					        with:
 | 
				
			||||||
          subject-name: ${{ steps.ev.outputs.imageNames }}
 | 
					          subject-name: ${{ steps.ev.outputs.attestImageNames }}
 | 
				
			||||||
          subject-digest: ${{ steps.push.outputs.digest }}
 | 
					          subject-digest: ${{ steps.push.outputs.digest }}
 | 
				
			||||||
          push-to-registry: true
 | 
					          push-to-registry: true
 | 
				
			||||||
  pr-comment:
 | 
					  pr-comment:
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										2
									
								
								.github/workflows/ci-outpost.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/ci-outpost.yml
									
									
									
									
										vendored
									
									
								
							@ -115,7 +115,7 @@ jobs:
 | 
				
			|||||||
        id: attest
 | 
					        id: attest
 | 
				
			||||||
        if: ${{ steps.ev.outputs.shouldBuild == 'true' }}
 | 
					        if: ${{ steps.ev.outputs.shouldBuild == 'true' }}
 | 
				
			||||||
        with:
 | 
					        with:
 | 
				
			||||||
          subject-name: ${{ steps.ev.outputs.imageNames }}
 | 
					          subject-name: ${{ steps.ev.outputs.attestImageNames }}
 | 
				
			||||||
          subject-digest: ${{ steps.push.outputs.digest }}
 | 
					          subject-digest: ${{ steps.push.outputs.digest }}
 | 
				
			||||||
          push-to-registry: true
 | 
					          push-to-registry: true
 | 
				
			||||||
  build-binary:
 | 
					  build-binary:
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										4
									
								
								.github/workflows/release-publish.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.github/workflows/release-publish.yml
									
									
									
									
										vendored
									
									
								
							@ -58,7 +58,7 @@ jobs:
 | 
				
			|||||||
      - uses: actions/attest-build-provenance@v1
 | 
					      - uses: actions/attest-build-provenance@v1
 | 
				
			||||||
        id: attest
 | 
					        id: attest
 | 
				
			||||||
        with:
 | 
					        with:
 | 
				
			||||||
          subject-name: ${{ steps.ev.outputs.imageNames }}
 | 
					          subject-name: ${{ steps.ev.outputs.attestImageNames }}
 | 
				
			||||||
          subject-digest: ${{ steps.push.outputs.digest }}
 | 
					          subject-digest: ${{ steps.push.outputs.digest }}
 | 
				
			||||||
          push-to-registry: true
 | 
					          push-to-registry: true
 | 
				
			||||||
  build-outpost:
 | 
					  build-outpost:
 | 
				
			||||||
@ -122,7 +122,7 @@ jobs:
 | 
				
			|||||||
      - uses: actions/attest-build-provenance@v1
 | 
					      - uses: actions/attest-build-provenance@v1
 | 
				
			||||||
        id: attest
 | 
					        id: attest
 | 
				
			||||||
        with:
 | 
					        with:
 | 
				
			||||||
          subject-name: ${{ steps.ev.outputs.imageNames }}
 | 
					          subject-name: ${{ steps.ev.outputs.attestImageNames }}
 | 
				
			||||||
          subject-digest: ${{ steps.push.outputs.digest }}
 | 
					          subject-digest: ${{ steps.push.outputs.digest }}
 | 
				
			||||||
          push-to-registry: true
 | 
					          push-to-registry: true
 | 
				
			||||||
  build-outpost-binary:
 | 
					  build-outpost-binary:
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										2
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								Makefile
									
									
									
									
									
								
							@ -205,7 +205,7 @@ gen: gen-build gen-client-ts
 | 
				
			|||||||
web-build: web-install  ## Build the Authentik UI
 | 
					web-build: web-install  ## Build the Authentik UI
 | 
				
			||||||
	cd web && npm run build
 | 
						cd web && npm run build
 | 
				
			||||||
 | 
					
 | 
				
			||||||
web: web-lint-fix web-lint web-check-compile web-test  ## Automatically fix formatting issues in the Authentik UI source code, lint the code, and compile it
 | 
					web: web-lint-fix web-lint web-check-compile  ## Automatically fix formatting issues in the Authentik UI source code, lint the code, and compile it
 | 
				
			||||||
 | 
					
 | 
				
			||||||
web-install:  ## Install the necessary libraries to build the Authentik UI
 | 
					web-install:  ## Install the necessary libraries to build the Authentik UI
 | 
				
			||||||
	cd web && npm ci
 | 
						cd web && npm ci
 | 
				
			||||||
 | 
				
			|||||||
@ -2,7 +2,7 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from os import environ
 | 
					from os import environ
 | 
				
			||||||
 | 
					
 | 
				
			||||||
__version__ = "2024.6.4"
 | 
					__version__ = "2024.8.6"
 | 
				
			||||||
ENV_GIT_HASH_KEY = "GIT_BUILD_HASH"
 | 
					ENV_GIT_HASH_KEY = "GIT_BUILD_HASH"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -51,9 +51,11 @@ class BlueprintInstanceSerializer(ModelSerializer):
 | 
				
			|||||||
        context = self.instance.context if self.instance else {}
 | 
					        context = self.instance.context if self.instance else {}
 | 
				
			||||||
        valid, logs = Importer.from_string(content, context).validate()
 | 
					        valid, logs = Importer.from_string(content, context).validate()
 | 
				
			||||||
        if not valid:
 | 
					        if not valid:
 | 
				
			||||||
            text_logs = "\n".join([x["event"] for x in logs])
 | 
					 | 
				
			||||||
            raise ValidationError(
 | 
					            raise ValidationError(
 | 
				
			||||||
                _("Failed to validate blueprint: {logs}".format_map({"logs": text_logs}))
 | 
					                [
 | 
				
			||||||
 | 
					                    _("Failed to validate blueprint"),
 | 
				
			||||||
 | 
					                    *[f"- {x.event}" for x in logs],
 | 
				
			||||||
 | 
					                ]
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
        return content
 | 
					        return content
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -78,5 +78,5 @@ class TestBlueprintsV1API(APITestCase):
 | 
				
			|||||||
        self.assertEqual(res.status_code, 400)
 | 
					        self.assertEqual(res.status_code, 400)
 | 
				
			||||||
        self.assertJSONEqual(
 | 
					        self.assertJSONEqual(
 | 
				
			||||||
            res.content.decode(),
 | 
					            res.content.decode(),
 | 
				
			||||||
            {"content": ["Failed to validate blueprint: Invalid blueprint version"]},
 | 
					            {"content": ["Failed to validate blueprint", "- Invalid blueprint version"]},
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
				
			|||||||
@ -429,7 +429,7 @@ class Importer:
 | 
				
			|||||||
        orig_import = deepcopy(self._import)
 | 
					        orig_import = deepcopy(self._import)
 | 
				
			||||||
        if self._import.version != 1:
 | 
					        if self._import.version != 1:
 | 
				
			||||||
            self.logger.warning("Invalid blueprint version")
 | 
					            self.logger.warning("Invalid blueprint version")
 | 
				
			||||||
            return False, [{"event": "Invalid blueprint version"}]
 | 
					            return False, [LogEvent("Invalid blueprint version", log_level="warning", logger=None)]
 | 
				
			||||||
        with (
 | 
					        with (
 | 
				
			||||||
            transaction_rollback(),
 | 
					            transaction_rollback(),
 | 
				
			||||||
            capture_logs() as logs,
 | 
					            capture_logs() as logs,
 | 
				
			||||||
 | 
				
			|||||||
@ -30,8 +30,10 @@ from authentik.core.api.utils import (
 | 
				
			|||||||
    PassiveSerializer,
 | 
					    PassiveSerializer,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from authentik.core.expression.evaluator import PropertyMappingEvaluator
 | 
					from authentik.core.expression.evaluator import PropertyMappingEvaluator
 | 
				
			||||||
 | 
					from authentik.core.expression.exceptions import PropertyMappingExpressionException
 | 
				
			||||||
from authentik.core.models import Group, PropertyMapping, User
 | 
					from authentik.core.models import Group, PropertyMapping, User
 | 
				
			||||||
from authentik.events.utils import sanitize_item
 | 
					from authentik.events.utils import sanitize_item
 | 
				
			||||||
 | 
					from authentik.lib.utils.errors import exception_to_string
 | 
				
			||||||
from authentik.policies.api.exec import PolicyTestSerializer
 | 
					from authentik.policies.api.exec import PolicyTestSerializer
 | 
				
			||||||
from authentik.rbac.decorators import permission_required
 | 
					from authentik.rbac.decorators import permission_required
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -162,12 +164,15 @@ class PropertyMappingViewSet(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        response_data = {"successful": True, "result": ""}
 | 
					        response_data = {"successful": True, "result": ""}
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            result = mapping.evaluate(**context)
 | 
					            result = mapping.evaluate(dry_run=True, **context)
 | 
				
			||||||
            response_data["result"] = dumps(
 | 
					            response_data["result"] = dumps(
 | 
				
			||||||
                sanitize_item(result), indent=(4 if format_result else None)
 | 
					                sanitize_item(result), indent=(4 if format_result else None)
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					        except PropertyMappingExpressionException as exc:
 | 
				
			||||||
 | 
					            response_data["result"] = exception_to_string(exc.exc)
 | 
				
			||||||
 | 
					            response_data["successful"] = False
 | 
				
			||||||
        except Exception as exc:
 | 
					        except Exception as exc:
 | 
				
			||||||
            response_data["result"] = str(exc)
 | 
					            response_data["result"] = exception_to_string(exc)
 | 
				
			||||||
            response_data["successful"] = False
 | 
					            response_data["successful"] = False
 | 
				
			||||||
        response = PropertyMappingTestResultSerializer(response_data)
 | 
					        response = PropertyMappingTestResultSerializer(response_data)
 | 
				
			||||||
        return Response(response.data)
 | 
					        return Response(response.data)
 | 
				
			||||||
 | 
				
			|||||||
@ -678,10 +678,13 @@ class UserViewSet(UsedByMixin, ModelViewSet):
 | 
				
			|||||||
        if not request.tenant.impersonation:
 | 
					        if not request.tenant.impersonation:
 | 
				
			||||||
            LOGGER.debug("User attempted to impersonate", user=request.user)
 | 
					            LOGGER.debug("User attempted to impersonate", user=request.user)
 | 
				
			||||||
            return Response(status=401)
 | 
					            return Response(status=401)
 | 
				
			||||||
        if not request.user.has_perm("impersonate"):
 | 
					        user_to_be = self.get_object()
 | 
				
			||||||
 | 
					        # Check both object-level perms and global perms
 | 
				
			||||||
 | 
					        if not request.user.has_perm(
 | 
				
			||||||
 | 
					            "authentik_core.impersonate", user_to_be
 | 
				
			||||||
 | 
					        ) and not request.user.has_perm("authentik_core.impersonate"):
 | 
				
			||||||
            LOGGER.debug("User attempted to impersonate without permissions", user=request.user)
 | 
					            LOGGER.debug("User attempted to impersonate without permissions", user=request.user)
 | 
				
			||||||
            return Response(status=401)
 | 
					            return Response(status=401)
 | 
				
			||||||
        user_to_be = self.get_object()
 | 
					 | 
				
			||||||
        if user_to_be.pk == self.request.user.pk:
 | 
					        if user_to_be.pk == self.request.user.pk:
 | 
				
			||||||
            LOGGER.debug("User attempted to impersonate themselves", user=request.user)
 | 
					            LOGGER.debug("User attempted to impersonate themselves", user=request.user)
 | 
				
			||||||
            return Response(status=401)
 | 
					            return Response(status=401)
 | 
				
			||||||
 | 
				
			|||||||
@ -9,10 +9,11 @@ class Command(TenantCommand):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def add_arguments(self, parser):
 | 
					    def add_arguments(self, parser):
 | 
				
			||||||
        parser.add_argument("--type", type=str, required=True)
 | 
					        parser.add_argument("--type", type=str, required=True)
 | 
				
			||||||
        parser.add_argument("--all", action="store_true")
 | 
					        parser.add_argument("--all", action="store_true", default=False)
 | 
				
			||||||
        parser.add_argument("usernames", nargs="+", type=str)
 | 
					        parser.add_argument("usernames", nargs="*", type=str)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def handle_per_tenant(self, **options):
 | 
					    def handle_per_tenant(self, **options):
 | 
				
			||||||
 | 
					        print(options)
 | 
				
			||||||
        new_type = UserTypes(options["type"])
 | 
					        new_type = UserTypes(options["type"])
 | 
				
			||||||
        qs = (
 | 
					        qs = (
 | 
				
			||||||
            User.objects.exclude_anonymous()
 | 
					            User.objects.exclude_anonymous()
 | 
				
			||||||
@ -22,6 +23,9 @@ class Command(TenantCommand):
 | 
				
			|||||||
        if options["usernames"] and options["all"]:
 | 
					        if options["usernames"] and options["all"]:
 | 
				
			||||||
            self.stderr.write("--all and usernames specified, only one can be specified")
 | 
					            self.stderr.write("--all and usernames specified, only one can be specified")
 | 
				
			||||||
            return
 | 
					            return
 | 
				
			||||||
 | 
					        if not options["usernames"] and not options["all"]:
 | 
				
			||||||
 | 
					            self.stderr.write("--all or usernames must be specified")
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
        if options["usernames"] and not options["all"]:
 | 
					        if options["usernames"] and not options["all"]:
 | 
				
			||||||
            qs = qs.filter(username__in=options["usernames"])
 | 
					            qs = qs.filter(username__in=options["usernames"])
 | 
				
			||||||
        updated = qs.update(type=new_type)
 | 
					        updated = qs.update(type=new_type)
 | 
				
			||||||
 | 
				
			|||||||
@ -466,8 +466,6 @@ class ApplicationQuerySet(QuerySet):
 | 
				
			|||||||
    def with_provider(self) -> "QuerySet[Application]":
 | 
					    def with_provider(self) -> "QuerySet[Application]":
 | 
				
			||||||
        qs = self.select_related("provider")
 | 
					        qs = self.select_related("provider")
 | 
				
			||||||
        for subclass in Provider.objects.get_queryset()._get_subclasses_recurse(Provider):
 | 
					        for subclass in Provider.objects.get_queryset()._get_subclasses_recurse(Provider):
 | 
				
			||||||
            if LOOKUP_SEP in subclass:
 | 
					 | 
				
			||||||
                continue
 | 
					 | 
				
			||||||
            qs = qs.select_related(f"provider__{subclass}")
 | 
					            qs = qs.select_related(f"provider__{subclass}")
 | 
				
			||||||
        return qs
 | 
					        return qs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -545,15 +543,24 @@ class Application(SerializerModel, PolicyBindingModel):
 | 
				
			|||||||
        if not self.provider:
 | 
					        if not self.provider:
 | 
				
			||||||
            return None
 | 
					            return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for subclass in Provider.objects.get_queryset()._get_subclasses_recurse(Provider):
 | 
					        candidates = []
 | 
				
			||||||
            # We don't care about recursion, skip nested models
 | 
					        base_class = Provider
 | 
				
			||||||
            if LOOKUP_SEP in subclass:
 | 
					        for subclass in base_class.objects.get_queryset()._get_subclasses_recurse(base_class):
 | 
				
			||||||
 | 
					            parent = self.provider
 | 
				
			||||||
 | 
					            for level in subclass.split(LOOKUP_SEP):
 | 
				
			||||||
 | 
					                try:
 | 
				
			||||||
 | 
					                    parent = getattr(parent, level)
 | 
				
			||||||
 | 
					                except AttributeError:
 | 
				
			||||||
 | 
					                    break
 | 
				
			||||||
 | 
					            if parent in candidates:
 | 
				
			||||||
                continue
 | 
					                continue
 | 
				
			||||||
            try:
 | 
					            idx = subclass.count(LOOKUP_SEP)
 | 
				
			||||||
                return getattr(self.provider, subclass)
 | 
					            if type(parent) is not base_class:
 | 
				
			||||||
            except AttributeError:
 | 
					                idx += 1
 | 
				
			||||||
                pass
 | 
					            candidates.insert(idx, parent)
 | 
				
			||||||
        return None
 | 
					        if not candidates:
 | 
				
			||||||
 | 
					            return None
 | 
				
			||||||
 | 
					        return candidates[-1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __str__(self):
 | 
					    def __str__(self):
 | 
				
			||||||
        return str(self.name)
 | 
					        return str(self.name)
 | 
				
			||||||
@ -901,7 +908,7 @@ class PropertyMapping(SerializerModel, ManagedModel):
 | 
				
			|||||||
        except ControlFlowException as exc:
 | 
					        except ControlFlowException as exc:
 | 
				
			||||||
            raise exc
 | 
					            raise exc
 | 
				
			||||||
        except Exception as exc:
 | 
					        except Exception as exc:
 | 
				
			||||||
            raise PropertyMappingExpressionException(self, exc) from exc
 | 
					            raise PropertyMappingExpressionException(exc, self) from exc
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __str__(self):
 | 
					    def __str__(self):
 | 
				
			||||||
        return f"Property Mapping {self.name}"
 | 
					        return f"Property Mapping {self.name}"
 | 
				
			||||||
 | 
				
			|||||||
@ -9,9 +9,12 @@ from rest_framework.test import APITestCase
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from authentik.core.models import Application
 | 
					from authentik.core.models import Application
 | 
				
			||||||
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
 | 
					from authentik.core.tests.utils import create_test_admin_user, create_test_flow
 | 
				
			||||||
 | 
					from authentik.lib.generators import generate_id
 | 
				
			||||||
from authentik.policies.dummy.models import DummyPolicy
 | 
					from authentik.policies.dummy.models import DummyPolicy
 | 
				
			||||||
from authentik.policies.models import PolicyBinding
 | 
					from authentik.policies.models import PolicyBinding
 | 
				
			||||||
from authentik.providers.oauth2.models import OAuth2Provider
 | 
					from authentik.providers.oauth2.models import OAuth2Provider, RedirectURI, RedirectURIMatchingMode
 | 
				
			||||||
 | 
					from authentik.providers.proxy.models import ProxyProvider
 | 
				
			||||||
 | 
					from authentik.providers.saml.models import SAMLProvider
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TestApplicationsAPI(APITestCase):
 | 
					class TestApplicationsAPI(APITestCase):
 | 
				
			||||||
@ -21,7 +24,7 @@ class TestApplicationsAPI(APITestCase):
 | 
				
			|||||||
        self.user = create_test_admin_user()
 | 
					        self.user = create_test_admin_user()
 | 
				
			||||||
        self.provider = OAuth2Provider.objects.create(
 | 
					        self.provider = OAuth2Provider.objects.create(
 | 
				
			||||||
            name="test",
 | 
					            name="test",
 | 
				
			||||||
            redirect_uris="http://some-other-domain",
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://some-other-domain")],
 | 
				
			||||||
            authorization_flow=create_test_flow(),
 | 
					            authorization_flow=create_test_flow(),
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        self.allowed: Application = Application.objects.create(
 | 
					        self.allowed: Application = Application.objects.create(
 | 
				
			||||||
@ -222,3 +225,31 @@ class TestApplicationsAPI(APITestCase):
 | 
				
			|||||||
                ],
 | 
					                ],
 | 
				
			||||||
            },
 | 
					            },
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_get_provider(self):
 | 
				
			||||||
 | 
					        """Ensure that proxy providers (at the time of writing that is the only provider
 | 
				
			||||||
 | 
					        that inherits from another proxy type (OAuth) instead of inheriting from the root
 | 
				
			||||||
 | 
					        provider class) is correctly looked up and selected from the database"""
 | 
				
			||||||
 | 
					        slug = generate_id()
 | 
				
			||||||
 | 
					        provider = ProxyProvider.objects.create(name=generate_id())
 | 
				
			||||||
 | 
					        Application.objects.create(
 | 
				
			||||||
 | 
					            name=generate_id(),
 | 
				
			||||||
 | 
					            slug=slug,
 | 
				
			||||||
 | 
					            provider=provider,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual(Application.objects.get(slug=slug).get_provider(), provider)
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            Application.objects.with_provider().get(slug=slug).get_provider(), provider
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        slug = generate_id()
 | 
				
			||||||
 | 
					        provider = SAMLProvider.objects.create(name=generate_id())
 | 
				
			||||||
 | 
					        Application.objects.create(
 | 
				
			||||||
 | 
					            name=generate_id(),
 | 
				
			||||||
 | 
					            slug=slug,
 | 
				
			||||||
 | 
					            provider=provider,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual(Application.objects.get(slug=slug).get_provider(), provider)
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            Application.objects.with_provider().get(slug=slug).get_provider(), provider
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
				
			|||||||
@ -3,10 +3,10 @@
 | 
				
			|||||||
from json import loads
 | 
					from json import loads
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from django.urls import reverse
 | 
					from django.urls import reverse
 | 
				
			||||||
 | 
					from guardian.shortcuts import assign_perm
 | 
				
			||||||
from rest_framework.test import APITestCase
 | 
					from rest_framework.test import APITestCase
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from authentik.core.models import User
 | 
					from authentik.core.tests.utils import create_test_admin_user, create_test_user
 | 
				
			||||||
from authentik.core.tests.utils import create_test_admin_user
 | 
					 | 
				
			||||||
from authentik.tenants.utils import get_current_tenant
 | 
					from authentik.tenants.utils import get_current_tenant
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -15,7 +15,7 @@ class TestImpersonation(APITestCase):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def setUp(self) -> None:
 | 
					    def setUp(self) -> None:
 | 
				
			||||||
        super().setUp()
 | 
					        super().setUp()
 | 
				
			||||||
        self.other_user = User.objects.create(username="to-impersonate")
 | 
					        self.other_user = create_test_user()
 | 
				
			||||||
        self.user = create_test_admin_user()
 | 
					        self.user = create_test_admin_user()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_impersonate_simple(self):
 | 
					    def test_impersonate_simple(self):
 | 
				
			||||||
@ -44,6 +44,46 @@ class TestImpersonation(APITestCase):
 | 
				
			|||||||
        self.assertEqual(response_body["user"]["username"], self.user.username)
 | 
					        self.assertEqual(response_body["user"]["username"], self.user.username)
 | 
				
			||||||
        self.assertNotIn("original", response_body)
 | 
					        self.assertNotIn("original", response_body)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_impersonate_global(self):
 | 
				
			||||||
 | 
					        """Test impersonation with global permissions"""
 | 
				
			||||||
 | 
					        new_user = create_test_user()
 | 
				
			||||||
 | 
					        assign_perm("authentik_core.impersonate", new_user)
 | 
				
			||||||
 | 
					        assign_perm("authentik_core.view_user", new_user)
 | 
				
			||||||
 | 
					        self.client.force_login(new_user)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        response = self.client.post(
 | 
				
			||||||
 | 
					            reverse(
 | 
				
			||||||
 | 
					                "authentik_api:user-impersonate",
 | 
				
			||||||
 | 
					                kwargs={"pk": self.other_user.pk},
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual(response.status_code, 201)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        response = self.client.get(reverse("authentik_api:user-me"))
 | 
				
			||||||
 | 
					        response_body = loads(response.content.decode())
 | 
				
			||||||
 | 
					        self.assertEqual(response_body["user"]["username"], self.other_user.username)
 | 
				
			||||||
 | 
					        self.assertEqual(response_body["original"]["username"], new_user.username)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_impersonate_scoped(self):
 | 
				
			||||||
 | 
					        """Test impersonation with scoped permissions"""
 | 
				
			||||||
 | 
					        new_user = create_test_user()
 | 
				
			||||||
 | 
					        assign_perm("authentik_core.impersonate", new_user, self.other_user)
 | 
				
			||||||
 | 
					        assign_perm("authentik_core.view_user", new_user, self.other_user)
 | 
				
			||||||
 | 
					        self.client.force_login(new_user)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        response = self.client.post(
 | 
				
			||||||
 | 
					            reverse(
 | 
				
			||||||
 | 
					                "authentik_api:user-impersonate",
 | 
				
			||||||
 | 
					                kwargs={"pk": self.other_user.pk},
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual(response.status_code, 201)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        response = self.client.get(reverse("authentik_api:user-me"))
 | 
				
			||||||
 | 
					        response_body = loads(response.content.decode())
 | 
				
			||||||
 | 
					        self.assertEqual(response_body["user"]["username"], self.other_user.username)
 | 
				
			||||||
 | 
					        self.assertEqual(response_body["original"]["username"], new_user.username)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_impersonate_denied(self):
 | 
					    def test_impersonate_denied(self):
 | 
				
			||||||
        """test impersonation without permissions"""
 | 
					        """test impersonation without permissions"""
 | 
				
			||||||
        self.client.force_login(self.other_user)
 | 
					        self.client.force_login(self.other_user)
 | 
				
			||||||
 | 
				
			|||||||
@ -31,6 +31,7 @@ class TestTransactionalApplicationsAPI(APITestCase):
 | 
				
			|||||||
                "provider": {
 | 
					                "provider": {
 | 
				
			||||||
                    "name": uid,
 | 
					                    "name": uid,
 | 
				
			||||||
                    "authorization_flow": str(authorization_flow.pk),
 | 
					                    "authorization_flow": str(authorization_flow.pk),
 | 
				
			||||||
 | 
					                    "redirect_uris": [],
 | 
				
			||||||
                },
 | 
					                },
 | 
				
			||||||
            },
 | 
					            },
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
@ -56,6 +57,7 @@ class TestTransactionalApplicationsAPI(APITestCase):
 | 
				
			|||||||
                "provider": {
 | 
					                "provider": {
 | 
				
			||||||
                    "name": uid,
 | 
					                    "name": uid,
 | 
				
			||||||
                    "authorization_flow": "",
 | 
					                    "authorization_flow": "",
 | 
				
			||||||
 | 
					                    "redirect_uris": [],
 | 
				
			||||||
                },
 | 
					                },
 | 
				
			||||||
            },
 | 
					            },
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
				
			|||||||
@ -18,7 +18,7 @@ from authentik.crypto.models import CertificateKeyPair
 | 
				
			|||||||
from authentik.crypto.tasks import MANAGED_DISCOVERED, certificate_discovery
 | 
					from authentik.crypto.tasks import MANAGED_DISCOVERED, certificate_discovery
 | 
				
			||||||
from authentik.lib.config import CONFIG
 | 
					from authentik.lib.config import CONFIG
 | 
				
			||||||
from authentik.lib.generators import generate_id, generate_key
 | 
					from authentik.lib.generators import generate_id, generate_key
 | 
				
			||||||
from authentik.providers.oauth2.models import OAuth2Provider
 | 
					from authentik.providers.oauth2.models import OAuth2Provider, RedirectURI, RedirectURIMatchingMode
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TestCrypto(APITestCase):
 | 
					class TestCrypto(APITestCase):
 | 
				
			||||||
@ -263,7 +263,7 @@ class TestCrypto(APITestCase):
 | 
				
			|||||||
            client_id="test",
 | 
					            client_id="test",
 | 
				
			||||||
            client_secret=generate_key(),
 | 
					            client_secret=generate_key(),
 | 
				
			||||||
            authorization_flow=create_test_flow(),
 | 
					            authorization_flow=create_test_flow(),
 | 
				
			||||||
            redirect_uris="http://localhost",
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")],
 | 
				
			||||||
            signing_key=keypair,
 | 
					            signing_key=keypair,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        response = self.client.get(
 | 
					        response = self.client.get(
 | 
				
			||||||
@ -295,7 +295,7 @@ class TestCrypto(APITestCase):
 | 
				
			|||||||
            client_id="test",
 | 
					            client_id="test",
 | 
				
			||||||
            client_secret=generate_key(),
 | 
					            client_secret=generate_key(),
 | 
				
			||||||
            authorization_flow=create_test_flow(),
 | 
					            authorization_flow=create_test_flow(),
 | 
				
			||||||
            redirect_uris="http://localhost",
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")],
 | 
				
			||||||
            signing_key=keypair,
 | 
					            signing_key=keypair,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        response = self.client.get(
 | 
					        response = self.client.get(
 | 
				
			||||||
 | 
				
			|||||||
@ -18,7 +18,7 @@ from authentik.core.api.used_by import UsedByMixin
 | 
				
			|||||||
from authentik.core.api.utils import ModelSerializer, PassiveSerializer
 | 
					from authentik.core.api.utils import ModelSerializer, PassiveSerializer
 | 
				
			||||||
from authentik.core.models import User, UserTypes
 | 
					from authentik.core.models import User, UserTypes
 | 
				
			||||||
from authentik.enterprise.license import LicenseKey, LicenseSummarySerializer
 | 
					from authentik.enterprise.license import LicenseKey, LicenseSummarySerializer
 | 
				
			||||||
from authentik.enterprise.models import License, LicenseUsageStatus
 | 
					from authentik.enterprise.models import License
 | 
				
			||||||
from authentik.rbac.decorators import permission_required
 | 
					from authentik.rbac.decorators import permission_required
 | 
				
			||||||
from authentik.tenants.utils import get_unique_identifier
 | 
					from authentik.tenants.utils import get_unique_identifier
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -29,7 +29,7 @@ class EnterpriseRequiredMixin:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def validate(self, attrs: dict) -> dict:
 | 
					    def validate(self, attrs: dict) -> dict:
 | 
				
			||||||
        """Check that a valid license exists"""
 | 
					        """Check that a valid license exists"""
 | 
				
			||||||
        if LicenseKey.cached_summary().status != LicenseUsageStatus.UNLICENSED:
 | 
					        if not LicenseKey.cached_summary().status.is_valid:
 | 
				
			||||||
            raise ValidationError(_("Enterprise is required to create/update this object."))
 | 
					            raise ValidationError(_("Enterprise is required to create/update this object."))
 | 
				
			||||||
        return super().validate(attrs)
 | 
					        return super().validate(attrs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -25,4 +25,4 @@ class AuthentikEnterpriseConfig(EnterpriseConfig):
 | 
				
			|||||||
        """Actual enterprise check, cached"""
 | 
					        """Actual enterprise check, cached"""
 | 
				
			||||||
        from authentik.enterprise.license import LicenseKey
 | 
					        from authentik.enterprise.license import LicenseKey
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return LicenseKey.cached_summary().status
 | 
					        return LicenseKey.cached_summary().status.is_valid
 | 
				
			||||||
 | 
				
			|||||||
@ -117,10 +117,13 @@ class LicenseKey:
 | 
				
			|||||||
                    our_cert.public_key(),
 | 
					                    our_cert.public_key(),
 | 
				
			||||||
                    algorithms=["ES512"],
 | 
					                    algorithms=["ES512"],
 | 
				
			||||||
                    audience=get_license_aud(),
 | 
					                    audience=get_license_aud(),
 | 
				
			||||||
                    options={"verify_exp": check_expiry},
 | 
					                    options={"verify_exp": check_expiry, "verify_signature": check_expiry},
 | 
				
			||||||
                ),
 | 
					                ),
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
        except PyJWTError:
 | 
					        except PyJWTError:
 | 
				
			||||||
 | 
					            unverified = decode(jwt, options={"verify_signature": False})
 | 
				
			||||||
 | 
					            if unverified["aud"] != get_license_aud():
 | 
				
			||||||
 | 
					                raise ValidationError("Invalid Install ID in license") from None
 | 
				
			||||||
            raise ValidationError("Unable to verify license") from None
 | 
					            raise ValidationError("Unable to verify license") from None
 | 
				
			||||||
        return body
 | 
					        return body
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -134,7 +137,7 @@ class LicenseKey:
 | 
				
			|||||||
            exp_ts = int(mktime(lic.expiry.timetuple()))
 | 
					            exp_ts = int(mktime(lic.expiry.timetuple()))
 | 
				
			||||||
            if total.exp == 0:
 | 
					            if total.exp == 0:
 | 
				
			||||||
                total.exp = exp_ts
 | 
					                total.exp = exp_ts
 | 
				
			||||||
            total.exp = min(total.exp, exp_ts)
 | 
					            total.exp = max(total.exp, exp_ts)
 | 
				
			||||||
            total.license_flags.extend(lic.status.license_flags)
 | 
					            total.license_flags.extend(lic.status.license_flags)
 | 
				
			||||||
        return total
 | 
					        return total
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -3,7 +3,7 @@
 | 
				
			|||||||
from datetime import datetime
 | 
					from datetime import datetime
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from django.core.cache import cache
 | 
					from django.core.cache import cache
 | 
				
			||||||
from django.db.models.signals import post_save, pre_save
 | 
					from django.db.models.signals import post_delete, post_save, pre_save
 | 
				
			||||||
from django.dispatch import receiver
 | 
					from django.dispatch import receiver
 | 
				
			||||||
from django.utils.timezone import get_current_timezone
 | 
					from django.utils.timezone import get_current_timezone
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -27,3 +27,9 @@ def post_save_license(sender: type[License], instance: License, **_):
 | 
				
			|||||||
    """Trigger license usage calculation when license is saved"""
 | 
					    """Trigger license usage calculation when license is saved"""
 | 
				
			||||||
    cache.delete(CACHE_KEY_ENTERPRISE_LICENSE)
 | 
					    cache.delete(CACHE_KEY_ENTERPRISE_LICENSE)
 | 
				
			||||||
    enterprise_update_usage.delay()
 | 
					    enterprise_update_usage.delay()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@receiver(post_delete, sender=License)
 | 
				
			||||||
 | 
					def post_delete_license(sender: type[License], instance: License, **_):
 | 
				
			||||||
 | 
					    """Clear license cache when license is deleted"""
 | 
				
			||||||
 | 
					    cache.delete(CACHE_KEY_ENTERPRISE_LICENSE)
 | 
				
			||||||
 | 
				
			|||||||
@ -69,8 +69,5 @@ class NotificationViewSet(
 | 
				
			|||||||
    @action(detail=False, methods=["post"])
 | 
					    @action(detail=False, methods=["post"])
 | 
				
			||||||
    def mark_all_seen(self, request: Request) -> Response:
 | 
					    def mark_all_seen(self, request: Request) -> Response:
 | 
				
			||||||
        """Mark all the user's notifications as seen"""
 | 
					        """Mark all the user's notifications as seen"""
 | 
				
			||||||
        notifications = Notification.objects.filter(user=request.user)
 | 
					        Notification.objects.filter(user=request.user, seen=False).update(seen=True)
 | 
				
			||||||
        for notification in notifications:
 | 
					 | 
				
			||||||
            notification.seen = True
 | 
					 | 
				
			||||||
        Notification.objects.bulk_update(notifications, ["seen"])
 | 
					 | 
				
			||||||
        return Response({}, status=204)
 | 
					        return Response({}, status=204)
 | 
				
			||||||
 | 
				
			|||||||
@ -49,6 +49,7 @@ from authentik.policies.models import PolicyBindingModel
 | 
				
			|||||||
from authentik.root.middleware import ClientIPMiddleware
 | 
					from authentik.root.middleware import ClientIPMiddleware
 | 
				
			||||||
from authentik.stages.email.utils import TemplateEmailMessage
 | 
					from authentik.stages.email.utils import TemplateEmailMessage
 | 
				
			||||||
from authentik.tenants.models import Tenant
 | 
					from authentik.tenants.models import Tenant
 | 
				
			||||||
 | 
					from authentik.tenants.utils import get_current_tenant
 | 
				
			||||||
 | 
					
 | 
				
			||||||
LOGGER = get_logger()
 | 
					LOGGER = get_logger()
 | 
				
			||||||
DISCORD_FIELD_LIMIT = 25
 | 
					DISCORD_FIELD_LIMIT = 25
 | 
				
			||||||
@ -58,7 +59,11 @@ NOTIFICATION_SUMMARY_LENGTH = 75
 | 
				
			|||||||
def default_event_duration():
 | 
					def default_event_duration():
 | 
				
			||||||
    """Default duration an Event is saved.
 | 
					    """Default duration an Event is saved.
 | 
				
			||||||
    This is used as a fallback when no brand is available"""
 | 
					    This is used as a fallback when no brand is available"""
 | 
				
			||||||
    return now() + timedelta(days=365)
 | 
					    try:
 | 
				
			||||||
 | 
					        tenant = get_current_tenant()
 | 
				
			||||||
 | 
					        return now() + timedelta_from_string(tenant.event_retention)
 | 
				
			||||||
 | 
					    except Tenant.DoesNotExist:
 | 
				
			||||||
 | 
					        return now() + timedelta(days=365)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def default_brand():
 | 
					def default_brand():
 | 
				
			||||||
@ -245,12 +250,6 @@ class Event(SerializerModel, ExpiringModel):
 | 
				
			|||||||
            if QS_QUERY in self.context["http_request"]["args"]:
 | 
					            if QS_QUERY in self.context["http_request"]["args"]:
 | 
				
			||||||
                wrapped = self.context["http_request"]["args"][QS_QUERY]
 | 
					                wrapped = self.context["http_request"]["args"][QS_QUERY]
 | 
				
			||||||
                self.context["http_request"]["args"] = cleanse_dict(QueryDict(wrapped))
 | 
					                self.context["http_request"]["args"] = cleanse_dict(QueryDict(wrapped))
 | 
				
			||||||
        if hasattr(request, "tenant"):
 | 
					 | 
				
			||||||
            tenant: Tenant = request.tenant
 | 
					 | 
				
			||||||
            # Because self.created only gets set on save, we can't use it's value here
 | 
					 | 
				
			||||||
            # hence we set self.created to now and then use it
 | 
					 | 
				
			||||||
            self.created = now()
 | 
					 | 
				
			||||||
            self.expires = self.created + timedelta_from_string(tenant.event_retention)
 | 
					 | 
				
			||||||
        if hasattr(request, "brand"):
 | 
					        if hasattr(request, "brand"):
 | 
				
			||||||
            brand: Brand = request.brand
 | 
					            brand: Brand = request.brand
 | 
				
			||||||
            self.brand = sanitize_dict(model_to_dict(brand))
 | 
					            self.brand = sanitize_dict(model_to_dict(brand))
 | 
				
			||||||
 | 
				
			|||||||
@ -1,13 +1,16 @@
 | 
				
			|||||||
"""authentik events signal listener"""
 | 
					"""authentik events signal listener"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from importlib import import_module
 | 
				
			||||||
from typing import Any
 | 
					from typing import Any
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from django.conf import settings
 | 
				
			||||||
from django.contrib.auth.signals import user_logged_in, user_logged_out
 | 
					from django.contrib.auth.signals import user_logged_in, user_logged_out
 | 
				
			||||||
from django.db.models.signals import post_save, pre_delete
 | 
					from django.db.models.signals import post_save, pre_delete
 | 
				
			||||||
from django.dispatch import receiver
 | 
					from django.dispatch import receiver
 | 
				
			||||||
from django.http import HttpRequest
 | 
					from django.http import HttpRequest
 | 
				
			||||||
 | 
					from rest_framework.request import Request
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from authentik.core.models import User
 | 
					from authentik.core.models import AuthenticatedSession, User
 | 
				
			||||||
from authentik.core.signals import login_failed, password_changed
 | 
					from authentik.core.signals import login_failed, password_changed
 | 
				
			||||||
from authentik.events.apps import SYSTEM_TASK_STATUS
 | 
					from authentik.events.apps import SYSTEM_TASK_STATUS
 | 
				
			||||||
from authentik.events.models import Event, EventAction, SystemTask
 | 
					from authentik.events.models import Event, EventAction, SystemTask
 | 
				
			||||||
@ -23,6 +26,7 @@ from authentik.stages.user_write.signals import user_write
 | 
				
			|||||||
from authentik.tenants.utils import get_current_tenant
 | 
					from authentik.tenants.utils import get_current_tenant
 | 
				
			||||||
 | 
					
 | 
				
			||||||
SESSION_LOGIN_EVENT = "login_event"
 | 
					SESSION_LOGIN_EVENT = "login_event"
 | 
				
			||||||
 | 
					_session_engine = import_module(settings.SESSION_ENGINE)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@receiver(user_logged_in)
 | 
					@receiver(user_logged_in)
 | 
				
			||||||
@ -40,11 +44,20 @@ def on_user_logged_in(sender, request: HttpRequest, user: User, **_):
 | 
				
			|||||||
            kwargs[PLAN_CONTEXT_METHOD_ARGS] = flow_plan.context.get(PLAN_CONTEXT_METHOD_ARGS, {})
 | 
					            kwargs[PLAN_CONTEXT_METHOD_ARGS] = flow_plan.context.get(PLAN_CONTEXT_METHOD_ARGS, {})
 | 
				
			||||||
    event = Event.new(EventAction.LOGIN, **kwargs).from_http(request, user=user)
 | 
					    event = Event.new(EventAction.LOGIN, **kwargs).from_http(request, user=user)
 | 
				
			||||||
    request.session[SESSION_LOGIN_EVENT] = event
 | 
					    request.session[SESSION_LOGIN_EVENT] = event
 | 
				
			||||||
 | 
					    request.session.save()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_login_event(request: HttpRequest) -> Event | None:
 | 
					def get_login_event(request_or_session: HttpRequest | AuthenticatedSession | None) -> Event | None:
 | 
				
			||||||
    """Wrapper to get login event that can be mocked in tests"""
 | 
					    """Wrapper to get login event that can be mocked in tests"""
 | 
				
			||||||
    return request.session.get(SESSION_LOGIN_EVENT, None)
 | 
					    session = None
 | 
				
			||||||
 | 
					    if not request_or_session:
 | 
				
			||||||
 | 
					        return None
 | 
				
			||||||
 | 
					    if isinstance(request_or_session, HttpRequest | Request):
 | 
				
			||||||
 | 
					        session = request_or_session.session
 | 
				
			||||||
 | 
					    if isinstance(request_or_session, AuthenticatedSession):
 | 
				
			||||||
 | 
					        SessionStore = _session_engine.SessionStore
 | 
				
			||||||
 | 
					        session = SessionStore(request_or_session.session_key)
 | 
				
			||||||
 | 
					    return session.get(SESSION_LOGIN_EVENT, None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@receiver(user_logged_out)
 | 
					@receiver(user_logged_out)
 | 
				
			||||||
 | 
				
			|||||||
@ -6,6 +6,7 @@ from django.db.models import Model
 | 
				
			|||||||
from django.test import TestCase
 | 
					from django.test import TestCase
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from authentik.core.models import default_token_key
 | 
					from authentik.core.models import default_token_key
 | 
				
			||||||
 | 
					from authentik.events.models import default_event_duration
 | 
				
			||||||
from authentik.lib.utils.reflection import get_apps
 | 
					from authentik.lib.utils.reflection import get_apps
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -20,7 +21,7 @@ def model_tester_factory(test_model: type[Model]) -> Callable:
 | 
				
			|||||||
        allowed = 0
 | 
					        allowed = 0
 | 
				
			||||||
        # Token-like objects need to lookup the current tenant to get the default token length
 | 
					        # Token-like objects need to lookup the current tenant to get the default token length
 | 
				
			||||||
        for field in test_model._meta.fields:
 | 
					        for field in test_model._meta.fields:
 | 
				
			||||||
            if field.default == default_token_key:
 | 
					            if field.default in [default_token_key, default_event_duration]:
 | 
				
			||||||
                allowed += 1
 | 
					                allowed += 1
 | 
				
			||||||
        with self.assertNumQueries(allowed):
 | 
					        with self.assertNumQueries(allowed):
 | 
				
			||||||
            str(test_model())
 | 
					            str(test_model())
 | 
				
			||||||
 | 
				
			|||||||
@ -2,7 +2,8 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from unittest.mock import MagicMock, patch
 | 
					from unittest.mock import MagicMock, patch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from django.test import TestCase
 | 
					from django.urls import reverse
 | 
				
			||||||
 | 
					from rest_framework.test import APITestCase
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from authentik.core.models import Group, User
 | 
					from authentik.core.models import Group, User
 | 
				
			||||||
from authentik.events.models import (
 | 
					from authentik.events.models import (
 | 
				
			||||||
@ -10,6 +11,7 @@ from authentik.events.models import (
 | 
				
			|||||||
    EventAction,
 | 
					    EventAction,
 | 
				
			||||||
    Notification,
 | 
					    Notification,
 | 
				
			||||||
    NotificationRule,
 | 
					    NotificationRule,
 | 
				
			||||||
 | 
					    NotificationSeverity,
 | 
				
			||||||
    NotificationTransport,
 | 
					    NotificationTransport,
 | 
				
			||||||
    NotificationWebhookMapping,
 | 
					    NotificationWebhookMapping,
 | 
				
			||||||
    TransportMode,
 | 
					    TransportMode,
 | 
				
			||||||
@ -20,7 +22,7 @@ from authentik.policies.exceptions import PolicyException
 | 
				
			|||||||
from authentik.policies.models import PolicyBinding
 | 
					from authentik.policies.models import PolicyBinding
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TestEventsNotifications(TestCase):
 | 
					class TestEventsNotifications(APITestCase):
 | 
				
			||||||
    """Test Event Notifications"""
 | 
					    """Test Event Notifications"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def setUp(self) -> None:
 | 
					    def setUp(self) -> None:
 | 
				
			||||||
@ -131,3 +133,15 @@ class TestEventsNotifications(TestCase):
 | 
				
			|||||||
        Notification.objects.all().delete()
 | 
					        Notification.objects.all().delete()
 | 
				
			||||||
        Event.new(EventAction.CUSTOM_PREFIX).save()
 | 
					        Event.new(EventAction.CUSTOM_PREFIX).save()
 | 
				
			||||||
        self.assertEqual(Notification.objects.first().body, "foo")
 | 
					        self.assertEqual(Notification.objects.first().body, "foo")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_api_mark_all_seen(self):
 | 
				
			||||||
 | 
					        """Test mark_all_seen"""
 | 
				
			||||||
 | 
					        self.client.force_login(self.user)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Notification.objects.create(
 | 
				
			||||||
 | 
					            severity=NotificationSeverity.NOTICE, body="foo", user=self.user, seen=False
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        response = self.client.post(reverse("authentik_api:notification-mark-all-seen"))
 | 
				
			||||||
 | 
					        self.assertEqual(response.status_code, 204)
 | 
				
			||||||
 | 
					        self.assertFalse(Notification.objects.filter(body="foo", seen=False).exists())
 | 
				
			||||||
 | 
				
			|||||||
@ -2,7 +2,6 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import re
 | 
					import re
 | 
				
			||||||
import socket
 | 
					import socket
 | 
				
			||||||
from collections.abc import Iterable
 | 
					 | 
				
			||||||
from ipaddress import ip_address, ip_network
 | 
					from ipaddress import ip_address, ip_network
 | 
				
			||||||
from textwrap import indent
 | 
					from textwrap import indent
 | 
				
			||||||
from types import CodeType
 | 
					from types import CodeType
 | 
				
			||||||
@ -28,6 +27,12 @@ from authentik.stages.authenticator import devices_for_user
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
LOGGER = get_logger()
 | 
					LOGGER = get_logger()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					ARG_SANITIZE = re.compile(r"[:.-]")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def sanitize_arg(arg_name: str) -> str:
 | 
				
			||||||
 | 
					    return re.sub(ARG_SANITIZE, "_", arg_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class BaseEvaluator:
 | 
					class BaseEvaluator:
 | 
				
			||||||
    """Validate and evaluate python-based expressions"""
 | 
					    """Validate and evaluate python-based expressions"""
 | 
				
			||||||
@ -177,9 +182,9 @@ class BaseEvaluator:
 | 
				
			|||||||
        proc = PolicyProcess(PolicyBinding(policy=policy), request=req, connection=None)
 | 
					        proc = PolicyProcess(PolicyBinding(policy=policy), request=req, connection=None)
 | 
				
			||||||
        return proc.profiling_wrapper()
 | 
					        return proc.profiling_wrapper()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def wrap_expression(self, expression: str, params: Iterable[str]) -> str:
 | 
					    def wrap_expression(self, expression: str) -> str:
 | 
				
			||||||
        """Wrap expression in a function, call it, and save the result as `result`"""
 | 
					        """Wrap expression in a function, call it, and save the result as `result`"""
 | 
				
			||||||
        handler_signature = ",".join(params)
 | 
					        handler_signature = ",".join(sanitize_arg(x) for x in self._context.keys())
 | 
				
			||||||
        full_expression = ""
 | 
					        full_expression = ""
 | 
				
			||||||
        full_expression += f"def handler({handler_signature}):\n"
 | 
					        full_expression += f"def handler({handler_signature}):\n"
 | 
				
			||||||
        full_expression += indent(expression, "    ")
 | 
					        full_expression += indent(expression, "    ")
 | 
				
			||||||
@ -188,8 +193,8 @@ class BaseEvaluator:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def compile(self, expression: str) -> CodeType:
 | 
					    def compile(self, expression: str) -> CodeType:
 | 
				
			||||||
        """Parse expression. Raises SyntaxError or ValueError if the syntax is incorrect."""
 | 
					        """Parse expression. Raises SyntaxError or ValueError if the syntax is incorrect."""
 | 
				
			||||||
        param_keys = self._context.keys()
 | 
					        expression = self.wrap_expression(expression)
 | 
				
			||||||
        return compile(self.wrap_expression(expression, param_keys), self._filename, "exec")
 | 
					        return compile(expression, self._filename, "exec")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def evaluate(self, expression_source: str) -> Any:
 | 
					    def evaluate(self, expression_source: str) -> Any:
 | 
				
			||||||
        """Parse and evaluate expression. If the syntax is incorrect, a SyntaxError is raised.
 | 
					        """Parse and evaluate expression. If the syntax is incorrect, a SyntaxError is raised.
 | 
				
			||||||
@ -205,7 +210,7 @@ class BaseEvaluator:
 | 
				
			|||||||
                self.handle_error(exc, expression_source)
 | 
					                self.handle_error(exc, expression_source)
 | 
				
			||||||
                raise exc
 | 
					                raise exc
 | 
				
			||||||
            try:
 | 
					            try:
 | 
				
			||||||
                _locals = self._context
 | 
					                _locals = {sanitize_arg(x): y for x, y in self._context.items()}
 | 
				
			||||||
                # Yes this is an exec, yes it is potentially bad. Since we limit what variables are
 | 
					                # Yes this is an exec, yes it is potentially bad. Since we limit what variables are
 | 
				
			||||||
                # available here, and these policies can only be edited by admins, this is a risk
 | 
					                # available here, and these policies can only be edited by admins, this is a risk
 | 
				
			||||||
                # we're willing to take.
 | 
					                # we're willing to take.
 | 
				
			||||||
 | 
				
			|||||||
@ -30,6 +30,11 @@ class TestHTTP(TestCase):
 | 
				
			|||||||
        request = self.factory.get("/", HTTP_X_FORWARDED_FOR="127.0.0.2")
 | 
					        request = self.factory.get("/", HTTP_X_FORWARDED_FOR="127.0.0.2")
 | 
				
			||||||
        self.assertEqual(ClientIPMiddleware.get_client_ip(request), "127.0.0.2")
 | 
					        self.assertEqual(ClientIPMiddleware.get_client_ip(request), "127.0.0.2")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_forward_for_invalid(self):
 | 
				
			||||||
 | 
					        """Test invalid forward for"""
 | 
				
			||||||
 | 
					        request = self.factory.get("/", HTTP_X_FORWARDED_FOR="foobar")
 | 
				
			||||||
 | 
					        self.assertEqual(ClientIPMiddleware.get_client_ip(request), ClientIPMiddleware.default_ip)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_fake_outpost(self):
 | 
					    def test_fake_outpost(self):
 | 
				
			||||||
        """Test faked IP which is overridden by an outpost"""
 | 
					        """Test faked IP which is overridden by an outpost"""
 | 
				
			||||||
        token = Token.objects.create(
 | 
					        token = Token.objects.create(
 | 
				
			||||||
@ -53,6 +58,17 @@ class TestHTTP(TestCase):
 | 
				
			|||||||
            },
 | 
					            },
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        self.assertEqual(ClientIPMiddleware.get_client_ip(request), "127.0.0.1")
 | 
					        self.assertEqual(ClientIPMiddleware.get_client_ip(request), "127.0.0.1")
 | 
				
			||||||
 | 
					        # Invalid, not a real IP
 | 
				
			||||||
 | 
					        self.user.type = UserTypes.INTERNAL_SERVICE_ACCOUNT
 | 
				
			||||||
 | 
					        self.user.save()
 | 
				
			||||||
 | 
					        request = self.factory.get(
 | 
				
			||||||
 | 
					            "/",
 | 
				
			||||||
 | 
					            **{
 | 
				
			||||||
 | 
					                ClientIPMiddleware.outpost_remote_ip_header: "foobar",
 | 
				
			||||||
 | 
					                ClientIPMiddleware.outpost_token_header: token.key,
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual(ClientIPMiddleware.get_client_ip(request), "127.0.0.1")
 | 
				
			||||||
        # Valid
 | 
					        # Valid
 | 
				
			||||||
        self.user.type = UserTypes.INTERNAL_SERVICE_ACCOUNT
 | 
					        self.user.type = UserTypes.INTERNAL_SERVICE_ACCOUNT
 | 
				
			||||||
        self.user.save()
 | 
					        self.user.save()
 | 
				
			||||||
 | 
				
			|||||||
@ -21,7 +21,14 @@ class DebugSession(Session):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def send(self, req: PreparedRequest, *args, **kwargs):
 | 
					    def send(self, req: PreparedRequest, *args, **kwargs):
 | 
				
			||||||
        request_id = str(uuid4())
 | 
					        request_id = str(uuid4())
 | 
				
			||||||
        LOGGER.debug("HTTP request sent", uid=request_id, path=req.path_url, headers=req.headers)
 | 
					        LOGGER.debug(
 | 
				
			||||||
 | 
					            "HTTP request sent",
 | 
				
			||||||
 | 
					            uid=request_id,
 | 
				
			||||||
 | 
					            url=req.url,
 | 
				
			||||||
 | 
					            method=req.method,
 | 
				
			||||||
 | 
					            headers=req.headers,
 | 
				
			||||||
 | 
					            body=req.body,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        resp = super().send(req, *args, **kwargs)
 | 
					        resp = super().send(req, *args, **kwargs)
 | 
				
			||||||
        LOGGER.debug(
 | 
					        LOGGER.debug(
 | 
				
			||||||
            "HTTP response received",
 | 
					            "HTTP response received",
 | 
				
			||||||
 | 
				
			|||||||
@ -108,7 +108,7 @@ class EventMatcherPolicy(Policy):
 | 
				
			|||||||
                result=result,
 | 
					                result=result,
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            matches.append(result)
 | 
					            matches.append(result)
 | 
				
			||||||
        passing = any(x.passing for x in matches)
 | 
					        passing = all(x.passing for x in matches)
 | 
				
			||||||
        messages = chain(*[x.messages for x in matches])
 | 
					        messages = chain(*[x.messages for x in matches])
 | 
				
			||||||
        result = PolicyResult(passing, *messages)
 | 
					        result = PolicyResult(passing, *messages)
 | 
				
			||||||
        result.source_results = matches
 | 
					        result.source_results = matches
 | 
				
			||||||
 | 
				
			|||||||
@ -77,11 +77,24 @@ class TestEventMatcherPolicy(TestCase):
 | 
				
			|||||||
        request = PolicyRequest(get_anonymous_user())
 | 
					        request = PolicyRequest(get_anonymous_user())
 | 
				
			||||||
        request.context["event"] = event
 | 
					        request.context["event"] = event
 | 
				
			||||||
        policy: EventMatcherPolicy = EventMatcherPolicy.objects.create(
 | 
					        policy: EventMatcherPolicy = EventMatcherPolicy.objects.create(
 | 
				
			||||||
            client_ip="1.2.3.5", app="bar"
 | 
					            client_ip="1.2.3.5", app="foo"
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        response = policy.passes(request)
 | 
					        response = policy.passes(request)
 | 
				
			||||||
        self.assertFalse(response.passing)
 | 
					        self.assertFalse(response.passing)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_multiple(self):
 | 
				
			||||||
 | 
					        """Test multiple"""
 | 
				
			||||||
 | 
					        event = Event.new(EventAction.LOGIN)
 | 
				
			||||||
 | 
					        event.app = "foo"
 | 
				
			||||||
 | 
					        event.client_ip = "1.2.3.4"
 | 
				
			||||||
 | 
					        request = PolicyRequest(get_anonymous_user())
 | 
				
			||||||
 | 
					        request.context["event"] = event
 | 
				
			||||||
 | 
					        policy: EventMatcherPolicy = EventMatcherPolicy.objects.create(
 | 
				
			||||||
 | 
					            client_ip="1.2.3.4", app="foo"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        response = policy.passes(request)
 | 
				
			||||||
 | 
					        self.assertTrue(response.passing)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_invalid(self):
 | 
					    def test_invalid(self):
 | 
				
			||||||
        """Test passing event"""
 | 
					        """Test passing event"""
 | 
				
			||||||
        request = PolicyRequest(get_anonymous_user())
 | 
					        request = PolicyRequest(get_anonymous_user())
 | 
				
			||||||
 | 
				
			|||||||
@ -4,13 +4,13 @@ from django.apps.registry import Apps
 | 
				
			|||||||
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
 | 
					from django.db.backends.base.schema import BaseDatabaseSchemaEditor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from django.db import migrations
 | 
					from django.db import migrations
 | 
				
			||||||
from django.contrib.auth.management import create_permissions
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def migrate_search_group(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
 | 
					def migrate_search_group(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
 | 
				
			||||||
    from guardian.shortcuts import assign_perm
 | 
					 | 
				
			||||||
    from authentik.core.models import User
 | 
					    from authentik.core.models import User
 | 
				
			||||||
    from django.apps import apps as real_apps
 | 
					    from django.apps import apps as real_apps
 | 
				
			||||||
 | 
					    from django.contrib.auth.management import create_permissions
 | 
				
			||||||
 | 
					    from guardian.shortcuts import UserObjectPermission
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    db_alias = schema_editor.connection.alias
 | 
					    db_alias = schema_editor.connection.alias
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -20,14 +20,25 @@ def migrate_search_group(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
 | 
				
			|||||||
    create_permissions(real_apps.get_app_config("authentik_providers_ldap"), using=db_alias)
 | 
					    create_permissions(real_apps.get_app_config("authentik_providers_ldap"), using=db_alias)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    LDAPProvider = apps.get_model("authentik_providers_ldap", "ldapprovider")
 | 
					    LDAPProvider = apps.get_model("authentik_providers_ldap", "ldapprovider")
 | 
				
			||||||
 | 
					    Permission = apps.get_model("auth", "Permission")
 | 
				
			||||||
 | 
					    UserObjectPermission = apps.get_model("guardian", "UserObjectPermission")
 | 
				
			||||||
 | 
					    ContentType = apps.get_model("contenttypes", "ContentType")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    new_prem = Permission.objects.using(db_alias).get(codename="search_full_directory")
 | 
				
			||||||
 | 
					    ct = ContentType.objects.using(db_alias).get(
 | 
				
			||||||
 | 
					        app_label="authentik_providers_ldap",
 | 
				
			||||||
 | 
					        model="ldapprovider",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    for provider in LDAPProvider.objects.using(db_alias).all():
 | 
					    for provider in LDAPProvider.objects.using(db_alias).all():
 | 
				
			||||||
        for user_pk in (
 | 
					        if not provider.search_group:
 | 
				
			||||||
            provider.search_group.users.using(db_alias).all().values_list("pk", flat=True)
 | 
					            continue
 | 
				
			||||||
        ):
 | 
					        for user in provider.search_group.users.using(db_alias).all():
 | 
				
			||||||
            # We need the correct user model instance to assign the permission
 | 
					            UserObjectPermission.objects.using(db_alias).create(
 | 
				
			||||||
            assign_perm(
 | 
					                user=user,
 | 
				
			||||||
                "search_full_directory", User.objects.using(db_alias).get(pk=user_pk), provider
 | 
					                permission=new_prem,
 | 
				
			||||||
 | 
					                object_pk=provider.pk,
 | 
				
			||||||
 | 
					                content_type=ct,
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -35,6 +46,7 @@ class Migration(migrations.Migration):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    dependencies = [
 | 
					    dependencies = [
 | 
				
			||||||
        ("authentik_providers_ldap", "0003_ldapprovider_mfa_support_and_more"),
 | 
					        ("authentik_providers_ldap", "0003_ldapprovider_mfa_support_and_more"),
 | 
				
			||||||
 | 
					        ("guardian", "0002_generic_permissions_index"),
 | 
				
			||||||
    ]
 | 
					    ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    operations = [
 | 
					    operations = [
 | 
				
			||||||
 | 
				
			|||||||
@ -1,15 +1,18 @@
 | 
				
			|||||||
"""OAuth2Provider API Views"""
 | 
					"""OAuth2Provider API Views"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from copy import copy
 | 
					from copy import copy
 | 
				
			||||||
 | 
					from re import compile
 | 
				
			||||||
 | 
					from re import error as RegexError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from django.urls import reverse
 | 
					from django.urls import reverse
 | 
				
			||||||
from django.utils import timezone
 | 
					from django.utils import timezone
 | 
				
			||||||
 | 
					from django.utils.translation import gettext_lazy as _
 | 
				
			||||||
from drf_spectacular.types import OpenApiTypes
 | 
					from drf_spectacular.types import OpenApiTypes
 | 
				
			||||||
from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema
 | 
					from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema
 | 
				
			||||||
from guardian.shortcuts import get_objects_for_user
 | 
					from guardian.shortcuts import get_objects_for_user
 | 
				
			||||||
from rest_framework.decorators import action
 | 
					from rest_framework.decorators import action
 | 
				
			||||||
from rest_framework.exceptions import ValidationError
 | 
					from rest_framework.exceptions import ValidationError
 | 
				
			||||||
from rest_framework.fields import CharField
 | 
					from rest_framework.fields import CharField, ChoiceField
 | 
				
			||||||
from rest_framework.generics import get_object_or_404
 | 
					from rest_framework.generics import get_object_or_404
 | 
				
			||||||
from rest_framework.request import Request
 | 
					from rest_framework.request import Request
 | 
				
			||||||
from rest_framework.response import Response
 | 
					from rest_framework.response import Response
 | 
				
			||||||
@ -20,13 +23,39 @@ from authentik.core.api.used_by import UsedByMixin
 | 
				
			|||||||
from authentik.core.api.utils import PassiveSerializer, PropertyMappingPreviewSerializer
 | 
					from authentik.core.api.utils import PassiveSerializer, PropertyMappingPreviewSerializer
 | 
				
			||||||
from authentik.core.models import Provider
 | 
					from authentik.core.models import Provider
 | 
				
			||||||
from authentik.providers.oauth2.id_token import IDToken
 | 
					from authentik.providers.oauth2.id_token import IDToken
 | 
				
			||||||
from authentik.providers.oauth2.models import AccessToken, OAuth2Provider, ScopeMapping
 | 
					from authentik.providers.oauth2.models import (
 | 
				
			||||||
 | 
					    AccessToken,
 | 
				
			||||||
 | 
					    OAuth2Provider,
 | 
				
			||||||
 | 
					    RedirectURIMatchingMode,
 | 
				
			||||||
 | 
					    ScopeMapping,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
from authentik.rbac.decorators import permission_required
 | 
					from authentik.rbac.decorators import permission_required
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class RedirectURISerializer(PassiveSerializer):
 | 
				
			||||||
 | 
					    """A single allowed redirect URI entry"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    matching_mode = ChoiceField(choices=RedirectURIMatchingMode.choices)
 | 
				
			||||||
 | 
					    url = CharField()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class OAuth2ProviderSerializer(ProviderSerializer):
 | 
					class OAuth2ProviderSerializer(ProviderSerializer):
 | 
				
			||||||
    """OAuth2Provider Serializer"""
 | 
					    """OAuth2Provider Serializer"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    redirect_uris = RedirectURISerializer(many=True, source="_redirect_uris")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def validate_redirect_uris(self, data: list) -> list:
 | 
				
			||||||
 | 
					        for entry in data:
 | 
				
			||||||
 | 
					            if entry.get("matching_mode") == RedirectURIMatchingMode.REGEX:
 | 
				
			||||||
 | 
					                url = entry.get("url")
 | 
				
			||||||
 | 
					                try:
 | 
				
			||||||
 | 
					                    compile(url)
 | 
				
			||||||
 | 
					                except RegexError:
 | 
				
			||||||
 | 
					                    raise ValidationError(
 | 
				
			||||||
 | 
					                        _("Invalid Regex Pattern: {url}".format(url=url))
 | 
				
			||||||
 | 
					                    ) from None
 | 
				
			||||||
 | 
					        return data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    class Meta:
 | 
					    class Meta:
 | 
				
			||||||
        model = OAuth2Provider
 | 
					        model = OAuth2Provider
 | 
				
			||||||
        fields = ProviderSerializer.Meta.fields + [
 | 
					        fields = ProviderSerializer.Meta.fields + [
 | 
				
			||||||
@ -78,7 +107,6 @@ class OAuth2ProviderViewSet(UsedByMixin, ModelViewSet):
 | 
				
			|||||||
        "refresh_token_validity",
 | 
					        "refresh_token_validity",
 | 
				
			||||||
        "include_claims_in_id_token",
 | 
					        "include_claims_in_id_token",
 | 
				
			||||||
        "signing_key",
 | 
					        "signing_key",
 | 
				
			||||||
        "redirect_uris",
 | 
					 | 
				
			||||||
        "sub_mode",
 | 
					        "sub_mode",
 | 
				
			||||||
        "property_mappings",
 | 
					        "property_mappings",
 | 
				
			||||||
        "issuer_mode",
 | 
					        "issuer_mode",
 | 
				
			||||||
 | 
				
			|||||||
@ -7,7 +7,7 @@ from django.http import HttpRequest, HttpResponse, HttpResponseRedirect
 | 
				
			|||||||
from authentik.events.models import Event, EventAction
 | 
					from authentik.events.models import Event, EventAction
 | 
				
			||||||
from authentik.lib.sentry import SentryIgnoredException
 | 
					from authentik.lib.sentry import SentryIgnoredException
 | 
				
			||||||
from authentik.lib.views import bad_request_message
 | 
					from authentik.lib.views import bad_request_message
 | 
				
			||||||
from authentik.providers.oauth2.models import GrantTypes
 | 
					from authentik.providers.oauth2.models import GrantTypes, RedirectURI
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class OAuth2Error(SentryIgnoredException):
 | 
					class OAuth2Error(SentryIgnoredException):
 | 
				
			||||||
@ -46,9 +46,9 @@ class RedirectUriError(OAuth2Error):
 | 
				
			|||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    provided_uri: str
 | 
					    provided_uri: str
 | 
				
			||||||
    allowed_uris: list[str]
 | 
					    allowed_uris: list[RedirectURI]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, provided_uri: str, allowed_uris: list[str]) -> None:
 | 
					    def __init__(self, provided_uri: str, allowed_uris: list[RedirectURI]) -> None:
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
        self.provided_uri = provided_uri
 | 
					        self.provided_uri = provided_uri
 | 
				
			||||||
        self.allowed_uris = allowed_uris
 | 
					        self.allowed_uris = allowed_uris
 | 
				
			||||||
 | 
				
			|||||||
@ -1,6 +1,7 @@
 | 
				
			|||||||
"""id_token utils"""
 | 
					"""id_token utils"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from dataclasses import asdict, dataclass, field
 | 
					from dataclasses import asdict, dataclass, field
 | 
				
			||||||
 | 
					from hashlib import sha256
 | 
				
			||||||
from typing import TYPE_CHECKING, Any
 | 
					from typing import TYPE_CHECKING, Any
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from django.db import models
 | 
					from django.db import models
 | 
				
			||||||
@ -23,8 +24,13 @@ if TYPE_CHECKING:
 | 
				
			|||||||
    from authentik.providers.oauth2.models import BaseGrantModel, OAuth2Provider
 | 
					    from authentik.providers.oauth2.models import BaseGrantModel, OAuth2Provider
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def hash_session_key(session_key: str) -> str:
 | 
				
			||||||
 | 
					    """Hash the session key for inclusion in JWTs as `sid`"""
 | 
				
			||||||
 | 
					    return sha256(session_key.encode("ascii")).hexdigest()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class SubModes(models.TextChoices):
 | 
					class SubModes(models.TextChoices):
 | 
				
			||||||
    """Mode after which 'sub' attribute is generateed, for compatibility reasons"""
 | 
					    """Mode after which 'sub' attribute is generated, for compatibility reasons"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    HASHED_USER_ID = "hashed_user_id", _("Based on the Hashed User ID")
 | 
					    HASHED_USER_ID = "hashed_user_id", _("Based on the Hashed User ID")
 | 
				
			||||||
    USER_ID = "user_id", _("Based on user ID")
 | 
					    USER_ID = "user_id", _("Based on user ID")
 | 
				
			||||||
@ -51,7 +57,8 @@ class IDToken:
 | 
				
			|||||||
    and potentially other requested Claims. The ID Token is represented as a
 | 
					    and potentially other requested Claims. The ID Token is represented as a
 | 
				
			||||||
    JSON Web Token (JWT) [JWT].
 | 
					    JSON Web Token (JWT) [JWT].
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    https://openid.net/specs/openid-connect-core-1_0.html#IDToken"""
 | 
					    https://openid.net/specs/openid-connect-core-1_0.html#IDToken
 | 
				
			||||||
 | 
					    https://www.iana.org/assignments/jwt/jwt.xhtml"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Issuer, https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.1
 | 
					    # Issuer, https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.1
 | 
				
			||||||
    iss: str | None = None
 | 
					    iss: str | None = None
 | 
				
			||||||
@ -79,6 +86,8 @@ class IDToken:
 | 
				
			|||||||
    nonce: str | None = None
 | 
					    nonce: str | None = None
 | 
				
			||||||
    # Access Token hash value, http://openid.net/specs/openid-connect-core-1_0.html
 | 
					    # Access Token hash value, http://openid.net/specs/openid-connect-core-1_0.html
 | 
				
			||||||
    at_hash: str | None = None
 | 
					    at_hash: str | None = None
 | 
				
			||||||
 | 
					    # Session ID, https://openid.net/specs/openid-connect-frontchannel-1_0.html#ClaimsContents
 | 
				
			||||||
 | 
					    sid: str | None = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    claims: dict[str, Any] = field(default_factory=dict)
 | 
					    claims: dict[str, Any] = field(default_factory=dict)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -116,9 +125,11 @@ class IDToken:
 | 
				
			|||||||
        now = timezone.now()
 | 
					        now = timezone.now()
 | 
				
			||||||
        id_token.iat = int(now.timestamp())
 | 
					        id_token.iat = int(now.timestamp())
 | 
				
			||||||
        id_token.auth_time = int(token.auth_time.timestamp())
 | 
					        id_token.auth_time = int(token.auth_time.timestamp())
 | 
				
			||||||
 | 
					        if token.session:
 | 
				
			||||||
 | 
					            id_token.sid = hash_session_key(token.session.session_key)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # We use the timestamp of the user's last successful login (EventAction.LOGIN) for auth_time
 | 
					        # We use the timestamp of the user's last successful login (EventAction.LOGIN) for auth_time
 | 
				
			||||||
        auth_event = get_login_event(request)
 | 
					        auth_event = get_login_event(token.session)
 | 
				
			||||||
        if auth_event:
 | 
					        if auth_event:
 | 
				
			||||||
            # Also check which method was used for authentication
 | 
					            # Also check which method was used for authentication
 | 
				
			||||||
            method = auth_event.context.get(PLAN_CONTEXT_METHOD, "")
 | 
					            method = auth_event.context.get(PLAN_CONTEXT_METHOD, "")
 | 
				
			||||||
 | 
				
			|||||||
@ -3,6 +3,7 @@
 | 
				
			|||||||
import django.db.models.deletion
 | 
					import django.db.models.deletion
 | 
				
			||||||
from django.apps.registry import Apps
 | 
					from django.apps.registry import Apps
 | 
				
			||||||
from django.db import migrations, models
 | 
					from django.db import migrations, models
 | 
				
			||||||
 | 
					from django.db.backends.base.schema import BaseDatabaseSchemaEditor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import authentik.lib.utils.time
 | 
					import authentik.lib.utils.time
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -14,7 +15,7 @@ scope_uid_map = {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def set_managed_flag(apps: Apps, schema_editor):
 | 
					def set_managed_flag(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
 | 
				
			||||||
    ScopeMapping = apps.get_model("authentik_providers_oauth2", "ScopeMapping")
 | 
					    ScopeMapping = apps.get_model("authentik_providers_oauth2", "ScopeMapping")
 | 
				
			||||||
    db_alias = schema_editor.connection.alias
 | 
					    db_alias = schema_editor.connection.alias
 | 
				
			||||||
    for mapping in ScopeMapping.objects.using(db_alias).filter(name__startswith="Autogenerated "):
 | 
					    for mapping in ScopeMapping.objects.using(db_alias).filter(name__startswith="Autogenerated "):
 | 
				
			||||||
 | 
				
			|||||||
@ -0,0 +1,26 @@
 | 
				
			|||||||
 | 
					# Generated by Django 5.0.9 on 2024-09-26 16:25
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from django.conf import settings
 | 
				
			||||||
 | 
					from django.db import migrations, models
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Migration(migrations.Migration):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    dependencies = [
 | 
				
			||||||
 | 
					        ("authentik_providers_oauth2", "0018_alter_accesstoken_expires_and_more"),
 | 
				
			||||||
 | 
					        migrations.swappable_dependency(settings.AUTH_USER_MODEL),
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Original preserved
 | 
				
			||||||
 | 
					    # See https://github.com/goauthentik/authentik/issues/11874
 | 
				
			||||||
 | 
					    # operations = [
 | 
				
			||||||
 | 
					    #     migrations.AddIndex(
 | 
				
			||||||
 | 
					    #         model_name="accesstoken",
 | 
				
			||||||
 | 
					    #         index=models.Index(fields=["token"], name="authentik_p_token_4bc870_idx"),
 | 
				
			||||||
 | 
					    #     ),
 | 
				
			||||||
 | 
					    #     migrations.AddIndex(
 | 
				
			||||||
 | 
					    #         model_name="refreshtoken",
 | 
				
			||||||
 | 
					    #         index=models.Index(fields=["token"], name="authentik_p_token_1a841f_idx"),
 | 
				
			||||||
 | 
					    #     ),
 | 
				
			||||||
 | 
					    # ]
 | 
				
			||||||
 | 
					    operations = []
 | 
				
			||||||
@ -0,0 +1,34 @@
 | 
				
			|||||||
 | 
					# Generated by Django 5.0.9 on 2024-09-27 14:50
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from django.conf import settings
 | 
				
			||||||
 | 
					from django.db import migrations, models
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Migration(migrations.Migration):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    dependencies = [
 | 
				
			||||||
 | 
					        ("authentik_providers_oauth2", "0019_accesstoken_authentik_p_token_4bc870_idx_and_more"),
 | 
				
			||||||
 | 
					        migrations.swappable_dependency(settings.AUTH_USER_MODEL),
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Original preserved
 | 
				
			||||||
 | 
					    # See https://github.com/goauthentik/authentik/issues/11874
 | 
				
			||||||
 | 
					    # operations = [
 | 
				
			||||||
 | 
					    #     migrations.RemoveIndex(
 | 
				
			||||||
 | 
					    #         model_name="accesstoken",
 | 
				
			||||||
 | 
					    #         name="authentik_p_token_4bc870_idx",
 | 
				
			||||||
 | 
					    #     ),
 | 
				
			||||||
 | 
					    #     migrations.RemoveIndex(
 | 
				
			||||||
 | 
					    #         model_name="refreshtoken",
 | 
				
			||||||
 | 
					    #         name="authentik_p_token_1a841f_idx",
 | 
				
			||||||
 | 
					    #     ),
 | 
				
			||||||
 | 
					    #     migrations.AddIndex(
 | 
				
			||||||
 | 
					    #         model_name="accesstoken",
 | 
				
			||||||
 | 
					    #         index=models.Index(fields=["token", "provider"], name="authentik_p_token_f99422_idx"),
 | 
				
			||||||
 | 
					    #     ),
 | 
				
			||||||
 | 
					    #     migrations.AddIndex(
 | 
				
			||||||
 | 
					    #         model_name="refreshtoken",
 | 
				
			||||||
 | 
					    #         index=models.Index(fields=["token", "provider"], name="authentik_p_token_a1d921_idx"),
 | 
				
			||||||
 | 
					    #     ),
 | 
				
			||||||
 | 
					    # ]
 | 
				
			||||||
 | 
					    operations = []
 | 
				
			||||||
@ -0,0 +1,42 @@
 | 
				
			|||||||
 | 
					# Generated by Django 5.0.9 on 2024-10-16 14:53
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import django.db.models.deletion
 | 
				
			||||||
 | 
					from django.db import migrations, models
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Migration(migrations.Migration):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    dependencies = [
 | 
				
			||||||
 | 
					        ("authentik_crypto", "0004_alter_certificatekeypair_name"),
 | 
				
			||||||
 | 
					        (
 | 
				
			||||||
 | 
					            "authentik_providers_oauth2",
 | 
				
			||||||
 | 
					            "0020_remove_accesstoken_authentik_p_token_4bc870_idx_and_more",
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    operations = [
 | 
				
			||||||
 | 
					        migrations.AddField(
 | 
				
			||||||
 | 
					            model_name="oauth2provider",
 | 
				
			||||||
 | 
					            name="encryption_key",
 | 
				
			||||||
 | 
					            field=models.ForeignKey(
 | 
				
			||||||
 | 
					                help_text="Key used to encrypt the tokens. When set, tokens will be encrypted and returned as JWEs.",
 | 
				
			||||||
 | 
					                null=True,
 | 
				
			||||||
 | 
					                on_delete=django.db.models.deletion.SET_NULL,
 | 
				
			||||||
 | 
					                related_name="oauth2provider_encryption_key_set",
 | 
				
			||||||
 | 
					                to="authentik_crypto.certificatekeypair",
 | 
				
			||||||
 | 
					                verbose_name="Encryption Key",
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        migrations.AlterField(
 | 
				
			||||||
 | 
					            model_name="oauth2provider",
 | 
				
			||||||
 | 
					            name="signing_key",
 | 
				
			||||||
 | 
					            field=models.ForeignKey(
 | 
				
			||||||
 | 
					                help_text="Key used to sign the tokens.",
 | 
				
			||||||
 | 
					                null=True,
 | 
				
			||||||
 | 
					                on_delete=django.db.models.deletion.SET_NULL,
 | 
				
			||||||
 | 
					                related_name="oauth2provider_signing_key_set",
 | 
				
			||||||
 | 
					                to="authentik_crypto.certificatekeypair",
 | 
				
			||||||
 | 
					                verbose_name="Signing Key",
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
@ -0,0 +1,113 @@
 | 
				
			|||||||
 | 
					# Generated by Django 5.0.9 on 2024-10-23 13:38
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from hashlib import sha256
 | 
				
			||||||
 | 
					import django.db.models.deletion
 | 
				
			||||||
 | 
					from django.db import migrations, models
 | 
				
			||||||
 | 
					from django.apps.registry import Apps
 | 
				
			||||||
 | 
					from django.db.backends.base.schema import BaseDatabaseSchemaEditor
 | 
				
			||||||
 | 
					from authentik.lib.migrations import progress_bar
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def migrate_session(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
 | 
				
			||||||
 | 
					    AuthenticatedSession = apps.get_model("authentik_core", "authenticatedsession")
 | 
				
			||||||
 | 
					    AuthorizationCode = apps.get_model("authentik_providers_oauth2", "authorizationcode")
 | 
				
			||||||
 | 
					    AccessToken = apps.get_model("authentik_providers_oauth2", "accesstoken")
 | 
				
			||||||
 | 
					    RefreshToken = apps.get_model("authentik_providers_oauth2", "refreshtoken")
 | 
				
			||||||
 | 
					    db_alias = schema_editor.connection.alias
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    print(f"\nFetching session keys, this might take a couple of minutes...")
 | 
				
			||||||
 | 
					    session_ids = {}
 | 
				
			||||||
 | 
					    for session in progress_bar(AuthenticatedSession.objects.using(db_alias).all()):
 | 
				
			||||||
 | 
					        session_ids[sha256(session.session_key.encode("ascii")).hexdigest()] = session.session_key
 | 
				
			||||||
 | 
					    for model in [AuthorizationCode, AccessToken, RefreshToken]:
 | 
				
			||||||
 | 
					        print(
 | 
				
			||||||
 | 
					            f"\nAdding session to {model._meta.verbose_name}, this might take a couple of minutes..."
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        for code in progress_bar(model.objects.using(db_alias).all()):
 | 
				
			||||||
 | 
					            if code.session_id_old not in session_ids:
 | 
				
			||||||
 | 
					                continue
 | 
				
			||||||
 | 
					            code.session = (
 | 
				
			||||||
 | 
					                AuthenticatedSession.objects.using(db_alias)
 | 
				
			||||||
 | 
					                .filter(session_key=session_ids[code.session_id_old])
 | 
				
			||||||
 | 
					                .first()
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            code.save()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Migration(migrations.Migration):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    dependencies = [
 | 
				
			||||||
 | 
					        ("authentik_core", "0039_source_group_matching_mode_alter_group_name_and_more"),
 | 
				
			||||||
 | 
					        ("authentik_providers_oauth2", "0021_oauth2provider_encryption_key_and_more"),
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    operations = [
 | 
				
			||||||
 | 
					        migrations.RenameField(
 | 
				
			||||||
 | 
					            model_name="accesstoken",
 | 
				
			||||||
 | 
					            old_name="session_id",
 | 
				
			||||||
 | 
					            new_name="session_id_old",
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        migrations.RenameField(
 | 
				
			||||||
 | 
					            model_name="authorizationcode",
 | 
				
			||||||
 | 
					            old_name="session_id",
 | 
				
			||||||
 | 
					            new_name="session_id_old",
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        migrations.RenameField(
 | 
				
			||||||
 | 
					            model_name="refreshtoken",
 | 
				
			||||||
 | 
					            old_name="session_id",
 | 
				
			||||||
 | 
					            new_name="session_id_old",
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        migrations.AddField(
 | 
				
			||||||
 | 
					            model_name="accesstoken",
 | 
				
			||||||
 | 
					            name="session",
 | 
				
			||||||
 | 
					            field=models.ForeignKey(
 | 
				
			||||||
 | 
					                default=None,
 | 
				
			||||||
 | 
					                null=True,
 | 
				
			||||||
 | 
					                on_delete=django.db.models.deletion.SET_DEFAULT,
 | 
				
			||||||
 | 
					                to="authentik_core.authenticatedsession",
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        migrations.AddField(
 | 
				
			||||||
 | 
					            model_name="authorizationcode",
 | 
				
			||||||
 | 
					            name="session",
 | 
				
			||||||
 | 
					            field=models.ForeignKey(
 | 
				
			||||||
 | 
					                default=None,
 | 
				
			||||||
 | 
					                null=True,
 | 
				
			||||||
 | 
					                on_delete=django.db.models.deletion.SET_DEFAULT,
 | 
				
			||||||
 | 
					                to="authentik_core.authenticatedsession",
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        migrations.AddField(
 | 
				
			||||||
 | 
					            model_name="devicetoken",
 | 
				
			||||||
 | 
					            name="session",
 | 
				
			||||||
 | 
					            field=models.ForeignKey(
 | 
				
			||||||
 | 
					                default=None,
 | 
				
			||||||
 | 
					                null=True,
 | 
				
			||||||
 | 
					                on_delete=django.db.models.deletion.SET_DEFAULT,
 | 
				
			||||||
 | 
					                to="authentik_core.authenticatedsession",
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        migrations.AddField(
 | 
				
			||||||
 | 
					            model_name="refreshtoken",
 | 
				
			||||||
 | 
					            name="session",
 | 
				
			||||||
 | 
					            field=models.ForeignKey(
 | 
				
			||||||
 | 
					                default=None,
 | 
				
			||||||
 | 
					                null=True,
 | 
				
			||||||
 | 
					                on_delete=django.db.models.deletion.SET_DEFAULT,
 | 
				
			||||||
 | 
					                to="authentik_core.authenticatedsession",
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        migrations.RunPython(migrate_session),
 | 
				
			||||||
 | 
					        migrations.RemoveField(
 | 
				
			||||||
 | 
					            model_name="accesstoken",
 | 
				
			||||||
 | 
					            name="session_id_old",
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        migrations.RemoveField(
 | 
				
			||||||
 | 
					            model_name="authorizationcode",
 | 
				
			||||||
 | 
					            name="session_id_old",
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        migrations.RemoveField(
 | 
				
			||||||
 | 
					            model_name="refreshtoken",
 | 
				
			||||||
 | 
					            name="session_id_old",
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
@ -0,0 +1,31 @@
 | 
				
			|||||||
 | 
					# Generated by Django 5.0.9 on 2024-10-31 14:28
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import django.contrib.postgres.indexes
 | 
				
			||||||
 | 
					from django.conf import settings
 | 
				
			||||||
 | 
					from django.db import migrations
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Migration(migrations.Migration):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    dependencies = [
 | 
				
			||||||
 | 
					        ("authentik_core", "0039_source_group_matching_mode_alter_group_name_and_more"),
 | 
				
			||||||
 | 
					        ("authentik_providers_oauth2", "0022_remove_accesstoken_session_id_and_more"),
 | 
				
			||||||
 | 
					        migrations.swappable_dependency(settings.AUTH_USER_MODEL),
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    operations = [
 | 
				
			||||||
 | 
					        migrations.RunSQL("DROP INDEX IF EXISTS authentik_p_token_f99422_idx;"),
 | 
				
			||||||
 | 
					        migrations.RunSQL("DROP INDEX IF EXISTS authentik_p_token_a1d921_idx;"),
 | 
				
			||||||
 | 
					        migrations.AddIndex(
 | 
				
			||||||
 | 
					            model_name="accesstoken",
 | 
				
			||||||
 | 
					            index=django.contrib.postgres.indexes.HashIndex(
 | 
				
			||||||
 | 
					                fields=["token"], name="authentik_p_token_e00883_hash"
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        migrations.AddIndex(
 | 
				
			||||||
 | 
					            model_name="refreshtoken",
 | 
				
			||||||
 | 
					            index=django.contrib.postgres.indexes.HashIndex(
 | 
				
			||||||
 | 
					                fields=["token"], name="authentik_p_token_32e2b7_hash"
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
@ -0,0 +1,49 @@
 | 
				
			|||||||
 | 
					# Generated by Django 5.0.9 on 2024-11-04 12:56
 | 
				
			||||||
 | 
					from dataclasses import asdict
 | 
				
			||||||
 | 
					from django.apps.registry import Apps
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from django.db.backends.base.schema import BaseDatabaseSchemaEditor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from django.db import migrations, models
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def migrate_redirect_uris(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
 | 
				
			||||||
 | 
					    from authentik.providers.oauth2.models import RedirectURI, RedirectURIMatchingMode
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    OAuth2Provider = apps.get_model("authentik_providers_oauth2", "oauth2provider")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    db_alias = schema_editor.connection.alias
 | 
				
			||||||
 | 
					    for provider in OAuth2Provider.objects.using(db_alias).all():
 | 
				
			||||||
 | 
					        uris = []
 | 
				
			||||||
 | 
					        for old in provider.old_redirect_uris.split("\n"):
 | 
				
			||||||
 | 
					            mode = RedirectURIMatchingMode.STRICT
 | 
				
			||||||
 | 
					            if old == "*" or old == ".*":
 | 
				
			||||||
 | 
					                mode = RedirectURIMatchingMode.REGEX
 | 
				
			||||||
 | 
					            uris.append(asdict(RedirectURI(mode, url=old)))
 | 
				
			||||||
 | 
					        provider._redirect_uris = uris
 | 
				
			||||||
 | 
					        provider.save()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Migration(migrations.Migration):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    dependencies = [
 | 
				
			||||||
 | 
					        ("authentik_providers_oauth2", "0023_alter_accesstoken_refreshtoken_use_hash_index"),
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    operations = [
 | 
				
			||||||
 | 
					        migrations.RenameField(
 | 
				
			||||||
 | 
					            model_name="oauth2provider",
 | 
				
			||||||
 | 
					            old_name="redirect_uris",
 | 
				
			||||||
 | 
					            new_name="old_redirect_uris",
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        migrations.AddField(
 | 
				
			||||||
 | 
					            model_name="oauth2provider",
 | 
				
			||||||
 | 
					            name="_redirect_uris",
 | 
				
			||||||
 | 
					            field=models.JSONField(default=dict, verbose_name="Redirect URIs"),
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        migrations.RunPython(migrate_redirect_uris, lambda *args: ...),
 | 
				
			||||||
 | 
					        migrations.RemoveField(
 | 
				
			||||||
 | 
					            model_name="oauth2provider",
 | 
				
			||||||
 | 
					            name="old_redirect_uris",
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
@ -3,7 +3,7 @@
 | 
				
			|||||||
import base64
 | 
					import base64
 | 
				
			||||||
import binascii
 | 
					import binascii
 | 
				
			||||||
import json
 | 
					import json
 | 
				
			||||||
from dataclasses import asdict
 | 
					from dataclasses import asdict, dataclass
 | 
				
			||||||
from functools import cached_property
 | 
					from functools import cached_property
 | 
				
			||||||
from hashlib import sha256
 | 
					from hashlib import sha256
 | 
				
			||||||
from typing import Any
 | 
					from typing import Any
 | 
				
			||||||
@ -12,6 +12,7 @@ from urllib.parse import urlparse, urlunparse
 | 
				
			|||||||
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey
 | 
					from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey
 | 
				
			||||||
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
 | 
					from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
 | 
				
			||||||
from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes
 | 
					from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes
 | 
				
			||||||
 | 
					from dacite import Config
 | 
				
			||||||
from dacite.core import from_dict
 | 
					from dacite.core import from_dict
 | 
				
			||||||
from django.db import models
 | 
					from django.db import models
 | 
				
			||||||
from django.http import HttpRequest
 | 
					from django.http import HttpRequest
 | 
				
			||||||
@ -23,7 +24,13 @@ from rest_framework.serializers import Serializer
 | 
				
			|||||||
from structlog.stdlib import get_logger
 | 
					from structlog.stdlib import get_logger
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from authentik.brands.models import WebfingerProvider
 | 
					from authentik.brands.models import WebfingerProvider
 | 
				
			||||||
from authentik.core.models import ExpiringModel, PropertyMapping, Provider, User
 | 
					from authentik.core.models import (
 | 
				
			||||||
 | 
					    AuthenticatedSession,
 | 
				
			||||||
 | 
					    ExpiringModel,
 | 
				
			||||||
 | 
					    PropertyMapping,
 | 
				
			||||||
 | 
					    Provider,
 | 
				
			||||||
 | 
					    User,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
from authentik.crypto.models import CertificateKeyPair
 | 
					from authentik.crypto.models import CertificateKeyPair
 | 
				
			||||||
from authentik.lib.generators import generate_code_fixed_length, generate_id, generate_key
 | 
					from authentik.lib.generators import generate_code_fixed_length, generate_id, generate_key
 | 
				
			||||||
from authentik.lib.models import SerializerModel
 | 
					from authentik.lib.models import SerializerModel
 | 
				
			||||||
@ -67,11 +74,25 @@ class IssuerMode(models.TextChoices):
 | 
				
			|||||||
    """Configure how the `iss` field is created."""
 | 
					    """Configure how the `iss` field is created."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    GLOBAL = "global", _("Same identifier is used for all providers")
 | 
					    GLOBAL = "global", _("Same identifier is used for all providers")
 | 
				
			||||||
    PER_PROVIDER = "per_provider", _(
 | 
					    PER_PROVIDER = (
 | 
				
			||||||
        "Each provider has a different issuer, based on the application slug."
 | 
					        "per_provider",
 | 
				
			||||||
 | 
					        _("Each provider has a different issuer, based on the application slug."),
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class RedirectURIMatchingMode(models.TextChoices):
 | 
				
			||||||
 | 
					    STRICT = "strict", _("Strict URL comparison")
 | 
				
			||||||
 | 
					    REGEX = "regex", _("Regular Expression URL matching")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@dataclass
 | 
				
			||||||
 | 
					class RedirectURI:
 | 
				
			||||||
 | 
					    """A single redirect URI entry"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    matching_mode: RedirectURIMatchingMode
 | 
				
			||||||
 | 
					    url: str
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ResponseTypes(models.TextChoices):
 | 
					class ResponseTypes(models.TextChoices):
 | 
				
			||||||
    """Response Type required by the client."""
 | 
					    """Response Type required by the client."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -146,11 +167,9 @@ class OAuth2Provider(WebfingerProvider, Provider):
 | 
				
			|||||||
        verbose_name=_("Client Secret"),
 | 
					        verbose_name=_("Client Secret"),
 | 
				
			||||||
        default=generate_client_secret,
 | 
					        default=generate_client_secret,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    redirect_uris = models.TextField(
 | 
					    _redirect_uris = models.JSONField(
 | 
				
			||||||
        default="",
 | 
					        default=dict,
 | 
				
			||||||
        blank=True,
 | 
					 | 
				
			||||||
        verbose_name=_("Redirect URIs"),
 | 
					        verbose_name=_("Redirect URIs"),
 | 
				
			||||||
        help_text=_("Enter each URI on a new line."),
 | 
					 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    include_claims_in_id_token = models.BooleanField(
 | 
					    include_claims_in_id_token = models.BooleanField(
 | 
				
			||||||
@ -251,12 +270,33 @@ class OAuth2Provider(WebfingerProvider, Provider):
 | 
				
			|||||||
        except Provider.application.RelatedObjectDoesNotExist:
 | 
					        except Provider.application.RelatedObjectDoesNotExist:
 | 
				
			||||||
            return None
 | 
					            return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def redirect_uris(self) -> list[RedirectURI]:
 | 
				
			||||||
 | 
					        uris = []
 | 
				
			||||||
 | 
					        for entry in self._redirect_uris:
 | 
				
			||||||
 | 
					            uris.append(
 | 
				
			||||||
 | 
					                from_dict(
 | 
				
			||||||
 | 
					                    RedirectURI,
 | 
				
			||||||
 | 
					                    entry,
 | 
				
			||||||
 | 
					                    config=Config(type_hooks={RedirectURIMatchingMode: RedirectURIMatchingMode}),
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        return uris
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @redirect_uris.setter
 | 
				
			||||||
 | 
					    def redirect_uris(self, value: list[RedirectURI]):
 | 
				
			||||||
 | 
					        cleansed = []
 | 
				
			||||||
 | 
					        for entry in value:
 | 
				
			||||||
 | 
					            cleansed.append(asdict(entry))
 | 
				
			||||||
 | 
					        self._redirect_uris = cleansed
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def launch_url(self) -> str | None:
 | 
					    def launch_url(self) -> str | None:
 | 
				
			||||||
        """Guess launch_url based on first redirect_uri"""
 | 
					        """Guess launch_url based on first redirect_uri"""
 | 
				
			||||||
        if self.redirect_uris == "":
 | 
					        redirects = self.redirect_uris
 | 
				
			||||||
 | 
					        if len(redirects) < 1:
 | 
				
			||||||
            return None
 | 
					            return None
 | 
				
			||||||
        main_url = self.redirect_uris.split("\n", maxsplit=1)[0]
 | 
					        main_url = redirects[0].url
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            launch_url = urlparse(main_url)._replace(path="")
 | 
					            launch_url = urlparse(main_url)._replace(path="")
 | 
				
			||||||
            return urlunparse(launch_url)
 | 
					            return urlunparse(launch_url)
 | 
				
			||||||
@ -320,7 +360,9 @@ class BaseGrantModel(models.Model):
 | 
				
			|||||||
    revoked = models.BooleanField(default=False)
 | 
					    revoked = models.BooleanField(default=False)
 | 
				
			||||||
    _scope = models.TextField(default="", verbose_name=_("Scopes"))
 | 
					    _scope = models.TextField(default="", verbose_name=_("Scopes"))
 | 
				
			||||||
    auth_time = models.DateTimeField(verbose_name="Authentication time")
 | 
					    auth_time = models.DateTimeField(verbose_name="Authentication time")
 | 
				
			||||||
    session_id = models.CharField(default="", blank=True)
 | 
					    session = models.ForeignKey(
 | 
				
			||||||
 | 
					        AuthenticatedSession, null=True, on_delete=models.SET_DEFAULT, default=None
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    class Meta:
 | 
					    class Meta:
 | 
				
			||||||
        abstract = True
 | 
					        abstract = True
 | 
				
			||||||
@ -452,6 +494,9 @@ class DeviceToken(ExpiringModel):
 | 
				
			|||||||
    device_code = models.TextField(default=generate_key)
 | 
					    device_code = models.TextField(default=generate_key)
 | 
				
			||||||
    user_code = models.TextField(default=generate_code_fixed_length)
 | 
					    user_code = models.TextField(default=generate_code_fixed_length)
 | 
				
			||||||
    _scope = models.TextField(default="", verbose_name=_("Scopes"))
 | 
					    _scope = models.TextField(default="", verbose_name=_("Scopes"))
 | 
				
			||||||
 | 
					    session = models.ForeignKey(
 | 
				
			||||||
 | 
					        AuthenticatedSession, null=True, on_delete=models.SET_DEFAULT, default=None
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def scope(self) -> list[str]:
 | 
					    def scope(self) -> list[str]:
 | 
				
			||||||
 | 
				
			|||||||
@ -1,5 +1,3 @@
 | 
				
			|||||||
from hashlib import sha256
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from django.contrib.auth.signals import user_logged_out
 | 
					from django.contrib.auth.signals import user_logged_out
 | 
				
			||||||
from django.dispatch import receiver
 | 
					from django.dispatch import receiver
 | 
				
			||||||
from django.http import HttpRequest
 | 
					from django.http import HttpRequest
 | 
				
			||||||
@ -13,5 +11,4 @@ def user_logged_out_oauth_access_token(sender, request: HttpRequest, user: User,
 | 
				
			|||||||
    """Revoke access tokens upon user logout"""
 | 
					    """Revoke access tokens upon user logout"""
 | 
				
			||||||
    if not request.session or not request.session.session_key:
 | 
					    if not request.session or not request.session.session_key:
 | 
				
			||||||
        return
 | 
					        return
 | 
				
			||||||
    hashed_session_key = sha256(request.session.session_key.encode("ascii")).hexdigest()
 | 
					    AccessToken.objects.filter(user=user, session__session_key=request.session.session_key).delete()
 | 
				
			||||||
    AccessToken.objects.filter(user=user, session_id=hashed_session_key).delete()
 | 
					 | 
				
			||||||
 | 
				
			|||||||
@ -10,7 +10,13 @@ from rest_framework.test import APITestCase
 | 
				
			|||||||
from authentik.blueprints.tests import apply_blueprint
 | 
					from authentik.blueprints.tests import apply_blueprint
 | 
				
			||||||
from authentik.core.models import Application
 | 
					from authentik.core.models import Application
 | 
				
			||||||
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
 | 
					from authentik.core.tests.utils import create_test_admin_user, create_test_flow
 | 
				
			||||||
from authentik.providers.oauth2.models import OAuth2Provider, ScopeMapping
 | 
					from authentik.lib.generators import generate_id
 | 
				
			||||||
 | 
					from authentik.providers.oauth2.models import (
 | 
				
			||||||
 | 
					    OAuth2Provider,
 | 
				
			||||||
 | 
					    RedirectURI,
 | 
				
			||||||
 | 
					    RedirectURIMatchingMode,
 | 
				
			||||||
 | 
					    ScopeMapping,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TestAPI(APITestCase):
 | 
					class TestAPI(APITestCase):
 | 
				
			||||||
@ -21,7 +27,7 @@ class TestAPI(APITestCase):
 | 
				
			|||||||
        self.provider: OAuth2Provider = OAuth2Provider.objects.create(
 | 
					        self.provider: OAuth2Provider = OAuth2Provider.objects.create(
 | 
				
			||||||
            name="test",
 | 
					            name="test",
 | 
				
			||||||
            authorization_flow=create_test_flow(),
 | 
					            authorization_flow=create_test_flow(),
 | 
				
			||||||
            redirect_uris="http://testserver",
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")],
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        self.provider.property_mappings.set(ScopeMapping.objects.all())
 | 
					        self.provider.property_mappings.set(ScopeMapping.objects.all())
 | 
				
			||||||
        self.app = Application.objects.create(name="test", slug="test", provider=self.provider)
 | 
					        self.app = Application.objects.create(name="test", slug="test", provider=self.provider)
 | 
				
			||||||
@ -50,9 +56,29 @@ class TestAPI(APITestCase):
 | 
				
			|||||||
    @skipUnless(version_info >= (3, 11, 4), "This behaviour is only Python 3.11.4 and up")
 | 
					    @skipUnless(version_info >= (3, 11, 4), "This behaviour is only Python 3.11.4 and up")
 | 
				
			||||||
    def test_launch_url(self):
 | 
					    def test_launch_url(self):
 | 
				
			||||||
        """Test launch_url"""
 | 
					        """Test launch_url"""
 | 
				
			||||||
        self.provider.redirect_uris = (
 | 
					        self.provider.redirect_uris = [
 | 
				
			||||||
            "https://[\\d\\w]+.pr.test.goauthentik.io/source/oauth/callback/authentik/\n"
 | 
					            RedirectURI(
 | 
				
			||||||
        )
 | 
					                RedirectURIMatchingMode.REGEX,
 | 
				
			||||||
 | 
					                "https://[\\d\\w]+.pr.test.goauthentik.io/source/oauth/callback/authentik/",
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					        ]
 | 
				
			||||||
        self.provider.save()
 | 
					        self.provider.save()
 | 
				
			||||||
        self.provider.refresh_from_db()
 | 
					        self.provider.refresh_from_db()
 | 
				
			||||||
        self.assertIsNone(self.provider.launch_url)
 | 
					        self.assertIsNone(self.provider.launch_url)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_validate_redirect_uris(self):
 | 
				
			||||||
 | 
					        """Test redirect_uris API"""
 | 
				
			||||||
 | 
					        response = self.client.post(
 | 
				
			||||||
 | 
					            reverse("authentik_api:oauth2provider-list"),
 | 
				
			||||||
 | 
					            data={
 | 
				
			||||||
 | 
					                "name": generate_id(),
 | 
				
			||||||
 | 
					                "authorization_flow": create_test_flow().pk,
 | 
				
			||||||
 | 
					                "invalidation_flow": create_test_flow().pk,
 | 
				
			||||||
 | 
					                "redirect_uris": [
 | 
				
			||||||
 | 
					                    {"matching_mode": "strict", "url": "http://goauthentik.io"},
 | 
				
			||||||
 | 
					                    {"matching_mode": "regex", "url": "**"},
 | 
				
			||||||
 | 
					                ],
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertJSONEqual(response.content, {"redirect_uris": ["Invalid Regex Pattern: **"]})
 | 
				
			||||||
 | 
					        self.assertEqual(response.status_code, 400)
 | 
				
			||||||
 | 
				
			|||||||
@ -19,6 +19,8 @@ from authentik.providers.oauth2.models import (
 | 
				
			|||||||
    AuthorizationCode,
 | 
					    AuthorizationCode,
 | 
				
			||||||
    GrantTypes,
 | 
					    GrantTypes,
 | 
				
			||||||
    OAuth2Provider,
 | 
					    OAuth2Provider,
 | 
				
			||||||
 | 
					    RedirectURI,
 | 
				
			||||||
 | 
					    RedirectURIMatchingMode,
 | 
				
			||||||
    ScopeMapping,
 | 
					    ScopeMapping,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from authentik.providers.oauth2.tests.utils import OAuthTestCase
 | 
					from authentik.providers.oauth2.tests.utils import OAuthTestCase
 | 
				
			||||||
@ -39,7 +41,7 @@ class TestAuthorize(OAuthTestCase):
 | 
				
			|||||||
            name=generate_id(),
 | 
					            name=generate_id(),
 | 
				
			||||||
            client_id="test",
 | 
					            client_id="test",
 | 
				
			||||||
            authorization_flow=create_test_flow(),
 | 
					            authorization_flow=create_test_flow(),
 | 
				
			||||||
            redirect_uris="http://local.invalid/Foo",
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")],
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        with self.assertRaises(AuthorizeError):
 | 
					        with self.assertRaises(AuthorizeError):
 | 
				
			||||||
            request = self.factory.get(
 | 
					            request = self.factory.get(
 | 
				
			||||||
@ -64,7 +66,7 @@ class TestAuthorize(OAuthTestCase):
 | 
				
			|||||||
            name=generate_id(),
 | 
					            name=generate_id(),
 | 
				
			||||||
            client_id="test",
 | 
					            client_id="test",
 | 
				
			||||||
            authorization_flow=create_test_flow(),
 | 
					            authorization_flow=create_test_flow(),
 | 
				
			||||||
            redirect_uris="http://local.invalid/Foo",
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")],
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        with self.assertRaises(AuthorizeError):
 | 
					        with self.assertRaises(AuthorizeError):
 | 
				
			||||||
            request = self.factory.get(
 | 
					            request = self.factory.get(
 | 
				
			||||||
@ -84,7 +86,7 @@ class TestAuthorize(OAuthTestCase):
 | 
				
			|||||||
            name=generate_id(),
 | 
					            name=generate_id(),
 | 
				
			||||||
            client_id="test",
 | 
					            client_id="test",
 | 
				
			||||||
            authorization_flow=create_test_flow(),
 | 
					            authorization_flow=create_test_flow(),
 | 
				
			||||||
            redirect_uris="http://local.invalid",
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        with self.assertRaises(RedirectUriError):
 | 
					        with self.assertRaises(RedirectUriError):
 | 
				
			||||||
            request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
 | 
					            request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
 | 
				
			||||||
@ -106,7 +108,7 @@ class TestAuthorize(OAuthTestCase):
 | 
				
			|||||||
            name=generate_id(),
 | 
					            name=generate_id(),
 | 
				
			||||||
            client_id="test",
 | 
					            client_id="test",
 | 
				
			||||||
            authorization_flow=create_test_flow(),
 | 
					            authorization_flow=create_test_flow(),
 | 
				
			||||||
            redirect_uris="data:local.invalid",
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "data:local.invalid")],
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        with self.assertRaises(RedirectUriError):
 | 
					        with self.assertRaises(RedirectUriError):
 | 
				
			||||||
            request = self.factory.get(
 | 
					            request = self.factory.get(
 | 
				
			||||||
@ -125,7 +127,7 @@ class TestAuthorize(OAuthTestCase):
 | 
				
			|||||||
            name=generate_id(),
 | 
					            name=generate_id(),
 | 
				
			||||||
            client_id="test",
 | 
					            client_id="test",
 | 
				
			||||||
            authorization_flow=create_test_flow(),
 | 
					            authorization_flow=create_test_flow(),
 | 
				
			||||||
            redirect_uris="",
 | 
					            redirect_uris=[],
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        with self.assertRaises(RedirectUriError):
 | 
					        with self.assertRaises(RedirectUriError):
 | 
				
			||||||
            request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
 | 
					            request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
 | 
				
			||||||
@ -140,7 +142,7 @@ class TestAuthorize(OAuthTestCase):
 | 
				
			|||||||
        )
 | 
					        )
 | 
				
			||||||
        OAuthAuthorizationParams.from_request(request)
 | 
					        OAuthAuthorizationParams.from_request(request)
 | 
				
			||||||
        provider.refresh_from_db()
 | 
					        provider.refresh_from_db()
 | 
				
			||||||
        self.assertEqual(provider.redirect_uris, "+")
 | 
					        self.assertEqual(provider.redirect_uris, [RedirectURI(RedirectURIMatchingMode.STRICT, "+")])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_invalid_redirect_uri_regex(self):
 | 
					    def test_invalid_redirect_uri_regex(self):
 | 
				
			||||||
        """test missing/invalid redirect URI"""
 | 
					        """test missing/invalid redirect URI"""
 | 
				
			||||||
@ -148,7 +150,7 @@ class TestAuthorize(OAuthTestCase):
 | 
				
			|||||||
            name=generate_id(),
 | 
					            name=generate_id(),
 | 
				
			||||||
            client_id="test",
 | 
					            client_id="test",
 | 
				
			||||||
            authorization_flow=create_test_flow(),
 | 
					            authorization_flow=create_test_flow(),
 | 
				
			||||||
            redirect_uris="http://local.invalid?",
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid?")],
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        with self.assertRaises(RedirectUriError):
 | 
					        with self.assertRaises(RedirectUriError):
 | 
				
			||||||
            request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
 | 
					            request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
 | 
				
			||||||
@ -170,7 +172,7 @@ class TestAuthorize(OAuthTestCase):
 | 
				
			|||||||
            name=generate_id(),
 | 
					            name=generate_id(),
 | 
				
			||||||
            client_id="test",
 | 
					            client_id="test",
 | 
				
			||||||
            authorization_flow=create_test_flow(),
 | 
					            authorization_flow=create_test_flow(),
 | 
				
			||||||
            redirect_uris="+",
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "+")],
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        with self.assertRaises(RedirectUriError):
 | 
					        with self.assertRaises(RedirectUriError):
 | 
				
			||||||
            request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
 | 
					            request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
 | 
				
			||||||
@ -213,7 +215,7 @@ class TestAuthorize(OAuthTestCase):
 | 
				
			|||||||
            name=generate_id(),
 | 
					            name=generate_id(),
 | 
				
			||||||
            client_id="test",
 | 
					            client_id="test",
 | 
				
			||||||
            authorization_flow=create_test_flow(),
 | 
					            authorization_flow=create_test_flow(),
 | 
				
			||||||
            redirect_uris="http://local.invalid/Foo",
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")],
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        provider.property_mappings.set(
 | 
					        provider.property_mappings.set(
 | 
				
			||||||
            ScopeMapping.objects.filter(
 | 
					            ScopeMapping.objects.filter(
 | 
				
			||||||
@ -301,7 +303,7 @@ class TestAuthorize(OAuthTestCase):
 | 
				
			|||||||
            name=generate_id(),
 | 
					            name=generate_id(),
 | 
				
			||||||
            client_id="test",
 | 
					            client_id="test",
 | 
				
			||||||
            authorization_flow=flow,
 | 
					            authorization_flow=flow,
 | 
				
			||||||
            redirect_uris="foo://localhost",
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "foo://localhost")],
 | 
				
			||||||
            access_code_validity="seconds=100",
 | 
					            access_code_validity="seconds=100",
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        Application.objects.create(name="app", slug="app", provider=provider)
 | 
					        Application.objects.create(name="app", slug="app", provider=provider)
 | 
				
			||||||
@ -343,7 +345,7 @@ class TestAuthorize(OAuthTestCase):
 | 
				
			|||||||
            name=generate_id(),
 | 
					            name=generate_id(),
 | 
				
			||||||
            client_id="test",
 | 
					            client_id="test",
 | 
				
			||||||
            authorization_flow=flow,
 | 
					            authorization_flow=flow,
 | 
				
			||||||
            redirect_uris="http://localhost",
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")],
 | 
				
			||||||
            signing_key=self.keypair,
 | 
					            signing_key=self.keypair,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        provider.property_mappings.set(
 | 
					        provider.property_mappings.set(
 | 
				
			||||||
@ -419,7 +421,7 @@ class TestAuthorize(OAuthTestCase):
 | 
				
			|||||||
            name=generate_id(),
 | 
					            name=generate_id(),
 | 
				
			||||||
            client_id="test",
 | 
					            client_id="test",
 | 
				
			||||||
            authorization_flow=flow,
 | 
					            authorization_flow=flow,
 | 
				
			||||||
            redirect_uris="http://localhost",
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")],
 | 
				
			||||||
            signing_key=self.keypair,
 | 
					            signing_key=self.keypair,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        Application.objects.create(name="app", slug="app", provider=provider)
 | 
					        Application.objects.create(name="app", slug="app", provider=provider)
 | 
				
			||||||
@ -474,7 +476,7 @@ class TestAuthorize(OAuthTestCase):
 | 
				
			|||||||
            name=generate_id(),
 | 
					            name=generate_id(),
 | 
				
			||||||
            client_id=generate_id(),
 | 
					            client_id=generate_id(),
 | 
				
			||||||
            authorization_flow=flow,
 | 
					            authorization_flow=flow,
 | 
				
			||||||
            redirect_uris="http://localhost",
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")],
 | 
				
			||||||
            signing_key=self.keypair,
 | 
					            signing_key=self.keypair,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        provider.property_mappings.set(
 | 
					        provider.property_mappings.set(
 | 
				
			||||||
@ -532,7 +534,7 @@ class TestAuthorize(OAuthTestCase):
 | 
				
			|||||||
            name=generate_id(),
 | 
					            name=generate_id(),
 | 
				
			||||||
            client_id=generate_id(),
 | 
					            client_id=generate_id(),
 | 
				
			||||||
            authorization_flow=flow,
 | 
					            authorization_flow=flow,
 | 
				
			||||||
            redirect_uris="http://localhost",
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")],
 | 
				
			||||||
            signing_key=self.keypair,
 | 
					            signing_key=self.keypair,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        app = Application.objects.create(name=generate_id(), slug=generate_id(), provider=provider)
 | 
					        app = Application.objects.create(name=generate_id(), slug=generate_id(), provider=provider)
 | 
				
			||||||
 | 
				
			|||||||
@ -11,7 +11,14 @@ from authentik.core.models import Application
 | 
				
			|||||||
from authentik.core.tests.utils import create_test_admin_user, create_test_cert, create_test_flow
 | 
					from authentik.core.tests.utils import create_test_admin_user, create_test_cert, create_test_flow
 | 
				
			||||||
from authentik.lib.generators import generate_id
 | 
					from authentik.lib.generators import generate_id
 | 
				
			||||||
from authentik.providers.oauth2.constants import ACR_AUTHENTIK_DEFAULT
 | 
					from authentik.providers.oauth2.constants import ACR_AUTHENTIK_DEFAULT
 | 
				
			||||||
from authentik.providers.oauth2.models import AccessToken, IDToken, OAuth2Provider, RefreshToken
 | 
					from authentik.providers.oauth2.models import (
 | 
				
			||||||
 | 
					    AccessToken,
 | 
				
			||||||
 | 
					    IDToken,
 | 
				
			||||||
 | 
					    OAuth2Provider,
 | 
				
			||||||
 | 
					    RedirectURI,
 | 
				
			||||||
 | 
					    RedirectURIMatchingMode,
 | 
				
			||||||
 | 
					    RefreshToken,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
from authentik.providers.oauth2.tests.utils import OAuthTestCase
 | 
					from authentik.providers.oauth2.tests.utils import OAuthTestCase
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -23,13 +30,12 @@ class TesOAuth2Introspection(OAuthTestCase):
 | 
				
			|||||||
        self.provider: OAuth2Provider = OAuth2Provider.objects.create(
 | 
					        self.provider: OAuth2Provider = OAuth2Provider.objects.create(
 | 
				
			||||||
            name=generate_id(),
 | 
					            name=generate_id(),
 | 
				
			||||||
            authorization_flow=create_test_flow(),
 | 
					            authorization_flow=create_test_flow(),
 | 
				
			||||||
            redirect_uris="",
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "")],
 | 
				
			||||||
            signing_key=create_test_cert(),
 | 
					            signing_key=create_test_cert(),
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        self.app = Application.objects.create(
 | 
					        self.app = Application.objects.create(
 | 
				
			||||||
            name=generate_id(), slug=generate_id(), provider=self.provider
 | 
					            name=generate_id(), slug=generate_id(), provider=self.provider
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        self.app.save()
 | 
					 | 
				
			||||||
        self.user = create_test_admin_user()
 | 
					        self.user = create_test_admin_user()
 | 
				
			||||||
        self.auth = b64encode(
 | 
					        self.auth = b64encode(
 | 
				
			||||||
            f"{self.provider.client_id}:{self.provider.client_secret}".encode()
 | 
					            f"{self.provider.client_id}:{self.provider.client_secret}".encode()
 | 
				
			||||||
@ -114,6 +120,41 @@ class TesOAuth2Introspection(OAuthTestCase):
 | 
				
			|||||||
            },
 | 
					            },
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_introspect_invalid_provider(self):
 | 
				
			||||||
 | 
					        """Test introspection (mismatched provider and token)"""
 | 
				
			||||||
 | 
					        provider: OAuth2Provider = OAuth2Provider.objects.create(
 | 
				
			||||||
 | 
					            name=generate_id(),
 | 
				
			||||||
 | 
					            authorization_flow=create_test_flow(),
 | 
				
			||||||
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "")],
 | 
				
			||||||
 | 
					            signing_key=create_test_cert(),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        auth = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        token: AccessToken = AccessToken.objects.create(
 | 
				
			||||||
 | 
					            provider=self.provider,
 | 
				
			||||||
 | 
					            user=self.user,
 | 
				
			||||||
 | 
					            token=generate_id(),
 | 
				
			||||||
 | 
					            auth_time=timezone.now(),
 | 
				
			||||||
 | 
					            _scope="openid user profile",
 | 
				
			||||||
 | 
					            _id_token=json.dumps(
 | 
				
			||||||
 | 
					                asdict(
 | 
				
			||||||
 | 
					                    IDToken("foo", "bar"),
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        res = self.client.post(
 | 
				
			||||||
 | 
					            reverse("authentik_providers_oauth2:token-introspection"),
 | 
				
			||||||
 | 
					            HTTP_AUTHORIZATION=f"Basic {auth}",
 | 
				
			||||||
 | 
					            data={"token": token.token},
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual(res.status_code, 200)
 | 
				
			||||||
 | 
					        self.assertJSONEqual(
 | 
				
			||||||
 | 
					            res.content.decode(),
 | 
				
			||||||
 | 
					            {
 | 
				
			||||||
 | 
					                "active": False,
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_introspect_invalid_auth(self):
 | 
					    def test_introspect_invalid_auth(self):
 | 
				
			||||||
        """Test introspect (invalid auth)"""
 | 
					        """Test introspect (invalid auth)"""
 | 
				
			||||||
        res = self.client.post(
 | 
					        res = self.client.post(
 | 
				
			||||||
 | 
				
			|||||||
@ -13,7 +13,7 @@ from authentik.core.tests.utils import create_test_cert, create_test_flow
 | 
				
			|||||||
from authentik.crypto.builder import PrivateKeyAlg
 | 
					from authentik.crypto.builder import PrivateKeyAlg
 | 
				
			||||||
from authentik.crypto.models import CertificateKeyPair
 | 
					from authentik.crypto.models import CertificateKeyPair
 | 
				
			||||||
from authentik.lib.generators import generate_id
 | 
					from authentik.lib.generators import generate_id
 | 
				
			||||||
from authentik.providers.oauth2.models import OAuth2Provider
 | 
					from authentik.providers.oauth2.models import OAuth2Provider, RedirectURI, RedirectURIMatchingMode
 | 
				
			||||||
from authentik.providers.oauth2.tests.utils import OAuthTestCase
 | 
					from authentik.providers.oauth2.tests.utils import OAuthTestCase
 | 
				
			||||||
 | 
					
 | 
				
			||||||
TEST_CORDS_CERT = """
 | 
					TEST_CORDS_CERT = """
 | 
				
			||||||
@ -49,7 +49,7 @@ class TestJWKS(OAuthTestCase):
 | 
				
			|||||||
            name="test",
 | 
					            name="test",
 | 
				
			||||||
            client_id="test",
 | 
					            client_id="test",
 | 
				
			||||||
            authorization_flow=create_test_flow(),
 | 
					            authorization_flow=create_test_flow(),
 | 
				
			||||||
            redirect_uris="http://local.invalid",
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
 | 
				
			||||||
            signing_key=create_test_cert(),
 | 
					            signing_key=create_test_cert(),
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        app = Application.objects.create(name="test", slug="test", provider=provider)
 | 
					        app = Application.objects.create(name="test", slug="test", provider=provider)
 | 
				
			||||||
@ -68,7 +68,7 @@ class TestJWKS(OAuthTestCase):
 | 
				
			|||||||
            name="test",
 | 
					            name="test",
 | 
				
			||||||
            client_id="test",
 | 
					            client_id="test",
 | 
				
			||||||
            authorization_flow=create_test_flow(),
 | 
					            authorization_flow=create_test_flow(),
 | 
				
			||||||
            redirect_uris="http://local.invalid",
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        app = Application.objects.create(name="test", slug="test", provider=provider)
 | 
					        app = Application.objects.create(name="test", slug="test", provider=provider)
 | 
				
			||||||
        response = self.client.get(
 | 
					        response = self.client.get(
 | 
				
			||||||
@ -82,7 +82,7 @@ class TestJWKS(OAuthTestCase):
 | 
				
			|||||||
            name="test",
 | 
					            name="test",
 | 
				
			||||||
            client_id="test",
 | 
					            client_id="test",
 | 
				
			||||||
            authorization_flow=create_test_flow(),
 | 
					            authorization_flow=create_test_flow(),
 | 
				
			||||||
            redirect_uris="http://local.invalid",
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
 | 
				
			||||||
            signing_key=create_test_cert(PrivateKeyAlg.ECDSA),
 | 
					            signing_key=create_test_cert(PrivateKeyAlg.ECDSA),
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        app = Application.objects.create(name="test", slug="test", provider=provider)
 | 
					        app = Application.objects.create(name="test", slug="test", provider=provider)
 | 
				
			||||||
@ -104,7 +104,7 @@ class TestJWKS(OAuthTestCase):
 | 
				
			|||||||
            name="test",
 | 
					            name="test",
 | 
				
			||||||
            client_id="test",
 | 
					            client_id="test",
 | 
				
			||||||
            authorization_flow=create_test_flow(),
 | 
					            authorization_flow=create_test_flow(),
 | 
				
			||||||
            redirect_uris="http://local.invalid",
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
 | 
				
			||||||
            signing_key=cert,
 | 
					            signing_key=cert,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        app = Application.objects.create(name="test", slug="test", provider=provider)
 | 
					        app = Application.objects.create(name="test", slug="test", provider=provider)
 | 
				
			||||||
 | 
				
			|||||||
@ -10,7 +10,14 @@ from django.utils import timezone
 | 
				
			|||||||
from authentik.core.models import Application
 | 
					from authentik.core.models import Application
 | 
				
			||||||
from authentik.core.tests.utils import create_test_admin_user, create_test_cert, create_test_flow
 | 
					from authentik.core.tests.utils import create_test_admin_user, create_test_cert, create_test_flow
 | 
				
			||||||
from authentik.lib.generators import generate_id
 | 
					from authentik.lib.generators import generate_id
 | 
				
			||||||
from authentik.providers.oauth2.models import AccessToken, IDToken, OAuth2Provider, RefreshToken
 | 
					from authentik.providers.oauth2.models import (
 | 
				
			||||||
 | 
					    AccessToken,
 | 
				
			||||||
 | 
					    IDToken,
 | 
				
			||||||
 | 
					    OAuth2Provider,
 | 
				
			||||||
 | 
					    RedirectURI,
 | 
				
			||||||
 | 
					    RedirectURIMatchingMode,
 | 
				
			||||||
 | 
					    RefreshToken,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
from authentik.providers.oauth2.tests.utils import OAuthTestCase
 | 
					from authentik.providers.oauth2.tests.utils import OAuthTestCase
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -22,7 +29,7 @@ class TesOAuth2Revoke(OAuthTestCase):
 | 
				
			|||||||
        self.provider: OAuth2Provider = OAuth2Provider.objects.create(
 | 
					        self.provider: OAuth2Provider = OAuth2Provider.objects.create(
 | 
				
			||||||
            name=generate_id(),
 | 
					            name=generate_id(),
 | 
				
			||||||
            authorization_flow=create_test_flow(),
 | 
					            authorization_flow=create_test_flow(),
 | 
				
			||||||
            redirect_uris="",
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "")],
 | 
				
			||||||
            signing_key=create_test_cert(),
 | 
					            signing_key=create_test_cert(),
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        self.app = Application.objects.create(
 | 
					        self.app = Application.objects.create(
 | 
				
			||||||
 | 
				
			|||||||
@ -22,6 +22,8 @@ from authentik.providers.oauth2.models import (
 | 
				
			|||||||
    AccessToken,
 | 
					    AccessToken,
 | 
				
			||||||
    AuthorizationCode,
 | 
					    AuthorizationCode,
 | 
				
			||||||
    OAuth2Provider,
 | 
					    OAuth2Provider,
 | 
				
			||||||
 | 
					    RedirectURI,
 | 
				
			||||||
 | 
					    RedirectURIMatchingMode,
 | 
				
			||||||
    RefreshToken,
 | 
					    RefreshToken,
 | 
				
			||||||
    ScopeMapping,
 | 
					    ScopeMapping,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
@ -42,7 +44,7 @@ class TestToken(OAuthTestCase):
 | 
				
			|||||||
        provider = OAuth2Provider.objects.create(
 | 
					        provider = OAuth2Provider.objects.create(
 | 
				
			||||||
            name=generate_id(),
 | 
					            name=generate_id(),
 | 
				
			||||||
            authorization_flow=create_test_flow(),
 | 
					            authorization_flow=create_test_flow(),
 | 
				
			||||||
            redirect_uris="http://TestServer",
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://TestServer")],
 | 
				
			||||||
            signing_key=self.keypair,
 | 
					            signing_key=self.keypair,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
 | 
					        header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
 | 
				
			||||||
@ -69,7 +71,7 @@ class TestToken(OAuthTestCase):
 | 
				
			|||||||
        provider = OAuth2Provider.objects.create(
 | 
					        provider = OAuth2Provider.objects.create(
 | 
				
			||||||
            name=generate_id(),
 | 
					            name=generate_id(),
 | 
				
			||||||
            authorization_flow=create_test_flow(),
 | 
					            authorization_flow=create_test_flow(),
 | 
				
			||||||
            redirect_uris="http://testserver",
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")],
 | 
				
			||||||
            signing_key=self.keypair,
 | 
					            signing_key=self.keypair,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
 | 
					        header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
 | 
				
			||||||
@ -90,7 +92,7 @@ class TestToken(OAuthTestCase):
 | 
				
			|||||||
        provider = OAuth2Provider.objects.create(
 | 
					        provider = OAuth2Provider.objects.create(
 | 
				
			||||||
            name=generate_id(),
 | 
					            name=generate_id(),
 | 
				
			||||||
            authorization_flow=create_test_flow(),
 | 
					            authorization_flow=create_test_flow(),
 | 
				
			||||||
            redirect_uris="http://local.invalid",
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
 | 
				
			||||||
            signing_key=self.keypair,
 | 
					            signing_key=self.keypair,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
 | 
					        header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
 | 
				
			||||||
@ -118,7 +120,7 @@ class TestToken(OAuthTestCase):
 | 
				
			|||||||
        provider = OAuth2Provider.objects.create(
 | 
					        provider = OAuth2Provider.objects.create(
 | 
				
			||||||
            name=generate_id(),
 | 
					            name=generate_id(),
 | 
				
			||||||
            authorization_flow=create_test_flow(),
 | 
					            authorization_flow=create_test_flow(),
 | 
				
			||||||
            redirect_uris="http://local.invalid",
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
 | 
				
			||||||
            signing_key=self.keypair,
 | 
					            signing_key=self.keypair,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        # Needs to be assigned to an application for iss to be set
 | 
					        # Needs to be assigned to an application for iss to be set
 | 
				
			||||||
@ -158,7 +160,7 @@ class TestToken(OAuthTestCase):
 | 
				
			|||||||
        provider = OAuth2Provider.objects.create(
 | 
					        provider = OAuth2Provider.objects.create(
 | 
				
			||||||
            name=generate_id(),
 | 
					            name=generate_id(),
 | 
				
			||||||
            authorization_flow=create_test_flow(),
 | 
					            authorization_flow=create_test_flow(),
 | 
				
			||||||
            redirect_uris="http://local.invalid",
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
 | 
				
			||||||
            signing_key=self.keypair,
 | 
					            signing_key=self.keypair,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        provider.property_mappings.set(
 | 
					        provider.property_mappings.set(
 | 
				
			||||||
@ -220,7 +222,7 @@ class TestToken(OAuthTestCase):
 | 
				
			|||||||
        provider = OAuth2Provider.objects.create(
 | 
					        provider = OAuth2Provider.objects.create(
 | 
				
			||||||
            name=generate_id(),
 | 
					            name=generate_id(),
 | 
				
			||||||
            authorization_flow=create_test_flow(),
 | 
					            authorization_flow=create_test_flow(),
 | 
				
			||||||
            redirect_uris="http://local.invalid",
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
 | 
				
			||||||
            signing_key=self.keypair,
 | 
					            signing_key=self.keypair,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        provider.property_mappings.set(
 | 
					        provider.property_mappings.set(
 | 
				
			||||||
@ -278,7 +280,7 @@ class TestToken(OAuthTestCase):
 | 
				
			|||||||
        provider = OAuth2Provider.objects.create(
 | 
					        provider = OAuth2Provider.objects.create(
 | 
				
			||||||
            name=generate_id(),
 | 
					            name=generate_id(),
 | 
				
			||||||
            authorization_flow=create_test_flow(),
 | 
					            authorization_flow=create_test_flow(),
 | 
				
			||||||
            redirect_uris="http://testserver",
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")],
 | 
				
			||||||
            signing_key=self.keypair,
 | 
					            signing_key=self.keypair,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        provider.property_mappings.set(
 | 
					        provider.property_mappings.set(
 | 
				
			||||||
 | 
				
			|||||||
@ -19,7 +19,12 @@ from authentik.providers.oauth2.constants import (
 | 
				
			|||||||
    SCOPE_OPENID_PROFILE,
 | 
					    SCOPE_OPENID_PROFILE,
 | 
				
			||||||
    TOKEN_TYPE,
 | 
					    TOKEN_TYPE,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from authentik.providers.oauth2.models import OAuth2Provider, ScopeMapping
 | 
					from authentik.providers.oauth2.models import (
 | 
				
			||||||
 | 
					    OAuth2Provider,
 | 
				
			||||||
 | 
					    RedirectURI,
 | 
				
			||||||
 | 
					    RedirectURIMatchingMode,
 | 
				
			||||||
 | 
					    ScopeMapping,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
from authentik.providers.oauth2.tests.utils import OAuthTestCase
 | 
					from authentik.providers.oauth2.tests.utils import OAuthTestCase
 | 
				
			||||||
from authentik.providers.oauth2.views.jwks import JWKSView
 | 
					from authentik.providers.oauth2.views.jwks import JWKSView
 | 
				
			||||||
from authentik.sources.oauth.models import OAuthSource
 | 
					from authentik.sources.oauth.models import OAuthSource
 | 
				
			||||||
@ -54,7 +59,7 @@ class TestTokenClientCredentialsJWTSource(OAuthTestCase):
 | 
				
			|||||||
        self.provider: OAuth2Provider = OAuth2Provider.objects.create(
 | 
					        self.provider: OAuth2Provider = OAuth2Provider.objects.create(
 | 
				
			||||||
            name="test",
 | 
					            name="test",
 | 
				
			||||||
            authorization_flow=create_test_flow(),
 | 
					            authorization_flow=create_test_flow(),
 | 
				
			||||||
            redirect_uris="http://testserver",
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")],
 | 
				
			||||||
            signing_key=self.cert,
 | 
					            signing_key=self.cert,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        self.provider.jwks_sources.add(self.source)
 | 
					        self.provider.jwks_sources.add(self.source)
 | 
				
			||||||
 | 
				
			|||||||
@ -19,7 +19,13 @@ from authentik.providers.oauth2.constants import (
 | 
				
			|||||||
    TOKEN_TYPE,
 | 
					    TOKEN_TYPE,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from authentik.providers.oauth2.errors import TokenError
 | 
					from authentik.providers.oauth2.errors import TokenError
 | 
				
			||||||
from authentik.providers.oauth2.models import OAuth2Provider, ScopeMapping
 | 
					from authentik.providers.oauth2.models import (
 | 
				
			||||||
 | 
					    AccessToken,
 | 
				
			||||||
 | 
					    OAuth2Provider,
 | 
				
			||||||
 | 
					    RedirectURI,
 | 
				
			||||||
 | 
					    RedirectURIMatchingMode,
 | 
				
			||||||
 | 
					    ScopeMapping,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
from authentik.providers.oauth2.tests.utils import OAuthTestCase
 | 
					from authentik.providers.oauth2.tests.utils import OAuthTestCase
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -33,7 +39,7 @@ class TestTokenClientCredentialsStandard(OAuthTestCase):
 | 
				
			|||||||
        self.provider = OAuth2Provider.objects.create(
 | 
					        self.provider = OAuth2Provider.objects.create(
 | 
				
			||||||
            name="test",
 | 
					            name="test",
 | 
				
			||||||
            authorization_flow=create_test_flow(),
 | 
					            authorization_flow=create_test_flow(),
 | 
				
			||||||
            redirect_uris="http://testserver",
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")],
 | 
				
			||||||
            signing_key=create_test_cert(),
 | 
					            signing_key=create_test_cert(),
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        self.provider.property_mappings.set(ScopeMapping.objects.all())
 | 
					        self.provider.property_mappings.set(ScopeMapping.objects.all())
 | 
				
			||||||
@ -107,6 +113,48 @@ class TestTokenClientCredentialsStandard(OAuthTestCase):
 | 
				
			|||||||
            {"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]},
 | 
					            {"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]},
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_incorrect_scopes(self):
 | 
				
			||||||
 | 
					        """test scope that isn't configured"""
 | 
				
			||||||
 | 
					        response = self.client.post(
 | 
				
			||||||
 | 
					            reverse("authentik_providers_oauth2:token"),
 | 
				
			||||||
 | 
					            {
 | 
				
			||||||
 | 
					                "grant_type": GRANT_TYPE_CLIENT_CREDENTIALS,
 | 
				
			||||||
 | 
					                "scope": f"{SCOPE_OPENID} {SCOPE_OPENID_EMAIL} {SCOPE_OPENID_PROFILE} extra_scope",
 | 
				
			||||||
 | 
					                "client_id": self.provider.client_id,
 | 
				
			||||||
 | 
					                "client_secret": self.provider.client_secret,
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual(response.status_code, 200)
 | 
				
			||||||
 | 
					        body = loads(response.content.decode())
 | 
				
			||||||
 | 
					        self.assertEqual(body["token_type"], TOKEN_TYPE)
 | 
				
			||||||
 | 
					        token = AccessToken.objects.filter(
 | 
				
			||||||
 | 
					            provider=self.provider, token=body["access_token"]
 | 
				
			||||||
 | 
					        ).first()
 | 
				
			||||||
 | 
					        self.assertSetEqual(
 | 
				
			||||||
 | 
					            set(token.scope), {SCOPE_OPENID, SCOPE_OPENID_EMAIL, SCOPE_OPENID_PROFILE}
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        _, alg = self.provider.jwt_key
 | 
				
			||||||
 | 
					        jwt = decode(
 | 
				
			||||||
 | 
					            body["access_token"],
 | 
				
			||||||
 | 
					            key=self.provider.signing_key.public_key,
 | 
				
			||||||
 | 
					            algorithms=[alg],
 | 
				
			||||||
 | 
					            audience=self.provider.client_id,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            jwt["given_name"], "Autogenerated user from application test (client credentials)"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual(jwt["preferred_username"], "ak-test-client_credentials")
 | 
				
			||||||
 | 
					        jwt = decode(
 | 
				
			||||||
 | 
					            body["id_token"],
 | 
				
			||||||
 | 
					            key=self.provider.signing_key.public_key,
 | 
				
			||||||
 | 
					            algorithms=[alg],
 | 
				
			||||||
 | 
					            audience=self.provider.client_id,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            jwt["given_name"], "Autogenerated user from application test (client credentials)"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual(jwt["preferred_username"], "ak-test-client_credentials")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_successful(self):
 | 
					    def test_successful(self):
 | 
				
			||||||
        """test successful"""
 | 
					        """test successful"""
 | 
				
			||||||
        response = self.client.post(
 | 
					        response = self.client.post(
 | 
				
			||||||
 | 
				
			|||||||
@ -20,7 +20,12 @@ from authentik.providers.oauth2.constants import (
 | 
				
			|||||||
    TOKEN_TYPE,
 | 
					    TOKEN_TYPE,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from authentik.providers.oauth2.errors import TokenError
 | 
					from authentik.providers.oauth2.errors import TokenError
 | 
				
			||||||
from authentik.providers.oauth2.models import OAuth2Provider, ScopeMapping
 | 
					from authentik.providers.oauth2.models import (
 | 
				
			||||||
 | 
					    OAuth2Provider,
 | 
				
			||||||
 | 
					    RedirectURI,
 | 
				
			||||||
 | 
					    RedirectURIMatchingMode,
 | 
				
			||||||
 | 
					    ScopeMapping,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
from authentik.providers.oauth2.tests.utils import OAuthTestCase
 | 
					from authentik.providers.oauth2.tests.utils import OAuthTestCase
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -34,7 +39,7 @@ class TestTokenClientCredentialsStandardCompat(OAuthTestCase):
 | 
				
			|||||||
        self.provider = OAuth2Provider.objects.create(
 | 
					        self.provider = OAuth2Provider.objects.create(
 | 
				
			||||||
            name="test",
 | 
					            name="test",
 | 
				
			||||||
            authorization_flow=create_test_flow(),
 | 
					            authorization_flow=create_test_flow(),
 | 
				
			||||||
            redirect_uris="http://testserver",
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")],
 | 
				
			||||||
            signing_key=create_test_cert(),
 | 
					            signing_key=create_test_cert(),
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        self.provider.property_mappings.set(ScopeMapping.objects.all())
 | 
					        self.provider.property_mappings.set(ScopeMapping.objects.all())
 | 
				
			||||||
 | 
				
			|||||||
@ -19,7 +19,12 @@ from authentik.providers.oauth2.constants import (
 | 
				
			|||||||
    TOKEN_TYPE,
 | 
					    TOKEN_TYPE,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from authentik.providers.oauth2.errors import TokenError
 | 
					from authentik.providers.oauth2.errors import TokenError
 | 
				
			||||||
from authentik.providers.oauth2.models import OAuth2Provider, ScopeMapping
 | 
					from authentik.providers.oauth2.models import (
 | 
				
			||||||
 | 
					    OAuth2Provider,
 | 
				
			||||||
 | 
					    RedirectURI,
 | 
				
			||||||
 | 
					    RedirectURIMatchingMode,
 | 
				
			||||||
 | 
					    ScopeMapping,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
from authentik.providers.oauth2.tests.utils import OAuthTestCase
 | 
					from authentik.providers.oauth2.tests.utils import OAuthTestCase
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -33,7 +38,7 @@ class TestTokenClientCredentialsUserNamePassword(OAuthTestCase):
 | 
				
			|||||||
        self.provider = OAuth2Provider.objects.create(
 | 
					        self.provider = OAuth2Provider.objects.create(
 | 
				
			||||||
            name="test",
 | 
					            name="test",
 | 
				
			||||||
            authorization_flow=create_test_flow(),
 | 
					            authorization_flow=create_test_flow(),
 | 
				
			||||||
            redirect_uris="http://testserver",
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")],
 | 
				
			||||||
            signing_key=create_test_cert(),
 | 
					            signing_key=create_test_cert(),
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        self.provider.property_mappings.set(ScopeMapping.objects.all())
 | 
					        self.provider.property_mappings.set(ScopeMapping.objects.all())
 | 
				
			||||||
 | 
				
			|||||||
@ -9,8 +9,19 @@ from authentik.blueprints.tests import apply_blueprint
 | 
				
			|||||||
from authentik.core.models import Application
 | 
					from authentik.core.models import Application
 | 
				
			||||||
from authentik.core.tests.utils import create_test_admin_user, create_test_cert, create_test_flow
 | 
					from authentik.core.tests.utils import create_test_admin_user, create_test_cert, create_test_flow
 | 
				
			||||||
from authentik.lib.generators import generate_code_fixed_length, generate_id
 | 
					from authentik.lib.generators import generate_code_fixed_length, generate_id
 | 
				
			||||||
from authentik.providers.oauth2.constants import GRANT_TYPE_DEVICE_CODE
 | 
					from authentik.providers.oauth2.constants import (
 | 
				
			||||||
from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider, ScopeMapping
 | 
					    GRANT_TYPE_DEVICE_CODE,
 | 
				
			||||||
 | 
					    SCOPE_OPENID,
 | 
				
			||||||
 | 
					    SCOPE_OPENID_EMAIL,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					from authentik.providers.oauth2.models import (
 | 
				
			||||||
 | 
					    AccessToken,
 | 
				
			||||||
 | 
					    DeviceToken,
 | 
				
			||||||
 | 
					    OAuth2Provider,
 | 
				
			||||||
 | 
					    RedirectURI,
 | 
				
			||||||
 | 
					    RedirectURIMatchingMode,
 | 
				
			||||||
 | 
					    ScopeMapping,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
from authentik.providers.oauth2.tests.utils import OAuthTestCase
 | 
					from authentik.providers.oauth2.tests.utils import OAuthTestCase
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -24,7 +35,7 @@ class TestTokenDeviceCode(OAuthTestCase):
 | 
				
			|||||||
        self.provider = OAuth2Provider.objects.create(
 | 
					        self.provider = OAuth2Provider.objects.create(
 | 
				
			||||||
            name="test",
 | 
					            name="test",
 | 
				
			||||||
            authorization_flow=create_test_flow(),
 | 
					            authorization_flow=create_test_flow(),
 | 
				
			||||||
            redirect_uris="http://testserver",
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")],
 | 
				
			||||||
            signing_key=create_test_cert(),
 | 
					            signing_key=create_test_cert(),
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        self.provider.property_mappings.set(ScopeMapping.objects.all())
 | 
					        self.provider.property_mappings.set(ScopeMapping.objects.all())
 | 
				
			||||||
@ -80,3 +91,28 @@ class TestTokenDeviceCode(OAuthTestCase):
 | 
				
			|||||||
            },
 | 
					            },
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        self.assertEqual(res.status_code, 200)
 | 
					        self.assertEqual(res.status_code, 200)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_code_mismatched_scope(self):
 | 
				
			||||||
 | 
					        """Test code with user (mismatched scopes)"""
 | 
				
			||||||
 | 
					        device_token = DeviceToken.objects.create(
 | 
				
			||||||
 | 
					            provider=self.provider,
 | 
				
			||||||
 | 
					            user_code=generate_code_fixed_length(),
 | 
				
			||||||
 | 
					            device_code=generate_id(),
 | 
				
			||||||
 | 
					            user=self.user,
 | 
				
			||||||
 | 
					            scope=[SCOPE_OPENID, SCOPE_OPENID_EMAIL],
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        res = self.client.post(
 | 
				
			||||||
 | 
					            reverse("authentik_providers_oauth2:token"),
 | 
				
			||||||
 | 
					            data={
 | 
				
			||||||
 | 
					                "client_id": self.provider.client_id,
 | 
				
			||||||
 | 
					                "grant_type": GRANT_TYPE_DEVICE_CODE,
 | 
				
			||||||
 | 
					                "device_code": device_token.device_code,
 | 
				
			||||||
 | 
					                "scope": f"{SCOPE_OPENID} {SCOPE_OPENID_EMAIL} invalid",
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual(res.status_code, 200)
 | 
				
			||||||
 | 
					        body = loads(res.content)
 | 
				
			||||||
 | 
					        token = AccessToken.objects.filter(
 | 
				
			||||||
 | 
					            provider=self.provider, token=body["access_token"]
 | 
				
			||||||
 | 
					        ).first()
 | 
				
			||||||
 | 
					        self.assertSetEqual(set(token.scope), {SCOPE_OPENID, SCOPE_OPENID_EMAIL})
 | 
				
			||||||
 | 
				
			|||||||
@ -10,7 +10,12 @@ from authentik.core.models import Application
 | 
				
			|||||||
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
 | 
					from authentik.core.tests.utils import create_test_admin_user, create_test_flow
 | 
				
			||||||
from authentik.lib.generators import generate_id
 | 
					from authentik.lib.generators import generate_id
 | 
				
			||||||
from authentik.providers.oauth2.constants import GRANT_TYPE_AUTHORIZATION_CODE
 | 
					from authentik.providers.oauth2.constants import GRANT_TYPE_AUTHORIZATION_CODE
 | 
				
			||||||
from authentik.providers.oauth2.models import AuthorizationCode, OAuth2Provider
 | 
					from authentik.providers.oauth2.models import (
 | 
				
			||||||
 | 
					    AuthorizationCode,
 | 
				
			||||||
 | 
					    OAuth2Provider,
 | 
				
			||||||
 | 
					    RedirectURI,
 | 
				
			||||||
 | 
					    RedirectURIMatchingMode,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
from authentik.providers.oauth2.tests.utils import OAuthTestCase
 | 
					from authentik.providers.oauth2.tests.utils import OAuthTestCase
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -30,7 +35,7 @@ class TestTokenPKCE(OAuthTestCase):
 | 
				
			|||||||
            name=generate_id(),
 | 
					            name=generate_id(),
 | 
				
			||||||
            client_id="test",
 | 
					            client_id="test",
 | 
				
			||||||
            authorization_flow=flow,
 | 
					            authorization_flow=flow,
 | 
				
			||||||
            redirect_uris="foo://localhost",
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "foo://localhost")],
 | 
				
			||||||
            access_code_validity="seconds=100",
 | 
					            access_code_validity="seconds=100",
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        Application.objects.create(name="app", slug="app", provider=provider)
 | 
					        Application.objects.create(name="app", slug="app", provider=provider)
 | 
				
			||||||
@ -93,7 +98,7 @@ class TestTokenPKCE(OAuthTestCase):
 | 
				
			|||||||
            name=generate_id(),
 | 
					            name=generate_id(),
 | 
				
			||||||
            client_id="test",
 | 
					            client_id="test",
 | 
				
			||||||
            authorization_flow=flow,
 | 
					            authorization_flow=flow,
 | 
				
			||||||
            redirect_uris="foo://localhost",
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "foo://localhost")],
 | 
				
			||||||
            access_code_validity="seconds=100",
 | 
					            access_code_validity="seconds=100",
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        Application.objects.create(name="app", slug="app", provider=provider)
 | 
					        Application.objects.create(name="app", slug="app", provider=provider)
 | 
				
			||||||
@ -154,7 +159,7 @@ class TestTokenPKCE(OAuthTestCase):
 | 
				
			|||||||
            name=generate_id(),
 | 
					            name=generate_id(),
 | 
				
			||||||
            client_id="test",
 | 
					            client_id="test",
 | 
				
			||||||
            authorization_flow=flow,
 | 
					            authorization_flow=flow,
 | 
				
			||||||
            redirect_uris="foo://localhost",
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "foo://localhost")],
 | 
				
			||||||
            access_code_validity="seconds=100",
 | 
					            access_code_validity="seconds=100",
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        Application.objects.create(name="app", slug="app", provider=provider)
 | 
					        Application.objects.create(name="app", slug="app", provider=provider)
 | 
				
			||||||
@ -210,7 +215,7 @@ class TestTokenPKCE(OAuthTestCase):
 | 
				
			|||||||
            name=generate_id(),
 | 
					            name=generate_id(),
 | 
				
			||||||
            client_id="test",
 | 
					            client_id="test",
 | 
				
			||||||
            authorization_flow=flow,
 | 
					            authorization_flow=flow,
 | 
				
			||||||
            redirect_uris="foo://localhost",
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "foo://localhost")],
 | 
				
			||||||
            access_code_validity="seconds=100",
 | 
					            access_code_validity="seconds=100",
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        Application.objects.create(name="app", slug="app", provider=provider)
 | 
					        Application.objects.create(name="app", slug="app", provider=provider)
 | 
				
			||||||
 | 
				
			|||||||
@ -11,7 +11,14 @@ from authentik.core.models import Application
 | 
				
			|||||||
from authentik.core.tests.utils import create_test_admin_user, create_test_cert, create_test_flow
 | 
					from authentik.core.tests.utils import create_test_admin_user, create_test_cert, create_test_flow
 | 
				
			||||||
from authentik.events.models import Event, EventAction
 | 
					from authentik.events.models import Event, EventAction
 | 
				
			||||||
from authentik.lib.generators import generate_id
 | 
					from authentik.lib.generators import generate_id
 | 
				
			||||||
from authentik.providers.oauth2.models import AccessToken, IDToken, OAuth2Provider, ScopeMapping
 | 
					from authentik.providers.oauth2.models import (
 | 
				
			||||||
 | 
					    AccessToken,
 | 
				
			||||||
 | 
					    IDToken,
 | 
				
			||||||
 | 
					    OAuth2Provider,
 | 
				
			||||||
 | 
					    RedirectURI,
 | 
				
			||||||
 | 
					    RedirectURIMatchingMode,
 | 
				
			||||||
 | 
					    ScopeMapping,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
from authentik.providers.oauth2.tests.utils import OAuthTestCase
 | 
					from authentik.providers.oauth2.tests.utils import OAuthTestCase
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -25,7 +32,7 @@ class TestUserinfo(OAuthTestCase):
 | 
				
			|||||||
        self.provider: OAuth2Provider = OAuth2Provider.objects.create(
 | 
					        self.provider: OAuth2Provider = OAuth2Provider.objects.create(
 | 
				
			||||||
            name=generate_id(),
 | 
					            name=generate_id(),
 | 
				
			||||||
            authorization_flow=create_test_flow(),
 | 
					            authorization_flow=create_test_flow(),
 | 
				
			||||||
            redirect_uris="",
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "")],
 | 
				
			||||||
            signing_key=create_test_cert(),
 | 
					            signing_key=create_test_cert(),
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        self.provider.property_mappings.set(ScopeMapping.objects.all())
 | 
					        self.provider.property_mappings.set(ScopeMapping.objects.all())
 | 
				
			||||||
 | 
				
			|||||||
@ -2,7 +2,6 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from dataclasses import InitVar, dataclass, field
 | 
					from dataclasses import InitVar, dataclass, field
 | 
				
			||||||
from datetime import timedelta
 | 
					from datetime import timedelta
 | 
				
			||||||
from hashlib import sha256
 | 
					 | 
				
			||||||
from json import dumps
 | 
					from json import dumps
 | 
				
			||||||
from re import error as RegexError
 | 
					from re import error as RegexError
 | 
				
			||||||
from re import fullmatch
 | 
					from re import fullmatch
 | 
				
			||||||
@ -16,7 +15,7 @@ from django.utils import timezone
 | 
				
			|||||||
from django.utils.translation import gettext as _
 | 
					from django.utils.translation import gettext as _
 | 
				
			||||||
from structlog.stdlib import get_logger
 | 
					from structlog.stdlib import get_logger
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from authentik.core.models import Application
 | 
					from authentik.core.models import Application, AuthenticatedSession
 | 
				
			||||||
from authentik.events.models import Event, EventAction
 | 
					from authentik.events.models import Event, EventAction
 | 
				
			||||||
from authentik.events.signals import get_login_event
 | 
					from authentik.events.signals import get_login_event
 | 
				
			||||||
from authentik.flows.challenge import (
 | 
					from authentik.flows.challenge import (
 | 
				
			||||||
@ -57,6 +56,8 @@ from authentik.providers.oauth2.models import (
 | 
				
			|||||||
    AuthorizationCode,
 | 
					    AuthorizationCode,
 | 
				
			||||||
    GrantTypes,
 | 
					    GrantTypes,
 | 
				
			||||||
    OAuth2Provider,
 | 
					    OAuth2Provider,
 | 
				
			||||||
 | 
					    RedirectURI,
 | 
				
			||||||
 | 
					    RedirectURIMatchingMode,
 | 
				
			||||||
    ResponseMode,
 | 
					    ResponseMode,
 | 
				
			||||||
    ResponseTypes,
 | 
					    ResponseTypes,
 | 
				
			||||||
    ScopeMapping,
 | 
					    ScopeMapping,
 | 
				
			||||||
@ -188,40 +189,39 @@ class OAuthAuthorizationParams:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def check_redirect_uri(self):
 | 
					    def check_redirect_uri(self):
 | 
				
			||||||
        """Redirect URI validation."""
 | 
					        """Redirect URI validation."""
 | 
				
			||||||
        allowed_redirect_urls = self.provider.redirect_uris.split()
 | 
					        allowed_redirect_urls = self.provider.redirect_uris
 | 
				
			||||||
        if not self.redirect_uri:
 | 
					        if not self.redirect_uri:
 | 
				
			||||||
            LOGGER.warning("Missing redirect uri.")
 | 
					            LOGGER.warning("Missing redirect uri.")
 | 
				
			||||||
            raise RedirectUriError("", allowed_redirect_urls)
 | 
					            raise RedirectUriError("", allowed_redirect_urls)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self.provider.redirect_uris == "":
 | 
					        if len(allowed_redirect_urls) < 1:
 | 
				
			||||||
            LOGGER.info("Setting redirect for blank redirect_uris", redirect=self.redirect_uri)
 | 
					            LOGGER.info("Setting redirect for blank redirect_uris", redirect=self.redirect_uri)
 | 
				
			||||||
            self.provider.redirect_uris = self.redirect_uri
 | 
					            self.provider.redirect_uris = [
 | 
				
			||||||
 | 
					                RedirectURI(RedirectURIMatchingMode.STRICT, self.redirect_uri)
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
            self.provider.save()
 | 
					            self.provider.save()
 | 
				
			||||||
            allowed_redirect_urls = self.provider.redirect_uris.split()
 | 
					            allowed_redirect_urls = self.provider.redirect_uris
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self.provider.redirect_uris == "*":
 | 
					        match_found = False
 | 
				
			||||||
            LOGGER.info("Converting redirect_uris to regex", redirect=self.redirect_uri)
 | 
					        for allowed in allowed_redirect_urls:
 | 
				
			||||||
            self.provider.redirect_uris = ".*"
 | 
					            if allowed.matching_mode == RedirectURIMatchingMode.STRICT:
 | 
				
			||||||
            self.provider.save()
 | 
					                if self.redirect_uri == allowed.url:
 | 
				
			||||||
            allowed_redirect_urls = self.provider.redirect_uris.split()
 | 
					                    match_found = True
 | 
				
			||||||
 | 
					                    break
 | 
				
			||||||
        try:
 | 
					            if allowed.matching_mode == RedirectURIMatchingMode.REGEX:
 | 
				
			||||||
            if not any(fullmatch(x, self.redirect_uri) for x in allowed_redirect_urls):
 | 
					                try:
 | 
				
			||||||
                LOGGER.warning(
 | 
					                    if fullmatch(allowed.url, self.redirect_uri):
 | 
				
			||||||
                    "Invalid redirect uri (regex comparison)",
 | 
					                        match_found = True
 | 
				
			||||||
                    redirect_uri_given=self.redirect_uri,
 | 
					                        break
 | 
				
			||||||
                    redirect_uri_expected=allowed_redirect_urls,
 | 
					                except RegexError as exc:
 | 
				
			||||||
                )
 | 
					                    LOGGER.warning(
 | 
				
			||||||
                raise RedirectUriError(self.redirect_uri, allowed_redirect_urls)
 | 
					                        "Failed to parse regular expression",
 | 
				
			||||||
        except RegexError as exc:
 | 
					                        exc=exc,
 | 
				
			||||||
            LOGGER.info("Failed to parse regular expression, checking directly", exc=exc)
 | 
					                        url=allowed.url,
 | 
				
			||||||
            if not any(x == self.redirect_uri for x in allowed_redirect_urls):
 | 
					                        provider=self.provider,
 | 
				
			||||||
                LOGGER.warning(
 | 
					                    )
 | 
				
			||||||
                    "Invalid redirect uri (strict comparison)",
 | 
					        if not match_found:
 | 
				
			||||||
                    redirect_uri_given=self.redirect_uri,
 | 
					            raise RedirectUriError(self.redirect_uri, allowed_redirect_urls)
 | 
				
			||||||
                    redirect_uri_expected=allowed_redirect_urls,
 | 
					 | 
				
			||||||
                )
 | 
					 | 
				
			||||||
                raise RedirectUriError(self.redirect_uri, allowed_redirect_urls) from None
 | 
					 | 
				
			||||||
        # Check against forbidden schemes
 | 
					        # Check against forbidden schemes
 | 
				
			||||||
        if urlparse(self.redirect_uri).scheme in FORBIDDEN_URI_SCHEMES:
 | 
					        if urlparse(self.redirect_uri).scheme in FORBIDDEN_URI_SCHEMES:
 | 
				
			||||||
            raise RedirectUriError(self.redirect_uri, allowed_redirect_urls)
 | 
					            raise RedirectUriError(self.redirect_uri, allowed_redirect_urls)
 | 
				
			||||||
@ -318,7 +318,9 @@ class OAuthAuthorizationParams:
 | 
				
			|||||||
            expires=now + timedelta_from_string(self.provider.access_code_validity),
 | 
					            expires=now + timedelta_from_string(self.provider.access_code_validity),
 | 
				
			||||||
            scope=self.scope,
 | 
					            scope=self.scope,
 | 
				
			||||||
            nonce=self.nonce,
 | 
					            nonce=self.nonce,
 | 
				
			||||||
            session_id=sha256(request.session.session_key.encode("ascii")).hexdigest(),
 | 
					            session=AuthenticatedSession.objects.filter(
 | 
				
			||||||
 | 
					                session_key=request.session.session_key
 | 
				
			||||||
 | 
					            ).first(),
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self.code_challenge and self.code_challenge_method:
 | 
					        if self.code_challenge and self.code_challenge_method:
 | 
				
			||||||
@ -610,7 +612,9 @@ class OAuthFulfillmentStage(StageView):
 | 
				
			|||||||
            expires=access_token_expiry,
 | 
					            expires=access_token_expiry,
 | 
				
			||||||
            provider=self.provider,
 | 
					            provider=self.provider,
 | 
				
			||||||
            auth_time=auth_event.created if auth_event else now,
 | 
					            auth_time=auth_event.created if auth_event else now,
 | 
				
			||||||
            session_id=sha256(self.request.session.session_key.encode("ascii")).hexdigest(),
 | 
					            session=AuthenticatedSession.objects.filter(
 | 
				
			||||||
 | 
					                session_key=self.request.session.session_key
 | 
				
			||||||
 | 
					            ).first(),
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        id_token = IDToken.new(self.provider, token, self.request)
 | 
					        id_token = IDToken.new(self.provider, token, self.request)
 | 
				
			||||||
 | 
				
			|||||||
@ -46,10 +46,10 @@ class TokenIntrospectionParams:
 | 
				
			|||||||
        if not provider:
 | 
					        if not provider:
 | 
				
			||||||
            raise TokenIntrospectionError
 | 
					            raise TokenIntrospectionError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        access_token = AccessToken.objects.filter(token=raw_token).first()
 | 
					        access_token = AccessToken.objects.filter(token=raw_token, provider=provider).first()
 | 
				
			||||||
        if access_token:
 | 
					        if access_token:
 | 
				
			||||||
            return TokenIntrospectionParams(access_token, provider)
 | 
					            return TokenIntrospectionParams(access_token, provider)
 | 
				
			||||||
        refresh_token = RefreshToken.objects.filter(token=raw_token).first()
 | 
					        refresh_token = RefreshToken.objects.filter(token=raw_token, provider=provider).first()
 | 
				
			||||||
        if refresh_token:
 | 
					        if refresh_token:
 | 
				
			||||||
            return TokenIntrospectionParams(refresh_token, provider)
 | 
					            return TokenIntrospectionParams(refresh_token, provider)
 | 
				
			||||||
        LOGGER.debug("Token does not exist", token=raw_token)
 | 
					        LOGGER.debug("Token does not exist", token=raw_token)
 | 
				
			||||||
 | 
				
			|||||||
@ -158,5 +158,5 @@ class ProviderInfoView(View):
 | 
				
			|||||||
            OAuth2Provider, pk=application.provider_id
 | 
					            OAuth2Provider, pk=application.provider_id
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        response = super().dispatch(request, *args, **kwargs)
 | 
					        response = super().dispatch(request, *args, **kwargs)
 | 
				
			||||||
        cors_allow(request, response, *self.provider.redirect_uris.split("\n"))
 | 
					        cors_allow(request, response, *[x.url for x in self.provider.redirect_uris])
 | 
				
			||||||
        return response
 | 
					        return response
 | 
				
			||||||
 | 
				
			|||||||
@ -58,7 +58,9 @@ from authentik.providers.oauth2.models import (
 | 
				
			|||||||
    ClientTypes,
 | 
					    ClientTypes,
 | 
				
			||||||
    DeviceToken,
 | 
					    DeviceToken,
 | 
				
			||||||
    OAuth2Provider,
 | 
					    OAuth2Provider,
 | 
				
			||||||
 | 
					    RedirectURIMatchingMode,
 | 
				
			||||||
    RefreshToken,
 | 
					    RefreshToken,
 | 
				
			||||||
 | 
					    ScopeMapping,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from authentik.providers.oauth2.utils import TokenResponse, cors_allow, extract_client_auth
 | 
					from authentik.providers.oauth2.utils import TokenResponse, cors_allow, extract_client_auth
 | 
				
			||||||
from authentik.providers.oauth2.views.authorize import FORBIDDEN_URI_SCHEMES
 | 
					from authentik.providers.oauth2.views.authorize import FORBIDDEN_URI_SCHEMES
 | 
				
			||||||
@ -77,7 +79,7 @@ class TokenParams:
 | 
				
			|||||||
    redirect_uri: str
 | 
					    redirect_uri: str
 | 
				
			||||||
    grant_type: str
 | 
					    grant_type: str
 | 
				
			||||||
    state: str
 | 
					    state: str
 | 
				
			||||||
    scope: list[str]
 | 
					    scope: set[str]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    provider: OAuth2Provider
 | 
					    provider: OAuth2Provider
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -112,11 +114,26 @@ class TokenParams:
 | 
				
			|||||||
            redirect_uri=request.POST.get("redirect_uri", ""),
 | 
					            redirect_uri=request.POST.get("redirect_uri", ""),
 | 
				
			||||||
            grant_type=request.POST.get("grant_type", ""),
 | 
					            grant_type=request.POST.get("grant_type", ""),
 | 
				
			||||||
            state=request.POST.get("state", ""),
 | 
					            state=request.POST.get("state", ""),
 | 
				
			||||||
            scope=request.POST.get("scope", "").split(),
 | 
					            scope=set(request.POST.get("scope", "").split()),
 | 
				
			||||||
            # PKCE parameter.
 | 
					            # PKCE parameter.
 | 
				
			||||||
            code_verifier=request.POST.get("code_verifier"),
 | 
					            code_verifier=request.POST.get("code_verifier"),
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __check_scopes(self):
 | 
				
			||||||
 | 
					        allowed_scope_names = set(
 | 
				
			||||||
 | 
					            ScopeMapping.objects.filter(provider__in=[self.provider]).values_list(
 | 
				
			||||||
 | 
					                "scope_name", flat=True
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        scopes_to_check = self.scope
 | 
				
			||||||
 | 
					        if not scopes_to_check.issubset(allowed_scope_names):
 | 
				
			||||||
 | 
					            LOGGER.info(
 | 
				
			||||||
 | 
					                "Application requested scopes not configured, setting to overlap",
 | 
				
			||||||
 | 
					                scope_allowed=allowed_scope_names,
 | 
				
			||||||
 | 
					                scope_given=self.scope,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            self.scope = self.scope.intersection(allowed_scope_names)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __check_policy_access(self, app: Application, request: HttpRequest, **kwargs):
 | 
					    def __check_policy_access(self, app: Application, request: HttpRequest, **kwargs):
 | 
				
			||||||
        with start_span(
 | 
					        with start_span(
 | 
				
			||||||
            op="authentik.providers.oauth2.token.policy",
 | 
					            op="authentik.providers.oauth2.token.policy",
 | 
				
			||||||
@ -149,7 +166,7 @@ class TokenParams:
 | 
				
			|||||||
                    client_id=self.provider.client_id,
 | 
					                    client_id=self.provider.client_id,
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
                raise TokenError("invalid_client")
 | 
					                raise TokenError("invalid_client")
 | 
				
			||||||
 | 
					        self.__check_scopes()
 | 
				
			||||||
        if self.grant_type == GRANT_TYPE_AUTHORIZATION_CODE:
 | 
					        if self.grant_type == GRANT_TYPE_AUTHORIZATION_CODE:
 | 
				
			||||||
            with start_span(
 | 
					            with start_span(
 | 
				
			||||||
                op="authentik.providers.oauth2.post.parse.code",
 | 
					                op="authentik.providers.oauth2.post.parse.code",
 | 
				
			||||||
@ -179,42 +196,7 @@ class TokenParams:
 | 
				
			|||||||
            LOGGER.warning("Missing authorization code")
 | 
					            LOGGER.warning("Missing authorization code")
 | 
				
			||||||
            raise TokenError("invalid_grant")
 | 
					            raise TokenError("invalid_grant")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        allowed_redirect_urls = self.provider.redirect_uris.split()
 | 
					        self.__check_redirect_uri(request)
 | 
				
			||||||
        # At this point, no provider should have a blank redirect_uri, in case they do
 | 
					 | 
				
			||||||
        # this will check an empty array and raise an error
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            if not any(fullmatch(x, self.redirect_uri) for x in allowed_redirect_urls):
 | 
					 | 
				
			||||||
                LOGGER.warning(
 | 
					 | 
				
			||||||
                    "Invalid redirect uri (regex comparison)",
 | 
					 | 
				
			||||||
                    redirect_uri=self.redirect_uri,
 | 
					 | 
				
			||||||
                    expected=allowed_redirect_urls,
 | 
					 | 
				
			||||||
                )
 | 
					 | 
				
			||||||
                Event.new(
 | 
					 | 
				
			||||||
                    EventAction.CONFIGURATION_ERROR,
 | 
					 | 
				
			||||||
                    message="Invalid redirect URI used by provider",
 | 
					 | 
				
			||||||
                    provider=self.provider,
 | 
					 | 
				
			||||||
                    redirect_uri=self.redirect_uri,
 | 
					 | 
				
			||||||
                    expected=allowed_redirect_urls,
 | 
					 | 
				
			||||||
                ).from_http(request)
 | 
					 | 
				
			||||||
                raise TokenError("invalid_client")
 | 
					 | 
				
			||||||
        except RegexError as exc:
 | 
					 | 
				
			||||||
            LOGGER.info("Failed to parse regular expression, checking directly", exc=exc)
 | 
					 | 
				
			||||||
            if not any(x == self.redirect_uri for x in allowed_redirect_urls):
 | 
					 | 
				
			||||||
                LOGGER.warning(
 | 
					 | 
				
			||||||
                    "Invalid redirect uri (strict comparison)",
 | 
					 | 
				
			||||||
                    redirect_uri=self.redirect_uri,
 | 
					 | 
				
			||||||
                    expected=allowed_redirect_urls,
 | 
					 | 
				
			||||||
                )
 | 
					 | 
				
			||||||
                Event.new(
 | 
					 | 
				
			||||||
                    EventAction.CONFIGURATION_ERROR,
 | 
					 | 
				
			||||||
                    message="Invalid redirect_uri configured",
 | 
					 | 
				
			||||||
                    provider=self.provider,
 | 
					 | 
				
			||||||
                ).from_http(request)
 | 
					 | 
				
			||||||
                raise TokenError("invalid_client") from None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Check against forbidden schemes
 | 
					 | 
				
			||||||
        if urlparse(self.redirect_uri).scheme in FORBIDDEN_URI_SCHEMES:
 | 
					 | 
				
			||||||
            raise TokenError("invalid_request")
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.authorization_code = AuthorizationCode.objects.filter(code=raw_code).first()
 | 
					        self.authorization_code = AuthorizationCode.objects.filter(code=raw_code).first()
 | 
				
			||||||
        if not self.authorization_code:
 | 
					        if not self.authorization_code:
 | 
				
			||||||
@ -254,6 +236,48 @@ class TokenParams:
 | 
				
			|||||||
        if not self.authorization_code.code_challenge and self.code_verifier:
 | 
					        if not self.authorization_code.code_challenge and self.code_verifier:
 | 
				
			||||||
            raise TokenError("invalid_grant")
 | 
					            raise TokenError("invalid_grant")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __check_redirect_uri(self, request: HttpRequest):
 | 
				
			||||||
 | 
					        allowed_redirect_urls = self.provider.redirect_uris
 | 
				
			||||||
 | 
					        # At this point, no provider should have a blank redirect_uri, in case they do
 | 
				
			||||||
 | 
					        # this will check an empty array and raise an error
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        match_found = False
 | 
				
			||||||
 | 
					        for allowed in allowed_redirect_urls:
 | 
				
			||||||
 | 
					            if allowed.matching_mode == RedirectURIMatchingMode.STRICT:
 | 
				
			||||||
 | 
					                if self.redirect_uri == allowed.url:
 | 
				
			||||||
 | 
					                    match_found = True
 | 
				
			||||||
 | 
					                    break
 | 
				
			||||||
 | 
					            if allowed.matching_mode == RedirectURIMatchingMode.REGEX:
 | 
				
			||||||
 | 
					                try:
 | 
				
			||||||
 | 
					                    if fullmatch(allowed.url, self.redirect_uri):
 | 
				
			||||||
 | 
					                        match_found = True
 | 
				
			||||||
 | 
					                        break
 | 
				
			||||||
 | 
					                except RegexError as exc:
 | 
				
			||||||
 | 
					                    LOGGER.warning(
 | 
				
			||||||
 | 
					                        "Failed to parse regular expression",
 | 
				
			||||||
 | 
					                        exc=exc,
 | 
				
			||||||
 | 
					                        url=allowed.url,
 | 
				
			||||||
 | 
					                        provider=self.provider,
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                    Event.new(
 | 
				
			||||||
 | 
					                        EventAction.CONFIGURATION_ERROR,
 | 
				
			||||||
 | 
					                        message="Invalid redirect_uri configured",
 | 
				
			||||||
 | 
					                        provider=self.provider,
 | 
				
			||||||
 | 
					                    ).from_http(request)
 | 
				
			||||||
 | 
					        if not match_found:
 | 
				
			||||||
 | 
					            Event.new(
 | 
				
			||||||
 | 
					                EventAction.CONFIGURATION_ERROR,
 | 
				
			||||||
 | 
					                message="Invalid redirect URI used by provider",
 | 
				
			||||||
 | 
					                provider=self.provider,
 | 
				
			||||||
 | 
					                redirect_uri=self.redirect_uri,
 | 
				
			||||||
 | 
					                expected=allowed_redirect_urls,
 | 
				
			||||||
 | 
					            ).from_http(request)
 | 
				
			||||||
 | 
					            raise TokenError("invalid_client")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Check against forbidden schemes
 | 
				
			||||||
 | 
					        if urlparse(self.redirect_uri).scheme in FORBIDDEN_URI_SCHEMES:
 | 
				
			||||||
 | 
					            raise TokenError("invalid_request")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __post_init_refresh(self, raw_token: str, request: HttpRequest):
 | 
					    def __post_init_refresh(self, raw_token: str, request: HttpRequest):
 | 
				
			||||||
        if not raw_token:
 | 
					        if not raw_token:
 | 
				
			||||||
            LOGGER.warning("Missing refresh token")
 | 
					            LOGGER.warning("Missing refresh token")
 | 
				
			||||||
@ -433,20 +457,20 @@ class TokenParams:
 | 
				
			|||||||
        app = Application.objects.filter(provider=self.provider).first()
 | 
					        app = Application.objects.filter(provider=self.provider).first()
 | 
				
			||||||
        if not app or not app.provider:
 | 
					        if not app or not app.provider:
 | 
				
			||||||
            raise TokenError("invalid_grant")
 | 
					            raise TokenError("invalid_grant")
 | 
				
			||||||
        self.user, _ = User.objects.update_or_create(
 | 
					        with audit_ignore():
 | 
				
			||||||
            # trim username to ensure the entire username is max 150 chars
 | 
					            self.user, _ = User.objects.update_or_create(
 | 
				
			||||||
            # (22 chars being the length of the "template")
 | 
					                # trim username to ensure the entire username is max 150 chars
 | 
				
			||||||
            username=f"ak-{self.provider.name[:150-22]}-client_credentials",
 | 
					                # (22 chars being the length of the "template")
 | 
				
			||||||
            defaults={
 | 
					                username=f"ak-{self.provider.name[:150-22]}-client_credentials",
 | 
				
			||||||
                "attributes": {
 | 
					                defaults={
 | 
				
			||||||
                    USER_ATTRIBUTE_GENERATED: True,
 | 
					                    "last_login": timezone.now(),
 | 
				
			||||||
 | 
					                    "name": f"Autogenerated user from application {app.name} (client credentials)",
 | 
				
			||||||
 | 
					                    "path": f"{USER_PATH_SYSTEM_PREFIX}/apps/{app.slug}",
 | 
				
			||||||
 | 
					                    "type": UserTypes.SERVICE_ACCOUNT,
 | 
				
			||||||
                },
 | 
					                },
 | 
				
			||||||
                "last_login": timezone.now(),
 | 
					            )
 | 
				
			||||||
                "name": f"Autogenerated user from application {app.name} (client credentials)",
 | 
					            self.user.attributes[USER_ATTRIBUTE_GENERATED] = True
 | 
				
			||||||
                "path": f"{USER_PATH_SYSTEM_PREFIX}/apps/{app.slug}",
 | 
					            self.user.save()
 | 
				
			||||||
                "type": UserTypes.SERVICE_ACCOUNT,
 | 
					 | 
				
			||||||
            },
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.__check_policy_access(app, request)
 | 
					        self.__check_policy_access(app, request)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Event.new(
 | 
					        Event.new(
 | 
				
			||||||
@ -470,9 +494,6 @@ class TokenParams:
 | 
				
			|||||||
            self.user, created = User.objects.update_or_create(
 | 
					            self.user, created = User.objects.update_or_create(
 | 
				
			||||||
                username=f"{self.provider.name}-{token.get('sub')}",
 | 
					                username=f"{self.provider.name}-{token.get('sub')}",
 | 
				
			||||||
                defaults={
 | 
					                defaults={
 | 
				
			||||||
                    "attributes": {
 | 
					 | 
				
			||||||
                        USER_ATTRIBUTE_GENERATED: True,
 | 
					 | 
				
			||||||
                    },
 | 
					 | 
				
			||||||
                    "last_login": timezone.now(),
 | 
					                    "last_login": timezone.now(),
 | 
				
			||||||
                    "name": (
 | 
					                    "name": (
 | 
				
			||||||
                        f"Autogenerated user from application {app.name} (client credentials JWT)"
 | 
					                        f"Autogenerated user from application {app.name} (client credentials JWT)"
 | 
				
			||||||
@ -481,6 +502,8 @@ class TokenParams:
 | 
				
			|||||||
                    "type": UserTypes.SERVICE_ACCOUNT,
 | 
					                    "type": UserTypes.SERVICE_ACCOUNT,
 | 
				
			||||||
                },
 | 
					                },
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					            self.user.attributes[USER_ATTRIBUTE_GENERATED] = True
 | 
				
			||||||
 | 
					            self.user.save()
 | 
				
			||||||
            exp = token.get("exp")
 | 
					            exp = token.get("exp")
 | 
				
			||||||
            if created and exp:
 | 
					            if created and exp:
 | 
				
			||||||
                self.user.attributes[USER_ATTRIBUTE_EXPIRES] = exp
 | 
					                self.user.attributes[USER_ATTRIBUTE_EXPIRES] = exp
 | 
				
			||||||
@ -498,7 +521,7 @@ class TokenView(View):
 | 
				
			|||||||
        response = super().dispatch(request, *args, **kwargs)
 | 
					        response = super().dispatch(request, *args, **kwargs)
 | 
				
			||||||
        allowed_origins = []
 | 
					        allowed_origins = []
 | 
				
			||||||
        if self.provider:
 | 
					        if self.provider:
 | 
				
			||||||
            allowed_origins = self.provider.redirect_uris.split("\n")
 | 
					            allowed_origins = [x.url for x in self.provider.redirect_uris]
 | 
				
			||||||
        cors_allow(self.request, response, *allowed_origins)
 | 
					        cors_allow(self.request, response, *allowed_origins)
 | 
				
			||||||
        return response
 | 
					        return response
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -551,7 +574,7 @@ class TokenView(View):
 | 
				
			|||||||
            # Keep same scopes as previous token
 | 
					            # Keep same scopes as previous token
 | 
				
			||||||
            scope=self.params.authorization_code.scope,
 | 
					            scope=self.params.authorization_code.scope,
 | 
				
			||||||
            auth_time=self.params.authorization_code.auth_time,
 | 
					            auth_time=self.params.authorization_code.auth_time,
 | 
				
			||||||
            session_id=self.params.authorization_code.session_id,
 | 
					            session=self.params.authorization_code.session,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        access_id_token = IDToken.new(
 | 
					        access_id_token = IDToken.new(
 | 
				
			||||||
            self.provider,
 | 
					            self.provider,
 | 
				
			||||||
@ -579,7 +602,7 @@ class TokenView(View):
 | 
				
			|||||||
                expires=refresh_token_expiry,
 | 
					                expires=refresh_token_expiry,
 | 
				
			||||||
                provider=self.provider,
 | 
					                provider=self.provider,
 | 
				
			||||||
                auth_time=self.params.authorization_code.auth_time,
 | 
					                auth_time=self.params.authorization_code.auth_time,
 | 
				
			||||||
                session_id=self.params.authorization_code.session_id,
 | 
					                session=self.params.authorization_code.session,
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            id_token = IDToken.new(
 | 
					            id_token = IDToken.new(
 | 
				
			||||||
                self.provider,
 | 
					                self.provider,
 | 
				
			||||||
@ -612,7 +635,7 @@ class TokenView(View):
 | 
				
			|||||||
            # Keep same scopes as previous token
 | 
					            # Keep same scopes as previous token
 | 
				
			||||||
            scope=self.params.refresh_token.scope,
 | 
					            scope=self.params.refresh_token.scope,
 | 
				
			||||||
            auth_time=self.params.refresh_token.auth_time,
 | 
					            auth_time=self.params.refresh_token.auth_time,
 | 
				
			||||||
            session_id=self.params.refresh_token.session_id,
 | 
					            session=self.params.refresh_token.session,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        access_token.id_token = IDToken.new(
 | 
					        access_token.id_token = IDToken.new(
 | 
				
			||||||
            self.provider,
 | 
					            self.provider,
 | 
				
			||||||
@ -628,7 +651,7 @@ class TokenView(View):
 | 
				
			|||||||
            expires=refresh_token_expiry,
 | 
					            expires=refresh_token_expiry,
 | 
				
			||||||
            provider=self.provider,
 | 
					            provider=self.provider,
 | 
				
			||||||
            auth_time=self.params.refresh_token.auth_time,
 | 
					            auth_time=self.params.refresh_token.auth_time,
 | 
				
			||||||
            session_id=self.params.refresh_token.session_id,
 | 
					            session=self.params.refresh_token.session,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        id_token = IDToken.new(
 | 
					        id_token = IDToken.new(
 | 
				
			||||||
            self.provider,
 | 
					            self.provider,
 | 
				
			||||||
@ -686,13 +709,14 @@ class TokenView(View):
 | 
				
			|||||||
            raise DeviceCodeError("authorization_pending")
 | 
					            raise DeviceCodeError("authorization_pending")
 | 
				
			||||||
        now = timezone.now()
 | 
					        now = timezone.now()
 | 
				
			||||||
        access_token_expiry = now + timedelta_from_string(self.provider.access_token_validity)
 | 
					        access_token_expiry = now + timedelta_from_string(self.provider.access_token_validity)
 | 
				
			||||||
        auth_event = get_login_event(self.request)
 | 
					        auth_event = get_login_event(self.params.device_code.session)
 | 
				
			||||||
        access_token = AccessToken(
 | 
					        access_token = AccessToken(
 | 
				
			||||||
            provider=self.provider,
 | 
					            provider=self.provider,
 | 
				
			||||||
            user=self.params.device_code.user,
 | 
					            user=self.params.device_code.user,
 | 
				
			||||||
            expires=access_token_expiry,
 | 
					            expires=access_token_expiry,
 | 
				
			||||||
            scope=self.params.device_code.scope,
 | 
					            scope=self.params.device_code.scope,
 | 
				
			||||||
            auth_time=auth_event.created if auth_event else now,
 | 
					            auth_time=auth_event.created if auth_event else now,
 | 
				
			||||||
 | 
					            session=self.params.device_code.session,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        access_token.id_token = IDToken.new(
 | 
					        access_token.id_token = IDToken.new(
 | 
				
			||||||
            self.provider,
 | 
					            self.provider,
 | 
				
			||||||
@ -710,7 +734,7 @@ class TokenView(View):
 | 
				
			|||||||
            "id_token": access_token.id_token.to_jwt(self.provider),
 | 
					            "id_token": access_token.id_token.to_jwt(self.provider),
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if SCOPE_OFFLINE_ACCESS in self.params.scope:
 | 
					        if SCOPE_OFFLINE_ACCESS in self.params.device_code.scope:
 | 
				
			||||||
            refresh_token_expiry = now + timedelta_from_string(self.provider.refresh_token_validity)
 | 
					            refresh_token_expiry = now + timedelta_from_string(self.provider.refresh_token_validity)
 | 
				
			||||||
            refresh_token = RefreshToken(
 | 
					            refresh_token = RefreshToken(
 | 
				
			||||||
                user=self.params.device_code.user,
 | 
					                user=self.params.device_code.user,
 | 
				
			||||||
 | 
				
			|||||||
@ -108,7 +108,7 @@ class UserInfoView(View):
 | 
				
			|||||||
        response = super().dispatch(request, *args, **kwargs)
 | 
					        response = super().dispatch(request, *args, **kwargs)
 | 
				
			||||||
        allowed_origins = []
 | 
					        allowed_origins = []
 | 
				
			||||||
        if self.token:
 | 
					        if self.token:
 | 
				
			||||||
            allowed_origins = self.token.provider.redirect_uris.split("\n")
 | 
					            allowed_origins = [x.url for x in self.token.provider.redirect_uris]
 | 
				
			||||||
        cors_allow(self.request, response, *allowed_origins)
 | 
					        cors_allow(self.request, response, *allowed_origins)
 | 
				
			||||||
        return response
 | 
					        return response
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -13,6 +13,7 @@ from authentik.core.api.providers import ProviderSerializer
 | 
				
			|||||||
from authentik.core.api.used_by import UsedByMixin
 | 
					from authentik.core.api.used_by import UsedByMixin
 | 
				
			||||||
from authentik.core.api.utils import ModelSerializer, PassiveSerializer
 | 
					from authentik.core.api.utils import ModelSerializer, PassiveSerializer
 | 
				
			||||||
from authentik.lib.utils.time import timedelta_from_string
 | 
					from authentik.lib.utils.time import timedelta_from_string
 | 
				
			||||||
 | 
					from authentik.providers.oauth2.api.providers import RedirectURISerializer
 | 
				
			||||||
from authentik.providers.oauth2.models import ScopeMapping
 | 
					from authentik.providers.oauth2.models import ScopeMapping
 | 
				
			||||||
from authentik.providers.oauth2.views.provider import ProviderInfoView
 | 
					from authentik.providers.oauth2.views.provider import ProviderInfoView
 | 
				
			||||||
from authentik.providers.proxy.models import ProxyMode, ProxyProvider
 | 
					from authentik.providers.proxy.models import ProxyMode, ProxyProvider
 | 
				
			||||||
@ -39,7 +40,7 @@ class ProxyProviderSerializer(ProviderSerializer):
 | 
				
			|||||||
    """ProxyProvider Serializer"""
 | 
					    """ProxyProvider Serializer"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    client_id = CharField(read_only=True)
 | 
					    client_id = CharField(read_only=True)
 | 
				
			||||||
    redirect_uris = CharField(read_only=True)
 | 
					    redirect_uris = RedirectURISerializer(many=True, read_only=True, source="_redirect_uris")
 | 
				
			||||||
    outpost_set = ListField(child=CharField(), read_only=True, source="outpost_set.all")
 | 
					    outpost_set = ListField(child=CharField(), read_only=True, source="outpost_set.all")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def validate_basic_auth_enabled(self, value: bool) -> bool:
 | 
					    def validate_basic_auth_enabled(self, value: bool) -> bool:
 | 
				
			||||||
@ -121,7 +122,6 @@ class ProxyProviderViewSet(UsedByMixin, ModelViewSet):
 | 
				
			|||||||
        "basic_auth_password_attribute": ["iexact"],
 | 
					        "basic_auth_password_attribute": ["iexact"],
 | 
				
			||||||
        "basic_auth_user_attribute": ["iexact"],
 | 
					        "basic_auth_user_attribute": ["iexact"],
 | 
				
			||||||
        "mode": ["iexact"],
 | 
					        "mode": ["iexact"],
 | 
				
			||||||
        "redirect_uris": ["iexact"],
 | 
					 | 
				
			||||||
        "cookie_domain": ["iexact"],
 | 
					        "cookie_domain": ["iexact"],
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    search_fields = ["name"]
 | 
					    search_fields = ["name"]
 | 
				
			||||||
 | 
				
			|||||||
@ -28,7 +28,7 @@ class ProxyDockerController(DockerController):
 | 
				
			|||||||
        labels = super()._get_labels()
 | 
					        labels = super()._get_labels()
 | 
				
			||||||
        labels["traefik.enable"] = "true"
 | 
					        labels["traefik.enable"] = "true"
 | 
				
			||||||
        labels[f"traefik.http.routers.{traefik_name}-router.rule"] = (
 | 
					        labels[f"traefik.http.routers.{traefik_name}-router.rule"] = (
 | 
				
			||||||
            f"({' || '.join([f'Host(`{host}`)' for host in hosts])})"
 | 
					            f"({' || '.join([f'Host({host})' for host in hosts])})"
 | 
				
			||||||
            f" && PathPrefix(`/outpost.goauthentik.io`)"
 | 
					            f" && PathPrefix(`/outpost.goauthentik.io`)"
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        labels[f"traefik.http.routers.{traefik_name}-router.tls"] = "true"
 | 
					        labels[f"traefik.http.routers.{traefik_name}-router.tls"] = "true"
 | 
				
			||||||
 | 
				
			|||||||
@ -13,7 +13,13 @@ from rest_framework.serializers import Serializer
 | 
				
			|||||||
from authentik.crypto.models import CertificateKeyPair
 | 
					from authentik.crypto.models import CertificateKeyPair
 | 
				
			||||||
from authentik.lib.models import DomainlessURLValidator
 | 
					from authentik.lib.models import DomainlessURLValidator
 | 
				
			||||||
from authentik.outposts.models import OutpostModel
 | 
					from authentik.outposts.models import OutpostModel
 | 
				
			||||||
from authentik.providers.oauth2.models import ClientTypes, OAuth2Provider, ScopeMapping
 | 
					from authentik.providers.oauth2.models import (
 | 
				
			||||||
 | 
					    ClientTypes,
 | 
				
			||||||
 | 
					    OAuth2Provider,
 | 
				
			||||||
 | 
					    RedirectURI,
 | 
				
			||||||
 | 
					    RedirectURIMatchingMode,
 | 
				
			||||||
 | 
					    ScopeMapping,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
SCOPE_AK_PROXY = "ak_proxy"
 | 
					SCOPE_AK_PROXY = "ak_proxy"
 | 
				
			||||||
OUTPOST_CALLBACK_SIGNATURE = "X-authentik-auth-callback"
 | 
					OUTPOST_CALLBACK_SIGNATURE = "X-authentik-auth-callback"
 | 
				
			||||||
@ -24,14 +30,14 @@ def get_cookie_secret():
 | 
				
			|||||||
    return "".join(SystemRandom().choice(string.ascii_uppercase + string.digits) for _ in range(32))
 | 
					    return "".join(SystemRandom().choice(string.ascii_uppercase + string.digits) for _ in range(32))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _get_callback_url(uri: str) -> str:
 | 
					def _get_callback_url(uri: str) -> list[RedirectURI]:
 | 
				
			||||||
    return "\n".join(
 | 
					    return [
 | 
				
			||||||
        [
 | 
					        RedirectURI(
 | 
				
			||||||
            urljoin(uri, "outpost.goauthentik.io/callback")
 | 
					            RedirectURIMatchingMode.STRICT,
 | 
				
			||||||
            + f"\\?{OUTPOST_CALLBACK_SIGNATURE}=true",
 | 
					            urljoin(uri, "outpost.goauthentik.io/callback") + f"?{OUTPOST_CALLBACK_SIGNATURE}=true",
 | 
				
			||||||
            uri + f"\\?{OUTPOST_CALLBACK_SIGNATURE}=true",
 | 
					        ),
 | 
				
			||||||
        ]
 | 
					        RedirectURI(RedirectURIMatchingMode.STRICT, uri + f"?{OUTPOST_CALLBACK_SIGNATURE}=true"),
 | 
				
			||||||
    )
 | 
					    ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ProxyMode(models.TextChoices):
 | 
					class ProxyMode(models.TextChoices):
 | 
				
			||||||
 | 
				
			|||||||
@ -1,13 +1,12 @@
 | 
				
			|||||||
"""proxy provider tasks"""
 | 
					"""proxy provider tasks"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from hashlib import sha256
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from asgiref.sync import async_to_sync
 | 
					from asgiref.sync import async_to_sync
 | 
				
			||||||
from channels.layers import get_channel_layer
 | 
					from channels.layers import get_channel_layer
 | 
				
			||||||
from django.db import DatabaseError, InternalError, ProgrammingError
 | 
					from django.db import DatabaseError, InternalError, ProgrammingError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from authentik.outposts.consumer import OUTPOST_GROUP
 | 
					from authentik.outposts.consumer import OUTPOST_GROUP
 | 
				
			||||||
from authentik.outposts.models import Outpost, OutpostType
 | 
					from authentik.outposts.models import Outpost, OutpostType
 | 
				
			||||||
 | 
					from authentik.providers.oauth2.id_token import hash_session_key
 | 
				
			||||||
from authentik.providers.proxy.models import ProxyProvider
 | 
					from authentik.providers.proxy.models import ProxyProvider
 | 
				
			||||||
from authentik.root.celery import CELERY_APP
 | 
					from authentik.root.celery import CELERY_APP
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -26,7 +25,7 @@ def proxy_set_defaults():
 | 
				
			|||||||
def proxy_on_logout(session_id: str):
 | 
					def proxy_on_logout(session_id: str):
 | 
				
			||||||
    """Update outpost instances connected to a single outpost"""
 | 
					    """Update outpost instances connected to a single outpost"""
 | 
				
			||||||
    layer = get_channel_layer()
 | 
					    layer = get_channel_layer()
 | 
				
			||||||
    hashed_session_id = sha256(session_id.encode("ascii")).hexdigest()
 | 
					    hashed_session_id = hash_session_key(session_id)
 | 
				
			||||||
    for outpost in Outpost.objects.filter(type=OutpostType.PROXY):
 | 
					    for outpost in Outpost.objects.filter(type=OutpostType.PROXY):
 | 
				
			||||||
        group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)}
 | 
					        group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)}
 | 
				
			||||||
        async_to_sync(layer.group_send)(
 | 
					        async_to_sync(layer.group_send)(
 | 
				
			||||||
 | 
				
			|||||||
@ -164,7 +164,7 @@ class SAMLProvider(Provider):
 | 
				
			|||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    sign_assertion = models.BooleanField(default=True)
 | 
					    sign_assertion = models.BooleanField(default=True)
 | 
				
			||||||
    sign_response = models.BooleanField(default=True)
 | 
					    sign_response = models.BooleanField(default=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def launch_url(self) -> str | None:
 | 
					    def launch_url(self) -> str | None:
 | 
				
			||||||
 | 
				
			|||||||
@ -50,6 +50,7 @@ class AssertionProcessor:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    _issue_instant: str
 | 
					    _issue_instant: str
 | 
				
			||||||
    _assertion_id: str
 | 
					    _assertion_id: str
 | 
				
			||||||
 | 
					    _response_id: str
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    _valid_not_before: str
 | 
					    _valid_not_before: str
 | 
				
			||||||
    _session_not_on_or_after: str
 | 
					    _session_not_on_or_after: str
 | 
				
			||||||
@ -62,6 +63,7 @@ class AssertionProcessor:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        self._issue_instant = get_time_string()
 | 
					        self._issue_instant = get_time_string()
 | 
				
			||||||
        self._assertion_id = get_random_id()
 | 
					        self._assertion_id = get_random_id()
 | 
				
			||||||
 | 
					        self._response_id = get_random_id()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self._valid_not_before = get_time_string(
 | 
					        self._valid_not_before = get_time_string(
 | 
				
			||||||
            timedelta_from_string(self.provider.assertion_valid_not_before)
 | 
					            timedelta_from_string(self.provider.assertion_valid_not_before)
 | 
				
			||||||
@ -130,7 +132,9 @@ class AssertionProcessor:
 | 
				
			|||||||
        """Generate AuthnStatement with AuthnContext and ContextClassRef Elements."""
 | 
					        """Generate AuthnStatement with AuthnContext and ContextClassRef Elements."""
 | 
				
			||||||
        auth_n_statement = Element(f"{{{NS_SAML_ASSERTION}}}AuthnStatement")
 | 
					        auth_n_statement = Element(f"{{{NS_SAML_ASSERTION}}}AuthnStatement")
 | 
				
			||||||
        auth_n_statement.attrib["AuthnInstant"] = self._valid_not_before
 | 
					        auth_n_statement.attrib["AuthnInstant"] = self._valid_not_before
 | 
				
			||||||
        auth_n_statement.attrib["SessionIndex"] = self._assertion_id
 | 
					        auth_n_statement.attrib["SessionIndex"] = sha256(
 | 
				
			||||||
 | 
					            self.http_request.session.session_key.encode("ascii")
 | 
				
			||||||
 | 
					        ).hexdigest()
 | 
				
			||||||
        auth_n_statement.attrib["SessionNotOnOrAfter"] = self._session_not_on_or_after
 | 
					        auth_n_statement.attrib["SessionNotOnOrAfter"] = self._session_not_on_or_after
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        auth_n_context = SubElement(auth_n_statement, f"{{{NS_SAML_ASSERTION}}}AuthnContext")
 | 
					        auth_n_context = SubElement(auth_n_statement, f"{{{NS_SAML_ASSERTION}}}AuthnContext")
 | 
				
			||||||
@ -285,7 +289,7 @@ class AssertionProcessor:
 | 
				
			|||||||
        response.attrib["Version"] = "2.0"
 | 
					        response.attrib["Version"] = "2.0"
 | 
				
			||||||
        response.attrib["IssueInstant"] = self._issue_instant
 | 
					        response.attrib["IssueInstant"] = self._issue_instant
 | 
				
			||||||
        response.attrib["Destination"] = self.provider.acs_url
 | 
					        response.attrib["Destination"] = self.provider.acs_url
 | 
				
			||||||
        response.attrib["ID"] = get_random_id()
 | 
					        response.attrib["ID"] = self._response_id
 | 
				
			||||||
        if self.auth_n_request.id:
 | 
					        if self.auth_n_request.id:
 | 
				
			||||||
            response.attrib["InResponseTo"] = self.auth_n_request.id
 | 
					            response.attrib["InResponseTo"] = self.auth_n_request.id
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -308,7 +312,7 @@ class AssertionProcessor:
 | 
				
			|||||||
        ref = xmlsec.template.add_reference(
 | 
					        ref = xmlsec.template.add_reference(
 | 
				
			||||||
            signature_node,
 | 
					            signature_node,
 | 
				
			||||||
            digest_algorithm_transform,
 | 
					            digest_algorithm_transform,
 | 
				
			||||||
            uri="#" + self._assertion_id,
 | 
					            uri="#" + element.attrib["ID"],
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        xmlsec.template.add_transform(ref, xmlsec.constants.TransformEnveloped)
 | 
					        xmlsec.template.add_transform(ref, xmlsec.constants.TransformEnveloped)
 | 
				
			||||||
        xmlsec.template.add_transform(ref, xmlsec.constants.TransformExclC14N)
 | 
					        xmlsec.template.add_transform(ref, xmlsec.constants.TransformExclC14N)
 | 
				
			||||||
 | 
				
			|||||||
@ -180,6 +180,10 @@ class TestAuthNRequest(TestCase):
 | 
				
			|||||||
        # Now create a response and convert it to string (provider)
 | 
					        # Now create a response and convert it to string (provider)
 | 
				
			||||||
        response_proc = AssertionProcessor(self.provider, http_request, parsed_request)
 | 
					        response_proc = AssertionProcessor(self.provider, http_request, parsed_request)
 | 
				
			||||||
        response = response_proc.build_response()
 | 
					        response = response_proc.build_response()
 | 
				
			||||||
 | 
					        # Ensure both response and assertion ID are in the response twice (once as ID attribute,
 | 
				
			||||||
 | 
					        # once as ds:Reference URI)
 | 
				
			||||||
 | 
					        self.assertEqual(response.count(response_proc._assertion_id), 2)
 | 
				
			||||||
 | 
					        self.assertEqual(response.count(response_proc._response_id), 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Now parse the response (source)
 | 
					        # Now parse the response (source)
 | 
				
			||||||
        http_request.POST = QueryDict(mutable=True)
 | 
					        http_request.POST = QueryDict(mutable=True)
 | 
				
			||||||
 | 
				
			|||||||
@ -54,7 +54,11 @@ class TestServiceProviderMetadataParser(TestCase):
 | 
				
			|||||||
        request = self.factory.get("/")
 | 
					        request = self.factory.get("/")
 | 
				
			||||||
        metadata = lxml_from_string(MetadataProcessor(provider, request).build_entity_descriptor())
 | 
					        metadata = lxml_from_string(MetadataProcessor(provider, request).build_entity_descriptor())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        schema = etree.XMLSchema(etree.parse("schemas/saml-schema-metadata-2.0.xsd"))  # nosec
 | 
					        schema = etree.XMLSchema(
 | 
				
			||||||
 | 
					            etree.parse(
 | 
				
			||||||
 | 
					                source="schemas/saml-schema-metadata-2.0.xsd", parser=etree.XMLParser()
 | 
				
			||||||
 | 
					            )  # nosec
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        self.assertTrue(schema.validate(metadata))
 | 
					        self.assertTrue(schema.validate(metadata))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_schema_want_authn_requests_signed(self):
 | 
					    def test_schema_want_authn_requests_signed(self):
 | 
				
			||||||
 | 
				
			|||||||
@ -47,7 +47,9 @@ class TestSchema(TestCase):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        metadata = lxml_from_string(request)
 | 
					        metadata = lxml_from_string(request)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        schema = etree.XMLSchema(etree.parse("schemas/saml-schema-protocol-2.0.xsd"))  # nosec
 | 
					        schema = etree.XMLSchema(
 | 
				
			||||||
 | 
					            etree.parse("schemas/saml-schema-protocol-2.0.xsd", parser=etree.XMLParser())  # nosec
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        self.assertTrue(schema.validate(metadata))
 | 
					        self.assertTrue(schema.validate(metadata))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_response_schema(self):
 | 
					    def test_response_schema(self):
 | 
				
			||||||
@ -68,5 +70,7 @@ class TestSchema(TestCase):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        metadata = lxml_from_string(response)
 | 
					        metadata = lxml_from_string(response)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        schema = etree.XMLSchema(etree.parse("schemas/saml-schema-protocol-2.0.xsd"))  # nosec
 | 
					        schema = etree.XMLSchema(
 | 
				
			||||||
 | 
					            etree.parse("schemas/saml-schema-protocol-2.0.xsd", parser=etree.XMLParser())  # nosec
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        self.assertTrue(schema.validate(metadata))
 | 
					        self.assertTrue(schema.validate(metadata))
 | 
				
			||||||
 | 
				
			|||||||
@ -2,9 +2,10 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from itertools import batched
 | 
					from itertools import batched
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from django.db import transaction
 | 
				
			||||||
from pydantic import ValidationError
 | 
					from pydantic import ValidationError
 | 
				
			||||||
from pydanticscim.group import GroupMember
 | 
					from pydanticscim.group import GroupMember
 | 
				
			||||||
from pydanticscim.responses import PatchOp, PatchOperation
 | 
					from pydanticscim.responses import PatchOp
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from authentik.core.models import Group
 | 
					from authentik.core.models import Group
 | 
				
			||||||
from authentik.lib.sync.mapper import PropertyMappingManager
 | 
					from authentik.lib.sync.mapper import PropertyMappingManager
 | 
				
			||||||
@ -19,7 +20,7 @@ from authentik.providers.scim.clients.base import SCIMClient
 | 
				
			|||||||
from authentik.providers.scim.clients.exceptions import (
 | 
					from authentik.providers.scim.clients.exceptions import (
 | 
				
			||||||
    SCIMRequestException,
 | 
					    SCIMRequestException,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from authentik.providers.scim.clients.schema import SCIM_GROUP_SCHEMA, PatchRequest
 | 
					from authentik.providers.scim.clients.schema import SCIM_GROUP_SCHEMA, PatchOperation, PatchRequest
 | 
				
			||||||
from authentik.providers.scim.clients.schema import Group as SCIMGroupSchema
 | 
					from authentik.providers.scim.clients.schema import Group as SCIMGroupSchema
 | 
				
			||||||
from authentik.providers.scim.models import (
 | 
					from authentik.providers.scim.models import (
 | 
				
			||||||
    SCIMMapping,
 | 
					    SCIMMapping,
 | 
				
			||||||
@ -104,13 +105,47 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
 | 
				
			|||||||
            provider=self.provider, group=group, scim_id=scim_id
 | 
					            provider=self.provider, group=group, scim_id=scim_id
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        users = list(group.users.order_by("id").values_list("id", flat=True))
 | 
					        users = list(group.users.order_by("id").values_list("id", flat=True))
 | 
				
			||||||
        self._patch_add_users(group, users)
 | 
					        self._patch_add_users(connection, users)
 | 
				
			||||||
        return connection
 | 
					        return connection
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def update(self, group: Group, connection: SCIMProviderGroup):
 | 
					    def update(self, group: Group, connection: SCIMProviderGroup):
 | 
				
			||||||
        """Update existing group"""
 | 
					        """Update existing group"""
 | 
				
			||||||
        scim_group = self.to_schema(group, connection)
 | 
					        scim_group = self.to_schema(group, connection)
 | 
				
			||||||
        scim_group.id = connection.scim_id
 | 
					        scim_group.id = connection.scim_id
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            if self._config.patch.supported:
 | 
				
			||||||
 | 
					                return self._update_patch(group, scim_group, connection)
 | 
				
			||||||
 | 
					            return self._update_put(group, scim_group, connection)
 | 
				
			||||||
 | 
					        except NotFoundSyncException:
 | 
				
			||||||
 | 
					            # Resource missing is handled by self.write, which will re-create the group
 | 
				
			||||||
 | 
					            raise
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _update_patch(
 | 
				
			||||||
 | 
					        self, group: Group, scim_group: SCIMGroupSchema, connection: SCIMProviderGroup
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        """Update a group via PATCH request"""
 | 
				
			||||||
 | 
					        # Patch group's attributes instead of replacing it and re-adding users if we can
 | 
				
			||||||
 | 
					        self._request(
 | 
				
			||||||
 | 
					            "PATCH",
 | 
				
			||||||
 | 
					            f"/Groups/{connection.scim_id}",
 | 
				
			||||||
 | 
					            json=PatchRequest(
 | 
				
			||||||
 | 
					                Operations=[
 | 
				
			||||||
 | 
					                    PatchOperation(
 | 
				
			||||||
 | 
					                        op=PatchOp.replace,
 | 
				
			||||||
 | 
					                        path=None,
 | 
				
			||||||
 | 
					                        value=scim_group.model_dump(mode="json", exclude_unset=True),
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                ]
 | 
				
			||||||
 | 
					            ).model_dump(
 | 
				
			||||||
 | 
					                mode="json",
 | 
				
			||||||
 | 
					                exclude_unset=True,
 | 
				
			||||||
 | 
					                exclude_none=True,
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        return self.patch_compare_users(group)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _update_put(self, group: Group, scim_group: SCIMGroupSchema, connection: SCIMProviderGroup):
 | 
				
			||||||
 | 
					        """Update a group via PUT request"""
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            self._request(
 | 
					            self._request(
 | 
				
			||||||
                "PUT",
 | 
					                "PUT",
 | 
				
			||||||
@ -120,33 +155,25 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
 | 
				
			|||||||
                    exclude_unset=True,
 | 
					                    exclude_unset=True,
 | 
				
			||||||
                ),
 | 
					                ),
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            users = list(group.users.order_by("id").values_list("id", flat=True))
 | 
					            return self.patch_compare_users(group)
 | 
				
			||||||
            return self._patch_add_users(group, users)
 | 
					 | 
				
			||||||
        except NotFoundSyncException:
 | 
					 | 
				
			||||||
            # Resource missing is handled by self.write, which will re-create the group
 | 
					 | 
				
			||||||
            raise
 | 
					 | 
				
			||||||
        except (SCIMRequestException, ObjectExistsSyncException):
 | 
					        except (SCIMRequestException, ObjectExistsSyncException):
 | 
				
			||||||
            # Some providers don't support PUT on groups, so this is mainly a fix for the initial
 | 
					            # Some providers don't support PUT on groups, so this is mainly a fix for the initial
 | 
				
			||||||
            # sync, send patch add requests for all the users the group currently has
 | 
					            # sync, send patch add requests for all the users the group currently has
 | 
				
			||||||
            users = list(group.users.order_by("id").values_list("id", flat=True))
 | 
					            return self._update_patch(group, scim_group, connection)
 | 
				
			||||||
            self._patch_add_users(group, users)
 | 
					 | 
				
			||||||
            # Also update the group name
 | 
					 | 
				
			||||||
            return self._patch(
 | 
					 | 
				
			||||||
                scim_group.id,
 | 
					 | 
				
			||||||
                PatchOperation(
 | 
					 | 
				
			||||||
                    op=PatchOp.replace,
 | 
					 | 
				
			||||||
                    path="displayName",
 | 
					 | 
				
			||||||
                    value=scim_group.displayName,
 | 
					 | 
				
			||||||
                ),
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def update_group(self, group: Group, action: Direction, users_set: set[int]):
 | 
					    def update_group(self, group: Group, action: Direction, users_set: set[int]):
 | 
				
			||||||
        """Update a group, either using PUT to replace it or PATCH if supported"""
 | 
					        """Update a group, either using PUT to replace it or PATCH if supported"""
 | 
				
			||||||
 | 
					        scim_group = SCIMProviderGroup.objects.filter(provider=self.provider, group=group).first()
 | 
				
			||||||
 | 
					        if not scim_group:
 | 
				
			||||||
 | 
					            self.logger.warning(
 | 
				
			||||||
 | 
					                "could not sync group membership, group does not exist", group=group
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
        if self._config.patch.supported:
 | 
					        if self._config.patch.supported:
 | 
				
			||||||
            if action == Direction.add:
 | 
					            if action == Direction.add:
 | 
				
			||||||
                return self._patch_add_users(group, users_set)
 | 
					                return self._patch_add_users(scim_group, users_set)
 | 
				
			||||||
            if action == Direction.remove:
 | 
					            if action == Direction.remove:
 | 
				
			||||||
                return self._patch_remove_users(group, users_set)
 | 
					                return self._patch_remove_users(scim_group, users_set)
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            return self.write(group)
 | 
					            return self.write(group)
 | 
				
			||||||
        except SCIMRequestException as exc:
 | 
					        except SCIMRequestException as exc:
 | 
				
			||||||
@ -154,19 +181,24 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
 | 
				
			|||||||
                # Assume that provider does not support PUT and also doesn't support
 | 
					                # Assume that provider does not support PUT and also doesn't support
 | 
				
			||||||
                # ServiceProviderConfig, so try PATCH as a fallback
 | 
					                # ServiceProviderConfig, so try PATCH as a fallback
 | 
				
			||||||
                if action == Direction.add:
 | 
					                if action == Direction.add:
 | 
				
			||||||
                    return self._patch_add_users(group, users_set)
 | 
					                    return self._patch_add_users(scim_group, users_set)
 | 
				
			||||||
                if action == Direction.remove:
 | 
					                if action == Direction.remove:
 | 
				
			||||||
                    return self._patch_remove_users(group, users_set)
 | 
					                    return self._patch_remove_users(scim_group, users_set)
 | 
				
			||||||
            raise exc
 | 
					            raise exc
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _patch(
 | 
					    def _patch_chunked(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
        group_id: str,
 | 
					        group_id: str,
 | 
				
			||||||
        *ops: PatchOperation,
 | 
					        *ops: PatchOperation,
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
 | 
					        """Helper function that chunks patch requests based on the maxOperations attribute.
 | 
				
			||||||
 | 
					        This is not strictly according to specs but there's nothing in the schema that allows the
 | 
				
			||||||
 | 
					        us to know what the maximum patch operations per request should be."""
 | 
				
			||||||
        chunk_size = self._config.bulk.maxOperations
 | 
					        chunk_size = self._config.bulk.maxOperations
 | 
				
			||||||
        if chunk_size < 1:
 | 
					        if chunk_size < 1:
 | 
				
			||||||
            chunk_size = len(ops)
 | 
					            chunk_size = len(ops)
 | 
				
			||||||
 | 
					        if len(ops) < 1:
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
        for chunk in batched(ops, chunk_size):
 | 
					        for chunk in batched(ops, chunk_size):
 | 
				
			||||||
            req = PatchRequest(Operations=list(chunk))
 | 
					            req = PatchRequest(Operations=list(chunk))
 | 
				
			||||||
            self._request(
 | 
					            self._request(
 | 
				
			||||||
@ -177,16 +209,70 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
 | 
				
			|||||||
                ),
 | 
					                ),
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _patch_add_users(self, group: Group, users_set: set[int]):
 | 
					    @transaction.atomic
 | 
				
			||||||
        """Add users in users_set to group"""
 | 
					    def patch_compare_users(self, group: Group):
 | 
				
			||||||
        if len(users_set) < 1:
 | 
					        """Compare users with a SCIM group and add/remove any differences"""
 | 
				
			||||||
            return
 | 
					        # Get scim group first
 | 
				
			||||||
        scim_group = SCIMProviderGroup.objects.filter(provider=self.provider, group=group).first()
 | 
					        scim_group = SCIMProviderGroup.objects.filter(provider=self.provider, group=group).first()
 | 
				
			||||||
        if not scim_group:
 | 
					        if not scim_group:
 | 
				
			||||||
            self.logger.warning(
 | 
					            self.logger.warning(
 | 
				
			||||||
                "could not sync group membership, group does not exist", group=group
 | 
					                "could not sync group membership, group does not exist", group=group
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            return
 | 
					            return
 | 
				
			||||||
 | 
					        # Get a list of all users in the authentik group
 | 
				
			||||||
 | 
					        raw_users_should = list(group.users.order_by("id").values_list("id", flat=True))
 | 
				
			||||||
 | 
					        # Lookup the SCIM IDs of the users
 | 
				
			||||||
 | 
					        users_should: list[str] = list(
 | 
				
			||||||
 | 
					            SCIMProviderUser.objects.filter(
 | 
				
			||||||
 | 
					                user__pk__in=raw_users_should, provider=self.provider
 | 
				
			||||||
 | 
					            ).values_list("scim_id", flat=True)
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        if len(raw_users_should) != len(users_should):
 | 
				
			||||||
 | 
					            self.logger.warning(
 | 
				
			||||||
 | 
					                "User count mismatch, not all users in the group are synced to SCIM yet.",
 | 
				
			||||||
 | 
					                group=group,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        # Get current group status
 | 
				
			||||||
 | 
					        current_group = SCIMGroupSchema.model_validate(
 | 
				
			||||||
 | 
					            self._request("GET", f"/Groups/{scim_group.scim_id}")
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        users_to_add = []
 | 
				
			||||||
 | 
					        users_to_remove = []
 | 
				
			||||||
 | 
					        # Check users currently in group and if they shouldn't be in the group and remove them
 | 
				
			||||||
 | 
					        for user in current_group.members or []:
 | 
				
			||||||
 | 
					            if user.value not in users_should:
 | 
				
			||||||
 | 
					                users_to_remove.append(user.value)
 | 
				
			||||||
 | 
					        # Check users that should be in the group and add them
 | 
				
			||||||
 | 
					        for user in users_should:
 | 
				
			||||||
 | 
					            if len([x for x in current_group.members if x.value == user]) < 1:
 | 
				
			||||||
 | 
					                users_to_add.append(user)
 | 
				
			||||||
 | 
					        # Only send request if we need to make changes
 | 
				
			||||||
 | 
					        if len(users_to_add) < 1 and len(users_to_remove) < 1:
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					        return self._patch_chunked(
 | 
				
			||||||
 | 
					            scim_group.scim_id,
 | 
				
			||||||
 | 
					            *[
 | 
				
			||||||
 | 
					                PatchOperation(
 | 
				
			||||||
 | 
					                    op=PatchOp.add,
 | 
				
			||||||
 | 
					                    path="members",
 | 
				
			||||||
 | 
					                    value=[{"value": x}],
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                for x in users_to_add
 | 
				
			||||||
 | 
					            ],
 | 
				
			||||||
 | 
					            *[
 | 
				
			||||||
 | 
					                PatchOperation(
 | 
				
			||||||
 | 
					                    op=PatchOp.remove,
 | 
				
			||||||
 | 
					                    path="members",
 | 
				
			||||||
 | 
					                    value=[{"value": x}],
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                for x in users_to_remove
 | 
				
			||||||
 | 
					            ],
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _patch_add_users(self, scim_group: SCIMProviderGroup, users_set: set[int]):
 | 
				
			||||||
 | 
					        """Add users in users_set to group"""
 | 
				
			||||||
 | 
					        if len(users_set) < 1:
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
        user_ids = list(
 | 
					        user_ids = list(
 | 
				
			||||||
            SCIMProviderUser.objects.filter(
 | 
					            SCIMProviderUser.objects.filter(
 | 
				
			||||||
                user__pk__in=users_set, provider=self.provider
 | 
					                user__pk__in=users_set, provider=self.provider
 | 
				
			||||||
@ -194,7 +280,7 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
 | 
				
			|||||||
        )
 | 
					        )
 | 
				
			||||||
        if len(user_ids) < 1:
 | 
					        if len(user_ids) < 1:
 | 
				
			||||||
            return
 | 
					            return
 | 
				
			||||||
        self._patch(
 | 
					        self._patch_chunked(
 | 
				
			||||||
            scim_group.scim_id,
 | 
					            scim_group.scim_id,
 | 
				
			||||||
            *[
 | 
					            *[
 | 
				
			||||||
                PatchOperation(
 | 
					                PatchOperation(
 | 
				
			||||||
@ -206,16 +292,10 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
 | 
				
			|||||||
            ],
 | 
					            ],
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _patch_remove_users(self, group: Group, users_set: set[int]):
 | 
					    def _patch_remove_users(self, scim_group: SCIMProviderGroup, users_set: set[int]):
 | 
				
			||||||
        """Remove users in users_set from group"""
 | 
					        """Remove users in users_set from group"""
 | 
				
			||||||
        if len(users_set) < 1:
 | 
					        if len(users_set) < 1:
 | 
				
			||||||
            return
 | 
					            return
 | 
				
			||||||
        scim_group = SCIMProviderGroup.objects.filter(provider=self.provider, group=group).first()
 | 
					 | 
				
			||||||
        if not scim_group:
 | 
					 | 
				
			||||||
            self.logger.warning(
 | 
					 | 
				
			||||||
                "could not sync group membership, group does not exist", group=group
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
            return
 | 
					 | 
				
			||||||
        user_ids = list(
 | 
					        user_ids = list(
 | 
				
			||||||
            SCIMProviderUser.objects.filter(
 | 
					            SCIMProviderUser.objects.filter(
 | 
				
			||||||
                user__pk__in=users_set, provider=self.provider
 | 
					                user__pk__in=users_set, provider=self.provider
 | 
				
			||||||
@ -223,7 +303,7 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
 | 
				
			|||||||
        )
 | 
					        )
 | 
				
			||||||
        if len(user_ids) < 1:
 | 
					        if len(user_ids) < 1:
 | 
				
			||||||
            return
 | 
					            return
 | 
				
			||||||
        self._patch(
 | 
					        self._patch_chunked(
 | 
				
			||||||
            scim_group.scim_id,
 | 
					            scim_group.scim_id,
 | 
				
			||||||
            *[
 | 
					            *[
 | 
				
			||||||
                PatchOperation(
 | 
					                PatchOperation(
 | 
				
			||||||
 | 
				
			|||||||
@ -2,6 +2,7 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from pydantic import Field
 | 
					from pydantic import Field
 | 
				
			||||||
from pydanticscim.group import Group as BaseGroup
 | 
					from pydanticscim.group import Group as BaseGroup
 | 
				
			||||||
 | 
					from pydanticscim.responses import PatchOperation as BasePatchOperation
 | 
				
			||||||
from pydanticscim.responses import PatchRequest as BasePatchRequest
 | 
					from pydanticscim.responses import PatchRequest as BasePatchRequest
 | 
				
			||||||
from pydanticscim.responses import SCIMError as BaseSCIMError
 | 
					from pydanticscim.responses import SCIMError as BaseSCIMError
 | 
				
			||||||
from pydanticscim.service_provider import Bulk as BaseBulk
 | 
					from pydanticscim.service_provider import Bulk as BaseBulk
 | 
				
			||||||
@ -68,6 +69,12 @@ class PatchRequest(BasePatchRequest):
 | 
				
			|||||||
    schemas: tuple[str] = ("urn:ietf:params:scim:api:messages:2.0:PatchOp",)
 | 
					    schemas: tuple[str] = ("urn:ietf:params:scim:api:messages:2.0:PatchOp",)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class PatchOperation(BasePatchOperation):
 | 
				
			||||||
 | 
					    """PatchOperation with optional path"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    path: str | None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class SCIMError(BaseSCIMError):
 | 
					class SCIMError(BaseSCIMError):
 | 
				
			||||||
    """SCIM error with optional status code"""
 | 
					    """SCIM error with optional status code"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -252,3 +252,118 @@ class SCIMMembershipTests(TestCase):
 | 
				
			|||||||
                    ],
 | 
					                    ],
 | 
				
			||||||
                },
 | 
					                },
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_member_add_save(self):
 | 
				
			||||||
 | 
					        """Test member add + save"""
 | 
				
			||||||
 | 
					        config = ServiceProviderConfiguration.default()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        config.patch.supported = True
 | 
				
			||||||
 | 
					        user_scim_id = generate_id()
 | 
				
			||||||
 | 
					        group_scim_id = generate_id()
 | 
				
			||||||
 | 
					        uid = generate_id()
 | 
				
			||||||
 | 
					        group = Group.objects.create(
 | 
				
			||||||
 | 
					            name=uid,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        user = User.objects.create(username=generate_id())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Test initial sync of group creation
 | 
				
			||||||
 | 
					        with Mocker() as mocker:
 | 
				
			||||||
 | 
					            mocker.get(
 | 
				
			||||||
 | 
					                "https://localhost/ServiceProviderConfig",
 | 
				
			||||||
 | 
					                json=config.model_dump(),
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            mocker.post(
 | 
				
			||||||
 | 
					                "https://localhost/Users",
 | 
				
			||||||
 | 
					                json={
 | 
				
			||||||
 | 
					                    "id": user_scim_id,
 | 
				
			||||||
 | 
					                },
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            mocker.post(
 | 
				
			||||||
 | 
					                "https://localhost/Groups",
 | 
				
			||||||
 | 
					                json={
 | 
				
			||||||
 | 
					                    "id": group_scim_id,
 | 
				
			||||||
 | 
					                },
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            self.configure()
 | 
				
			||||||
 | 
					            sync_tasks.trigger_single_task(self.provider, scim_sync).get()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            self.assertEqual(mocker.call_count, 6)
 | 
				
			||||||
 | 
					            self.assertEqual(mocker.request_history[0].method, "GET")
 | 
				
			||||||
 | 
					            self.assertEqual(mocker.request_history[1].method, "GET")
 | 
				
			||||||
 | 
					            self.assertEqual(mocker.request_history[2].method, "GET")
 | 
				
			||||||
 | 
					            self.assertEqual(mocker.request_history[3].method, "POST")
 | 
				
			||||||
 | 
					            self.assertEqual(mocker.request_history[4].method, "GET")
 | 
				
			||||||
 | 
					            self.assertEqual(mocker.request_history[5].method, "POST")
 | 
				
			||||||
 | 
					            self.assertJSONEqual(
 | 
				
			||||||
 | 
					                mocker.request_history[3].body,
 | 
				
			||||||
 | 
					                {
 | 
				
			||||||
 | 
					                    "schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
 | 
				
			||||||
 | 
					                    "emails": [],
 | 
				
			||||||
 | 
					                    "active": True,
 | 
				
			||||||
 | 
					                    "externalId": user.uid,
 | 
				
			||||||
 | 
					                    "name": {"familyName": " ", "formatted": " ", "givenName": ""},
 | 
				
			||||||
 | 
					                    "displayName": "",
 | 
				
			||||||
 | 
					                    "userName": user.username,
 | 
				
			||||||
 | 
					                },
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            self.assertJSONEqual(
 | 
				
			||||||
 | 
					                mocker.request_history[5].body,
 | 
				
			||||||
 | 
					                {
 | 
				
			||||||
 | 
					                    "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"],
 | 
				
			||||||
 | 
					                    "externalId": str(group.pk),
 | 
				
			||||||
 | 
					                    "displayName": group.name,
 | 
				
			||||||
 | 
					                },
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        with Mocker() as mocker:
 | 
				
			||||||
 | 
					            mocker.get(
 | 
				
			||||||
 | 
					                "https://localhost/ServiceProviderConfig",
 | 
				
			||||||
 | 
					                json=config.model_dump(),
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            mocker.get(
 | 
				
			||||||
 | 
					                f"https://localhost/Groups/{group_scim_id}",
 | 
				
			||||||
 | 
					                json={},
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            mocker.patch(
 | 
				
			||||||
 | 
					                f"https://localhost/Groups/{group_scim_id}",
 | 
				
			||||||
 | 
					                json={},
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            group.users.add(user)
 | 
				
			||||||
 | 
					            group.save()
 | 
				
			||||||
 | 
					            self.assertEqual(mocker.call_count, 5)
 | 
				
			||||||
 | 
					            self.assertEqual(mocker.request_history[0].method, "GET")
 | 
				
			||||||
 | 
					            self.assertEqual(mocker.request_history[1].method, "PATCH")
 | 
				
			||||||
 | 
					            self.assertEqual(mocker.request_history[2].method, "GET")
 | 
				
			||||||
 | 
					            self.assertEqual(mocker.request_history[3].method, "PATCH")
 | 
				
			||||||
 | 
					            self.assertEqual(mocker.request_history[4].method, "GET")
 | 
				
			||||||
 | 
					            self.assertJSONEqual(
 | 
				
			||||||
 | 
					                mocker.request_history[1].body,
 | 
				
			||||||
 | 
					                {
 | 
				
			||||||
 | 
					                    "schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
 | 
				
			||||||
 | 
					                    "Operations": [
 | 
				
			||||||
 | 
					                        {
 | 
				
			||||||
 | 
					                            "op": "add",
 | 
				
			||||||
 | 
					                            "path": "members",
 | 
				
			||||||
 | 
					                            "value": [{"value": user_scim_id}],
 | 
				
			||||||
 | 
					                        }
 | 
				
			||||||
 | 
					                    ],
 | 
				
			||||||
 | 
					                },
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            self.assertJSONEqual(
 | 
				
			||||||
 | 
					                mocker.request_history[3].body,
 | 
				
			||||||
 | 
					                {
 | 
				
			||||||
 | 
					                    "Operations": [
 | 
				
			||||||
 | 
					                        {
 | 
				
			||||||
 | 
					                            "op": "replace",
 | 
				
			||||||
 | 
					                            "value": {
 | 
				
			||||||
 | 
					                                "id": group_scim_id,
 | 
				
			||||||
 | 
					                                "displayName": group.name,
 | 
				
			||||||
 | 
					                                "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"],
 | 
				
			||||||
 | 
					                                "externalId": str(group.pk),
 | 
				
			||||||
 | 
					                            },
 | 
				
			||||||
 | 
					                        }
 | 
				
			||||||
 | 
					                    ]
 | 
				
			||||||
 | 
					                },
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
				
			|||||||
@ -87,7 +87,11 @@ def task_error_hook(task_id: str, exception: Exception, traceback, *args, **kwar
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
def _get_startup_tasks_default_tenant() -> list[Callable]:
 | 
					def _get_startup_tasks_default_tenant() -> list[Callable]:
 | 
				
			||||||
    """Get all tasks to be run on startup for the default tenant"""
 | 
					    """Get all tasks to be run on startup for the default tenant"""
 | 
				
			||||||
    return []
 | 
					    from authentik.outposts.tasks import outpost_connection_discovery
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return [
 | 
				
			||||||
 | 
					        outpost_connection_discovery,
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _get_startup_tasks_all_tenants() -> list[Callable]:
 | 
					def _get_startup_tasks_all_tenants() -> list[Callable]:
 | 
				
			||||||
 | 
				
			|||||||
@ -2,6 +2,7 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from collections.abc import Callable
 | 
					from collections.abc import Callable
 | 
				
			||||||
from hashlib import sha512
 | 
					from hashlib import sha512
 | 
				
			||||||
 | 
					from ipaddress import ip_address
 | 
				
			||||||
from time import perf_counter, time
 | 
					from time import perf_counter, time
 | 
				
			||||||
from typing import Any
 | 
					from typing import Any
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -174,6 +175,7 @@ class ClientIPMiddleware:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]):
 | 
					    def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]):
 | 
				
			||||||
        self.get_response = get_response
 | 
					        self.get_response = get_response
 | 
				
			||||||
 | 
					        self.logger = get_logger().bind()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _get_client_ip_from_meta(self, meta: dict[str, Any]) -> str:
 | 
					    def _get_client_ip_from_meta(self, meta: dict[str, Any]) -> str:
 | 
				
			||||||
        """Attempt to get the client's IP by checking common HTTP Headers.
 | 
					        """Attempt to get the client's IP by checking common HTTP Headers.
 | 
				
			||||||
@ -185,11 +187,16 @@ class ClientIPMiddleware:
 | 
				
			|||||||
            "HTTP_X_FORWARDED_FOR",
 | 
					            "HTTP_X_FORWARDED_FOR",
 | 
				
			||||||
            "REMOTE_ADDR",
 | 
					            "REMOTE_ADDR",
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        for _header in headers:
 | 
					        try:
 | 
				
			||||||
            if _header in meta:
 | 
					            for _header in headers:
 | 
				
			||||||
                ips: list[str] = meta.get(_header).split(",")
 | 
					                if _header in meta:
 | 
				
			||||||
                return ips[0].strip()
 | 
					                    ips: list[str] = meta.get(_header).split(",")
 | 
				
			||||||
        return self.default_ip
 | 
					                    # Ensure the IP parses as a valid IP
 | 
				
			||||||
 | 
					                    return str(ip_address(ips[0].strip()))
 | 
				
			||||||
 | 
					            return self.default_ip
 | 
				
			||||||
 | 
					        except ValueError as exc:
 | 
				
			||||||
 | 
					            self.logger.debug("Invalid remote IP", exc=exc)
 | 
				
			||||||
 | 
					            return self.default_ip
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # FIXME: this should probably not be in `root` but rather in a middleware in `outposts`
 | 
					    # FIXME: this should probably not be in `root` but rather in a middleware in `outposts`
 | 
				
			||||||
    # but for now it's fine
 | 
					    # but for now it's fine
 | 
				
			||||||
@ -226,7 +233,11 @@ class ClientIPMiddleware:
 | 
				
			|||||||
        Scope.get_isolation_scope().set_user(user)
 | 
					        Scope.get_isolation_scope().set_user(user)
 | 
				
			||||||
        # Set the outpost service account on the request
 | 
					        # Set the outpost service account on the request
 | 
				
			||||||
        setattr(request, self.request_attr_outpost_user, user)
 | 
					        setattr(request, self.request_attr_outpost_user, user)
 | 
				
			||||||
        return delegated_ip
 | 
					        try:
 | 
				
			||||||
 | 
					            return str(ip_address(delegated_ip))
 | 
				
			||||||
 | 
					        except ValueError as exc:
 | 
				
			||||||
 | 
					            self.logger.debug("Invalid remote IP from Outpost", exc=exc)
 | 
				
			||||||
 | 
					            return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _get_client_ip(self, request: HttpRequest | None) -> str:
 | 
					    def _get_client_ip(self, request: HttpRequest | None) -> str:
 | 
				
			||||||
        """Attempt to get the client's IP by checking common HTTP Headers.
 | 
					        """Attempt to get the client's IP by checking common HTTP Headers.
 | 
				
			||||||
 | 
				
			|||||||
@ -1,6 +1,8 @@
 | 
				
			|||||||
"""Metrics view"""
 | 
					"""Metrics view"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from base64 import b64encode
 | 
					from hmac import compare_digest
 | 
				
			||||||
 | 
					from pathlib import Path
 | 
				
			||||||
 | 
					from tempfile import gettempdir
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from django.conf import settings
 | 
					from django.conf import settings
 | 
				
			||||||
from django.db import connections
 | 
					from django.db import connections
 | 
				
			||||||
@ -16,22 +18,21 @@ monitoring_set = Signal()
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class MetricsView(View):
 | 
					class MetricsView(View):
 | 
				
			||||||
    """Wrapper around ExportToDjangoView, using http-basic auth"""
 | 
					    """Wrapper around ExportToDjangoView with authentication, accessed by the authentik router"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, **kwargs):
 | 
				
			||||||
 | 
					        _tmp = Path(gettempdir())
 | 
				
			||||||
 | 
					        with open(_tmp / "authentik-core-metrics.key") as _f:
 | 
				
			||||||
 | 
					            self.monitoring_key = _f.read()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get(self, request: HttpRequest) -> HttpResponse:
 | 
					    def get(self, request: HttpRequest) -> HttpResponse:
 | 
				
			||||||
        """Check for HTTP-Basic auth"""
 | 
					        """Check for HTTP-Basic auth"""
 | 
				
			||||||
        auth_header = request.META.get("HTTP_AUTHORIZATION", "")
 | 
					        auth_header = request.META.get("HTTP_AUTHORIZATION", "")
 | 
				
			||||||
        auth_type, _, given_credentials = auth_header.partition(" ")
 | 
					        auth_type, _, given_credentials = auth_header.partition(" ")
 | 
				
			||||||
        credentials = f"monitor:{settings.SECRET_KEY}"
 | 
					        authed = auth_type == "Bearer" and compare_digest(given_credentials, self.monitoring_key)
 | 
				
			||||||
        expected = b64encode(str.encode(credentials)).decode()
 | 
					 | 
				
			||||||
        authed = auth_type == "Basic" and given_credentials == expected
 | 
					 | 
				
			||||||
        if not authed and not settings.DEBUG:
 | 
					        if not authed and not settings.DEBUG:
 | 
				
			||||||
            response = HttpResponse(status=401)
 | 
					            return HttpResponse(status=401)
 | 
				
			||||||
            response["WWW-Authenticate"] = 'Basic realm="authentik-monitoring"'
 | 
					 | 
				
			||||||
            return response
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        monitoring_set.send_robust(self)
 | 
					        monitoring_set.send_robust(self)
 | 
				
			||||||
 | 
					 | 
				
			||||||
        return ExportToDjangoView(request)
 | 
					        return ExportToDjangoView(request)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -1,6 +1,7 @@
 | 
				
			|||||||
"""authentik storage backends"""
 | 
					"""authentik storage backends"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
 | 
					from urllib.parse import parse_qsl, urlsplit
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from django.conf import settings
 | 
					from django.conf import settings
 | 
				
			||||||
from django.core.exceptions import SuspiciousOperation
 | 
					from django.core.exceptions import SuspiciousOperation
 | 
				
			||||||
@ -110,3 +111,34 @@ class S3Storage(BaseS3Storage):
 | 
				
			|||||||
        if self.querystring_auth:
 | 
					        if self.querystring_auth:
 | 
				
			||||||
            return url
 | 
					            return url
 | 
				
			||||||
        return self._strip_signing_parameters(url)
 | 
					        return self._strip_signing_parameters(url)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _strip_signing_parameters(self, url):
 | 
				
			||||||
 | 
					        # Boto3 does not currently support generating URLs that are unsigned. Instead
 | 
				
			||||||
 | 
					        # we take the signed URLs and strip any querystring params related to signing
 | 
				
			||||||
 | 
					        # and expiration.
 | 
				
			||||||
 | 
					        # Note that this may end up with URLs that are still invalid, especially if
 | 
				
			||||||
 | 
					        # params are passed in that only work with signed URLs, e.g. response header
 | 
				
			||||||
 | 
					        # params.
 | 
				
			||||||
 | 
					        # The code attempts to strip all query parameters that match names of known
 | 
				
			||||||
 | 
					        # parameters from v2 and v4 signatures, regardless of the actual signature
 | 
				
			||||||
 | 
					        # version used.
 | 
				
			||||||
 | 
					        split_url = urlsplit(url)
 | 
				
			||||||
 | 
					        qs = parse_qsl(split_url.query, keep_blank_values=True)
 | 
				
			||||||
 | 
					        blacklist = {
 | 
				
			||||||
 | 
					            "x-amz-algorithm",
 | 
				
			||||||
 | 
					            "x-amz-credential",
 | 
				
			||||||
 | 
					            "x-amz-date",
 | 
				
			||||||
 | 
					            "x-amz-expires",
 | 
				
			||||||
 | 
					            "x-amz-signedheaders",
 | 
				
			||||||
 | 
					            "x-amz-signature",
 | 
				
			||||||
 | 
					            "x-amz-security-token",
 | 
				
			||||||
 | 
					            "awsaccesskeyid",
 | 
				
			||||||
 | 
					            "expires",
 | 
				
			||||||
 | 
					            "signature",
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        filtered_qs = ((key, val) for key, val in qs if key.lower() not in blacklist)
 | 
				
			||||||
 | 
					        # Note: Parameters that did not have a value in the original query string will
 | 
				
			||||||
 | 
					        # have an '=' sign appended to it, e.g ?foo&bar becomes ?foo=&bar=
 | 
				
			||||||
 | 
					        joined_qs = ("=".join(keyval) for keyval in filtered_qs)
 | 
				
			||||||
 | 
					        split_url = split_url._replace(query="&".join(joined_qs))
 | 
				
			||||||
 | 
					        return split_url.geturl()
 | 
				
			||||||
 | 
				
			|||||||
@ -1,8 +1,9 @@
 | 
				
			|||||||
"""root tests"""
 | 
					"""root tests"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from base64 import b64encode
 | 
					from pathlib import Path
 | 
				
			||||||
 | 
					from secrets import token_urlsafe
 | 
				
			||||||
 | 
					from tempfile import gettempdir
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from django.conf import settings
 | 
					 | 
				
			||||||
from django.test import TestCase
 | 
					from django.test import TestCase
 | 
				
			||||||
from django.urls import reverse
 | 
					from django.urls import reverse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -10,6 +11,16 @@ from django.urls import reverse
 | 
				
			|||||||
class TestRoot(TestCase):
 | 
					class TestRoot(TestCase):
 | 
				
			||||||
    """Test root application"""
 | 
					    """Test root application"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def setUp(self):
 | 
				
			||||||
 | 
					        _tmp = Path(gettempdir())
 | 
				
			||||||
 | 
					        self.token = token_urlsafe(32)
 | 
				
			||||||
 | 
					        with open(_tmp / "authentik-core-metrics.key", "w") as _f:
 | 
				
			||||||
 | 
					            _f.write(self.token)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def tearDown(self):
 | 
				
			||||||
 | 
					        _tmp = Path(gettempdir())
 | 
				
			||||||
 | 
					        (_tmp / "authentik-core-metrics.key").unlink()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_monitoring_error(self):
 | 
					    def test_monitoring_error(self):
 | 
				
			||||||
        """Test monitoring without any credentials"""
 | 
					        """Test monitoring without any credentials"""
 | 
				
			||||||
        response = self.client.get(reverse("metrics"))
 | 
					        response = self.client.get(reverse("metrics"))
 | 
				
			||||||
@ -17,8 +28,7 @@ class TestRoot(TestCase):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def test_monitoring_ok(self):
 | 
					    def test_monitoring_ok(self):
 | 
				
			||||||
        """Test monitoring with credentials"""
 | 
					        """Test monitoring with credentials"""
 | 
				
			||||||
        creds = "Basic " + b64encode(f"monitor:{settings.SECRET_KEY}".encode()).decode("utf-8")
 | 
					        auth_headers = {"HTTP_AUTHORIZATION": f"Bearer {self.token}"}
 | 
				
			||||||
        auth_headers = {"HTTP_AUTHORIZATION": creds}
 | 
					 | 
				
			||||||
        response = self.client.get(reverse("metrics"), **auth_headers)
 | 
					        response = self.client.get(reverse("metrics"), **auth_headers)
 | 
				
			||||||
        self.assertEqual(response.status_code, 200)
 | 
					        self.assertEqual(response.status_code, 200)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -3,6 +3,7 @@
 | 
				
			|||||||
from typing import Any
 | 
					from typing import Any
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from django.core.cache import cache
 | 
					from django.core.cache import cache
 | 
				
			||||||
 | 
					from django.utils.translation import gettext_lazy as _
 | 
				
			||||||
from drf_spectacular.utils import extend_schema, inline_serializer
 | 
					from drf_spectacular.utils import extend_schema, inline_serializer
 | 
				
			||||||
from guardian.shortcuts import get_objects_for_user
 | 
					from guardian.shortcuts import get_objects_for_user
 | 
				
			||||||
from rest_framework.decorators import action
 | 
					from rest_framework.decorators import action
 | 
				
			||||||
@ -39,9 +40,8 @@ class LDAPSourceSerializer(SourceSerializer):
 | 
				
			|||||||
        """Get cached source connectivity"""
 | 
					        """Get cached source connectivity"""
 | 
				
			||||||
        return cache.get(CACHE_KEY_STATUS + source.slug, None)
 | 
					        return cache.get(CACHE_KEY_STATUS + source.slug, None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def validate(self, attrs: dict[str, Any]) -> dict[str, Any]:
 | 
					    def validate_sync_users_password(self, sync_users_password: bool) -> bool:
 | 
				
			||||||
        """Check that only a single source has password_sync on"""
 | 
					        """Check that only a single source has password_sync on"""
 | 
				
			||||||
        sync_users_password = attrs.get("sync_users_password", True)
 | 
					 | 
				
			||||||
        if sync_users_password:
 | 
					        if sync_users_password:
 | 
				
			||||||
            sources = LDAPSource.objects.filter(sync_users_password=True)
 | 
					            sources = LDAPSource.objects.filter(sync_users_password=True)
 | 
				
			||||||
            if self.instance:
 | 
					            if self.instance:
 | 
				
			||||||
@ -49,11 +49,31 @@ class LDAPSourceSerializer(SourceSerializer):
 | 
				
			|||||||
            if sources.exists():
 | 
					            if sources.exists():
 | 
				
			||||||
                raise ValidationError(
 | 
					                raise ValidationError(
 | 
				
			||||||
                    {
 | 
					                    {
 | 
				
			||||||
                        "sync_users_password": (
 | 
					                        "sync_users_password": _(
 | 
				
			||||||
                            "Only a single LDAP Source with password synchronization is allowed"
 | 
					                            "Only a single LDAP Source with password synchronization is allowed"
 | 
				
			||||||
                        )
 | 
					                        )
 | 
				
			||||||
                    }
 | 
					                    }
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
 | 
					        return sync_users_password
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def validate(self, attrs: dict[str, Any]) -> dict[str, Any]:
 | 
				
			||||||
 | 
					        """Validate property mappings with sync_ flags"""
 | 
				
			||||||
 | 
					        types = ["user", "group"]
 | 
				
			||||||
 | 
					        for type in types:
 | 
				
			||||||
 | 
					            toggle_value = attrs.get(f"sync_{type}s", False)
 | 
				
			||||||
 | 
					            mappings_field = f"{type}_property_mappings"
 | 
				
			||||||
 | 
					            mappings_value = attrs.get(mappings_field, [])
 | 
				
			||||||
 | 
					            if toggle_value and len(mappings_value) == 0:
 | 
				
			||||||
 | 
					                raise ValidationError(
 | 
				
			||||||
 | 
					                    {
 | 
				
			||||||
 | 
					                        mappings_field: _(
 | 
				
			||||||
 | 
					                            (
 | 
				
			||||||
 | 
					                                "When 'Sync {type}s' is enabled, '{type}s property "
 | 
				
			||||||
 | 
					                                "mappings' cannot be empty."
 | 
				
			||||||
 | 
					                            ).format(type=type)
 | 
				
			||||||
 | 
					                        )
 | 
				
			||||||
 | 
					                    }
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
        return super().validate(attrs)
 | 
					        return super().validate(attrs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    class Meta:
 | 
					    class Meta:
 | 
				
			||||||
@ -166,11 +186,12 @@ class LDAPSourceViewSet(UsedByMixin, ModelViewSet):
 | 
				
			|||||||
        for sync_class in SYNC_CLASSES:
 | 
					        for sync_class in SYNC_CLASSES:
 | 
				
			||||||
            class_name = sync_class.name()
 | 
					            class_name = sync_class.name()
 | 
				
			||||||
            all_objects.setdefault(class_name, [])
 | 
					            all_objects.setdefault(class_name, [])
 | 
				
			||||||
            for obj in sync_class(source).get_objects(size_limit=10):
 | 
					            for page in sync_class(source).get_objects(size_limit=10):
 | 
				
			||||||
                obj: dict
 | 
					                for obj in page:
 | 
				
			||||||
                obj.pop("raw_attributes", None)
 | 
					                    obj: dict
 | 
				
			||||||
                obj.pop("raw_dn", None)
 | 
					                    obj.pop("raw_attributes", None)
 | 
				
			||||||
                all_objects[class_name].append(obj)
 | 
					                    obj.pop("raw_dn", None)
 | 
				
			||||||
 | 
					                    all_objects[class_name].append(obj)
 | 
				
			||||||
        return Response(data=all_objects)
 | 
					        return Response(data=all_objects)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -26,17 +26,16 @@ def sync_ldap_source_on_save(sender, instance: LDAPSource, **_):
 | 
				
			|||||||
    """Ensure that source is synced on save (if enabled)"""
 | 
					    """Ensure that source is synced on save (if enabled)"""
 | 
				
			||||||
    if not instance.enabled:
 | 
					    if not instance.enabled:
 | 
				
			||||||
        return
 | 
					        return
 | 
				
			||||||
 | 
					    ldap_connectivity_check.delay(instance.pk)
 | 
				
			||||||
    # Don't sync sources when they don't have any property mappings. This will only happen if:
 | 
					    # Don't sync sources when they don't have any property mappings. This will only happen if:
 | 
				
			||||||
    # - the user forgets to set them or
 | 
					    # - the user forgets to set them or
 | 
				
			||||||
    # - the source is newly created, this is the first save event
 | 
					    # - the source is newly created, this is the first save event
 | 
				
			||||||
    #   and the mappings are created with an m2m event
 | 
					    #   and the mappings are created with an m2m event
 | 
				
			||||||
    if (
 | 
					    if instance.sync_users and not instance.user_property_mappings.exists():
 | 
				
			||||||
        not instance.user_property_mappings.exists()
 | 
					        return
 | 
				
			||||||
        or not instance.group_property_mappings.exists()
 | 
					    if instance.sync_groups and not instance.group_property_mappings.exists():
 | 
				
			||||||
    ):
 | 
					 | 
				
			||||||
        return
 | 
					        return
 | 
				
			||||||
    ldap_sync_single.delay(instance.pk)
 | 
					    ldap_sync_single.delay(instance.pk)
 | 
				
			||||||
    ldap_connectivity_check.delay(instance.pk)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@receiver(password_validate)
 | 
					@receiver(password_validate)
 | 
				
			||||||
 | 
				
			|||||||
@ -38,7 +38,11 @@ class GroupLDAPSynchronizer(BaseLDAPSynchronizer):
 | 
				
			|||||||
            search_base=self.base_dn_groups,
 | 
					            search_base=self.base_dn_groups,
 | 
				
			||||||
            search_filter=self._source.group_object_filter,
 | 
					            search_filter=self._source.group_object_filter,
 | 
				
			||||||
            search_scope=SUBTREE,
 | 
					            search_scope=SUBTREE,
 | 
				
			||||||
            attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES],
 | 
					            attributes=[
 | 
				
			||||||
 | 
					                ALL_ATTRIBUTES,
 | 
				
			||||||
 | 
					                ALL_OPERATIONAL_ATTRIBUTES,
 | 
				
			||||||
 | 
					                self._source.object_uniqueness_field,
 | 
				
			||||||
 | 
					            ],
 | 
				
			||||||
            **kwargs,
 | 
					            **kwargs,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -53,9 +57,9 @@ class GroupLDAPSynchronizer(BaseLDAPSynchronizer):
 | 
				
			|||||||
                continue
 | 
					                continue
 | 
				
			||||||
            attributes = group.get("attributes", {})
 | 
					            attributes = group.get("attributes", {})
 | 
				
			||||||
            group_dn = flatten(flatten(group.get("entryDN", group.get("dn"))))
 | 
					            group_dn = flatten(flatten(group.get("entryDN", group.get("dn"))))
 | 
				
			||||||
            if self._source.object_uniqueness_field not in attributes:
 | 
					            if not attributes.get(self._source.object_uniqueness_field):
 | 
				
			||||||
                self.message(
 | 
					                self.message(
 | 
				
			||||||
                    f"Cannot find uniqueness field in attributes: '{group_dn}'",
 | 
					                    f"Uniqueness field not found/not set in attributes: '{group_dn}'",
 | 
				
			||||||
                    attributes=attributes.keys(),
 | 
					                    attributes=attributes.keys(),
 | 
				
			||||||
                    dn=group_dn,
 | 
					                    dn=group_dn,
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
 | 
				
			|||||||
@ -40,7 +40,11 @@ class UserLDAPSynchronizer(BaseLDAPSynchronizer):
 | 
				
			|||||||
            search_base=self.base_dn_users,
 | 
					            search_base=self.base_dn_users,
 | 
				
			||||||
            search_filter=self._source.user_object_filter,
 | 
					            search_filter=self._source.user_object_filter,
 | 
				
			||||||
            search_scope=SUBTREE,
 | 
					            search_scope=SUBTREE,
 | 
				
			||||||
            attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES],
 | 
					            attributes=[
 | 
				
			||||||
 | 
					                ALL_ATTRIBUTES,
 | 
				
			||||||
 | 
					                ALL_OPERATIONAL_ATTRIBUTES,
 | 
				
			||||||
 | 
					                self._source.object_uniqueness_field,
 | 
				
			||||||
 | 
					            ],
 | 
				
			||||||
            **kwargs,
 | 
					            **kwargs,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -55,9 +59,9 @@ class UserLDAPSynchronizer(BaseLDAPSynchronizer):
 | 
				
			|||||||
                continue
 | 
					                continue
 | 
				
			||||||
            attributes = user.get("attributes", {})
 | 
					            attributes = user.get("attributes", {})
 | 
				
			||||||
            user_dn = flatten(user.get("entryDN", user.get("dn")))
 | 
					            user_dn = flatten(user.get("entryDN", user.get("dn")))
 | 
				
			||||||
            if self._source.object_uniqueness_field not in attributes:
 | 
					            if not attributes.get(self._source.object_uniqueness_field):
 | 
				
			||||||
                self.message(
 | 
					                self.message(
 | 
				
			||||||
                    f"Cannot find uniqueness field in attributes: '{user_dn}'",
 | 
					                    f"Uniqueness field not found/not set in attributes: '{user_dn}'",
 | 
				
			||||||
                    attributes=attributes.keys(),
 | 
					                    attributes=attributes.keys(),
 | 
				
			||||||
                    dn=user_dn,
 | 
					                    dn=user_dn,
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										4
									
								
								authentik/sources/ldap/sync/vendor/ms_ad.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								authentik/sources/ldap/sync/vendor/ms_ad.py
									
									
									
									
										vendored
									
									
								
							@ -78,7 +78,9 @@ class MicrosoftActiveDirectory(BaseLDAPSynchronizer):
 | 
				
			|||||||
        #   /useraccountcontrol-manipulate-account-properties
 | 
					        #   /useraccountcontrol-manipulate-account-properties
 | 
				
			||||||
        uac_bit = attributes.get("userAccountControl", 512)
 | 
					        uac_bit = attributes.get("userAccountControl", 512)
 | 
				
			||||||
        uac = UserAccountControl(uac_bit)
 | 
					        uac = UserAccountControl(uac_bit)
 | 
				
			||||||
        is_active = UserAccountControl.ACCOUNTDISABLE not in uac
 | 
					        is_active = (
 | 
				
			||||||
 | 
					            UserAccountControl.ACCOUNTDISABLE not in uac and UserAccountControl.LOCKOUT not in uac
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        if is_active != user.is_active:
 | 
					        if is_active != user.is_active:
 | 
				
			||||||
            user.is_active = is_active
 | 
					            user.is_active = is_active
 | 
				
			||||||
            user.save()
 | 
					            user.save()
 | 
				
			||||||
 | 
				
			|||||||
@ -50,3 +50,35 @@ class LDAPAPITests(APITestCase):
 | 
				
			|||||||
            }
 | 
					            }
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        self.assertFalse(serializer.is_valid())
 | 
					        self.assertFalse(serializer.is_valid())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_sync_users_mapping_empty(self):
 | 
				
			||||||
 | 
					        """Check that when sync_users is enabled, property mappings must be set"""
 | 
				
			||||||
 | 
					        serializer = LDAPSourceSerializer(
 | 
				
			||||||
 | 
					            data={
 | 
				
			||||||
 | 
					                "name": "foo",
 | 
				
			||||||
 | 
					                "slug": " foo",
 | 
				
			||||||
 | 
					                "server_uri": "ldaps://1.2.3.4",
 | 
				
			||||||
 | 
					                "bind_cn": "",
 | 
				
			||||||
 | 
					                "bind_password": LDAP_PASSWORD,
 | 
				
			||||||
 | 
					                "base_dn": "dc=foo",
 | 
				
			||||||
 | 
					                "sync_users": True,
 | 
				
			||||||
 | 
					                "user_property_mappings": [],
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertFalse(serializer.is_valid())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_sync_groups_mapping_empty(self):
 | 
				
			||||||
 | 
					        """Check that when sync_groups is enabled, property mappings must be set"""
 | 
				
			||||||
 | 
					        serializer = LDAPSourceSerializer(
 | 
				
			||||||
 | 
					            data={
 | 
				
			||||||
 | 
					                "name": "foo",
 | 
				
			||||||
 | 
					                "slug": " foo",
 | 
				
			||||||
 | 
					                "server_uri": "ldaps://1.2.3.4",
 | 
				
			||||||
 | 
					                "bind_cn": "",
 | 
				
			||||||
 | 
					                "bind_password": LDAP_PASSWORD,
 | 
				
			||||||
 | 
					                "base_dn": "dc=foo",
 | 
				
			||||||
 | 
					                "sync_groups": True,
 | 
				
			||||||
 | 
					                "group_property_mappings": [],
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertFalse(serializer.is_valid())
 | 
				
			||||||
 | 
				
			|||||||
@ -30,7 +30,9 @@ class TestMetadataProcessor(TestCase):
 | 
				
			|||||||
        xml = MetadataProcessor(self.source, request).build_entity_descriptor()
 | 
					        xml = MetadataProcessor(self.source, request).build_entity_descriptor()
 | 
				
			||||||
        metadata = lxml_from_string(xml)
 | 
					        metadata = lxml_from_string(xml)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        schema = etree.XMLSchema(etree.parse("schemas/saml-schema-metadata-2.0.xsd"))  # nosec
 | 
					        schema = etree.XMLSchema(
 | 
				
			||||||
 | 
					            etree.parse("schemas/saml-schema-metadata-2.0.xsd", parser=etree.XMLParser())  # nosec
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        self.assertTrue(schema.validate(metadata))
 | 
					        self.assertTrue(schema.validate(metadata))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_metadata_consistent(self):
 | 
					    def test_metadata_consistent(self):
 | 
				
			||||||
 | 
				
			|||||||
@ -82,3 +82,5 @@ entries:
 | 
				
			|||||||
    order: 10
 | 
					    order: 10
 | 
				
			||||||
    target: !KeyOf default-authentication-flow-password-binding
 | 
					    target: !KeyOf default-authentication-flow-password-binding
 | 
				
			||||||
    policy: !KeyOf default-authentication-flow-password-optional
 | 
					    policy: !KeyOf default-authentication-flow-password-optional
 | 
				
			||||||
 | 
					  attrs:
 | 
				
			||||||
 | 
					    failure_result: true
 | 
				
			||||||
 | 
				
			|||||||
@ -2,7 +2,7 @@
 | 
				
			|||||||
    "$schema": "http://json-schema.org/draft-07/schema",
 | 
					    "$schema": "http://json-schema.org/draft-07/schema",
 | 
				
			||||||
    "$id": "https://goauthentik.io/blueprints/schema.json",
 | 
					    "$id": "https://goauthentik.io/blueprints/schema.json",
 | 
				
			||||||
    "type": "object",
 | 
					    "type": "object",
 | 
				
			||||||
    "title": "authentik 2024.6.4 Blueprint schema",
 | 
					    "title": "authentik 2024.8.6 Blueprint schema",
 | 
				
			||||||
    "required": [
 | 
					    "required": [
 | 
				
			||||||
        "version",
 | 
					        "version",
 | 
				
			||||||
        "entries"
 | 
					        "entries"
 | 
				
			||||||
@ -5345,9 +5345,30 @@
 | 
				
			|||||||
                    "description": "Key used to sign the tokens. Only required when JWT Algorithm is set to RS256."
 | 
					                    "description": "Key used to sign the tokens. Only required when JWT Algorithm is set to RS256."
 | 
				
			||||||
                },
 | 
					                },
 | 
				
			||||||
                "redirect_uris": {
 | 
					                "redirect_uris": {
 | 
				
			||||||
                    "type": "string",
 | 
					                    "type": "array",
 | 
				
			||||||
                    "title": "Redirect URIs",
 | 
					                    "items": {
 | 
				
			||||||
                    "description": "Enter each URI on a new line."
 | 
					                        "type": "object",
 | 
				
			||||||
 | 
					                        "properties": {
 | 
				
			||||||
 | 
					                            "matching_mode": {
 | 
				
			||||||
 | 
					                                "type": "string",
 | 
				
			||||||
 | 
					                                "enum": [
 | 
				
			||||||
 | 
					                                    "strict",
 | 
				
			||||||
 | 
					                                    "regex"
 | 
				
			||||||
 | 
					                                ],
 | 
				
			||||||
 | 
					                                "title": "Matching mode"
 | 
				
			||||||
 | 
					                            },
 | 
				
			||||||
 | 
					                            "url": {
 | 
				
			||||||
 | 
					                                "type": "string",
 | 
				
			||||||
 | 
					                                "minLength": 1,
 | 
				
			||||||
 | 
					                                "title": "Url"
 | 
				
			||||||
 | 
					                            }
 | 
				
			||||||
 | 
					                        },
 | 
				
			||||||
 | 
					                        "required": [
 | 
				
			||||||
 | 
					                            "matching_mode",
 | 
				
			||||||
 | 
					                            "url"
 | 
				
			||||||
 | 
					                        ]
 | 
				
			||||||
 | 
					                    },
 | 
				
			||||||
 | 
					                    "title": "Redirect uris"
 | 
				
			||||||
                },
 | 
					                },
 | 
				
			||||||
                "sub_mode": {
 | 
					                "sub_mode": {
 | 
				
			||||||
                    "type": "string",
 | 
					                    "type": "string",
 | 
				
			||||||
 | 
				
			|||||||
@ -14,11 +14,7 @@ entries:
 | 
				
			|||||||
      expression: |
 | 
					      expression: |
 | 
				
			||||||
        # This mapping is used by the authentik proxy. It passes extra user attributes,
 | 
					        # This mapping is used by the authentik proxy. It passes extra user attributes,
 | 
				
			||||||
        # which are used for example for the HTTP-Basic Authentication mapping.
 | 
					        # which are used for example for the HTTP-Basic Authentication mapping.
 | 
				
			||||||
        session_id = None
 | 
					 | 
				
			||||||
        if "token" in request.context:
 | 
					 | 
				
			||||||
            session_id = request.context.get("token").session_id
 | 
					 | 
				
			||||||
        return {
 | 
					        return {
 | 
				
			||||||
            "sid": session_id,
 | 
					 | 
				
			||||||
            "ak_proxy": {
 | 
					            "ak_proxy": {
 | 
				
			||||||
                "user_attributes": request.user.group_attributes(request),
 | 
					                "user_attributes": request.user.group_attributes(request),
 | 
				
			||||||
                "is_superuser": request.user.is_superuser,
 | 
					                "is_superuser": request.user.is_superuser,
 | 
				
			||||||
 | 
				
			|||||||
@ -31,7 +31,7 @@ services:
 | 
				
			|||||||
    volumes:
 | 
					    volumes:
 | 
				
			||||||
      - redis:/data
 | 
					      - redis:/data
 | 
				
			||||||
  server:
 | 
					  server:
 | 
				
			||||||
    image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2024.6.4}
 | 
					    image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2024.8.6}
 | 
				
			||||||
    restart: unless-stopped
 | 
					    restart: unless-stopped
 | 
				
			||||||
    command: server
 | 
					    command: server
 | 
				
			||||||
    environment:
 | 
					    environment:
 | 
				
			||||||
@ -52,7 +52,7 @@ services:
 | 
				
			|||||||
      - postgresql
 | 
					      - postgresql
 | 
				
			||||||
      - redis
 | 
					      - redis
 | 
				
			||||||
  worker:
 | 
					  worker:
 | 
				
			||||||
    image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2024.6.4}
 | 
					    image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2024.8.6}
 | 
				
			||||||
    restart: unless-stopped
 | 
					    restart: unless-stopped
 | 
				
			||||||
    command: worker
 | 
					    command: worker
 | 
				
			||||||
    environment:
 | 
					    environment:
 | 
				
			||||||
 | 
				
			|||||||
@ -29,4 +29,4 @@ func UserAgent() string {
 | 
				
			|||||||
	return fmt.Sprintf("authentik@%s", FullVersion())
 | 
						return fmt.Sprintf("authentik@%s", FullVersion())
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const VERSION = "2024.6.4"
 | 
					const VERSION = "2024.8.6"
 | 
				
			||||||
 | 
				
			|||||||
@ -35,10 +35,11 @@ func Paginator[Tobj any, Treq any, Tres PaginatorResponse[Tobj]](
 | 
				
			|||||||
	req PaginatorRequest[Treq, Tres],
 | 
						req PaginatorRequest[Treq, Tres],
 | 
				
			||||||
	opts PaginatorOptions,
 | 
						opts PaginatorOptions,
 | 
				
			||||||
) ([]Tobj, error) {
 | 
					) ([]Tobj, error) {
 | 
				
			||||||
 | 
						var bfreq, cfreq interface{}
 | 
				
			||||||
	fetchOffset := func(page int32) (Tres, error) {
 | 
						fetchOffset := func(page int32) (Tres, error) {
 | 
				
			||||||
		req.Page(page)
 | 
							bfreq = req.Page(page)
 | 
				
			||||||
		req.PageSize(int32(opts.PageSize))
 | 
							cfreq = bfreq.(PaginatorRequest[Treq, Tres]).PageSize(int32(opts.PageSize))
 | 
				
			||||||
		res, _, err := req.Execute()
 | 
							res, _, err := cfreq.(PaginatorRequest[Treq, Tres]).Execute()
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			opts.Logger.WithError(err).WithField("page", page).Warning("failed to fetch page")
 | 
								opts.Logger.WithError(err).WithField("page", page).Warning("failed to fetch page")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										26
									
								
								internal/outpost/ak/api_utils_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								internal/outpost/ak/api_utils_test.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,26 @@
 | 
				
			|||||||
 | 
					package ak
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// func Test_PaginatorCompile(t *testing.T) {
 | 
				
			||||||
 | 
					// 	req := api.ApiCoreUsersListRequest{}
 | 
				
			||||||
 | 
					// 	Paginator(req, PaginatorOptions{
 | 
				
			||||||
 | 
					// 		PageSize: 100,
 | 
				
			||||||
 | 
					// 	})
 | 
				
			||||||
 | 
					// }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// func Test_PaginatorCompileExplicit(t *testing.T) {
 | 
				
			||||||
 | 
					// 	req := api.ApiCoreUsersListRequest{}
 | 
				
			||||||
 | 
					// 	Paginator[
 | 
				
			||||||
 | 
					// 		api.User,
 | 
				
			||||||
 | 
					// 		api.ApiCoreUsersListRequest,
 | 
				
			||||||
 | 
					// 		*api.PaginatedUserList,
 | 
				
			||||||
 | 
					// 	](req, PaginatorOptions{
 | 
				
			||||||
 | 
					// 		PageSize: 100,
 | 
				
			||||||
 | 
					// 	})
 | 
				
			||||||
 | 
					// }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// func Test_PaginatorCompileOther(t *testing.T) {
 | 
				
			||||||
 | 
					// 	req := api.ApiOutpostsProxyListRequest{}
 | 
				
			||||||
 | 
					// 	Paginator(req, PaginatorOptions{
 | 
				
			||||||
 | 
					// 		PageSize: 100,
 | 
				
			||||||
 | 
					// 	})
 | 
				
			||||||
 | 
					// }
 | 
				
			||||||
@ -96,7 +96,7 @@ func (db *DirectBinder) Bind(username string, req *bind.Request) (ldap.LDAPResul
 | 
				
			|||||||
		return ldap.LDAPResultOperationsError, nil
 | 
							return ldap.LDAPResultOperationsError, nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	flags.UserPk = userInfo.User.Pk
 | 
						flags.UserPk = userInfo.User.Pk
 | 
				
			||||||
	flags.CanSearch = access.HasSearchPermission != nil
 | 
						flags.CanSearch = access.GetHasSearchPermission()
 | 
				
			||||||
	db.si.SetFlags(req.BindDN, &flags)
 | 
						db.si.SetFlags(req.BindDN, &flags)
 | 
				
			||||||
	if flags.CanSearch {
 | 
						if flags.CanSearch {
 | 
				
			||||||
		req.Log().Debug("Allowed access to search")
 | 
							req.Log().Debug("Allowed access to search")
 | 
				
			||||||
 | 
				
			|||||||
@ -193,7 +193,17 @@ func NewApplication(p api.ProxyOutpostConfig, c *http.Client, server Server) (*A
 | 
				
			|||||||
	})
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	mux.HandleFunc("/outpost.goauthentik.io/start", func(w http.ResponseWriter, r *http.Request) {
 | 
						mux.HandleFunc("/outpost.goauthentik.io/start", func(w http.ResponseWriter, r *http.Request) {
 | 
				
			||||||
		a.handleAuthStart(w, r, "")
 | 
							fwd := ""
 | 
				
			||||||
 | 
							// This should only really be hit for nginx forward_auth
 | 
				
			||||||
 | 
							// as for that the auth start redirect URL is generated by the
 | 
				
			||||||
 | 
							// reverse proxy, and as such we won't have a request we just
 | 
				
			||||||
 | 
							// denied to reference for final URL
 | 
				
			||||||
 | 
							rd, ok := a.checkRedirectParam(r)
 | 
				
			||||||
 | 
							if ok {
 | 
				
			||||||
 | 
								a.log.WithField("rd", rd).Trace("Setting redirect")
 | 
				
			||||||
 | 
								fwd = rd
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							a.handleAuthStart(w, r, fwd)
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
	mux.HandleFunc("/outpost.goauthentik.io/callback", a.handleAuthCallback)
 | 
						mux.HandleFunc("/outpost.goauthentik.io/callback", a.handleAuthCallback)
 | 
				
			||||||
	mux.HandleFunc("/outpost.goauthentik.io/sign_out", a.handleSignOut)
 | 
						mux.HandleFunc("/outpost.goauthentik.io/sign_out", a.handleSignOut)
 | 
				
			||||||
 | 
				
			|||||||
@ -15,36 +15,6 @@ const (
 | 
				
			|||||||
	LogoutSignature   = "X-authentik-logout"
 | 
						LogoutSignature   = "X-authentik-logout"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (a *Application) checkRedirectParam(r *http.Request) (string, bool) {
 | 
					 | 
				
			||||||
	rd := r.URL.Query().Get(redirectParam)
 | 
					 | 
				
			||||||
	if rd == "" {
 | 
					 | 
				
			||||||
		return "", false
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	u, err := url.Parse(rd)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		a.log.WithError(err).Warning("Failed to parse redirect URL")
 | 
					 | 
				
			||||||
		return "", false
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	// Check to make sure we only redirect to allowed places
 | 
					 | 
				
			||||||
	if a.Mode() == api.PROXYMODE_PROXY || a.Mode() == api.PROXYMODE_FORWARD_SINGLE {
 | 
					 | 
				
			||||||
		ext, err := url.Parse(a.proxyConfig.ExternalHost)
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			return "", false
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		ext.Scheme = ""
 | 
					 | 
				
			||||||
		if !strings.Contains(u.String(), ext.String()) {
 | 
					 | 
				
			||||||
			a.log.WithField("url", u.String()).WithField("ext", ext.String()).Warning("redirect URI did not contain external host")
 | 
					 | 
				
			||||||
			return "", false
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	} else {
 | 
					 | 
				
			||||||
		if !strings.HasSuffix(u.Host, *a.proxyConfig.CookieDomain) {
 | 
					 | 
				
			||||||
			a.log.WithField("host", u.Host).WithField("dom", *a.proxyConfig.CookieDomain).Warning("redirect URI Host was not included in cookie domain")
 | 
					 | 
				
			||||||
			return "", false
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return u.String(), true
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (a *Application) handleAuthStart(rw http.ResponseWriter, r *http.Request, fwd string) {
 | 
					func (a *Application) handleAuthStart(rw http.ResponseWriter, r *http.Request, fwd string) {
 | 
				
			||||||
	state, err := a.createState(r, fwd)
 | 
						state, err := a.createState(r, fwd)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
 | 
				
			|||||||
@ -5,10 +5,13 @@ import (
 | 
				
			|||||||
	"encoding/base64"
 | 
						"encoding/base64"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
 | 
						"net/url"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/golang-jwt/jwt/v5"
 | 
						"github.com/golang-jwt/jwt/v5"
 | 
				
			||||||
	"github.com/gorilla/securecookie"
 | 
						"github.com/gorilla/securecookie"
 | 
				
			||||||
	"github.com/mitchellh/mapstructure"
 | 
						"github.com/mitchellh/mapstructure"
 | 
				
			||||||
 | 
						"goauthentik.io/api/v3"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type OAuthState struct {
 | 
					type OAuthState struct {
 | 
				
			||||||
@ -27,6 +30,44 @@ func (oas *OAuthState) GetAudience() (jwt.ClaimStrings, error)       { return ni
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
var base32RawStdEncoding = base32.StdEncoding.WithPadding(base32.NoPadding)
 | 
					var base32RawStdEncoding = base32.StdEncoding.WithPadding(base32.NoPadding)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Validate that the given redirect parameter (?rd=...) is valid and can be used
 | 
				
			||||||
 | 
					// For proxy/forward_single this checks that if the `rd` param has a Hostname (and is a full URL)
 | 
				
			||||||
 | 
					// the hostname matches what's configured, or no hostname must be given
 | 
				
			||||||
 | 
					// For forward_domain this checks if the domain of the URL in `rd` ends with the configured domain
 | 
				
			||||||
 | 
					func (a *Application) checkRedirectParam(r *http.Request) (string, bool) {
 | 
				
			||||||
 | 
						rd := r.URL.Query().Get(redirectParam)
 | 
				
			||||||
 | 
						if rd == "" {
 | 
				
			||||||
 | 
							return "", false
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						u, err := url.Parse(rd)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							a.log.WithError(err).Warning("Failed to parse redirect URL")
 | 
				
			||||||
 | 
							return "", false
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						// Check to make sure we only redirect to allowed places
 | 
				
			||||||
 | 
						if a.Mode() == api.PROXYMODE_PROXY || a.Mode() == api.PROXYMODE_FORWARD_SINGLE {
 | 
				
			||||||
 | 
							ext, err := url.Parse(a.proxyConfig.ExternalHost)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return "", false
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							// Either hostname needs to match the configured domain, or host name must be empty for just a path
 | 
				
			||||||
 | 
							if u.Host == "" {
 | 
				
			||||||
 | 
								u.Host = ext.Host
 | 
				
			||||||
 | 
								u.Scheme = ext.Scheme
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if u.Host != ext.Host {
 | 
				
			||||||
 | 
								a.log.WithField("url", u.String()).WithField("ext", ext.String()).Warning("redirect URI did not contain external host")
 | 
				
			||||||
 | 
								return "", false
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							if !strings.HasSuffix(u.Host, *a.proxyConfig.CookieDomain) {
 | 
				
			||||||
 | 
								a.log.WithField("host", u.Host).WithField("dom", *a.proxyConfig.CookieDomain).Warning("redirect URI Host was not included in cookie domain")
 | 
				
			||||||
 | 
								return "", false
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return u.String(), true
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (a *Application) createState(r *http.Request, fwd string) (string, error) {
 | 
					func (a *Application) createState(r *http.Request, fwd string) (string, error) {
 | 
				
			||||||
	s, _ := a.sessions.Get(r, a.SessionName())
 | 
						s, _ := a.sessions.Get(r, a.SessionName())
 | 
				
			||||||
	if s.ID == "" {
 | 
						if s.ID == "" {
 | 
				
			||||||
@ -39,17 +80,6 @@ func (a *Application) createState(r *http.Request, fwd string) (string, error) {
 | 
				
			|||||||
		SessionID: s.ID,
 | 
							SessionID: s.ID,
 | 
				
			||||||
		Redirect:  fwd,
 | 
							Redirect:  fwd,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if fwd == "" {
 | 
					 | 
				
			||||||
		// This should only really be hit for nginx forward_auth
 | 
					 | 
				
			||||||
		// as for that the auth start redirect URL is generated by the
 | 
					 | 
				
			||||||
		// reverse proxy, and as such we won't have a request we just
 | 
					 | 
				
			||||||
		// denied to reference for final URL
 | 
					 | 
				
			||||||
		rd, ok := a.checkRedirectParam(r)
 | 
					 | 
				
			||||||
		if ok {
 | 
					 | 
				
			||||||
			a.log.WithField("rd", rd).Trace("Setting redirect")
 | 
					 | 
				
			||||||
			st.Redirect = rd
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	token := jwt.NewWithClaims(jwt.SigningMethodHS256, st)
 | 
						token := jwt.NewWithClaims(jwt.SigningMethodHS256, st)
 | 
				
			||||||
	tokenString, err := token.SignedString([]byte(a.proxyConfig.GetCookieSecret()))
 | 
						tokenString, err := token.SignedString([]byte(a.proxyConfig.GetCookieSecret()))
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
 | 
				
			|||||||
@ -8,25 +8,45 @@ import (
 | 
				
			|||||||
	"goauthentik.io/api/v3"
 | 
						"goauthentik.io/api/v3"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestCheckRedirectParam(t *testing.T) {
 | 
					func TestCheckRedirectParam_None(t *testing.T) {
 | 
				
			||||||
	a := newTestApplication()
 | 
						a := newTestApplication()
 | 
				
			||||||
 | 
						// Test no rd param
 | 
				
			||||||
	req, _ := http.NewRequest("GET", "/outpost.goauthentik.io/auth/start", nil)
 | 
						req, _ := http.NewRequest("GET", "/outpost.goauthentik.io/auth/start", nil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	rd, ok := a.checkRedirectParam(req)
 | 
						rd, ok := a.checkRedirectParam(req)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	assert.Equal(t, false, ok)
 | 
						assert.Equal(t, false, ok)
 | 
				
			||||||
	assert.Equal(t, "", rd)
 | 
						assert.Equal(t, "", rd)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	req, _ = http.NewRequest("GET", "/outpost.goauthentik.io/auth/start?rd=https://google.com", nil)
 | 
					func TestCheckRedirectParam_Invalid(t *testing.T) {
 | 
				
			||||||
 | 
						a := newTestApplication()
 | 
				
			||||||
 | 
						// Test invalid rd param
 | 
				
			||||||
 | 
						req, _ := http.NewRequest("GET", "/outpost.goauthentik.io/auth/start?rd=https://google.com", nil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	rd, ok = a.checkRedirectParam(req)
 | 
						rd, ok := a.checkRedirectParam(req)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	assert.Equal(t, false, ok)
 | 
						assert.Equal(t, false, ok)
 | 
				
			||||||
	assert.Equal(t, "", rd)
 | 
						assert.Equal(t, "", rd)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	req, _ = http.NewRequest("GET", "/outpost.goauthentik.io/auth/start?rd=https://ext.t.goauthentik.io/test?foo", nil)
 | 
					func TestCheckRedirectParam_ValidFull(t *testing.T) {
 | 
				
			||||||
 | 
						a := newTestApplication()
 | 
				
			||||||
 | 
						// Test valid full rd param
 | 
				
			||||||
 | 
						req, _ := http.NewRequest("GET", "/outpost.goauthentik.io/auth/start?rd=https://ext.t.goauthentik.io/test?foo", nil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	rd, ok = a.checkRedirectParam(req)
 | 
						rd, ok := a.checkRedirectParam(req)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						assert.Equal(t, true, ok)
 | 
				
			||||||
 | 
						assert.Equal(t, "https://ext.t.goauthentik.io/test?foo", rd)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestCheckRedirectParam_ValidPartial(t *testing.T) {
 | 
				
			||||||
 | 
						a := newTestApplication()
 | 
				
			||||||
 | 
						// Test valid partial rd param
 | 
				
			||||||
 | 
						req, _ := http.NewRequest("GET", "/outpost.goauthentik.io/auth/start?rd=/test?foo", nil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						rd, ok := a.checkRedirectParam(req)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	assert.Equal(t, true, ok)
 | 
						assert.Equal(t, true, ok)
 | 
				
			||||||
	assert.Equal(t, "https://ext.t.goauthentik.io/test?foo", rd)
 | 
						assert.Equal(t, "https://ext.t.goauthentik.io/test?foo", rd)
 | 
				
			||||||
 | 
				
			|||||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user