Compare commits
	
		
			104 Commits
		
	
	
		
			version/20
			...
			version-te
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 5b6b059b40 | |||
| 060cea219b | |||
| af9d82c02d | |||
| cc8fb66da2 | |||
| f0edc7b931 | |||
| b39632abb0 | |||
| c59b859ec0 | |||
| a46939b591 | |||
| bfb4a25026 | |||
| 646276b37c | |||
| 58f9d86d0b | |||
| cf0a268fb1 | |||
| ec783ae587 | |||
| f50d44792c | |||
| b225b0200e | |||
| 507f9b7ae2 | |||
| 5991b82cde | |||
| f38bc8d09e | |||
| 9824f283de | |||
| 341d866c00 | |||
| 965ddcb564 | |||
| a0a1a101e8 | |||
| 277c922ec3 | |||
| f372627d61 | |||
| 1be86325d5 | |||
| 6d71454aa0 | |||
| 75d6aab0bb | |||
| 496dce093a | |||
| f740ba0ffe | |||
| a82af054a4 | |||
| c80e3da644 | |||
| af9bb566f8 | |||
| 5ca929417b | |||
| 3c1c44bda1 | |||
| c05977f144 | |||
| 55333ef1ac | |||
| 49ad6d2aa8 | |||
| b7e4373d6e | |||
| 699c074816 | |||
| c26855f953 | |||
| 1457b38e7e | |||
| 55d08c5be3 | |||
| ffbfbd43cb | |||
| cb24fe5c5d | |||
| aa81d8f12d | |||
| 2ee1a0241b | |||
| 89bc7a037d | |||
| a21683555a | |||
| 5a98235ee0 | |||
| 3ce836fd8b | |||
| 5a5f7814ab | |||
| 907d475897 | |||
| 41503fc0b2 | |||
| cfc7646a5a | |||
| 7103336456 | |||
| 48db4af56d | |||
| 8285b5d9a7 | |||
| 43218bd027 | |||
| 042fae143d | |||
| f6f997525f | |||
| 753fb5e1b2 | |||
| 06a42df732 | |||
| 66a2a62c7b | |||
| 41bbbde232 | |||
| 373c0ff7d0 | |||
| 30345d450c | |||
| b9dc83466d | |||
| f26175a99f | |||
| c7881e6eb4 | |||
| 97b98a4192 | |||
| fc65d3f43a | |||
| aa87695f3c | |||
| c3fb84397a | |||
| 8d78cd97d0 | |||
| 24d2c4089c | |||
| 38f47c65a1 | |||
| 896096374c | |||
| 0e2326ed06 | |||
| a07db454be | |||
| 87a4a81798 | |||
| f0ee743ea1 | |||
| fbac1e9d95 | |||
| d8536ed78e | |||
| 848dae52ab | |||
| f62a470dfa | |||
| 16a8409014 | |||
| dfa5b8aba5 | |||
| 54270e960f | |||
| 6541b7fcef | |||
| 19af49a49b | |||
| 99e189cae3 | |||
| 6f68563df2 | |||
| df03b2a156 | |||
| e1211ba01b | |||
| 24ea3f0ee8 | |||
| 79045ab283 | |||
| e27189364e | |||
| ba224e4eb9 | |||
| 336950628e | |||
| 6ede552292 | |||
| 07b6356b38 | |||
| 4c5730a222 | |||
| 8ab84c8d91 | |||
| 89ef82337d | 
| @ -1,5 +1,5 @@ | ||||
| [bumpversion] | ||||
| current_version = 2024.2.4 | ||||
| current_version = 2024.2.1 | ||||
| tag = True | ||||
| commit = True | ||||
| parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))? | ||||
|  | ||||
| @ -11,10 +11,6 @@ inputs: | ||||
|     description: "Docker image arch" | ||||
|  | ||||
| outputs: | ||||
|   shouldBuild: | ||||
|     description: "Whether to build image or not" | ||||
|     value: ${{ steps.ev.outputs.shouldBuild }} | ||||
|  | ||||
|   sha: | ||||
|     description: "sha" | ||||
|     value: ${{ steps.ev.outputs.sha }} | ||||
|  | ||||
| @ -7,8 +7,6 @@ from time import time | ||||
| parser = configparser.ConfigParser() | ||||
| parser.read(".bumpversion.cfg") | ||||
|  | ||||
| should_build = str(os.environ.get("DOCKER_USERNAME", None) is not None).lower() | ||||
|  | ||||
| branch_name = os.environ["GITHUB_REF"] | ||||
| if os.environ.get("GITHUB_HEAD_REF", "") != "": | ||||
|     branch_name = os.environ["GITHUB_HEAD_REF"] | ||||
| @ -54,7 +52,6 @@ image_main_tag = image_tags[0] | ||||
| image_tags_rendered = ",".join(image_tags) | ||||
|  | ||||
| with open(os.environ["GITHUB_OUTPUT"], "a+", encoding="utf-8") as _output: | ||||
|     print("shouldBuild=%s" % should_build, file=_output) | ||||
|     print("sha=%s" % sha, file=_output) | ||||
|     print("version=%s" % version, file=_output) | ||||
|     print("prerelease=%s" % prerelease, file=_output) | ||||
|  | ||||
							
								
								
									
										11
									
								
								.github/workflows/ci-main.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										11
									
								
								.github/workflows/ci-main.yml
									
									
									
									
										vendored
									
									
								
							| @ -28,10 +28,7 @@ jobs: | ||||
|           - bandit | ||||
|           - black | ||||
|           - codespell | ||||
|           - isort | ||||
|           - pending-migrations | ||||
|           # - pylint | ||||
|           - pyright | ||||
|           - ruff | ||||
|     runs-on: ubuntu-latest | ||||
|     steps: | ||||
| @ -219,6 +216,7 @@ jobs: | ||||
|       # Needed to upload contianer images to ghcr.io | ||||
|       packages: write | ||||
|     timeout-minutes: 120 | ||||
|     if: "github.repository == 'goauthentik/authentik'" | ||||
|     steps: | ||||
|       - uses: actions/checkout@v4 | ||||
|         with: | ||||
| @ -230,13 +228,10 @@ jobs: | ||||
|       - name: prepare variables | ||||
|         uses: ./.github/actions/docker-push-variables | ||||
|         id: ev | ||||
|         env: | ||||
|           DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} | ||||
|         with: | ||||
|           image-name: ghcr.io/goauthentik/dev-server | ||||
|           image-arch: ${{ matrix.arch }} | ||||
|       - name: Login to Container Registry | ||||
|         if: ${{ steps.ev.outputs.shouldBuild == 'true' }} | ||||
|         uses: docker/login-action@v3 | ||||
|         with: | ||||
|           registry: ghcr.io | ||||
| @ -252,7 +247,7 @@ jobs: | ||||
|             GEOIPUPDATE_ACCOUNT_ID=${{ secrets.GEOIPUPDATE_ACCOUNT_ID }} | ||||
|             GEOIPUPDATE_LICENSE_KEY=${{ secrets.GEOIPUPDATE_LICENSE_KEY }} | ||||
|           tags: ${{ steps.ev.outputs.imageTags }} | ||||
|           push: ${{ steps.ev.outputs.shouldBuild == 'true' }} | ||||
|           push: true | ||||
|           build-args: | | ||||
|             GIT_BUILD_HASH=${{ steps.ev.outputs.sha }} | ||||
|           cache-from: type=gha | ||||
| @ -274,8 +269,6 @@ jobs: | ||||
|       - name: prepare variables | ||||
|         uses: ./.github/actions/docker-push-variables | ||||
|         id: ev | ||||
|         env: | ||||
|           DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} | ||||
|         with: | ||||
|           image-name: ghcr.io/goauthentik/dev-server | ||||
|       - name: Comment on PR | ||||
|  | ||||
							
								
								
									
										6
									
								
								.github/workflows/ci-outpost.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										6
									
								
								.github/workflows/ci-outpost.yml
									
									
									
									
										vendored
									
									
								
							| @ -71,6 +71,7 @@ jobs: | ||||
|     permissions: | ||||
|       # Needed to upload contianer images to ghcr.io | ||||
|       packages: write | ||||
|     if: "github.repository == 'goauthentik/authentik'" | ||||
|     steps: | ||||
|       - uses: actions/checkout@v4 | ||||
|         with: | ||||
| @ -82,12 +83,9 @@ jobs: | ||||
|       - name: prepare variables | ||||
|         uses: ./.github/actions/docker-push-variables | ||||
|         id: ev | ||||
|         env: | ||||
|           DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} | ||||
|         with: | ||||
|           image-name: ghcr.io/goauthentik/dev-${{ matrix.type }} | ||||
|       - name: Login to Container Registry | ||||
|         if: ${{ steps.ev.outputs.shouldBuild == 'true' }} | ||||
|         uses: docker/login-action@v3 | ||||
|         with: | ||||
|           registry: ghcr.io | ||||
| @ -100,7 +98,7 @@ jobs: | ||||
|         with: | ||||
|           tags: ${{ steps.ev.outputs.imageTags }} | ||||
|           file: ${{ matrix.type }}.Dockerfile | ||||
|           push: ${{ steps.ev.outputs.shouldBuild == 'true' }} | ||||
|           push: true | ||||
|           build-args: | | ||||
|             GIT_BUILD_HASH=${{ steps.ev.outputs.sha }} | ||||
|           platforms: linux/amd64,linux/arm64 | ||||
|  | ||||
							
								
								
									
										6
									
								
								.github/workflows/release-publish.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										6
									
								
								.github/workflows/release-publish.yml
									
									
									
									
										vendored
									
									
								
							| @ -20,8 +20,6 @@ jobs: | ||||
|       - name: prepare variables | ||||
|         uses: ./.github/actions/docker-push-variables | ||||
|         id: ev | ||||
|         env: | ||||
|           DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} | ||||
|         with: | ||||
|           image-name: ghcr.io/goauthentik/server,beryju/authentik | ||||
|       - name: Docker Login Registry | ||||
| @ -74,8 +72,6 @@ jobs: | ||||
|       - name: prepare variables | ||||
|         uses: ./.github/actions/docker-push-variables | ||||
|         id: ev | ||||
|         env: | ||||
|           DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} | ||||
|         with: | ||||
|           image-name: ghcr.io/goauthentik/${{ matrix.type }},beryju/authentik-${{ matrix.type }} | ||||
|       - name: make empty clients | ||||
| @ -172,8 +168,6 @@ jobs: | ||||
|       - name: prepare variables | ||||
|         uses: ./.github/actions/docker-push-variables | ||||
|         id: ev | ||||
|         env: | ||||
|           DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} | ||||
|         with: | ||||
|           image-name: ghcr.io/goauthentik/server | ||||
|       - name: Get static files from docker image | ||||
|  | ||||
							
								
								
									
										2
									
								
								.github/workflows/release-tag.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/release-tag.yml
									
									
									
									
										vendored
									
									
								
							| @ -32,8 +32,6 @@ jobs: | ||||
|       - name: prepare variables | ||||
|         uses: ./.github/actions/docker-push-variables | ||||
|         id: ev | ||||
|         env: | ||||
|           DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} | ||||
|         with: | ||||
|           image-name: ghcr.io/goauthentik/server | ||||
|       - name: Create Release | ||||
|  | ||||
							
								
								
									
										3
									
								
								.vscode/extensions.json
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.vscode/extensions.json
									
									
									
									
										vendored
									
									
								
							| @ -10,8 +10,7 @@ | ||||
|         "Gruntfuggly.todo-tree", | ||||
|         "mechatroner.rainbow-csv", | ||||
|         "ms-python.black-formatter", | ||||
|         "ms-python.isort", | ||||
|         "ms-python.pylint", | ||||
|         "charliermarsh.ruff", | ||||
|         "ms-python.python", | ||||
|         "ms-python.vscode-pylance", | ||||
|         "ms-python.black-formatter", | ||||
|  | ||||
| @ -103,10 +103,9 @@ RUN --mount=type=bind,target=./pyproject.toml,src=./pyproject.toml \ | ||||
|     --mount=type=cache,target=/root/.cache/pip \ | ||||
|     --mount=type=cache,target=/root/.cache/pypoetry \ | ||||
|     python -m venv /ak-root/venv/ && \ | ||||
|     bash -c "source ${VENV_PATH}/bin/activate && \ | ||||
|     pip3 install --upgrade pip && \ | ||||
|     pip3 install poetry && \ | ||||
|         poetry install --only=main --no-ansi --no-interaction --no-root" | ||||
|     poetry install --only=main --no-ansi --no-interaction | ||||
|  | ||||
| # Stage 6: Run | ||||
| FROM docker.io/python:3.12.2-slim-bookworm AS final-image | ||||
|  | ||||
							
								
								
									
										14
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										14
									
								
								Makefile
									
									
									
									
									
								
							| @ -59,15 +59,12 @@ test: ## Run the server tests and produce a coverage report (locally) | ||||
| 	coverage report | ||||
|  | ||||
| lint-fix:  ## Lint and automatically fix errors in the python source code. Reports spelling errors. | ||||
| 	isort $(PY_SOURCES) | ||||
| 	black $(PY_SOURCES) | ||||
| 	ruff --fix $(PY_SOURCES) | ||||
| 	ruff check --fix $(PY_SOURCES) | ||||
| 	codespell -w $(CODESPELL_ARGS) | ||||
|  | ||||
| lint: ## Lint the python and golang sources | ||||
| 	bandit -r $(PY_SOURCES) -x node_modules | ||||
| 	./web/node_modules/.bin/pyright $(PY_SOURCES) | ||||
| 	pylint $(PY_SOURCES) | ||||
| 	golangci-lint run -v | ||||
|  | ||||
| core-install: | ||||
| @ -249,9 +246,6 @@ ci--meta-debug: | ||||
| 	python -V | ||||
| 	node --version | ||||
|  | ||||
| ci-pylint: ci--meta-debug | ||||
| 	pylint $(PY_SOURCES) | ||||
|  | ||||
| ci-black: ci--meta-debug | ||||
| 	black --check $(PY_SOURCES) | ||||
|  | ||||
| @ -261,14 +255,8 @@ ci-ruff: ci--meta-debug | ||||
| ci-codespell: ci--meta-debug | ||||
| 	codespell $(CODESPELL_ARGS) -s | ||||
|  | ||||
| ci-isort: ci--meta-debug | ||||
| 	isort --check $(PY_SOURCES) | ||||
|  | ||||
| ci-bandit: ci--meta-debug | ||||
| 	bandit -r $(PY_SOURCES) | ||||
|  | ||||
| ci-pyright: ci--meta-debug | ||||
| 	./web/node_modules/.bin/pyright $(PY_SOURCES) | ||||
|  | ||||
| ci-pending-migrations: ci--meta-debug | ||||
| 	ak makemigrations --check | ||||
|  | ||||
| @ -1,13 +1,12 @@ | ||||
| """authentik root module""" | ||||
|  | ||||
| from os import environ | ||||
| from typing import Optional | ||||
|  | ||||
| __version__ = "2024.2.4" | ||||
| __version__ = "2024.2.1" | ||||
| ENV_GIT_HASH_KEY = "GIT_BUILD_HASH" | ||||
|  | ||||
|  | ||||
| def get_build_hash(fallback: Optional[str] = None) -> str: | ||||
| def get_build_hash(fallback: str | None = None) -> str: | ||||
|     """Get build hash""" | ||||
|     build_hash = environ.get(ENV_GIT_HASH_KEY, fallback if fallback else "") | ||||
|     return fallback if build_hash == "" and fallback else build_hash | ||||
|  | ||||
| @ -18,7 +18,7 @@ class AuthentikAPIConfig(AppConfig): | ||||
|  | ||||
|         # Class is defined here as it needs to be created early enough that drf-spectacular will | ||||
|         # find it, but also won't cause any import issues | ||||
|         # pylint: disable=unused-variable | ||||
|  | ||||
|         class TokenSchema(OpenApiAuthenticationExtension): | ||||
|             """Auth schema""" | ||||
|  | ||||
|  | ||||
| @ -1,7 +1,7 @@ | ||||
| """API Authentication""" | ||||
|  | ||||
| from hmac import compare_digest | ||||
| from typing import Any, Optional | ||||
| from typing import Any | ||||
|  | ||||
| from django.conf import settings | ||||
| from rest_framework.authentication import BaseAuthentication, get_authorization_header | ||||
| @ -17,7 +17,7 @@ from authentik.providers.oauth2.constants import SCOPE_AUTHENTIK_API | ||||
| LOGGER = get_logger() | ||||
|  | ||||
|  | ||||
| def validate_auth(header: bytes) -> Optional[str]: | ||||
| def validate_auth(header: bytes) -> str | None: | ||||
|     """Validate that the header is in a correct format, | ||||
|     returns type and credentials""" | ||||
|     auth_credentials = header.decode().strip() | ||||
| @ -32,7 +32,7 @@ def validate_auth(header: bytes) -> Optional[str]: | ||||
|     return auth_credentials | ||||
|  | ||||
|  | ||||
| def bearer_auth(raw_header: bytes) -> Optional[User]: | ||||
| def bearer_auth(raw_header: bytes) -> User | None: | ||||
|     """raw_header in the Format of `Bearer ....`""" | ||||
|     user = auth_user_lookup(raw_header) | ||||
|     if not user: | ||||
| @ -42,7 +42,7 @@ def bearer_auth(raw_header: bytes) -> Optional[User]: | ||||
|     return user | ||||
|  | ||||
|  | ||||
| def auth_user_lookup(raw_header: bytes) -> Optional[User]: | ||||
| def auth_user_lookup(raw_header: bytes) -> User | None: | ||||
|     """raw_header in the Format of `Bearer ....`""" | ||||
|     from authentik.providers.oauth2.models import AccessToken | ||||
|  | ||||
| @ -75,7 +75,7 @@ def auth_user_lookup(raw_header: bytes) -> Optional[User]: | ||||
|     raise AuthenticationFailed("Token invalid/expired") | ||||
|  | ||||
|  | ||||
| def token_secret_key(value: str) -> Optional[User]: | ||||
| def token_secret_key(value: str) -> User | None: | ||||
|     """Check if the token is the secret key | ||||
|     and return the service account for the managed outpost""" | ||||
|     from authentik.outposts.apps import MANAGED_OUTPOST | ||||
|  | ||||
| @ -25,17 +25,17 @@ class TestAPIAuth(TestCase): | ||||
|     def test_invalid_type(self): | ||||
|         """Test invalid type""" | ||||
|         with self.assertRaises(AuthenticationFailed): | ||||
|             bearer_auth("foo bar".encode()) | ||||
|             bearer_auth(b"foo bar") | ||||
|  | ||||
|     def test_invalid_empty(self): | ||||
|         """Test invalid type""" | ||||
|         self.assertIsNone(bearer_auth("Bearer ".encode())) | ||||
|         self.assertIsNone(bearer_auth("".encode())) | ||||
|         self.assertIsNone(bearer_auth(b"Bearer ")) | ||||
|         self.assertIsNone(bearer_auth(b"")) | ||||
|  | ||||
|     def test_invalid_no_token(self): | ||||
|         """Test invalid with no token""" | ||||
|         with self.assertRaises(AuthenticationFailed): | ||||
|             auth = b64encode(":abc".encode()).decode() | ||||
|             auth = b64encode(b":abc").decode() | ||||
|             self.assertIsNone(bearer_auth(f"Basic :{auth}".encode())) | ||||
|  | ||||
|     def test_bearer_valid(self): | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| """authentik API Modelviewset tests""" | ||||
|  | ||||
| from typing import Callable | ||||
| from collections.abc import Callable | ||||
|  | ||||
| from django.test import TestCase | ||||
| from rest_framework.viewsets import ModelViewSet, ReadOnlyModelViewSet | ||||
| @ -26,6 +26,6 @@ def viewset_tester_factory(test_viewset: type[ModelViewSet]) -> Callable: | ||||
|  | ||||
|  | ||||
| for _, viewset, _ in router.registry: | ||||
|     if not issubclass(viewset, (ModelViewSet, ReadOnlyModelViewSet)): | ||||
|     if not issubclass(viewset, ModelViewSet | ReadOnlyModelViewSet): | ||||
|         continue | ||||
|     setattr(TestModelViewSets, f"test_viewset_{viewset.__name__}", viewset_tester_factory(viewset)) | ||||
|  | ||||
| @ -68,11 +68,7 @@ class ConfigView(APIView): | ||||
|         """Get all capabilities this server instance supports""" | ||||
|         caps = [] | ||||
|         deb_test = settings.DEBUG or settings.TEST | ||||
|         if ( | ||||
|             CONFIG.get("storage.media.backend", "file") == "s3" | ||||
|             or Path(settings.STORAGES["default"]["OPTIONS"]["location"]).is_mount() | ||||
|             or deb_test | ||||
|         ): | ||||
|         if Path(settings.MEDIA_ROOT).is_mount() or deb_test: | ||||
|             caps.append(Capabilities.CAN_SAVE_MEDIA) | ||||
|         for processor in get_context_processors(): | ||||
|             if cap := processor.capability(): | ||||
|  | ||||
| @ -33,7 +33,7 @@ for _authentik_app in get_apps(): | ||||
|             app_name=_authentik_app.name, | ||||
|         ) | ||||
|         continue | ||||
|     urls: list = getattr(api_urls, "api_urlpatterns") | ||||
|     urls: list = api_urls.api_urlpatterns | ||||
|     for url in urls: | ||||
|         if isinstance(url, URLPattern): | ||||
|             _other_urls.append(url) | ||||
|  | ||||
| @ -52,7 +52,9 @@ class BlueprintInstanceSerializer(ModelSerializer): | ||||
|         valid, logs = Importer.from_string(content, context).validate() | ||||
|         if not valid: | ||||
|             text_logs = "\n".join([x["event"] for x in logs]) | ||||
|             raise ValidationError(_("Failed to validate blueprint: %(logs)s" % {"logs": text_logs})) | ||||
|             raise ValidationError( | ||||
|                 _("Failed to validate blueprint: {logs}".format_map({"logs": text_logs})) | ||||
|             ) | ||||
|         return content | ||||
|  | ||||
|     def validate(self, attrs: dict) -> dict: | ||||
|  | ||||
| @ -1,5 +1,6 @@ | ||||
| """authentik Blueprints app""" | ||||
|  | ||||
| from collections.abc import Callable | ||||
| from importlib import import_module | ||||
| from inspect import ismethod | ||||
|  | ||||
| @ -13,8 +14,8 @@ class ManagedAppConfig(AppConfig): | ||||
|  | ||||
|     logger: BoundLogger | ||||
|  | ||||
|     RECONCILE_GLOBAL_PREFIX: str = "reconcile_global_" | ||||
|     RECONCILE_TENANT_PREFIX: str = "reconcile_tenant_" | ||||
|     RECONCILE_GLOBAL_CATEGORY: str = "global" | ||||
|     RECONCILE_TENANT_CATEGORY: str = "tenant" | ||||
|  | ||||
|     def __init__(self, app_name: str, *args, **kwargs) -> None: | ||||
|         super().__init__(app_name, *args, **kwargs) | ||||
| @ -22,8 +23,8 @@ class ManagedAppConfig(AppConfig): | ||||
|  | ||||
|     def ready(self) -> None: | ||||
|         self.import_related() | ||||
|         self.reconcile_global() | ||||
|         self.reconcile_tenant() | ||||
|         self._reconcile_global() | ||||
|         self._reconcile_tenant() | ||||
|         return super().ready() | ||||
|  | ||||
|     def import_related(self): | ||||
| @ -51,7 +52,8 @@ class ManagedAppConfig(AppConfig): | ||||
|             meth = getattr(self, meth_name) | ||||
|             if not ismethod(meth): | ||||
|                 continue | ||||
|             if not meth_name.startswith(prefix): | ||||
|             category = getattr(meth, "_authentik_managed_reconcile", None) | ||||
|             if category != prefix: | ||||
|                 continue | ||||
|             name = meth_name.replace(prefix, "") | ||||
|             try: | ||||
| @ -61,7 +63,19 @@ class ManagedAppConfig(AppConfig): | ||||
|             except (DatabaseError, ProgrammingError, InternalError) as exc: | ||||
|                 self.logger.warning("Failed to run reconcile", name=name, exc=exc) | ||||
|  | ||||
|     def reconcile_tenant(self) -> None: | ||||
|     @staticmethod | ||||
|     def reconcile_tenant(func: Callable): | ||||
|         """Mark a function to be called on startup (for each tenant)""" | ||||
|         func._authentik_managed_reconcile = ManagedAppConfig.RECONCILE_TENANT_CATEGORY | ||||
|         return func | ||||
|  | ||||
|     @staticmethod | ||||
|     def reconcile_global(func: Callable): | ||||
|         """Mark a function to be called on startup (globally)""" | ||||
|         func._authentik_managed_reconcile = ManagedAppConfig.RECONCILE_GLOBAL_CATEGORY | ||||
|         return func | ||||
|  | ||||
|     def _reconcile_tenant(self) -> None: | ||||
|         """reconcile ourselves for tenanted methods""" | ||||
|         from authentik.tenants.models import Tenant | ||||
|  | ||||
| @ -72,9 +86,9 @@ class ManagedAppConfig(AppConfig): | ||||
|             return | ||||
|         for tenant in tenants: | ||||
|             with tenant: | ||||
|                 self._reconcile(self.RECONCILE_TENANT_PREFIX) | ||||
|                 self._reconcile(self.RECONCILE_TENANT_CATEGORY) | ||||
|  | ||||
|     def reconcile_global(self) -> None: | ||||
|     def _reconcile_global(self) -> None: | ||||
|         """ | ||||
|         reconcile ourselves for global methods. | ||||
|         Used for signals, tasks, etc. Database queries should not be made in here. | ||||
| @ -82,7 +96,7 @@ class ManagedAppConfig(AppConfig): | ||||
|         from django_tenants.utils import get_public_schema_name, schema_context | ||||
|  | ||||
|         with schema_context(get_public_schema_name()): | ||||
|             self._reconcile(self.RECONCILE_GLOBAL_PREFIX) | ||||
|             self._reconcile(self.RECONCILE_GLOBAL_CATEGORY) | ||||
|  | ||||
|  | ||||
| class AuthentikBlueprintsConfig(ManagedAppConfig): | ||||
| @ -93,11 +107,13 @@ class AuthentikBlueprintsConfig(ManagedAppConfig): | ||||
|     verbose_name = "authentik Blueprints" | ||||
|     default = True | ||||
|  | ||||
|     def reconcile_global_load_blueprints_v1_tasks(self): | ||||
|     @ManagedAppConfig.reconcile_global | ||||
|     def load_blueprints_v1_tasks(self): | ||||
|         """Load v1 tasks""" | ||||
|         self.import_module("authentik.blueprints.v1.tasks") | ||||
|  | ||||
|     def reconcile_tenant_blueprints_discovery(self): | ||||
|     @ManagedAppConfig.reconcile_tenant | ||||
|     def blueprints_discovery(self): | ||||
|         """Run blueprint discovery""" | ||||
|         from authentik.blueprints.v1.tasks import blueprints_discovery, clear_failed_blueprints | ||||
|  | ||||
|  | ||||
| @ -71,6 +71,19 @@ class BlueprintInstance(SerializerModel, ManagedModel, CreatedUpdatedModel): | ||||
|     enabled = models.BooleanField(default=True) | ||||
|     managed_models = ArrayField(models.TextField(), default=list) | ||||
|  | ||||
|     class Meta: | ||||
|         verbose_name = _("Blueprint Instance") | ||||
|         verbose_name_plural = _("Blueprint Instances") | ||||
|         unique_together = ( | ||||
|             ( | ||||
|                 "name", | ||||
|                 "path", | ||||
|             ), | ||||
|         ) | ||||
|  | ||||
|     def __str__(self) -> str: | ||||
|         return f"Blueprint Instance {self.name}" | ||||
|  | ||||
|     def retrieve_oci(self) -> str: | ||||
|         """Get blueprint from an OCI registry""" | ||||
|         client = BlueprintOCIClient(self.path.replace(OCI_PREFIX, "https://")) | ||||
| @ -89,7 +102,7 @@ class BlueprintInstance(SerializerModel, ManagedModel, CreatedUpdatedModel): | ||||
|                 raise BlueprintRetrievalFailed("Invalid blueprint path") | ||||
|             with full_path.open("r", encoding="utf-8") as _file: | ||||
|                 return _file.read() | ||||
|         except (IOError, OSError) as exc: | ||||
|         except OSError as exc: | ||||
|             raise BlueprintRetrievalFailed(exc) from exc | ||||
|  | ||||
|     def retrieve(self) -> str: | ||||
| @ -105,16 +118,3 @@ class BlueprintInstance(SerializerModel, ManagedModel, CreatedUpdatedModel): | ||||
|         from authentik.blueprints.api import BlueprintInstanceSerializer | ||||
|  | ||||
|         return BlueprintInstanceSerializer | ||||
|  | ||||
|     def __str__(self) -> str: | ||||
|         return f"Blueprint Instance {self.name}" | ||||
|  | ||||
|     class Meta: | ||||
|         verbose_name = _("Blueprint Instance") | ||||
|         verbose_name_plural = _("Blueprint Instances") | ||||
|         unique_together = ( | ||||
|             ( | ||||
|                 "name", | ||||
|                 "path", | ||||
|             ), | ||||
|         ) | ||||
|  | ||||
| @ -1,7 +1,7 @@ | ||||
| """Blueprint helpers""" | ||||
|  | ||||
| from collections.abc import Callable | ||||
| from functools import wraps | ||||
| from typing import Callable | ||||
|  | ||||
| from django.apps import apps | ||||
|  | ||||
|  | ||||
| @ -1,7 +1,7 @@ | ||||
| """test packaged blueprints""" | ||||
|  | ||||
| from collections.abc import Callable | ||||
| from pathlib import Path | ||||
| from typing import Callable | ||||
|  | ||||
| from django.test import TransactionTestCase | ||||
|  | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| """authentik managed models tests""" | ||||
|  | ||||
| from typing import Callable, Type | ||||
| from collections.abc import Callable | ||||
|  | ||||
| from django.apps import apps | ||||
| from django.test import TestCase | ||||
| @ -14,7 +14,7 @@ class TestModels(TestCase): | ||||
|     """Test Models""" | ||||
|  | ||||
|  | ||||
| def serializer_tester_factory(test_model: Type[SerializerModel]) -> Callable: | ||||
| def serializer_tester_factory(test_model: type[SerializerModel]) -> Callable: | ||||
|     """Test serializer""" | ||||
|  | ||||
|     def tester(self: TestModels): | ||||
|  | ||||
| @ -54,7 +54,7 @@ class TestBlueprintsV1Tasks(TransactionTestCase): | ||||
|             file.seek(0) | ||||
|             file_hash = sha512(file.read().encode()).hexdigest() | ||||
|             file.flush() | ||||
|             blueprints_discovery()  # pylint: disable=no-value-for-parameter | ||||
|             blueprints_discovery() | ||||
|             instance = BlueprintInstance.objects.filter(name=blueprint_id).first() | ||||
|             self.assertEqual(instance.last_applied_hash, file_hash) | ||||
|             self.assertEqual( | ||||
| @ -82,7 +82,7 @@ class TestBlueprintsV1Tasks(TransactionTestCase): | ||||
|                 ) | ||||
|             ) | ||||
|             file.flush() | ||||
|             blueprints_discovery()  # pylint: disable=no-value-for-parameter | ||||
|             blueprints_discovery() | ||||
|             blueprint = BlueprintInstance.objects.filter(name="foo").first() | ||||
|             self.assertEqual( | ||||
|                 blueprint.last_applied_hash, | ||||
| @ -107,7 +107,7 @@ class TestBlueprintsV1Tasks(TransactionTestCase): | ||||
|                 ) | ||||
|             ) | ||||
|             file.flush() | ||||
|             blueprints_discovery()  # pylint: disable=no-value-for-parameter | ||||
|             blueprints_discovery() | ||||
|             blueprint.refresh_from_db() | ||||
|             self.assertEqual( | ||||
|                 blueprint.last_applied_hash, | ||||
| @ -149,7 +149,7 @@ class TestBlueprintsV1Tasks(TransactionTestCase): | ||||
|                 instance.status, | ||||
|                 BlueprintInstanceStatus.UNKNOWN, | ||||
|             ) | ||||
|             apply_blueprint(instance.pk)  # pylint: disable=no-value-for-parameter | ||||
|             apply_blueprint(instance.pk) | ||||
|             instance.refresh_from_db() | ||||
|             self.assertEqual(instance.last_applied_hash, "") | ||||
|             self.assertEqual( | ||||
|  | ||||
| @ -1,13 +1,14 @@ | ||||
| """transfer common classes""" | ||||
|  | ||||
| from collections import OrderedDict | ||||
| from collections.abc import Iterable, Mapping | ||||
| from copy import copy | ||||
| from dataclasses import asdict, dataclass, field, is_dataclass | ||||
| from enum import Enum | ||||
| from functools import reduce | ||||
| from operator import ixor | ||||
| from os import getenv | ||||
| from typing import Any, Iterable, Literal, Mapping, Optional, Union | ||||
| from typing import Any, Literal, Union | ||||
| from uuid import UUID | ||||
|  | ||||
| from deepmerge import always_merger | ||||
| @ -45,7 +46,7 @@ def get_attrs(obj: SerializerModel) -> dict[str, Any]: | ||||
| class BlueprintEntryState: | ||||
|     """State of a single instance""" | ||||
|  | ||||
|     instance: Optional[Model] = None | ||||
|     instance: Model | None = None | ||||
|  | ||||
|  | ||||
| class BlueprintEntryDesiredState(Enum): | ||||
| @ -67,9 +68,9 @@ class BlueprintEntry: | ||||
|     ) | ||||
|     conditions: list[Any] = field(default_factory=list) | ||||
|     identifiers: dict[str, Any] = field(default_factory=dict) | ||||
|     attrs: Optional[dict[str, Any]] = field(default_factory=dict) | ||||
|     attrs: dict[str, Any] | None = field(default_factory=dict) | ||||
|  | ||||
|     id: Optional[str] = None | ||||
|     id: str | None = None | ||||
|  | ||||
|     _state: BlueprintEntryState = field(default_factory=BlueprintEntryState) | ||||
|  | ||||
| @ -92,10 +93,10 @@ class BlueprintEntry: | ||||
|             attrs=all_attrs, | ||||
|         ) | ||||
|  | ||||
|     def _get_tag_context( | ||||
|     def get_tag_context( | ||||
|         self, | ||||
|         depth: int = 0, | ||||
|         context_tag_type: Optional[type["YAMLTagContext"] | tuple["YAMLTagContext", ...]] = None, | ||||
|         context_tag_type: type["YAMLTagContext"] | tuple["YAMLTagContext", ...] | None = None, | ||||
|     ) -> "YAMLTagContext": | ||||
|         """Get a YAMLTagContext object located at a certain depth in the tag tree""" | ||||
|         if depth < 0: | ||||
| @ -108,8 +109,8 @@ class BlueprintEntry: | ||||
|  | ||||
|         try: | ||||
|             return contexts[-(depth + 1)] | ||||
|         except IndexError: | ||||
|             raise ValueError(f"invalid depth: {depth}. Max depth: {len(contexts) - 1}") | ||||
|         except IndexError as exc: | ||||
|             raise ValueError(f"invalid depth: {depth}. Max depth: {len(contexts) - 1}") from exc | ||||
|  | ||||
|     def tag_resolver(self, value: Any, blueprint: "Blueprint") -> Any: | ||||
|         """Check if we have any special tags that need handling""" | ||||
| @ -170,7 +171,7 @@ class Blueprint: | ||||
|     entries: list[BlueprintEntry] = field(default_factory=list) | ||||
|     context: dict = field(default_factory=dict) | ||||
|  | ||||
|     metadata: Optional[BlueprintMetadata] = field(default=None) | ||||
|     metadata: BlueprintMetadata | None = field(default=None) | ||||
|  | ||||
|  | ||||
| class YAMLTag: | ||||
| @ -218,7 +219,7 @@ class Env(YAMLTag): | ||||
|     """Lookup environment variable with optional default""" | ||||
|  | ||||
|     key: str | ||||
|     default: Optional[Any] | ||||
|     default: Any | None | ||||
|  | ||||
|     def __init__(self, loader: "BlueprintLoader", node: ScalarNode | SequenceNode) -> None: | ||||
|         super().__init__() | ||||
| @ -237,7 +238,7 @@ class Context(YAMLTag): | ||||
|     """Lookup key from instance context""" | ||||
|  | ||||
|     key: str | ||||
|     default: Optional[Any] | ||||
|     default: Any | None | ||||
|  | ||||
|     def __init__(self, loader: "BlueprintLoader", node: ScalarNode | SequenceNode) -> None: | ||||
|         super().__init__() | ||||
| @ -281,7 +282,7 @@ class Format(YAMLTag): | ||||
|         try: | ||||
|             return self.format_string % tuple(args) | ||||
|         except TypeError as exc: | ||||
|             raise EntryInvalidError.from_entry(exc, entry) | ||||
|             raise EntryInvalidError.from_entry(exc, entry) from exc | ||||
|  | ||||
|  | ||||
| class Find(YAMLTag): | ||||
| @ -366,7 +367,7 @@ class Condition(YAMLTag): | ||||
|             comparator = self._COMPARATORS[self.mode.upper()] | ||||
|             return comparator(tuple(bool(x) for x in args)) | ||||
|         except (TypeError, KeyError) as exc: | ||||
|             raise EntryInvalidError.from_entry(exc, entry) | ||||
|             raise EntryInvalidError.from_entry(exc, entry) from exc | ||||
|  | ||||
|  | ||||
| class If(YAMLTag): | ||||
| @ -398,7 +399,7 @@ class If(YAMLTag): | ||||
|                 blueprint, | ||||
|             ) | ||||
|         except TypeError as exc: | ||||
|             raise EntryInvalidError.from_entry(exc, entry) | ||||
|             raise EntryInvalidError.from_entry(exc, entry) from exc | ||||
|  | ||||
|  | ||||
| class Enumerate(YAMLTag, YAMLTagContext): | ||||
| @ -412,9 +413,7 @@ class Enumerate(YAMLTag, YAMLTagContext): | ||||
|         "SEQ": (list, lambda a, b: [*a, b]), | ||||
|         "MAP": ( | ||||
|             dict, | ||||
|             lambda a, b: always_merger.merge( | ||||
|                 a, {b[0]: b[1]} if isinstance(b, (tuple, list)) else b | ||||
|             ), | ||||
|             lambda a, b: always_merger.merge(a, {b[0]: b[1]} if isinstance(b, tuple | list) else b), | ||||
|         ), | ||||
|     } | ||||
|  | ||||
| @ -456,7 +455,7 @@ class Enumerate(YAMLTag, YAMLTagContext): | ||||
|         try: | ||||
|             output_class, add_fn = self._OUTPUT_BODIES[self.output_body.upper()] | ||||
|         except KeyError as exc: | ||||
|             raise EntryInvalidError.from_entry(exc, entry) | ||||
|             raise EntryInvalidError.from_entry(exc, entry) from exc | ||||
|  | ||||
|         result = output_class() | ||||
|  | ||||
| @ -484,13 +483,13 @@ class EnumeratedItem(YAMLTag): | ||||
|  | ||||
|     _SUPPORTED_CONTEXT_TAGS = (Enumerate,) | ||||
|  | ||||
|     def __init__(self, loader: "BlueprintLoader", node: ScalarNode) -> None: | ||||
|     def __init__(self, _loader: "BlueprintLoader", node: ScalarNode) -> None: | ||||
|         super().__init__() | ||||
|         self.depth = int(node.value) | ||||
|  | ||||
|     def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any: | ||||
|         try: | ||||
|             context_tag: Enumerate = entry._get_tag_context( | ||||
|             context_tag: Enumerate = entry.get_tag_context( | ||||
|                 depth=self.depth, | ||||
|                 context_tag_type=EnumeratedItem._SUPPORTED_CONTEXT_TAGS, | ||||
|             ) | ||||
| @ -500,9 +499,11 @@ class EnumeratedItem(YAMLTag): | ||||
|                     f"{self.__class__.__name__} tags are only usable " | ||||
|                     f"inside an {Enumerate.__name__} tag", | ||||
|                     entry, | ||||
|                 ) | ||||
|                 ) from exc | ||||
|  | ||||
|             raise EntryInvalidError.from_entry(f"{self.__class__.__name__} tag: {exc}", entry) | ||||
|             raise EntryInvalidError.from_entry( | ||||
|                 f"{self.__class__.__name__} tag: {exc}", entry | ||||
|             ) from exc | ||||
|  | ||||
|         return context_tag.get_context(entry, blueprint) | ||||
|  | ||||
| @ -515,8 +516,8 @@ class Index(EnumeratedItem): | ||||
|  | ||||
|         try: | ||||
|             return context[0] | ||||
|         except IndexError:  # pragma: no cover | ||||
|             raise EntryInvalidError.from_entry(f"Empty/invalid context: {context}", entry) | ||||
|         except IndexError as exc:  # pragma: no cover | ||||
|             raise EntryInvalidError.from_entry(f"Empty/invalid context: {context}", entry) from exc | ||||
|  | ||||
|  | ||||
| class Value(EnumeratedItem): | ||||
| @ -527,8 +528,8 @@ class Value(EnumeratedItem): | ||||
|  | ||||
|         try: | ||||
|             return context[1] | ||||
|         except IndexError:  # pragma: no cover | ||||
|             raise EntryInvalidError.from_entry(f"Empty/invalid context: {context}", entry) | ||||
|         except IndexError as exc:  # pragma: no cover | ||||
|             raise EntryInvalidError.from_entry(f"Empty/invalid context: {context}", entry) from exc | ||||
|  | ||||
|  | ||||
| class BlueprintDumper(SafeDumper): | ||||
| @ -582,13 +583,13 @@ class BlueprintLoader(SafeLoader): | ||||
| class EntryInvalidError(SentryIgnoredException): | ||||
|     """Error raised when an entry is invalid""" | ||||
|  | ||||
|     entry_model: Optional[str] | ||||
|     entry_id: Optional[str] | ||||
|     validation_error: Optional[ValidationError] | ||||
|     serializer: Optional[Serializer] = None | ||||
|     entry_model: str | None | ||||
|     entry_id: str | None | ||||
|     validation_error: ValidationError | None | ||||
|     serializer: Serializer | None = None | ||||
|  | ||||
|     def __init__( | ||||
|         self, *args: object, validation_error: Optional[ValidationError] = None, **kwargs | ||||
|         self, *args: object, validation_error: ValidationError | None = None, **kwargs | ||||
|     ) -> None: | ||||
|         super().__init__(*args) | ||||
|         self.entry_model = None | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| """Blueprint exporter""" | ||||
|  | ||||
| from typing import Iterable | ||||
| from collections.abc import Iterable | ||||
| from uuid import UUID | ||||
|  | ||||
| from django.apps import apps | ||||
| @ -59,7 +59,7 @@ class Exporter: | ||||
|         blueprint = Blueprint() | ||||
|         self._pre_export(blueprint) | ||||
|         blueprint.metadata = BlueprintMetadata( | ||||
|             name=_("authentik Export - %(date)s" % {"date": str(now())}), | ||||
|             name=_("authentik Export - {date}".format_map({"date": str(now())})), | ||||
|             labels={ | ||||
|                 LABEL_AUTHENTIK_GENERATED: "true", | ||||
|             }, | ||||
|  | ||||
| @ -2,7 +2,7 @@ | ||||
|  | ||||
| from contextlib import contextmanager | ||||
| from copy import deepcopy | ||||
| from typing import Any, Optional | ||||
| from typing import Any | ||||
|  | ||||
| from dacite.config import Config | ||||
| from dacite.core import from_dict | ||||
| @ -62,7 +62,7 @@ SERIALIZER_CONTEXT_BLUEPRINT = "blueprint_entry" | ||||
| def excluded_models() -> list[type[Model]]: | ||||
|     """Return a list of all excluded models that shouldn't be exposed via API | ||||
|     or other means (internal only, base classes, non-used objects, etc)""" | ||||
|     # pylint: disable=imported-auth-user | ||||
|  | ||||
|     from django.contrib.auth.models import Group as DjangoGroup | ||||
|     from django.contrib.auth.models import User as DjangoUser | ||||
|  | ||||
| @ -101,7 +101,7 @@ def excluded_models() -> list[type[Model]]: | ||||
|  | ||||
| def is_model_allowed(model: type[Model]) -> bool: | ||||
|     """Check if model is allowed""" | ||||
|     return model not in excluded_models() and issubclass(model, (SerializerModel, BaseMetaModel)) | ||||
|     return model not in excluded_models() and issubclass(model, SerializerModel | BaseMetaModel) | ||||
|  | ||||
|  | ||||
| class DoRollback(SentryIgnoredException): | ||||
| @ -125,7 +125,7 @@ class Importer: | ||||
|     logger: BoundLogger | ||||
|     _import: Blueprint | ||||
|  | ||||
|     def __init__(self, blueprint: Blueprint, context: Optional[dict] = None): | ||||
|     def __init__(self, blueprint: Blueprint, context: dict | None = None): | ||||
|         self.__pk_map: dict[Any, Model] = {} | ||||
|         self._import = blueprint | ||||
|         self.logger = get_logger() | ||||
| @ -168,7 +168,7 @@ class Importer: | ||||
|         for key, value in attrs.items(): | ||||
|             try: | ||||
|                 if isinstance(value, dict): | ||||
|                     for idx, _inner_key in enumerate(value): | ||||
|                     for _, _inner_key in enumerate(value): | ||||
|                         value[_inner_key] = updater(value[_inner_key]) | ||||
|                 elif isinstance(value, list): | ||||
|                     for idx, _inner_value in enumerate(value): | ||||
| @ -197,8 +197,7 @@ class Importer: | ||||
|  | ||||
|         return main_query | sub_query | ||||
|  | ||||
|     # pylint: disable-msg=too-many-locals | ||||
|     def _validate_single(self, entry: BlueprintEntry) -> Optional[BaseSerializer]: | ||||
|     def _validate_single(self, entry: BlueprintEntry) -> BaseSerializer | None: | ||||
|         """Validate a single entry""" | ||||
|         if not entry.check_all_conditions_match(self._import): | ||||
|             self.logger.debug("One or more conditions of this entry are not fulfilled, skipping") | ||||
| @ -369,7 +368,7 @@ class Importer: | ||||
|                     self.__pk_map[entry.identifiers["pk"]] = instance.pk | ||||
|                 entry._state = BlueprintEntryState(instance) | ||||
|             elif state == BlueprintEntryDesiredState.ABSENT: | ||||
|                 instance: Optional[Model] = serializer.instance | ||||
|                 instance: Model | None = serializer.instance | ||||
|                 if instance.pk: | ||||
|                     instance.delete() | ||||
|                     self.logger.debug("deleted model", mode=instance) | ||||
|  | ||||
| @ -43,7 +43,7 @@ class ApplyBlueprintMetaSerializer(PassiveSerializer): | ||||
|             LOGGER.info("Blueprint does not exist, but not required") | ||||
|             return MetaResult() | ||||
|         LOGGER.debug("Applying blueprint from meta model", blueprint=self.blueprint_instance) | ||||
|         # pylint: disable=no-value-for-parameter | ||||
|  | ||||
|         apply_blueprint(str(self.blueprint_instance.pk)) | ||||
|         return MetaResult() | ||||
|  | ||||
|  | ||||
| @ -8,15 +8,15 @@ from rest_framework.serializers import Serializer | ||||
| class BaseMetaModel(Model): | ||||
|     """Base models""" | ||||
|  | ||||
|     class Meta: | ||||
|         abstract = True | ||||
|  | ||||
|     @staticmethod | ||||
|     def serializer() -> Serializer: | ||||
|         """Serializer similar to SerializerModel, but as a static method since | ||||
|         this is an abstract model""" | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     class Meta: | ||||
|         abstract = True | ||||
|  | ||||
|  | ||||
| class MetaResult: | ||||
|     """Result returned by Meta Models' serializers. Empty class but we can't return none as | ||||
|  | ||||
| @ -4,7 +4,6 @@ from dataclasses import asdict, dataclass, field | ||||
| from hashlib import sha512 | ||||
| from pathlib import Path | ||||
| from sys import platform | ||||
| from typing import Optional | ||||
|  | ||||
| from dacite.core import from_dict | ||||
| from django.db import DatabaseError, InternalError, ProgrammingError | ||||
| @ -50,14 +49,14 @@ class BlueprintFile: | ||||
|     version: int | ||||
|     hash: str | ||||
|     last_m: int | ||||
|     meta: Optional[BlueprintMetadata] = field(default=None) | ||||
|     meta: BlueprintMetadata | None = field(default=None) | ||||
|  | ||||
|  | ||||
| def start_blueprint_watcher(): | ||||
|     """Start blueprint watcher, if it's not running already.""" | ||||
|     # This function might be called twice since it's called on celery startup | ||||
|     # pylint: disable=global-statement | ||||
|     global _file_watcher_started | ||||
|  | ||||
|     global _file_watcher_started  # noqa: PLW0603 | ||||
|     if _file_watcher_started: | ||||
|         return | ||||
|     observer = Observer() | ||||
| @ -126,7 +125,7 @@ def blueprints_find() -> list[BlueprintFile]: | ||||
|         # Check if any part in the path starts with a dot and assume a hidden file | ||||
|         if any(part for part in path.parts if part.startswith(".")): | ||||
|             continue | ||||
|         with open(path, "r", encoding="utf-8") as blueprint_file: | ||||
|         with open(path, encoding="utf-8") as blueprint_file: | ||||
|             try: | ||||
|                 raw_blueprint = load(blueprint_file.read(), BlueprintLoader) | ||||
|             except YAMLError as exc: | ||||
| @ -150,7 +149,7 @@ def blueprints_find() -> list[BlueprintFile]: | ||||
|     throws=(DatabaseError, ProgrammingError, InternalError), base=SystemTask, bind=True | ||||
| ) | ||||
| @prefill_task | ||||
| def blueprints_discovery(self: SystemTask, path: Optional[str] = None): | ||||
| def blueprints_discovery(self: SystemTask, path: str | None = None): | ||||
|     """Find blueprints and check if they need to be created in the database""" | ||||
|     count = 0 | ||||
|     for blueprint in blueprints_find(): | ||||
| @ -197,7 +196,7 @@ def check_blueprint_v1_file(blueprint: BlueprintFile): | ||||
| def apply_blueprint(self: SystemTask, instance_pk: str): | ||||
|     """Apply single blueprint""" | ||||
|     self.save_on_success = False | ||||
|     instance: Optional[BlueprintInstance] = None | ||||
|     instance: BlueprintInstance | None = None | ||||
|     try: | ||||
|         instance: BlueprintInstance = BlueprintInstance.objects.filter(pk=instance_pk).first() | ||||
|         if not instance or not instance.enabled: | ||||
| @ -225,10 +224,10 @@ def apply_blueprint(self: SystemTask, instance_pk: str): | ||||
|         instance.last_applied = now() | ||||
|         self.set_status(TaskStatus.SUCCESSFUL) | ||||
|     except ( | ||||
|         OSError, | ||||
|         DatabaseError, | ||||
|         ProgrammingError, | ||||
|         InternalError, | ||||
|         IOError, | ||||
|         BlueprintRetrievalFailed, | ||||
|         EntryInvalidError, | ||||
|     ) as exc: | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| """Inject brand into current request""" | ||||
|  | ||||
| from typing import Callable | ||||
| from collections.abc import Callable | ||||
|  | ||||
| from django.http.request import HttpRequest | ||||
| from django.http.response import HttpResponse | ||||
| @ -20,7 +20,7 @@ class BrandMiddleware: | ||||
|     def __call__(self, request: HttpRequest) -> HttpResponse: | ||||
|         if not hasattr(request, "brand"): | ||||
|             brand = get_brand_for_request(request) | ||||
|             setattr(request, "brand", brand) | ||||
|             request.brand = brand | ||||
|             locale = brand.default_locale | ||||
|             if locale != "": | ||||
|                 activate(locale) | ||||
|  | ||||
| @ -71,7 +71,7 @@ class Brand(SerializerModel): | ||||
|         """Get default locale""" | ||||
|         try: | ||||
|             return self.attributes.get("settings", {}).get("locale", "") | ||||
|         # pylint: disable=broad-except | ||||
|  | ||||
|         except Exception as exc: | ||||
|             LOGGER.warning("Failed to get default locale", exc=exc) | ||||
|             return "" | ||||
|  | ||||
| @ -1,8 +1,8 @@ | ||||
| """Application API Views""" | ||||
|  | ||||
| from collections.abc import Iterator | ||||
| from copy import copy | ||||
| from datetime import timedelta | ||||
| from typing import Iterator, Optional | ||||
|  | ||||
| from django.core.cache import cache | ||||
| from django.db.models import QuerySet | ||||
| @ -60,7 +60,7 @@ class ApplicationSerializer(ModelSerializer): | ||||
|  | ||||
|     meta_icon = ReadOnlyField(source="get_meta_icon") | ||||
|  | ||||
|     def get_launch_url(self, app: Application) -> Optional[str]: | ||||
|     def get_launch_url(self, app: Application) -> str | None: | ||||
|         """Allow formatting of launch URL""" | ||||
|         user = None | ||||
|         if "request" in self.context: | ||||
| @ -100,7 +100,6 @@ class ApplicationSerializer(ModelSerializer): | ||||
| class ApplicationViewSet(UsedByMixin, ModelViewSet): | ||||
|     """Application Viewset""" | ||||
|  | ||||
|     # pylint: disable=no-member | ||||
|     queryset = Application.objects.all().prefetch_related("provider") | ||||
|     serializer_class = ApplicationSerializer | ||||
|     search_fields = [ | ||||
| @ -131,7 +130,7 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet): | ||||
|         return queryset | ||||
|  | ||||
|     def _get_allowed_applications( | ||||
|         self, pagined_apps: Iterator[Application], user: Optional[User] = None | ||||
|         self, pagined_apps: Iterator[Application], user: User | None = None | ||||
|     ) -> list[Application]: | ||||
|         applications = [] | ||||
|         request = self.request._request | ||||
| @ -169,7 +168,7 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet): | ||||
|             try: | ||||
|                 for_user = User.objects.filter(pk=request.query_params.get("for_user")).first() | ||||
|             except ValueError: | ||||
|                 raise ValidationError({"for_user": "for_user must be numerical"}) | ||||
|                 raise ValidationError({"for_user": "for_user must be numerical"}) from None | ||||
|             if not for_user: | ||||
|                 raise ValidationError({"for_user": "User not found"}) | ||||
|         engine = PolicyEngine(application, for_user, request) | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| """AuthenticatedSessions API Viewset""" | ||||
|  | ||||
| from typing import Optional, TypedDict | ||||
| from typing import TypedDict | ||||
|  | ||||
| from django_filters.rest_framework import DjangoFilterBackend | ||||
| from guardian.utils import get_anonymous_user | ||||
| @ -72,11 +72,11 @@ class AuthenticatedSessionSerializer(ModelSerializer): | ||||
|         """Get parsed user agent""" | ||||
|         return user_agent_parser.Parse(instance.last_user_agent) | ||||
|  | ||||
|     def get_geo_ip(self, instance: AuthenticatedSession) -> Optional[GeoIPDict]:  # pragma: no cover | ||||
|     def get_geo_ip(self, instance: AuthenticatedSession) -> GeoIPDict | None:  # pragma: no cover | ||||
|         """Get GeoIP Data""" | ||||
|         return GEOIP_CONTEXT_PROCESSOR.city_dict(instance.last_ip) | ||||
|  | ||||
|     def get_asn(self, instance: AuthenticatedSession) -> Optional[ASNDict]:  # pragma: no cover | ||||
|     def get_asn(self, instance: AuthenticatedSession) -> ASNDict | None:  # pragma: no cover | ||||
|         """Get ASN Data""" | ||||
|         return ASN_CONTEXT_PROCESSOR.asn_dict(instance.last_ip) | ||||
|  | ||||
|  | ||||
| @ -1,7 +1,6 @@ | ||||
| """Groups API Viewset""" | ||||
|  | ||||
| from json import loads | ||||
| from typing import Optional | ||||
|  | ||||
| from django.http import Http404 | ||||
| from django_filters.filters import CharFilter, ModelMultipleChoiceFilter | ||||
| @ -59,7 +58,7 @@ class GroupSerializer(ModelSerializer): | ||||
|  | ||||
|     num_pk = IntegerField(read_only=True) | ||||
|  | ||||
|     def validate_parent(self, parent: Optional[Group]): | ||||
|     def validate_parent(self, parent: Group | None): | ||||
|         """Validate group parent (if set), ensuring the parent isn't itself""" | ||||
|         if not self.instance or not parent: | ||||
|             return parent | ||||
| @ -114,7 +113,7 @@ class GroupFilter(FilterSet): | ||||
|         try: | ||||
|             value = loads(value) | ||||
|         except ValueError: | ||||
|             raise ValidationError(detail="filter: failed to parse JSON") | ||||
|             raise ValidationError(detail="filter: failed to parse JSON") from None | ||||
|         if not isinstance(value, dict): | ||||
|             raise ValidationError(detail="filter: value must be key:value mapping") | ||||
|         qs = {} | ||||
| @ -140,7 +139,6 @@ class UserAccountSerializer(PassiveSerializer): | ||||
| class GroupViewSet(UsedByMixin, ModelViewSet): | ||||
|     """Group Viewset""" | ||||
|  | ||||
|     # pylint: disable=no-member | ||||
|     queryset = Group.objects.all().select_related("parent").prefetch_related("users") | ||||
|     serializer_class = GroupSerializer | ||||
|     search_fields = ["name", "is_superuser"] | ||||
|  | ||||
| @ -146,7 +146,7 @@ class PropertyMappingViewSet( | ||||
|             response_data["result"] = dumps( | ||||
|                 sanitize_item(result), indent=(4 if format_result else None) | ||||
|             ) | ||||
|         except Exception as exc:  # pylint: disable=broad-except | ||||
|         except Exception as exc: | ||||
|             response_data["result"] = str(exc) | ||||
|             response_data["successful"] = False | ||||
|         response = PropertyMappingTestResultSerializer(response_data) | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| """Source API Views""" | ||||
|  | ||||
| from typing import Iterable | ||||
| from collections.abc import Iterable | ||||
|  | ||||
| from django_filters.rest_framework import DjangoFilterBackend | ||||
| from drf_spectacular.utils import OpenApiResponse, extend_schema | ||||
|  | ||||
| @ -20,7 +20,7 @@ from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT | ||||
| from authentik.core.api.used_by import UsedByMixin | ||||
| from authentik.core.api.users import UserSerializer | ||||
| from authentik.core.api.utils import PassiveSerializer | ||||
| from authentik.core.models import USER_ATTRIBUTE_TOKEN_EXPIRING, Token, TokenIntents, User | ||||
| from authentik.core.models import USER_ATTRIBUTE_TOKEN_EXPIRING, Token, TokenIntents | ||||
| from authentik.events.models import Event, EventAction | ||||
| from authentik.events.utils import model_to_dict | ||||
| from authentik.rbac.decorators import permission_required | ||||
| @ -36,13 +36,6 @@ class TokenSerializer(ManagedSerializer, ModelSerializer): | ||||
|         if SERIALIZER_CONTEXT_BLUEPRINT in self.context: | ||||
|             self.fields["key"] = CharField(required=False) | ||||
|  | ||||
|     def validate_user(self, user: User): | ||||
|         """Ensure user of token cannot be changed""" | ||||
|         if self.instance and self.instance.user_id: | ||||
|             if user.pk != self.instance.user_id: | ||||
|                 raise ValidationError("User cannot be changed") | ||||
|         return user | ||||
|  | ||||
|     def validate(self, attrs: dict[Any, str]) -> dict[Any, str]: | ||||
|         """Ensure only API or App password tokens are created.""" | ||||
|         request: Request = self.context.get("request") | ||||
|  | ||||
| @ -65,7 +65,7 @@ class TransactionApplicationSerializer(PassiveSerializer): | ||||
|                 raise ValidationError("Invalid provider model") | ||||
|             self._provider_model = model | ||||
|         except LookupError: | ||||
|             raise ValidationError("Invalid provider model") | ||||
|             raise ValidationError("Invalid provider model") from None | ||||
|         return fq_model_name | ||||
|  | ||||
|     def validate(self, attrs: dict) -> dict: | ||||
| @ -106,7 +106,7 @@ class TransactionApplicationSerializer(PassiveSerializer): | ||||
|                 { | ||||
|                     exc.entry_id: exc.validation_error.detail, | ||||
|                 } | ||||
|             ) | ||||
|             ) from None | ||||
|         return blueprint | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -54,7 +54,6 @@ class UsedByMixin: | ||||
|         responses={200: UsedBySerializer(many=True)}, | ||||
|     ) | ||||
|     @action(detail=True, pagination_class=None, filter_backends=[]) | ||||
|     # pylint: disable=too-many-locals | ||||
|     def used_by(self, request: Request, *args, **kwargs) -> Response: | ||||
|         """Get a list of all objects that use this object""" | ||||
|         model: Model = self.get_object() | ||||
|  | ||||
| @ -2,7 +2,7 @@ | ||||
|  | ||||
| from datetime import timedelta | ||||
| from json import loads | ||||
| from typing import Any, Optional | ||||
| from typing import Any | ||||
|  | ||||
| from django.contrib.auth import update_session_auth_hash | ||||
| from django.contrib.sessions.backends.cache import KEY_PREFIX | ||||
| @ -142,7 +142,7 @@ class UserSerializer(ModelSerializer): | ||||
|         self._set_password(instance, password) | ||||
|         return instance | ||||
|  | ||||
|     def _set_password(self, instance: User, password: Optional[str]): | ||||
|     def _set_password(self, instance: User, password: str | None): | ||||
|         """Set password of user if we're in a blueprint context, and if it's an empty | ||||
|         string then use an unusable password""" | ||||
|         if SERIALIZER_CONTEXT_BLUEPRINT in self.context and password: | ||||
| @ -358,7 +358,7 @@ class UsersFilter(FilterSet): | ||||
|         try: | ||||
|             value = loads(value) | ||||
|         except ValueError: | ||||
|             raise ValidationError(detail="filter: failed to parse JSON") | ||||
|             raise ValidationError(detail="filter: failed to parse JSON") from None | ||||
|         if not isinstance(value, dict): | ||||
|             raise ValidationError(detail="filter: value must be key:value mapping") | ||||
|         qs = {} | ||||
| @ -397,15 +397,14 @@ class UserViewSet(UsedByMixin, ModelViewSet): | ||||
|     def get_queryset(self):  # pragma: no cover | ||||
|         return User.objects.all().exclude_anonymous().prefetch_related("ak_groups") | ||||
|  | ||||
|     def _create_recovery_link(self) -> tuple[Optional[str], Optional[Token]]: | ||||
|     def _create_recovery_link(self) -> tuple[str, Token]: | ||||
|         """Create a recovery link (when the current brand has a recovery flow set), | ||||
|         that can either be shown to an admin or sent to the user directly""" | ||||
|         brand: Brand = self.request._request.brand | ||||
|         # Check that there is a recovery flow, if not return an error | ||||
|         flow = brand.flow_recovery | ||||
|         if not flow: | ||||
|             LOGGER.debug("No recovery flow set") | ||||
|             return None, None | ||||
|             raise ValidationError({"non_field_errors": "No recovery flow set."}) | ||||
|         user: User = self.get_object() | ||||
|         planner = FlowPlanner(flow) | ||||
|         planner.allow_empty_flows = True | ||||
| @ -417,8 +416,9 @@ class UserViewSet(UsedByMixin, ModelViewSet): | ||||
|                 }, | ||||
|             ) | ||||
|         except FlowNonApplicableException: | ||||
|             LOGGER.warning("Recovery flow not applicable to user") | ||||
|             return None, None | ||||
|             raise ValidationError( | ||||
|                 {"non_field_errors": "Recovery flow not applicable to user"} | ||||
|             ) from None | ||||
|         token, __ = FlowToken.objects.update_or_create( | ||||
|             identifier=f"{user.uid}-password-reset", | ||||
|             defaults={ | ||||
| @ -563,16 +563,13 @@ class UserViewSet(UsedByMixin, ModelViewSet): | ||||
|     @extend_schema( | ||||
|         responses={ | ||||
|             "200": LinkSerializer(many=False), | ||||
|             "404": LinkSerializer(many=False), | ||||
|         }, | ||||
|         request=None, | ||||
|     ) | ||||
|     @action(detail=True, pagination_class=None, filter_backends=[]) | ||||
|     @action(detail=True, pagination_class=None, filter_backends=[], methods=["POST"]) | ||||
|     def recovery(self, request: Request, pk: int) -> Response: | ||||
|         """Create a temporary link that a user can use to recover their accounts""" | ||||
|         link, _ = self._create_recovery_link() | ||||
|         if not link: | ||||
|             LOGGER.debug("Couldn't create token") | ||||
|             return Response({"link": ""}, status=404) | ||||
|         return Response({"link": link}) | ||||
|  | ||||
|     @permission_required("authentik_core.reset_user_password") | ||||
| @ -587,31 +584,28 @@ class UserViewSet(UsedByMixin, ModelViewSet): | ||||
|         ], | ||||
|         responses={ | ||||
|             "204": OpenApiResponse(description="Successfully sent recover email"), | ||||
|             "404": OpenApiResponse(description="Bad request"), | ||||
|         }, | ||||
|         request=None, | ||||
|     ) | ||||
|     @action(detail=True, pagination_class=None, filter_backends=[]) | ||||
|     @action(detail=True, pagination_class=None, filter_backends=[], methods=["POST"]) | ||||
|     def recovery_email(self, request: Request, pk: int) -> Response: | ||||
|         """Create a temporary link that a user can use to recover their accounts""" | ||||
|         for_user: User = self.get_object() | ||||
|         if for_user.email == "": | ||||
|             LOGGER.debug("User doesn't have an email address") | ||||
|             return Response(status=404) | ||||
|             raise ValidationError({"non_field_errors": "User does not have an email address set."}) | ||||
|         link, token = self._create_recovery_link() | ||||
|         if not link: | ||||
|             LOGGER.debug("Couldn't create token") | ||||
|             return Response(status=404) | ||||
|         # Lookup the email stage to assure the current user can access it | ||||
|         stages = get_objects_for_user( | ||||
|             request.user, "authentik_stages_email.view_emailstage" | ||||
|         ).filter(pk=request.query_params.get("email_stage")) | ||||
|         if not stages.exists(): | ||||
|             LOGGER.debug("Email stage does not exist/user has no permissions") | ||||
|             return Response(status=404) | ||||
|             raise ValidationError({"non_field_errors": "Email stage does not exist."}) | ||||
|         email_stage: EmailStage = stages.first() | ||||
|         message = TemplateEmailMessage( | ||||
|             subject=_(email_stage.subject), | ||||
|             to=[(for_user.name, for_user.email)], | ||||
|             to=[for_user.email], | ||||
|             template_name=email_stage.template, | ||||
|             language=for_user.locale(request), | ||||
|             template_context={ | ||||
|  | ||||
| @ -14,14 +14,16 @@ class AuthentikCoreConfig(ManagedAppConfig): | ||||
|     mountpoint = "" | ||||
|     default = True | ||||
|  | ||||
|     def reconcile_global_debug_worker_hook(self): | ||||
|     @ManagedAppConfig.reconcile_global | ||||
|     def debug_worker_hook(self): | ||||
|         """Dispatch startup tasks inline when debugging""" | ||||
|         if settings.DEBUG: | ||||
|             from authentik.root.celery import worker_ready_hook | ||||
|  | ||||
|             worker_ready_hook() | ||||
|  | ||||
|     def reconcile_tenant_source_inbuilt(self): | ||||
|     @ManagedAppConfig.reconcile_tenant | ||||
|     def source_inbuilt(self): | ||||
|         """Reconcile inbuilt source""" | ||||
|         from authentik.core.models import Source | ||||
|  | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| """Authenticate with tokens""" | ||||
|  | ||||
| from typing import Any, Optional | ||||
| from typing import Any | ||||
|  | ||||
| from django.contrib.auth.backends import ModelBackend | ||||
| from django.http.request import HttpRequest | ||||
| @ -16,15 +16,15 @@ class InbuiltBackend(ModelBackend): | ||||
|     """Inbuilt backend""" | ||||
|  | ||||
|     def authenticate( | ||||
|         self, request: HttpRequest, username: Optional[str], password: Optional[str], **kwargs: Any | ||||
|     ) -> Optional[User]: | ||||
|         self, request: HttpRequest, username: str | None, password: str | None, **kwargs: Any | ||||
|     ) -> User | None: | ||||
|         user = super().authenticate(request, username=username, password=password, **kwargs) | ||||
|         if not user: | ||||
|             return None | ||||
|         self.set_method("password", request) | ||||
|         return user | ||||
|  | ||||
|     def set_method(self, method: str, request: Optional[HttpRequest], **kwargs): | ||||
|     def set_method(self, method: str, request: HttpRequest | None, **kwargs): | ||||
|         """Set method data on current flow, if possbiel""" | ||||
|         if not request: | ||||
|             return | ||||
| @ -40,18 +40,18 @@ class TokenBackend(InbuiltBackend): | ||||
|     """Authenticate with token""" | ||||
|  | ||||
|     def authenticate( | ||||
|         self, request: HttpRequest, username: Optional[str], password: Optional[str], **kwargs: Any | ||||
|     ) -> Optional[User]: | ||||
|         self, request: HttpRequest, username: str | None, password: str | None, **kwargs: Any | ||||
|     ) -> User | None: | ||||
|         try: | ||||
|             # pylint: disable=no-member | ||||
|  | ||||
|             user = User._default_manager.get_by_natural_key(username) | ||||
|         # pylint: disable=no-member | ||||
|  | ||||
|         except User.DoesNotExist: | ||||
|             # Run the default password hasher once to reduce the timing | ||||
|             # difference between an existing and a nonexistent user (#20760). | ||||
|             User().set_password(password) | ||||
|             return None | ||||
|         # pylint: disable=no-member | ||||
|  | ||||
|         tokens = Token.filter_not_expired( | ||||
|             user=user, key=password, intent=TokenIntents.INTENT_APP_PASSWORD | ||||
|         ) | ||||
|  | ||||
| @ -38,6 +38,6 @@ class TokenOutpostMiddleware: | ||||
|                 raise DenyConnection() | ||||
|         except AuthenticationFailed as exc: | ||||
|             LOGGER.warning("Failed to authenticate", exc=exc) | ||||
|             raise DenyConnection() | ||||
|             raise DenyConnection() from None | ||||
|  | ||||
|         scope["user"] = user | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| """Property Mapping Evaluator""" | ||||
|  | ||||
| from typing import Any, Optional | ||||
| from typing import Any | ||||
|  | ||||
| from django.db.models import Model | ||||
| from django.http import HttpRequest | ||||
| @ -27,9 +27,9 @@ class PropertyMappingEvaluator(BaseEvaluator): | ||||
|     def __init__( | ||||
|         self, | ||||
|         model: Model, | ||||
|         user: Optional[User] = None, | ||||
|         request: Optional[HttpRequest] = None, | ||||
|         dry_run: Optional[bool] = False, | ||||
|         user: User | None = None, | ||||
|         request: HttpRequest | None = None, | ||||
|         dry_run: bool | None = False, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         if hasattr(model, "name"): | ||||
|  | ||||
| @ -16,13 +16,8 @@ from authentik.events.middleware import should_log_model | ||||
| from authentik.events.models import Event, EventAction | ||||
| from authentik.events.utils import model_to_dict | ||||
|  | ||||
| BANNER_TEXT = """### authentik shell ({authentik}) | ||||
| ### Node {node} | Arch {arch} | Python {python} """.format( | ||||
|     node=platform.node(), | ||||
|     python=platform.python_version(), | ||||
|     arch=platform.machine(), | ||||
|     authentik=get_full_version(), | ||||
| ) | ||||
| BANNER_TEXT = f"""### authentik shell ({get_full_version()}) | ||||
| ### Node {platform.node()} | Arch {platform.machine()} | Python {platform.python_version()} """ | ||||
|  | ||||
|  | ||||
| class Command(BaseCommand): | ||||
| @ -86,7 +81,7 @@ class Command(BaseCommand): | ||||
|  | ||||
|         # If Python code has been passed, execute it and exit. | ||||
|         if options["command"]: | ||||
|             # pylint: disable=exec-used | ||||
|  | ||||
|             exec(options["command"], namespace)  # nosec # noqa | ||||
|             return | ||||
|  | ||||
| @ -99,7 +94,7 @@ class Command(BaseCommand): | ||||
|         else: | ||||
|             try: | ||||
|                 hook() | ||||
|             except Exception:  # pylint: disable=broad-except | ||||
|             except Exception: | ||||
|                 # Match the behavior of the cpython shell where an error in | ||||
|                 # sys.__interactivehook__ prints a warning and the exception | ||||
|                 # and continues. | ||||
|  | ||||
| @ -1,7 +1,7 @@ | ||||
| """authentik admin Middleware to impersonate users""" | ||||
|  | ||||
| from collections.abc import Callable | ||||
| from contextvars import ContextVar | ||||
| from typing import Callable, Optional | ||||
| from uuid import uuid4 | ||||
|  | ||||
| from django.http import HttpRequest, HttpResponse | ||||
| @ -15,9 +15,9 @@ RESPONSE_HEADER_ID = "X-authentik-id" | ||||
| KEY_AUTH_VIA = "auth_via" | ||||
| KEY_USER = "user" | ||||
|  | ||||
| CTX_REQUEST_ID = ContextVar[Optional[str]](STRUCTLOG_KEY_PREFIX + "request_id", default=None) | ||||
| CTX_HOST = ContextVar[Optional[str]](STRUCTLOG_KEY_PREFIX + "host", default=None) | ||||
| CTX_AUTH_VIA = ContextVar[Optional[str]](STRUCTLOG_KEY_PREFIX + KEY_AUTH_VIA, default=None) | ||||
| CTX_REQUEST_ID = ContextVar[str | None](STRUCTLOG_KEY_PREFIX + "request_id", default=None) | ||||
| CTX_HOST = ContextVar[str | None](STRUCTLOG_KEY_PREFIX + "host", default=None) | ||||
| CTX_AUTH_VIA = ContextVar[str | None](STRUCTLOG_KEY_PREFIX + KEY_AUTH_VIA, default=None) | ||||
|  | ||||
|  | ||||
| class ImpersonateMiddleware: | ||||
| @ -55,7 +55,7 @@ class RequestIDMiddleware: | ||||
|     def __call__(self, request: HttpRequest) -> HttpResponse: | ||||
|         if not hasattr(request, "request_id"): | ||||
|             request_id = uuid4().hex | ||||
|             setattr(request, "request_id", request_id) | ||||
|             request.request_id = request_id | ||||
|             CTX_REQUEST_ID.set(request_id) | ||||
|             CTX_HOST.set(request.get_host()) | ||||
|             set_tag("authentik.request_id", request_id) | ||||
| @ -67,7 +67,7 @@ class RequestIDMiddleware: | ||||
|         response = self.get_response(request) | ||||
|  | ||||
|         response[RESPONSE_HEADER_ID] = request.request_id | ||||
|         setattr(response, "ak_context", {}) | ||||
|         response.ak_context = {} | ||||
|         response.ak_context["request_id"] = CTX_REQUEST_ID.get() | ||||
|         response.ak_context["host"] = CTX_HOST.get() | ||||
|         response.ak_context[KEY_AUTH_VIA] = CTX_AUTH_VIA.get() | ||||
|  | ||||
| @ -222,7 +222,7 @@ class User(SerializerModel, GuardianUserMixin, AbstractUser): | ||||
|         there are at most 3 queries done""" | ||||
|         return Group.children_recursive(self.ak_groups.all()) | ||||
|  | ||||
|     def group_attributes(self, request: Optional[HttpRequest] = None) -> dict[str, Any]: | ||||
|     def group_attributes(self, request: HttpRequest | None = None) -> dict[str, Any]: | ||||
|         """Get a dictionary containing the attributes from all groups the user belongs to, | ||||
|         including the users attributes""" | ||||
|         final_attributes = {} | ||||
| @ -278,11 +278,11 @@ class User(SerializerModel, GuardianUserMixin, AbstractUser): | ||||
|         """Generate a globally unique UID, based on the user ID and the hashed secret key""" | ||||
|         return sha256(f"{self.id}-{get_install_id()}".encode("ascii")).hexdigest() | ||||
|  | ||||
|     def locale(self, request: Optional[HttpRequest] = None) -> str: | ||||
|     def locale(self, request: HttpRequest | None = None) -> str: | ||||
|         """Get the locale the user has configured""" | ||||
|         try: | ||||
|             return self.attributes.get("settings", {}).get("locale", "") | ||||
|         # pylint: disable=broad-except | ||||
|  | ||||
|         except Exception as exc: | ||||
|             LOGGER.warning("Failed to get default locale", exc=exc) | ||||
|         if request: | ||||
| @ -358,7 +358,7 @@ class Provider(SerializerModel): | ||||
|     objects = InheritanceManager() | ||||
|  | ||||
|     @property | ||||
|     def launch_url(self) -> Optional[str]: | ||||
|     def launch_url(self) -> str | None: | ||||
|         """URL to this provider and initiate authorization for the user. | ||||
|         Can return None for providers that are not URL-based""" | ||||
|         return None | ||||
| @ -435,7 +435,7 @@ class Application(SerializerModel, PolicyBindingModel): | ||||
|         return ApplicationSerializer | ||||
|  | ||||
|     @property | ||||
|     def get_meta_icon(self) -> Optional[str]: | ||||
|     def get_meta_icon(self) -> str | None: | ||||
|         """Get the URL to the App Icon image. If the name is /static or starts with http | ||||
|         it is returned as-is""" | ||||
|         if not self.meta_icon: | ||||
| @ -444,7 +444,7 @@ class Application(SerializerModel, PolicyBindingModel): | ||||
|             return self.meta_icon.name | ||||
|         return self.meta_icon.url | ||||
|  | ||||
|     def get_launch_url(self, user: Optional["User"] = None) -> Optional[str]: | ||||
|     def get_launch_url(self, user: Optional["User"] = None) -> str | None: | ||||
|         """Get launch URL if set, otherwise attempt to get launch URL based on provider.""" | ||||
|         url = None | ||||
|         if self.meta_launch_url: | ||||
| @ -457,13 +457,13 @@ class Application(SerializerModel, PolicyBindingModel): | ||||
|                 user = user._wrapped | ||||
|             try: | ||||
|                 return url % user.__dict__ | ||||
|             # pylint: disable=broad-except | ||||
|  | ||||
|             except Exception as exc: | ||||
|                 LOGGER.warning("Failed to format launch url", exc=exc) | ||||
|                 return url | ||||
|         return url | ||||
|  | ||||
|     def get_provider(self) -> Optional[Provider]: | ||||
|     def get_provider(self) -> Provider | None: | ||||
|         """Get casted provider instance""" | ||||
|         if not self.provider: | ||||
|             return None | ||||
| @ -551,7 +551,7 @@ class Source(ManagedModel, SerializerModel, PolicyBindingModel): | ||||
|     objects = InheritanceManager() | ||||
|  | ||||
|     @property | ||||
|     def icon_url(self) -> Optional[str]: | ||||
|     def icon_url(self) -> str | None: | ||||
|         """Get the URL to the Icon. If the name is /static or | ||||
|         starts with http it is returned as-is""" | ||||
|         if not self.icon: | ||||
| @ -566,7 +566,7 @@ class Source(ManagedModel, SerializerModel, PolicyBindingModel): | ||||
|             return self.user_path_template % { | ||||
|                 "slug": self.slug, | ||||
|             } | ||||
|         # pylint: disable=broad-except | ||||
|  | ||||
|         except Exception as exc: | ||||
|             LOGGER.warning("Failed to template user path", exc=exc, source=self) | ||||
|             return User.default_path() | ||||
| @ -576,12 +576,12 @@ class Source(ManagedModel, SerializerModel, PolicyBindingModel): | ||||
|         """Return component used to edit this object""" | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def ui_login_button(self, request: HttpRequest) -> Optional[UILoginButton]: | ||||
|     def ui_login_button(self, request: HttpRequest) -> UILoginButton | None: | ||||
|         """If source uses a http-based flow, return UI Information about the login | ||||
|         button. If source doesn't use http-based flow, return None.""" | ||||
|         return None | ||||
|  | ||||
|     def ui_user_settings(self) -> Optional[UserSettingSerializer]: | ||||
|     def ui_user_settings(self) -> UserSettingSerializer | None: | ||||
|         """Entrypoint to integrate with User settings. Can either return None if no | ||||
|         user settings are available, or UserSettingSerializer.""" | ||||
|         return None | ||||
| @ -627,6 +627,9 @@ class ExpiringModel(models.Model): | ||||
|     expires = models.DateTimeField(default=default_token_duration) | ||||
|     expiring = models.BooleanField(default=True) | ||||
|  | ||||
|     class Meta: | ||||
|         abstract = True | ||||
|  | ||||
|     def expire_action(self, *args, **kwargs): | ||||
|         """Handler which is called when this object is expired. By | ||||
|         default the object is deleted. This is less efficient compared | ||||
| @ -649,9 +652,6 @@ class ExpiringModel(models.Model): | ||||
|             return False | ||||
|         return now() > self.expires | ||||
|  | ||||
|     class Meta: | ||||
|         abstract = True | ||||
|  | ||||
|  | ||||
| class TokenIntents(models.TextChoices): | ||||
|     """Intents a Token can be created for.""" | ||||
| @ -681,6 +681,21 @@ class Token(SerializerModel, ManagedModel, ExpiringModel): | ||||
|     user = models.ForeignKey("User", on_delete=models.CASCADE, related_name="+") | ||||
|     description = models.TextField(default="", blank=True) | ||||
|  | ||||
|     class Meta: | ||||
|         verbose_name = _("Token") | ||||
|         verbose_name_plural = _("Tokens") | ||||
|         indexes = [ | ||||
|             models.Index(fields=["identifier"]), | ||||
|             models.Index(fields=["key"]), | ||||
|         ] | ||||
|         permissions = [("view_token_key", _("View token's key"))] | ||||
|  | ||||
|     def __str__(self): | ||||
|         description = f"{self.identifier}" | ||||
|         if self.expiring: | ||||
|             description += f" (expires={self.expires})" | ||||
|         return description | ||||
|  | ||||
|     @property | ||||
|     def serializer(self) -> type[Serializer]: | ||||
|         from authentik.core.api.tokens import TokenSerializer | ||||
| @ -708,21 +723,6 @@ class Token(SerializerModel, ManagedModel, ExpiringModel): | ||||
|             message=f"Token {self.identifier}'s secret was rotated.", | ||||
|         ).save() | ||||
|  | ||||
|     def __str__(self): | ||||
|         description = f"{self.identifier}" | ||||
|         if self.expiring: | ||||
|             description += f" (expires={self.expires})" | ||||
|         return description | ||||
|  | ||||
|     class Meta: | ||||
|         verbose_name = _("Token") | ||||
|         verbose_name_plural = _("Tokens") | ||||
|         indexes = [ | ||||
|             models.Index(fields=["identifier"]), | ||||
|             models.Index(fields=["key"]), | ||||
|         ] | ||||
|         permissions = [("view_token_key", _("View token's key"))] | ||||
|  | ||||
|  | ||||
| class PropertyMapping(SerializerModel, ManagedModel): | ||||
|     """User-defined key -> x mapping which can be used by providers to expose extra data.""" | ||||
| @ -743,7 +743,7 @@ class PropertyMapping(SerializerModel, ManagedModel): | ||||
|         """Get serializer for this model""" | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def evaluate(self, user: Optional[User], request: Optional[HttpRequest], **kwargs) -> Any: | ||||
|     def evaluate(self, user: User | None, request: HttpRequest | None, **kwargs) -> Any: | ||||
|         """Evaluate `self.expression` using `**kwargs` as Context.""" | ||||
|         from authentik.core.expression.evaluator import PropertyMappingEvaluator | ||||
|  | ||||
| @ -779,6 +779,13 @@ class AuthenticatedSession(ExpiringModel): | ||||
|     last_user_agent = models.TextField(blank=True) | ||||
|     last_used = models.DateTimeField(auto_now=True) | ||||
|  | ||||
|     class Meta: | ||||
|         verbose_name = _("Authenticated Session") | ||||
|         verbose_name_plural = _("Authenticated Sessions") | ||||
|  | ||||
|     def __str__(self) -> str: | ||||
|         return f"Authenticated Session {self.session_key[:10]}" | ||||
|  | ||||
|     @staticmethod | ||||
|     def from_request(request: HttpRequest, user: User) -> Optional["AuthenticatedSession"]: | ||||
|         """Create a new session from a http request""" | ||||
| @ -793,7 +800,3 @@ class AuthenticatedSession(ExpiringModel): | ||||
|             last_user_agent=request.META.get("HTTP_USER_AGENT", ""), | ||||
|             expires=request.session.get_expiry_date(), | ||||
|         ) | ||||
|  | ||||
|     class Meta: | ||||
|         verbose_name = _("Authenticated Session") | ||||
|         verbose_name_plural = _("Authenticated Sessions") | ||||
|  | ||||
| @ -1,7 +1,7 @@ | ||||
| """Source decision helper""" | ||||
|  | ||||
| from enum import Enum | ||||
| from typing import Any, Optional | ||||
| from typing import Any | ||||
|  | ||||
| from django.contrib import messages | ||||
| from django.db import IntegrityError | ||||
| @ -90,15 +90,14 @@ class SourceFlowManager: | ||||
|         self._logger = get_logger().bind(source=source, identifier=identifier) | ||||
|         self.policy_context = {} | ||||
|  | ||||
|     # pylint: disable=too-many-return-statements | ||||
|     def get_action(self, **kwargs) -> tuple[Action, Optional[UserSourceConnection]]: | ||||
|     def get_action(self, **kwargs) -> tuple[Action, UserSourceConnection | None]:  # noqa: PLR0911 | ||||
|         """decide which action should be taken""" | ||||
|         new_connection = self.connection_type(source=self.source, identifier=self.identifier) | ||||
|         # When request is authenticated, always link | ||||
|         if self.request.user.is_authenticated: | ||||
|             new_connection.user = self.request.user | ||||
|             new_connection = self.update_connection(new_connection, **kwargs) | ||||
|             # pylint: disable=no-member | ||||
|  | ||||
|             new_connection.save() | ||||
|             return Action.LINK, new_connection | ||||
|  | ||||
| @ -188,8 +187,10 @@ class SourceFlowManager: | ||||
|         # Default case, assume deny | ||||
|         error = Exception( | ||||
|             _( | ||||
|                 "Request to authenticate with %(source)s has been denied. Please authenticate " | ||||
|                 "with the source you've previously signed up with." % {"source": self.source.name} | ||||
|                 "Request to authenticate with {source} has been denied. Please authenticate " | ||||
|                 "with the source you've previously signed up with.".format_map( | ||||
|                     {"source": self.source.name} | ||||
|                 ) | ||||
|             ), | ||||
|         ) | ||||
|         return self.error_handler(error) | ||||
| @ -217,7 +218,7 @@ class SourceFlowManager: | ||||
|         self, | ||||
|         flow: Flow, | ||||
|         connection: UserSourceConnection, | ||||
|         stages: Optional[list[StageView]] = None, | ||||
|         stages: list[StageView] | None = None, | ||||
|         **kwargs, | ||||
|     ) -> HttpResponse: | ||||
|         """Prepare Authentication Plan, redirect user FlowExecutor""" | ||||
| @ -270,7 +271,9 @@ class SourceFlowManager: | ||||
|                 in_memory_stage( | ||||
|                     MessageStage, | ||||
|                     message=_( | ||||
|                         "Successfully authenticated with %(source)s!" % {"source": self.source.name} | ||||
|                         "Successfully authenticated with {source}!".format_map( | ||||
|                             {"source": self.source.name} | ||||
|                         ) | ||||
|                     ), | ||||
|                 ) | ||||
|             ], | ||||
| @ -294,7 +297,7 @@ class SourceFlowManager: | ||||
|         ).from_http(self.request) | ||||
|         messages.success( | ||||
|             self.request, | ||||
|             _("Successfully linked %(source)s!" % {"source": self.source.name}), | ||||
|             _("Successfully linked {source}!".format_map({"source": self.source.name})), | ||||
|         ) | ||||
|         return redirect( | ||||
|             reverse( | ||||
| @ -322,7 +325,9 @@ class SourceFlowManager: | ||||
|                 in_memory_stage( | ||||
|                     MessageStage, | ||||
|                     message=_( | ||||
|                         "Successfully authenticated with %(source)s!" % {"source": self.source.name} | ||||
|                         "Successfully authenticated with {source}!".format_map( | ||||
|                             {"source": self.source.name} | ||||
|                         ) | ||||
|                     ), | ||||
|                 ) | ||||
|             ], | ||||
|  | ||||
| @ -37,20 +37,20 @@ def clean_expired_models(self: SystemTask): | ||||
|         messages.append(f"Expired {amount} {cls._meta.verbose_name_plural}") | ||||
|     # Special case | ||||
|     amount = 0 | ||||
|     # pylint: disable=no-member | ||||
|  | ||||
|     for session in AuthenticatedSession.objects.all(): | ||||
|         cache_key = f"{KEY_PREFIX}{session.session_key}" | ||||
|         value = None | ||||
|         try: | ||||
|             value = cache.get(cache_key) | ||||
|         # pylint: disable=broad-except | ||||
|  | ||||
|         except Exception as exc: | ||||
|             LOGGER.debug("Failed to get session from cache", exc=exc) | ||||
|         if not value: | ||||
|             session.delete() | ||||
|             amount += 1 | ||||
|     LOGGER.debug("Expired sessions", model=AuthenticatedSession, amount=amount) | ||||
|     # pylint: disable=no-member | ||||
|  | ||||
|     messages.append(f"Expired {amount} {AuthenticatedSession._meta.verbose_name_plural}") | ||||
|     self.set_status(TaskStatus.SUCCESSFUL, *messages) | ||||
|  | ||||
|  | ||||
| @ -1,7 +1,7 @@ | ||||
| """authentik core models tests""" | ||||
|  | ||||
| from collections.abc import Callable | ||||
| from time import sleep | ||||
| from typing import Callable | ||||
|  | ||||
| from django.test import RequestFactory, TestCase | ||||
| from django.utils.timezone import now | ||||
|  | ||||
| @ -173,5 +173,5 @@ class TestSourceFlowManager(TestCase): | ||||
|         self.assertEqual(action, Action.ENROLL) | ||||
|         response = flow_manager.get_flow() | ||||
|         self.assertIsInstance(response, AccessDeniedResponse) | ||||
|         # pylint: disable=no-member | ||||
|  | ||||
|         self.assertEqual(response.error_message, "foo") | ||||
|  | ||||
| @ -7,8 +7,8 @@ from guardian.shortcuts import get_anonymous_user | ||||
| from rest_framework.test import APITestCase | ||||
|  | ||||
| from authentik.core.api.tokens import TokenSerializer | ||||
| from authentik.core.models import USER_ATTRIBUTE_TOKEN_EXPIRING, Token, TokenIntents | ||||
| from authentik.core.tests.utils import create_test_admin_user, create_test_user | ||||
| from authentik.core.models import USER_ATTRIBUTE_TOKEN_EXPIRING, Token, TokenIntents, User | ||||
| from authentik.core.tests.utils import create_test_admin_user | ||||
| from authentik.lib.generators import generate_id | ||||
|  | ||||
|  | ||||
| @ -17,7 +17,7 @@ class TestTokenAPI(APITestCase): | ||||
|  | ||||
|     def setUp(self) -> None: | ||||
|         super().setUp() | ||||
|         self.user = create_test_user() | ||||
|         self.user = User.objects.create(username="testuser") | ||||
|         self.admin = create_test_admin_user() | ||||
|         self.client.force_login(self.user) | ||||
|  | ||||
| @ -76,24 +76,6 @@ class TestTokenAPI(APITestCase): | ||||
|         self.assertEqual(token.intent, TokenIntents.INTENT_API) | ||||
|         self.assertEqual(token.expiring, False) | ||||
|  | ||||
|     def test_token_change_user(self): | ||||
|         """Test creating a token and then changing the user""" | ||||
|         ident = generate_id() | ||||
|         response = self.client.post(reverse("authentik_api:token-list"), {"identifier": ident}) | ||||
|         self.assertEqual(response.status_code, 201) | ||||
|         token = Token.objects.get(identifier=ident) | ||||
|         self.assertEqual(token.user, self.user) | ||||
|         self.assertEqual(token.intent, TokenIntents.INTENT_API) | ||||
|         self.assertEqual(token.expiring, True) | ||||
|         self.assertTrue(self.user.has_perm("authentik_core.view_token_key", token)) | ||||
|         response = self.client.put( | ||||
|             reverse("authentik_api:token-detail", kwargs={"identifier": ident}), | ||||
|             data={"identifier": "user_token_poc_v3", "intent": "api", "user": self.admin.pk}, | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 400) | ||||
|         token.refresh_from_db() | ||||
|         self.assertEqual(token.user, self.user) | ||||
|  | ||||
|     def test_list(self): | ||||
|         """Test Token List (Test normal authentication)""" | ||||
|         Token.objects.all().delete() | ||||
|  | ||||
| @ -60,10 +60,11 @@ class TestUsersAPI(APITestCase): | ||||
|     def test_recovery_no_flow(self): | ||||
|         """Test user recovery link (no recovery flow set)""" | ||||
|         self.client.force_login(self.admin) | ||||
|         response = self.client.get( | ||||
|         response = self.client.post( | ||||
|             reverse("authentik_api:user-recovery", kwargs={"pk": self.user.pk}) | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 404) | ||||
|         self.assertEqual(response.status_code, 400) | ||||
|         self.assertJSONEqual(response.content, {"non_field_errors": "No recovery flow set."}) | ||||
|  | ||||
|     def test_set_password(self): | ||||
|         """Test Direct password set""" | ||||
| @ -84,7 +85,7 @@ class TestUsersAPI(APITestCase): | ||||
|         brand.flow_recovery = flow | ||||
|         brand.save() | ||||
|         self.client.force_login(self.admin) | ||||
|         response = self.client.get( | ||||
|         response = self.client.post( | ||||
|             reverse("authentik_api:user-recovery", kwargs={"pk": self.user.pk}) | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 200) | ||||
| @ -92,16 +93,20 @@ class TestUsersAPI(APITestCase): | ||||
|     def test_recovery_email_no_flow(self): | ||||
|         """Test user recovery link (no recovery flow set)""" | ||||
|         self.client.force_login(self.admin) | ||||
|         response = self.client.get( | ||||
|         response = self.client.post( | ||||
|             reverse("authentik_api:user-recovery-email", kwargs={"pk": self.user.pk}) | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 404) | ||||
|         self.assertEqual(response.status_code, 400) | ||||
|         self.assertJSONEqual( | ||||
|             response.content, {"non_field_errors": "User does not have an email address set."} | ||||
|         ) | ||||
|         self.user.email = "foo@bar.baz" | ||||
|         self.user.save() | ||||
|         response = self.client.get( | ||||
|         response = self.client.post( | ||||
|             reverse("authentik_api:user-recovery-email", kwargs={"pk": self.user.pk}) | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 404) | ||||
|         self.assertEqual(response.status_code, 400) | ||||
|         self.assertJSONEqual(response.content, {"non_field_errors": "No recovery flow set."}) | ||||
|  | ||||
|     def test_recovery_email_no_stage(self): | ||||
|         """Test user recovery link (no email stage)""" | ||||
| @ -112,10 +117,11 @@ class TestUsersAPI(APITestCase): | ||||
|         brand.flow_recovery = flow | ||||
|         brand.save() | ||||
|         self.client.force_login(self.admin) | ||||
|         response = self.client.get( | ||||
|         response = self.client.post( | ||||
|             reverse("authentik_api:user-recovery-email", kwargs={"pk": self.user.pk}) | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 404) | ||||
|         self.assertEqual(response.status_code, 400) | ||||
|         self.assertJSONEqual(response.content, {"non_field_errors": "Email stage does not exist."}) | ||||
|  | ||||
|     def test_recovery_email(self): | ||||
|         """Test user recovery link""" | ||||
| @ -129,7 +135,7 @@ class TestUsersAPI(APITestCase): | ||||
|         stage = EmailStage.objects.create(name="email") | ||||
|  | ||||
|         self.client.force_login(self.admin) | ||||
|         response = self.client.get( | ||||
|         response = self.client.post( | ||||
|             reverse( | ||||
|                 "authentik_api:user-recovery-email", | ||||
|                 kwargs={"pk": self.user.pk}, | ||||
|  | ||||
| @ -1,7 +1,5 @@ | ||||
| """Test Utils""" | ||||
|  | ||||
| from typing import Optional | ||||
|  | ||||
| from django.utils.text import slugify | ||||
|  | ||||
| from authentik.brands.models import Brand | ||||
| @ -22,7 +20,7 @@ def create_test_flow( | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def create_test_user(name: Optional[str] = None, **kwargs) -> User: | ||||
| def create_test_user(name: str | None = None, **kwargs) -> User: | ||||
|     """Generate a test user""" | ||||
|     uid = generate_id(20) if not name else name | ||||
|     kwargs.setdefault("email", f"{uid}@goauthentik.io") | ||||
| @ -36,7 +34,7 @@ def create_test_user(name: Optional[str] = None, **kwargs) -> User: | ||||
|     return user | ||||
|  | ||||
|  | ||||
| def create_test_admin_user(name: Optional[str] = None, **kwargs) -> User: | ||||
| def create_test_admin_user(name: str | None = None, **kwargs) -> User: | ||||
|     """Generate a test-admin user""" | ||||
|     user = create_test_user(name, **kwargs) | ||||
|     group = Group.objects.create(name=user.name or name, is_superuser=True) | ||||
|  | ||||
| @ -1,7 +1,6 @@ | ||||
| """authentik core dataclasses""" | ||||
|  | ||||
| from dataclasses import dataclass | ||||
| from typing import Optional | ||||
|  | ||||
| from rest_framework.fields import CharField | ||||
|  | ||||
| @ -20,7 +19,7 @@ class UILoginButton: | ||||
|     challenge: Challenge | ||||
|  | ||||
|     # Icon URL, used as-is | ||||
|     icon_url: Optional[str] = None | ||||
|     icon_url: str | None = None | ||||
|  | ||||
|  | ||||
| class UserSettingSerializer(PassiveSerializer): | ||||
|  | ||||
| @ -57,7 +57,7 @@ class RedirectToAppLaunch(View): | ||||
|                 }, | ||||
|             ) | ||||
|         except FlowNonApplicableException: | ||||
|             raise Http404 | ||||
|             raise Http404 from None | ||||
|         plan.insert_stage(in_memory_stage(RedirectToAppStage)) | ||||
|         request.session[SESSION_KEY_PLAN] = plan | ||||
|         return redirect_with_qs("authentik_core:if-flow", request.GET, flow_slug=flow.slug) | ||||
|  | ||||
| @ -61,7 +61,6 @@ class ServerErrorView(TemplateView): | ||||
|     response_class = ServerErrorTemplateResponse | ||||
|     template_name = "if/error.html" | ||||
|  | ||||
|     # pylint: disable=useless-super-delegation | ||||
|     def dispatch(self, *args, **kwargs):  # pragma: no cover | ||||
|         """Little wrapper so django accepts this function""" | ||||
|         return super().dispatch(*args, **kwargs) | ||||
|  | ||||
| @ -1,7 +1,6 @@ | ||||
| """Crypto API Views""" | ||||
|  | ||||
| from datetime import datetime | ||||
| from typing import Optional | ||||
|  | ||||
| from cryptography.hazmat.backends import default_backend | ||||
| from cryptography.hazmat.primitives.serialization import load_pem_private_key | ||||
| @ -56,25 +55,25 @@ class CertificateKeyPairSerializer(ModelSerializer): | ||||
|             return True | ||||
|         return str(request.query_params.get("include_details", "true")).lower() == "true" | ||||
|  | ||||
|     def get_fingerprint_sha256(self, instance: CertificateKeyPair) -> Optional[str]: | ||||
|     def get_fingerprint_sha256(self, instance: CertificateKeyPair) -> str | None: | ||||
|         "Get certificate Hash (SHA256)" | ||||
|         if not self._should_include_details: | ||||
|             return None | ||||
|         return instance.fingerprint_sha256 | ||||
|  | ||||
|     def get_fingerprint_sha1(self, instance: CertificateKeyPair) -> Optional[str]: | ||||
|     def get_fingerprint_sha1(self, instance: CertificateKeyPair) -> str | None: | ||||
|         "Get certificate Hash (SHA1)" | ||||
|         if not self._should_include_details: | ||||
|             return None | ||||
|         return instance.fingerprint_sha1 | ||||
|  | ||||
|     def get_cert_expiry(self, instance: CertificateKeyPair) -> Optional[datetime]: | ||||
|     def get_cert_expiry(self, instance: CertificateKeyPair) -> datetime | None: | ||||
|         "Get certificate expiry" | ||||
|         if not self._should_include_details: | ||||
|             return None | ||||
|         return DateTimeField().to_representation(instance.certificate.not_valid_after) | ||||
|  | ||||
|     def get_cert_subject(self, instance: CertificateKeyPair) -> Optional[str]: | ||||
|     def get_cert_subject(self, instance: CertificateKeyPair) -> str | None: | ||||
|         """Get certificate subject as full rfc4514""" | ||||
|         if not self._should_include_details: | ||||
|             return None | ||||
| @ -84,7 +83,7 @@ class CertificateKeyPairSerializer(ModelSerializer): | ||||
|         """Show if this keypair has a private key configured or not""" | ||||
|         return instance.key_data != "" and instance.key_data is not None | ||||
|  | ||||
|     def get_private_key_type(self, instance: CertificateKeyPair) -> Optional[str]: | ||||
|     def get_private_key_type(self, instance: CertificateKeyPair) -> str | None: | ||||
|         """Get the private key's type, if set""" | ||||
|         if not self._should_include_details: | ||||
|             return None | ||||
| @ -121,7 +120,7 @@ class CertificateKeyPairSerializer(ModelSerializer): | ||||
|             str(load_pem_x509_certificate(value.encode("utf-8"), default_backend())) | ||||
|         except ValueError as exc: | ||||
|             LOGGER.warning("Failed to load certificate", exc=exc) | ||||
|             raise ValidationError("Unable to load certificate.") | ||||
|             raise ValidationError("Unable to load certificate.") from None | ||||
|         return value | ||||
|  | ||||
|     def validate_key_data(self, value: str) -> str: | ||||
| @ -140,7 +139,7 @@ class CertificateKeyPairSerializer(ModelSerializer): | ||||
|                 ) | ||||
|             except (ValueError, TypeError) as exc: | ||||
|                 LOGGER.warning("Failed to load private key", exc=exc) | ||||
|                 raise ValidationError("Unable to load private key (possibly encrypted?).") | ||||
|                 raise ValidationError("Unable to load private key (possibly encrypted?).") from None | ||||
|         return value | ||||
|  | ||||
|     class Meta: | ||||
|  | ||||
| @ -1,7 +1,6 @@ | ||||
| """authentik crypto app config""" | ||||
|  | ||||
| from datetime import datetime, timezone | ||||
| from typing import Optional | ||||
| from datetime import UTC, datetime | ||||
|  | ||||
| from authentik.blueprints.apps import ManagedAppConfig | ||||
| from authentik.lib.generators import generate_id | ||||
| @ -36,20 +35,22 @@ class AuthentikCryptoConfig(ManagedAppConfig): | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|     def reconcile_tenant_managed_jwt_cert(self): | ||||
|     @ManagedAppConfig.reconcile_tenant | ||||
|     def managed_jwt_cert(self): | ||||
|         """Ensure managed JWT certificate""" | ||||
|         from authentik.crypto.models import CertificateKeyPair | ||||
|  | ||||
|         cert: Optional[CertificateKeyPair] = CertificateKeyPair.objects.filter( | ||||
|         cert: CertificateKeyPair | None = CertificateKeyPair.objects.filter( | ||||
|             managed=MANAGED_KEY | ||||
|         ).first() | ||||
|         now = datetime.now(tz=timezone.utc) | ||||
|         now = datetime.now(tz=UTC) | ||||
|         if not cert or ( | ||||
|             now < cert.certificate.not_valid_after_utc or now > cert.certificate.not_valid_after_utc | ||||
|         ): | ||||
|             self._create_update_cert() | ||||
|  | ||||
|     def reconcile_tenant_self_signed(self): | ||||
|     @ManagedAppConfig.reconcile_tenant | ||||
|     def self_signed(self): | ||||
|         """Create self-signed keypair""" | ||||
|         from authentik.crypto.builder import CertificateBuilder | ||||
|         from authentik.crypto.models import CertificateKeyPair | ||||
|  | ||||
| @ -2,7 +2,6 @@ | ||||
|  | ||||
| import datetime | ||||
| import uuid | ||||
| from typing import Optional | ||||
|  | ||||
| from cryptography import x509 | ||||
| from cryptography.hazmat.backends import default_backend | ||||
| @ -44,7 +43,7 @@ class CertificateBuilder: | ||||
|     def generate_private_key(self) -> PrivateKeyTypes: | ||||
|         """Generate private key""" | ||||
|         if self._use_ec_private_key: | ||||
|             return ec.generate_private_key(curve=ec.SECP256R1) | ||||
|             return ec.generate_private_key(curve=ec.SECP256R1()) | ||||
|         return rsa.generate_private_key( | ||||
|             public_exponent=65537, key_size=4096, backend=default_backend() | ||||
|         ) | ||||
| @ -52,7 +51,7 @@ class CertificateBuilder: | ||||
|     def build( | ||||
|         self, | ||||
|         validity_days: int = 365, | ||||
|         subject_alt_names: Optional[list[str]] = None, | ||||
|         subject_alt_names: list[str] | None = None, | ||||
|     ): | ||||
|         """Build self-signed certificate""" | ||||
|         one_day = datetime.timedelta(1, 0, 0) | ||||
|  | ||||
| @ -24,13 +24,13 @@ class Command(TenantCommand): | ||||
|         if not keypair: | ||||
|             keypair = CertificateKeyPair(name=options["name"]) | ||||
|             dirty = True | ||||
|         with open(options["certificate"], mode="r", encoding="utf-8") as _cert: | ||||
|         with open(options["certificate"], encoding="utf-8") as _cert: | ||||
|             cert_data = _cert.read() | ||||
|             if keypair.certificate_data != cert_data: | ||||
|                 dirty = True | ||||
|             keypair.certificate_data = cert_data | ||||
|         if options["private_key"]: | ||||
|             with open(options["private_key"], mode="r", encoding="utf-8") as _key: | ||||
|             with open(options["private_key"], encoding="utf-8") as _key: | ||||
|                 key_data = _key.read() | ||||
|                 if keypair.key_data != key_data: | ||||
|                     dirty = True | ||||
|  | ||||
| @ -2,7 +2,6 @@ | ||||
|  | ||||
| from binascii import hexlify | ||||
| from hashlib import md5 | ||||
| from typing import Optional | ||||
| from uuid import uuid4 | ||||
|  | ||||
| from cryptography.hazmat.backends import default_backend | ||||
| @ -37,9 +36,9 @@ class CertificateKeyPair(SerializerModel, ManagedModel, CreatedUpdatedModel): | ||||
|         default="", | ||||
|     ) | ||||
|  | ||||
|     _cert: Optional[Certificate] = None | ||||
|     _private_key: Optional[PrivateKeyTypes] = None | ||||
|     _public_key: Optional[PublicKeyTypes] = None | ||||
|     _cert: Certificate | None = None | ||||
|     _private_key: PrivateKeyTypes | None = None | ||||
|     _public_key: PublicKeyTypes | None = None | ||||
|  | ||||
|     @property | ||||
|     def serializer(self) -> Serializer: | ||||
| @ -57,7 +56,7 @@ class CertificateKeyPair(SerializerModel, ManagedModel, CreatedUpdatedModel): | ||||
|         return self._cert | ||||
|  | ||||
|     @property | ||||
|     def public_key(self) -> Optional[PublicKeyTypes]: | ||||
|     def public_key(self) -> PublicKeyTypes | None: | ||||
|         """Get public key of the private key""" | ||||
|         if not self._public_key: | ||||
|             self._public_key = self.private_key.public_key() | ||||
| @ -66,7 +65,7 @@ class CertificateKeyPair(SerializerModel, ManagedModel, CreatedUpdatedModel): | ||||
|     @property | ||||
|     def private_key( | ||||
|         self, | ||||
|     ) -> Optional[PrivateKeyTypes]: | ||||
|     ) -> PrivateKeyTypes | None: | ||||
|         """Get python cryptography PrivateKey instance""" | ||||
|         if not self._private_key and self.key_data != "": | ||||
|             try: | ||||
|  | ||||
| @ -58,7 +58,7 @@ def certificate_discovery(self: SystemTask): | ||||
|         else: | ||||
|             cert_name = path.name.replace(path.suffix, "") | ||||
|         try: | ||||
|             with open(path, "r", encoding="utf-8") as _file: | ||||
|             with open(path, encoding="utf-8") as _file: | ||||
|                 body = _file.read() | ||||
|                 if "PRIVATE KEY" in body: | ||||
|                     private_keys[cert_name] = ensure_private_key_valid(body) | ||||
|  | ||||
| @ -267,7 +267,7 @@ class TestCrypto(APITestCase): | ||||
|             with open(f"{temp_dir}/foo.bar/privkey.pem", "w+", encoding="utf-8") as _key: | ||||
|                 _key.write(builder.private_key) | ||||
|             with CONFIG.patch("cert_discovery_dir", temp_dir): | ||||
|                 certificate_discovery()  # pylint: disable=no-value-for-parameter | ||||
|                 certificate_discovery() | ||||
|         keypair: CertificateKeyPair = CertificateKeyPair.objects.filter( | ||||
|             managed=MANAGED_DISCOVERED % "foo" | ||||
|         ).first() | ||||
|  | ||||
| @ -31,7 +31,7 @@ class EnterpriseRequiredMixin: | ||||
|  | ||||
|     def validate(self, attrs: dict) -> dict: | ||||
|         """Check that a valid license exists""" | ||||
|         if not LicenseKey.cached_summary().has_license: | ||||
|         if not LicenseKey.cached_summary().valid: | ||||
|             raise ValidationError(_("Enterprise is required to create/update this object.")) | ||||
|         return super().validate(attrs) | ||||
|  | ||||
|  | ||||
| @ -13,7 +13,8 @@ class AuthentikEnterpriseAuditConfig(EnterpriseConfig): | ||||
|     verbose_name = "authentik Enterprise.Audit" | ||||
|     default = True | ||||
|  | ||||
|     def reconcile_global_install_middleware(self): | ||||
|     @EnterpriseConfig.reconcile_global | ||||
|     def install_middleware(self): | ||||
|         """Install enterprise audit middleware""" | ||||
|         orig_import = "authentik.events.middleware.AuditMiddleware" | ||||
|         new_import = "authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware" | ||||
|  | ||||
| @ -11,6 +11,7 @@ from django.db.models.expressions import BaseExpression, Combinable | ||||
| from django.db.models.signals import post_init | ||||
| from django.http import HttpRequest | ||||
|  | ||||
| from authentik.core.models import User | ||||
| from authentik.events.middleware import AuditMiddleware, should_log_model | ||||
| from authentik.events.utils import cleanse_dict, sanitize_item | ||||
|  | ||||
| @ -27,10 +28,13 @@ class EnterpriseAuditMiddleware(AuditMiddleware): | ||||
|         super().connect(request) | ||||
|         if not self.enabled: | ||||
|             return | ||||
|         user = getattr(request, "user", self.anonymous_user) | ||||
|         if not user.is_authenticated: | ||||
|             user = self.anonymous_user | ||||
|         if not hasattr(request, "request_id"): | ||||
|             return | ||||
|         post_init.connect( | ||||
|             partial(self.post_init_handler, request=request), | ||||
|             partial(self.post_init_handler, user=user, request=request), | ||||
|             dispatch_uid=request.request_id, | ||||
|             weak=False, | ||||
|         ) | ||||
| @ -58,7 +62,7 @@ class EnterpriseAuditMiddleware(AuditMiddleware): | ||||
|                 field_value = value.name | ||||
|  | ||||
|             # If current field value is an expression, we are not evaluating it | ||||
|             if isinstance(field_value, (BaseExpression, Combinable)): | ||||
|             if isinstance(field_value, BaseExpression | Combinable): | ||||
|                 continue | ||||
|             field_value = field.to_python(field_value) | ||||
|             data[field.name] = deepcopy(field_value) | ||||
| @ -72,21 +76,21 @@ class EnterpriseAuditMiddleware(AuditMiddleware): | ||||
|                 diff[key] = {"previous_value": value, "new_value": after.get(key)} | ||||
|         return sanitize_item(diff) | ||||
|  | ||||
|     def post_init_handler(self, request: HttpRequest, sender, instance: Model, **_): | ||||
|     def post_init_handler(self, user: User, request: HttpRequest, sender, instance: Model, **_): | ||||
|         """post_init django model handler""" | ||||
|         if not should_log_model(instance): | ||||
|             return | ||||
|         if hasattr(instance, "_previous_state"): | ||||
|             return | ||||
|         before = len(connection.queries) | ||||
|         setattr(instance, "_previous_state", self.serialize_simple(instance)) | ||||
|         instance._previous_state = self.serialize_simple(instance) | ||||
|         after = len(connection.queries) | ||||
|         if after > before: | ||||
|             raise AssertionError("More queries generated by serialize_simple") | ||||
|  | ||||
|     # pylint: disable=too-many-arguments | ||||
|     def post_save_handler( | ||||
|         self, | ||||
|         user: User, | ||||
|         request: HttpRequest, | ||||
|         sender, | ||||
|         instance: Model, | ||||
| @ -108,4 +112,6 @@ class EnterpriseAuditMiddleware(AuditMiddleware): | ||||
|                 for field_set in ignored_field_sets: | ||||
|                     if set(diff.keys()) == set(field_set): | ||||
|                         return None | ||||
|         return super().post_save_handler(request, sender, instance, created, thread_kwargs, **_) | ||||
|         return super().post_save_handler( | ||||
|             user, request, sender, instance, created, thread_kwargs, **_ | ||||
|         ) | ||||
|  | ||||
| @ -27,7 +27,7 @@ CACHE_KEY_ENTERPRISE_LICENSE = "goauthentik.io/enterprise/license" | ||||
| CACHE_EXPIRY_ENTERPRISE_LICENSE = 3 * 60 * 60  # 2 Hours | ||||
|  | ||||
|  | ||||
| @lru_cache() | ||||
| @lru_cache | ||||
| def get_licensing_key() -> Certificate: | ||||
|     """Get Root CA PEM""" | ||||
|     with open("authentik/enterprise/public.pem", "rb") as _key: | ||||
| @ -88,7 +88,7 @@ class LicenseKey: | ||||
|         try: | ||||
|             headers = get_unverified_header(jwt) | ||||
|         except PyJWTError: | ||||
|             raise ValidationError("Unable to verify license") | ||||
|             raise ValidationError("Unable to verify license") from None | ||||
|         x5c: list[str] = headers.get("x5c", []) | ||||
|         if len(x5c) < 1: | ||||
|             raise ValidationError("Unable to verify license") | ||||
| @ -98,7 +98,7 @@ class LicenseKey: | ||||
|             our_cert.verify_directly_issued_by(intermediate) | ||||
|             intermediate.verify_directly_issued_by(get_licensing_key()) | ||||
|         except (InvalidSignature, TypeError, ValueError, Error): | ||||
|             raise ValidationError("Unable to verify license") | ||||
|             raise ValidationError("Unable to verify license") from None | ||||
|         try: | ||||
|             body = from_dict( | ||||
|                 LicenseKey, | ||||
| @ -110,7 +110,7 @@ class LicenseKey: | ||||
|                 ), | ||||
|             ) | ||||
|         except PyJWTError: | ||||
|             raise ValidationError("Unable to verify license") | ||||
|             raise ValidationError("Unable to verify license") from None | ||||
|         return body | ||||
|  | ||||
|     @staticmethod | ||||
| @ -188,21 +188,20 @@ class LicenseKey: | ||||
|  | ||||
|     def summary(self) -> LicenseSummary: | ||||
|         """Summary of license status""" | ||||
|         has_license = License.objects.all().count() > 0 | ||||
|         last_valid = LicenseKey.last_valid_date() | ||||
|         show_admin_warning = last_valid < now() - timedelta(weeks=2) | ||||
|         show_user_warning = last_valid < now() - timedelta(weeks=4) | ||||
|         read_only = last_valid < now() - timedelta(weeks=6) | ||||
|         latest_valid = datetime.fromtimestamp(self.exp) | ||||
|         return LicenseSummary( | ||||
|             show_admin_warning=show_admin_warning and has_license, | ||||
|             show_user_warning=show_user_warning and has_license, | ||||
|             read_only=read_only and has_license, | ||||
|             show_admin_warning=show_admin_warning, | ||||
|             show_user_warning=show_user_warning, | ||||
|             read_only=read_only, | ||||
|             latest_valid=latest_valid, | ||||
|             internal_users=self.internal_users, | ||||
|             external_users=self.external_users, | ||||
|             valid=self.is_valid(), | ||||
|             has_license=has_license, | ||||
|             has_license=License.objects.all().count() > 0, | ||||
|         ) | ||||
|  | ||||
|     @staticmethod | ||||
|  | ||||
| @ -1,7 +1,5 @@ | ||||
| """Enterprise license policies""" | ||||
|  | ||||
| from typing import Optional | ||||
|  | ||||
| from django.utils.translation import gettext_lazy as _ | ||||
|  | ||||
| from authentik.core.models import User, UserTypes | ||||
| @ -21,7 +19,7 @@ class EnterprisePolicyAccessView(PolicyAccessView): | ||||
|             return PolicyResult(False, _("Feature only accessible for internal users.")) | ||||
|         return PolicyResult(True) | ||||
|  | ||||
|     def user_has_access(self, user: Optional[User] = None) -> PolicyResult: | ||||
|     def user_has_access(self, user: User | None = None) -> PolicyResult: | ||||
|         user = user or self.request.user | ||||
|         request = PolicyRequest(user) | ||||
|         request.http_request = self.request | ||||
|  | ||||
| @ -6,13 +6,13 @@ from rest_framework.filters import OrderingFilter, SearchFilter | ||||
| from rest_framework.serializers import ModelSerializer | ||||
| from rest_framework.viewsets import GenericViewSet | ||||
|  | ||||
| from authentik.api.authorization import OwnerFilter, OwnerSuperuserPermissions | ||||
| from authentik.api.authorization import OwnerFilter, OwnerPermissions | ||||
| from authentik.core.api.groups import GroupMemberSerializer | ||||
| from authentik.core.api.used_by import UsedByMixin | ||||
| from authentik.enterprise.api import EnterpriseRequiredMixin | ||||
| from authentik.enterprise.providers.rac.api.endpoints import EndpointSerializer | ||||
| from authentik.enterprise.providers.rac.api.providers import RACProviderSerializer | ||||
| from authentik.enterprise.providers.rac.models import ConnectionToken | ||||
| from authentik.enterprise.providers.rac.models import ConnectionToken, Endpoint | ||||
|  | ||||
|  | ||||
| class ConnectionTokenSerializer(EnterpriseRequiredMixin, ModelSerializer): | ||||
| @ -23,7 +23,7 @@ class ConnectionTokenSerializer(EnterpriseRequiredMixin, ModelSerializer): | ||||
|     user = GroupMemberSerializer(source="session.user", read_only=True) | ||||
|  | ||||
|     class Meta: | ||||
|         model = ConnectionToken | ||||
|         model = Endpoint | ||||
|         fields = [ | ||||
|             "pk", | ||||
|             "provider", | ||||
| @ -49,5 +49,5 @@ class ConnectionTokenViewSet( | ||||
|     filterset_fields = ["endpoint", "session__user", "provider"] | ||||
|     search_fields = ["endpoint__name", "provider__name"] | ||||
|     ordering = ["endpoint__name", "provider__name"] | ||||
|     permission_classes = [OwnerSuperuserPermissions] | ||||
|     permission_classes = [OwnerPermissions] | ||||
|     filter_backends = [OwnerFilter, DjangoFilterBackend, OrderingFilter, SearchFilter] | ||||
|  | ||||
| @ -1,7 +1,5 @@ | ||||
| """RAC Provider API Views""" | ||||
|  | ||||
| from typing import Optional | ||||
|  | ||||
| from django.core.cache import cache | ||||
| from django.db.models import QuerySet | ||||
| from django.urls import reverse | ||||
| @ -36,11 +34,11 @@ class EndpointSerializer(EnterpriseRequiredMixin, ModelSerializer): | ||||
|     provider_obj = RACProviderSerializer(source="provider", read_only=True) | ||||
|     launch_url = SerializerMethodField() | ||||
|  | ||||
|     def get_launch_url(self, endpoint: Endpoint) -> Optional[str]: | ||||
|     def get_launch_url(self, endpoint: Endpoint) -> str | None: | ||||
|         """Build actual launch URL (the provider itself does not have one, just | ||||
|         individual endpoints)""" | ||||
|         try: | ||||
|             # pylint: disable=no-member | ||||
|  | ||||
|             return reverse( | ||||
|                 "authentik_providers_rac:start", | ||||
|                 kwargs={"app": endpoint.provider.application.slug, "endpoint": endpoint.pk}, | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| """RAC Models""" | ||||
|  | ||||
| from typing import Any, Optional | ||||
| from typing import Any | ||||
| from uuid import uuid4 | ||||
|  | ||||
| from deepmerge import always_merger | ||||
| @ -58,7 +58,7 @@ class RACProvider(Provider): | ||||
|     ) | ||||
|  | ||||
|     @property | ||||
|     def launch_url(self) -> Optional[str]: | ||||
|     def launch_url(self) -> str | None: | ||||
|         """URL to this provider and initiate authorization for the user. | ||||
|         Can return None for providers that are not URL-based""" | ||||
|         return "goauthentik.io://providers/rac/launch" | ||||
| @ -112,7 +112,7 @@ class RACPropertyMapping(PropertyMapping): | ||||
|  | ||||
|     static_settings = models.JSONField(default=dict) | ||||
|  | ||||
|     def evaluate(self, user: Optional[User], request: Optional[HttpRequest], **kwargs) -> Any: | ||||
|     def evaluate(self, user: User | None, request: HttpRequest | None, **kwargs) -> Any: | ||||
|         """Evaluate `self.expression` using `**kwargs` as Context.""" | ||||
|         if len(self.static_settings) > 0: | ||||
|             return self.static_settings | ||||
|  | ||||
| @ -47,7 +47,7 @@ class RACStartView(EnterprisePolicyAccessView): | ||||
|                 }, | ||||
|             ) | ||||
|         except FlowNonApplicableException: | ||||
|             raise Http404 | ||||
|             raise Http404 from None | ||||
|         plan.insert_stage( | ||||
|             in_memory_stage( | ||||
|                 RACFinalStage, | ||||
| @ -132,16 +132,7 @@ class RACFinalStage(RedirectStage): | ||||
|             flow=self.executor.plan.flow_pk, | ||||
|             endpoint=self.endpoint.name, | ||||
|         ).from_http(self.request) | ||||
|         setattr( | ||||
|             self.executor.current_stage, | ||||
|             "destination", | ||||
|             self.request.build_absolute_uri( | ||||
|                 reverse( | ||||
|                     "authentik_providers_rac:if-rac", | ||||
|                     kwargs={ | ||||
|                         "token": str(token.token), | ||||
|                     }, | ||||
|                 ) | ||||
|             ), | ||||
|         self.executor.current_stage.destination = self.request.build_absolute_uri( | ||||
|             reverse("authentik_providers_rac:if-rac", kwargs={"token": str(token.token)}) | ||||
|         ) | ||||
|         return super().get_challenge(*args, **kwargs) | ||||
|  | ||||
| @ -2,14 +2,11 @@ | ||||
|  | ||||
| from datetime import datetime | ||||
|  | ||||
| from django.core.cache import cache | ||||
| from django.db.models.signals import post_save, pre_save | ||||
| from django.db.models.signals import pre_save | ||||
| from django.dispatch import receiver | ||||
| from django.utils.timezone import get_current_timezone | ||||
|  | ||||
| from authentik.enterprise.license import CACHE_KEY_ENTERPRISE_LICENSE | ||||
| from authentik.enterprise.models import License | ||||
| from authentik.enterprise.tasks import enterprise_update_usage | ||||
|  | ||||
|  | ||||
| @receiver(pre_save, sender=License) | ||||
| @ -20,10 +17,3 @@ def pre_save_license(sender: type[License], instance: License, **_): | ||||
|     instance.internal_users = status.internal_users | ||||
|     instance.external_users = status.external_users | ||||
|     instance.expiry = datetime.fromtimestamp(status.exp, tz=get_current_timezone()) | ||||
|  | ||||
|  | ||||
| @receiver(post_save, sender=License) | ||||
| def post_save_license(sender: type[License], instance: License, **_): | ||||
|     """Trigger license usage calculation when license is saved""" | ||||
|     cache.delete(CACHE_KEY_ENTERPRISE_LICENSE) | ||||
|     enterprise_update_usage.delay() | ||||
|  | ||||
| @ -92,7 +92,7 @@ class SystemTaskViewSet(ReadOnlyModelViewSet): | ||||
|             task_func.delay(*task.task_call_args, **task.task_call_kwargs) | ||||
|             messages.success( | ||||
|                 self.request, | ||||
|                 _("Successfully started task %(name)s." % {"name": task.name}), | ||||
|                 _("Successfully started task {name}.".format_map({"name": task.name})), | ||||
|             ) | ||||
|             return Response(status=204) | ||||
|         except (ImportError, AttributeError) as exc:  # pragma: no cover | ||||
|  | ||||
| @ -35,7 +35,8 @@ class AuthentikEventsConfig(ManagedAppConfig): | ||||
|     verbose_name = "authentik Events" | ||||
|     default = True | ||||
|  | ||||
|     def reconcile_global_check_deprecations(self): | ||||
|     @ManagedAppConfig.reconcile_global | ||||
|     def check_deprecations(self): | ||||
|         """Check for config deprecations""" | ||||
|         from authentik.events.models import Event, EventAction | ||||
|  | ||||
| @ -56,7 +57,8 @@ class AuthentikEventsConfig(ManagedAppConfig): | ||||
|                 message=msg, | ||||
|             ).save() | ||||
|  | ||||
|     def reconcile_tenant_prefill_tasks(self): | ||||
|     @ManagedAppConfig.reconcile_tenant | ||||
|     def prefill_tasks(self): | ||||
|         """Prefill tasks""" | ||||
|         from authentik.events.models import SystemTask | ||||
|         from authentik.events.system_tasks import _prefill_tasks | ||||
| @ -67,7 +69,8 @@ class AuthentikEventsConfig(ManagedAppConfig): | ||||
|             task.save() | ||||
|             self.logger.debug("prefilled task", task_name=task.name) | ||||
|  | ||||
|     def reconcile_tenant_run_scheduled_tasks(self): | ||||
|     @ManagedAppConfig.reconcile_tenant | ||||
|     def run_scheduled_tasks(self): | ||||
|         """Run schedule tasks which are behind schedule (only applies | ||||
|         to tasks of which we keep metrics)""" | ||||
|         from authentik.events.models import TaskStatus | ||||
|  | ||||
| @ -46,7 +46,7 @@ class ASNContextProcessor(MMDBContextProcessor): | ||||
|             "asn": self.asn_dict(ClientIPMiddleware.get_client_ip(request)), | ||||
|         } | ||||
|  | ||||
|     def asn(self, ip_address: str) -> Optional[ASN]: | ||||
|     def asn(self, ip_address: str) -> ASN | None: | ||||
|         """Wrapper for Reader.asn""" | ||||
|         with Hub.current.start_span( | ||||
|             op="authentik.events.asn.asn", | ||||
| @ -71,7 +71,7 @@ class ASNContextProcessor(MMDBContextProcessor): | ||||
|         } | ||||
|         return asn_dict | ||||
|  | ||||
|     def asn_dict(self, ip_address: str) -> Optional[ASNDict]: | ||||
|     def asn_dict(self, ip_address: str) -> ASNDict | None: | ||||
|         """Wrapper for self.asn that returns a dict""" | ||||
|         asn = self.asn(ip_address) | ||||
|         if not asn: | ||||
|  | ||||
| @ -47,7 +47,7 @@ class GeoIPContextProcessor(MMDBContextProcessor): | ||||
|         # Different key `geoip` vs `geo` for legacy reasons | ||||
|         return {"geoip": self.city(ClientIPMiddleware.get_client_ip(request))} | ||||
|  | ||||
|     def city(self, ip_address: str) -> Optional[City]: | ||||
|     def city(self, ip_address: str) -> City | None: | ||||
|         """Wrapper for Reader.city""" | ||||
|         with Hub.current.start_span( | ||||
|             op="authentik.events.geo.city", | ||||
| @ -76,7 +76,7 @@ class GeoIPContextProcessor(MMDBContextProcessor): | ||||
|             city_dict["city"] = city.city.name | ||||
|         return city_dict | ||||
|  | ||||
|     def city_dict(self, ip_address: str) -> Optional[GeoIPDict]: | ||||
|     def city_dict(self, ip_address: str) -> GeoIPDict | None: | ||||
|         """Wrapper for self.city that returns a dict""" | ||||
|         city = self.city(ip_address) | ||||
|         if not city: | ||||
|  | ||||
| @ -1,7 +1,6 @@ | ||||
| """Common logic for reading MMDB files""" | ||||
|  | ||||
| from pathlib import Path | ||||
| from typing import Optional | ||||
|  | ||||
| from geoip2.database import Reader | ||||
| from structlog.stdlib import get_logger | ||||
| @ -13,7 +12,7 @@ class MMDBContextProcessor(EventContextProcessor): | ||||
|     """Common logic for reading MaxMind DB files, including re-loading if the file has changed""" | ||||
|  | ||||
|     def __init__(self): | ||||
|         self.reader: Optional[Reader] = None | ||||
|         self.reader: Reader | None = None | ||||
|         self._last_mtime: float = 0.0 | ||||
|         self.logger = get_logger() | ||||
|         self.open() | ||||
|  | ||||
| @ -1,8 +1,9 @@ | ||||
| """Events middleware""" | ||||
|  | ||||
| from collections.abc import Callable | ||||
| from functools import partial | ||||
| from threading import Thread | ||||
| from typing import Any, Callable, Optional | ||||
| from typing import Any | ||||
|  | ||||
| from django.conf import settings | ||||
| from django.contrib.sessions.models import Session | ||||
| @ -49,9 +50,9 @@ class EventNewThread(Thread): | ||||
|     action: str | ||||
|     request: HttpRequest | ||||
|     kwargs: dict[str, Any] | ||||
|     user: Optional[User] = None | ||||
|     user: User | None = None | ||||
|  | ||||
|     def __init__(self, action: str, request: HttpRequest, user: Optional[User] = None, **kwargs): | ||||
|     def __init__(self, action: str, request: HttpRequest, user: User | None = None, **kwargs): | ||||
|         super().__init__() | ||||
|         self.action = action | ||||
|         self.request = request | ||||
| @ -82,29 +83,26 @@ class AuditMiddleware: | ||||
|  | ||||
|         self.anonymous_user = get_anonymous_user() | ||||
|  | ||||
|     def get_user(self, request: HttpRequest) -> User: | ||||
|         user = getattr(request, "user", self.anonymous_user) | ||||
|         if not user.is_authenticated: | ||||
|             return self.anonymous_user | ||||
|         return user | ||||
|  | ||||
|     def connect(self, request: HttpRequest): | ||||
|         """Connect signal for automatic logging""" | ||||
|         self._ensure_fallback_user() | ||||
|         user = getattr(request, "user", self.anonymous_user) | ||||
|         if not user.is_authenticated: | ||||
|             user = self.anonymous_user | ||||
|         if not hasattr(request, "request_id"): | ||||
|             return | ||||
|         post_save.connect( | ||||
|             partial(self.post_save_handler, request=request), | ||||
|             partial(self.post_save_handler, user=user, request=request), | ||||
|             dispatch_uid=request.request_id, | ||||
|             weak=False, | ||||
|         ) | ||||
|         pre_delete.connect( | ||||
|             partial(self.pre_delete_handler, request=request), | ||||
|             partial(self.pre_delete_handler, user=user, request=request), | ||||
|             dispatch_uid=request.request_id, | ||||
|             weak=False, | ||||
|         ) | ||||
|         m2m_changed.connect( | ||||
|             partial(self.m2m_changed_handler, request=request), | ||||
|             partial(self.m2m_changed_handler, user=user, request=request), | ||||
|             dispatch_uid=request.request_id, | ||||
|             weak=False, | ||||
|         ) | ||||
| @ -147,31 +145,29 @@ class AuditMiddleware: | ||||
|             ) | ||||
|             thread.run() | ||||
|  | ||||
|     # pylint: disable=too-many-arguments | ||||
|     def post_save_handler( | ||||
|         self, | ||||
|         user: User, | ||||
|         request: HttpRequest, | ||||
|         sender, | ||||
|         instance: Model, | ||||
|         created: bool, | ||||
|         thread_kwargs: Optional[dict] = None, | ||||
|         thread_kwargs: dict | None = None, | ||||
|         **_, | ||||
|     ): | ||||
|         """Signal handler for all object's post_save""" | ||||
|         if not should_log_model(instance): | ||||
|             return | ||||
|         user = self.get_user(request) | ||||
|  | ||||
|         action = EventAction.MODEL_CREATED if created else EventAction.MODEL_UPDATED | ||||
|         thread = EventNewThread(action, request, user=user, model=model_to_dict(instance)) | ||||
|         thread.kwargs.update(thread_kwargs or {}) | ||||
|         thread.run() | ||||
|  | ||||
|     def pre_delete_handler(self, request: HttpRequest, sender, instance: Model, **_): | ||||
|     def pre_delete_handler(self, user: User, request: HttpRequest, sender, instance: Model, **_): | ||||
|         """Signal handler for all object's pre_delete""" | ||||
|         if not should_log_model(instance):  # pragma: no cover | ||||
|             return | ||||
|         user = self.get_user(request) | ||||
|  | ||||
|         EventNewThread( | ||||
|             EventAction.MODEL_DELETED, | ||||
| @ -180,13 +176,14 @@ class AuditMiddleware: | ||||
|             model=model_to_dict(instance), | ||||
|         ).run() | ||||
|  | ||||
|     def m2m_changed_handler(self, request: HttpRequest, sender, instance: Model, action: str, **_): | ||||
|     def m2m_changed_handler( | ||||
|         self, user: User, request: HttpRequest, sender, instance: Model, action: str, **_ | ||||
|     ): | ||||
|         """Signal handler for all object's m2m_changed""" | ||||
|         if action not in ["pre_add", "pre_remove", "post_clear"]: | ||||
|             return | ||||
|         if not should_log_m2m(instance): | ||||
|             return | ||||
|         user = self.get_user(request) | ||||
|  | ||||
|         EventNewThread( | ||||
|             EventAction.MODEL_UPDATED, | ||||
|  | ||||
| @ -7,7 +7,6 @@ from difflib import get_close_matches | ||||
| from functools import lru_cache | ||||
| from inspect import currentframe | ||||
| from smtplib import SMTPException | ||||
| from typing import Optional | ||||
| from uuid import uuid4 | ||||
|  | ||||
| from django.apps import apps | ||||
| @ -52,6 +51,8 @@ from authentik.stages.email.utils import TemplateEmailMessage | ||||
| from authentik.tenants.models import Tenant | ||||
|  | ||||
| LOGGER = get_logger() | ||||
| DISCORD_FIELD_LIMIT = 25 | ||||
| NOTIFICATION_SUMMARY_LENGTH = 75 | ||||
|  | ||||
|  | ||||
| def default_event_duration(): | ||||
| @ -65,7 +66,7 @@ def default_brand(): | ||||
|     return sanitize_dict(model_to_dict(DEFAULT_BRAND)) | ||||
|  | ||||
|  | ||||
| @lru_cache() | ||||
| @lru_cache | ||||
| def django_app_names() -> list[str]: | ||||
|     """Get a cached list of all django apps' names (not labels)""" | ||||
|     return [x.name for x in apps.app_configs.values()] | ||||
| @ -198,7 +199,7 @@ class Event(SerializerModel, ExpiringModel): | ||||
|     @staticmethod | ||||
|     def new( | ||||
|         action: str | EventAction, | ||||
|         app: Optional[str] = None, | ||||
|         app: str | None = None, | ||||
|         **kwargs, | ||||
|     ) -> "Event": | ||||
|         """Create new Event instance from arguments. Instance is NOT saved.""" | ||||
| @ -224,7 +225,7 @@ class Event(SerializerModel, ExpiringModel): | ||||
|         self.user = get_user(user) | ||||
|         return self | ||||
|  | ||||
|     def from_http(self, request: HttpRequest, user: Optional[User] = None) -> "Event": | ||||
|     def from_http(self, request: HttpRequest, user: User | None = None) -> "Event": | ||||
|         """Add data from a Django-HttpRequest, allowing the creation of | ||||
|         Events independently from requests. | ||||
|         `user` arguments optionally overrides user from requests.""" | ||||
| @ -418,7 +419,7 @@ class NotificationTransport(SerializerModel): | ||||
|                 if not isinstance(value, str): | ||||
|                     continue | ||||
|                 # https://birdie0.github.io/discord-webhooks-guide/other/field_limits.html | ||||
|                 if len(fields) >= 25: | ||||
|                 if len(fields) >= DISCORD_FIELD_LIMIT: | ||||
|                     continue | ||||
|                 fields.append({"title": key[:256], "value": value[:1024]}) | ||||
|         body = { | ||||
| @ -451,13 +452,6 @@ class NotificationTransport(SerializerModel): | ||||
|  | ||||
|     def send_email(self, notification: "Notification") -> list[str]: | ||||
|         """Send notification via global email configuration""" | ||||
|         if notification.user.email.strip() == "": | ||||
|             LOGGER.info( | ||||
|                 "Discarding notification as user has no email address", | ||||
|                 user=notification.user, | ||||
|                 notification=notification, | ||||
|             ) | ||||
|             return None | ||||
|         subject_prefix = "authentik Notification: " | ||||
|         context = { | ||||
|             "key_value": { | ||||
| @ -479,7 +473,7 @@ class NotificationTransport(SerializerModel): | ||||
|                     continue | ||||
|                 context["key_value"][key] = value | ||||
|         else: | ||||
|             context["title"] += notification.body[:75] | ||||
|             context["title"] += notification.body[:NOTIFICATION_SUMMARY_LENGTH] | ||||
|         # TODO: improve permission check | ||||
|         if notification.user.is_superuser: | ||||
|             context["source"] = { | ||||
| @ -487,7 +481,7 @@ class NotificationTransport(SerializerModel): | ||||
|             } | ||||
|         mail = TemplateEmailMessage( | ||||
|             subject=subject_prefix + context["title"], | ||||
|             to=[(notification.user.name, notification.user.email)], | ||||
|             to=[f"{notification.user.name} <{notification.user.email}>"], | ||||
|             language=notification.user.locale(), | ||||
|             template_name="email/event_notification.html", | ||||
|             template_context=context, | ||||
| @ -496,7 +490,7 @@ class NotificationTransport(SerializerModel): | ||||
|         try: | ||||
|             from authentik.stages.email.tasks import send_mail | ||||
|  | ||||
|             return send_mail(mail.__dict__)  # pylint: disable=no-value-for-parameter | ||||
|             return send_mail(mail.__dict__) | ||||
|         except (SMTPException, ConnectionError, OSError) as exc: | ||||
|             raise NotificationTransportError(exc) from exc | ||||
|  | ||||
| @ -540,7 +534,11 @@ class Notification(SerializerModel): | ||||
|         return NotificationSerializer | ||||
|  | ||||
|     def __str__(self) -> str: | ||||
|         body_trunc = (self.body[:75] + "..") if len(self.body) > 75 else self.body | ||||
|         body_trunc = ( | ||||
|             (self.body[:NOTIFICATION_SUMMARY_LENGTH] + "..") | ||||
|             if len(self.body) > NOTIFICATION_SUMMARY_LENGTH | ||||
|             else self.body | ||||
|         ) | ||||
|         return f"Notification for user {self.user}: {body_trunc}" | ||||
|  | ||||
|     class Meta: | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| """authentik events signal listener""" | ||||
|  | ||||
| from typing import Any, Optional | ||||
| from typing import Any | ||||
|  | ||||
| from django.contrib.auth.signals import user_logged_in, user_logged_out | ||||
| from django.db.models.signals import post_save, pre_delete | ||||
| @ -42,7 +42,7 @@ def on_user_logged_in(sender, request: HttpRequest, user: User, **_): | ||||
|     request.session[SESSION_LOGIN_EVENT] = event | ||||
|  | ||||
|  | ||||
| def get_login_event(request: HttpRequest) -> Optional[Event]: | ||||
| def get_login_event(request: HttpRequest) -> Event | None: | ||||
|     """Wrapper to get login event that can be mocked in tests""" | ||||
|     return request.session.get(SESSION_LOGIN_EVENT, None) | ||||
|  | ||||
| @ -71,7 +71,7 @@ def on_login_failed( | ||||
|     sender, | ||||
|     credentials: dict[str, str], | ||||
|     request: HttpRequest, | ||||
|     stage: Optional[Stage] = None, | ||||
|     stage: Stage | None = None, | ||||
|     **kwargs, | ||||
| ): | ||||
|     """Failed Login, authentik custom event""" | ||||
|  | ||||
| @ -2,16 +2,15 @@ | ||||
|  | ||||
| from datetime import datetime, timedelta | ||||
| from time import perf_counter | ||||
| from typing import Any, Optional | ||||
| from typing import Any | ||||
|  | ||||
| from django.utils.timezone import now | ||||
| from django.utils.translation import gettext_lazy as _ | ||||
| from structlog.stdlib import get_logger | ||||
| from tenant_schemas_celery.task import TenantTask | ||||
|  | ||||
| from authentik.events.models import Event, EventAction | ||||
| from authentik.events.models import Event, EventAction, TaskStatus | ||||
| from authentik.events.models import SystemTask as DBSystemTask | ||||
| from authentik.events.models import TaskStatus | ||||
| from authentik.events.utils import sanitize_item | ||||
| from authentik.lib.utils.errors import exception_to_string | ||||
|  | ||||
| @ -27,10 +26,10 @@ class SystemTask(TenantTask): | ||||
|     _status: TaskStatus | ||||
|     _messages: list[str] | ||||
|  | ||||
|     _uid: Optional[str] | ||||
|     _uid: str | None | ||||
|     # Precise start time from perf_counter | ||||
|     _start_precise: Optional[float] = None | ||||
|     _start: Optional[datetime] = None | ||||
|     _start_precise: float | None = None | ||||
|     _start: datetime | None = None | ||||
|  | ||||
|     def __init__(self, *args, **kwargs) -> None: | ||||
|         super().__init__(*args, **kwargs) | ||||
| @ -60,14 +59,13 @@ class SystemTask(TenantTask): | ||||
|         self._start = now() | ||||
|         return super().before_start(task_id, args, kwargs) | ||||
|  | ||||
|     def db(self) -> Optional[DBSystemTask]: | ||||
|     def db(self) -> DBSystemTask | None: | ||||
|         """Get DB object for latest task""" | ||||
|         return DBSystemTask.objects.filter( | ||||
|             name=self.__name__, | ||||
|             uid=self._uid, | ||||
|         ).first() | ||||
|  | ||||
|     # pylint: disable=too-many-arguments | ||||
|     def after_return(self, status, retval, task_id, args: list[Any], kwargs: dict[str, Any], einfo): | ||||
|         super().after_return(status, retval, task_id, args, kwargs, einfo=einfo) | ||||
|         if not self._status: | ||||
| @ -97,7 +95,6 @@ class SystemTask(TenantTask): | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|     # pylint: disable=too-many-arguments | ||||
|     def on_failure(self, exc, task_id, args, kwargs, einfo): | ||||
|         super().on_failure(exc, task_id, args, kwargs, einfo=einfo) | ||||
|         if not self._status: | ||||
|  | ||||
| @ -1,7 +1,5 @@ | ||||
| """Event notification tasks""" | ||||
|  | ||||
| from typing import Optional | ||||
|  | ||||
| from django.db.models.query_utils import Q | ||||
| from guardian.shortcuts import get_anonymous_user | ||||
| from structlog.stdlib import get_logger | ||||
| @ -38,7 +36,7 @@ def event_trigger_handler(event_uuid: str, trigger_name: str): | ||||
|     if not event: | ||||
|         LOGGER.warning("event doesn't exist yet or anymore", event_uuid=event_uuid) | ||||
|         return | ||||
|     trigger: Optional[NotificationRule] = NotificationRule.objects.filter(name=trigger_name).first() | ||||
|     trigger: NotificationRule | None = NotificationRule.objects.filter(name=trigger_name).first() | ||||
|     if not trigger: | ||||
|         return | ||||
|  | ||||
|  | ||||
| @ -105,7 +105,7 @@ class TestEvents(TestCase): | ||||
|         # Test brand | ||||
|         request = self.factory.get("/") | ||||
|         brand = Brand(domain="test-brand") | ||||
|         setattr(request, "brand", brand) | ||||
|         request.brand = brand | ||||
|         event = Event.new("unittest").from_http(request) | ||||
|         self.assertEqual( | ||||
|             event.brand, | ||||
|  | ||||
| @ -3,10 +3,9 @@ | ||||
| from django.urls import reverse | ||||
| from rest_framework.test import APITestCase | ||||
|  | ||||
| from authentik.core.models import Application, Token, TokenIntents | ||||
| from authentik.core.models import Application | ||||
| from authentik.core.tests.utils import create_test_admin_user | ||||
| from authentik.events.models import Event, EventAction | ||||
| from authentik.lib.generators import generate_id | ||||
|  | ||||
|  | ||||
| class TestEventsMiddleware(APITestCase): | ||||
| @ -48,30 +47,3 @@ class TestEventsMiddleware(APITestCase): | ||||
|                 context__model__name="test-delete", | ||||
|             ).exists() | ||||
|         ) | ||||
|  | ||||
|     def test_create_with_api(self): | ||||
|         """Test model creation event (with API token auth)""" | ||||
|         self.client.logout() | ||||
|         token = Token.objects.create(user=self.user, intent=TokenIntents.INTENT_API, expiring=False) | ||||
|         uid = generate_id() | ||||
|         self.client.post( | ||||
|             reverse("authentik_api:application-list"), | ||||
|             data={"name": uid, "slug": uid}, | ||||
|             HTTP_AUTHORIZATION=f"Bearer {token.key}", | ||||
|         ) | ||||
|         self.assertTrue(Application.objects.filter(name=uid).exists()) | ||||
|         event = Event.objects.filter( | ||||
|             action=EventAction.MODEL_CREATED, | ||||
|             context__model__model_name="application", | ||||
|             context__model__app="authentik_core", | ||||
|             context__model__name=uid, | ||||
|         ).first() | ||||
|         self.assertIsNotNone(event) | ||||
|         self.assertEqual( | ||||
|             event.user, | ||||
|             { | ||||
|                 "pk": self.user.pk, | ||||
|                 "email": self.user.email, | ||||
|                 "username": self.user.username, | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
| @ -7,7 +7,7 @@ from datetime import date, datetime, time, timedelta | ||||
| from enum import Enum | ||||
| from pathlib import Path | ||||
| from types import GeneratorType, NoneType | ||||
| from typing import Any, Optional | ||||
| from typing import Any | ||||
| from uuid import UUID | ||||
|  | ||||
| from django.contrib.auth.models import AnonymousUser | ||||
| @ -37,7 +37,7 @@ def cleanse_item(key: str, value: Any) -> Any: | ||||
|     """Cleanse a single item""" | ||||
|     if isinstance(value, dict): | ||||
|         return cleanse_dict(value) | ||||
|     if isinstance(value, (list, tuple, set)): | ||||
|     if isinstance(value, list | tuple | set): | ||||
|         for idx, item in enumerate(value): | ||||
|             value[idx] = cleanse_item(key, item) | ||||
|         return value | ||||
| @ -74,7 +74,7 @@ def model_to_dict(model: Model) -> dict[str, Any]: | ||||
|     } | ||||
|  | ||||
|  | ||||
| def get_user(user: User | AnonymousUser, original_user: Optional[User] = None) -> dict[str, Any]: | ||||
| def get_user(user: User | AnonymousUser, original_user: User | None = None) -> dict[str, Any]: | ||||
|     """Convert user object to dictionary, optionally including the original user""" | ||||
|     if isinstance(user, AnonymousUser): | ||||
|         try: | ||||
| @ -95,8 +95,7 @@ def get_user(user: User | AnonymousUser, original_user: Optional[User] = None) - | ||||
|     return user_data | ||||
|  | ||||
|  | ||||
| # pylint: disable=too-many-return-statements,too-many-branches | ||||
| def sanitize_item(value: Any) -> Any: | ||||
| def sanitize_item(value: Any) -> Any:  # noqa: PLR0911, PLR0912 | ||||
|     """Sanitize a single item, ensure it is JSON parsable""" | ||||
|     if is_dataclass(value): | ||||
|         # Because asdict calls `copy.deepcopy(obj)` on everything that's not tuple/dict, | ||||
| @ -115,20 +114,20 @@ def sanitize_item(value: Any) -> Any: | ||||
|         return sanitize_dict(value) | ||||
|     if isinstance(value, GeneratorType): | ||||
|         return sanitize_item(list(value)) | ||||
|     if isinstance(value, (list, tuple, set)): | ||||
|     if isinstance(value, list | tuple | set): | ||||
|         new_values = [] | ||||
|         for item in value: | ||||
|             new_value = sanitize_item(item) | ||||
|             if new_value: | ||||
|                 new_values.append(new_value) | ||||
|         return new_values | ||||
|     if isinstance(value, (User, AnonymousUser)): | ||||
|     if isinstance(value, User | AnonymousUser): | ||||
|         return sanitize_dict(get_user(value)) | ||||
|     if isinstance(value, models.Model): | ||||
|         return sanitize_dict(model_to_dict(value)) | ||||
|     if isinstance(value, UUID): | ||||
|         return value.hex | ||||
|     if isinstance(value, (HttpRequest, WSGIRequest)): | ||||
|     if isinstance(value, HttpRequest | WSGIRequest): | ||||
|         return ... | ||||
|     if isinstance(value, City): | ||||
|         return GEOIP_CONTEXT_PROCESSOR.city_to_dict(value) | ||||
| @ -171,7 +170,7 @@ def sanitize_item(value: Any) -> Any: | ||||
|             "module": value.__module__, | ||||
|         } | ||||
|     # List taken from the stdlib's JSON encoder (_make_iterencode, encoder.py:415) | ||||
|     if isinstance(value, (bool, int, float, NoneType, list, tuple, dict)): | ||||
|     if isinstance(value, bool | int | float | NoneType | list | tuple | dict): | ||||
|         return value | ||||
|     try: | ||||
|         return DjangoJSONEncoder().default(value) | ||||
|  | ||||
| @ -114,7 +114,6 @@ class FlowImportResultSerializer(PassiveSerializer): | ||||
| class FlowViewSet(UsedByMixin, ModelViewSet): | ||||
|     """Flow Viewset""" | ||||
|  | ||||
|     # pylint: disable=no-member | ||||
|     queryset = Flow.objects.all().prefetch_related("stages", "policies") | ||||
|     serializer_class = FlowSerializer | ||||
|     lookup_field = "slug" | ||||
| @ -279,7 +278,7 @@ class FlowViewSet(UsedByMixin, ModelViewSet): | ||||
|         }, | ||||
|     ) | ||||
|     @action(detail=True, pagination_class=None, filter_backends=[]) | ||||
|     def execute(self, request: Request, slug: str): | ||||
|     def execute(self, request: Request, _slug: str): | ||||
|         """Execute flow for current user""" | ||||
|         # Because we pre-plan the flow here, and not in the planner, we need to manually clear | ||||
|         # the history of the inspector | ||||
| @ -294,8 +293,9 @@ class FlowViewSet(UsedByMixin, ModelViewSet): | ||||
|             return bad_request_message( | ||||
|                 request, | ||||
|                 _( | ||||
|                     "Flow not applicable to current user/request: %(messages)s" | ||||
|                     % {"messages": exc.messages} | ||||
|                     "Flow not applicable to current user/request: {messages}".format_map( | ||||
|                         {"messages": exc.messages} | ||||
|                     ) | ||||
|                 ), | ||||
|             ) | ||||
|         return Response( | ||||
|  | ||||
| @ -1,7 +1,6 @@ | ||||
| """Flows Diagram API""" | ||||
|  | ||||
| from dataclasses import dataclass, field | ||||
| from typing import Optional | ||||
|  | ||||
| from django.utils.translation import gettext as _ | ||||
| from guardian.shortcuts import get_objects_for_user | ||||
| @ -18,8 +17,8 @@ class DiagramElement: | ||||
|  | ||||
|     identifier: str | ||||
|     description: str | ||||
|     action: Optional[str] = None | ||||
|     source: Optional[list["DiagramElement"]] = None | ||||
|     action: str | None = None | ||||
|     source: list["DiagramElement"] | None = None | ||||
|  | ||||
|     style: list[str] = field(default_factory=lambda: ["[", "]"]) | ||||
|  | ||||
| @ -66,10 +65,10 @@ class FlowDiagram: | ||||
|         ): | ||||
|             element = DiagramElement( | ||||
|                 f"flow_policy_{p_index}", | ||||
|                 _("Policy (%(type)s)" % {"type": policy_binding.policy._meta.verbose_name}) | ||||
|                 _("Policy ({type})".format_map({"type": policy_binding.policy._meta.verbose_name})) | ||||
|                 + "\n" | ||||
|                 + policy_binding.policy.name, | ||||
|                 _("Binding %(order)d" % {"order": policy_binding.order}), | ||||
|                 _("Binding {order}".format_map({"order": policy_binding.order})), | ||||
|                 parent_elements, | ||||
|                 style=["{{", "}}"], | ||||
|             ) | ||||
| @ -92,7 +91,7 @@ class FlowDiagram: | ||||
|         ): | ||||
|             element = DiagramElement( | ||||
|                 f"stage_{stage_index}_policy_{p_index}", | ||||
|                 _("Policy (%(type)s)" % {"type": policy_binding.policy._meta.verbose_name}) | ||||
|                 _("Policy ({type})".format_map({"type": policy_binding.policy._meta.verbose_name})) | ||||
|                 + "\n" | ||||
|                 + policy_binding.policy.name, | ||||
|                 "", | ||||
| @ -120,7 +119,7 @@ class FlowDiagram: | ||||
|  | ||||
|             element = DiagramElement( | ||||
|                 f"stage_{s_index}", | ||||
|                 _("Stage (%(type)s)" % {"type": stage_binding.stage._meta.verbose_name}) | ||||
|                 _("Stage ({type})".format_map({"type": stage_binding.stage._meta.verbose_name})) | ||||
|                 + "\n" | ||||
|                 + stage_binding.stage.name, | ||||
|                 action, | ||||
|  | ||||
| @ -31,9 +31,10 @@ class AuthentikFlowsConfig(ManagedAppConfig): | ||||
|     verbose_name = "authentik Flows" | ||||
|     default = True | ||||
|  | ||||
|     def reconcile_global_load_stages(self): | ||||
|     @ManagedAppConfig.reconcile_global | ||||
|     def load_stages(self): | ||||
|         """Ensure all stages are loaded""" | ||||
|         from authentik.flows.models import Stage | ||||
|  | ||||
|         for stage in all_subclasses(Stage): | ||||
|             _ = stage().type | ||||
|             _ = stage().view | ||||
|  | ||||
| @ -104,7 +104,7 @@ class FlowErrorChallenge(Challenge): | ||||
|     error = CharField(required=False) | ||||
|     traceback = CharField(required=False) | ||||
|  | ||||
|     def __init__(self, request: Optional[Request] = None, error: Optional[Exception] = None): | ||||
|     def __init__(self, request: Request | None = None, error: Exception | None = None): | ||||
|         super().__init__(data={}) | ||||
|         if not request or not error: | ||||
|             return | ||||
|  | ||||
| @ -1,7 +1,5 @@ | ||||
| """flow exceptions""" | ||||
|  | ||||
| from typing import Optional | ||||
|  | ||||
| from django.utils.translation import gettext_lazy as _ | ||||
|  | ||||
| from authentik.lib.sentry import SentryIgnoredException | ||||
| @ -11,7 +9,7 @@ from authentik.policies.types import PolicyResult | ||||
| class FlowNonApplicableException(SentryIgnoredException): | ||||
|     """Flow does not apply to current user (denied by policy, or otherwise).""" | ||||
|  | ||||
|     policy_result: Optional[PolicyResult] = None | ||||
|     policy_result: PolicyResult | None = None | ||||
|  | ||||
|     @property | ||||
|     def messages(self) -> str: | ||||
|  | ||||
| @ -1,7 +1,7 @@ | ||||
| """Stage Markers""" | ||||
|  | ||||
| from dataclasses import dataclass | ||||
| from typing import TYPE_CHECKING, Optional | ||||
| from typing import TYPE_CHECKING | ||||
|  | ||||
| from django.http.request import HttpRequest | ||||
| from structlog.stdlib import get_logger | ||||
| @ -25,7 +25,7 @@ class StageMarker: | ||||
|         plan: "FlowPlan", | ||||
|         binding: FlowStageBinding, | ||||
|         http_request: HttpRequest, | ||||
|     ) -> Optional[FlowStageBinding]: | ||||
|     ) -> FlowStageBinding | None: | ||||
|         """Process callback for this marker. This should be overridden by sub-classes. | ||||
|         If a stage should be removed, return None.""" | ||||
|         return binding | ||||
| @ -42,7 +42,7 @@ class ReevaluateMarker(StageMarker): | ||||
|         plan: "FlowPlan", | ||||
|         binding: FlowStageBinding, | ||||
|         http_request: HttpRequest, | ||||
|     ) -> Optional[FlowStageBinding]: | ||||
|     ) -> FlowStageBinding | None: | ||||
|         """Re-evaluate policies bound to stage, and if they fail, remove from plan""" | ||||
|         from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER | ||||
|  | ||||
|  | ||||
| @ -16,7 +16,7 @@ def set_oobe_flow_authentication(apps: Apps, schema_editor: BaseDatabaseSchemaEd | ||||
|     users = User.objects.using(db_alias).exclude(username="akadmin") | ||||
|     try: | ||||
|         users = users.exclude(pk=get_anonymous_user().pk) | ||||
|     # pylint: disable=broad-except | ||||
|  | ||||
|     except Exception:  # nosec | ||||
|         pass | ||||
|  | ||||
|  | ||||
| @ -2,7 +2,7 @@ | ||||
|  | ||||
| from base64 import b64decode, b64encode | ||||
| from pickle import dumps, loads  # nosec | ||||
| from typing import TYPE_CHECKING, Optional | ||||
| from typing import TYPE_CHECKING | ||||
| from uuid import uuid4 | ||||
|  | ||||
| from django.db import models | ||||
| @ -83,7 +83,7 @@ class Stage(SerializerModel): | ||||
|     objects = InheritanceManager() | ||||
|  | ||||
|     @property | ||||
|     def type(self) -> type["StageView"]: | ||||
|     def view(self) -> type["StageView"]: | ||||
|         """Return StageView class that implements logic for this stage""" | ||||
|         # This is a bit of a workaround, since we can't set class methods with setattr | ||||
|         if hasattr(self, "__in_memory_type"): | ||||
| @ -95,7 +95,7 @@ class Stage(SerializerModel): | ||||
|         """Return component used to edit this object""" | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def ui_user_settings(self) -> Optional[UserSettingSerializer]: | ||||
|     def ui_user_settings(self) -> UserSettingSerializer | None: | ||||
|         """Entrypoint to integrate with User settings. Can either return None if no | ||||
|         user settings are available, or a challenge.""" | ||||
|         return None | ||||
| @ -113,8 +113,8 @@ def in_memory_stage(view: type["StageView"], **kwargs) -> Stage: | ||||
|     # we set the view as a separate property and reference a generic function | ||||
|     # that returns that member | ||||
|     setattr(stage, "__in_memory_type", view) | ||||
|     setattr(stage, "name", _("Dynamic In-memory stage: %(doc)s" % {"doc": view.__doc__})) | ||||
|     setattr(stage._meta, "verbose_name", class_to_path(view)) | ||||
|     stage.name = _("Dynamic In-memory stage: {doc}".format_map({"doc": view.__doc__})) | ||||
|     stage._meta.verbose_name = class_to_path(view) | ||||
|     for key, value in kwargs.items(): | ||||
|         setattr(stage, key, value) | ||||
|     return stage | ||||
|  | ||||
| @ -1,7 +1,7 @@ | ||||
| """Flows Planner""" | ||||
|  | ||||
| from dataclasses import dataclass, field | ||||
| from typing import Any, Optional | ||||
| from typing import Any | ||||
|  | ||||
| from django.core.cache import cache | ||||
| from django.http import HttpRequest | ||||
| @ -39,7 +39,7 @@ CACHE_TIMEOUT = CONFIG.get_int("cache.timeout_flows") | ||||
| CACHE_PREFIX = "goauthentik.io/flows/planner/" | ||||
|  | ||||
|  | ||||
| def cache_key(flow: Flow, user: Optional[User] = None) -> str: | ||||
| def cache_key(flow: Flow, user: User | None = None) -> str: | ||||
|     """Generate Cache key for flow""" | ||||
|     prefix = CACHE_PREFIX + str(flow.pk) | ||||
|     if user: | ||||
| @ -58,16 +58,16 @@ class FlowPlan: | ||||
|     context: dict[str, Any] = field(default_factory=dict) | ||||
|     markers: list[StageMarker] = field(default_factory=list) | ||||
|  | ||||
|     def append_stage(self, stage: Stage, marker: Optional[StageMarker] = None): | ||||
|     def append_stage(self, stage: Stage, marker: StageMarker | None = None): | ||||
|         """Append `stage` to all stages, optionally with stage marker""" | ||||
|         return self.append(FlowStageBinding(stage=stage), marker) | ||||
|  | ||||
|     def append(self, binding: FlowStageBinding, marker: Optional[StageMarker] = None): | ||||
|     def append(self, binding: FlowStageBinding, marker: StageMarker | None = None): | ||||
|         """Append `stage` to all stages, optionally with stage marker""" | ||||
|         self.bindings.append(binding) | ||||
|         self.markers.append(marker or StageMarker()) | ||||
|  | ||||
|     def insert_stage(self, stage: Stage, marker: Optional[StageMarker] = None): | ||||
|     def insert_stage(self, stage: Stage, marker: StageMarker | None = None): | ||||
|         """Insert stage into plan, as immediate next stage""" | ||||
|         self.bindings.insert(1, FlowStageBinding(stage=stage, order=0)) | ||||
|         self.markers.insert(1, marker or StageMarker()) | ||||
| @ -78,7 +78,7 @@ class FlowPlan: | ||||
|  | ||||
|         self.insert_stage(in_memory_stage(RedirectStage, destination=destination)) | ||||
|  | ||||
|     def next(self, http_request: Optional[HttpRequest]) -> Optional[FlowStageBinding]: | ||||
|     def next(self, http_request: HttpRequest | None) -> FlowStageBinding | None: | ||||
|         """Return next pending stage from the bottom of the list""" | ||||
|         if not self.has_stages: | ||||
|             return None | ||||
| @ -94,7 +94,7 @@ class FlowPlan: | ||||
|             self.markers.remove(marker) | ||||
|             if not self.has_stages: | ||||
|                 return None | ||||
|             # pylint: disable=not-callable | ||||
|  | ||||
|             return self.next(http_request) | ||||
|         return marked_stage | ||||
|  | ||||
| @ -148,9 +148,7 @@ class FlowPlanner: | ||||
|             if not outpost_user: | ||||
|                 raise FlowNonApplicableException() | ||||
|  | ||||
|     def plan( | ||||
|         self, request: HttpRequest, default_context: Optional[dict[str, Any]] = None | ||||
|     ) -> FlowPlan: | ||||
|     def plan(self, request: HttpRequest, default_context: dict[str, Any] | None = None) -> FlowPlan: | ||||
|         """Check each of the flows' policies, check policies for each stage with PolicyBinding | ||||
|         and return ordered list""" | ||||
|         with Hub.current.start_span( | ||||
| @ -214,7 +212,7 @@ class FlowPlanner: | ||||
|         self, | ||||
|         user: User, | ||||
|         request: HttpRequest, | ||||
|         default_context: Optional[dict[str, Any]], | ||||
|         default_context: dict[str, Any] | None, | ||||
|     ) -> FlowPlan: | ||||
|         """Build flow plan by checking each stage in their respective | ||||
|         order and checking the applied policies""" | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| """authentik stage Base view""" | ||||
|  | ||||
| from typing import TYPE_CHECKING, Optional | ||||
| from typing import TYPE_CHECKING | ||||
|  | ||||
| from django.contrib.auth.models import AnonymousUser | ||||
| from django.http import HttpRequest | ||||
| @ -153,7 +153,7 @@ class ChallengeStageView(StageView): | ||||
|                 "app": self.executor.plan.context.get(PLAN_CONTEXT_APPLICATION, ""), | ||||
|                 "user": self.get_pending_user(for_display=True), | ||||
|             } | ||||
|         # pylint: disable=broad-except | ||||
|  | ||||
|         except Exception as exc: | ||||
|             self.logger.warning("failed to template title", exc=exc) | ||||
|             return self.executor.flow.title | ||||
| @ -234,9 +234,9 @@ class ChallengeStageView(StageView): | ||||
| class AccessDeniedChallengeView(ChallengeStageView): | ||||
|     """Used internally by FlowExecutor's stage_invalid()""" | ||||
|  | ||||
|     error_message: Optional[str] | ||||
|     error_message: str | None | ||||
|  | ||||
|     def __init__(self, executor: "FlowExecutorView", error_message: Optional[str] = None, **kwargs): | ||||
|     def __init__(self, executor: "FlowExecutorView", error_message: str | None = None, **kwargs): | ||||
|         super().__init__(executor, **kwargs) | ||||
|         self.error_message = error_message | ||||
|  | ||||
|  | ||||
| @ -1,7 +1,7 @@ | ||||
| """Test helpers""" | ||||
|  | ||||
| from json import loads | ||||
| from typing import Any, Optional | ||||
| from typing import Any | ||||
|  | ||||
| from django.http.response import HttpResponse | ||||
| from django.urls.base import reverse | ||||
| @ -15,12 +15,11 @@ from authentik.flows.models import Flow | ||||
| class FlowTestCase(APITestCase): | ||||
|     """Helpers for testing flows and stages.""" | ||||
|  | ||||
|     # pylint: disable=invalid-name | ||||
|     def assertStageResponse( | ||||
|         self, | ||||
|         response: HttpResponse, | ||||
|         flow: Optional[Flow] = None, | ||||
|         user: Optional[User] = None, | ||||
|         flow: Flow | None = None, | ||||
|         user: User | None = None, | ||||
|         **kwargs, | ||||
|     ) -> dict[str, Any]: | ||||
|         """Assert various attributes of a stage response""" | ||||
| @ -45,7 +44,6 @@ class FlowTestCase(APITestCase): | ||||
|             self.assertEqual(raw_response[key], expected) | ||||
|         return raw_response | ||||
|  | ||||
|     # pylint: disable=invalid-name | ||||
|     def assertStageRedirects(self, response: HttpResponse, to: str) -> dict[str, Any]: | ||||
|         """Wrapper around assertStageResponse that checks for a redirect""" | ||||
|         return self.assertStageResponse( | ||||
|  | ||||
| @ -1,7 +1,6 @@ | ||||
| """flow views tests""" | ||||
|  | ||||
| from unittest.mock import MagicMock, PropertyMock, patch | ||||
| from urllib.parse import urlencode | ||||
|  | ||||
| from django.http import HttpRequest, HttpResponse | ||||
| from django.test.client import RequestFactory | ||||
| @ -19,12 +18,7 @@ from authentik.flows.models import ( | ||||
| from authentik.flows.planner import FlowPlan, FlowPlanner | ||||
| from authentik.flows.stage import PLAN_CONTEXT_PENDING_USER_IDENTIFIER, StageView | ||||
| from authentik.flows.tests import FlowTestCase | ||||
| from authentik.flows.views.executor import ( | ||||
|     NEXT_ARG_NAME, | ||||
|     QS_QUERY, | ||||
|     SESSION_KEY_PLAN, | ||||
|     FlowExecutorView, | ||||
| ) | ||||
| from authentik.flows.views.executor import NEXT_ARG_NAME, SESSION_KEY_PLAN, FlowExecutorView | ||||
| from authentik.lib.generators import generate_id | ||||
| from authentik.policies.dummy.models import DummyPolicy | ||||
| from authentik.policies.models import PolicyBinding | ||||
| @ -127,73 +121,16 @@ class TestFlowExecutor(FlowTestCase): | ||||
|         TO_STAGE_RESPONSE_MOCK, | ||||
|     ) | ||||
|     def test_invalid_flow_redirect(self): | ||||
|         """Test invalid flow with valid redirect destination""" | ||||
|         """Tests that an invalid flow still redirects""" | ||||
|         flow = create_test_flow( | ||||
|             FlowDesignation.AUTHENTICATION, | ||||
|         ) | ||||
|  | ||||
|         dest = "/unique-string" | ||||
|         url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}) | ||||
|         response = self.client.get(url + f"?{QS_QUERY}={urlencode({NEXT_ARG_NAME: dest})}") | ||||
|         response = self.client.get(url + f"?{NEXT_ARG_NAME}={dest}") | ||||
|         self.assertEqual(response.status_code, 302) | ||||
|         self.assertEqual(response.url, "/unique-string") | ||||
|  | ||||
|     @patch( | ||||
|         "authentik.flows.views.executor.to_stage_response", | ||||
|         TO_STAGE_RESPONSE_MOCK, | ||||
|     ) | ||||
|     def test_invalid_flow_invalid_redirect(self): | ||||
|         """Test invalid flow redirect with an invalid URL""" | ||||
|         flow = create_test_flow( | ||||
|             FlowDesignation.AUTHENTICATION, | ||||
|         ) | ||||
|  | ||||
|         dest = "http://something.example.com/unique-string" | ||||
|         url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}) | ||||
|  | ||||
|         response = self.client.get(url + f"?{QS_QUERY}={urlencode({NEXT_ARG_NAME: dest})}") | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|         self.assertStageResponse( | ||||
|             response, | ||||
|             flow, | ||||
|             component="ak-stage-access-denied", | ||||
|             error_message="Invalid next URL", | ||||
|         ) | ||||
|  | ||||
|     @patch( | ||||
|         "authentik.flows.views.executor.to_stage_response", | ||||
|         TO_STAGE_RESPONSE_MOCK, | ||||
|     ) | ||||
|     def test_valid_flow_redirect(self): | ||||
|         """Test valid flow with valid redirect destination""" | ||||
|         flow = create_test_flow() | ||||
|  | ||||
|         dest = "/unique-string" | ||||
|         url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}) | ||||
|  | ||||
|         response = self.client.get(url + f"?{QS_QUERY}={urlencode({NEXT_ARG_NAME: dest})}") | ||||
|         self.assertEqual(response.status_code, 302) | ||||
|         self.assertEqual(response.url, "/unique-string") | ||||
|  | ||||
|     @patch( | ||||
|         "authentik.flows.views.executor.to_stage_response", | ||||
|         TO_STAGE_RESPONSE_MOCK, | ||||
|     ) | ||||
|     def test_valid_flow_invalid_redirect(self): | ||||
|         """Test valid flow redirect with an invalid URL""" | ||||
|         flow = create_test_flow() | ||||
|  | ||||
|         dest = "http://something.example.com/unique-string" | ||||
|         url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}) | ||||
|  | ||||
|         response = self.client.get(url + f"?{QS_QUERY}={urlencode({NEXT_ARG_NAME: dest})}") | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|         self.assertStageResponse( | ||||
|             response, | ||||
|             flow, | ||||
|             component="ak-stage-access-denied", | ||||
|             error_message="Invalid next URL", | ||||
|         ) | ||||
|         self.assertEqual(response.url, reverse("authentik_core:root-redirect")) | ||||
|  | ||||
|     @patch( | ||||
|         "authentik.flows.views.executor.to_stage_response", | ||||
|  | ||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user
	