Compare commits
	
		
			104 Commits
		
	
	
		
			version/20
			...
			version-te
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 5b6b059b40 | |||
| 060cea219b | |||
| af9d82c02d | |||
| cc8fb66da2 | |||
| f0edc7b931 | |||
| b39632abb0 | |||
| c59b859ec0 | |||
| a46939b591 | |||
| bfb4a25026 | |||
| 646276b37c | |||
| 58f9d86d0b | |||
| cf0a268fb1 | |||
| ec783ae587 | |||
| f50d44792c | |||
| b225b0200e | |||
| 507f9b7ae2 | |||
| 5991b82cde | |||
| f38bc8d09e | |||
| 9824f283de | |||
| 341d866c00 | |||
| 965ddcb564 | |||
| a0a1a101e8 | |||
| 277c922ec3 | |||
| f372627d61 | |||
| 1be86325d5 | |||
| 6d71454aa0 | |||
| 75d6aab0bb | |||
| 496dce093a | |||
| f740ba0ffe | |||
| a82af054a4 | |||
| c80e3da644 | |||
| af9bb566f8 | |||
| 5ca929417b | |||
| 3c1c44bda1 | |||
| c05977f144 | |||
| 55333ef1ac | |||
| 49ad6d2aa8 | |||
| b7e4373d6e | |||
| 699c074816 | |||
| c26855f953 | |||
| 1457b38e7e | |||
| 55d08c5be3 | |||
| ffbfbd43cb | |||
| cb24fe5c5d | |||
| aa81d8f12d | |||
| 2ee1a0241b | |||
| 89bc7a037d | |||
| a21683555a | |||
| 5a98235ee0 | |||
| 3ce836fd8b | |||
| 5a5f7814ab | |||
| 907d475897 | |||
| 41503fc0b2 | |||
| cfc7646a5a | |||
| 7103336456 | |||
| 48db4af56d | |||
| 8285b5d9a7 | |||
| 43218bd027 | |||
| 042fae143d | |||
| f6f997525f | |||
| 753fb5e1b2 | |||
| 06a42df732 | |||
| 66a2a62c7b | |||
| 41bbbde232 | |||
| 373c0ff7d0 | |||
| 30345d450c | |||
| b9dc83466d | |||
| f26175a99f | |||
| c7881e6eb4 | |||
| 97b98a4192 | |||
| fc65d3f43a | |||
| aa87695f3c | |||
| c3fb84397a | |||
| 8d78cd97d0 | |||
| 24d2c4089c | |||
| 38f47c65a1 | |||
| 896096374c | |||
| 0e2326ed06 | |||
| a07db454be | |||
| 87a4a81798 | |||
| f0ee743ea1 | |||
| fbac1e9d95 | |||
| d8536ed78e | |||
| 848dae52ab | |||
| f62a470dfa | |||
| 16a8409014 | |||
| dfa5b8aba5 | |||
| 54270e960f | |||
| 6541b7fcef | |||
| 19af49a49b | |||
| 99e189cae3 | |||
| 6f68563df2 | |||
| df03b2a156 | |||
| e1211ba01b | |||
| 24ea3f0ee8 | |||
| 79045ab283 | |||
| e27189364e | |||
| ba224e4eb9 | |||
| 336950628e | |||
| 6ede552292 | |||
| 07b6356b38 | |||
| 4c5730a222 | |||
| 8ab84c8d91 | |||
| 89ef82337d | 
| @ -1,5 +1,5 @@ | |||||||
| [bumpversion] | [bumpversion] | ||||||
| current_version = 2024.2.2 | current_version = 2024.2.1 | ||||||
| tag = True | tag = True | ||||||
| commit = True | commit = True | ||||||
| parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))? | parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))? | ||||||
|  | |||||||
| @ -11,10 +11,6 @@ inputs: | |||||||
|     description: "Docker image arch" |     description: "Docker image arch" | ||||||
|  |  | ||||||
| outputs: | outputs: | ||||||
|   shouldBuild: |  | ||||||
|     description: "Whether to build image or not" |  | ||||||
|     value: ${{ steps.ev.outputs.shouldBuild }} |  | ||||||
|  |  | ||||||
|   sha: |   sha: | ||||||
|     description: "sha" |     description: "sha" | ||||||
|     value: ${{ steps.ev.outputs.sha }} |     value: ${{ steps.ev.outputs.sha }} | ||||||
|  | |||||||
| @ -7,8 +7,6 @@ from time import time | |||||||
| parser = configparser.ConfigParser() | parser = configparser.ConfigParser() | ||||||
| parser.read(".bumpversion.cfg") | parser.read(".bumpversion.cfg") | ||||||
|  |  | ||||||
| should_build = str(os.environ.get("DOCKER_USERNAME", None) is not None).lower() |  | ||||||
|  |  | ||||||
| branch_name = os.environ["GITHUB_REF"] | branch_name = os.environ["GITHUB_REF"] | ||||||
| if os.environ.get("GITHUB_HEAD_REF", "") != "": | if os.environ.get("GITHUB_HEAD_REF", "") != "": | ||||||
|     branch_name = os.environ["GITHUB_HEAD_REF"] |     branch_name = os.environ["GITHUB_HEAD_REF"] | ||||||
| @ -54,7 +52,6 @@ image_main_tag = image_tags[0] | |||||||
| image_tags_rendered = ",".join(image_tags) | image_tags_rendered = ",".join(image_tags) | ||||||
|  |  | ||||||
| with open(os.environ["GITHUB_OUTPUT"], "a+", encoding="utf-8") as _output: | with open(os.environ["GITHUB_OUTPUT"], "a+", encoding="utf-8") as _output: | ||||||
|     print("shouldBuild=%s" % should_build, file=_output) |  | ||||||
|     print("sha=%s" % sha, file=_output) |     print("sha=%s" % sha, file=_output) | ||||||
|     print("version=%s" % version, file=_output) |     print("version=%s" % version, file=_output) | ||||||
|     print("prerelease=%s" % prerelease, file=_output) |     print("prerelease=%s" % prerelease, file=_output) | ||||||
|  | |||||||
							
								
								
									
										11
									
								
								.github/workflows/ci-main.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										11
									
								
								.github/workflows/ci-main.yml
									
									
									
									
										vendored
									
									
								
							| @ -28,10 +28,7 @@ jobs: | |||||||
|           - bandit |           - bandit | ||||||
|           - black |           - black | ||||||
|           - codespell |           - codespell | ||||||
|           - isort |  | ||||||
|           - pending-migrations |           - pending-migrations | ||||||
|           # - pylint |  | ||||||
|           - pyright |  | ||||||
|           - ruff |           - ruff | ||||||
|     runs-on: ubuntu-latest |     runs-on: ubuntu-latest | ||||||
|     steps: |     steps: | ||||||
| @ -219,6 +216,7 @@ jobs: | |||||||
|       # Needed to upload contianer images to ghcr.io |       # Needed to upload contianer images to ghcr.io | ||||||
|       packages: write |       packages: write | ||||||
|     timeout-minutes: 120 |     timeout-minutes: 120 | ||||||
|  |     if: "github.repository == 'goauthentik/authentik'" | ||||||
|     steps: |     steps: | ||||||
|       - uses: actions/checkout@v4 |       - uses: actions/checkout@v4 | ||||||
|         with: |         with: | ||||||
| @ -230,13 +228,10 @@ jobs: | |||||||
|       - name: prepare variables |       - name: prepare variables | ||||||
|         uses: ./.github/actions/docker-push-variables |         uses: ./.github/actions/docker-push-variables | ||||||
|         id: ev |         id: ev | ||||||
|         env: |  | ||||||
|           DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} |  | ||||||
|         with: |         with: | ||||||
|           image-name: ghcr.io/goauthentik/dev-server |           image-name: ghcr.io/goauthentik/dev-server | ||||||
|           image-arch: ${{ matrix.arch }} |           image-arch: ${{ matrix.arch }} | ||||||
|       - name: Login to Container Registry |       - name: Login to Container Registry | ||||||
|         if: ${{ steps.ev.outputs.shouldBuild == 'true' }} |  | ||||||
|         uses: docker/login-action@v3 |         uses: docker/login-action@v3 | ||||||
|         with: |         with: | ||||||
|           registry: ghcr.io |           registry: ghcr.io | ||||||
| @ -252,7 +247,7 @@ jobs: | |||||||
|             GEOIPUPDATE_ACCOUNT_ID=${{ secrets.GEOIPUPDATE_ACCOUNT_ID }} |             GEOIPUPDATE_ACCOUNT_ID=${{ secrets.GEOIPUPDATE_ACCOUNT_ID }} | ||||||
|             GEOIPUPDATE_LICENSE_KEY=${{ secrets.GEOIPUPDATE_LICENSE_KEY }} |             GEOIPUPDATE_LICENSE_KEY=${{ secrets.GEOIPUPDATE_LICENSE_KEY }} | ||||||
|           tags: ${{ steps.ev.outputs.imageTags }} |           tags: ${{ steps.ev.outputs.imageTags }} | ||||||
|           push: ${{ steps.ev.outputs.shouldBuild == 'true' }} |           push: true | ||||||
|           build-args: | |           build-args: | | ||||||
|             GIT_BUILD_HASH=${{ steps.ev.outputs.sha }} |             GIT_BUILD_HASH=${{ steps.ev.outputs.sha }} | ||||||
|           cache-from: type=gha |           cache-from: type=gha | ||||||
| @ -274,8 +269,6 @@ jobs: | |||||||
|       - name: prepare variables |       - name: prepare variables | ||||||
|         uses: ./.github/actions/docker-push-variables |         uses: ./.github/actions/docker-push-variables | ||||||
|         id: ev |         id: ev | ||||||
|         env: |  | ||||||
|           DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} |  | ||||||
|         with: |         with: | ||||||
|           image-name: ghcr.io/goauthentik/dev-server |           image-name: ghcr.io/goauthentik/dev-server | ||||||
|       - name: Comment on PR |       - name: Comment on PR | ||||||
|  | |||||||
							
								
								
									
										6
									
								
								.github/workflows/ci-outpost.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										6
									
								
								.github/workflows/ci-outpost.yml
									
									
									
									
										vendored
									
									
								
							| @ -71,6 +71,7 @@ jobs: | |||||||
|     permissions: |     permissions: | ||||||
|       # Needed to upload contianer images to ghcr.io |       # Needed to upload contianer images to ghcr.io | ||||||
|       packages: write |       packages: write | ||||||
|  |     if: "github.repository == 'goauthentik/authentik'" | ||||||
|     steps: |     steps: | ||||||
|       - uses: actions/checkout@v4 |       - uses: actions/checkout@v4 | ||||||
|         with: |         with: | ||||||
| @ -82,12 +83,9 @@ jobs: | |||||||
|       - name: prepare variables |       - name: prepare variables | ||||||
|         uses: ./.github/actions/docker-push-variables |         uses: ./.github/actions/docker-push-variables | ||||||
|         id: ev |         id: ev | ||||||
|         env: |  | ||||||
|           DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} |  | ||||||
|         with: |         with: | ||||||
|           image-name: ghcr.io/goauthentik/dev-${{ matrix.type }} |           image-name: ghcr.io/goauthentik/dev-${{ matrix.type }} | ||||||
|       - name: Login to Container Registry |       - name: Login to Container Registry | ||||||
|         if: ${{ steps.ev.outputs.shouldBuild == 'true' }} |  | ||||||
|         uses: docker/login-action@v3 |         uses: docker/login-action@v3 | ||||||
|         with: |         with: | ||||||
|           registry: ghcr.io |           registry: ghcr.io | ||||||
| @ -100,7 +98,7 @@ jobs: | |||||||
|         with: |         with: | ||||||
|           tags: ${{ steps.ev.outputs.imageTags }} |           tags: ${{ steps.ev.outputs.imageTags }} | ||||||
|           file: ${{ matrix.type }}.Dockerfile |           file: ${{ matrix.type }}.Dockerfile | ||||||
|           push: ${{ steps.ev.outputs.shouldBuild == 'true' }} |           push: true | ||||||
|           build-args: | |           build-args: | | ||||||
|             GIT_BUILD_HASH=${{ steps.ev.outputs.sha }} |             GIT_BUILD_HASH=${{ steps.ev.outputs.sha }} | ||||||
|           platforms: linux/amd64,linux/arm64 |           platforms: linux/amd64,linux/arm64 | ||||||
|  | |||||||
							
								
								
									
										6
									
								
								.github/workflows/release-publish.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										6
									
								
								.github/workflows/release-publish.yml
									
									
									
									
										vendored
									
									
								
							| @ -20,8 +20,6 @@ jobs: | |||||||
|       - name: prepare variables |       - name: prepare variables | ||||||
|         uses: ./.github/actions/docker-push-variables |         uses: ./.github/actions/docker-push-variables | ||||||
|         id: ev |         id: ev | ||||||
|         env: |  | ||||||
|           DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} |  | ||||||
|         with: |         with: | ||||||
|           image-name: ghcr.io/goauthentik/server,beryju/authentik |           image-name: ghcr.io/goauthentik/server,beryju/authentik | ||||||
|       - name: Docker Login Registry |       - name: Docker Login Registry | ||||||
| @ -74,8 +72,6 @@ jobs: | |||||||
|       - name: prepare variables |       - name: prepare variables | ||||||
|         uses: ./.github/actions/docker-push-variables |         uses: ./.github/actions/docker-push-variables | ||||||
|         id: ev |         id: ev | ||||||
|         env: |  | ||||||
|           DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} |  | ||||||
|         with: |         with: | ||||||
|           image-name: ghcr.io/goauthentik/${{ matrix.type }},beryju/authentik-${{ matrix.type }} |           image-name: ghcr.io/goauthentik/${{ matrix.type }},beryju/authentik-${{ matrix.type }} | ||||||
|       - name: make empty clients |       - name: make empty clients | ||||||
| @ -172,8 +168,6 @@ jobs: | |||||||
|       - name: prepare variables |       - name: prepare variables | ||||||
|         uses: ./.github/actions/docker-push-variables |         uses: ./.github/actions/docker-push-variables | ||||||
|         id: ev |         id: ev | ||||||
|         env: |  | ||||||
|           DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} |  | ||||||
|         with: |         with: | ||||||
|           image-name: ghcr.io/goauthentik/server |           image-name: ghcr.io/goauthentik/server | ||||||
|       - name: Get static files from docker image |       - name: Get static files from docker image | ||||||
|  | |||||||
							
								
								
									
										2
									
								
								.github/workflows/release-tag.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/release-tag.yml
									
									
									
									
										vendored
									
									
								
							| @ -32,8 +32,6 @@ jobs: | |||||||
|       - name: prepare variables |       - name: prepare variables | ||||||
|         uses: ./.github/actions/docker-push-variables |         uses: ./.github/actions/docker-push-variables | ||||||
|         id: ev |         id: ev | ||||||
|         env: |  | ||||||
|           DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} |  | ||||||
|         with: |         with: | ||||||
|           image-name: ghcr.io/goauthentik/server |           image-name: ghcr.io/goauthentik/server | ||||||
|       - name: Create Release |       - name: Create Release | ||||||
|  | |||||||
							
								
								
									
										3
									
								
								.vscode/extensions.json
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.vscode/extensions.json
									
									
									
									
										vendored
									
									
								
							| @ -10,8 +10,7 @@ | |||||||
|         "Gruntfuggly.todo-tree", |         "Gruntfuggly.todo-tree", | ||||||
|         "mechatroner.rainbow-csv", |         "mechatroner.rainbow-csv", | ||||||
|         "ms-python.black-formatter", |         "ms-python.black-formatter", | ||||||
|         "ms-python.isort", |         "charliermarsh.ruff", | ||||||
|         "ms-python.pylint", |  | ||||||
|         "ms-python.python", |         "ms-python.python", | ||||||
|         "ms-python.vscode-pylance", |         "ms-python.vscode-pylance", | ||||||
|         "ms-python.black-formatter", |         "ms-python.black-formatter", | ||||||
|  | |||||||
| @ -103,10 +103,9 @@ RUN --mount=type=bind,target=./pyproject.toml,src=./pyproject.toml \ | |||||||
|     --mount=type=cache,target=/root/.cache/pip \ |     --mount=type=cache,target=/root/.cache/pip \ | ||||||
|     --mount=type=cache,target=/root/.cache/pypoetry \ |     --mount=type=cache,target=/root/.cache/pypoetry \ | ||||||
|     python -m venv /ak-root/venv/ && \ |     python -m venv /ak-root/venv/ && \ | ||||||
|     bash -c "source ${VENV_PATH}/bin/activate && \ |  | ||||||
|     pip3 install --upgrade pip && \ |     pip3 install --upgrade pip && \ | ||||||
|     pip3 install poetry && \ |     pip3 install poetry && \ | ||||||
|         poetry install --only=main --no-ansi --no-interaction --no-root" |     poetry install --only=main --no-ansi --no-interaction | ||||||
|  |  | ||||||
| # Stage 6: Run | # Stage 6: Run | ||||||
| FROM docker.io/python:3.12.2-slim-bookworm AS final-image | FROM docker.io/python:3.12.2-slim-bookworm AS final-image | ||||||
|  | |||||||
							
								
								
									
										14
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										14
									
								
								Makefile
									
									
									
									
									
								
							| @ -59,15 +59,12 @@ test: ## Run the server tests and produce a coverage report (locally) | |||||||
| 	coverage report | 	coverage report | ||||||
|  |  | ||||||
| lint-fix:  ## Lint and automatically fix errors in the python source code. Reports spelling errors. | lint-fix:  ## Lint and automatically fix errors in the python source code. Reports spelling errors. | ||||||
| 	isort $(PY_SOURCES) |  | ||||||
| 	black $(PY_SOURCES) | 	black $(PY_SOURCES) | ||||||
| 	ruff --fix $(PY_SOURCES) | 	ruff check --fix $(PY_SOURCES) | ||||||
| 	codespell -w $(CODESPELL_ARGS) | 	codespell -w $(CODESPELL_ARGS) | ||||||
|  |  | ||||||
| lint: ## Lint the python and golang sources | lint: ## Lint the python and golang sources | ||||||
| 	bandit -r $(PY_SOURCES) -x node_modules | 	bandit -r $(PY_SOURCES) -x node_modules | ||||||
| 	./web/node_modules/.bin/pyright $(PY_SOURCES) |  | ||||||
| 	pylint $(PY_SOURCES) |  | ||||||
| 	golangci-lint run -v | 	golangci-lint run -v | ||||||
|  |  | ||||||
| core-install: | core-install: | ||||||
| @ -249,9 +246,6 @@ ci--meta-debug: | |||||||
| 	python -V | 	python -V | ||||||
| 	node --version | 	node --version | ||||||
|  |  | ||||||
| ci-pylint: ci--meta-debug |  | ||||||
| 	pylint $(PY_SOURCES) |  | ||||||
|  |  | ||||||
| ci-black: ci--meta-debug | ci-black: ci--meta-debug | ||||||
| 	black --check $(PY_SOURCES) | 	black --check $(PY_SOURCES) | ||||||
|  |  | ||||||
| @ -261,14 +255,8 @@ ci-ruff: ci--meta-debug | |||||||
| ci-codespell: ci--meta-debug | ci-codespell: ci--meta-debug | ||||||
| 	codespell $(CODESPELL_ARGS) -s | 	codespell $(CODESPELL_ARGS) -s | ||||||
|  |  | ||||||
| ci-isort: ci--meta-debug |  | ||||||
| 	isort --check $(PY_SOURCES) |  | ||||||
|  |  | ||||||
| ci-bandit: ci--meta-debug | ci-bandit: ci--meta-debug | ||||||
| 	bandit -r $(PY_SOURCES) | 	bandit -r $(PY_SOURCES) | ||||||
|  |  | ||||||
| ci-pyright: ci--meta-debug |  | ||||||
| 	./web/node_modules/.bin/pyright $(PY_SOURCES) |  | ||||||
|  |  | ||||||
| ci-pending-migrations: ci--meta-debug | ci-pending-migrations: ci--meta-debug | ||||||
| 	ak makemigrations --check | 	ak makemigrations --check | ||||||
|  | |||||||
| @ -1,13 +1,12 @@ | |||||||
| """authentik root module""" | """authentik root module""" | ||||||
|  |  | ||||||
| from os import environ | from os import environ | ||||||
| from typing import Optional |  | ||||||
|  |  | ||||||
| __version__ = "2024.2.2" | __version__ = "2024.2.1" | ||||||
| ENV_GIT_HASH_KEY = "GIT_BUILD_HASH" | ENV_GIT_HASH_KEY = "GIT_BUILD_HASH" | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_build_hash(fallback: Optional[str] = None) -> str: | def get_build_hash(fallback: str | None = None) -> str: | ||||||
|     """Get build hash""" |     """Get build hash""" | ||||||
|     build_hash = environ.get(ENV_GIT_HASH_KEY, fallback if fallback else "") |     build_hash = environ.get(ENV_GIT_HASH_KEY, fallback if fallback else "") | ||||||
|     return fallback if build_hash == "" and fallback else build_hash |     return fallback if build_hash == "" and fallback else build_hash | ||||||
|  | |||||||
| @ -18,7 +18,7 @@ class AuthentikAPIConfig(AppConfig): | |||||||
|  |  | ||||||
|         # Class is defined here as it needs to be created early enough that drf-spectacular will |         # Class is defined here as it needs to be created early enough that drf-spectacular will | ||||||
|         # find it, but also won't cause any import issues |         # find it, but also won't cause any import issues | ||||||
|         # pylint: disable=unused-variable |  | ||||||
|         class TokenSchema(OpenApiAuthenticationExtension): |         class TokenSchema(OpenApiAuthenticationExtension): | ||||||
|             """Auth schema""" |             """Auth schema""" | ||||||
|  |  | ||||||
|  | |||||||
| @ -1,7 +1,7 @@ | |||||||
| """API Authentication""" | """API Authentication""" | ||||||
|  |  | ||||||
| from hmac import compare_digest | from hmac import compare_digest | ||||||
| from typing import Any, Optional | from typing import Any | ||||||
|  |  | ||||||
| from django.conf import settings | from django.conf import settings | ||||||
| from rest_framework.authentication import BaseAuthentication, get_authorization_header | from rest_framework.authentication import BaseAuthentication, get_authorization_header | ||||||
| @ -17,7 +17,7 @@ from authentik.providers.oauth2.constants import SCOPE_AUTHENTIK_API | |||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  |  | ||||||
|  |  | ||||||
| def validate_auth(header: bytes) -> Optional[str]: | def validate_auth(header: bytes) -> str | None: | ||||||
|     """Validate that the header is in a correct format, |     """Validate that the header is in a correct format, | ||||||
|     returns type and credentials""" |     returns type and credentials""" | ||||||
|     auth_credentials = header.decode().strip() |     auth_credentials = header.decode().strip() | ||||||
| @ -32,7 +32,7 @@ def validate_auth(header: bytes) -> Optional[str]: | |||||||
|     return auth_credentials |     return auth_credentials | ||||||
|  |  | ||||||
|  |  | ||||||
| def bearer_auth(raw_header: bytes) -> Optional[User]: | def bearer_auth(raw_header: bytes) -> User | None: | ||||||
|     """raw_header in the Format of `Bearer ....`""" |     """raw_header in the Format of `Bearer ....`""" | ||||||
|     user = auth_user_lookup(raw_header) |     user = auth_user_lookup(raw_header) | ||||||
|     if not user: |     if not user: | ||||||
| @ -42,7 +42,7 @@ def bearer_auth(raw_header: bytes) -> Optional[User]: | |||||||
|     return user |     return user | ||||||
|  |  | ||||||
|  |  | ||||||
| def auth_user_lookup(raw_header: bytes) -> Optional[User]: | def auth_user_lookup(raw_header: bytes) -> User | None: | ||||||
|     """raw_header in the Format of `Bearer ....`""" |     """raw_header in the Format of `Bearer ....`""" | ||||||
|     from authentik.providers.oauth2.models import AccessToken |     from authentik.providers.oauth2.models import AccessToken | ||||||
|  |  | ||||||
| @ -75,7 +75,7 @@ def auth_user_lookup(raw_header: bytes) -> Optional[User]: | |||||||
|     raise AuthenticationFailed("Token invalid/expired") |     raise AuthenticationFailed("Token invalid/expired") | ||||||
|  |  | ||||||
|  |  | ||||||
| def token_secret_key(value: str) -> Optional[User]: | def token_secret_key(value: str) -> User | None: | ||||||
|     """Check if the token is the secret key |     """Check if the token is the secret key | ||||||
|     and return the service account for the managed outpost""" |     and return the service account for the managed outpost""" | ||||||
|     from authentik.outposts.apps import MANAGED_OUTPOST |     from authentik.outposts.apps import MANAGED_OUTPOST | ||||||
|  | |||||||
| @ -25,17 +25,17 @@ class TestAPIAuth(TestCase): | |||||||
|     def test_invalid_type(self): |     def test_invalid_type(self): | ||||||
|         """Test invalid type""" |         """Test invalid type""" | ||||||
|         with self.assertRaises(AuthenticationFailed): |         with self.assertRaises(AuthenticationFailed): | ||||||
|             bearer_auth("foo bar".encode()) |             bearer_auth(b"foo bar") | ||||||
|  |  | ||||||
|     def test_invalid_empty(self): |     def test_invalid_empty(self): | ||||||
|         """Test invalid type""" |         """Test invalid type""" | ||||||
|         self.assertIsNone(bearer_auth("Bearer ".encode())) |         self.assertIsNone(bearer_auth(b"Bearer ")) | ||||||
|         self.assertIsNone(bearer_auth("".encode())) |         self.assertIsNone(bearer_auth(b"")) | ||||||
|  |  | ||||||
|     def test_invalid_no_token(self): |     def test_invalid_no_token(self): | ||||||
|         """Test invalid with no token""" |         """Test invalid with no token""" | ||||||
|         with self.assertRaises(AuthenticationFailed): |         with self.assertRaises(AuthenticationFailed): | ||||||
|             auth = b64encode(":abc".encode()).decode() |             auth = b64encode(b":abc").decode() | ||||||
|             self.assertIsNone(bearer_auth(f"Basic :{auth}".encode())) |             self.assertIsNone(bearer_auth(f"Basic :{auth}".encode())) | ||||||
|  |  | ||||||
|     def test_bearer_valid(self): |     def test_bearer_valid(self): | ||||||
|  | |||||||
| @ -1,6 +1,6 @@ | |||||||
| """authentik API Modelviewset tests""" | """authentik API Modelviewset tests""" | ||||||
|  |  | ||||||
| from typing import Callable | from collections.abc import Callable | ||||||
|  |  | ||||||
| from django.test import TestCase | from django.test import TestCase | ||||||
| from rest_framework.viewsets import ModelViewSet, ReadOnlyModelViewSet | from rest_framework.viewsets import ModelViewSet, ReadOnlyModelViewSet | ||||||
| @ -26,6 +26,6 @@ def viewset_tester_factory(test_viewset: type[ModelViewSet]) -> Callable: | |||||||
|  |  | ||||||
|  |  | ||||||
| for _, viewset, _ in router.registry: | for _, viewset, _ in router.registry: | ||||||
|     if not issubclass(viewset, (ModelViewSet, ReadOnlyModelViewSet)): |     if not issubclass(viewset, ModelViewSet | ReadOnlyModelViewSet): | ||||||
|         continue |         continue | ||||||
|     setattr(TestModelViewSets, f"test_viewset_{viewset.__name__}", viewset_tester_factory(viewset)) |     setattr(TestModelViewSets, f"test_viewset_{viewset.__name__}", viewset_tester_factory(viewset)) | ||||||
|  | |||||||
| @ -33,7 +33,7 @@ for _authentik_app in get_apps(): | |||||||
|             app_name=_authentik_app.name, |             app_name=_authentik_app.name, | ||||||
|         ) |         ) | ||||||
|         continue |         continue | ||||||
|     urls: list = getattr(api_urls, "api_urlpatterns") |     urls: list = api_urls.api_urlpatterns | ||||||
|     for url in urls: |     for url in urls: | ||||||
|         if isinstance(url, URLPattern): |         if isinstance(url, URLPattern): | ||||||
|             _other_urls.append(url) |             _other_urls.append(url) | ||||||
|  | |||||||
| @ -52,7 +52,9 @@ class BlueprintInstanceSerializer(ModelSerializer): | |||||||
|         valid, logs = Importer.from_string(content, context).validate() |         valid, logs = Importer.from_string(content, context).validate() | ||||||
|         if not valid: |         if not valid: | ||||||
|             text_logs = "\n".join([x["event"] for x in logs]) |             text_logs = "\n".join([x["event"] for x in logs]) | ||||||
|             raise ValidationError(_("Failed to validate blueprint: %(logs)s" % {"logs": text_logs})) |             raise ValidationError( | ||||||
|  |                 _("Failed to validate blueprint: {logs}".format_map({"logs": text_logs})) | ||||||
|  |             ) | ||||||
|         return content |         return content | ||||||
|  |  | ||||||
|     def validate(self, attrs: dict) -> dict: |     def validate(self, attrs: dict) -> dict: | ||||||
|  | |||||||
| @ -1,5 +1,6 @@ | |||||||
| """authentik Blueprints app""" | """authentik Blueprints app""" | ||||||
|  |  | ||||||
|  | from collections.abc import Callable | ||||||
| from importlib import import_module | from importlib import import_module | ||||||
| from inspect import ismethod | from inspect import ismethod | ||||||
|  |  | ||||||
| @ -13,8 +14,8 @@ class ManagedAppConfig(AppConfig): | |||||||
|  |  | ||||||
|     logger: BoundLogger |     logger: BoundLogger | ||||||
|  |  | ||||||
|     RECONCILE_GLOBAL_PREFIX: str = "reconcile_global_" |     RECONCILE_GLOBAL_CATEGORY: str = "global" | ||||||
|     RECONCILE_TENANT_PREFIX: str = "reconcile_tenant_" |     RECONCILE_TENANT_CATEGORY: str = "tenant" | ||||||
|  |  | ||||||
|     def __init__(self, app_name: str, *args, **kwargs) -> None: |     def __init__(self, app_name: str, *args, **kwargs) -> None: | ||||||
|         super().__init__(app_name, *args, **kwargs) |         super().__init__(app_name, *args, **kwargs) | ||||||
| @ -22,8 +23,8 @@ class ManagedAppConfig(AppConfig): | |||||||
|  |  | ||||||
|     def ready(self) -> None: |     def ready(self) -> None: | ||||||
|         self.import_related() |         self.import_related() | ||||||
|         self.reconcile_global() |         self._reconcile_global() | ||||||
|         self.reconcile_tenant() |         self._reconcile_tenant() | ||||||
|         return super().ready() |         return super().ready() | ||||||
|  |  | ||||||
|     def import_related(self): |     def import_related(self): | ||||||
| @ -51,7 +52,8 @@ class ManagedAppConfig(AppConfig): | |||||||
|             meth = getattr(self, meth_name) |             meth = getattr(self, meth_name) | ||||||
|             if not ismethod(meth): |             if not ismethod(meth): | ||||||
|                 continue |                 continue | ||||||
|             if not meth_name.startswith(prefix): |             category = getattr(meth, "_authentik_managed_reconcile", None) | ||||||
|  |             if category != prefix: | ||||||
|                 continue |                 continue | ||||||
|             name = meth_name.replace(prefix, "") |             name = meth_name.replace(prefix, "") | ||||||
|             try: |             try: | ||||||
| @ -61,7 +63,19 @@ class ManagedAppConfig(AppConfig): | |||||||
|             except (DatabaseError, ProgrammingError, InternalError) as exc: |             except (DatabaseError, ProgrammingError, InternalError) as exc: | ||||||
|                 self.logger.warning("Failed to run reconcile", name=name, exc=exc) |                 self.logger.warning("Failed to run reconcile", name=name, exc=exc) | ||||||
|  |  | ||||||
|     def reconcile_tenant(self) -> None: |     @staticmethod | ||||||
|  |     def reconcile_tenant(func: Callable): | ||||||
|  |         """Mark a function to be called on startup (for each tenant)""" | ||||||
|  |         func._authentik_managed_reconcile = ManagedAppConfig.RECONCILE_TENANT_CATEGORY | ||||||
|  |         return func | ||||||
|  |  | ||||||
|  |     @staticmethod | ||||||
|  |     def reconcile_global(func: Callable): | ||||||
|  |         """Mark a function to be called on startup (globally)""" | ||||||
|  |         func._authentik_managed_reconcile = ManagedAppConfig.RECONCILE_GLOBAL_CATEGORY | ||||||
|  |         return func | ||||||
|  |  | ||||||
|  |     def _reconcile_tenant(self) -> None: | ||||||
|         """reconcile ourselves for tenanted methods""" |         """reconcile ourselves for tenanted methods""" | ||||||
|         from authentik.tenants.models import Tenant |         from authentik.tenants.models import Tenant | ||||||
|  |  | ||||||
| @ -72,9 +86,9 @@ class ManagedAppConfig(AppConfig): | |||||||
|             return |             return | ||||||
|         for tenant in tenants: |         for tenant in tenants: | ||||||
|             with tenant: |             with tenant: | ||||||
|                 self._reconcile(self.RECONCILE_TENANT_PREFIX) |                 self._reconcile(self.RECONCILE_TENANT_CATEGORY) | ||||||
|  |  | ||||||
|     def reconcile_global(self) -> None: |     def _reconcile_global(self) -> None: | ||||||
|         """ |         """ | ||||||
|         reconcile ourselves for global methods. |         reconcile ourselves for global methods. | ||||||
|         Used for signals, tasks, etc. Database queries should not be made in here. |         Used for signals, tasks, etc. Database queries should not be made in here. | ||||||
| @ -82,7 +96,7 @@ class ManagedAppConfig(AppConfig): | |||||||
|         from django_tenants.utils import get_public_schema_name, schema_context |         from django_tenants.utils import get_public_schema_name, schema_context | ||||||
|  |  | ||||||
|         with schema_context(get_public_schema_name()): |         with schema_context(get_public_schema_name()): | ||||||
|             self._reconcile(self.RECONCILE_GLOBAL_PREFIX) |             self._reconcile(self.RECONCILE_GLOBAL_CATEGORY) | ||||||
|  |  | ||||||
|  |  | ||||||
| class AuthentikBlueprintsConfig(ManagedAppConfig): | class AuthentikBlueprintsConfig(ManagedAppConfig): | ||||||
| @ -93,11 +107,13 @@ class AuthentikBlueprintsConfig(ManagedAppConfig): | |||||||
|     verbose_name = "authentik Blueprints" |     verbose_name = "authentik Blueprints" | ||||||
|     default = True |     default = True | ||||||
|  |  | ||||||
|     def reconcile_global_load_blueprints_v1_tasks(self): |     @ManagedAppConfig.reconcile_global | ||||||
|  |     def load_blueprints_v1_tasks(self): | ||||||
|         """Load v1 tasks""" |         """Load v1 tasks""" | ||||||
|         self.import_module("authentik.blueprints.v1.tasks") |         self.import_module("authentik.blueprints.v1.tasks") | ||||||
|  |  | ||||||
|     def reconcile_tenant_blueprints_discovery(self): |     @ManagedAppConfig.reconcile_tenant | ||||||
|  |     def blueprints_discovery(self): | ||||||
|         """Run blueprint discovery""" |         """Run blueprint discovery""" | ||||||
|         from authentik.blueprints.v1.tasks import blueprints_discovery, clear_failed_blueprints |         from authentik.blueprints.v1.tasks import blueprints_discovery, clear_failed_blueprints | ||||||
|  |  | ||||||
|  | |||||||
| @ -71,6 +71,19 @@ class BlueprintInstance(SerializerModel, ManagedModel, CreatedUpdatedModel): | |||||||
|     enabled = models.BooleanField(default=True) |     enabled = models.BooleanField(default=True) | ||||||
|     managed_models = ArrayField(models.TextField(), default=list) |     managed_models = ArrayField(models.TextField(), default=list) | ||||||
|  |  | ||||||
|  |     class Meta: | ||||||
|  |         verbose_name = _("Blueprint Instance") | ||||||
|  |         verbose_name_plural = _("Blueprint Instances") | ||||||
|  |         unique_together = ( | ||||||
|  |             ( | ||||||
|  |                 "name", | ||||||
|  |                 "path", | ||||||
|  |             ), | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def __str__(self) -> str: | ||||||
|  |         return f"Blueprint Instance {self.name}" | ||||||
|  |  | ||||||
|     def retrieve_oci(self) -> str: |     def retrieve_oci(self) -> str: | ||||||
|         """Get blueprint from an OCI registry""" |         """Get blueprint from an OCI registry""" | ||||||
|         client = BlueprintOCIClient(self.path.replace(OCI_PREFIX, "https://")) |         client = BlueprintOCIClient(self.path.replace(OCI_PREFIX, "https://")) | ||||||
| @ -89,7 +102,7 @@ class BlueprintInstance(SerializerModel, ManagedModel, CreatedUpdatedModel): | |||||||
|                 raise BlueprintRetrievalFailed("Invalid blueprint path") |                 raise BlueprintRetrievalFailed("Invalid blueprint path") | ||||||
|             with full_path.open("r", encoding="utf-8") as _file: |             with full_path.open("r", encoding="utf-8") as _file: | ||||||
|                 return _file.read() |                 return _file.read() | ||||||
|         except (IOError, OSError) as exc: |         except OSError as exc: | ||||||
|             raise BlueprintRetrievalFailed(exc) from exc |             raise BlueprintRetrievalFailed(exc) from exc | ||||||
|  |  | ||||||
|     def retrieve(self) -> str: |     def retrieve(self) -> str: | ||||||
| @ -105,16 +118,3 @@ class BlueprintInstance(SerializerModel, ManagedModel, CreatedUpdatedModel): | |||||||
|         from authentik.blueprints.api import BlueprintInstanceSerializer |         from authentik.blueprints.api import BlueprintInstanceSerializer | ||||||
|  |  | ||||||
|         return BlueprintInstanceSerializer |         return BlueprintInstanceSerializer | ||||||
|  |  | ||||||
|     def __str__(self) -> str: |  | ||||||
|         return f"Blueprint Instance {self.name}" |  | ||||||
|  |  | ||||||
|     class Meta: |  | ||||||
|         verbose_name = _("Blueprint Instance") |  | ||||||
|         verbose_name_plural = _("Blueprint Instances") |  | ||||||
|         unique_together = ( |  | ||||||
|             ( |  | ||||||
|                 "name", |  | ||||||
|                 "path", |  | ||||||
|             ), |  | ||||||
|         ) |  | ||||||
|  | |||||||
| @ -1,7 +1,7 @@ | |||||||
| """Blueprint helpers""" | """Blueprint helpers""" | ||||||
|  |  | ||||||
|  | from collections.abc import Callable | ||||||
| from functools import wraps | from functools import wraps | ||||||
| from typing import Callable |  | ||||||
|  |  | ||||||
| from django.apps import apps | from django.apps import apps | ||||||
|  |  | ||||||
|  | |||||||
| @ -1,7 +1,7 @@ | |||||||
| """test packaged blueprints""" | """test packaged blueprints""" | ||||||
|  |  | ||||||
|  | from collections.abc import Callable | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from typing import Callable |  | ||||||
|  |  | ||||||
| from django.test import TransactionTestCase | from django.test import TransactionTestCase | ||||||
|  |  | ||||||
|  | |||||||
| @ -1,6 +1,6 @@ | |||||||
| """authentik managed models tests""" | """authentik managed models tests""" | ||||||
|  |  | ||||||
| from typing import Callable, Type | from collections.abc import Callable | ||||||
|  |  | ||||||
| from django.apps import apps | from django.apps import apps | ||||||
| from django.test import TestCase | from django.test import TestCase | ||||||
| @ -14,7 +14,7 @@ class TestModels(TestCase): | |||||||
|     """Test Models""" |     """Test Models""" | ||||||
|  |  | ||||||
|  |  | ||||||
| def serializer_tester_factory(test_model: Type[SerializerModel]) -> Callable: | def serializer_tester_factory(test_model: type[SerializerModel]) -> Callable: | ||||||
|     """Test serializer""" |     """Test serializer""" | ||||||
|  |  | ||||||
|     def tester(self: TestModels): |     def tester(self: TestModels): | ||||||
|  | |||||||
| @ -54,7 +54,7 @@ class TestBlueprintsV1Tasks(TransactionTestCase): | |||||||
|             file.seek(0) |             file.seek(0) | ||||||
|             file_hash = sha512(file.read().encode()).hexdigest() |             file_hash = sha512(file.read().encode()).hexdigest() | ||||||
|             file.flush() |             file.flush() | ||||||
|             blueprints_discovery()  # pylint: disable=no-value-for-parameter |             blueprints_discovery() | ||||||
|             instance = BlueprintInstance.objects.filter(name=blueprint_id).first() |             instance = BlueprintInstance.objects.filter(name=blueprint_id).first() | ||||||
|             self.assertEqual(instance.last_applied_hash, file_hash) |             self.assertEqual(instance.last_applied_hash, file_hash) | ||||||
|             self.assertEqual( |             self.assertEqual( | ||||||
| @ -82,7 +82,7 @@ class TestBlueprintsV1Tasks(TransactionTestCase): | |||||||
|                 ) |                 ) | ||||||
|             ) |             ) | ||||||
|             file.flush() |             file.flush() | ||||||
|             blueprints_discovery()  # pylint: disable=no-value-for-parameter |             blueprints_discovery() | ||||||
|             blueprint = BlueprintInstance.objects.filter(name="foo").first() |             blueprint = BlueprintInstance.objects.filter(name="foo").first() | ||||||
|             self.assertEqual( |             self.assertEqual( | ||||||
|                 blueprint.last_applied_hash, |                 blueprint.last_applied_hash, | ||||||
| @ -107,7 +107,7 @@ class TestBlueprintsV1Tasks(TransactionTestCase): | |||||||
|                 ) |                 ) | ||||||
|             ) |             ) | ||||||
|             file.flush() |             file.flush() | ||||||
|             blueprints_discovery()  # pylint: disable=no-value-for-parameter |             blueprints_discovery() | ||||||
|             blueprint.refresh_from_db() |             blueprint.refresh_from_db() | ||||||
|             self.assertEqual( |             self.assertEqual( | ||||||
|                 blueprint.last_applied_hash, |                 blueprint.last_applied_hash, | ||||||
| @ -149,7 +149,7 @@ class TestBlueprintsV1Tasks(TransactionTestCase): | |||||||
|                 instance.status, |                 instance.status, | ||||||
|                 BlueprintInstanceStatus.UNKNOWN, |                 BlueprintInstanceStatus.UNKNOWN, | ||||||
|             ) |             ) | ||||||
|             apply_blueprint(instance.pk)  # pylint: disable=no-value-for-parameter |             apply_blueprint(instance.pk) | ||||||
|             instance.refresh_from_db() |             instance.refresh_from_db() | ||||||
|             self.assertEqual(instance.last_applied_hash, "") |             self.assertEqual(instance.last_applied_hash, "") | ||||||
|             self.assertEqual( |             self.assertEqual( | ||||||
|  | |||||||
| @ -1,13 +1,14 @@ | |||||||
| """transfer common classes""" | """transfer common classes""" | ||||||
|  |  | ||||||
| from collections import OrderedDict | from collections import OrderedDict | ||||||
|  | from collections.abc import Iterable, Mapping | ||||||
| from copy import copy | from copy import copy | ||||||
| from dataclasses import asdict, dataclass, field, is_dataclass | from dataclasses import asdict, dataclass, field, is_dataclass | ||||||
| from enum import Enum | from enum import Enum | ||||||
| from functools import reduce | from functools import reduce | ||||||
| from operator import ixor | from operator import ixor | ||||||
| from os import getenv | from os import getenv | ||||||
| from typing import Any, Iterable, Literal, Mapping, Optional, Union | from typing import Any, Literal, Union | ||||||
| from uuid import UUID | from uuid import UUID | ||||||
|  |  | ||||||
| from deepmerge import always_merger | from deepmerge import always_merger | ||||||
| @ -45,7 +46,7 @@ def get_attrs(obj: SerializerModel) -> dict[str, Any]: | |||||||
| class BlueprintEntryState: | class BlueprintEntryState: | ||||||
|     """State of a single instance""" |     """State of a single instance""" | ||||||
|  |  | ||||||
|     instance: Optional[Model] = None |     instance: Model | None = None | ||||||
|  |  | ||||||
|  |  | ||||||
| class BlueprintEntryDesiredState(Enum): | class BlueprintEntryDesiredState(Enum): | ||||||
| @ -67,9 +68,9 @@ class BlueprintEntry: | |||||||
|     ) |     ) | ||||||
|     conditions: list[Any] = field(default_factory=list) |     conditions: list[Any] = field(default_factory=list) | ||||||
|     identifiers: dict[str, Any] = field(default_factory=dict) |     identifiers: dict[str, Any] = field(default_factory=dict) | ||||||
|     attrs: Optional[dict[str, Any]] = field(default_factory=dict) |     attrs: dict[str, Any] | None = field(default_factory=dict) | ||||||
|  |  | ||||||
|     id: Optional[str] = None |     id: str | None = None | ||||||
|  |  | ||||||
|     _state: BlueprintEntryState = field(default_factory=BlueprintEntryState) |     _state: BlueprintEntryState = field(default_factory=BlueprintEntryState) | ||||||
|  |  | ||||||
| @ -92,10 +93,10 @@ class BlueprintEntry: | |||||||
|             attrs=all_attrs, |             attrs=all_attrs, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def _get_tag_context( |     def get_tag_context( | ||||||
|         self, |         self, | ||||||
|         depth: int = 0, |         depth: int = 0, | ||||||
|         context_tag_type: Optional[type["YAMLTagContext"] | tuple["YAMLTagContext", ...]] = None, |         context_tag_type: type["YAMLTagContext"] | tuple["YAMLTagContext", ...] | None = None, | ||||||
|     ) -> "YAMLTagContext": |     ) -> "YAMLTagContext": | ||||||
|         """Get a YAMLTagContext object located at a certain depth in the tag tree""" |         """Get a YAMLTagContext object located at a certain depth in the tag tree""" | ||||||
|         if depth < 0: |         if depth < 0: | ||||||
| @ -108,8 +109,8 @@ class BlueprintEntry: | |||||||
|  |  | ||||||
|         try: |         try: | ||||||
|             return contexts[-(depth + 1)] |             return contexts[-(depth + 1)] | ||||||
|         except IndexError: |         except IndexError as exc: | ||||||
|             raise ValueError(f"invalid depth: {depth}. Max depth: {len(contexts) - 1}") |             raise ValueError(f"invalid depth: {depth}. Max depth: {len(contexts) - 1}") from exc | ||||||
|  |  | ||||||
|     def tag_resolver(self, value: Any, blueprint: "Blueprint") -> Any: |     def tag_resolver(self, value: Any, blueprint: "Blueprint") -> Any: | ||||||
|         """Check if we have any special tags that need handling""" |         """Check if we have any special tags that need handling""" | ||||||
| @ -170,7 +171,7 @@ class Blueprint: | |||||||
|     entries: list[BlueprintEntry] = field(default_factory=list) |     entries: list[BlueprintEntry] = field(default_factory=list) | ||||||
|     context: dict = field(default_factory=dict) |     context: dict = field(default_factory=dict) | ||||||
|  |  | ||||||
|     metadata: Optional[BlueprintMetadata] = field(default=None) |     metadata: BlueprintMetadata | None = field(default=None) | ||||||
|  |  | ||||||
|  |  | ||||||
| class YAMLTag: | class YAMLTag: | ||||||
| @ -218,7 +219,7 @@ class Env(YAMLTag): | |||||||
|     """Lookup environment variable with optional default""" |     """Lookup environment variable with optional default""" | ||||||
|  |  | ||||||
|     key: str |     key: str | ||||||
|     default: Optional[Any] |     default: Any | None | ||||||
|  |  | ||||||
|     def __init__(self, loader: "BlueprintLoader", node: ScalarNode | SequenceNode) -> None: |     def __init__(self, loader: "BlueprintLoader", node: ScalarNode | SequenceNode) -> None: | ||||||
|         super().__init__() |         super().__init__() | ||||||
| @ -237,7 +238,7 @@ class Context(YAMLTag): | |||||||
|     """Lookup key from instance context""" |     """Lookup key from instance context""" | ||||||
|  |  | ||||||
|     key: str |     key: str | ||||||
|     default: Optional[Any] |     default: Any | None | ||||||
|  |  | ||||||
|     def __init__(self, loader: "BlueprintLoader", node: ScalarNode | SequenceNode) -> None: |     def __init__(self, loader: "BlueprintLoader", node: ScalarNode | SequenceNode) -> None: | ||||||
|         super().__init__() |         super().__init__() | ||||||
| @ -281,7 +282,7 @@ class Format(YAMLTag): | |||||||
|         try: |         try: | ||||||
|             return self.format_string % tuple(args) |             return self.format_string % tuple(args) | ||||||
|         except TypeError as exc: |         except TypeError as exc: | ||||||
|             raise EntryInvalidError.from_entry(exc, entry) |             raise EntryInvalidError.from_entry(exc, entry) from exc | ||||||
|  |  | ||||||
|  |  | ||||||
| class Find(YAMLTag): | class Find(YAMLTag): | ||||||
| @ -366,7 +367,7 @@ class Condition(YAMLTag): | |||||||
|             comparator = self._COMPARATORS[self.mode.upper()] |             comparator = self._COMPARATORS[self.mode.upper()] | ||||||
|             return comparator(tuple(bool(x) for x in args)) |             return comparator(tuple(bool(x) for x in args)) | ||||||
|         except (TypeError, KeyError) as exc: |         except (TypeError, KeyError) as exc: | ||||||
|             raise EntryInvalidError.from_entry(exc, entry) |             raise EntryInvalidError.from_entry(exc, entry) from exc | ||||||
|  |  | ||||||
|  |  | ||||||
| class If(YAMLTag): | class If(YAMLTag): | ||||||
| @ -398,7 +399,7 @@ class If(YAMLTag): | |||||||
|                 blueprint, |                 blueprint, | ||||||
|             ) |             ) | ||||||
|         except TypeError as exc: |         except TypeError as exc: | ||||||
|             raise EntryInvalidError.from_entry(exc, entry) |             raise EntryInvalidError.from_entry(exc, entry) from exc | ||||||
|  |  | ||||||
|  |  | ||||||
| class Enumerate(YAMLTag, YAMLTagContext): | class Enumerate(YAMLTag, YAMLTagContext): | ||||||
| @ -412,9 +413,7 @@ class Enumerate(YAMLTag, YAMLTagContext): | |||||||
|         "SEQ": (list, lambda a, b: [*a, b]), |         "SEQ": (list, lambda a, b: [*a, b]), | ||||||
|         "MAP": ( |         "MAP": ( | ||||||
|             dict, |             dict, | ||||||
|             lambda a, b: always_merger.merge( |             lambda a, b: always_merger.merge(a, {b[0]: b[1]} if isinstance(b, tuple | list) else b), | ||||||
|                 a, {b[0]: b[1]} if isinstance(b, (tuple, list)) else b |  | ||||||
|             ), |  | ||||||
|         ), |         ), | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @ -456,7 +455,7 @@ class Enumerate(YAMLTag, YAMLTagContext): | |||||||
|         try: |         try: | ||||||
|             output_class, add_fn = self._OUTPUT_BODIES[self.output_body.upper()] |             output_class, add_fn = self._OUTPUT_BODIES[self.output_body.upper()] | ||||||
|         except KeyError as exc: |         except KeyError as exc: | ||||||
|             raise EntryInvalidError.from_entry(exc, entry) |             raise EntryInvalidError.from_entry(exc, entry) from exc | ||||||
|  |  | ||||||
|         result = output_class() |         result = output_class() | ||||||
|  |  | ||||||
| @ -484,13 +483,13 @@ class EnumeratedItem(YAMLTag): | |||||||
|  |  | ||||||
|     _SUPPORTED_CONTEXT_TAGS = (Enumerate,) |     _SUPPORTED_CONTEXT_TAGS = (Enumerate,) | ||||||
|  |  | ||||||
|     def __init__(self, loader: "BlueprintLoader", node: ScalarNode) -> None: |     def __init__(self, _loader: "BlueprintLoader", node: ScalarNode) -> None: | ||||||
|         super().__init__() |         super().__init__() | ||||||
|         self.depth = int(node.value) |         self.depth = int(node.value) | ||||||
|  |  | ||||||
|     def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any: |     def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any: | ||||||
|         try: |         try: | ||||||
|             context_tag: Enumerate = entry._get_tag_context( |             context_tag: Enumerate = entry.get_tag_context( | ||||||
|                 depth=self.depth, |                 depth=self.depth, | ||||||
|                 context_tag_type=EnumeratedItem._SUPPORTED_CONTEXT_TAGS, |                 context_tag_type=EnumeratedItem._SUPPORTED_CONTEXT_TAGS, | ||||||
|             ) |             ) | ||||||
| @ -500,9 +499,11 @@ class EnumeratedItem(YAMLTag): | |||||||
|                     f"{self.__class__.__name__} tags are only usable " |                     f"{self.__class__.__name__} tags are only usable " | ||||||
|                     f"inside an {Enumerate.__name__} tag", |                     f"inside an {Enumerate.__name__} tag", | ||||||
|                     entry, |                     entry, | ||||||
|                 ) |                 ) from exc | ||||||
|  |  | ||||||
|             raise EntryInvalidError.from_entry(f"{self.__class__.__name__} tag: {exc}", entry) |             raise EntryInvalidError.from_entry( | ||||||
|  |                 f"{self.__class__.__name__} tag: {exc}", entry | ||||||
|  |             ) from exc | ||||||
|  |  | ||||||
|         return context_tag.get_context(entry, blueprint) |         return context_tag.get_context(entry, blueprint) | ||||||
|  |  | ||||||
| @ -515,8 +516,8 @@ class Index(EnumeratedItem): | |||||||
|  |  | ||||||
|         try: |         try: | ||||||
|             return context[0] |             return context[0] | ||||||
|         except IndexError:  # pragma: no cover |         except IndexError as exc:  # pragma: no cover | ||||||
|             raise EntryInvalidError.from_entry(f"Empty/invalid context: {context}", entry) |             raise EntryInvalidError.from_entry(f"Empty/invalid context: {context}", entry) from exc | ||||||
|  |  | ||||||
|  |  | ||||||
| class Value(EnumeratedItem): | class Value(EnumeratedItem): | ||||||
| @ -527,8 +528,8 @@ class Value(EnumeratedItem): | |||||||
|  |  | ||||||
|         try: |         try: | ||||||
|             return context[1] |             return context[1] | ||||||
|         except IndexError:  # pragma: no cover |         except IndexError as exc:  # pragma: no cover | ||||||
|             raise EntryInvalidError.from_entry(f"Empty/invalid context: {context}", entry) |             raise EntryInvalidError.from_entry(f"Empty/invalid context: {context}", entry) from exc | ||||||
|  |  | ||||||
|  |  | ||||||
| class BlueprintDumper(SafeDumper): | class BlueprintDumper(SafeDumper): | ||||||
| @ -582,13 +583,13 @@ class BlueprintLoader(SafeLoader): | |||||||
| class EntryInvalidError(SentryIgnoredException): | class EntryInvalidError(SentryIgnoredException): | ||||||
|     """Error raised when an entry is invalid""" |     """Error raised when an entry is invalid""" | ||||||
|  |  | ||||||
|     entry_model: Optional[str] |     entry_model: str | None | ||||||
|     entry_id: Optional[str] |     entry_id: str | None | ||||||
|     validation_error: Optional[ValidationError] |     validation_error: ValidationError | None | ||||||
|     serializer: Optional[Serializer] = None |     serializer: Serializer | None = None | ||||||
|  |  | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, *args: object, validation_error: Optional[ValidationError] = None, **kwargs |         self, *args: object, validation_error: ValidationError | None = None, **kwargs | ||||||
|     ) -> None: |     ) -> None: | ||||||
|         super().__init__(*args) |         super().__init__(*args) | ||||||
|         self.entry_model = None |         self.entry_model = None | ||||||
|  | |||||||
| @ -1,6 +1,6 @@ | |||||||
| """Blueprint exporter""" | """Blueprint exporter""" | ||||||
|  |  | ||||||
| from typing import Iterable | from collections.abc import Iterable | ||||||
| from uuid import UUID | from uuid import UUID | ||||||
|  |  | ||||||
| from django.apps import apps | from django.apps import apps | ||||||
| @ -59,7 +59,7 @@ class Exporter: | |||||||
|         blueprint = Blueprint() |         blueprint = Blueprint() | ||||||
|         self._pre_export(blueprint) |         self._pre_export(blueprint) | ||||||
|         blueprint.metadata = BlueprintMetadata( |         blueprint.metadata = BlueprintMetadata( | ||||||
|             name=_("authentik Export - %(date)s" % {"date": str(now())}), |             name=_("authentik Export - {date}".format_map({"date": str(now())})), | ||||||
|             labels={ |             labels={ | ||||||
|                 LABEL_AUTHENTIK_GENERATED: "true", |                 LABEL_AUTHENTIK_GENERATED: "true", | ||||||
|             }, |             }, | ||||||
|  | |||||||
| @ -2,7 +2,7 @@ | |||||||
|  |  | ||||||
| from contextlib import contextmanager | from contextlib import contextmanager | ||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
| from typing import Any, Optional | from typing import Any | ||||||
|  |  | ||||||
| from dacite.config import Config | from dacite.config import Config | ||||||
| from dacite.core import from_dict | from dacite.core import from_dict | ||||||
| @ -62,7 +62,7 @@ SERIALIZER_CONTEXT_BLUEPRINT = "blueprint_entry" | |||||||
| def excluded_models() -> list[type[Model]]: | def excluded_models() -> list[type[Model]]: | ||||||
|     """Return a list of all excluded models that shouldn't be exposed via API |     """Return a list of all excluded models that shouldn't be exposed via API | ||||||
|     or other means (internal only, base classes, non-used objects, etc)""" |     or other means (internal only, base classes, non-used objects, etc)""" | ||||||
|     # pylint: disable=imported-auth-user |  | ||||||
|     from django.contrib.auth.models import Group as DjangoGroup |     from django.contrib.auth.models import Group as DjangoGroup | ||||||
|     from django.contrib.auth.models import User as DjangoUser |     from django.contrib.auth.models import User as DjangoUser | ||||||
|  |  | ||||||
| @ -101,7 +101,7 @@ def excluded_models() -> list[type[Model]]: | |||||||
|  |  | ||||||
| def is_model_allowed(model: type[Model]) -> bool: | def is_model_allowed(model: type[Model]) -> bool: | ||||||
|     """Check if model is allowed""" |     """Check if model is allowed""" | ||||||
|     return model not in excluded_models() and issubclass(model, (SerializerModel, BaseMetaModel)) |     return model not in excluded_models() and issubclass(model, SerializerModel | BaseMetaModel) | ||||||
|  |  | ||||||
|  |  | ||||||
| class DoRollback(SentryIgnoredException): | class DoRollback(SentryIgnoredException): | ||||||
| @ -125,7 +125,7 @@ class Importer: | |||||||
|     logger: BoundLogger |     logger: BoundLogger | ||||||
|     _import: Blueprint |     _import: Blueprint | ||||||
|  |  | ||||||
|     def __init__(self, blueprint: Blueprint, context: Optional[dict] = None): |     def __init__(self, blueprint: Blueprint, context: dict | None = None): | ||||||
|         self.__pk_map: dict[Any, Model] = {} |         self.__pk_map: dict[Any, Model] = {} | ||||||
|         self._import = blueprint |         self._import = blueprint | ||||||
|         self.logger = get_logger() |         self.logger = get_logger() | ||||||
| @ -168,7 +168,7 @@ class Importer: | |||||||
|         for key, value in attrs.items(): |         for key, value in attrs.items(): | ||||||
|             try: |             try: | ||||||
|                 if isinstance(value, dict): |                 if isinstance(value, dict): | ||||||
|                     for idx, _inner_key in enumerate(value): |                     for _, _inner_key in enumerate(value): | ||||||
|                         value[_inner_key] = updater(value[_inner_key]) |                         value[_inner_key] = updater(value[_inner_key]) | ||||||
|                 elif isinstance(value, list): |                 elif isinstance(value, list): | ||||||
|                     for idx, _inner_value in enumerate(value): |                     for idx, _inner_value in enumerate(value): | ||||||
| @ -197,8 +197,7 @@ class Importer: | |||||||
|  |  | ||||||
|         return main_query | sub_query |         return main_query | sub_query | ||||||
|  |  | ||||||
|     # pylint: disable-msg=too-many-locals |     def _validate_single(self, entry: BlueprintEntry) -> BaseSerializer | None: | ||||||
|     def _validate_single(self, entry: BlueprintEntry) -> Optional[BaseSerializer]: |  | ||||||
|         """Validate a single entry""" |         """Validate a single entry""" | ||||||
|         if not entry.check_all_conditions_match(self._import): |         if not entry.check_all_conditions_match(self._import): | ||||||
|             self.logger.debug("One or more conditions of this entry are not fulfilled, skipping") |             self.logger.debug("One or more conditions of this entry are not fulfilled, skipping") | ||||||
| @ -369,7 +368,7 @@ class Importer: | |||||||
|                     self.__pk_map[entry.identifiers["pk"]] = instance.pk |                     self.__pk_map[entry.identifiers["pk"]] = instance.pk | ||||||
|                 entry._state = BlueprintEntryState(instance) |                 entry._state = BlueprintEntryState(instance) | ||||||
|             elif state == BlueprintEntryDesiredState.ABSENT: |             elif state == BlueprintEntryDesiredState.ABSENT: | ||||||
|                 instance: Optional[Model] = serializer.instance |                 instance: Model | None = serializer.instance | ||||||
|                 if instance.pk: |                 if instance.pk: | ||||||
|                     instance.delete() |                     instance.delete() | ||||||
|                     self.logger.debug("deleted model", mode=instance) |                     self.logger.debug("deleted model", mode=instance) | ||||||
|  | |||||||
| @ -43,7 +43,7 @@ class ApplyBlueprintMetaSerializer(PassiveSerializer): | |||||||
|             LOGGER.info("Blueprint does not exist, but not required") |             LOGGER.info("Blueprint does not exist, but not required") | ||||||
|             return MetaResult() |             return MetaResult() | ||||||
|         LOGGER.debug("Applying blueprint from meta model", blueprint=self.blueprint_instance) |         LOGGER.debug("Applying blueprint from meta model", blueprint=self.blueprint_instance) | ||||||
|         # pylint: disable=no-value-for-parameter |  | ||||||
|         apply_blueprint(str(self.blueprint_instance.pk)) |         apply_blueprint(str(self.blueprint_instance.pk)) | ||||||
|         return MetaResult() |         return MetaResult() | ||||||
|  |  | ||||||
|  | |||||||
| @ -8,15 +8,15 @@ from rest_framework.serializers import Serializer | |||||||
| class BaseMetaModel(Model): | class BaseMetaModel(Model): | ||||||
|     """Base models""" |     """Base models""" | ||||||
|  |  | ||||||
|  |     class Meta: | ||||||
|  |         abstract = True | ||||||
|  |  | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def serializer() -> Serializer: |     def serializer() -> Serializer: | ||||||
|         """Serializer similar to SerializerModel, but as a static method since |         """Serializer similar to SerializerModel, but as a static method since | ||||||
|         this is an abstract model""" |         this is an abstract model""" | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|     class Meta: |  | ||||||
|         abstract = True |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class MetaResult: | class MetaResult: | ||||||
|     """Result returned by Meta Models' serializers. Empty class but we can't return none as |     """Result returned by Meta Models' serializers. Empty class but we can't return none as | ||||||
|  | |||||||
| @ -4,7 +4,6 @@ from dataclasses import asdict, dataclass, field | |||||||
| from hashlib import sha512 | from hashlib import sha512 | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from sys import platform | from sys import platform | ||||||
| from typing import Optional |  | ||||||
|  |  | ||||||
| from dacite.core import from_dict | from dacite.core import from_dict | ||||||
| from django.db import DatabaseError, InternalError, ProgrammingError | from django.db import DatabaseError, InternalError, ProgrammingError | ||||||
| @ -50,14 +49,14 @@ class BlueprintFile: | |||||||
|     version: int |     version: int | ||||||
|     hash: str |     hash: str | ||||||
|     last_m: int |     last_m: int | ||||||
|     meta: Optional[BlueprintMetadata] = field(default=None) |     meta: BlueprintMetadata | None = field(default=None) | ||||||
|  |  | ||||||
|  |  | ||||||
| def start_blueprint_watcher(): | def start_blueprint_watcher(): | ||||||
|     """Start blueprint watcher, if it's not running already.""" |     """Start blueprint watcher, if it's not running already.""" | ||||||
|     # This function might be called twice since it's called on celery startup |     # This function might be called twice since it's called on celery startup | ||||||
|     # pylint: disable=global-statement |  | ||||||
|     global _file_watcher_started |     global _file_watcher_started  # noqa: PLW0603 | ||||||
|     if _file_watcher_started: |     if _file_watcher_started: | ||||||
|         return |         return | ||||||
|     observer = Observer() |     observer = Observer() | ||||||
| @ -126,7 +125,7 @@ def blueprints_find() -> list[BlueprintFile]: | |||||||
|         # Check if any part in the path starts with a dot and assume a hidden file |         # Check if any part in the path starts with a dot and assume a hidden file | ||||||
|         if any(part for part in path.parts if part.startswith(".")): |         if any(part for part in path.parts if part.startswith(".")): | ||||||
|             continue |             continue | ||||||
|         with open(path, "r", encoding="utf-8") as blueprint_file: |         with open(path, encoding="utf-8") as blueprint_file: | ||||||
|             try: |             try: | ||||||
|                 raw_blueprint = load(blueprint_file.read(), BlueprintLoader) |                 raw_blueprint = load(blueprint_file.read(), BlueprintLoader) | ||||||
|             except YAMLError as exc: |             except YAMLError as exc: | ||||||
| @ -150,7 +149,7 @@ def blueprints_find() -> list[BlueprintFile]: | |||||||
|     throws=(DatabaseError, ProgrammingError, InternalError), base=SystemTask, bind=True |     throws=(DatabaseError, ProgrammingError, InternalError), base=SystemTask, bind=True | ||||||
| ) | ) | ||||||
| @prefill_task | @prefill_task | ||||||
| def blueprints_discovery(self: SystemTask, path: Optional[str] = None): | def blueprints_discovery(self: SystemTask, path: str | None = None): | ||||||
|     """Find blueprints and check if they need to be created in the database""" |     """Find blueprints and check if they need to be created in the database""" | ||||||
|     count = 0 |     count = 0 | ||||||
|     for blueprint in blueprints_find(): |     for blueprint in blueprints_find(): | ||||||
| @ -197,7 +196,7 @@ def check_blueprint_v1_file(blueprint: BlueprintFile): | |||||||
| def apply_blueprint(self: SystemTask, instance_pk: str): | def apply_blueprint(self: SystemTask, instance_pk: str): | ||||||
|     """Apply single blueprint""" |     """Apply single blueprint""" | ||||||
|     self.save_on_success = False |     self.save_on_success = False | ||||||
|     instance: Optional[BlueprintInstance] = None |     instance: BlueprintInstance | None = None | ||||||
|     try: |     try: | ||||||
|         instance: BlueprintInstance = BlueprintInstance.objects.filter(pk=instance_pk).first() |         instance: BlueprintInstance = BlueprintInstance.objects.filter(pk=instance_pk).first() | ||||||
|         if not instance or not instance.enabled: |         if not instance or not instance.enabled: | ||||||
| @ -225,10 +224,10 @@ def apply_blueprint(self: SystemTask, instance_pk: str): | |||||||
|         instance.last_applied = now() |         instance.last_applied = now() | ||||||
|         self.set_status(TaskStatus.SUCCESSFUL) |         self.set_status(TaskStatus.SUCCESSFUL) | ||||||
|     except ( |     except ( | ||||||
|  |         OSError, | ||||||
|         DatabaseError, |         DatabaseError, | ||||||
|         ProgrammingError, |         ProgrammingError, | ||||||
|         InternalError, |         InternalError, | ||||||
|         IOError, |  | ||||||
|         BlueprintRetrievalFailed, |         BlueprintRetrievalFailed, | ||||||
|         EntryInvalidError, |         EntryInvalidError, | ||||||
|     ) as exc: |     ) as exc: | ||||||
|  | |||||||
| @ -1,6 +1,6 @@ | |||||||
| """Inject brand into current request""" | """Inject brand into current request""" | ||||||
|  |  | ||||||
| from typing import Callable | from collections.abc import Callable | ||||||
|  |  | ||||||
| from django.http.request import HttpRequest | from django.http.request import HttpRequest | ||||||
| from django.http.response import HttpResponse | from django.http.response import HttpResponse | ||||||
| @ -20,7 +20,7 @@ class BrandMiddleware: | |||||||
|     def __call__(self, request: HttpRequest) -> HttpResponse: |     def __call__(self, request: HttpRequest) -> HttpResponse: | ||||||
|         if not hasattr(request, "brand"): |         if not hasattr(request, "brand"): | ||||||
|             brand = get_brand_for_request(request) |             brand = get_brand_for_request(request) | ||||||
|             setattr(request, "brand", brand) |             request.brand = brand | ||||||
|             locale = brand.default_locale |             locale = brand.default_locale | ||||||
|             if locale != "": |             if locale != "": | ||||||
|                 activate(locale) |                 activate(locale) | ||||||
|  | |||||||
| @ -71,7 +71,7 @@ class Brand(SerializerModel): | |||||||
|         """Get default locale""" |         """Get default locale""" | ||||||
|         try: |         try: | ||||||
|             return self.attributes.get("settings", {}).get("locale", "") |             return self.attributes.get("settings", {}).get("locale", "") | ||||||
|         # pylint: disable=broad-except |  | ||||||
|         except Exception as exc: |         except Exception as exc: | ||||||
|             LOGGER.warning("Failed to get default locale", exc=exc) |             LOGGER.warning("Failed to get default locale", exc=exc) | ||||||
|             return "" |             return "" | ||||||
|  | |||||||
| @ -1,8 +1,8 @@ | |||||||
| """Application API Views""" | """Application API Views""" | ||||||
|  |  | ||||||
|  | from collections.abc import Iterator | ||||||
| from copy import copy | from copy import copy | ||||||
| from datetime import timedelta | from datetime import timedelta | ||||||
| from typing import Iterator, Optional |  | ||||||
|  |  | ||||||
| from django.core.cache import cache | from django.core.cache import cache | ||||||
| from django.db.models import QuerySet | from django.db.models import QuerySet | ||||||
| @ -60,7 +60,7 @@ class ApplicationSerializer(ModelSerializer): | |||||||
|  |  | ||||||
|     meta_icon = ReadOnlyField(source="get_meta_icon") |     meta_icon = ReadOnlyField(source="get_meta_icon") | ||||||
|  |  | ||||||
|     def get_launch_url(self, app: Application) -> Optional[str]: |     def get_launch_url(self, app: Application) -> str | None: | ||||||
|         """Allow formatting of launch URL""" |         """Allow formatting of launch URL""" | ||||||
|         user = None |         user = None | ||||||
|         if "request" in self.context: |         if "request" in self.context: | ||||||
| @ -100,7 +100,6 @@ class ApplicationSerializer(ModelSerializer): | |||||||
| class ApplicationViewSet(UsedByMixin, ModelViewSet): | class ApplicationViewSet(UsedByMixin, ModelViewSet): | ||||||
|     """Application Viewset""" |     """Application Viewset""" | ||||||
|  |  | ||||||
|     # pylint: disable=no-member |  | ||||||
|     queryset = Application.objects.all().prefetch_related("provider") |     queryset = Application.objects.all().prefetch_related("provider") | ||||||
|     serializer_class = ApplicationSerializer |     serializer_class = ApplicationSerializer | ||||||
|     search_fields = [ |     search_fields = [ | ||||||
| @ -131,7 +130,7 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet): | |||||||
|         return queryset |         return queryset | ||||||
|  |  | ||||||
|     def _get_allowed_applications( |     def _get_allowed_applications( | ||||||
|         self, pagined_apps: Iterator[Application], user: Optional[User] = None |         self, pagined_apps: Iterator[Application], user: User | None = None | ||||||
|     ) -> list[Application]: |     ) -> list[Application]: | ||||||
|         applications = [] |         applications = [] | ||||||
|         request = self.request._request |         request = self.request._request | ||||||
| @ -169,7 +168,7 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet): | |||||||
|             try: |             try: | ||||||
|                 for_user = User.objects.filter(pk=request.query_params.get("for_user")).first() |                 for_user = User.objects.filter(pk=request.query_params.get("for_user")).first() | ||||||
|             except ValueError: |             except ValueError: | ||||||
|                 raise ValidationError({"for_user": "for_user must be numerical"}) |                 raise ValidationError({"for_user": "for_user must be numerical"}) from None | ||||||
|             if not for_user: |             if not for_user: | ||||||
|                 raise ValidationError({"for_user": "User not found"}) |                 raise ValidationError({"for_user": "User not found"}) | ||||||
|         engine = PolicyEngine(application, for_user, request) |         engine = PolicyEngine(application, for_user, request) | ||||||
|  | |||||||
| @ -1,6 +1,6 @@ | |||||||
| """AuthenticatedSessions API Viewset""" | """AuthenticatedSessions API Viewset""" | ||||||
|  |  | ||||||
| from typing import Optional, TypedDict | from typing import TypedDict | ||||||
|  |  | ||||||
| from django_filters.rest_framework import DjangoFilterBackend | from django_filters.rest_framework import DjangoFilterBackend | ||||||
| from guardian.utils import get_anonymous_user | from guardian.utils import get_anonymous_user | ||||||
| @ -72,11 +72,11 @@ class AuthenticatedSessionSerializer(ModelSerializer): | |||||||
|         """Get parsed user agent""" |         """Get parsed user agent""" | ||||||
|         return user_agent_parser.Parse(instance.last_user_agent) |         return user_agent_parser.Parse(instance.last_user_agent) | ||||||
|  |  | ||||||
|     def get_geo_ip(self, instance: AuthenticatedSession) -> Optional[GeoIPDict]:  # pragma: no cover |     def get_geo_ip(self, instance: AuthenticatedSession) -> GeoIPDict | None:  # pragma: no cover | ||||||
|         """Get GeoIP Data""" |         """Get GeoIP Data""" | ||||||
|         return GEOIP_CONTEXT_PROCESSOR.city_dict(instance.last_ip) |         return GEOIP_CONTEXT_PROCESSOR.city_dict(instance.last_ip) | ||||||
|  |  | ||||||
|     def get_asn(self, instance: AuthenticatedSession) -> Optional[ASNDict]:  # pragma: no cover |     def get_asn(self, instance: AuthenticatedSession) -> ASNDict | None:  # pragma: no cover | ||||||
|         """Get ASN Data""" |         """Get ASN Data""" | ||||||
|         return ASN_CONTEXT_PROCESSOR.asn_dict(instance.last_ip) |         return ASN_CONTEXT_PROCESSOR.asn_dict(instance.last_ip) | ||||||
|  |  | ||||||
|  | |||||||
| @ -1,7 +1,6 @@ | |||||||
| """Groups API Viewset""" | """Groups API Viewset""" | ||||||
|  |  | ||||||
| from json import loads | from json import loads | ||||||
| from typing import Optional |  | ||||||
|  |  | ||||||
| from django.http import Http404 | from django.http import Http404 | ||||||
| from django_filters.filters import CharFilter, ModelMultipleChoiceFilter | from django_filters.filters import CharFilter, ModelMultipleChoiceFilter | ||||||
| @ -59,7 +58,7 @@ class GroupSerializer(ModelSerializer): | |||||||
|  |  | ||||||
|     num_pk = IntegerField(read_only=True) |     num_pk = IntegerField(read_only=True) | ||||||
|  |  | ||||||
|     def validate_parent(self, parent: Optional[Group]): |     def validate_parent(self, parent: Group | None): | ||||||
|         """Validate group parent (if set), ensuring the parent isn't itself""" |         """Validate group parent (if set), ensuring the parent isn't itself""" | ||||||
|         if not self.instance or not parent: |         if not self.instance or not parent: | ||||||
|             return parent |             return parent | ||||||
| @ -114,7 +113,7 @@ class GroupFilter(FilterSet): | |||||||
|         try: |         try: | ||||||
|             value = loads(value) |             value = loads(value) | ||||||
|         except ValueError: |         except ValueError: | ||||||
|             raise ValidationError(detail="filter: failed to parse JSON") |             raise ValidationError(detail="filter: failed to parse JSON") from None | ||||||
|         if not isinstance(value, dict): |         if not isinstance(value, dict): | ||||||
|             raise ValidationError(detail="filter: value must be key:value mapping") |             raise ValidationError(detail="filter: value must be key:value mapping") | ||||||
|         qs = {} |         qs = {} | ||||||
| @ -140,7 +139,6 @@ class UserAccountSerializer(PassiveSerializer): | |||||||
| class GroupViewSet(UsedByMixin, ModelViewSet): | class GroupViewSet(UsedByMixin, ModelViewSet): | ||||||
|     """Group Viewset""" |     """Group Viewset""" | ||||||
|  |  | ||||||
|     # pylint: disable=no-member |  | ||||||
|     queryset = Group.objects.all().select_related("parent").prefetch_related("users") |     queryset = Group.objects.all().select_related("parent").prefetch_related("users") | ||||||
|     serializer_class = GroupSerializer |     serializer_class = GroupSerializer | ||||||
|     search_fields = ["name", "is_superuser"] |     search_fields = ["name", "is_superuser"] | ||||||
|  | |||||||
| @ -146,7 +146,7 @@ class PropertyMappingViewSet( | |||||||
|             response_data["result"] = dumps( |             response_data["result"] = dumps( | ||||||
|                 sanitize_item(result), indent=(4 if format_result else None) |                 sanitize_item(result), indent=(4 if format_result else None) | ||||||
|             ) |             ) | ||||||
|         except Exception as exc:  # pylint: disable=broad-except |         except Exception as exc: | ||||||
|             response_data["result"] = str(exc) |             response_data["result"] = str(exc) | ||||||
|             response_data["successful"] = False |             response_data["successful"] = False | ||||||
|         response = PropertyMappingTestResultSerializer(response_data) |         response = PropertyMappingTestResultSerializer(response_data) | ||||||
|  | |||||||
| @ -1,6 +1,6 @@ | |||||||
| """Source API Views""" | """Source API Views""" | ||||||
|  |  | ||||||
| from typing import Iterable | from collections.abc import Iterable | ||||||
|  |  | ||||||
| from django_filters.rest_framework import DjangoFilterBackend | from django_filters.rest_framework import DjangoFilterBackend | ||||||
| from drf_spectacular.utils import OpenApiResponse, extend_schema | from drf_spectacular.utils import OpenApiResponse, extend_schema | ||||||
|  | |||||||
| @ -65,7 +65,7 @@ class TransactionApplicationSerializer(PassiveSerializer): | |||||||
|                 raise ValidationError("Invalid provider model") |                 raise ValidationError("Invalid provider model") | ||||||
|             self._provider_model = model |             self._provider_model = model | ||||||
|         except LookupError: |         except LookupError: | ||||||
|             raise ValidationError("Invalid provider model") |             raise ValidationError("Invalid provider model") from None | ||||||
|         return fq_model_name |         return fq_model_name | ||||||
|  |  | ||||||
|     def validate(self, attrs: dict) -> dict: |     def validate(self, attrs: dict) -> dict: | ||||||
| @ -106,7 +106,7 @@ class TransactionApplicationSerializer(PassiveSerializer): | |||||||
|                 { |                 { | ||||||
|                     exc.entry_id: exc.validation_error.detail, |                     exc.entry_id: exc.validation_error.detail, | ||||||
|                 } |                 } | ||||||
|             ) |             ) from None | ||||||
|         return blueprint |         return blueprint | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| @ -54,7 +54,6 @@ class UsedByMixin: | |||||||
|         responses={200: UsedBySerializer(many=True)}, |         responses={200: UsedBySerializer(many=True)}, | ||||||
|     ) |     ) | ||||||
|     @action(detail=True, pagination_class=None, filter_backends=[]) |     @action(detail=True, pagination_class=None, filter_backends=[]) | ||||||
|     # pylint: disable=too-many-locals |  | ||||||
|     def used_by(self, request: Request, *args, **kwargs) -> Response: |     def used_by(self, request: Request, *args, **kwargs) -> Response: | ||||||
|         """Get a list of all objects that use this object""" |         """Get a list of all objects that use this object""" | ||||||
|         model: Model = self.get_object() |         model: Model = self.get_object() | ||||||
|  | |||||||
| @ -2,7 +2,7 @@ | |||||||
|  |  | ||||||
| from datetime import timedelta | from datetime import timedelta | ||||||
| from json import loads | from json import loads | ||||||
| from typing import Any, Optional | from typing import Any | ||||||
|  |  | ||||||
| from django.contrib.auth import update_session_auth_hash | from django.contrib.auth import update_session_auth_hash | ||||||
| from django.contrib.sessions.backends.cache import KEY_PREFIX | from django.contrib.sessions.backends.cache import KEY_PREFIX | ||||||
| @ -142,7 +142,7 @@ class UserSerializer(ModelSerializer): | |||||||
|         self._set_password(instance, password) |         self._set_password(instance, password) | ||||||
|         return instance |         return instance | ||||||
|  |  | ||||||
|     def _set_password(self, instance: User, password: Optional[str]): |     def _set_password(self, instance: User, password: str | None): | ||||||
|         """Set password of user if we're in a blueprint context, and if it's an empty |         """Set password of user if we're in a blueprint context, and if it's an empty | ||||||
|         string then use an unusable password""" |         string then use an unusable password""" | ||||||
|         if SERIALIZER_CONTEXT_BLUEPRINT in self.context and password: |         if SERIALIZER_CONTEXT_BLUEPRINT in self.context and password: | ||||||
| @ -358,7 +358,7 @@ class UsersFilter(FilterSet): | |||||||
|         try: |         try: | ||||||
|             value = loads(value) |             value = loads(value) | ||||||
|         except ValueError: |         except ValueError: | ||||||
|             raise ValidationError(detail="filter: failed to parse JSON") |             raise ValidationError(detail="filter: failed to parse JSON") from None | ||||||
|         if not isinstance(value, dict): |         if not isinstance(value, dict): | ||||||
|             raise ValidationError(detail="filter: value must be key:value mapping") |             raise ValidationError(detail="filter: value must be key:value mapping") | ||||||
|         qs = {} |         qs = {} | ||||||
| @ -397,15 +397,14 @@ class UserViewSet(UsedByMixin, ModelViewSet): | |||||||
|     def get_queryset(self):  # pragma: no cover |     def get_queryset(self):  # pragma: no cover | ||||||
|         return User.objects.all().exclude_anonymous().prefetch_related("ak_groups") |         return User.objects.all().exclude_anonymous().prefetch_related("ak_groups") | ||||||
|  |  | ||||||
|     def _create_recovery_link(self) -> tuple[Optional[str], Optional[Token]]: |     def _create_recovery_link(self) -> tuple[str, Token]: | ||||||
|         """Create a recovery link (when the current brand has a recovery flow set), |         """Create a recovery link (when the current brand has a recovery flow set), | ||||||
|         that can either be shown to an admin or sent to the user directly""" |         that can either be shown to an admin or sent to the user directly""" | ||||||
|         brand: Brand = self.request._request.brand |         brand: Brand = self.request._request.brand | ||||||
|         # Check that there is a recovery flow, if not return an error |         # Check that there is a recovery flow, if not return an error | ||||||
|         flow = brand.flow_recovery |         flow = brand.flow_recovery | ||||||
|         if not flow: |         if not flow: | ||||||
|             LOGGER.debug("No recovery flow set") |             raise ValidationError({"non_field_errors": "No recovery flow set."}) | ||||||
|             return None, None |  | ||||||
|         user: User = self.get_object() |         user: User = self.get_object() | ||||||
|         planner = FlowPlanner(flow) |         planner = FlowPlanner(flow) | ||||||
|         planner.allow_empty_flows = True |         planner.allow_empty_flows = True | ||||||
| @ -417,8 +416,9 @@ class UserViewSet(UsedByMixin, ModelViewSet): | |||||||
|                 }, |                 }, | ||||||
|             ) |             ) | ||||||
|         except FlowNonApplicableException: |         except FlowNonApplicableException: | ||||||
|             LOGGER.warning("Recovery flow not applicable to user") |             raise ValidationError( | ||||||
|             return None, None |                 {"non_field_errors": "Recovery flow not applicable to user"} | ||||||
|  |             ) from None | ||||||
|         token, __ = FlowToken.objects.update_or_create( |         token, __ = FlowToken.objects.update_or_create( | ||||||
|             identifier=f"{user.uid}-password-reset", |             identifier=f"{user.uid}-password-reset", | ||||||
|             defaults={ |             defaults={ | ||||||
| @ -563,16 +563,13 @@ class UserViewSet(UsedByMixin, ModelViewSet): | |||||||
|     @extend_schema( |     @extend_schema( | ||||||
|         responses={ |         responses={ | ||||||
|             "200": LinkSerializer(many=False), |             "200": LinkSerializer(many=False), | ||||||
|             "404": LinkSerializer(many=False), |  | ||||||
|         }, |         }, | ||||||
|  |         request=None, | ||||||
|     ) |     ) | ||||||
|     @action(detail=True, pagination_class=None, filter_backends=[]) |     @action(detail=True, pagination_class=None, filter_backends=[], methods=["POST"]) | ||||||
|     def recovery(self, request: Request, pk: int) -> Response: |     def recovery(self, request: Request, pk: int) -> Response: | ||||||
|         """Create a temporary link that a user can use to recover their accounts""" |         """Create a temporary link that a user can use to recover their accounts""" | ||||||
|         link, _ = self._create_recovery_link() |         link, _ = self._create_recovery_link() | ||||||
|         if not link: |  | ||||||
|             LOGGER.debug("Couldn't create token") |  | ||||||
|             return Response({"link": ""}, status=404) |  | ||||||
|         return Response({"link": link}) |         return Response({"link": link}) | ||||||
|  |  | ||||||
|     @permission_required("authentik_core.reset_user_password") |     @permission_required("authentik_core.reset_user_password") | ||||||
| @ -587,27 +584,24 @@ class UserViewSet(UsedByMixin, ModelViewSet): | |||||||
|         ], |         ], | ||||||
|         responses={ |         responses={ | ||||||
|             "204": OpenApiResponse(description="Successfully sent recover email"), |             "204": OpenApiResponse(description="Successfully sent recover email"), | ||||||
|             "404": OpenApiResponse(description="Bad request"), |  | ||||||
|         }, |         }, | ||||||
|  |         request=None, | ||||||
|     ) |     ) | ||||||
|     @action(detail=True, pagination_class=None, filter_backends=[]) |     @action(detail=True, pagination_class=None, filter_backends=[], methods=["POST"]) | ||||||
|     def recovery_email(self, request: Request, pk: int) -> Response: |     def recovery_email(self, request: Request, pk: int) -> Response: | ||||||
|         """Create a temporary link that a user can use to recover their accounts""" |         """Create a temporary link that a user can use to recover their accounts""" | ||||||
|         for_user: User = self.get_object() |         for_user: User = self.get_object() | ||||||
|         if for_user.email == "": |         if for_user.email == "": | ||||||
|             LOGGER.debug("User doesn't have an email address") |             LOGGER.debug("User doesn't have an email address") | ||||||
|             return Response(status=404) |             raise ValidationError({"non_field_errors": "User does not have an email address set."}) | ||||||
|         link, token = self._create_recovery_link() |         link, token = self._create_recovery_link() | ||||||
|         if not link: |  | ||||||
|             LOGGER.debug("Couldn't create token") |  | ||||||
|             return Response(status=404) |  | ||||||
|         # Lookup the email stage to assure the current user can access it |         # Lookup the email stage to assure the current user can access it | ||||||
|         stages = get_objects_for_user( |         stages = get_objects_for_user( | ||||||
|             request.user, "authentik_stages_email.view_emailstage" |             request.user, "authentik_stages_email.view_emailstage" | ||||||
|         ).filter(pk=request.query_params.get("email_stage")) |         ).filter(pk=request.query_params.get("email_stage")) | ||||||
|         if not stages.exists(): |         if not stages.exists(): | ||||||
|             LOGGER.debug("Email stage does not exist/user has no permissions") |             LOGGER.debug("Email stage does not exist/user has no permissions") | ||||||
|             return Response(status=404) |             raise ValidationError({"non_field_errors": "Email stage does not exist."}) | ||||||
|         email_stage: EmailStage = stages.first() |         email_stage: EmailStage = stages.first() | ||||||
|         message = TemplateEmailMessage( |         message = TemplateEmailMessage( | ||||||
|             subject=_(email_stage.subject), |             subject=_(email_stage.subject), | ||||||
|  | |||||||
| @ -14,14 +14,16 @@ class AuthentikCoreConfig(ManagedAppConfig): | |||||||
|     mountpoint = "" |     mountpoint = "" | ||||||
|     default = True |     default = True | ||||||
|  |  | ||||||
|     def reconcile_global_debug_worker_hook(self): |     @ManagedAppConfig.reconcile_global | ||||||
|  |     def debug_worker_hook(self): | ||||||
|         """Dispatch startup tasks inline when debugging""" |         """Dispatch startup tasks inline when debugging""" | ||||||
|         if settings.DEBUG: |         if settings.DEBUG: | ||||||
|             from authentik.root.celery import worker_ready_hook |             from authentik.root.celery import worker_ready_hook | ||||||
|  |  | ||||||
|             worker_ready_hook() |             worker_ready_hook() | ||||||
|  |  | ||||||
|     def reconcile_tenant_source_inbuilt(self): |     @ManagedAppConfig.reconcile_tenant | ||||||
|  |     def source_inbuilt(self): | ||||||
|         """Reconcile inbuilt source""" |         """Reconcile inbuilt source""" | ||||||
|         from authentik.core.models import Source |         from authentik.core.models import Source | ||||||
|  |  | ||||||
|  | |||||||
| @ -1,6 +1,6 @@ | |||||||
| """Authenticate with tokens""" | """Authenticate with tokens""" | ||||||
|  |  | ||||||
| from typing import Any, Optional | from typing import Any | ||||||
|  |  | ||||||
| from django.contrib.auth.backends import ModelBackend | from django.contrib.auth.backends import ModelBackend | ||||||
| from django.http.request import HttpRequest | from django.http.request import HttpRequest | ||||||
| @ -16,15 +16,15 @@ class InbuiltBackend(ModelBackend): | |||||||
|     """Inbuilt backend""" |     """Inbuilt backend""" | ||||||
|  |  | ||||||
|     def authenticate( |     def authenticate( | ||||||
|         self, request: HttpRequest, username: Optional[str], password: Optional[str], **kwargs: Any |         self, request: HttpRequest, username: str | None, password: str | None, **kwargs: Any | ||||||
|     ) -> Optional[User]: |     ) -> User | None: | ||||||
|         user = super().authenticate(request, username=username, password=password, **kwargs) |         user = super().authenticate(request, username=username, password=password, **kwargs) | ||||||
|         if not user: |         if not user: | ||||||
|             return None |             return None | ||||||
|         self.set_method("password", request) |         self.set_method("password", request) | ||||||
|         return user |         return user | ||||||
|  |  | ||||||
|     def set_method(self, method: str, request: Optional[HttpRequest], **kwargs): |     def set_method(self, method: str, request: HttpRequest | None, **kwargs): | ||||||
|         """Set method data on current flow, if possbiel""" |         """Set method data on current flow, if possbiel""" | ||||||
|         if not request: |         if not request: | ||||||
|             return |             return | ||||||
| @ -40,18 +40,18 @@ class TokenBackend(InbuiltBackend): | |||||||
|     """Authenticate with token""" |     """Authenticate with token""" | ||||||
|  |  | ||||||
|     def authenticate( |     def authenticate( | ||||||
|         self, request: HttpRequest, username: Optional[str], password: Optional[str], **kwargs: Any |         self, request: HttpRequest, username: str | None, password: str | None, **kwargs: Any | ||||||
|     ) -> Optional[User]: |     ) -> User | None: | ||||||
|         try: |         try: | ||||||
|             # pylint: disable=no-member |  | ||||||
|             user = User._default_manager.get_by_natural_key(username) |             user = User._default_manager.get_by_natural_key(username) | ||||||
|         # pylint: disable=no-member |  | ||||||
|         except User.DoesNotExist: |         except User.DoesNotExist: | ||||||
|             # Run the default password hasher once to reduce the timing |             # Run the default password hasher once to reduce the timing | ||||||
|             # difference between an existing and a nonexistent user (#20760). |             # difference between an existing and a nonexistent user (#20760). | ||||||
|             User().set_password(password) |             User().set_password(password) | ||||||
|             return None |             return None | ||||||
|         # pylint: disable=no-member |  | ||||||
|         tokens = Token.filter_not_expired( |         tokens = Token.filter_not_expired( | ||||||
|             user=user, key=password, intent=TokenIntents.INTENT_APP_PASSWORD |             user=user, key=password, intent=TokenIntents.INTENT_APP_PASSWORD | ||||||
|         ) |         ) | ||||||
|  | |||||||
| @ -38,6 +38,6 @@ class TokenOutpostMiddleware: | |||||||
|                 raise DenyConnection() |                 raise DenyConnection() | ||||||
|         except AuthenticationFailed as exc: |         except AuthenticationFailed as exc: | ||||||
|             LOGGER.warning("Failed to authenticate", exc=exc) |             LOGGER.warning("Failed to authenticate", exc=exc) | ||||||
|             raise DenyConnection() |             raise DenyConnection() from None | ||||||
|  |  | ||||||
|         scope["user"] = user |         scope["user"] = user | ||||||
|  | |||||||
| @ -1,6 +1,6 @@ | |||||||
| """Property Mapping Evaluator""" | """Property Mapping Evaluator""" | ||||||
|  |  | ||||||
| from typing import Any, Optional | from typing import Any | ||||||
|  |  | ||||||
| from django.db.models import Model | from django.db.models import Model | ||||||
| from django.http import HttpRequest | from django.http import HttpRequest | ||||||
| @ -27,9 +27,9 @@ class PropertyMappingEvaluator(BaseEvaluator): | |||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
|         model: Model, |         model: Model, | ||||||
|         user: Optional[User] = None, |         user: User | None = None, | ||||||
|         request: Optional[HttpRequest] = None, |         request: HttpRequest | None = None, | ||||||
|         dry_run: Optional[bool] = False, |         dry_run: bool | None = False, | ||||||
|         **kwargs, |         **kwargs, | ||||||
|     ): |     ): | ||||||
|         if hasattr(model, "name"): |         if hasattr(model, "name"): | ||||||
|  | |||||||
| @ -16,13 +16,8 @@ from authentik.events.middleware import should_log_model | |||||||
| from authentik.events.models import Event, EventAction | from authentik.events.models import Event, EventAction | ||||||
| from authentik.events.utils import model_to_dict | from authentik.events.utils import model_to_dict | ||||||
|  |  | ||||||
| BANNER_TEXT = """### authentik shell ({authentik}) | BANNER_TEXT = f"""### authentik shell ({get_full_version()}) | ||||||
| ### Node {node} | Arch {arch} | Python {python} """.format( | ### Node {platform.node()} | Arch {platform.machine()} | Python {platform.python_version()} """ | ||||||
|     node=platform.node(), |  | ||||||
|     python=platform.python_version(), |  | ||||||
|     arch=platform.machine(), |  | ||||||
|     authentik=get_full_version(), |  | ||||||
| ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class Command(BaseCommand): | class Command(BaseCommand): | ||||||
| @ -86,7 +81,7 @@ class Command(BaseCommand): | |||||||
|  |  | ||||||
|         # If Python code has been passed, execute it and exit. |         # If Python code has been passed, execute it and exit. | ||||||
|         if options["command"]: |         if options["command"]: | ||||||
|             # pylint: disable=exec-used |  | ||||||
|             exec(options["command"], namespace)  # nosec # noqa |             exec(options["command"], namespace)  # nosec # noqa | ||||||
|             return |             return | ||||||
|  |  | ||||||
| @ -99,7 +94,7 @@ class Command(BaseCommand): | |||||||
|         else: |         else: | ||||||
|             try: |             try: | ||||||
|                 hook() |                 hook() | ||||||
|             except Exception:  # pylint: disable=broad-except |             except Exception: | ||||||
|                 # Match the behavior of the cpython shell where an error in |                 # Match the behavior of the cpython shell where an error in | ||||||
|                 # sys.__interactivehook__ prints a warning and the exception |                 # sys.__interactivehook__ prints a warning and the exception | ||||||
|                 # and continues. |                 # and continues. | ||||||
|  | |||||||
| @ -1,7 +1,7 @@ | |||||||
| """authentik admin Middleware to impersonate users""" | """authentik admin Middleware to impersonate users""" | ||||||
|  |  | ||||||
|  | from collections.abc import Callable | ||||||
| from contextvars import ContextVar | from contextvars import ContextVar | ||||||
| from typing import Callable, Optional |  | ||||||
| from uuid import uuid4 | from uuid import uuid4 | ||||||
|  |  | ||||||
| from django.http import HttpRequest, HttpResponse | from django.http import HttpRequest, HttpResponse | ||||||
| @ -15,9 +15,9 @@ RESPONSE_HEADER_ID = "X-authentik-id" | |||||||
| KEY_AUTH_VIA = "auth_via" | KEY_AUTH_VIA = "auth_via" | ||||||
| KEY_USER = "user" | KEY_USER = "user" | ||||||
|  |  | ||||||
| CTX_REQUEST_ID = ContextVar[Optional[str]](STRUCTLOG_KEY_PREFIX + "request_id", default=None) | CTX_REQUEST_ID = ContextVar[str | None](STRUCTLOG_KEY_PREFIX + "request_id", default=None) | ||||||
| CTX_HOST = ContextVar[Optional[str]](STRUCTLOG_KEY_PREFIX + "host", default=None) | CTX_HOST = ContextVar[str | None](STRUCTLOG_KEY_PREFIX + "host", default=None) | ||||||
| CTX_AUTH_VIA = ContextVar[Optional[str]](STRUCTLOG_KEY_PREFIX + KEY_AUTH_VIA, default=None) | CTX_AUTH_VIA = ContextVar[str | None](STRUCTLOG_KEY_PREFIX + KEY_AUTH_VIA, default=None) | ||||||
|  |  | ||||||
|  |  | ||||||
| class ImpersonateMiddleware: | class ImpersonateMiddleware: | ||||||
| @ -55,7 +55,7 @@ class RequestIDMiddleware: | |||||||
|     def __call__(self, request: HttpRequest) -> HttpResponse: |     def __call__(self, request: HttpRequest) -> HttpResponse: | ||||||
|         if not hasattr(request, "request_id"): |         if not hasattr(request, "request_id"): | ||||||
|             request_id = uuid4().hex |             request_id = uuid4().hex | ||||||
|             setattr(request, "request_id", request_id) |             request.request_id = request_id | ||||||
|             CTX_REQUEST_ID.set(request_id) |             CTX_REQUEST_ID.set(request_id) | ||||||
|             CTX_HOST.set(request.get_host()) |             CTX_HOST.set(request.get_host()) | ||||||
|             set_tag("authentik.request_id", request_id) |             set_tag("authentik.request_id", request_id) | ||||||
| @ -67,7 +67,7 @@ class RequestIDMiddleware: | |||||||
|         response = self.get_response(request) |         response = self.get_response(request) | ||||||
|  |  | ||||||
|         response[RESPONSE_HEADER_ID] = request.request_id |         response[RESPONSE_HEADER_ID] = request.request_id | ||||||
|         setattr(response, "ak_context", {}) |         response.ak_context = {} | ||||||
|         response.ak_context["request_id"] = CTX_REQUEST_ID.get() |         response.ak_context["request_id"] = CTX_REQUEST_ID.get() | ||||||
|         response.ak_context["host"] = CTX_HOST.get() |         response.ak_context["host"] = CTX_HOST.get() | ||||||
|         response.ak_context[KEY_AUTH_VIA] = CTX_AUTH_VIA.get() |         response.ak_context[KEY_AUTH_VIA] = CTX_AUTH_VIA.get() | ||||||
|  | |||||||
| @ -222,7 +222,7 @@ class User(SerializerModel, GuardianUserMixin, AbstractUser): | |||||||
|         there are at most 3 queries done""" |         there are at most 3 queries done""" | ||||||
|         return Group.children_recursive(self.ak_groups.all()) |         return Group.children_recursive(self.ak_groups.all()) | ||||||
|  |  | ||||||
|     def group_attributes(self, request: Optional[HttpRequest] = None) -> dict[str, Any]: |     def group_attributes(self, request: HttpRequest | None = None) -> dict[str, Any]: | ||||||
|         """Get a dictionary containing the attributes from all groups the user belongs to, |         """Get a dictionary containing the attributes from all groups the user belongs to, | ||||||
|         including the users attributes""" |         including the users attributes""" | ||||||
|         final_attributes = {} |         final_attributes = {} | ||||||
| @ -278,11 +278,11 @@ class User(SerializerModel, GuardianUserMixin, AbstractUser): | |||||||
|         """Generate a globally unique UID, based on the user ID and the hashed secret key""" |         """Generate a globally unique UID, based on the user ID and the hashed secret key""" | ||||||
|         return sha256(f"{self.id}-{get_install_id()}".encode("ascii")).hexdigest() |         return sha256(f"{self.id}-{get_install_id()}".encode("ascii")).hexdigest() | ||||||
|  |  | ||||||
|     def locale(self, request: Optional[HttpRequest] = None) -> str: |     def locale(self, request: HttpRequest | None = None) -> str: | ||||||
|         """Get the locale the user has configured""" |         """Get the locale the user has configured""" | ||||||
|         try: |         try: | ||||||
|             return self.attributes.get("settings", {}).get("locale", "") |             return self.attributes.get("settings", {}).get("locale", "") | ||||||
|         # pylint: disable=broad-except |  | ||||||
|         except Exception as exc: |         except Exception as exc: | ||||||
|             LOGGER.warning("Failed to get default locale", exc=exc) |             LOGGER.warning("Failed to get default locale", exc=exc) | ||||||
|         if request: |         if request: | ||||||
| @ -358,7 +358,7 @@ class Provider(SerializerModel): | |||||||
|     objects = InheritanceManager() |     objects = InheritanceManager() | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def launch_url(self) -> Optional[str]: |     def launch_url(self) -> str | None: | ||||||
|         """URL to this provider and initiate authorization for the user. |         """URL to this provider and initiate authorization for the user. | ||||||
|         Can return None for providers that are not URL-based""" |         Can return None for providers that are not URL-based""" | ||||||
|         return None |         return None | ||||||
| @ -435,7 +435,7 @@ class Application(SerializerModel, PolicyBindingModel): | |||||||
|         return ApplicationSerializer |         return ApplicationSerializer | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def get_meta_icon(self) -> Optional[str]: |     def get_meta_icon(self) -> str | None: | ||||||
|         """Get the URL to the App Icon image. If the name is /static or starts with http |         """Get the URL to the App Icon image. If the name is /static or starts with http | ||||||
|         it is returned as-is""" |         it is returned as-is""" | ||||||
|         if not self.meta_icon: |         if not self.meta_icon: | ||||||
| @ -444,7 +444,7 @@ class Application(SerializerModel, PolicyBindingModel): | |||||||
|             return self.meta_icon.name |             return self.meta_icon.name | ||||||
|         return self.meta_icon.url |         return self.meta_icon.url | ||||||
|  |  | ||||||
|     def get_launch_url(self, user: Optional["User"] = None) -> Optional[str]: |     def get_launch_url(self, user: Optional["User"] = None) -> str | None: | ||||||
|         """Get launch URL if set, otherwise attempt to get launch URL based on provider.""" |         """Get launch URL if set, otherwise attempt to get launch URL based on provider.""" | ||||||
|         url = None |         url = None | ||||||
|         if self.meta_launch_url: |         if self.meta_launch_url: | ||||||
| @ -457,13 +457,13 @@ class Application(SerializerModel, PolicyBindingModel): | |||||||
|                 user = user._wrapped |                 user = user._wrapped | ||||||
|             try: |             try: | ||||||
|                 return url % user.__dict__ |                 return url % user.__dict__ | ||||||
|             # pylint: disable=broad-except |  | ||||||
|             except Exception as exc: |             except Exception as exc: | ||||||
|                 LOGGER.warning("Failed to format launch url", exc=exc) |                 LOGGER.warning("Failed to format launch url", exc=exc) | ||||||
|                 return url |                 return url | ||||||
|         return url |         return url | ||||||
|  |  | ||||||
|     def get_provider(self) -> Optional[Provider]: |     def get_provider(self) -> Provider | None: | ||||||
|         """Get casted provider instance""" |         """Get casted provider instance""" | ||||||
|         if not self.provider: |         if not self.provider: | ||||||
|             return None |             return None | ||||||
| @ -551,7 +551,7 @@ class Source(ManagedModel, SerializerModel, PolicyBindingModel): | |||||||
|     objects = InheritanceManager() |     objects = InheritanceManager() | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def icon_url(self) -> Optional[str]: |     def icon_url(self) -> str | None: | ||||||
|         """Get the URL to the Icon. If the name is /static or |         """Get the URL to the Icon. If the name is /static or | ||||||
|         starts with http it is returned as-is""" |         starts with http it is returned as-is""" | ||||||
|         if not self.icon: |         if not self.icon: | ||||||
| @ -566,7 +566,7 @@ class Source(ManagedModel, SerializerModel, PolicyBindingModel): | |||||||
|             return self.user_path_template % { |             return self.user_path_template % { | ||||||
|                 "slug": self.slug, |                 "slug": self.slug, | ||||||
|             } |             } | ||||||
|         # pylint: disable=broad-except |  | ||||||
|         except Exception as exc: |         except Exception as exc: | ||||||
|             LOGGER.warning("Failed to template user path", exc=exc, source=self) |             LOGGER.warning("Failed to template user path", exc=exc, source=self) | ||||||
|             return User.default_path() |             return User.default_path() | ||||||
| @ -576,12 +576,12 @@ class Source(ManagedModel, SerializerModel, PolicyBindingModel): | |||||||
|         """Return component used to edit this object""" |         """Return component used to edit this object""" | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|     def ui_login_button(self, request: HttpRequest) -> Optional[UILoginButton]: |     def ui_login_button(self, request: HttpRequest) -> UILoginButton | None: | ||||||
|         """If source uses a http-based flow, return UI Information about the login |         """If source uses a http-based flow, return UI Information about the login | ||||||
|         button. If source doesn't use http-based flow, return None.""" |         button. If source doesn't use http-based flow, return None.""" | ||||||
|         return None |         return None | ||||||
|  |  | ||||||
|     def ui_user_settings(self) -> Optional[UserSettingSerializer]: |     def ui_user_settings(self) -> UserSettingSerializer | None: | ||||||
|         """Entrypoint to integrate with User settings. Can either return None if no |         """Entrypoint to integrate with User settings. Can either return None if no | ||||||
|         user settings are available, or UserSettingSerializer.""" |         user settings are available, or UserSettingSerializer.""" | ||||||
|         return None |         return None | ||||||
| @ -627,6 +627,9 @@ class ExpiringModel(models.Model): | |||||||
|     expires = models.DateTimeField(default=default_token_duration) |     expires = models.DateTimeField(default=default_token_duration) | ||||||
|     expiring = models.BooleanField(default=True) |     expiring = models.BooleanField(default=True) | ||||||
|  |  | ||||||
|  |     class Meta: | ||||||
|  |         abstract = True | ||||||
|  |  | ||||||
|     def expire_action(self, *args, **kwargs): |     def expire_action(self, *args, **kwargs): | ||||||
|         """Handler which is called when this object is expired. By |         """Handler which is called when this object is expired. By | ||||||
|         default the object is deleted. This is less efficient compared |         default the object is deleted. This is less efficient compared | ||||||
| @ -649,9 +652,6 @@ class ExpiringModel(models.Model): | |||||||
|             return False |             return False | ||||||
|         return now() > self.expires |         return now() > self.expires | ||||||
|  |  | ||||||
|     class Meta: |  | ||||||
|         abstract = True |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class TokenIntents(models.TextChoices): | class TokenIntents(models.TextChoices): | ||||||
|     """Intents a Token can be created for.""" |     """Intents a Token can be created for.""" | ||||||
| @ -681,6 +681,21 @@ class Token(SerializerModel, ManagedModel, ExpiringModel): | |||||||
|     user = models.ForeignKey("User", on_delete=models.CASCADE, related_name="+") |     user = models.ForeignKey("User", on_delete=models.CASCADE, related_name="+") | ||||||
|     description = models.TextField(default="", blank=True) |     description = models.TextField(default="", blank=True) | ||||||
|  |  | ||||||
|  |     class Meta: | ||||||
|  |         verbose_name = _("Token") | ||||||
|  |         verbose_name_plural = _("Tokens") | ||||||
|  |         indexes = [ | ||||||
|  |             models.Index(fields=["identifier"]), | ||||||
|  |             models.Index(fields=["key"]), | ||||||
|  |         ] | ||||||
|  |         permissions = [("view_token_key", _("View token's key"))] | ||||||
|  |  | ||||||
|  |     def __str__(self): | ||||||
|  |         description = f"{self.identifier}" | ||||||
|  |         if self.expiring: | ||||||
|  |             description += f" (expires={self.expires})" | ||||||
|  |         return description | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def serializer(self) -> type[Serializer]: |     def serializer(self) -> type[Serializer]: | ||||||
|         from authentik.core.api.tokens import TokenSerializer |         from authentik.core.api.tokens import TokenSerializer | ||||||
| @ -708,21 +723,6 @@ class Token(SerializerModel, ManagedModel, ExpiringModel): | |||||||
|             message=f"Token {self.identifier}'s secret was rotated.", |             message=f"Token {self.identifier}'s secret was rotated.", | ||||||
|         ).save() |         ).save() | ||||||
|  |  | ||||||
|     def __str__(self): |  | ||||||
|         description = f"{self.identifier}" |  | ||||||
|         if self.expiring: |  | ||||||
|             description += f" (expires={self.expires})" |  | ||||||
|         return description |  | ||||||
|  |  | ||||||
|     class Meta: |  | ||||||
|         verbose_name = _("Token") |  | ||||||
|         verbose_name_plural = _("Tokens") |  | ||||||
|         indexes = [ |  | ||||||
|             models.Index(fields=["identifier"]), |  | ||||||
|             models.Index(fields=["key"]), |  | ||||||
|         ] |  | ||||||
|         permissions = [("view_token_key", _("View token's key"))] |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class PropertyMapping(SerializerModel, ManagedModel): | class PropertyMapping(SerializerModel, ManagedModel): | ||||||
|     """User-defined key -> x mapping which can be used by providers to expose extra data.""" |     """User-defined key -> x mapping which can be used by providers to expose extra data.""" | ||||||
| @ -743,7 +743,7 @@ class PropertyMapping(SerializerModel, ManagedModel): | |||||||
|         """Get serializer for this model""" |         """Get serializer for this model""" | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|     def evaluate(self, user: Optional[User], request: Optional[HttpRequest], **kwargs) -> Any: |     def evaluate(self, user: User | None, request: HttpRequest | None, **kwargs) -> Any: | ||||||
|         """Evaluate `self.expression` using `**kwargs` as Context.""" |         """Evaluate `self.expression` using `**kwargs` as Context.""" | ||||||
|         from authentik.core.expression.evaluator import PropertyMappingEvaluator |         from authentik.core.expression.evaluator import PropertyMappingEvaluator | ||||||
|  |  | ||||||
| @ -779,6 +779,13 @@ class AuthenticatedSession(ExpiringModel): | |||||||
|     last_user_agent = models.TextField(blank=True) |     last_user_agent = models.TextField(blank=True) | ||||||
|     last_used = models.DateTimeField(auto_now=True) |     last_used = models.DateTimeField(auto_now=True) | ||||||
|  |  | ||||||
|  |     class Meta: | ||||||
|  |         verbose_name = _("Authenticated Session") | ||||||
|  |         verbose_name_plural = _("Authenticated Sessions") | ||||||
|  |  | ||||||
|  |     def __str__(self) -> str: | ||||||
|  |         return f"Authenticated Session {self.session_key[:10]}" | ||||||
|  |  | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def from_request(request: HttpRequest, user: User) -> Optional["AuthenticatedSession"]: |     def from_request(request: HttpRequest, user: User) -> Optional["AuthenticatedSession"]: | ||||||
|         """Create a new session from a http request""" |         """Create a new session from a http request""" | ||||||
| @ -793,7 +800,3 @@ class AuthenticatedSession(ExpiringModel): | |||||||
|             last_user_agent=request.META.get("HTTP_USER_AGENT", ""), |             last_user_agent=request.META.get("HTTP_USER_AGENT", ""), | ||||||
|             expires=request.session.get_expiry_date(), |             expires=request.session.get_expiry_date(), | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     class Meta: |  | ||||||
|         verbose_name = _("Authenticated Session") |  | ||||||
|         verbose_name_plural = _("Authenticated Sessions") |  | ||||||
|  | |||||||
| @ -1,7 +1,7 @@ | |||||||
| """Source decision helper""" | """Source decision helper""" | ||||||
|  |  | ||||||
| from enum import Enum | from enum import Enum | ||||||
| from typing import Any, Optional | from typing import Any | ||||||
|  |  | ||||||
| from django.contrib import messages | from django.contrib import messages | ||||||
| from django.db import IntegrityError | from django.db import IntegrityError | ||||||
| @ -90,15 +90,14 @@ class SourceFlowManager: | |||||||
|         self._logger = get_logger().bind(source=source, identifier=identifier) |         self._logger = get_logger().bind(source=source, identifier=identifier) | ||||||
|         self.policy_context = {} |         self.policy_context = {} | ||||||
|  |  | ||||||
|     # pylint: disable=too-many-return-statements |     def get_action(self, **kwargs) -> tuple[Action, UserSourceConnection | None]:  # noqa: PLR0911 | ||||||
|     def get_action(self, **kwargs) -> tuple[Action, Optional[UserSourceConnection]]: |  | ||||||
|         """decide which action should be taken""" |         """decide which action should be taken""" | ||||||
|         new_connection = self.connection_type(source=self.source, identifier=self.identifier) |         new_connection = self.connection_type(source=self.source, identifier=self.identifier) | ||||||
|         # When request is authenticated, always link |         # When request is authenticated, always link | ||||||
|         if self.request.user.is_authenticated: |         if self.request.user.is_authenticated: | ||||||
|             new_connection.user = self.request.user |             new_connection.user = self.request.user | ||||||
|             new_connection = self.update_connection(new_connection, **kwargs) |             new_connection = self.update_connection(new_connection, **kwargs) | ||||||
|             # pylint: disable=no-member |  | ||||||
|             new_connection.save() |             new_connection.save() | ||||||
|             return Action.LINK, new_connection |             return Action.LINK, new_connection | ||||||
|  |  | ||||||
| @ -188,8 +187,10 @@ class SourceFlowManager: | |||||||
|         # Default case, assume deny |         # Default case, assume deny | ||||||
|         error = Exception( |         error = Exception( | ||||||
|             _( |             _( | ||||||
|                 "Request to authenticate with %(source)s has been denied. Please authenticate " |                 "Request to authenticate with {source} has been denied. Please authenticate " | ||||||
|                 "with the source you've previously signed up with." % {"source": self.source.name} |                 "with the source you've previously signed up with.".format_map( | ||||||
|  |                     {"source": self.source.name} | ||||||
|  |                 ) | ||||||
|             ), |             ), | ||||||
|         ) |         ) | ||||||
|         return self.error_handler(error) |         return self.error_handler(error) | ||||||
| @ -217,7 +218,7 @@ class SourceFlowManager: | |||||||
|         self, |         self, | ||||||
|         flow: Flow, |         flow: Flow, | ||||||
|         connection: UserSourceConnection, |         connection: UserSourceConnection, | ||||||
|         stages: Optional[list[StageView]] = None, |         stages: list[StageView] | None = None, | ||||||
|         **kwargs, |         **kwargs, | ||||||
|     ) -> HttpResponse: |     ) -> HttpResponse: | ||||||
|         """Prepare Authentication Plan, redirect user FlowExecutor""" |         """Prepare Authentication Plan, redirect user FlowExecutor""" | ||||||
| @ -270,7 +271,9 @@ class SourceFlowManager: | |||||||
|                 in_memory_stage( |                 in_memory_stage( | ||||||
|                     MessageStage, |                     MessageStage, | ||||||
|                     message=_( |                     message=_( | ||||||
|                         "Successfully authenticated with %(source)s!" % {"source": self.source.name} |                         "Successfully authenticated with {source}!".format_map( | ||||||
|  |                             {"source": self.source.name} | ||||||
|  |                         ) | ||||||
|                     ), |                     ), | ||||||
|                 ) |                 ) | ||||||
|             ], |             ], | ||||||
| @ -294,7 +297,7 @@ class SourceFlowManager: | |||||||
|         ).from_http(self.request) |         ).from_http(self.request) | ||||||
|         messages.success( |         messages.success( | ||||||
|             self.request, |             self.request, | ||||||
|             _("Successfully linked %(source)s!" % {"source": self.source.name}), |             _("Successfully linked {source}!".format_map({"source": self.source.name})), | ||||||
|         ) |         ) | ||||||
|         return redirect( |         return redirect( | ||||||
|             reverse( |             reverse( | ||||||
| @ -322,7 +325,9 @@ class SourceFlowManager: | |||||||
|                 in_memory_stage( |                 in_memory_stage( | ||||||
|                     MessageStage, |                     MessageStage, | ||||||
|                     message=_( |                     message=_( | ||||||
|                         "Successfully authenticated with %(source)s!" % {"source": self.source.name} |                         "Successfully authenticated with {source}!".format_map( | ||||||
|  |                             {"source": self.source.name} | ||||||
|  |                         ) | ||||||
|                     ), |                     ), | ||||||
|                 ) |                 ) | ||||||
|             ], |             ], | ||||||
|  | |||||||
| @ -37,20 +37,20 @@ def clean_expired_models(self: SystemTask): | |||||||
|         messages.append(f"Expired {amount} {cls._meta.verbose_name_plural}") |         messages.append(f"Expired {amount} {cls._meta.verbose_name_plural}") | ||||||
|     # Special case |     # Special case | ||||||
|     amount = 0 |     amount = 0 | ||||||
|     # pylint: disable=no-member |  | ||||||
|     for session in AuthenticatedSession.objects.all(): |     for session in AuthenticatedSession.objects.all(): | ||||||
|         cache_key = f"{KEY_PREFIX}{session.session_key}" |         cache_key = f"{KEY_PREFIX}{session.session_key}" | ||||||
|         value = None |         value = None | ||||||
|         try: |         try: | ||||||
|             value = cache.get(cache_key) |             value = cache.get(cache_key) | ||||||
|         # pylint: disable=broad-except |  | ||||||
|         except Exception as exc: |         except Exception as exc: | ||||||
|             LOGGER.debug("Failed to get session from cache", exc=exc) |             LOGGER.debug("Failed to get session from cache", exc=exc) | ||||||
|         if not value: |         if not value: | ||||||
|             session.delete() |             session.delete() | ||||||
|             amount += 1 |             amount += 1 | ||||||
|     LOGGER.debug("Expired sessions", model=AuthenticatedSession, amount=amount) |     LOGGER.debug("Expired sessions", model=AuthenticatedSession, amount=amount) | ||||||
|     # pylint: disable=no-member |  | ||||||
|     messages.append(f"Expired {amount} {AuthenticatedSession._meta.verbose_name_plural}") |     messages.append(f"Expired {amount} {AuthenticatedSession._meta.verbose_name_plural}") | ||||||
|     self.set_status(TaskStatus.SUCCESSFUL, *messages) |     self.set_status(TaskStatus.SUCCESSFUL, *messages) | ||||||
|  |  | ||||||
|  | |||||||
| @ -1,7 +1,7 @@ | |||||||
| """authentik core models tests""" | """authentik core models tests""" | ||||||
|  |  | ||||||
|  | from collections.abc import Callable | ||||||
| from time import sleep | from time import sleep | ||||||
| from typing import Callable |  | ||||||
|  |  | ||||||
| from django.test import RequestFactory, TestCase | from django.test import RequestFactory, TestCase | ||||||
| from django.utils.timezone import now | from django.utils.timezone import now | ||||||
|  | |||||||
| @ -173,5 +173,5 @@ class TestSourceFlowManager(TestCase): | |||||||
|         self.assertEqual(action, Action.ENROLL) |         self.assertEqual(action, Action.ENROLL) | ||||||
|         response = flow_manager.get_flow() |         response = flow_manager.get_flow() | ||||||
|         self.assertIsInstance(response, AccessDeniedResponse) |         self.assertIsInstance(response, AccessDeniedResponse) | ||||||
|         # pylint: disable=no-member |  | ||||||
|         self.assertEqual(response.error_message, "foo") |         self.assertEqual(response.error_message, "foo") | ||||||
|  | |||||||
| @ -60,10 +60,11 @@ class TestUsersAPI(APITestCase): | |||||||
|     def test_recovery_no_flow(self): |     def test_recovery_no_flow(self): | ||||||
|         """Test user recovery link (no recovery flow set)""" |         """Test user recovery link (no recovery flow set)""" | ||||||
|         self.client.force_login(self.admin) |         self.client.force_login(self.admin) | ||||||
|         response = self.client.get( |         response = self.client.post( | ||||||
|             reverse("authentik_api:user-recovery", kwargs={"pk": self.user.pk}) |             reverse("authentik_api:user-recovery", kwargs={"pk": self.user.pk}) | ||||||
|         ) |         ) | ||||||
|         self.assertEqual(response.status_code, 404) |         self.assertEqual(response.status_code, 400) | ||||||
|  |         self.assertJSONEqual(response.content, {"non_field_errors": "No recovery flow set."}) | ||||||
|  |  | ||||||
|     def test_set_password(self): |     def test_set_password(self): | ||||||
|         """Test Direct password set""" |         """Test Direct password set""" | ||||||
| @ -84,7 +85,7 @@ class TestUsersAPI(APITestCase): | |||||||
|         brand.flow_recovery = flow |         brand.flow_recovery = flow | ||||||
|         brand.save() |         brand.save() | ||||||
|         self.client.force_login(self.admin) |         self.client.force_login(self.admin) | ||||||
|         response = self.client.get( |         response = self.client.post( | ||||||
|             reverse("authentik_api:user-recovery", kwargs={"pk": self.user.pk}) |             reverse("authentik_api:user-recovery", kwargs={"pk": self.user.pk}) | ||||||
|         ) |         ) | ||||||
|         self.assertEqual(response.status_code, 200) |         self.assertEqual(response.status_code, 200) | ||||||
| @ -92,16 +93,20 @@ class TestUsersAPI(APITestCase): | |||||||
|     def test_recovery_email_no_flow(self): |     def test_recovery_email_no_flow(self): | ||||||
|         """Test user recovery link (no recovery flow set)""" |         """Test user recovery link (no recovery flow set)""" | ||||||
|         self.client.force_login(self.admin) |         self.client.force_login(self.admin) | ||||||
|         response = self.client.get( |         response = self.client.post( | ||||||
|             reverse("authentik_api:user-recovery-email", kwargs={"pk": self.user.pk}) |             reverse("authentik_api:user-recovery-email", kwargs={"pk": self.user.pk}) | ||||||
|         ) |         ) | ||||||
|         self.assertEqual(response.status_code, 404) |         self.assertEqual(response.status_code, 400) | ||||||
|  |         self.assertJSONEqual( | ||||||
|  |             response.content, {"non_field_errors": "User does not have an email address set."} | ||||||
|  |         ) | ||||||
|         self.user.email = "foo@bar.baz" |         self.user.email = "foo@bar.baz" | ||||||
|         self.user.save() |         self.user.save() | ||||||
|         response = self.client.get( |         response = self.client.post( | ||||||
|             reverse("authentik_api:user-recovery-email", kwargs={"pk": self.user.pk}) |             reverse("authentik_api:user-recovery-email", kwargs={"pk": self.user.pk}) | ||||||
|         ) |         ) | ||||||
|         self.assertEqual(response.status_code, 404) |         self.assertEqual(response.status_code, 400) | ||||||
|  |         self.assertJSONEqual(response.content, {"non_field_errors": "No recovery flow set."}) | ||||||
|  |  | ||||||
|     def test_recovery_email_no_stage(self): |     def test_recovery_email_no_stage(self): | ||||||
|         """Test user recovery link (no email stage)""" |         """Test user recovery link (no email stage)""" | ||||||
| @ -112,10 +117,11 @@ class TestUsersAPI(APITestCase): | |||||||
|         brand.flow_recovery = flow |         brand.flow_recovery = flow | ||||||
|         brand.save() |         brand.save() | ||||||
|         self.client.force_login(self.admin) |         self.client.force_login(self.admin) | ||||||
|         response = self.client.get( |         response = self.client.post( | ||||||
|             reverse("authentik_api:user-recovery-email", kwargs={"pk": self.user.pk}) |             reverse("authentik_api:user-recovery-email", kwargs={"pk": self.user.pk}) | ||||||
|         ) |         ) | ||||||
|         self.assertEqual(response.status_code, 404) |         self.assertEqual(response.status_code, 400) | ||||||
|  |         self.assertJSONEqual(response.content, {"non_field_errors": "Email stage does not exist."}) | ||||||
|  |  | ||||||
|     def test_recovery_email(self): |     def test_recovery_email(self): | ||||||
|         """Test user recovery link""" |         """Test user recovery link""" | ||||||
| @ -129,7 +135,7 @@ class TestUsersAPI(APITestCase): | |||||||
|         stage = EmailStage.objects.create(name="email") |         stage = EmailStage.objects.create(name="email") | ||||||
|  |  | ||||||
|         self.client.force_login(self.admin) |         self.client.force_login(self.admin) | ||||||
|         response = self.client.get( |         response = self.client.post( | ||||||
|             reverse( |             reverse( | ||||||
|                 "authentik_api:user-recovery-email", |                 "authentik_api:user-recovery-email", | ||||||
|                 kwargs={"pk": self.user.pk}, |                 kwargs={"pk": self.user.pk}, | ||||||
|  | |||||||
| @ -1,7 +1,5 @@ | |||||||
| """Test Utils""" | """Test Utils""" | ||||||
|  |  | ||||||
| from typing import Optional |  | ||||||
|  |  | ||||||
| from django.utils.text import slugify | from django.utils.text import slugify | ||||||
|  |  | ||||||
| from authentik.brands.models import Brand | from authentik.brands.models import Brand | ||||||
| @ -22,7 +20,7 @@ def create_test_flow( | |||||||
|     ) |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
| def create_test_user(name: Optional[str] = None, **kwargs) -> User: | def create_test_user(name: str | None = None, **kwargs) -> User: | ||||||
|     """Generate a test user""" |     """Generate a test user""" | ||||||
|     uid = generate_id(20) if not name else name |     uid = generate_id(20) if not name else name | ||||||
|     kwargs.setdefault("email", f"{uid}@goauthentik.io") |     kwargs.setdefault("email", f"{uid}@goauthentik.io") | ||||||
| @ -36,7 +34,7 @@ def create_test_user(name: Optional[str] = None, **kwargs) -> User: | |||||||
|     return user |     return user | ||||||
|  |  | ||||||
|  |  | ||||||
| def create_test_admin_user(name: Optional[str] = None, **kwargs) -> User: | def create_test_admin_user(name: str | None = None, **kwargs) -> User: | ||||||
|     """Generate a test-admin user""" |     """Generate a test-admin user""" | ||||||
|     user = create_test_user(name, **kwargs) |     user = create_test_user(name, **kwargs) | ||||||
|     group = Group.objects.create(name=user.name or name, is_superuser=True) |     group = Group.objects.create(name=user.name or name, is_superuser=True) | ||||||
|  | |||||||
| @ -1,7 +1,6 @@ | |||||||
| """authentik core dataclasses""" | """authentik core dataclasses""" | ||||||
|  |  | ||||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||||
| from typing import Optional |  | ||||||
|  |  | ||||||
| from rest_framework.fields import CharField | from rest_framework.fields import CharField | ||||||
|  |  | ||||||
| @ -20,7 +19,7 @@ class UILoginButton: | |||||||
|     challenge: Challenge |     challenge: Challenge | ||||||
|  |  | ||||||
|     # Icon URL, used as-is |     # Icon URL, used as-is | ||||||
|     icon_url: Optional[str] = None |     icon_url: str | None = None | ||||||
|  |  | ||||||
|  |  | ||||||
| class UserSettingSerializer(PassiveSerializer): | class UserSettingSerializer(PassiveSerializer): | ||||||
|  | |||||||
| @ -57,7 +57,7 @@ class RedirectToAppLaunch(View): | |||||||
|                 }, |                 }, | ||||||
|             ) |             ) | ||||||
|         except FlowNonApplicableException: |         except FlowNonApplicableException: | ||||||
|             raise Http404 |             raise Http404 from None | ||||||
|         plan.insert_stage(in_memory_stage(RedirectToAppStage)) |         plan.insert_stage(in_memory_stage(RedirectToAppStage)) | ||||||
|         request.session[SESSION_KEY_PLAN] = plan |         request.session[SESSION_KEY_PLAN] = plan | ||||||
|         return redirect_with_qs("authentik_core:if-flow", request.GET, flow_slug=flow.slug) |         return redirect_with_qs("authentik_core:if-flow", request.GET, flow_slug=flow.slug) | ||||||
|  | |||||||
| @ -61,7 +61,6 @@ class ServerErrorView(TemplateView): | |||||||
|     response_class = ServerErrorTemplateResponse |     response_class = ServerErrorTemplateResponse | ||||||
|     template_name = "if/error.html" |     template_name = "if/error.html" | ||||||
|  |  | ||||||
|     # pylint: disable=useless-super-delegation |  | ||||||
|     def dispatch(self, *args, **kwargs):  # pragma: no cover |     def dispatch(self, *args, **kwargs):  # pragma: no cover | ||||||
|         """Little wrapper so django accepts this function""" |         """Little wrapper so django accepts this function""" | ||||||
|         return super().dispatch(*args, **kwargs) |         return super().dispatch(*args, **kwargs) | ||||||
|  | |||||||
| @ -1,7 +1,6 @@ | |||||||
| """Crypto API Views""" | """Crypto API Views""" | ||||||
|  |  | ||||||
| from datetime import datetime | from datetime import datetime | ||||||
| from typing import Optional |  | ||||||
|  |  | ||||||
| from cryptography.hazmat.backends import default_backend | from cryptography.hazmat.backends import default_backend | ||||||
| from cryptography.hazmat.primitives.serialization import load_pem_private_key | from cryptography.hazmat.primitives.serialization import load_pem_private_key | ||||||
| @ -56,25 +55,25 @@ class CertificateKeyPairSerializer(ModelSerializer): | |||||||
|             return True |             return True | ||||||
|         return str(request.query_params.get("include_details", "true")).lower() == "true" |         return str(request.query_params.get("include_details", "true")).lower() == "true" | ||||||
|  |  | ||||||
|     def get_fingerprint_sha256(self, instance: CertificateKeyPair) -> Optional[str]: |     def get_fingerprint_sha256(self, instance: CertificateKeyPair) -> str | None: | ||||||
|         "Get certificate Hash (SHA256)" |         "Get certificate Hash (SHA256)" | ||||||
|         if not self._should_include_details: |         if not self._should_include_details: | ||||||
|             return None |             return None | ||||||
|         return instance.fingerprint_sha256 |         return instance.fingerprint_sha256 | ||||||
|  |  | ||||||
|     def get_fingerprint_sha1(self, instance: CertificateKeyPair) -> Optional[str]: |     def get_fingerprint_sha1(self, instance: CertificateKeyPair) -> str | None: | ||||||
|         "Get certificate Hash (SHA1)" |         "Get certificate Hash (SHA1)" | ||||||
|         if not self._should_include_details: |         if not self._should_include_details: | ||||||
|             return None |             return None | ||||||
|         return instance.fingerprint_sha1 |         return instance.fingerprint_sha1 | ||||||
|  |  | ||||||
|     def get_cert_expiry(self, instance: CertificateKeyPair) -> Optional[datetime]: |     def get_cert_expiry(self, instance: CertificateKeyPair) -> datetime | None: | ||||||
|         "Get certificate expiry" |         "Get certificate expiry" | ||||||
|         if not self._should_include_details: |         if not self._should_include_details: | ||||||
|             return None |             return None | ||||||
|         return DateTimeField().to_representation(instance.certificate.not_valid_after) |         return DateTimeField().to_representation(instance.certificate.not_valid_after) | ||||||
|  |  | ||||||
|     def get_cert_subject(self, instance: CertificateKeyPair) -> Optional[str]: |     def get_cert_subject(self, instance: CertificateKeyPair) -> str | None: | ||||||
|         """Get certificate subject as full rfc4514""" |         """Get certificate subject as full rfc4514""" | ||||||
|         if not self._should_include_details: |         if not self._should_include_details: | ||||||
|             return None |             return None | ||||||
| @ -84,7 +83,7 @@ class CertificateKeyPairSerializer(ModelSerializer): | |||||||
|         """Show if this keypair has a private key configured or not""" |         """Show if this keypair has a private key configured or not""" | ||||||
|         return instance.key_data != "" and instance.key_data is not None |         return instance.key_data != "" and instance.key_data is not None | ||||||
|  |  | ||||||
|     def get_private_key_type(self, instance: CertificateKeyPair) -> Optional[str]: |     def get_private_key_type(self, instance: CertificateKeyPair) -> str | None: | ||||||
|         """Get the private key's type, if set""" |         """Get the private key's type, if set""" | ||||||
|         if not self._should_include_details: |         if not self._should_include_details: | ||||||
|             return None |             return None | ||||||
| @ -121,7 +120,7 @@ class CertificateKeyPairSerializer(ModelSerializer): | |||||||
|             str(load_pem_x509_certificate(value.encode("utf-8"), default_backend())) |             str(load_pem_x509_certificate(value.encode("utf-8"), default_backend())) | ||||||
|         except ValueError as exc: |         except ValueError as exc: | ||||||
|             LOGGER.warning("Failed to load certificate", exc=exc) |             LOGGER.warning("Failed to load certificate", exc=exc) | ||||||
|             raise ValidationError("Unable to load certificate.") |             raise ValidationError("Unable to load certificate.") from None | ||||||
|         return value |         return value | ||||||
|  |  | ||||||
|     def validate_key_data(self, value: str) -> str: |     def validate_key_data(self, value: str) -> str: | ||||||
| @ -140,7 +139,7 @@ class CertificateKeyPairSerializer(ModelSerializer): | |||||||
|                 ) |                 ) | ||||||
|             except (ValueError, TypeError) as exc: |             except (ValueError, TypeError) as exc: | ||||||
|                 LOGGER.warning("Failed to load private key", exc=exc) |                 LOGGER.warning("Failed to load private key", exc=exc) | ||||||
|                 raise ValidationError("Unable to load private key (possibly encrypted?).") |                 raise ValidationError("Unable to load private key (possibly encrypted?).") from None | ||||||
|         return value |         return value | ||||||
|  |  | ||||||
|     class Meta: |     class Meta: | ||||||
|  | |||||||
| @ -1,7 +1,6 @@ | |||||||
| """authentik crypto app config""" | """authentik crypto app config""" | ||||||
|  |  | ||||||
| from datetime import datetime, timezone | from datetime import UTC, datetime | ||||||
| from typing import Optional |  | ||||||
|  |  | ||||||
| from authentik.blueprints.apps import ManagedAppConfig | from authentik.blueprints.apps import ManagedAppConfig | ||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
| @ -36,20 +35,22 @@ class AuthentikCryptoConfig(ManagedAppConfig): | |||||||
|             }, |             }, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def reconcile_tenant_managed_jwt_cert(self): |     @ManagedAppConfig.reconcile_tenant | ||||||
|  |     def managed_jwt_cert(self): | ||||||
|         """Ensure managed JWT certificate""" |         """Ensure managed JWT certificate""" | ||||||
|         from authentik.crypto.models import CertificateKeyPair |         from authentik.crypto.models import CertificateKeyPair | ||||||
|  |  | ||||||
|         cert: Optional[CertificateKeyPair] = CertificateKeyPair.objects.filter( |         cert: CertificateKeyPair | None = CertificateKeyPair.objects.filter( | ||||||
|             managed=MANAGED_KEY |             managed=MANAGED_KEY | ||||||
|         ).first() |         ).first() | ||||||
|         now = datetime.now(tz=timezone.utc) |         now = datetime.now(tz=UTC) | ||||||
|         if not cert or ( |         if not cert or ( | ||||||
|             now < cert.certificate.not_valid_after_utc or now > cert.certificate.not_valid_after_utc |             now < cert.certificate.not_valid_after_utc or now > cert.certificate.not_valid_after_utc | ||||||
|         ): |         ): | ||||||
|             self._create_update_cert() |             self._create_update_cert() | ||||||
|  |  | ||||||
|     def reconcile_tenant_self_signed(self): |     @ManagedAppConfig.reconcile_tenant | ||||||
|  |     def self_signed(self): | ||||||
|         """Create self-signed keypair""" |         """Create self-signed keypair""" | ||||||
|         from authentik.crypto.builder import CertificateBuilder |         from authentik.crypto.builder import CertificateBuilder | ||||||
|         from authentik.crypto.models import CertificateKeyPair |         from authentik.crypto.models import CertificateKeyPair | ||||||
|  | |||||||
| @ -2,7 +2,6 @@ | |||||||
|  |  | ||||||
| import datetime | import datetime | ||||||
| import uuid | import uuid | ||||||
| from typing import Optional |  | ||||||
|  |  | ||||||
| from cryptography import x509 | from cryptography import x509 | ||||||
| from cryptography.hazmat.backends import default_backend | from cryptography.hazmat.backends import default_backend | ||||||
| @ -44,7 +43,7 @@ class CertificateBuilder: | |||||||
|     def generate_private_key(self) -> PrivateKeyTypes: |     def generate_private_key(self) -> PrivateKeyTypes: | ||||||
|         """Generate private key""" |         """Generate private key""" | ||||||
|         if self._use_ec_private_key: |         if self._use_ec_private_key: | ||||||
|             return ec.generate_private_key(curve=ec.SECP256R1) |             return ec.generate_private_key(curve=ec.SECP256R1()) | ||||||
|         return rsa.generate_private_key( |         return rsa.generate_private_key( | ||||||
|             public_exponent=65537, key_size=4096, backend=default_backend() |             public_exponent=65537, key_size=4096, backend=default_backend() | ||||||
|         ) |         ) | ||||||
| @ -52,7 +51,7 @@ class CertificateBuilder: | |||||||
|     def build( |     def build( | ||||||
|         self, |         self, | ||||||
|         validity_days: int = 365, |         validity_days: int = 365, | ||||||
|         subject_alt_names: Optional[list[str]] = None, |         subject_alt_names: list[str] | None = None, | ||||||
|     ): |     ): | ||||||
|         """Build self-signed certificate""" |         """Build self-signed certificate""" | ||||||
|         one_day = datetime.timedelta(1, 0, 0) |         one_day = datetime.timedelta(1, 0, 0) | ||||||
|  | |||||||
| @ -24,13 +24,13 @@ class Command(TenantCommand): | |||||||
|         if not keypair: |         if not keypair: | ||||||
|             keypair = CertificateKeyPair(name=options["name"]) |             keypair = CertificateKeyPair(name=options["name"]) | ||||||
|             dirty = True |             dirty = True | ||||||
|         with open(options["certificate"], mode="r", encoding="utf-8") as _cert: |         with open(options["certificate"], encoding="utf-8") as _cert: | ||||||
|             cert_data = _cert.read() |             cert_data = _cert.read() | ||||||
|             if keypair.certificate_data != cert_data: |             if keypair.certificate_data != cert_data: | ||||||
|                 dirty = True |                 dirty = True | ||||||
|             keypair.certificate_data = cert_data |             keypair.certificate_data = cert_data | ||||||
|         if options["private_key"]: |         if options["private_key"]: | ||||||
|             with open(options["private_key"], mode="r", encoding="utf-8") as _key: |             with open(options["private_key"], encoding="utf-8") as _key: | ||||||
|                 key_data = _key.read() |                 key_data = _key.read() | ||||||
|                 if keypair.key_data != key_data: |                 if keypair.key_data != key_data: | ||||||
|                     dirty = True |                     dirty = True | ||||||
|  | |||||||
| @ -2,7 +2,6 @@ | |||||||
|  |  | ||||||
| from binascii import hexlify | from binascii import hexlify | ||||||
| from hashlib import md5 | from hashlib import md5 | ||||||
| from typing import Optional |  | ||||||
| from uuid import uuid4 | from uuid import uuid4 | ||||||
|  |  | ||||||
| from cryptography.hazmat.backends import default_backend | from cryptography.hazmat.backends import default_backend | ||||||
| @ -37,9 +36,9 @@ class CertificateKeyPair(SerializerModel, ManagedModel, CreatedUpdatedModel): | |||||||
|         default="", |         default="", | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     _cert: Optional[Certificate] = None |     _cert: Certificate | None = None | ||||||
|     _private_key: Optional[PrivateKeyTypes] = None |     _private_key: PrivateKeyTypes | None = None | ||||||
|     _public_key: Optional[PublicKeyTypes] = None |     _public_key: PublicKeyTypes | None = None | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def serializer(self) -> Serializer: |     def serializer(self) -> Serializer: | ||||||
| @ -57,7 +56,7 @@ class CertificateKeyPair(SerializerModel, ManagedModel, CreatedUpdatedModel): | |||||||
|         return self._cert |         return self._cert | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def public_key(self) -> Optional[PublicKeyTypes]: |     def public_key(self) -> PublicKeyTypes | None: | ||||||
|         """Get public key of the private key""" |         """Get public key of the private key""" | ||||||
|         if not self._public_key: |         if not self._public_key: | ||||||
|             self._public_key = self.private_key.public_key() |             self._public_key = self.private_key.public_key() | ||||||
| @ -66,7 +65,7 @@ class CertificateKeyPair(SerializerModel, ManagedModel, CreatedUpdatedModel): | |||||||
|     @property |     @property | ||||||
|     def private_key( |     def private_key( | ||||||
|         self, |         self, | ||||||
|     ) -> Optional[PrivateKeyTypes]: |     ) -> PrivateKeyTypes | None: | ||||||
|         """Get python cryptography PrivateKey instance""" |         """Get python cryptography PrivateKey instance""" | ||||||
|         if not self._private_key and self.key_data != "": |         if not self._private_key and self.key_data != "": | ||||||
|             try: |             try: | ||||||
|  | |||||||
| @ -58,7 +58,7 @@ def certificate_discovery(self: SystemTask): | |||||||
|         else: |         else: | ||||||
|             cert_name = path.name.replace(path.suffix, "") |             cert_name = path.name.replace(path.suffix, "") | ||||||
|         try: |         try: | ||||||
|             with open(path, "r", encoding="utf-8") as _file: |             with open(path, encoding="utf-8") as _file: | ||||||
|                 body = _file.read() |                 body = _file.read() | ||||||
|                 if "PRIVATE KEY" in body: |                 if "PRIVATE KEY" in body: | ||||||
|                     private_keys[cert_name] = ensure_private_key_valid(body) |                     private_keys[cert_name] = ensure_private_key_valid(body) | ||||||
|  | |||||||
| @ -267,7 +267,7 @@ class TestCrypto(APITestCase): | |||||||
|             with open(f"{temp_dir}/foo.bar/privkey.pem", "w+", encoding="utf-8") as _key: |             with open(f"{temp_dir}/foo.bar/privkey.pem", "w+", encoding="utf-8") as _key: | ||||||
|                 _key.write(builder.private_key) |                 _key.write(builder.private_key) | ||||||
|             with CONFIG.patch("cert_discovery_dir", temp_dir): |             with CONFIG.patch("cert_discovery_dir", temp_dir): | ||||||
|                 certificate_discovery()  # pylint: disable=no-value-for-parameter |                 certificate_discovery() | ||||||
|         keypair: CertificateKeyPair = CertificateKeyPair.objects.filter( |         keypair: CertificateKeyPair = CertificateKeyPair.objects.filter( | ||||||
|             managed=MANAGED_DISCOVERED % "foo" |             managed=MANAGED_DISCOVERED % "foo" | ||||||
|         ).first() |         ).first() | ||||||
|  | |||||||
| @ -13,7 +13,8 @@ class AuthentikEnterpriseAuditConfig(EnterpriseConfig): | |||||||
|     verbose_name = "authentik Enterprise.Audit" |     verbose_name = "authentik Enterprise.Audit" | ||||||
|     default = True |     default = True | ||||||
|  |  | ||||||
|     def reconcile_global_install_middleware(self): |     @EnterpriseConfig.reconcile_global | ||||||
|  |     def install_middleware(self): | ||||||
|         """Install enterprise audit middleware""" |         """Install enterprise audit middleware""" | ||||||
|         orig_import = "authentik.events.middleware.AuditMiddleware" |         orig_import = "authentik.events.middleware.AuditMiddleware" | ||||||
|         new_import = "authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware" |         new_import = "authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware" | ||||||
|  | |||||||
| @ -62,7 +62,7 @@ class EnterpriseAuditMiddleware(AuditMiddleware): | |||||||
|                 field_value = value.name |                 field_value = value.name | ||||||
|  |  | ||||||
|             # If current field value is an expression, we are not evaluating it |             # If current field value is an expression, we are not evaluating it | ||||||
|             if isinstance(field_value, (BaseExpression, Combinable)): |             if isinstance(field_value, BaseExpression | Combinable): | ||||||
|                 continue |                 continue | ||||||
|             field_value = field.to_python(field_value) |             field_value = field.to_python(field_value) | ||||||
|             data[field.name] = deepcopy(field_value) |             data[field.name] = deepcopy(field_value) | ||||||
| @ -83,12 +83,11 @@ class EnterpriseAuditMiddleware(AuditMiddleware): | |||||||
|         if hasattr(instance, "_previous_state"): |         if hasattr(instance, "_previous_state"): | ||||||
|             return |             return | ||||||
|         before = len(connection.queries) |         before = len(connection.queries) | ||||||
|         setattr(instance, "_previous_state", self.serialize_simple(instance)) |         instance._previous_state = self.serialize_simple(instance) | ||||||
|         after = len(connection.queries) |         after = len(connection.queries) | ||||||
|         if after > before: |         if after > before: | ||||||
|             raise AssertionError("More queries generated by serialize_simple") |             raise AssertionError("More queries generated by serialize_simple") | ||||||
|  |  | ||||||
|     # pylint: disable=too-many-arguments |  | ||||||
|     def post_save_handler( |     def post_save_handler( | ||||||
|         self, |         self, | ||||||
|         user: User, |         user: User, | ||||||
|  | |||||||
| @ -27,7 +27,7 @@ CACHE_KEY_ENTERPRISE_LICENSE = "goauthentik.io/enterprise/license" | |||||||
| CACHE_EXPIRY_ENTERPRISE_LICENSE = 3 * 60 * 60  # 2 Hours | CACHE_EXPIRY_ENTERPRISE_LICENSE = 3 * 60 * 60  # 2 Hours | ||||||
|  |  | ||||||
|  |  | ||||||
| @lru_cache() | @lru_cache | ||||||
| def get_licensing_key() -> Certificate: | def get_licensing_key() -> Certificate: | ||||||
|     """Get Root CA PEM""" |     """Get Root CA PEM""" | ||||||
|     with open("authentik/enterprise/public.pem", "rb") as _key: |     with open("authentik/enterprise/public.pem", "rb") as _key: | ||||||
| @ -88,7 +88,7 @@ class LicenseKey: | |||||||
|         try: |         try: | ||||||
|             headers = get_unverified_header(jwt) |             headers = get_unverified_header(jwt) | ||||||
|         except PyJWTError: |         except PyJWTError: | ||||||
|             raise ValidationError("Unable to verify license") |             raise ValidationError("Unable to verify license") from None | ||||||
|         x5c: list[str] = headers.get("x5c", []) |         x5c: list[str] = headers.get("x5c", []) | ||||||
|         if len(x5c) < 1: |         if len(x5c) < 1: | ||||||
|             raise ValidationError("Unable to verify license") |             raise ValidationError("Unable to verify license") | ||||||
| @ -98,7 +98,7 @@ class LicenseKey: | |||||||
|             our_cert.verify_directly_issued_by(intermediate) |             our_cert.verify_directly_issued_by(intermediate) | ||||||
|             intermediate.verify_directly_issued_by(get_licensing_key()) |             intermediate.verify_directly_issued_by(get_licensing_key()) | ||||||
|         except (InvalidSignature, TypeError, ValueError, Error): |         except (InvalidSignature, TypeError, ValueError, Error): | ||||||
|             raise ValidationError("Unable to verify license") |             raise ValidationError("Unable to verify license") from None | ||||||
|         try: |         try: | ||||||
|             body = from_dict( |             body = from_dict( | ||||||
|                 LicenseKey, |                 LicenseKey, | ||||||
| @ -110,7 +110,7 @@ class LicenseKey: | |||||||
|                 ), |                 ), | ||||||
|             ) |             ) | ||||||
|         except PyJWTError: |         except PyJWTError: | ||||||
|             raise ValidationError("Unable to verify license") |             raise ValidationError("Unable to verify license") from None | ||||||
|         return body |         return body | ||||||
|  |  | ||||||
|     @staticmethod |     @staticmethod | ||||||
| @ -188,21 +188,20 @@ class LicenseKey: | |||||||
|  |  | ||||||
|     def summary(self) -> LicenseSummary: |     def summary(self) -> LicenseSummary: | ||||||
|         """Summary of license status""" |         """Summary of license status""" | ||||||
|         has_license = License.objects.all().count() > 0 |  | ||||||
|         last_valid = LicenseKey.last_valid_date() |         last_valid = LicenseKey.last_valid_date() | ||||||
|         show_admin_warning = last_valid < now() - timedelta(weeks=2) |         show_admin_warning = last_valid < now() - timedelta(weeks=2) | ||||||
|         show_user_warning = last_valid < now() - timedelta(weeks=4) |         show_user_warning = last_valid < now() - timedelta(weeks=4) | ||||||
|         read_only = last_valid < now() - timedelta(weeks=6) |         read_only = last_valid < now() - timedelta(weeks=6) | ||||||
|         latest_valid = datetime.fromtimestamp(self.exp) |         latest_valid = datetime.fromtimestamp(self.exp) | ||||||
|         return LicenseSummary( |         return LicenseSummary( | ||||||
|             show_admin_warning=show_admin_warning and has_license, |             show_admin_warning=show_admin_warning, | ||||||
|             show_user_warning=show_user_warning and has_license, |             show_user_warning=show_user_warning, | ||||||
|             read_only=read_only and has_license, |             read_only=read_only, | ||||||
|             latest_valid=latest_valid, |             latest_valid=latest_valid, | ||||||
|             internal_users=self.internal_users, |             internal_users=self.internal_users, | ||||||
|             external_users=self.external_users, |             external_users=self.external_users, | ||||||
|             valid=self.is_valid(), |             valid=self.is_valid(), | ||||||
|             has_license=has_license, |             has_license=License.objects.all().count() > 0, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     @staticmethod |     @staticmethod | ||||||
|  | |||||||
| @ -1,7 +1,5 @@ | |||||||
| """Enterprise license policies""" | """Enterprise license policies""" | ||||||
|  |  | ||||||
| from typing import Optional |  | ||||||
|  |  | ||||||
| from django.utils.translation import gettext_lazy as _ | from django.utils.translation import gettext_lazy as _ | ||||||
|  |  | ||||||
| from authentik.core.models import User, UserTypes | from authentik.core.models import User, UserTypes | ||||||
| @ -21,7 +19,7 @@ class EnterprisePolicyAccessView(PolicyAccessView): | |||||||
|             return PolicyResult(False, _("Feature only accessible for internal users.")) |             return PolicyResult(False, _("Feature only accessible for internal users.")) | ||||||
|         return PolicyResult(True) |         return PolicyResult(True) | ||||||
|  |  | ||||||
|     def user_has_access(self, user: Optional[User] = None) -> PolicyResult: |     def user_has_access(self, user: User | None = None) -> PolicyResult: | ||||||
|         user = user or self.request.user |         user = user or self.request.user | ||||||
|         request = PolicyRequest(user) |         request = PolicyRequest(user) | ||||||
|         request.http_request = self.request |         request.http_request = self.request | ||||||
|  | |||||||
| @ -1,7 +1,5 @@ | |||||||
| """RAC Provider API Views""" | """RAC Provider API Views""" | ||||||
|  |  | ||||||
| from typing import Optional |  | ||||||
|  |  | ||||||
| from django.core.cache import cache | from django.core.cache import cache | ||||||
| from django.db.models import QuerySet | from django.db.models import QuerySet | ||||||
| from django.urls import reverse | from django.urls import reverse | ||||||
| @ -36,11 +34,11 @@ class EndpointSerializer(EnterpriseRequiredMixin, ModelSerializer): | |||||||
|     provider_obj = RACProviderSerializer(source="provider", read_only=True) |     provider_obj = RACProviderSerializer(source="provider", read_only=True) | ||||||
|     launch_url = SerializerMethodField() |     launch_url = SerializerMethodField() | ||||||
|  |  | ||||||
|     def get_launch_url(self, endpoint: Endpoint) -> Optional[str]: |     def get_launch_url(self, endpoint: Endpoint) -> str | None: | ||||||
|         """Build actual launch URL (the provider itself does not have one, just |         """Build actual launch URL (the provider itself does not have one, just | ||||||
|         individual endpoints)""" |         individual endpoints)""" | ||||||
|         try: |         try: | ||||||
|             # pylint: disable=no-member |  | ||||||
|             return reverse( |             return reverse( | ||||||
|                 "authentik_providers_rac:start", |                 "authentik_providers_rac:start", | ||||||
|                 kwargs={"app": endpoint.provider.application.slug, "endpoint": endpoint.pk}, |                 kwargs={"app": endpoint.provider.application.slug, "endpoint": endpoint.pk}, | ||||||
|  | |||||||
| @ -1,6 +1,6 @@ | |||||||
| """RAC Models""" | """RAC Models""" | ||||||
|  |  | ||||||
| from typing import Any, Optional | from typing import Any | ||||||
| from uuid import uuid4 | from uuid import uuid4 | ||||||
|  |  | ||||||
| from deepmerge import always_merger | from deepmerge import always_merger | ||||||
| @ -58,7 +58,7 @@ class RACProvider(Provider): | |||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def launch_url(self) -> Optional[str]: |     def launch_url(self) -> str | None: | ||||||
|         """URL to this provider and initiate authorization for the user. |         """URL to this provider and initiate authorization for the user. | ||||||
|         Can return None for providers that are not URL-based""" |         Can return None for providers that are not URL-based""" | ||||||
|         return "goauthentik.io://providers/rac/launch" |         return "goauthentik.io://providers/rac/launch" | ||||||
| @ -112,7 +112,7 @@ class RACPropertyMapping(PropertyMapping): | |||||||
|  |  | ||||||
|     static_settings = models.JSONField(default=dict) |     static_settings = models.JSONField(default=dict) | ||||||
|  |  | ||||||
|     def evaluate(self, user: Optional[User], request: Optional[HttpRequest], **kwargs) -> Any: |     def evaluate(self, user: User | None, request: HttpRequest | None, **kwargs) -> Any: | ||||||
|         """Evaluate `self.expression` using `**kwargs` as Context.""" |         """Evaluate `self.expression` using `**kwargs` as Context.""" | ||||||
|         if len(self.static_settings) > 0: |         if len(self.static_settings) > 0: | ||||||
|             return self.static_settings |             return self.static_settings | ||||||
|  | |||||||
| @ -47,7 +47,7 @@ class RACStartView(EnterprisePolicyAccessView): | |||||||
|                 }, |                 }, | ||||||
|             ) |             ) | ||||||
|         except FlowNonApplicableException: |         except FlowNonApplicableException: | ||||||
|             raise Http404 |             raise Http404 from None | ||||||
|         plan.insert_stage( |         plan.insert_stage( | ||||||
|             in_memory_stage( |             in_memory_stage( | ||||||
|                 RACFinalStage, |                 RACFinalStage, | ||||||
| @ -132,16 +132,7 @@ class RACFinalStage(RedirectStage): | |||||||
|             flow=self.executor.plan.flow_pk, |             flow=self.executor.plan.flow_pk, | ||||||
|             endpoint=self.endpoint.name, |             endpoint=self.endpoint.name, | ||||||
|         ).from_http(self.request) |         ).from_http(self.request) | ||||||
|         setattr( |         self.executor.current_stage.destination = self.request.build_absolute_uri( | ||||||
|             self.executor.current_stage, |             reverse("authentik_providers_rac:if-rac", kwargs={"token": str(token.token)}) | ||||||
|             "destination", |  | ||||||
|             self.request.build_absolute_uri( |  | ||||||
|                 reverse( |  | ||||||
|                     "authentik_providers_rac:if-rac", |  | ||||||
|                     kwargs={ |  | ||||||
|                         "token": str(token.token), |  | ||||||
|                     }, |  | ||||||
|                 ) |  | ||||||
|             ), |  | ||||||
|         ) |         ) | ||||||
|         return super().get_challenge(*args, **kwargs) |         return super().get_challenge(*args, **kwargs) | ||||||
|  | |||||||
| @ -2,14 +2,11 @@ | |||||||
|  |  | ||||||
| from datetime import datetime | from datetime import datetime | ||||||
|  |  | ||||||
| from django.core.cache import cache | from django.db.models.signals import pre_save | ||||||
| from django.db.models.signals import post_save, pre_save |  | ||||||
| from django.dispatch import receiver | from django.dispatch import receiver | ||||||
| from django.utils.timezone import get_current_timezone | from django.utils.timezone import get_current_timezone | ||||||
|  |  | ||||||
| from authentik.enterprise.license import CACHE_KEY_ENTERPRISE_LICENSE |  | ||||||
| from authentik.enterprise.models import License | from authentik.enterprise.models import License | ||||||
| from authentik.enterprise.tasks import enterprise_update_usage |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @receiver(pre_save, sender=License) | @receiver(pre_save, sender=License) | ||||||
| @ -20,10 +17,3 @@ def pre_save_license(sender: type[License], instance: License, **_): | |||||||
|     instance.internal_users = status.internal_users |     instance.internal_users = status.internal_users | ||||||
|     instance.external_users = status.external_users |     instance.external_users = status.external_users | ||||||
|     instance.expiry = datetime.fromtimestamp(status.exp, tz=get_current_timezone()) |     instance.expiry = datetime.fromtimestamp(status.exp, tz=get_current_timezone()) | ||||||
|  |  | ||||||
|  |  | ||||||
| @receiver(post_save, sender=License) |  | ||||||
| def post_save_license(sender: type[License], instance: License, **_): |  | ||||||
|     """Trigger license usage calculation when license is saved""" |  | ||||||
|     cache.delete(CACHE_KEY_ENTERPRISE_LICENSE) |  | ||||||
|     enterprise_update_usage.delay() |  | ||||||
|  | |||||||
| @ -92,7 +92,7 @@ class SystemTaskViewSet(ReadOnlyModelViewSet): | |||||||
|             task_func.delay(*task.task_call_args, **task.task_call_kwargs) |             task_func.delay(*task.task_call_args, **task.task_call_kwargs) | ||||||
|             messages.success( |             messages.success( | ||||||
|                 self.request, |                 self.request, | ||||||
|                 _("Successfully started task %(name)s." % {"name": task.name}), |                 _("Successfully started task {name}.".format_map({"name": task.name})), | ||||||
|             ) |             ) | ||||||
|             return Response(status=204) |             return Response(status=204) | ||||||
|         except (ImportError, AttributeError) as exc:  # pragma: no cover |         except (ImportError, AttributeError) as exc:  # pragma: no cover | ||||||
|  | |||||||
| @ -35,7 +35,8 @@ class AuthentikEventsConfig(ManagedAppConfig): | |||||||
|     verbose_name = "authentik Events" |     verbose_name = "authentik Events" | ||||||
|     default = True |     default = True | ||||||
|  |  | ||||||
|     def reconcile_global_check_deprecations(self): |     @ManagedAppConfig.reconcile_global | ||||||
|  |     def check_deprecations(self): | ||||||
|         """Check for config deprecations""" |         """Check for config deprecations""" | ||||||
|         from authentik.events.models import Event, EventAction |         from authentik.events.models import Event, EventAction | ||||||
|  |  | ||||||
| @ -56,7 +57,8 @@ class AuthentikEventsConfig(ManagedAppConfig): | |||||||
|                 message=msg, |                 message=msg, | ||||||
|             ).save() |             ).save() | ||||||
|  |  | ||||||
|     def reconcile_tenant_prefill_tasks(self): |     @ManagedAppConfig.reconcile_tenant | ||||||
|  |     def prefill_tasks(self): | ||||||
|         """Prefill tasks""" |         """Prefill tasks""" | ||||||
|         from authentik.events.models import SystemTask |         from authentik.events.models import SystemTask | ||||||
|         from authentik.events.system_tasks import _prefill_tasks |         from authentik.events.system_tasks import _prefill_tasks | ||||||
| @ -67,7 +69,8 @@ class AuthentikEventsConfig(ManagedAppConfig): | |||||||
|             task.save() |             task.save() | ||||||
|             self.logger.debug("prefilled task", task_name=task.name) |             self.logger.debug("prefilled task", task_name=task.name) | ||||||
|  |  | ||||||
|     def reconcile_tenant_run_scheduled_tasks(self): |     @ManagedAppConfig.reconcile_tenant | ||||||
|  |     def run_scheduled_tasks(self): | ||||||
|         """Run schedule tasks which are behind schedule (only applies |         """Run schedule tasks which are behind schedule (only applies | ||||||
|         to tasks of which we keep metrics)""" |         to tasks of which we keep metrics)""" | ||||||
|         from authentik.events.models import TaskStatus |         from authentik.events.models import TaskStatus | ||||||
|  | |||||||
| @ -46,7 +46,7 @@ class ASNContextProcessor(MMDBContextProcessor): | |||||||
|             "asn": self.asn_dict(ClientIPMiddleware.get_client_ip(request)), |             "asn": self.asn_dict(ClientIPMiddleware.get_client_ip(request)), | ||||||
|         } |         } | ||||||
|  |  | ||||||
|     def asn(self, ip_address: str) -> Optional[ASN]: |     def asn(self, ip_address: str) -> ASN | None: | ||||||
|         """Wrapper for Reader.asn""" |         """Wrapper for Reader.asn""" | ||||||
|         with Hub.current.start_span( |         with Hub.current.start_span( | ||||||
|             op="authentik.events.asn.asn", |             op="authentik.events.asn.asn", | ||||||
| @ -71,7 +71,7 @@ class ASNContextProcessor(MMDBContextProcessor): | |||||||
|         } |         } | ||||||
|         return asn_dict |         return asn_dict | ||||||
|  |  | ||||||
|     def asn_dict(self, ip_address: str) -> Optional[ASNDict]: |     def asn_dict(self, ip_address: str) -> ASNDict | None: | ||||||
|         """Wrapper for self.asn that returns a dict""" |         """Wrapper for self.asn that returns a dict""" | ||||||
|         asn = self.asn(ip_address) |         asn = self.asn(ip_address) | ||||||
|         if not asn: |         if not asn: | ||||||
|  | |||||||
| @ -47,7 +47,7 @@ class GeoIPContextProcessor(MMDBContextProcessor): | |||||||
|         # Different key `geoip` vs `geo` for legacy reasons |         # Different key `geoip` vs `geo` for legacy reasons | ||||||
|         return {"geoip": self.city(ClientIPMiddleware.get_client_ip(request))} |         return {"geoip": self.city(ClientIPMiddleware.get_client_ip(request))} | ||||||
|  |  | ||||||
|     def city(self, ip_address: str) -> Optional[City]: |     def city(self, ip_address: str) -> City | None: | ||||||
|         """Wrapper for Reader.city""" |         """Wrapper for Reader.city""" | ||||||
|         with Hub.current.start_span( |         with Hub.current.start_span( | ||||||
|             op="authentik.events.geo.city", |             op="authentik.events.geo.city", | ||||||
| @ -76,7 +76,7 @@ class GeoIPContextProcessor(MMDBContextProcessor): | |||||||
|             city_dict["city"] = city.city.name |             city_dict["city"] = city.city.name | ||||||
|         return city_dict |         return city_dict | ||||||
|  |  | ||||||
|     def city_dict(self, ip_address: str) -> Optional[GeoIPDict]: |     def city_dict(self, ip_address: str) -> GeoIPDict | None: | ||||||
|         """Wrapper for self.city that returns a dict""" |         """Wrapper for self.city that returns a dict""" | ||||||
|         city = self.city(ip_address) |         city = self.city(ip_address) | ||||||
|         if not city: |         if not city: | ||||||
|  | |||||||
| @ -1,7 +1,6 @@ | |||||||
| """Common logic for reading MMDB files""" | """Common logic for reading MMDB files""" | ||||||
|  |  | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from typing import Optional |  | ||||||
|  |  | ||||||
| from geoip2.database import Reader | from geoip2.database import Reader | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
| @ -13,7 +12,7 @@ class MMDBContextProcessor(EventContextProcessor): | |||||||
|     """Common logic for reading MaxMind DB files, including re-loading if the file has changed""" |     """Common logic for reading MaxMind DB files, including re-loading if the file has changed""" | ||||||
|  |  | ||||||
|     def __init__(self): |     def __init__(self): | ||||||
|         self.reader: Optional[Reader] = None |         self.reader: Reader | None = None | ||||||
|         self._last_mtime: float = 0.0 |         self._last_mtime: float = 0.0 | ||||||
|         self.logger = get_logger() |         self.logger = get_logger() | ||||||
|         self.open() |         self.open() | ||||||
|  | |||||||
| @ -1,8 +1,9 @@ | |||||||
| """Events middleware""" | """Events middleware""" | ||||||
|  |  | ||||||
|  | from collections.abc import Callable | ||||||
| from functools import partial | from functools import partial | ||||||
| from threading import Thread | from threading import Thread | ||||||
| from typing import Any, Callable, Optional | from typing import Any | ||||||
|  |  | ||||||
| from django.conf import settings | from django.conf import settings | ||||||
| from django.contrib.sessions.models import Session | from django.contrib.sessions.models import Session | ||||||
| @ -49,9 +50,9 @@ class EventNewThread(Thread): | |||||||
|     action: str |     action: str | ||||||
|     request: HttpRequest |     request: HttpRequest | ||||||
|     kwargs: dict[str, Any] |     kwargs: dict[str, Any] | ||||||
|     user: Optional[User] = None |     user: User | None = None | ||||||
|  |  | ||||||
|     def __init__(self, action: str, request: HttpRequest, user: Optional[User] = None, **kwargs): |     def __init__(self, action: str, request: HttpRequest, user: User | None = None, **kwargs): | ||||||
|         super().__init__() |         super().__init__() | ||||||
|         self.action = action |         self.action = action | ||||||
|         self.request = request |         self.request = request | ||||||
| @ -144,7 +145,6 @@ class AuditMiddleware: | |||||||
|             ) |             ) | ||||||
|             thread.run() |             thread.run() | ||||||
|  |  | ||||||
|     # pylint: disable=too-many-arguments |  | ||||||
|     def post_save_handler( |     def post_save_handler( | ||||||
|         self, |         self, | ||||||
|         user: User, |         user: User, | ||||||
| @ -152,7 +152,7 @@ class AuditMiddleware: | |||||||
|         sender, |         sender, | ||||||
|         instance: Model, |         instance: Model, | ||||||
|         created: bool, |         created: bool, | ||||||
|         thread_kwargs: Optional[dict] = None, |         thread_kwargs: dict | None = None, | ||||||
|         **_, |         **_, | ||||||
|     ): |     ): | ||||||
|         """Signal handler for all object's post_save""" |         """Signal handler for all object's post_save""" | ||||||
|  | |||||||
| @ -7,7 +7,6 @@ from difflib import get_close_matches | |||||||
| from functools import lru_cache | from functools import lru_cache | ||||||
| from inspect import currentframe | from inspect import currentframe | ||||||
| from smtplib import SMTPException | from smtplib import SMTPException | ||||||
| from typing import Optional |  | ||||||
| from uuid import uuid4 | from uuid import uuid4 | ||||||
|  |  | ||||||
| from django.apps import apps | from django.apps import apps | ||||||
| @ -52,6 +51,8 @@ from authentik.stages.email.utils import TemplateEmailMessage | |||||||
| from authentik.tenants.models import Tenant | from authentik.tenants.models import Tenant | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  | DISCORD_FIELD_LIMIT = 25 | ||||||
|  | NOTIFICATION_SUMMARY_LENGTH = 75 | ||||||
|  |  | ||||||
|  |  | ||||||
| def default_event_duration(): | def default_event_duration(): | ||||||
| @ -65,7 +66,7 @@ def default_brand(): | |||||||
|     return sanitize_dict(model_to_dict(DEFAULT_BRAND)) |     return sanitize_dict(model_to_dict(DEFAULT_BRAND)) | ||||||
|  |  | ||||||
|  |  | ||||||
| @lru_cache() | @lru_cache | ||||||
| def django_app_names() -> list[str]: | def django_app_names() -> list[str]: | ||||||
|     """Get a cached list of all django apps' names (not labels)""" |     """Get a cached list of all django apps' names (not labels)""" | ||||||
|     return [x.name for x in apps.app_configs.values()] |     return [x.name for x in apps.app_configs.values()] | ||||||
| @ -198,7 +199,7 @@ class Event(SerializerModel, ExpiringModel): | |||||||
|     @staticmethod |     @staticmethod | ||||||
|     def new( |     def new( | ||||||
|         action: str | EventAction, |         action: str | EventAction, | ||||||
|         app: Optional[str] = None, |         app: str | None = None, | ||||||
|         **kwargs, |         **kwargs, | ||||||
|     ) -> "Event": |     ) -> "Event": | ||||||
|         """Create new Event instance from arguments. Instance is NOT saved.""" |         """Create new Event instance from arguments. Instance is NOT saved.""" | ||||||
| @ -224,7 +225,7 @@ class Event(SerializerModel, ExpiringModel): | |||||||
|         self.user = get_user(user) |         self.user = get_user(user) | ||||||
|         return self |         return self | ||||||
|  |  | ||||||
|     def from_http(self, request: HttpRequest, user: Optional[User] = None) -> "Event": |     def from_http(self, request: HttpRequest, user: User | None = None) -> "Event": | ||||||
|         """Add data from a Django-HttpRequest, allowing the creation of |         """Add data from a Django-HttpRequest, allowing the creation of | ||||||
|         Events independently from requests. |         Events independently from requests. | ||||||
|         `user` arguments optionally overrides user from requests.""" |         `user` arguments optionally overrides user from requests.""" | ||||||
| @ -418,7 +419,7 @@ class NotificationTransport(SerializerModel): | |||||||
|                 if not isinstance(value, str): |                 if not isinstance(value, str): | ||||||
|                     continue |                     continue | ||||||
|                 # https://birdie0.github.io/discord-webhooks-guide/other/field_limits.html |                 # https://birdie0.github.io/discord-webhooks-guide/other/field_limits.html | ||||||
|                 if len(fields) >= 25: |                 if len(fields) >= DISCORD_FIELD_LIMIT: | ||||||
|                     continue |                     continue | ||||||
|                 fields.append({"title": key[:256], "value": value[:1024]}) |                 fields.append({"title": key[:256], "value": value[:1024]}) | ||||||
|         body = { |         body = { | ||||||
| @ -472,7 +473,7 @@ class NotificationTransport(SerializerModel): | |||||||
|                     continue |                     continue | ||||||
|                 context["key_value"][key] = value |                 context["key_value"][key] = value | ||||||
|         else: |         else: | ||||||
|             context["title"] += notification.body[:75] |             context["title"] += notification.body[:NOTIFICATION_SUMMARY_LENGTH] | ||||||
|         # TODO: improve permission check |         # TODO: improve permission check | ||||||
|         if notification.user.is_superuser: |         if notification.user.is_superuser: | ||||||
|             context["source"] = { |             context["source"] = { | ||||||
| @ -489,7 +490,7 @@ class NotificationTransport(SerializerModel): | |||||||
|         try: |         try: | ||||||
|             from authentik.stages.email.tasks import send_mail |             from authentik.stages.email.tasks import send_mail | ||||||
|  |  | ||||||
|             return send_mail(mail.__dict__)  # pylint: disable=no-value-for-parameter |             return send_mail(mail.__dict__) | ||||||
|         except (SMTPException, ConnectionError, OSError) as exc: |         except (SMTPException, ConnectionError, OSError) as exc: | ||||||
|             raise NotificationTransportError(exc) from exc |             raise NotificationTransportError(exc) from exc | ||||||
|  |  | ||||||
| @ -533,7 +534,11 @@ class Notification(SerializerModel): | |||||||
|         return NotificationSerializer |         return NotificationSerializer | ||||||
|  |  | ||||||
|     def __str__(self) -> str: |     def __str__(self) -> str: | ||||||
|         body_trunc = (self.body[:75] + "..") if len(self.body) > 75 else self.body |         body_trunc = ( | ||||||
|  |             (self.body[:NOTIFICATION_SUMMARY_LENGTH] + "..") | ||||||
|  |             if len(self.body) > NOTIFICATION_SUMMARY_LENGTH | ||||||
|  |             else self.body | ||||||
|  |         ) | ||||||
|         return f"Notification for user {self.user}: {body_trunc}" |         return f"Notification for user {self.user}: {body_trunc}" | ||||||
|  |  | ||||||
|     class Meta: |     class Meta: | ||||||
|  | |||||||
| @ -1,6 +1,6 @@ | |||||||
| """authentik events signal listener""" | """authentik events signal listener""" | ||||||
|  |  | ||||||
| from typing import Any, Optional | from typing import Any | ||||||
|  |  | ||||||
| from django.contrib.auth.signals import user_logged_in, user_logged_out | from django.contrib.auth.signals import user_logged_in, user_logged_out | ||||||
| from django.db.models.signals import post_save, pre_delete | from django.db.models.signals import post_save, pre_delete | ||||||
| @ -42,7 +42,7 @@ def on_user_logged_in(sender, request: HttpRequest, user: User, **_): | |||||||
|     request.session[SESSION_LOGIN_EVENT] = event |     request.session[SESSION_LOGIN_EVENT] = event | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_login_event(request: HttpRequest) -> Optional[Event]: | def get_login_event(request: HttpRequest) -> Event | None: | ||||||
|     """Wrapper to get login event that can be mocked in tests""" |     """Wrapper to get login event that can be mocked in tests""" | ||||||
|     return request.session.get(SESSION_LOGIN_EVENT, None) |     return request.session.get(SESSION_LOGIN_EVENT, None) | ||||||
|  |  | ||||||
| @ -71,7 +71,7 @@ def on_login_failed( | |||||||
|     sender, |     sender, | ||||||
|     credentials: dict[str, str], |     credentials: dict[str, str], | ||||||
|     request: HttpRequest, |     request: HttpRequest, | ||||||
|     stage: Optional[Stage] = None, |     stage: Stage | None = None, | ||||||
|     **kwargs, |     **kwargs, | ||||||
| ): | ): | ||||||
|     """Failed Login, authentik custom event""" |     """Failed Login, authentik custom event""" | ||||||
|  | |||||||
| @ -2,16 +2,15 @@ | |||||||
|  |  | ||||||
| from datetime import datetime, timedelta | from datetime import datetime, timedelta | ||||||
| from time import perf_counter | from time import perf_counter | ||||||
| from typing import Any, Optional | from typing import Any | ||||||
|  |  | ||||||
| from django.utils.timezone import now | from django.utils.timezone import now | ||||||
| from django.utils.translation import gettext_lazy as _ | from django.utils.translation import gettext_lazy as _ | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
| from tenant_schemas_celery.task import TenantTask | from tenant_schemas_celery.task import TenantTask | ||||||
|  |  | ||||||
| from authentik.events.models import Event, EventAction | from authentik.events.models import Event, EventAction, TaskStatus | ||||||
| from authentik.events.models import SystemTask as DBSystemTask | from authentik.events.models import SystemTask as DBSystemTask | ||||||
| from authentik.events.models import TaskStatus |  | ||||||
| from authentik.events.utils import sanitize_item | from authentik.events.utils import sanitize_item | ||||||
| from authentik.lib.utils.errors import exception_to_string | from authentik.lib.utils.errors import exception_to_string | ||||||
|  |  | ||||||
| @ -27,10 +26,10 @@ class SystemTask(TenantTask): | |||||||
|     _status: TaskStatus |     _status: TaskStatus | ||||||
|     _messages: list[str] |     _messages: list[str] | ||||||
|  |  | ||||||
|     _uid: Optional[str] |     _uid: str | None | ||||||
|     # Precise start time from perf_counter |     # Precise start time from perf_counter | ||||||
|     _start_precise: Optional[float] = None |     _start_precise: float | None = None | ||||||
|     _start: Optional[datetime] = None |     _start: datetime | None = None | ||||||
|  |  | ||||||
|     def __init__(self, *args, **kwargs) -> None: |     def __init__(self, *args, **kwargs) -> None: | ||||||
|         super().__init__(*args, **kwargs) |         super().__init__(*args, **kwargs) | ||||||
| @ -60,14 +59,13 @@ class SystemTask(TenantTask): | |||||||
|         self._start = now() |         self._start = now() | ||||||
|         return super().before_start(task_id, args, kwargs) |         return super().before_start(task_id, args, kwargs) | ||||||
|  |  | ||||||
|     def db(self) -> Optional[DBSystemTask]: |     def db(self) -> DBSystemTask | None: | ||||||
|         """Get DB object for latest task""" |         """Get DB object for latest task""" | ||||||
|         return DBSystemTask.objects.filter( |         return DBSystemTask.objects.filter( | ||||||
|             name=self.__name__, |             name=self.__name__, | ||||||
|             uid=self._uid, |             uid=self._uid, | ||||||
|         ).first() |         ).first() | ||||||
|  |  | ||||||
|     # pylint: disable=too-many-arguments |  | ||||||
|     def after_return(self, status, retval, task_id, args: list[Any], kwargs: dict[str, Any], einfo): |     def after_return(self, status, retval, task_id, args: list[Any], kwargs: dict[str, Any], einfo): | ||||||
|         super().after_return(status, retval, task_id, args, kwargs, einfo=einfo) |         super().after_return(status, retval, task_id, args, kwargs, einfo=einfo) | ||||||
|         if not self._status: |         if not self._status: | ||||||
| @ -97,7 +95,6 @@ class SystemTask(TenantTask): | |||||||
|             }, |             }, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     # pylint: disable=too-many-arguments |  | ||||||
|     def on_failure(self, exc, task_id, args, kwargs, einfo): |     def on_failure(self, exc, task_id, args, kwargs, einfo): | ||||||
|         super().on_failure(exc, task_id, args, kwargs, einfo=einfo) |         super().on_failure(exc, task_id, args, kwargs, einfo=einfo) | ||||||
|         if not self._status: |         if not self._status: | ||||||
|  | |||||||
| @ -1,7 +1,5 @@ | |||||||
| """Event notification tasks""" | """Event notification tasks""" | ||||||
|  |  | ||||||
| from typing import Optional |  | ||||||
|  |  | ||||||
| from django.db.models.query_utils import Q | from django.db.models.query_utils import Q | ||||||
| from guardian.shortcuts import get_anonymous_user | from guardian.shortcuts import get_anonymous_user | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
| @ -38,7 +36,7 @@ def event_trigger_handler(event_uuid: str, trigger_name: str): | |||||||
|     if not event: |     if not event: | ||||||
|         LOGGER.warning("event doesn't exist yet or anymore", event_uuid=event_uuid) |         LOGGER.warning("event doesn't exist yet or anymore", event_uuid=event_uuid) | ||||||
|         return |         return | ||||||
|     trigger: Optional[NotificationRule] = NotificationRule.objects.filter(name=trigger_name).first() |     trigger: NotificationRule | None = NotificationRule.objects.filter(name=trigger_name).first() | ||||||
|     if not trigger: |     if not trigger: | ||||||
|         return |         return | ||||||
|  |  | ||||||
|  | |||||||
| @ -105,7 +105,7 @@ class TestEvents(TestCase): | |||||||
|         # Test brand |         # Test brand | ||||||
|         request = self.factory.get("/") |         request = self.factory.get("/") | ||||||
|         brand = Brand(domain="test-brand") |         brand = Brand(domain="test-brand") | ||||||
|         setattr(request, "brand", brand) |         request.brand = brand | ||||||
|         event = Event.new("unittest").from_http(request) |         event = Event.new("unittest").from_http(request) | ||||||
|         self.assertEqual( |         self.assertEqual( | ||||||
|             event.brand, |             event.brand, | ||||||
|  | |||||||
| @ -7,7 +7,7 @@ from datetime import date, datetime, time, timedelta | |||||||
| from enum import Enum | from enum import Enum | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from types import GeneratorType, NoneType | from types import GeneratorType, NoneType | ||||||
| from typing import Any, Optional | from typing import Any | ||||||
| from uuid import UUID | from uuid import UUID | ||||||
|  |  | ||||||
| from django.contrib.auth.models import AnonymousUser | from django.contrib.auth.models import AnonymousUser | ||||||
| @ -37,7 +37,7 @@ def cleanse_item(key: str, value: Any) -> Any: | |||||||
|     """Cleanse a single item""" |     """Cleanse a single item""" | ||||||
|     if isinstance(value, dict): |     if isinstance(value, dict): | ||||||
|         return cleanse_dict(value) |         return cleanse_dict(value) | ||||||
|     if isinstance(value, (list, tuple, set)): |     if isinstance(value, list | tuple | set): | ||||||
|         for idx, item in enumerate(value): |         for idx, item in enumerate(value): | ||||||
|             value[idx] = cleanse_item(key, item) |             value[idx] = cleanse_item(key, item) | ||||||
|         return value |         return value | ||||||
| @ -74,7 +74,7 @@ def model_to_dict(model: Model) -> dict[str, Any]: | |||||||
|     } |     } | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_user(user: User | AnonymousUser, original_user: Optional[User] = None) -> dict[str, Any]: | def get_user(user: User | AnonymousUser, original_user: User | None = None) -> dict[str, Any]: | ||||||
|     """Convert user object to dictionary, optionally including the original user""" |     """Convert user object to dictionary, optionally including the original user""" | ||||||
|     if isinstance(user, AnonymousUser): |     if isinstance(user, AnonymousUser): | ||||||
|         try: |         try: | ||||||
| @ -95,8 +95,7 @@ def get_user(user: User | AnonymousUser, original_user: Optional[User] = None) - | |||||||
|     return user_data |     return user_data | ||||||
|  |  | ||||||
|  |  | ||||||
| # pylint: disable=too-many-return-statements,too-many-branches | def sanitize_item(value: Any) -> Any:  # noqa: PLR0911, PLR0912 | ||||||
| def sanitize_item(value: Any) -> Any: |  | ||||||
|     """Sanitize a single item, ensure it is JSON parsable""" |     """Sanitize a single item, ensure it is JSON parsable""" | ||||||
|     if is_dataclass(value): |     if is_dataclass(value): | ||||||
|         # Because asdict calls `copy.deepcopy(obj)` on everything that's not tuple/dict, |         # Because asdict calls `copy.deepcopy(obj)` on everything that's not tuple/dict, | ||||||
| @ -115,20 +114,20 @@ def sanitize_item(value: Any) -> Any: | |||||||
|         return sanitize_dict(value) |         return sanitize_dict(value) | ||||||
|     if isinstance(value, GeneratorType): |     if isinstance(value, GeneratorType): | ||||||
|         return sanitize_item(list(value)) |         return sanitize_item(list(value)) | ||||||
|     if isinstance(value, (list, tuple, set)): |     if isinstance(value, list | tuple | set): | ||||||
|         new_values = [] |         new_values = [] | ||||||
|         for item in value: |         for item in value: | ||||||
|             new_value = sanitize_item(item) |             new_value = sanitize_item(item) | ||||||
|             if new_value: |             if new_value: | ||||||
|                 new_values.append(new_value) |                 new_values.append(new_value) | ||||||
|         return new_values |         return new_values | ||||||
|     if isinstance(value, (User, AnonymousUser)): |     if isinstance(value, User | AnonymousUser): | ||||||
|         return sanitize_dict(get_user(value)) |         return sanitize_dict(get_user(value)) | ||||||
|     if isinstance(value, models.Model): |     if isinstance(value, models.Model): | ||||||
|         return sanitize_dict(model_to_dict(value)) |         return sanitize_dict(model_to_dict(value)) | ||||||
|     if isinstance(value, UUID): |     if isinstance(value, UUID): | ||||||
|         return value.hex |         return value.hex | ||||||
|     if isinstance(value, (HttpRequest, WSGIRequest)): |     if isinstance(value, HttpRequest | WSGIRequest): | ||||||
|         return ... |         return ... | ||||||
|     if isinstance(value, City): |     if isinstance(value, City): | ||||||
|         return GEOIP_CONTEXT_PROCESSOR.city_to_dict(value) |         return GEOIP_CONTEXT_PROCESSOR.city_to_dict(value) | ||||||
| @ -171,7 +170,7 @@ def sanitize_item(value: Any) -> Any: | |||||||
|             "module": value.__module__, |             "module": value.__module__, | ||||||
|         } |         } | ||||||
|     # List taken from the stdlib's JSON encoder (_make_iterencode, encoder.py:415) |     # List taken from the stdlib's JSON encoder (_make_iterencode, encoder.py:415) | ||||||
|     if isinstance(value, (bool, int, float, NoneType, list, tuple, dict)): |     if isinstance(value, bool | int | float | NoneType | list | tuple | dict): | ||||||
|         return value |         return value | ||||||
|     try: |     try: | ||||||
|         return DjangoJSONEncoder().default(value) |         return DjangoJSONEncoder().default(value) | ||||||
|  | |||||||
| @ -114,7 +114,6 @@ class FlowImportResultSerializer(PassiveSerializer): | |||||||
| class FlowViewSet(UsedByMixin, ModelViewSet): | class FlowViewSet(UsedByMixin, ModelViewSet): | ||||||
|     """Flow Viewset""" |     """Flow Viewset""" | ||||||
|  |  | ||||||
|     # pylint: disable=no-member |  | ||||||
|     queryset = Flow.objects.all().prefetch_related("stages", "policies") |     queryset = Flow.objects.all().prefetch_related("stages", "policies") | ||||||
|     serializer_class = FlowSerializer |     serializer_class = FlowSerializer | ||||||
|     lookup_field = "slug" |     lookup_field = "slug" | ||||||
| @ -279,7 +278,7 @@ class FlowViewSet(UsedByMixin, ModelViewSet): | |||||||
|         }, |         }, | ||||||
|     ) |     ) | ||||||
|     @action(detail=True, pagination_class=None, filter_backends=[]) |     @action(detail=True, pagination_class=None, filter_backends=[]) | ||||||
|     def execute(self, request: Request, slug: str): |     def execute(self, request: Request, _slug: str): | ||||||
|         """Execute flow for current user""" |         """Execute flow for current user""" | ||||||
|         # Because we pre-plan the flow here, and not in the planner, we need to manually clear |         # Because we pre-plan the flow here, and not in the planner, we need to manually clear | ||||||
|         # the history of the inspector |         # the history of the inspector | ||||||
| @ -294,8 +293,9 @@ class FlowViewSet(UsedByMixin, ModelViewSet): | |||||||
|             return bad_request_message( |             return bad_request_message( | ||||||
|                 request, |                 request, | ||||||
|                 _( |                 _( | ||||||
|                     "Flow not applicable to current user/request: %(messages)s" |                     "Flow not applicable to current user/request: {messages}".format_map( | ||||||
|                     % {"messages": exc.messages} |                         {"messages": exc.messages} | ||||||
|  |                     ) | ||||||
|                 ), |                 ), | ||||||
|             ) |             ) | ||||||
|         return Response( |         return Response( | ||||||
|  | |||||||
| @ -1,7 +1,6 @@ | |||||||
| """Flows Diagram API""" | """Flows Diagram API""" | ||||||
|  |  | ||||||
| from dataclasses import dataclass, field | from dataclasses import dataclass, field | ||||||
| from typing import Optional |  | ||||||
|  |  | ||||||
| from django.utils.translation import gettext as _ | from django.utils.translation import gettext as _ | ||||||
| from guardian.shortcuts import get_objects_for_user | from guardian.shortcuts import get_objects_for_user | ||||||
| @ -18,8 +17,8 @@ class DiagramElement: | |||||||
|  |  | ||||||
|     identifier: str |     identifier: str | ||||||
|     description: str |     description: str | ||||||
|     action: Optional[str] = None |     action: str | None = None | ||||||
|     source: Optional[list["DiagramElement"]] = None |     source: list["DiagramElement"] | None = None | ||||||
|  |  | ||||||
|     style: list[str] = field(default_factory=lambda: ["[", "]"]) |     style: list[str] = field(default_factory=lambda: ["[", "]"]) | ||||||
|  |  | ||||||
| @ -66,10 +65,10 @@ class FlowDiagram: | |||||||
|         ): |         ): | ||||||
|             element = DiagramElement( |             element = DiagramElement( | ||||||
|                 f"flow_policy_{p_index}", |                 f"flow_policy_{p_index}", | ||||||
|                 _("Policy (%(type)s)" % {"type": policy_binding.policy._meta.verbose_name}) |                 _("Policy ({type})".format_map({"type": policy_binding.policy._meta.verbose_name})) | ||||||
|                 + "\n" |                 + "\n" | ||||||
|                 + policy_binding.policy.name, |                 + policy_binding.policy.name, | ||||||
|                 _("Binding %(order)d" % {"order": policy_binding.order}), |                 _("Binding {order}".format_map({"order": policy_binding.order})), | ||||||
|                 parent_elements, |                 parent_elements, | ||||||
|                 style=["{{", "}}"], |                 style=["{{", "}}"], | ||||||
|             ) |             ) | ||||||
| @ -92,7 +91,7 @@ class FlowDiagram: | |||||||
|         ): |         ): | ||||||
|             element = DiagramElement( |             element = DiagramElement( | ||||||
|                 f"stage_{stage_index}_policy_{p_index}", |                 f"stage_{stage_index}_policy_{p_index}", | ||||||
|                 _("Policy (%(type)s)" % {"type": policy_binding.policy._meta.verbose_name}) |                 _("Policy ({type})".format_map({"type": policy_binding.policy._meta.verbose_name})) | ||||||
|                 + "\n" |                 + "\n" | ||||||
|                 + policy_binding.policy.name, |                 + policy_binding.policy.name, | ||||||
|                 "", |                 "", | ||||||
| @ -120,7 +119,7 @@ class FlowDiagram: | |||||||
|  |  | ||||||
|             element = DiagramElement( |             element = DiagramElement( | ||||||
|                 f"stage_{s_index}", |                 f"stage_{s_index}", | ||||||
|                 _("Stage (%(type)s)" % {"type": stage_binding.stage._meta.verbose_name}) |                 _("Stage ({type})".format_map({"type": stage_binding.stage._meta.verbose_name})) | ||||||
|                 + "\n" |                 + "\n" | ||||||
|                 + stage_binding.stage.name, |                 + stage_binding.stage.name, | ||||||
|                 action, |                 action, | ||||||
|  | |||||||
| @ -31,9 +31,10 @@ class AuthentikFlowsConfig(ManagedAppConfig): | |||||||
|     verbose_name = "authentik Flows" |     verbose_name = "authentik Flows" | ||||||
|     default = True |     default = True | ||||||
|  |  | ||||||
|     def reconcile_global_load_stages(self): |     @ManagedAppConfig.reconcile_global | ||||||
|  |     def load_stages(self): | ||||||
|         """Ensure all stages are loaded""" |         """Ensure all stages are loaded""" | ||||||
|         from authentik.flows.models import Stage |         from authentik.flows.models import Stage | ||||||
|  |  | ||||||
|         for stage in all_subclasses(Stage): |         for stage in all_subclasses(Stage): | ||||||
|             _ = stage().type |             _ = stage().view | ||||||
|  | |||||||
| @ -104,7 +104,7 @@ class FlowErrorChallenge(Challenge): | |||||||
|     error = CharField(required=False) |     error = CharField(required=False) | ||||||
|     traceback = CharField(required=False) |     traceback = CharField(required=False) | ||||||
|  |  | ||||||
|     def __init__(self, request: Optional[Request] = None, error: Optional[Exception] = None): |     def __init__(self, request: Request | None = None, error: Exception | None = None): | ||||||
|         super().__init__(data={}) |         super().__init__(data={}) | ||||||
|         if not request or not error: |         if not request or not error: | ||||||
|             return |             return | ||||||
|  | |||||||
| @ -1,7 +1,5 @@ | |||||||
| """flow exceptions""" | """flow exceptions""" | ||||||
|  |  | ||||||
| from typing import Optional |  | ||||||
|  |  | ||||||
| from django.utils.translation import gettext_lazy as _ | from django.utils.translation import gettext_lazy as _ | ||||||
|  |  | ||||||
| from authentik.lib.sentry import SentryIgnoredException | from authentik.lib.sentry import SentryIgnoredException | ||||||
| @ -11,7 +9,7 @@ from authentik.policies.types import PolicyResult | |||||||
| class FlowNonApplicableException(SentryIgnoredException): | class FlowNonApplicableException(SentryIgnoredException): | ||||||
|     """Flow does not apply to current user (denied by policy, or otherwise).""" |     """Flow does not apply to current user (denied by policy, or otherwise).""" | ||||||
|  |  | ||||||
|     policy_result: Optional[PolicyResult] = None |     policy_result: PolicyResult | None = None | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def messages(self) -> str: |     def messages(self) -> str: | ||||||
|  | |||||||
| @ -1,7 +1,7 @@ | |||||||
| """Stage Markers""" | """Stage Markers""" | ||||||
|  |  | ||||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||||
| from typing import TYPE_CHECKING, Optional | from typing import TYPE_CHECKING | ||||||
|  |  | ||||||
| from django.http.request import HttpRequest | from django.http.request import HttpRequest | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
| @ -25,7 +25,7 @@ class StageMarker: | |||||||
|         plan: "FlowPlan", |         plan: "FlowPlan", | ||||||
|         binding: FlowStageBinding, |         binding: FlowStageBinding, | ||||||
|         http_request: HttpRequest, |         http_request: HttpRequest, | ||||||
|     ) -> Optional[FlowStageBinding]: |     ) -> FlowStageBinding | None: | ||||||
|         """Process callback for this marker. This should be overridden by sub-classes. |         """Process callback for this marker. This should be overridden by sub-classes. | ||||||
|         If a stage should be removed, return None.""" |         If a stage should be removed, return None.""" | ||||||
|         return binding |         return binding | ||||||
| @ -42,7 +42,7 @@ class ReevaluateMarker(StageMarker): | |||||||
|         plan: "FlowPlan", |         plan: "FlowPlan", | ||||||
|         binding: FlowStageBinding, |         binding: FlowStageBinding, | ||||||
|         http_request: HttpRequest, |         http_request: HttpRequest, | ||||||
|     ) -> Optional[FlowStageBinding]: |     ) -> FlowStageBinding | None: | ||||||
|         """Re-evaluate policies bound to stage, and if they fail, remove from plan""" |         """Re-evaluate policies bound to stage, and if they fail, remove from plan""" | ||||||
|         from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER |         from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER | ||||||
|  |  | ||||||
|  | |||||||
| @ -16,7 +16,7 @@ def set_oobe_flow_authentication(apps: Apps, schema_editor: BaseDatabaseSchemaEd | |||||||
|     users = User.objects.using(db_alias).exclude(username="akadmin") |     users = User.objects.using(db_alias).exclude(username="akadmin") | ||||||
|     try: |     try: | ||||||
|         users = users.exclude(pk=get_anonymous_user().pk) |         users = users.exclude(pk=get_anonymous_user().pk) | ||||||
|     # pylint: disable=broad-except |  | ||||||
|     except Exception:  # nosec |     except Exception:  # nosec | ||||||
|         pass |         pass | ||||||
|  |  | ||||||
|  | |||||||
| @ -2,7 +2,7 @@ | |||||||
|  |  | ||||||
| from base64 import b64decode, b64encode | from base64 import b64decode, b64encode | ||||||
| from pickle import dumps, loads  # nosec | from pickle import dumps, loads  # nosec | ||||||
| from typing import TYPE_CHECKING, Optional | from typing import TYPE_CHECKING | ||||||
| from uuid import uuid4 | from uuid import uuid4 | ||||||
|  |  | ||||||
| from django.db import models | from django.db import models | ||||||
| @ -83,7 +83,7 @@ class Stage(SerializerModel): | |||||||
|     objects = InheritanceManager() |     objects = InheritanceManager() | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def type(self) -> type["StageView"]: |     def view(self) -> type["StageView"]: | ||||||
|         """Return StageView class that implements logic for this stage""" |         """Return StageView class that implements logic for this stage""" | ||||||
|         # This is a bit of a workaround, since we can't set class methods with setattr |         # This is a bit of a workaround, since we can't set class methods with setattr | ||||||
|         if hasattr(self, "__in_memory_type"): |         if hasattr(self, "__in_memory_type"): | ||||||
| @ -95,7 +95,7 @@ class Stage(SerializerModel): | |||||||
|         """Return component used to edit this object""" |         """Return component used to edit this object""" | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|     def ui_user_settings(self) -> Optional[UserSettingSerializer]: |     def ui_user_settings(self) -> UserSettingSerializer | None: | ||||||
|         """Entrypoint to integrate with User settings. Can either return None if no |         """Entrypoint to integrate with User settings. Can either return None if no | ||||||
|         user settings are available, or a challenge.""" |         user settings are available, or a challenge.""" | ||||||
|         return None |         return None | ||||||
| @ -113,8 +113,8 @@ def in_memory_stage(view: type["StageView"], **kwargs) -> Stage: | |||||||
|     # we set the view as a separate property and reference a generic function |     # we set the view as a separate property and reference a generic function | ||||||
|     # that returns that member |     # that returns that member | ||||||
|     setattr(stage, "__in_memory_type", view) |     setattr(stage, "__in_memory_type", view) | ||||||
|     setattr(stage, "name", _("Dynamic In-memory stage: %(doc)s" % {"doc": view.__doc__})) |     stage.name = _("Dynamic In-memory stage: {doc}".format_map({"doc": view.__doc__})) | ||||||
|     setattr(stage._meta, "verbose_name", class_to_path(view)) |     stage._meta.verbose_name = class_to_path(view) | ||||||
|     for key, value in kwargs.items(): |     for key, value in kwargs.items(): | ||||||
|         setattr(stage, key, value) |         setattr(stage, key, value) | ||||||
|     return stage |     return stage | ||||||
|  | |||||||
| @ -1,7 +1,7 @@ | |||||||
| """Flows Planner""" | """Flows Planner""" | ||||||
|  |  | ||||||
| from dataclasses import dataclass, field | from dataclasses import dataclass, field | ||||||
| from typing import Any, Optional | from typing import Any | ||||||
|  |  | ||||||
| from django.core.cache import cache | from django.core.cache import cache | ||||||
| from django.http import HttpRequest | from django.http import HttpRequest | ||||||
| @ -39,7 +39,7 @@ CACHE_TIMEOUT = CONFIG.get_int("cache.timeout_flows") | |||||||
| CACHE_PREFIX = "goauthentik.io/flows/planner/" | CACHE_PREFIX = "goauthentik.io/flows/planner/" | ||||||
|  |  | ||||||
|  |  | ||||||
| def cache_key(flow: Flow, user: Optional[User] = None) -> str: | def cache_key(flow: Flow, user: User | None = None) -> str: | ||||||
|     """Generate Cache key for flow""" |     """Generate Cache key for flow""" | ||||||
|     prefix = CACHE_PREFIX + str(flow.pk) |     prefix = CACHE_PREFIX + str(flow.pk) | ||||||
|     if user: |     if user: | ||||||
| @ -58,16 +58,16 @@ class FlowPlan: | |||||||
|     context: dict[str, Any] = field(default_factory=dict) |     context: dict[str, Any] = field(default_factory=dict) | ||||||
|     markers: list[StageMarker] = field(default_factory=list) |     markers: list[StageMarker] = field(default_factory=list) | ||||||
|  |  | ||||||
|     def append_stage(self, stage: Stage, marker: Optional[StageMarker] = None): |     def append_stage(self, stage: Stage, marker: StageMarker | None = None): | ||||||
|         """Append `stage` to all stages, optionally with stage marker""" |         """Append `stage` to all stages, optionally with stage marker""" | ||||||
|         return self.append(FlowStageBinding(stage=stage), marker) |         return self.append(FlowStageBinding(stage=stage), marker) | ||||||
|  |  | ||||||
|     def append(self, binding: FlowStageBinding, marker: Optional[StageMarker] = None): |     def append(self, binding: FlowStageBinding, marker: StageMarker | None = None): | ||||||
|         """Append `stage` to all stages, optionally with stage marker""" |         """Append `stage` to all stages, optionally with stage marker""" | ||||||
|         self.bindings.append(binding) |         self.bindings.append(binding) | ||||||
|         self.markers.append(marker or StageMarker()) |         self.markers.append(marker or StageMarker()) | ||||||
|  |  | ||||||
|     def insert_stage(self, stage: Stage, marker: Optional[StageMarker] = None): |     def insert_stage(self, stage: Stage, marker: StageMarker | None = None): | ||||||
|         """Insert stage into plan, as immediate next stage""" |         """Insert stage into plan, as immediate next stage""" | ||||||
|         self.bindings.insert(1, FlowStageBinding(stage=stage, order=0)) |         self.bindings.insert(1, FlowStageBinding(stage=stage, order=0)) | ||||||
|         self.markers.insert(1, marker or StageMarker()) |         self.markers.insert(1, marker or StageMarker()) | ||||||
| @ -78,7 +78,7 @@ class FlowPlan: | |||||||
|  |  | ||||||
|         self.insert_stage(in_memory_stage(RedirectStage, destination=destination)) |         self.insert_stage(in_memory_stage(RedirectStage, destination=destination)) | ||||||
|  |  | ||||||
|     def next(self, http_request: Optional[HttpRequest]) -> Optional[FlowStageBinding]: |     def next(self, http_request: HttpRequest | None) -> FlowStageBinding | None: | ||||||
|         """Return next pending stage from the bottom of the list""" |         """Return next pending stage from the bottom of the list""" | ||||||
|         if not self.has_stages: |         if not self.has_stages: | ||||||
|             return None |             return None | ||||||
| @ -94,7 +94,7 @@ class FlowPlan: | |||||||
|             self.markers.remove(marker) |             self.markers.remove(marker) | ||||||
|             if not self.has_stages: |             if not self.has_stages: | ||||||
|                 return None |                 return None | ||||||
|             # pylint: disable=not-callable |  | ||||||
|             return self.next(http_request) |             return self.next(http_request) | ||||||
|         return marked_stage |         return marked_stage | ||||||
|  |  | ||||||
| @ -148,9 +148,7 @@ class FlowPlanner: | |||||||
|             if not outpost_user: |             if not outpost_user: | ||||||
|                 raise FlowNonApplicableException() |                 raise FlowNonApplicableException() | ||||||
|  |  | ||||||
|     def plan( |     def plan(self, request: HttpRequest, default_context: dict[str, Any] | None = None) -> FlowPlan: | ||||||
|         self, request: HttpRequest, default_context: Optional[dict[str, Any]] = None |  | ||||||
|     ) -> FlowPlan: |  | ||||||
|         """Check each of the flows' policies, check policies for each stage with PolicyBinding |         """Check each of the flows' policies, check policies for each stage with PolicyBinding | ||||||
|         and return ordered list""" |         and return ordered list""" | ||||||
|         with Hub.current.start_span( |         with Hub.current.start_span( | ||||||
| @ -214,7 +212,7 @@ class FlowPlanner: | |||||||
|         self, |         self, | ||||||
|         user: User, |         user: User, | ||||||
|         request: HttpRequest, |         request: HttpRequest, | ||||||
|         default_context: Optional[dict[str, Any]], |         default_context: dict[str, Any] | None, | ||||||
|     ) -> FlowPlan: |     ) -> FlowPlan: | ||||||
|         """Build flow plan by checking each stage in their respective |         """Build flow plan by checking each stage in their respective | ||||||
|         order and checking the applied policies""" |         order and checking the applied policies""" | ||||||
|  | |||||||
| @ -1,6 +1,6 @@ | |||||||
| """authentik stage Base view""" | """authentik stage Base view""" | ||||||
|  |  | ||||||
| from typing import TYPE_CHECKING, Optional | from typing import TYPE_CHECKING | ||||||
|  |  | ||||||
| from django.contrib.auth.models import AnonymousUser | from django.contrib.auth.models import AnonymousUser | ||||||
| from django.http import HttpRequest | from django.http import HttpRequest | ||||||
| @ -153,7 +153,7 @@ class ChallengeStageView(StageView): | |||||||
|                 "app": self.executor.plan.context.get(PLAN_CONTEXT_APPLICATION, ""), |                 "app": self.executor.plan.context.get(PLAN_CONTEXT_APPLICATION, ""), | ||||||
|                 "user": self.get_pending_user(for_display=True), |                 "user": self.get_pending_user(for_display=True), | ||||||
|             } |             } | ||||||
|         # pylint: disable=broad-except |  | ||||||
|         except Exception as exc: |         except Exception as exc: | ||||||
|             self.logger.warning("failed to template title", exc=exc) |             self.logger.warning("failed to template title", exc=exc) | ||||||
|             return self.executor.flow.title |             return self.executor.flow.title | ||||||
| @ -234,9 +234,9 @@ class ChallengeStageView(StageView): | |||||||
| class AccessDeniedChallengeView(ChallengeStageView): | class AccessDeniedChallengeView(ChallengeStageView): | ||||||
|     """Used internally by FlowExecutor's stage_invalid()""" |     """Used internally by FlowExecutor's stage_invalid()""" | ||||||
|  |  | ||||||
|     error_message: Optional[str] |     error_message: str | None | ||||||
|  |  | ||||||
|     def __init__(self, executor: "FlowExecutorView", error_message: Optional[str] = None, **kwargs): |     def __init__(self, executor: "FlowExecutorView", error_message: str | None = None, **kwargs): | ||||||
|         super().__init__(executor, **kwargs) |         super().__init__(executor, **kwargs) | ||||||
|         self.error_message = error_message |         self.error_message = error_message | ||||||
|  |  | ||||||
|  | |||||||
| @ -1,7 +1,7 @@ | |||||||
| """Test helpers""" | """Test helpers""" | ||||||
|  |  | ||||||
| from json import loads | from json import loads | ||||||
| from typing import Any, Optional | from typing import Any | ||||||
|  |  | ||||||
| from django.http.response import HttpResponse | from django.http.response import HttpResponse | ||||||
| from django.urls.base import reverse | from django.urls.base import reverse | ||||||
| @ -15,12 +15,11 @@ from authentik.flows.models import Flow | |||||||
| class FlowTestCase(APITestCase): | class FlowTestCase(APITestCase): | ||||||
|     """Helpers for testing flows and stages.""" |     """Helpers for testing flows and stages.""" | ||||||
|  |  | ||||||
|     # pylint: disable=invalid-name |  | ||||||
|     def assertStageResponse( |     def assertStageResponse( | ||||||
|         self, |         self, | ||||||
|         response: HttpResponse, |         response: HttpResponse, | ||||||
|         flow: Optional[Flow] = None, |         flow: Flow | None = None, | ||||||
|         user: Optional[User] = None, |         user: User | None = None, | ||||||
|         **kwargs, |         **kwargs, | ||||||
|     ) -> dict[str, Any]: |     ) -> dict[str, Any]: | ||||||
|         """Assert various attributes of a stage response""" |         """Assert various attributes of a stage response""" | ||||||
| @ -45,7 +44,6 @@ class FlowTestCase(APITestCase): | |||||||
|             self.assertEqual(raw_response[key], expected) |             self.assertEqual(raw_response[key], expected) | ||||||
|         return raw_response |         return raw_response | ||||||
|  |  | ||||||
|     # pylint: disable=invalid-name |  | ||||||
|     def assertStageRedirects(self, response: HttpResponse, to: str) -> dict[str, Any]: |     def assertStageRedirects(self, response: HttpResponse, to: str) -> dict[str, Any]: | ||||||
|         """Wrapper around assertStageResponse that checks for a redirect""" |         """Wrapper around assertStageResponse that checks for a redirect""" | ||||||
|         return self.assertStageResponse( |         return self.assertStageResponse( | ||||||
|  | |||||||
| @ -1,7 +1,6 @@ | |||||||
| """flow views tests""" | """flow views tests""" | ||||||
|  |  | ||||||
| from unittest.mock import MagicMock, PropertyMock, patch | from unittest.mock import MagicMock, PropertyMock, patch | ||||||
| from urllib.parse import urlencode |  | ||||||
|  |  | ||||||
| from django.http import HttpRequest, HttpResponse | from django.http import HttpRequest, HttpResponse | ||||||
| from django.test.client import RequestFactory | from django.test.client import RequestFactory | ||||||
| @ -19,12 +18,7 @@ from authentik.flows.models import ( | |||||||
| from authentik.flows.planner import FlowPlan, FlowPlanner | from authentik.flows.planner import FlowPlan, FlowPlanner | ||||||
| from authentik.flows.stage import PLAN_CONTEXT_PENDING_USER_IDENTIFIER, StageView | from authentik.flows.stage import PLAN_CONTEXT_PENDING_USER_IDENTIFIER, StageView | ||||||
| from authentik.flows.tests import FlowTestCase | from authentik.flows.tests import FlowTestCase | ||||||
| from authentik.flows.views.executor import ( | from authentik.flows.views.executor import NEXT_ARG_NAME, SESSION_KEY_PLAN, FlowExecutorView | ||||||
|     NEXT_ARG_NAME, |  | ||||||
|     QS_QUERY, |  | ||||||
|     SESSION_KEY_PLAN, |  | ||||||
|     FlowExecutorView, |  | ||||||
| ) |  | ||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
| from authentik.policies.dummy.models import DummyPolicy | from authentik.policies.dummy.models import DummyPolicy | ||||||
| from authentik.policies.models import PolicyBinding | from authentik.policies.models import PolicyBinding | ||||||
| @ -127,73 +121,16 @@ class TestFlowExecutor(FlowTestCase): | |||||||
|         TO_STAGE_RESPONSE_MOCK, |         TO_STAGE_RESPONSE_MOCK, | ||||||
|     ) |     ) | ||||||
|     def test_invalid_flow_redirect(self): |     def test_invalid_flow_redirect(self): | ||||||
|         """Test invalid flow with valid redirect destination""" |         """Tests that an invalid flow still redirects""" | ||||||
|         flow = create_test_flow( |         flow = create_test_flow( | ||||||
|             FlowDesignation.AUTHENTICATION, |             FlowDesignation.AUTHENTICATION, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|         dest = "/unique-string" |         dest = "/unique-string" | ||||||
|         url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}) |         url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}) | ||||||
|         response = self.client.get(url + f"?{QS_QUERY}={urlencode({NEXT_ARG_NAME: dest})}") |         response = self.client.get(url + f"?{NEXT_ARG_NAME}={dest}") | ||||||
|         self.assertEqual(response.status_code, 302) |         self.assertEqual(response.status_code, 302) | ||||||
|         self.assertEqual(response.url, "/unique-string") |         self.assertEqual(response.url, reverse("authentik_core:root-redirect")) | ||||||
|  |  | ||||||
|     @patch( |  | ||||||
|         "authentik.flows.views.executor.to_stage_response", |  | ||||||
|         TO_STAGE_RESPONSE_MOCK, |  | ||||||
|     ) |  | ||||||
|     def test_invalid_flow_invalid_redirect(self): |  | ||||||
|         """Test invalid flow redirect with an invalid URL""" |  | ||||||
|         flow = create_test_flow( |  | ||||||
|             FlowDesignation.AUTHENTICATION, |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|         dest = "http://something.example.com/unique-string" |  | ||||||
|         url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}) |  | ||||||
|  |  | ||||||
|         response = self.client.get(url + f"?{QS_QUERY}={urlencode({NEXT_ARG_NAME: dest})}") |  | ||||||
|         self.assertEqual(response.status_code, 200) |  | ||||||
|         self.assertStageResponse( |  | ||||||
|             response, |  | ||||||
|             flow, |  | ||||||
|             component="ak-stage-access-denied", |  | ||||||
|             error_message="Invalid next URL", |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     @patch( |  | ||||||
|         "authentik.flows.views.executor.to_stage_response", |  | ||||||
|         TO_STAGE_RESPONSE_MOCK, |  | ||||||
|     ) |  | ||||||
|     def test_valid_flow_redirect(self): |  | ||||||
|         """Test valid flow with valid redirect destination""" |  | ||||||
|         flow = create_test_flow() |  | ||||||
|  |  | ||||||
|         dest = "/unique-string" |  | ||||||
|         url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}) |  | ||||||
|  |  | ||||||
|         response = self.client.get(url + f"?{QS_QUERY}={urlencode({NEXT_ARG_NAME: dest})}") |  | ||||||
|         self.assertEqual(response.status_code, 302) |  | ||||||
|         self.assertEqual(response.url, "/unique-string") |  | ||||||
|  |  | ||||||
|     @patch( |  | ||||||
|         "authentik.flows.views.executor.to_stage_response", |  | ||||||
|         TO_STAGE_RESPONSE_MOCK, |  | ||||||
|     ) |  | ||||||
|     def test_valid_flow_invalid_redirect(self): |  | ||||||
|         """Test valid flow redirect with an invalid URL""" |  | ||||||
|         flow = create_test_flow() |  | ||||||
|  |  | ||||||
|         dest = "http://something.example.com/unique-string" |  | ||||||
|         url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}) |  | ||||||
|  |  | ||||||
|         response = self.client.get(url + f"?{QS_QUERY}={urlencode({NEXT_ARG_NAME: dest})}") |  | ||||||
|         self.assertEqual(response.status_code, 200) |  | ||||||
|         self.assertStageResponse( |  | ||||||
|             response, |  | ||||||
|             flow, |  | ||||||
|             component="ak-stage-access-denied", |  | ||||||
|             error_message="Invalid next URL", |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     @patch( |     @patch( | ||||||
|         "authentik.flows.views.executor.to_stage_response", |         "authentik.flows.views.executor.to_stage_response", | ||||||
|  | |||||||
| @ -1,6 +1,6 @@ | |||||||
| """base model tests""" | """base model tests""" | ||||||
|  |  | ||||||
| from typing import Callable | from collections.abc import Callable | ||||||
|  |  | ||||||
| from django.test import TestCase | from django.test import TestCase | ||||||
|  |  | ||||||
| @ -22,7 +22,7 @@ def model_tester_factory(test_model: type[Stage]) -> Callable: | |||||||
|             model_class = test_model.__bases__[0]() |             model_class = test_model.__bases__[0]() | ||||||
|         else: |         else: | ||||||
|             model_class = test_model() |             model_class = test_model() | ||||||
|         self.assertTrue(issubclass(model_class.type, StageView)) |         self.assertTrue(issubclass(model_class.view, StageView)) | ||||||
|         self.assertIsNotNone(test_model.component) |         self.assertIsNotNone(test_model.component) | ||||||
|         _ = model_class.ui_user_settings() |         _ = model_class.ui_user_settings() | ||||||
|  |  | ||||||
|  | |||||||
| @ -1,6 +1,6 @@ | |||||||
| """stage view tests""" | """stage view tests""" | ||||||
|  |  | ||||||
| from typing import Callable | from collections.abc import Callable | ||||||
|  |  | ||||||
| from django.test import RequestFactory, TestCase | from django.test import RequestFactory, TestCase | ||||||
|  |  | ||||||
|  | |||||||
| @ -1,7 +1,6 @@ | |||||||
| """authentik multi-stage authentication engine""" | """authentik multi-stage authentication engine""" | ||||||
|  |  | ||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
| from typing import Optional |  | ||||||
|  |  | ||||||
| from django.conf import settings | from django.conf import settings | ||||||
| from django.contrib.auth.mixins import LoginRequiredMixin | from django.contrib.auth.mixins import LoginRequiredMixin | ||||||
| @ -12,7 +11,6 @@ from django.shortcuts import get_object_or_404, redirect | |||||||
| from django.template.response import TemplateResponse | from django.template.response import TemplateResponse | ||||||
| from django.urls import reverse | from django.urls import reverse | ||||||
| from django.utils.decorators import method_decorator | from django.utils.decorators import method_decorator | ||||||
| from django.utils.translation import gettext as _ |  | ||||||
| from django.views.decorators.clickjacking import xframe_options_sameorigin | from django.views.decorators.clickjacking import xframe_options_sameorigin | ||||||
| from django.views.generic import View | from django.views.generic import View | ||||||
| from drf_spectacular.types import OpenApiTypes | from drf_spectacular.types import OpenApiTypes | ||||||
| @ -108,8 +106,8 @@ class FlowExecutorView(APIView): | |||||||
|  |  | ||||||
|     flow: Flow |     flow: Flow | ||||||
|  |  | ||||||
|     plan: Optional[FlowPlan] = None |     plan: FlowPlan | None = None | ||||||
|     current_binding: Optional[FlowStageBinding] = None |     current_binding: FlowStageBinding | None = None | ||||||
|     current_stage: Stage |     current_stage: Stage | ||||||
|     current_stage_view: View |     current_stage_view: View | ||||||
|  |  | ||||||
| @ -137,9 +135,9 @@ class FlowExecutorView(APIView): | |||||||
|             ) |             ) | ||||||
|         return to_stage_response(self.request, self.stage_invalid(error_message=exc.messages)) |         return to_stage_response(self.request, self.stage_invalid(error_message=exc.messages)) | ||||||
|  |  | ||||||
|     def _check_flow_token(self, key: str) -> Optional[FlowPlan]: |     def _check_flow_token(self, key: str) -> FlowPlan | None: | ||||||
|         """Check if the user is using a flow token to restore a plan""" |         """Check if the user is using a flow token to restore a plan""" | ||||||
|         token: Optional[FlowToken] = FlowToken.filter_not_expired(key=key).first() |         token: FlowToken | None = FlowToken.filter_not_expired(key=key).first() | ||||||
|         if not token: |         if not token: | ||||||
|             return None |             return None | ||||||
|         plan = None |         plan = None | ||||||
| @ -155,7 +153,6 @@ class FlowExecutorView(APIView): | |||||||
|         self._logger.debug("f(exec): restored flow plan from token", plan=plan) |         self._logger.debug("f(exec): restored flow plan from token", plan=plan) | ||||||
|         return plan |         return plan | ||||||
|  |  | ||||||
|     # pylint: disable=too-many-return-statements |  | ||||||
|     def dispatch(self, request: HttpRequest, flow_slug: str) -> HttpResponse: |     def dispatch(self, request: HttpRequest, flow_slug: str) -> HttpResponse: | ||||||
|         with Hub.current.start_span( |         with Hub.current.start_span( | ||||||
|             op="authentik.flow.executor.dispatch", description=self.flow.slug |             op="authentik.flow.executor.dispatch", description=self.flow.slug | ||||||
| @ -179,8 +176,6 @@ class FlowExecutorView(APIView): | |||||||
|                     self.cancel() |                     self.cancel() | ||||||
|                 self._logger.debug("f(exec): Continuing existing plan") |                 self._logger.debug("f(exec): Continuing existing plan") | ||||||
|  |  | ||||||
|             # Initial flow request, check if we have an upstream query string passed in |  | ||||||
|             request.session[SESSION_KEY_GET] = get_params |  | ||||||
|             # Don't check session again as we've either already loaded the plan or we need to plan |             # Don't check session again as we've either already loaded the plan or we need to plan | ||||||
|             if not self.plan: |             if not self.plan: | ||||||
|                 request.session[SESSION_KEY_HISTORY] = [] |                 request.session[SESSION_KEY_HISTORY] = [] | ||||||
| @ -195,6 +190,8 @@ class FlowExecutorView(APIView): | |||||||
|                     # To match behaviour with loading an empty flow plan from cache, |                     # To match behaviour with loading an empty flow plan from cache, | ||||||
|                     # we don't show an error message here, but rather call _flow_done() |                     # we don't show an error message here, but rather call _flow_done() | ||||||
|                     return self._flow_done() |                     return self._flow_done() | ||||||
|  |             # Initial flow request, check if we have an upstream query string passed in | ||||||
|  |             request.session[SESSION_KEY_GET] = get_params | ||||||
|             # We don't save the Plan after getting the next stage |             # We don't save the Plan after getting the next stage | ||||||
|             # as it hasn't been successfully passed yet |             # as it hasn't been successfully passed yet | ||||||
|             try: |             try: | ||||||
| @ -202,7 +199,7 @@ class FlowExecutorView(APIView): | |||||||
|                 # if the cached plan is from an older version, it might have different attributes |                 # if the cached plan is from an older version, it might have different attributes | ||||||
|                 # in which case we just delete the plan and invalidate everything |                 # in which case we just delete the plan and invalidate everything | ||||||
|                 next_binding = self.plan.next(self.request) |                 next_binding = self.plan.next(self.request) | ||||||
|             except Exception as exc:  # pylint: disable=broad-except |             except Exception as exc: | ||||||
|                 self._logger.warning( |                 self._logger.warning( | ||||||
|                     "f(exec): found incompatible flow plan, invalidating run", exc=exc |                     "f(exec): found incompatible flow plan, invalidating run", exc=exc | ||||||
|                 ) |                 ) | ||||||
| @ -220,7 +217,7 @@ class FlowExecutorView(APIView): | |||||||
|                 flow_slug=self.flow.slug, |                 flow_slug=self.flow.slug, | ||||||
|             ) |             ) | ||||||
|             try: |             try: | ||||||
|                 stage_cls = self.current_stage.type |                 stage_cls = self.current_stage.view | ||||||
|             except NotImplementedError as exc: |             except NotImplementedError as exc: | ||||||
|                 self._logger.debug("Error getting stage type", exc=exc) |                 self._logger.debug("Error getting stage type", exc=exc) | ||||||
|                 return self.stage_invalid() |                 return self.stage_invalid() | ||||||
| @ -291,7 +288,7 @@ class FlowExecutorView(APIView): | |||||||
|                 span.set_data("authentik Flow", self.flow.slug) |                 span.set_data("authentik Flow", self.flow.slug) | ||||||
|                 stage_response = self.current_stage_view.dispatch(request) |                 stage_response = self.current_stage_view.dispatch(request) | ||||||
|                 return to_stage_response(request, stage_response) |                 return to_stage_response(request, stage_response) | ||||||
|         except Exception as exc:  # pylint: disable=broad-except |         except Exception as exc: | ||||||
|             return self.handle_exception(exc) |             return self.handle_exception(exc) | ||||||
|  |  | ||||||
|     @extend_schema( |     @extend_schema( | ||||||
| @ -342,7 +339,7 @@ class FlowExecutorView(APIView): | |||||||
|                 span.set_data("authentik Flow", self.flow.slug) |                 span.set_data("authentik Flow", self.flow.slug) | ||||||
|                 stage_response = self.current_stage_view.dispatch(request) |                 stage_response = self.current_stage_view.dispatch(request) | ||||||
|                 return to_stage_response(request, stage_response) |                 return to_stage_response(request, stage_response) | ||||||
|         except Exception as exc:  # pylint: disable=broad-except |         except Exception as exc: | ||||||
|             return self.handle_exception(exc) |             return self.handle_exception(exc) | ||||||
|  |  | ||||||
|     def _initiate_plan(self) -> FlowPlan: |     def _initiate_plan(self) -> FlowPlan: | ||||||
| @ -354,7 +351,7 @@ class FlowExecutorView(APIView): | |||||||
|             # there are no issues with the class we might've gotten |             # there are no issues with the class we might've gotten | ||||||
|             # from the cache. If there are errors, just delete all cached flows |             # from the cache. If there are errors, just delete all cached flows | ||||||
|             _ = plan.has_stages |             _ = plan.has_stages | ||||||
|         except Exception:  # pylint: disable=broad-except |         except Exception: | ||||||
|             keys = cache.keys(f"{CACHE_PREFIX}*") |             keys = cache.keys(f"{CACHE_PREFIX}*") | ||||||
|             cache.delete_many(keys) |             cache.delete_many(keys) | ||||||
|             return self._initiate_plan() |             return self._initiate_plan() | ||||||
| @ -393,11 +390,7 @@ class FlowExecutorView(APIView): | |||||||
|             NEXT_ARG_NAME, "authentik_core:root-redirect" |             NEXT_ARG_NAME, "authentik_core:root-redirect" | ||||||
|         ) |         ) | ||||||
|         self.cancel() |         self.cancel() | ||||||
|         if next_param and not is_url_absolute(next_param): |  | ||||||
|         return to_stage_response(self.request, redirect_with_qs(next_param)) |         return to_stage_response(self.request, redirect_with_qs(next_param)) | ||||||
|         return to_stage_response( |  | ||||||
|             self.request, self.stage_invalid(error_message=_("Invalid next URL")) |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def stage_ok(self) -> HttpResponse: |     def stage_ok(self) -> HttpResponse: | ||||||
|         """Callback called by stages upon successful completion. |         """Callback called by stages upon successful completion. | ||||||
| @ -426,7 +419,7 @@ class FlowExecutorView(APIView): | |||||||
|         ) |         ) | ||||||
|         return self._flow_done() |         return self._flow_done() | ||||||
|  |  | ||||||
|     def stage_invalid(self, error_message: Optional[str] = None) -> HttpResponse: |     def stage_invalid(self, error_message: str | None = None) -> HttpResponse: | ||||||
|         """Callback used stage when data is correct but a policy denies access |         """Callback used stage when data is correct but a policy denies access | ||||||
|         or the user account is disabled. |         or the user account is disabled. | ||||||
|  |  | ||||||
| @ -484,9 +477,9 @@ class CancelView(View): | |||||||
| class ToDefaultFlow(View): | class ToDefaultFlow(View): | ||||||
|     """Redirect to default flow matching by designation""" |     """Redirect to default flow matching by designation""" | ||||||
|  |  | ||||||
|     designation: Optional[FlowDesignation] = None |     designation: FlowDesignation | None = None | ||||||
|  |  | ||||||
|     def flow_by_policy(self, request: HttpRequest, **flow_filter) -> Optional[Flow]: |     def flow_by_policy(self, request: HttpRequest, **flow_filter) -> Flow | None: | ||||||
|         """Get a Flow by `**flow_filter` and check if the request from `request` can access it.""" |         """Get a Flow by `**flow_filter` and check if the request from `request` can access it.""" | ||||||
|         flows = Flow.objects.filter(**flow_filter).order_by("slug") |         flows = Flow.objects.filter(**flow_filter).order_by("slug") | ||||||
|         for flow in flows: |         for flow in flows: | ||||||
| @ -508,9 +501,7 @@ class ToDefaultFlow(View): | |||||||
|         if self.designation == FlowDesignation.AUTHENTICATION: |         if self.designation == FlowDesignation.AUTHENTICATION: | ||||||
|             flow = brand.flow_authentication |             flow = brand.flow_authentication | ||||||
|             # Check if we have a default flow from application |             # Check if we have a default flow from application | ||||||
|             application: Optional[Application] = self.request.session.get( |             application: Application | None = self.request.session.get(SESSION_KEY_APPLICATION_PRE) | ||||||
|                 SESSION_KEY_APPLICATION_PRE |  | ||||||
|             ) |  | ||||||
|             if application and application.provider and application.provider.authentication_flow: |             if application and application.provider and application.provider.authentication_flow: | ||||||
|                 flow = application.provider.authentication_flow |                 flow = application.provider.authentication_flow | ||||||
|         elif self.designation == FlowDesignation.INVALIDATION: |         elif self.designation == FlowDesignation.INVALIDATION: | ||||||
| @ -540,7 +531,10 @@ class ToDefaultFlow(View): | |||||||
|  |  | ||||||
| def to_stage_response(request: HttpRequest, source: HttpResponse) -> HttpResponse: | def to_stage_response(request: HttpRequest, source: HttpResponse) -> HttpResponse: | ||||||
|     """Convert normal HttpResponse into JSON Response""" |     """Convert normal HttpResponse into JSON Response""" | ||||||
|     if isinstance(source, HttpResponseRedirect) or source.status_code == 302: |     if ( | ||||||
|  |         isinstance(source, HttpResponseRedirect) | ||||||
|  |         or source.status_code == HttpResponseRedirect.status_code | ||||||
|  |     ): | ||||||
|         redirect_url = source["Location"] |         redirect_url = source["Location"] | ||||||
|         # Redirects to the same URL usually indicate an Error within a form |         # Redirects to the same URL usually indicate an Error within a form | ||||||
|         if request.get_full_path() == redirect_url: |         if request.get_full_path() == redirect_url: | ||||||
| @ -604,7 +598,7 @@ class ConfigureFlowInitView(LoginRequiredMixin, View): | |||||||
|             ) |             ) | ||||||
|         except FlowNonApplicableException: |         except FlowNonApplicableException: | ||||||
|             LOGGER.warning("Flow not applicable to user") |             LOGGER.warning("Flow not applicable to user") | ||||||
|             raise Http404 |             raise Http404 from None | ||||||
|         request.session[SESSION_KEY_PLAN] = plan |         request.session[SESSION_KEY_PLAN] = plan | ||||||
|         return redirect_with_qs( |         return redirect_with_qs( | ||||||
|             "authentik_core:if-flow", |             "authentik_core:if-flow", | ||||||
|  | |||||||
| @ -26,6 +26,8 @@ from authentik.flows.planner import FlowPlan | |||||||
| from authentik.flows.views.executor import SESSION_KEY_HISTORY, SESSION_KEY_PLAN | from authentik.flows.views.executor import SESSION_KEY_HISTORY, SESSION_KEY_PLAN | ||||||
| from authentik.root.install_id import get_install_id | from authentik.root.install_id import get_install_id | ||||||
|  |  | ||||||
|  | MIN_FLOW_LENGTH = 2 | ||||||
|  |  | ||||||
|  |  | ||||||
| class FlowInspectorPlanSerializer(PassiveSerializer): | class FlowInspectorPlanSerializer(PassiveSerializer): | ||||||
|     """Serializer for an active FlowPlan""" |     """Serializer for an active FlowPlan""" | ||||||
| @ -41,7 +43,7 @@ class FlowInspectorPlanSerializer(PassiveSerializer): | |||||||
|  |  | ||||||
|     def get_next_planned_stage(self, plan: FlowPlan) -> FlowStageBindingSerializer: |     def get_next_planned_stage(self, plan: FlowPlan) -> FlowStageBindingSerializer: | ||||||
|         """Get the next planned stage""" |         """Get the next planned stage""" | ||||||
|         if len(plan.bindings) < 2: |         if len(plan.bindings) < MIN_FLOW_LENGTH: | ||||||
|             return FlowStageBindingSerializer().data |             return FlowStageBindingSerializer().data | ||||||
|         return FlowStageBindingSerializer(instance=plan.bindings[1]).data |         return FlowStageBindingSerializer(instance=plan.bindings[1]).data | ||||||
|  |  | ||||||
| @ -49,7 +51,7 @@ class FlowInspectorPlanSerializer(PassiveSerializer): | |||||||
|         """Get the plan's context, sanitized""" |         """Get the plan's context, sanitized""" | ||||||
|         return sanitize_dict(plan.context) |         return sanitize_dict(plan.context) | ||||||
|  |  | ||||||
|     def get_session_id(self, plan: FlowPlan) -> str: |     def get_session_id(self, _plan: FlowPlan) -> str: | ||||||
|         """Get a unique session ID""" |         """Get a unique session ID""" | ||||||
|         request: Request = self.context["request"] |         request: Request = self.context["request"] | ||||||
|         return sha256( |         return sha256( | ||||||
|  | |||||||
| @ -3,11 +3,11 @@ | |||||||
| from base64 import b64encode | from base64 import b64encode | ||||||
| from functools import cache as funccache | from functools import cache as funccache | ||||||
| from hashlib import md5 | from hashlib import md5 | ||||||
| from typing import TYPE_CHECKING, Optional | from typing import TYPE_CHECKING | ||||||
| from urllib.parse import urlencode | from urllib.parse import urlencode | ||||||
|  |  | ||||||
| from django.core.cache import cache | from django.core.cache import cache | ||||||
| from django.http import HttpRequest | from django.http import HttpRequest, HttpResponseNotFound | ||||||
| from django.templatetags.static import static | from django.templatetags.static import static | ||||||
| from lxml import etree  # nosec | from lxml import etree  # nosec | ||||||
| from lxml.etree import Element, SubElement  # nosec | from lxml.etree import Element, SubElement  # nosec | ||||||
| @ -37,18 +37,18 @@ SVG_FONTS = [ | |||||||
| ] | ] | ||||||
|  |  | ||||||
|  |  | ||||||
| def avatar_mode_none(user: "User", mode: str) -> Optional[str]: | def avatar_mode_none(user: "User", mode: str) -> str | None: | ||||||
|     """No avatar""" |     """No avatar""" | ||||||
|     return DEFAULT_AVATAR |     return DEFAULT_AVATAR | ||||||
|  |  | ||||||
|  |  | ||||||
| def avatar_mode_attribute(user: "User", mode: str) -> Optional[str]: | def avatar_mode_attribute(user: "User", mode: str) -> str | None: | ||||||
|     """Avatars based on a user attribute""" |     """Avatars based on a user attribute""" | ||||||
|     avatar = get_path_from_dict(user.attributes, mode[11:], default=None) |     avatar = get_path_from_dict(user.attributes, mode[11:], default=None) | ||||||
|     return avatar |     return avatar | ||||||
|  |  | ||||||
|  |  | ||||||
| def avatar_mode_gravatar(user: "User", mode: str) -> Optional[str]: | def avatar_mode_gravatar(user: "User", mode: str) -> str | None: | ||||||
|     """Gravatar avatars""" |     """Gravatar avatars""" | ||||||
|     # gravatar uses md5 for their URLs, so md5 can't be avoided |     # gravatar uses md5 for their URLs, so md5 can't be avoided | ||||||
|     mail_hash = md5(user.email.lower().encode("utf-8")).hexdigest()  # nosec |     mail_hash = md5(user.email.lower().encode("utf-8")).hexdigest()  # nosec | ||||||
| @ -65,7 +65,7 @@ def avatar_mode_gravatar(user: "User", mode: str) -> Optional[str]: | |||||||
|         # (HEAD since we don't need the body) |         # (HEAD since we don't need the body) | ||||||
|         # so if that returns a 404, move onto the next mode |         # so if that returns a 404, move onto the next mode | ||||||
|         res = get_http_session().head(gravatar_url, timeout=5) |         res = get_http_session().head(gravatar_url, timeout=5) | ||||||
|         if res.status_code == 404: |         if res.status_code == HttpResponseNotFound.status_code: | ||||||
|             cache.set(full_key, None) |             cache.set(full_key, None) | ||||||
|             return None |             return None | ||||||
|         res.raise_for_status() |         res.raise_for_status() | ||||||
| @ -86,12 +86,13 @@ def generate_colors(text: str) -> tuple[str, str]: | |||||||
|     red = min(max((color >> 16) & 0xFF, 55), 200) |     red = min(max((color >> 16) & 0xFF, 55), 200) | ||||||
|     bg_hex = f"{red:02x}{green:02x}{blue:02x}" |     bg_hex = f"{red:02x}{green:02x}{blue:02x}" | ||||||
|     # Contrasting text color (https://stackoverflow.com/a/3943023) |     # Contrasting text color (https://stackoverflow.com/a/3943023) | ||||||
|     text_hex = "000" if (red * 0.299 + green * 0.587 + blue * 0.114) > 186 else "fff" |     text_hex = ( | ||||||
|  |         "000" if (red * 0.299 + green * 0.587 + blue * 0.114) > 186 else "fff"  # noqa: PLR2004 | ||||||
|  |     ) | ||||||
|     return bg_hex, text_hex |     return bg_hex, text_hex | ||||||
|  |  | ||||||
|  |  | ||||||
| @funccache | @funccache | ||||||
| # pylint: disable=too-many-arguments,too-many-locals |  | ||||||
| def generate_avatar_from_name( | def generate_avatar_from_name( | ||||||
|     name: str, |     name: str, | ||||||
|     length: int = 2, |     length: int = 2, | ||||||
| @ -107,7 +108,7 @@ def generate_avatar_from_name( | |||||||
|     """ |     """ | ||||||
|     name_parts = name.split() |     name_parts = name.split() | ||||||
|     # Only abbreviate first and last name |     # Only abbreviate first and last name | ||||||
|     if len(name_parts) > 2: |     if len(name_parts) > 2:  # noqa: PLR2004 | ||||||
|         name_parts = [name_parts[0], name_parts[-1]] |         name_parts = [name_parts[0], name_parts[-1]] | ||||||
|  |  | ||||||
|     if len(name_parts) == 1: |     if len(name_parts) == 1: | ||||||
| @ -155,7 +156,7 @@ def generate_avatar_from_name( | |||||||
|     return etree.tostring(root_element).decode() |     return etree.tostring(root_element).decode() | ||||||
|  |  | ||||||
|  |  | ||||||
| def avatar_mode_generated(user: "User", mode: str) -> Optional[str]: | def avatar_mode_generated(user: "User", mode: str) -> str | None: | ||||||
|     """Wrapper that converts generated avatar to base64 svg""" |     """Wrapper that converts generated avatar to base64 svg""" | ||||||
|     # By default generate based off of user's display name |     # By default generate based off of user's display name | ||||||
|     name = user.name.strip() |     name = user.name.strip() | ||||||
| @ -169,7 +170,7 @@ def avatar_mode_generated(user: "User", mode: str) -> Optional[str]: | |||||||
|     return f"data:image/svg+xml;base64,{b64encode(svg.encode('utf-8')).decode('utf-8')}" |     return f"data:image/svg+xml;base64,{b64encode(svg.encode('utf-8')).decode('utf-8')}" | ||||||
|  |  | ||||||
|  |  | ||||||
| def avatar_mode_url(user: "User", mode: str) -> Optional[str]: | def avatar_mode_url(user: "User", mode: str) -> str | None: | ||||||
|     """Format url""" |     """Format url""" | ||||||
|     mail_hash = md5(user.email.lower().encode("utf-8")).hexdigest()  # nosec |     mail_hash = md5(user.email.lower().encode("utf-8")).hexdigest()  # nosec | ||||||
|     return mode % { |     return mode % { | ||||||
| @ -179,7 +180,7 @@ def avatar_mode_url(user: "User", mode: str) -> Optional[str]: | |||||||
|     } |     } | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_avatar(user: "User", request: Optional[HttpRequest] = None) -> str: | def get_avatar(user: "User", request: HttpRequest | None = None) -> str: | ||||||
|     """Get avatar with configured mode""" |     """Get avatar with configured mode""" | ||||||
|     mode_map = { |     mode_map = { | ||||||
|         "none": avatar_mode_none, |         "none": avatar_mode_none, | ||||||
|  | |||||||
| @ -13,7 +13,7 @@ from json.decoder import JSONDecodeError | |||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from sys import argv, stderr | from sys import argv, stderr | ||||||
| from time import time | from time import time | ||||||
| from typing import Any, Optional | from typing import Any | ||||||
| from urllib.parse import urlparse | from urllib.parse import urlparse | ||||||
|  |  | ||||||
| import yaml | import yaml | ||||||
| @ -89,7 +89,7 @@ class Attr: | |||||||
|  |  | ||||||
|     # depending on source_type, might contain the environment variable or the path |     # depending on source_type, might contain the environment variable or the path | ||||||
|     # to the config file containing this change or the file containing this value |     # to the config file containing this change or the file containing this value | ||||||
|     source: Optional[str] = field(default=None) |     source: str | None = field(default=None) | ||||||
|  |  | ||||||
|     def __post_init__(self): |     def __post_init__(self): | ||||||
|         if isinstance(self.value, Attr): |         if isinstance(self.value, Attr): | ||||||
| @ -190,16 +190,18 @@ class ConfigLoader: | |||||||
|  |  | ||||||
|     def update(self, root: dict[str, Any], updatee: dict[str, Any]) -> dict[str, Any]: |     def update(self, root: dict[str, Any], updatee: dict[str, Any]) -> dict[str, Any]: | ||||||
|         """Recursively update dictionary""" |         """Recursively update dictionary""" | ||||||
|         for key, value in updatee.items(): |         for key, raw_value in updatee.items(): | ||||||
|             if isinstance(value, Mapping): |             if isinstance(raw_value, Mapping): | ||||||
|                 root[key] = self.update(root.get(key, {}), value) |                 root[key] = self.update(root.get(key, {}), raw_value) | ||||||
|             else: |             else: | ||||||
|                 if isinstance(value, str): |                 if isinstance(raw_value, str): | ||||||
|                     value = self.parse_uri(value) |                     value = self.parse_uri(raw_value) | ||||||
|                 elif isinstance(value, Attr) and isinstance(value.value, str): |                 elif isinstance(raw_value, Attr) and isinstance(raw_value.value, str): | ||||||
|                     value = self.parse_uri(value.value) |                     value = self.parse_uri(raw_value.value) | ||||||
|                 elif not isinstance(value, Attr): |                 elif not isinstance(raw_value, Attr): | ||||||
|                     value = Attr(value) |                     value = Attr(raw_value) | ||||||
|  |                 else: | ||||||
|  |                     value = raw_value | ||||||
|                 root[key] = value |                 root[key] = value | ||||||
|         return root |         return root | ||||||
|  |  | ||||||
| @ -219,7 +221,7 @@ class ConfigLoader: | |||||||
|             parsed_value = os.getenv(url.netloc, url.query) |             parsed_value = os.getenv(url.netloc, url.query) | ||||||
|         if url.scheme == "file": |         if url.scheme == "file": | ||||||
|             try: |             try: | ||||||
|                 with open(url.path, "r", encoding="utf8") as _file: |                 with open(url.path, encoding="utf8") as _file: | ||||||
|                     parsed_value = _file.read().strip() |                     parsed_value = _file.read().strip() | ||||||
|             except OSError as exc: |             except OSError as exc: | ||||||
|                 self.log("error", f"Failed to read config value from {url.path}: {exc}") |                 self.log("error", f"Failed to read config value from {url.path}: {exc}") | ||||||
| @ -257,7 +259,7 @@ class ConfigLoader: | |||||||
|             relative_key = key.replace(f"{ENV_PREFIX}_", "", 1).replace("__", ".").lower() |             relative_key = key.replace(f"{ENV_PREFIX}_", "", 1).replace("__", ".").lower() | ||||||
|             # Check if the value is json, and try to load it |             # Check if the value is json, and try to load it | ||||||
|             try: |             try: | ||||||
|                 value = loads(value) |                 value = loads(value)  # noqa: PLW2901 | ||||||
|             except JSONDecodeError: |             except JSONDecodeError: | ||||||
|                 pass |                 pass | ||||||
|             attr_value = Attr(value, Attr.Source.ENV, relative_key) |             attr_value = Attr(value, Attr.Source.ENV, relative_key) | ||||||
| @ -330,7 +332,7 @@ CONFIG = ConfigLoader() | |||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     if len(argv) < 2: |     if len(argv) < 2:  # noqa: PLR2004 | ||||||
|         print(dumps(CONFIG.raw, indent=4, cls=AttrEncoder)) |         print(dumps(CONFIG.raw, indent=4, cls=AttrEncoder)) | ||||||
|     else: |     else: | ||||||
|         print(CONFIG.get(argv[1])) |         print(CONFIG.get(argv[1])) | ||||||
|  | |||||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user
	