diff --git a/.github/workflows/ci-main.yml b/.github/workflows/ci-main.yml index 96a235c4e2..535cd6efc6 100644 --- a/.github/workflows/ci-main.yml +++ b/.github/workflows/ci-main.yml @@ -28,10 +28,7 @@ jobs: - bandit - black - codespell - - isort - pending-migrations - # - pylint - - pyright - ruff runs-on: ubuntu-latest steps: diff --git a/.vscode/extensions.json b/.vscode/extensions.json index dd280d682e..53a07a56d2 100644 --- a/.vscode/extensions.json +++ b/.vscode/extensions.json @@ -10,8 +10,7 @@ "Gruntfuggly.todo-tree", "mechatroner.rainbow-csv", "ms-python.black-formatter", - "ms-python.isort", - "ms-python.pylint", + "charliermarsh.ruff", "ms-python.python", "ms-python.vscode-pylance", "ms-python.black-formatter", diff --git a/Makefile b/Makefile index 7c6b2063ef..9f56b19d45 100644 --- a/Makefile +++ b/Makefile @@ -59,15 +59,12 @@ test: ## Run the server tests and produce a coverage report (locally) coverage report lint-fix: ## Lint and automatically fix errors in the python source code. Reports spelling errors. - isort $(PY_SOURCES) black $(PY_SOURCES) - ruff --fix $(PY_SOURCES) + ruff check --fix $(PY_SOURCES) codespell -w $(CODESPELL_ARGS) lint: ## Lint the python and golang sources bandit -r $(PY_SOURCES) -x node_modules - ./web/node_modules/.bin/pyright $(PY_SOURCES) - pylint $(PY_SOURCES) golangci-lint run -v core-install: @@ -249,9 +246,6 @@ ci--meta-debug: python -V node --version -ci-pylint: ci--meta-debug - pylint $(PY_SOURCES) - ci-black: ci--meta-debug black --check $(PY_SOURCES) @@ -261,14 +255,8 @@ ci-ruff: ci--meta-debug ci-codespell: ci--meta-debug codespell $(CODESPELL_ARGS) -s -ci-isort: ci--meta-debug - isort --check $(PY_SOURCES) - ci-bandit: ci--meta-debug bandit -r $(PY_SOURCES) -ci-pyright: ci--meta-debug - ./web/node_modules/.bin/pyright $(PY_SOURCES) - ci-pending-migrations: ci--meta-debug ak makemigrations --check diff --git a/authentik/__init__.py b/authentik/__init__.py index 8f5c2da79f..1f78a84baf 100644 --- a/authentik/__init__.py +++ b/authentik/__init__.py @@ -1,13 +1,12 @@ """authentik root module""" from os import environ -from typing import Optional __version__ = "2024.2.1" ENV_GIT_HASH_KEY = "GIT_BUILD_HASH" -def get_build_hash(fallback: Optional[str] = None) -> str: +def get_build_hash(fallback: str | None = None) -> str: """Get build hash""" build_hash = environ.get(ENV_GIT_HASH_KEY, fallback if fallback else "") return fallback if build_hash == "" and fallback else build_hash diff --git a/authentik/api/apps.py b/authentik/api/apps.py index 7b91d99a9e..5acf3e29c1 100644 --- a/authentik/api/apps.py +++ b/authentik/api/apps.py @@ -18,7 +18,7 @@ class AuthentikAPIConfig(AppConfig): # Class is defined here as it needs to be created early enough that drf-spectacular will # find it, but also won't cause any import issues - # pylint: disable=unused-variable + class TokenSchema(OpenApiAuthenticationExtension): """Auth schema""" diff --git a/authentik/api/authentication.py b/authentik/api/authentication.py index 04edc73580..cbc606a5a5 100644 --- a/authentik/api/authentication.py +++ b/authentik/api/authentication.py @@ -1,7 +1,7 @@ """API Authentication""" from hmac import compare_digest -from typing import Any, Optional +from typing import Any from django.conf import settings from rest_framework.authentication import BaseAuthentication, get_authorization_header @@ -17,7 +17,7 @@ from authentik.providers.oauth2.constants import SCOPE_AUTHENTIK_API LOGGER = get_logger() -def validate_auth(header: bytes) -> Optional[str]: +def validate_auth(header: bytes) -> str | None: """Validate that the header is in a correct format, returns type and credentials""" auth_credentials = header.decode().strip() @@ -32,7 +32,7 @@ def validate_auth(header: bytes) -> Optional[str]: return auth_credentials -def bearer_auth(raw_header: bytes) -> Optional[User]: +def bearer_auth(raw_header: bytes) -> User | None: """raw_header in the Format of `Bearer ....`""" user = auth_user_lookup(raw_header) if not user: @@ -42,7 +42,7 @@ def bearer_auth(raw_header: bytes) -> Optional[User]: return user -def auth_user_lookup(raw_header: bytes) -> Optional[User]: +def auth_user_lookup(raw_header: bytes) -> User | None: """raw_header in the Format of `Bearer ....`""" from authentik.providers.oauth2.models import AccessToken @@ -75,7 +75,7 @@ def auth_user_lookup(raw_header: bytes) -> Optional[User]: raise AuthenticationFailed("Token invalid/expired") -def token_secret_key(value: str) -> Optional[User]: +def token_secret_key(value: str) -> User | None: """Check if the token is the secret key and return the service account for the managed outpost""" from authentik.outposts.apps import MANAGED_OUTPOST diff --git a/authentik/api/tests/test_auth.py b/authentik/api/tests/test_auth.py index 7978236a40..f449b88b60 100644 --- a/authentik/api/tests/test_auth.py +++ b/authentik/api/tests/test_auth.py @@ -25,17 +25,17 @@ class TestAPIAuth(TestCase): def test_invalid_type(self): """Test invalid type""" with self.assertRaises(AuthenticationFailed): - bearer_auth("foo bar".encode()) + bearer_auth(b"foo bar") def test_invalid_empty(self): """Test invalid type""" - self.assertIsNone(bearer_auth("Bearer ".encode())) - self.assertIsNone(bearer_auth("".encode())) + self.assertIsNone(bearer_auth(b"Bearer ")) + self.assertIsNone(bearer_auth(b"")) def test_invalid_no_token(self): """Test invalid with no token""" with self.assertRaises(AuthenticationFailed): - auth = b64encode(":abc".encode()).decode() + auth = b64encode(b":abc").decode() self.assertIsNone(bearer_auth(f"Basic :{auth}".encode())) def test_bearer_valid(self): diff --git a/authentik/api/tests/test_viewsets.py b/authentik/api/tests/test_viewsets.py index 08cc457c49..da16e801c9 100644 --- a/authentik/api/tests/test_viewsets.py +++ b/authentik/api/tests/test_viewsets.py @@ -1,6 +1,6 @@ """authentik API Modelviewset tests""" -from typing import Callable +from collections.abc import Callable from django.test import TestCase from rest_framework.viewsets import ModelViewSet, ReadOnlyModelViewSet @@ -26,6 +26,6 @@ def viewset_tester_factory(test_viewset: type[ModelViewSet]) -> Callable: for _, viewset, _ in router.registry: - if not issubclass(viewset, (ModelViewSet, ReadOnlyModelViewSet)): + if not issubclass(viewset, ModelViewSet | ReadOnlyModelViewSet): continue setattr(TestModelViewSets, f"test_viewset_{viewset.__name__}", viewset_tester_factory(viewset)) diff --git a/authentik/api/v3/urls.py b/authentik/api/v3/urls.py index e56b8510bd..735197dbd1 100644 --- a/authentik/api/v3/urls.py +++ b/authentik/api/v3/urls.py @@ -33,7 +33,7 @@ for _authentik_app in get_apps(): app_name=_authentik_app.name, ) continue - urls: list = getattr(api_urls, "api_urlpatterns") + urls: list = api_urls.api_urlpatterns for url in urls: if isinstance(url, URLPattern): _other_urls.append(url) diff --git a/authentik/blueprints/api.py b/authentik/blueprints/api.py index c15ff8f37e..993121e705 100644 --- a/authentik/blueprints/api.py +++ b/authentik/blueprints/api.py @@ -52,7 +52,9 @@ class BlueprintInstanceSerializer(ModelSerializer): valid, logs = Importer.from_string(content, context).validate() if not valid: text_logs = "\n".join([x["event"] for x in logs]) - raise ValidationError(_("Failed to validate blueprint: %(logs)s" % {"logs": text_logs})) + raise ValidationError( + _("Failed to validate blueprint: {logs}".format_map({"logs": text_logs})) + ) return content def validate(self, attrs: dict) -> dict: diff --git a/authentik/blueprints/apps.py b/authentik/blueprints/apps.py index 9a9fb53dfb..9d7718916d 100644 --- a/authentik/blueprints/apps.py +++ b/authentik/blueprints/apps.py @@ -1,8 +1,8 @@ """authentik Blueprints app""" +from collections.abc import Callable from importlib import import_module from inspect import ismethod -from typing import Callable from django.apps import AppConfig from django.db import DatabaseError, InternalError, ProgrammingError @@ -66,13 +66,13 @@ class ManagedAppConfig(AppConfig): @staticmethod def reconcile_tenant(func: Callable): """Mark a function to be called on startup (for each tenant)""" - setattr(func, "_authentik_managed_reconcile", ManagedAppConfig.RECONCILE_TENANT_CATEGORY) + func._authentik_managed_reconcile = ManagedAppConfig.RECONCILE_TENANT_CATEGORY return func @staticmethod def reconcile_global(func: Callable): """Mark a function to be called on startup (globally)""" - setattr(func, "_authentik_managed_reconcile", ManagedAppConfig.RECONCILE_GLOBAL_CATEGORY) + func._authentik_managed_reconcile = ManagedAppConfig.RECONCILE_GLOBAL_CATEGORY return func def _reconcile_tenant(self) -> None: diff --git a/authentik/blueprints/models.py b/authentik/blueprints/models.py index f1679866ee..a3abcba59f 100644 --- a/authentik/blueprints/models.py +++ b/authentik/blueprints/models.py @@ -71,6 +71,19 @@ class BlueprintInstance(SerializerModel, ManagedModel, CreatedUpdatedModel): enabled = models.BooleanField(default=True) managed_models = ArrayField(models.TextField(), default=list) + class Meta: + verbose_name = _("Blueprint Instance") + verbose_name_plural = _("Blueprint Instances") + unique_together = ( + ( + "name", + "path", + ), + ) + + def __str__(self) -> str: + return f"Blueprint Instance {self.name}" + def retrieve_oci(self) -> str: """Get blueprint from an OCI registry""" client = BlueprintOCIClient(self.path.replace(OCI_PREFIX, "https://")) @@ -89,7 +102,7 @@ class BlueprintInstance(SerializerModel, ManagedModel, CreatedUpdatedModel): raise BlueprintRetrievalFailed("Invalid blueprint path") with full_path.open("r", encoding="utf-8") as _file: return _file.read() - except (IOError, OSError) as exc: + except OSError as exc: raise BlueprintRetrievalFailed(exc) from exc def retrieve(self) -> str: @@ -105,16 +118,3 @@ class BlueprintInstance(SerializerModel, ManagedModel, CreatedUpdatedModel): from authentik.blueprints.api import BlueprintInstanceSerializer return BlueprintInstanceSerializer - - def __str__(self) -> str: - return f"Blueprint Instance {self.name}" - - class Meta: - verbose_name = _("Blueprint Instance") - verbose_name_plural = _("Blueprint Instances") - unique_together = ( - ( - "name", - "path", - ), - ) diff --git a/authentik/blueprints/tests/__init__.py b/authentik/blueprints/tests/__init__.py index ab8e780310..1dd22a453d 100644 --- a/authentik/blueprints/tests/__init__.py +++ b/authentik/blueprints/tests/__init__.py @@ -1,7 +1,7 @@ """Blueprint helpers""" +from collections.abc import Callable from functools import wraps -from typing import Callable from django.apps import apps diff --git a/authentik/blueprints/tests/test_packaged.py b/authentik/blueprints/tests/test_packaged.py index 9e1111c321..443173bac2 100644 --- a/authentik/blueprints/tests/test_packaged.py +++ b/authentik/blueprints/tests/test_packaged.py @@ -1,7 +1,7 @@ """test packaged blueprints""" +from collections.abc import Callable from pathlib import Path -from typing import Callable from django.test import TransactionTestCase diff --git a/authentik/blueprints/tests/test_serializer_models.py b/authentik/blueprints/tests/test_serializer_models.py index 8bff9994ed..a3ae6005aa 100644 --- a/authentik/blueprints/tests/test_serializer_models.py +++ b/authentik/blueprints/tests/test_serializer_models.py @@ -1,6 +1,6 @@ """authentik managed models tests""" -from typing import Callable, Type +from collections.abc import Callable from django.apps import apps from django.test import TestCase @@ -14,7 +14,7 @@ class TestModels(TestCase): """Test Models""" -def serializer_tester_factory(test_model: Type[SerializerModel]) -> Callable: +def serializer_tester_factory(test_model: type[SerializerModel]) -> Callable: """Test serializer""" def tester(self: TestModels): diff --git a/authentik/blueprints/tests/test_v1_tasks.py b/authentik/blueprints/tests/test_v1_tasks.py index fefdb62e25..b1d201419d 100644 --- a/authentik/blueprints/tests/test_v1_tasks.py +++ b/authentik/blueprints/tests/test_v1_tasks.py @@ -54,7 +54,7 @@ class TestBlueprintsV1Tasks(TransactionTestCase): file.seek(0) file_hash = sha512(file.read().encode()).hexdigest() file.flush() - blueprints_discovery() # pylint: disable=no-value-for-parameter + blueprints_discovery() instance = BlueprintInstance.objects.filter(name=blueprint_id).first() self.assertEqual(instance.last_applied_hash, file_hash) self.assertEqual( @@ -82,7 +82,7 @@ class TestBlueprintsV1Tasks(TransactionTestCase): ) ) file.flush() - blueprints_discovery() # pylint: disable=no-value-for-parameter + blueprints_discovery() blueprint = BlueprintInstance.objects.filter(name="foo").first() self.assertEqual( blueprint.last_applied_hash, @@ -107,7 +107,7 @@ class TestBlueprintsV1Tasks(TransactionTestCase): ) ) file.flush() - blueprints_discovery() # pylint: disable=no-value-for-parameter + blueprints_discovery() blueprint.refresh_from_db() self.assertEqual( blueprint.last_applied_hash, @@ -149,7 +149,7 @@ class TestBlueprintsV1Tasks(TransactionTestCase): instance.status, BlueprintInstanceStatus.UNKNOWN, ) - apply_blueprint(instance.pk) # pylint: disable=no-value-for-parameter + apply_blueprint(instance.pk) instance.refresh_from_db() self.assertEqual(instance.last_applied_hash, "") self.assertEqual( diff --git a/authentik/blueprints/v1/common.py b/authentik/blueprints/v1/common.py index 2199b3ebec..4d66143281 100644 --- a/authentik/blueprints/v1/common.py +++ b/authentik/blueprints/v1/common.py @@ -1,13 +1,14 @@ """transfer common classes""" from collections import OrderedDict +from collections.abc import Iterable, Mapping from copy import copy from dataclasses import asdict, dataclass, field, is_dataclass from enum import Enum from functools import reduce from operator import ixor from os import getenv -from typing import Any, Iterable, Literal, Mapping, Optional, Union +from typing import Any, Literal, Union from uuid import UUID from deepmerge import always_merger @@ -45,7 +46,7 @@ def get_attrs(obj: SerializerModel) -> dict[str, Any]: class BlueprintEntryState: """State of a single instance""" - instance: Optional[Model] = None + instance: Model | None = None class BlueprintEntryDesiredState(Enum): @@ -67,9 +68,9 @@ class BlueprintEntry: ) conditions: list[Any] = field(default_factory=list) identifiers: dict[str, Any] = field(default_factory=dict) - attrs: Optional[dict[str, Any]] = field(default_factory=dict) + attrs: dict[str, Any] | None = field(default_factory=dict) - id: Optional[str] = None + id: str | None = None _state: BlueprintEntryState = field(default_factory=BlueprintEntryState) @@ -92,10 +93,10 @@ class BlueprintEntry: attrs=all_attrs, ) - def _get_tag_context( + def get_tag_context( self, depth: int = 0, - context_tag_type: Optional[type["YAMLTagContext"] | tuple["YAMLTagContext", ...]] = None, + context_tag_type: type["YAMLTagContext"] | tuple["YAMLTagContext", ...] | None = None, ) -> "YAMLTagContext": """Get a YAMLTagContext object located at a certain depth in the tag tree""" if depth < 0: @@ -108,8 +109,8 @@ class BlueprintEntry: try: return contexts[-(depth + 1)] - except IndexError: - raise ValueError(f"invalid depth: {depth}. Max depth: {len(contexts) - 1}") + except IndexError as exc: + raise ValueError(f"invalid depth: {depth}. Max depth: {len(contexts) - 1}") from exc def tag_resolver(self, value: Any, blueprint: "Blueprint") -> Any: """Check if we have any special tags that need handling""" @@ -170,7 +171,7 @@ class Blueprint: entries: list[BlueprintEntry] = field(default_factory=list) context: dict = field(default_factory=dict) - metadata: Optional[BlueprintMetadata] = field(default=None) + metadata: BlueprintMetadata | None = field(default=None) class YAMLTag: @@ -218,7 +219,7 @@ class Env(YAMLTag): """Lookup environment variable with optional default""" key: str - default: Optional[Any] + default: Any | None def __init__(self, loader: "BlueprintLoader", node: ScalarNode | SequenceNode) -> None: super().__init__() @@ -237,7 +238,7 @@ class Context(YAMLTag): """Lookup key from instance context""" key: str - default: Optional[Any] + default: Any | None def __init__(self, loader: "BlueprintLoader", node: ScalarNode | SequenceNode) -> None: super().__init__() @@ -281,7 +282,7 @@ class Format(YAMLTag): try: return self.format_string % tuple(args) except TypeError as exc: - raise EntryInvalidError.from_entry(exc, entry) + raise EntryInvalidError.from_entry(exc, entry) from exc class Find(YAMLTag): @@ -366,7 +367,7 @@ class Condition(YAMLTag): comparator = self._COMPARATORS[self.mode.upper()] return comparator(tuple(bool(x) for x in args)) except (TypeError, KeyError) as exc: - raise EntryInvalidError.from_entry(exc, entry) + raise EntryInvalidError.from_entry(exc, entry) from exc class If(YAMLTag): @@ -398,7 +399,7 @@ class If(YAMLTag): blueprint, ) except TypeError as exc: - raise EntryInvalidError.from_entry(exc, entry) + raise EntryInvalidError.from_entry(exc, entry) from exc class Enumerate(YAMLTag, YAMLTagContext): @@ -412,9 +413,7 @@ class Enumerate(YAMLTag, YAMLTagContext): "SEQ": (list, lambda a, b: [*a, b]), "MAP": ( dict, - lambda a, b: always_merger.merge( - a, {b[0]: b[1]} if isinstance(b, (tuple, list)) else b - ), + lambda a, b: always_merger.merge(a, {b[0]: b[1]} if isinstance(b, tuple | list) else b), ), } @@ -456,7 +455,7 @@ class Enumerate(YAMLTag, YAMLTagContext): try: output_class, add_fn = self._OUTPUT_BODIES[self.output_body.upper()] except KeyError as exc: - raise EntryInvalidError.from_entry(exc, entry) + raise EntryInvalidError.from_entry(exc, entry) from exc result = output_class() @@ -484,13 +483,13 @@ class EnumeratedItem(YAMLTag): _SUPPORTED_CONTEXT_TAGS = (Enumerate,) - def __init__(self, loader: "BlueprintLoader", node: ScalarNode) -> None: + def __init__(self, _loader: "BlueprintLoader", node: ScalarNode) -> None: super().__init__() self.depth = int(node.value) def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any: try: - context_tag: Enumerate = entry._get_tag_context( + context_tag: Enumerate = entry.get_tag_context( depth=self.depth, context_tag_type=EnumeratedItem._SUPPORTED_CONTEXT_TAGS, ) @@ -500,9 +499,11 @@ class EnumeratedItem(YAMLTag): f"{self.__class__.__name__} tags are only usable " f"inside an {Enumerate.__name__} tag", entry, - ) + ) from exc - raise EntryInvalidError.from_entry(f"{self.__class__.__name__} tag: {exc}", entry) + raise EntryInvalidError.from_entry( + f"{self.__class__.__name__} tag: {exc}", entry + ) from exc return context_tag.get_context(entry, blueprint) @@ -515,8 +516,8 @@ class Index(EnumeratedItem): try: return context[0] - except IndexError: # pragma: no cover - raise EntryInvalidError.from_entry(f"Empty/invalid context: {context}", entry) + except IndexError as exc: # pragma: no cover + raise EntryInvalidError.from_entry(f"Empty/invalid context: {context}", entry) from exc class Value(EnumeratedItem): @@ -527,8 +528,8 @@ class Value(EnumeratedItem): try: return context[1] - except IndexError: # pragma: no cover - raise EntryInvalidError.from_entry(f"Empty/invalid context: {context}", entry) + except IndexError as exc: # pragma: no cover + raise EntryInvalidError.from_entry(f"Empty/invalid context: {context}", entry) from exc class BlueprintDumper(SafeDumper): @@ -582,13 +583,13 @@ class BlueprintLoader(SafeLoader): class EntryInvalidError(SentryIgnoredException): """Error raised when an entry is invalid""" - entry_model: Optional[str] - entry_id: Optional[str] - validation_error: Optional[ValidationError] - serializer: Optional[Serializer] = None + entry_model: str | None + entry_id: str | None + validation_error: ValidationError | None + serializer: Serializer | None = None def __init__( - self, *args: object, validation_error: Optional[ValidationError] = None, **kwargs + self, *args: object, validation_error: ValidationError | None = None, **kwargs ) -> None: super().__init__(*args) self.entry_model = None diff --git a/authentik/blueprints/v1/exporter.py b/authentik/blueprints/v1/exporter.py index f4712ee533..89ed20be04 100644 --- a/authentik/blueprints/v1/exporter.py +++ b/authentik/blueprints/v1/exporter.py @@ -1,6 +1,6 @@ """Blueprint exporter""" -from typing import Iterable +from collections.abc import Iterable from uuid import UUID from django.apps import apps @@ -59,7 +59,7 @@ class Exporter: blueprint = Blueprint() self._pre_export(blueprint) blueprint.metadata = BlueprintMetadata( - name=_("authentik Export - %(date)s" % {"date": str(now())}), + name=_("authentik Export - {date}".format_map({"date": str(now())})), labels={ LABEL_AUTHENTIK_GENERATED: "true", }, diff --git a/authentik/blueprints/v1/importer.py b/authentik/blueprints/v1/importer.py index 21127ec00e..ff1138d6fe 100644 --- a/authentik/blueprints/v1/importer.py +++ b/authentik/blueprints/v1/importer.py @@ -2,7 +2,7 @@ from contextlib import contextmanager from copy import deepcopy -from typing import Any, Optional +from typing import Any from dacite.config import Config from dacite.core import from_dict @@ -62,7 +62,7 @@ SERIALIZER_CONTEXT_BLUEPRINT = "blueprint_entry" def excluded_models() -> list[type[Model]]: """Return a list of all excluded models that shouldn't be exposed via API or other means (internal only, base classes, non-used objects, etc)""" - # pylint: disable=imported-auth-user + from django.contrib.auth.models import Group as DjangoGroup from django.contrib.auth.models import User as DjangoUser @@ -101,7 +101,7 @@ def excluded_models() -> list[type[Model]]: def is_model_allowed(model: type[Model]) -> bool: """Check if model is allowed""" - return model not in excluded_models() and issubclass(model, (SerializerModel, BaseMetaModel)) + return model not in excluded_models() and issubclass(model, SerializerModel | BaseMetaModel) class DoRollback(SentryIgnoredException): @@ -125,7 +125,7 @@ class Importer: logger: BoundLogger _import: Blueprint - def __init__(self, blueprint: Blueprint, context: Optional[dict] = None): + def __init__(self, blueprint: Blueprint, context: dict | None = None): self.__pk_map: dict[Any, Model] = {} self._import = blueprint self.logger = get_logger() @@ -168,7 +168,7 @@ class Importer: for key, value in attrs.items(): try: if isinstance(value, dict): - for idx, _inner_key in enumerate(value): + for _, _inner_key in enumerate(value): value[_inner_key] = updater(value[_inner_key]) elif isinstance(value, list): for idx, _inner_value in enumerate(value): @@ -197,8 +197,7 @@ class Importer: return main_query | sub_query - # pylint: disable-msg=too-many-locals - def _validate_single(self, entry: BlueprintEntry) -> Optional[BaseSerializer]: + def _validate_single(self, entry: BlueprintEntry) -> BaseSerializer | None: """Validate a single entry""" if not entry.check_all_conditions_match(self._import): self.logger.debug("One or more conditions of this entry are not fulfilled, skipping") @@ -369,7 +368,7 @@ class Importer: self.__pk_map[entry.identifiers["pk"]] = instance.pk entry._state = BlueprintEntryState(instance) elif state == BlueprintEntryDesiredState.ABSENT: - instance: Optional[Model] = serializer.instance + instance: Model | None = serializer.instance if instance.pk: instance.delete() self.logger.debug("deleted model", mode=instance) diff --git a/authentik/blueprints/v1/meta/apply_blueprint.py b/authentik/blueprints/v1/meta/apply_blueprint.py index f3c5aa4704..abd593c045 100644 --- a/authentik/blueprints/v1/meta/apply_blueprint.py +++ b/authentik/blueprints/v1/meta/apply_blueprint.py @@ -43,7 +43,7 @@ class ApplyBlueprintMetaSerializer(PassiveSerializer): LOGGER.info("Blueprint does not exist, but not required") return MetaResult() LOGGER.debug("Applying blueprint from meta model", blueprint=self.blueprint_instance) - # pylint: disable=no-value-for-parameter + apply_blueprint(str(self.blueprint_instance.pk)) return MetaResult() diff --git a/authentik/blueprints/v1/meta/registry.py b/authentik/blueprints/v1/meta/registry.py index 588f03e806..7d64815914 100644 --- a/authentik/blueprints/v1/meta/registry.py +++ b/authentik/blueprints/v1/meta/registry.py @@ -8,15 +8,15 @@ from rest_framework.serializers import Serializer class BaseMetaModel(Model): """Base models""" + class Meta: + abstract = True + @staticmethod def serializer() -> Serializer: """Serializer similar to SerializerModel, but as a static method since this is an abstract model""" raise NotImplementedError - class Meta: - abstract = True - class MetaResult: """Result returned by Meta Models' serializers. Empty class but we can't return none as diff --git a/authentik/blueprints/v1/tasks.py b/authentik/blueprints/v1/tasks.py index 6d12c28d1d..058cb4a11c 100644 --- a/authentik/blueprints/v1/tasks.py +++ b/authentik/blueprints/v1/tasks.py @@ -4,7 +4,6 @@ from dataclasses import asdict, dataclass, field from hashlib import sha512 from pathlib import Path from sys import platform -from typing import Optional from dacite.core import from_dict from django.db import DatabaseError, InternalError, ProgrammingError @@ -50,14 +49,14 @@ class BlueprintFile: version: int hash: str last_m: int - meta: Optional[BlueprintMetadata] = field(default=None) + meta: BlueprintMetadata | None = field(default=None) def start_blueprint_watcher(): """Start blueprint watcher, if it's not running already.""" # This function might be called twice since it's called on celery startup - # pylint: disable=global-statement - global _file_watcher_started + + global _file_watcher_started # noqa: PLW0603 if _file_watcher_started: return observer = Observer() @@ -126,7 +125,7 @@ def blueprints_find() -> list[BlueprintFile]: # Check if any part in the path starts with a dot and assume a hidden file if any(part for part in path.parts if part.startswith(".")): continue - with open(path, "r", encoding="utf-8") as blueprint_file: + with open(path, encoding="utf-8") as blueprint_file: try: raw_blueprint = load(blueprint_file.read(), BlueprintLoader) except YAMLError as exc: @@ -150,7 +149,7 @@ def blueprints_find() -> list[BlueprintFile]: throws=(DatabaseError, ProgrammingError, InternalError), base=SystemTask, bind=True ) @prefill_task -def blueprints_discovery(self: SystemTask, path: Optional[str] = None): +def blueprints_discovery(self: SystemTask, path: str | None = None): """Find blueprints and check if they need to be created in the database""" count = 0 for blueprint in blueprints_find(): @@ -197,7 +196,7 @@ def check_blueprint_v1_file(blueprint: BlueprintFile): def apply_blueprint(self: SystemTask, instance_pk: str): """Apply single blueprint""" self.save_on_success = False - instance: Optional[BlueprintInstance] = None + instance: BlueprintInstance | None = None try: instance: BlueprintInstance = BlueprintInstance.objects.filter(pk=instance_pk).first() if not instance or not instance.enabled: @@ -225,10 +224,10 @@ def apply_blueprint(self: SystemTask, instance_pk: str): instance.last_applied = now() self.set_status(TaskStatus.SUCCESSFUL) except ( + OSError, DatabaseError, ProgrammingError, InternalError, - IOError, BlueprintRetrievalFailed, EntryInvalidError, ) as exc: diff --git a/authentik/brands/middleware.py b/authentik/brands/middleware.py index 3bf0ea9f52..71650cc621 100644 --- a/authentik/brands/middleware.py +++ b/authentik/brands/middleware.py @@ -1,6 +1,6 @@ """Inject brand into current request""" -from typing import Callable +from collections.abc import Callable from django.http.request import HttpRequest from django.http.response import HttpResponse @@ -20,7 +20,7 @@ class BrandMiddleware: def __call__(self, request: HttpRequest) -> HttpResponse: if not hasattr(request, "brand"): brand = get_brand_for_request(request) - setattr(request, "brand", brand) + request.brand = brand locale = brand.default_locale if locale != "": activate(locale) diff --git a/authentik/brands/models.py b/authentik/brands/models.py index 8d3cdc121b..dca4f724ad 100644 --- a/authentik/brands/models.py +++ b/authentik/brands/models.py @@ -71,7 +71,7 @@ class Brand(SerializerModel): """Get default locale""" try: return self.attributes.get("settings", {}).get("locale", "") - # pylint: disable=broad-except + except Exception as exc: LOGGER.warning("Failed to get default locale", exc=exc) return "" diff --git a/authentik/core/api/applications.py b/authentik/core/api/applications.py index 3a89bf28f5..9e766ceacc 100644 --- a/authentik/core/api/applications.py +++ b/authentik/core/api/applications.py @@ -1,8 +1,8 @@ """Application API Views""" +from collections.abc import Iterator from copy import copy from datetime import timedelta -from typing import Iterator, Optional from django.core.cache import cache from django.db.models import QuerySet @@ -60,7 +60,7 @@ class ApplicationSerializer(ModelSerializer): meta_icon = ReadOnlyField(source="get_meta_icon") - def get_launch_url(self, app: Application) -> Optional[str]: + def get_launch_url(self, app: Application) -> str | None: """Allow formatting of launch URL""" user = None if "request" in self.context: @@ -100,7 +100,6 @@ class ApplicationSerializer(ModelSerializer): class ApplicationViewSet(UsedByMixin, ModelViewSet): """Application Viewset""" - # pylint: disable=no-member queryset = Application.objects.all().prefetch_related("provider") serializer_class = ApplicationSerializer search_fields = [ @@ -131,7 +130,7 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet): return queryset def _get_allowed_applications( - self, pagined_apps: Iterator[Application], user: Optional[User] = None + self, pagined_apps: Iterator[Application], user: User | None = None ) -> list[Application]: applications = [] request = self.request._request @@ -169,7 +168,7 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet): try: for_user = User.objects.filter(pk=request.query_params.get("for_user")).first() except ValueError: - raise ValidationError({"for_user": "for_user must be numerical"}) + raise ValidationError({"for_user": "for_user must be numerical"}) from None if not for_user: raise ValidationError({"for_user": "User not found"}) engine = PolicyEngine(application, for_user, request) diff --git a/authentik/core/api/authenticated_sessions.py b/authentik/core/api/authenticated_sessions.py index c9d29aabed..b8094238bf 100644 --- a/authentik/core/api/authenticated_sessions.py +++ b/authentik/core/api/authenticated_sessions.py @@ -1,6 +1,6 @@ """AuthenticatedSessions API Viewset""" -from typing import Optional, TypedDict +from typing import TypedDict from django_filters.rest_framework import DjangoFilterBackend from guardian.utils import get_anonymous_user @@ -72,11 +72,11 @@ class AuthenticatedSessionSerializer(ModelSerializer): """Get parsed user agent""" return user_agent_parser.Parse(instance.last_user_agent) - def get_geo_ip(self, instance: AuthenticatedSession) -> Optional[GeoIPDict]: # pragma: no cover + def get_geo_ip(self, instance: AuthenticatedSession) -> GeoIPDict | None: # pragma: no cover """Get GeoIP Data""" return GEOIP_CONTEXT_PROCESSOR.city_dict(instance.last_ip) - def get_asn(self, instance: AuthenticatedSession) -> Optional[ASNDict]: # pragma: no cover + def get_asn(self, instance: AuthenticatedSession) -> ASNDict | None: # pragma: no cover """Get ASN Data""" return ASN_CONTEXT_PROCESSOR.asn_dict(instance.last_ip) diff --git a/authentik/core/api/groups.py b/authentik/core/api/groups.py index 3f3cdf0a61..680726a440 100644 --- a/authentik/core/api/groups.py +++ b/authentik/core/api/groups.py @@ -1,7 +1,6 @@ """Groups API Viewset""" from json import loads -from typing import Optional from django.http import Http404 from django_filters.filters import CharFilter, ModelMultipleChoiceFilter @@ -59,7 +58,7 @@ class GroupSerializer(ModelSerializer): num_pk = IntegerField(read_only=True) - def validate_parent(self, parent: Optional[Group]): + def validate_parent(self, parent: Group | None): """Validate group parent (if set), ensuring the parent isn't itself""" if not self.instance or not parent: return parent @@ -114,7 +113,7 @@ class GroupFilter(FilterSet): try: value = loads(value) except ValueError: - raise ValidationError(detail="filter: failed to parse JSON") + raise ValidationError(detail="filter: failed to parse JSON") from None if not isinstance(value, dict): raise ValidationError(detail="filter: value must be key:value mapping") qs = {} @@ -140,7 +139,6 @@ class UserAccountSerializer(PassiveSerializer): class GroupViewSet(UsedByMixin, ModelViewSet): """Group Viewset""" - # pylint: disable=no-member queryset = Group.objects.all().select_related("parent").prefetch_related("users") serializer_class = GroupSerializer search_fields = ["name", "is_superuser"] diff --git a/authentik/core/api/propertymappings.py b/authentik/core/api/propertymappings.py index 25c83df843..8cb9590977 100644 --- a/authentik/core/api/propertymappings.py +++ b/authentik/core/api/propertymappings.py @@ -146,7 +146,7 @@ class PropertyMappingViewSet( response_data["result"] = dumps( sanitize_item(result), indent=(4 if format_result else None) ) - except Exception as exc: # pylint: disable=broad-except + except Exception as exc: response_data["result"] = str(exc) response_data["successful"] = False response = PropertyMappingTestResultSerializer(response_data) diff --git a/authentik/core/api/sources.py b/authentik/core/api/sources.py index 395464c20b..04341d8804 100644 --- a/authentik/core/api/sources.py +++ b/authentik/core/api/sources.py @@ -1,6 +1,6 @@ """Source API Views""" -from typing import Iterable +from collections.abc import Iterable from django_filters.rest_framework import DjangoFilterBackend from drf_spectacular.utils import OpenApiResponse, extend_schema diff --git a/authentik/core/api/transactional_applications.py b/authentik/core/api/transactional_applications.py index e320fae07e..1e47096ca6 100644 --- a/authentik/core/api/transactional_applications.py +++ b/authentik/core/api/transactional_applications.py @@ -65,7 +65,7 @@ class TransactionApplicationSerializer(PassiveSerializer): raise ValidationError("Invalid provider model") self._provider_model = model except LookupError: - raise ValidationError("Invalid provider model") + raise ValidationError("Invalid provider model") from None return fq_model_name def validate(self, attrs: dict) -> dict: @@ -106,7 +106,7 @@ class TransactionApplicationSerializer(PassiveSerializer): { exc.entry_id: exc.validation_error.detail, } - ) + ) from None return blueprint diff --git a/authentik/core/api/used_by.py b/authentik/core/api/used_by.py index 2aa2b07968..73c3a6a790 100644 --- a/authentik/core/api/used_by.py +++ b/authentik/core/api/used_by.py @@ -54,7 +54,6 @@ class UsedByMixin: responses={200: UsedBySerializer(many=True)}, ) @action(detail=True, pagination_class=None, filter_backends=[]) - # pylint: disable=too-many-locals def used_by(self, request: Request, *args, **kwargs) -> Response: """Get a list of all objects that use this object""" model: Model = self.get_object() diff --git a/authentik/core/api/users.py b/authentik/core/api/users.py index 6449d6760a..79bb389bf4 100644 --- a/authentik/core/api/users.py +++ b/authentik/core/api/users.py @@ -2,7 +2,7 @@ from datetime import timedelta from json import loads -from typing import Any, Optional +from typing import Any from django.contrib.auth import update_session_auth_hash from django.contrib.sessions.backends.cache import KEY_PREFIX @@ -142,7 +142,7 @@ class UserSerializer(ModelSerializer): self._set_password(instance, password) return instance - def _set_password(self, instance: User, password: Optional[str]): + def _set_password(self, instance: User, password: str | None): """Set password of user if we're in a blueprint context, and if it's an empty string then use an unusable password""" if SERIALIZER_CONTEXT_BLUEPRINT in self.context and password: @@ -358,7 +358,7 @@ class UsersFilter(FilterSet): try: value = loads(value) except ValueError: - raise ValidationError(detail="filter: failed to parse JSON") + raise ValidationError(detail="filter: failed to parse JSON") from None if not isinstance(value, dict): raise ValidationError(detail="filter: value must be key:value mapping") qs = {} @@ -416,7 +416,9 @@ class UserViewSet(UsedByMixin, ModelViewSet): }, ) except FlowNonApplicableException: - raise ValidationError({"non_field_errors": "Recovery flow not applicable to user"}) + raise ValidationError( + {"non_field_errors": "Recovery flow not applicable to user"} + ) from None token, __ = FlowToken.objects.update_or_create( identifier=f"{user.uid}-password-reset", defaults={ diff --git a/authentik/core/auth.py b/authentik/core/auth.py index c73a3802a5..7d2522d48e 100644 --- a/authentik/core/auth.py +++ b/authentik/core/auth.py @@ -1,6 +1,6 @@ """Authenticate with tokens""" -from typing import Any, Optional +from typing import Any from django.contrib.auth.backends import ModelBackend from django.http.request import HttpRequest @@ -16,15 +16,15 @@ class InbuiltBackend(ModelBackend): """Inbuilt backend""" def authenticate( - self, request: HttpRequest, username: Optional[str], password: Optional[str], **kwargs: Any - ) -> Optional[User]: + self, request: HttpRequest, username: str | None, password: str | None, **kwargs: Any + ) -> User | None: user = super().authenticate(request, username=username, password=password, **kwargs) if not user: return None self.set_method("password", request) return user - def set_method(self, method: str, request: Optional[HttpRequest], **kwargs): + def set_method(self, method: str, request: HttpRequest | None, **kwargs): """Set method data on current flow, if possbiel""" if not request: return @@ -40,18 +40,18 @@ class TokenBackend(InbuiltBackend): """Authenticate with token""" def authenticate( - self, request: HttpRequest, username: Optional[str], password: Optional[str], **kwargs: Any - ) -> Optional[User]: + self, request: HttpRequest, username: str | None, password: str | None, **kwargs: Any + ) -> User | None: try: - # pylint: disable=no-member + user = User._default_manager.get_by_natural_key(username) - # pylint: disable=no-member + except User.DoesNotExist: # Run the default password hasher once to reduce the timing # difference between an existing and a nonexistent user (#20760). User().set_password(password) return None - # pylint: disable=no-member + tokens = Token.filter_not_expired( user=user, key=password, intent=TokenIntents.INTENT_APP_PASSWORD ) diff --git a/authentik/core/channels.py b/authentik/core/channels.py index 3678be005b..4b0de49aa7 100644 --- a/authentik/core/channels.py +++ b/authentik/core/channels.py @@ -38,6 +38,6 @@ class TokenOutpostMiddleware: raise DenyConnection() except AuthenticationFailed as exc: LOGGER.warning("Failed to authenticate", exc=exc) - raise DenyConnection() + raise DenyConnection() from None scope["user"] = user diff --git a/authentik/core/expression/evaluator.py b/authentik/core/expression/evaluator.py index 77aa890c1b..a8d365a1c0 100644 --- a/authentik/core/expression/evaluator.py +++ b/authentik/core/expression/evaluator.py @@ -1,6 +1,6 @@ """Property Mapping Evaluator""" -from typing import Any, Optional +from typing import Any from django.db.models import Model from django.http import HttpRequest @@ -27,9 +27,9 @@ class PropertyMappingEvaluator(BaseEvaluator): def __init__( self, model: Model, - user: Optional[User] = None, - request: Optional[HttpRequest] = None, - dry_run: Optional[bool] = False, + user: User | None = None, + request: HttpRequest | None = None, + dry_run: bool | None = False, **kwargs, ): if hasattr(model, "name"): diff --git a/authentik/core/management/commands/shell.py b/authentik/core/management/commands/shell.py index a04545f5cd..6731069ae3 100644 --- a/authentik/core/management/commands/shell.py +++ b/authentik/core/management/commands/shell.py @@ -16,13 +16,8 @@ from authentik.events.middleware import should_log_model from authentik.events.models import Event, EventAction from authentik.events.utils import model_to_dict -BANNER_TEXT = """### authentik shell ({authentik}) -### Node {node} | Arch {arch} | Python {python} """.format( - node=platform.node(), - python=platform.python_version(), - arch=platform.machine(), - authentik=get_full_version(), -) +BANNER_TEXT = f"""### authentik shell ({get_full_version()}) +### Node {platform.node()} | Arch {platform.machine()} | Python {platform.python_version()} """ class Command(BaseCommand): @@ -86,7 +81,7 @@ class Command(BaseCommand): # If Python code has been passed, execute it and exit. if options["command"]: - # pylint: disable=exec-used + exec(options["command"], namespace) # nosec # noqa return @@ -99,7 +94,7 @@ class Command(BaseCommand): else: try: hook() - except Exception: # pylint: disable=broad-except + except Exception: # Match the behavior of the cpython shell where an error in # sys.__interactivehook__ prints a warning and the exception # and continues. diff --git a/authentik/core/middleware.py b/authentik/core/middleware.py index 4b710ef12a..f59b9aa6b7 100644 --- a/authentik/core/middleware.py +++ b/authentik/core/middleware.py @@ -1,7 +1,7 @@ """authentik admin Middleware to impersonate users""" +from collections.abc import Callable from contextvars import ContextVar -from typing import Callable, Optional from uuid import uuid4 from django.http import HttpRequest, HttpResponse @@ -15,9 +15,9 @@ RESPONSE_HEADER_ID = "X-authentik-id" KEY_AUTH_VIA = "auth_via" KEY_USER = "user" -CTX_REQUEST_ID = ContextVar[Optional[str]](STRUCTLOG_KEY_PREFIX + "request_id", default=None) -CTX_HOST = ContextVar[Optional[str]](STRUCTLOG_KEY_PREFIX + "host", default=None) -CTX_AUTH_VIA = ContextVar[Optional[str]](STRUCTLOG_KEY_PREFIX + KEY_AUTH_VIA, default=None) +CTX_REQUEST_ID = ContextVar[str | None](STRUCTLOG_KEY_PREFIX + "request_id", default=None) +CTX_HOST = ContextVar[str | None](STRUCTLOG_KEY_PREFIX + "host", default=None) +CTX_AUTH_VIA = ContextVar[str | None](STRUCTLOG_KEY_PREFIX + KEY_AUTH_VIA, default=None) class ImpersonateMiddleware: @@ -55,7 +55,7 @@ class RequestIDMiddleware: def __call__(self, request: HttpRequest) -> HttpResponse: if not hasattr(request, "request_id"): request_id = uuid4().hex - setattr(request, "request_id", request_id) + request.request_id = request_id CTX_REQUEST_ID.set(request_id) CTX_HOST.set(request.get_host()) set_tag("authentik.request_id", request_id) @@ -67,7 +67,7 @@ class RequestIDMiddleware: response = self.get_response(request) response[RESPONSE_HEADER_ID] = request.request_id - setattr(response, "ak_context", {}) + response.ak_context = {} response.ak_context["request_id"] = CTX_REQUEST_ID.get() response.ak_context["host"] = CTX_HOST.get() response.ak_context[KEY_AUTH_VIA] = CTX_AUTH_VIA.get() diff --git a/authentik/core/models.py b/authentik/core/models.py index d35015d0d3..0c3dd3fbf2 100644 --- a/authentik/core/models.py +++ b/authentik/core/models.py @@ -222,7 +222,7 @@ class User(SerializerModel, GuardianUserMixin, AbstractUser): there are at most 3 queries done""" return Group.children_recursive(self.ak_groups.all()) - def group_attributes(self, request: Optional[HttpRequest] = None) -> dict[str, Any]: + def group_attributes(self, request: HttpRequest | None = None) -> dict[str, Any]: """Get a dictionary containing the attributes from all groups the user belongs to, including the users attributes""" final_attributes = {} @@ -278,11 +278,11 @@ class User(SerializerModel, GuardianUserMixin, AbstractUser): """Generate a globally unique UID, based on the user ID and the hashed secret key""" return sha256(f"{self.id}-{get_install_id()}".encode("ascii")).hexdigest() - def locale(self, request: Optional[HttpRequest] = None) -> str: + def locale(self, request: HttpRequest | None = None) -> str: """Get the locale the user has configured""" try: return self.attributes.get("settings", {}).get("locale", "") - # pylint: disable=broad-except + except Exception as exc: LOGGER.warning("Failed to get default locale", exc=exc) if request: @@ -358,7 +358,7 @@ class Provider(SerializerModel): objects = InheritanceManager() @property - def launch_url(self) -> Optional[str]: + def launch_url(self) -> str | None: """URL to this provider and initiate authorization for the user. Can return None for providers that are not URL-based""" return None @@ -435,7 +435,7 @@ class Application(SerializerModel, PolicyBindingModel): return ApplicationSerializer @property - def get_meta_icon(self) -> Optional[str]: + def get_meta_icon(self) -> str | None: """Get the URL to the App Icon image. If the name is /static or starts with http it is returned as-is""" if not self.meta_icon: @@ -444,7 +444,7 @@ class Application(SerializerModel, PolicyBindingModel): return self.meta_icon.name return self.meta_icon.url - def get_launch_url(self, user: Optional["User"] = None) -> Optional[str]: + def get_launch_url(self, user: Optional["User"] = None) -> str | None: """Get launch URL if set, otherwise attempt to get launch URL based on provider.""" url = None if self.meta_launch_url: @@ -457,13 +457,13 @@ class Application(SerializerModel, PolicyBindingModel): user = user._wrapped try: return url % user.__dict__ - # pylint: disable=broad-except + except Exception as exc: LOGGER.warning("Failed to format launch url", exc=exc) return url return url - def get_provider(self) -> Optional[Provider]: + def get_provider(self) -> Provider | None: """Get casted provider instance""" if not self.provider: return None @@ -551,7 +551,7 @@ class Source(ManagedModel, SerializerModel, PolicyBindingModel): objects = InheritanceManager() @property - def icon_url(self) -> Optional[str]: + def icon_url(self) -> str | None: """Get the URL to the Icon. If the name is /static or starts with http it is returned as-is""" if not self.icon: @@ -566,7 +566,7 @@ class Source(ManagedModel, SerializerModel, PolicyBindingModel): return self.user_path_template % { "slug": self.slug, } - # pylint: disable=broad-except + except Exception as exc: LOGGER.warning("Failed to template user path", exc=exc, source=self) return User.default_path() @@ -576,12 +576,12 @@ class Source(ManagedModel, SerializerModel, PolicyBindingModel): """Return component used to edit this object""" raise NotImplementedError - def ui_login_button(self, request: HttpRequest) -> Optional[UILoginButton]: + def ui_login_button(self, request: HttpRequest) -> UILoginButton | None: """If source uses a http-based flow, return UI Information about the login button. If source doesn't use http-based flow, return None.""" return None - def ui_user_settings(self) -> Optional[UserSettingSerializer]: + def ui_user_settings(self) -> UserSettingSerializer | None: """Entrypoint to integrate with User settings. Can either return None if no user settings are available, or UserSettingSerializer.""" return None @@ -627,6 +627,9 @@ class ExpiringModel(models.Model): expires = models.DateTimeField(default=default_token_duration) expiring = models.BooleanField(default=True) + class Meta: + abstract = True + def expire_action(self, *args, **kwargs): """Handler which is called when this object is expired. By default the object is deleted. This is less efficient compared @@ -649,9 +652,6 @@ class ExpiringModel(models.Model): return False return now() > self.expires - class Meta: - abstract = True - class TokenIntents(models.TextChoices): """Intents a Token can be created for.""" @@ -681,6 +681,21 @@ class Token(SerializerModel, ManagedModel, ExpiringModel): user = models.ForeignKey("User", on_delete=models.CASCADE, related_name="+") description = models.TextField(default="", blank=True) + class Meta: + verbose_name = _("Token") + verbose_name_plural = _("Tokens") + indexes = [ + models.Index(fields=["identifier"]), + models.Index(fields=["key"]), + ] + permissions = [("view_token_key", _("View token's key"))] + + def __str__(self): + description = f"{self.identifier}" + if self.expiring: + description += f" (expires={self.expires})" + return description + @property def serializer(self) -> type[Serializer]: from authentik.core.api.tokens import TokenSerializer @@ -708,21 +723,6 @@ class Token(SerializerModel, ManagedModel, ExpiringModel): message=f"Token {self.identifier}'s secret was rotated.", ).save() - def __str__(self): - description = f"{self.identifier}" - if self.expiring: - description += f" (expires={self.expires})" - return description - - class Meta: - verbose_name = _("Token") - verbose_name_plural = _("Tokens") - indexes = [ - models.Index(fields=["identifier"]), - models.Index(fields=["key"]), - ] - permissions = [("view_token_key", _("View token's key"))] - class PropertyMapping(SerializerModel, ManagedModel): """User-defined key -> x mapping which can be used by providers to expose extra data.""" @@ -743,7 +743,7 @@ class PropertyMapping(SerializerModel, ManagedModel): """Get serializer for this model""" raise NotImplementedError - def evaluate(self, user: Optional[User], request: Optional[HttpRequest], **kwargs) -> Any: + def evaluate(self, user: User | None, request: HttpRequest | None, **kwargs) -> Any: """Evaluate `self.expression` using `**kwargs` as Context.""" from authentik.core.expression.evaluator import PropertyMappingEvaluator @@ -779,6 +779,13 @@ class AuthenticatedSession(ExpiringModel): last_user_agent = models.TextField(blank=True) last_used = models.DateTimeField(auto_now=True) + class Meta: + verbose_name = _("Authenticated Session") + verbose_name_plural = _("Authenticated Sessions") + + def __str__(self) -> str: + return f"Authenticated Session {self.session_key[:10]}" + @staticmethod def from_request(request: HttpRequest, user: User) -> Optional["AuthenticatedSession"]: """Create a new session from a http request""" @@ -793,7 +800,3 @@ class AuthenticatedSession(ExpiringModel): last_user_agent=request.META.get("HTTP_USER_AGENT", ""), expires=request.session.get_expiry_date(), ) - - class Meta: - verbose_name = _("Authenticated Session") - verbose_name_plural = _("Authenticated Sessions") diff --git a/authentik/core/sources/flow_manager.py b/authentik/core/sources/flow_manager.py index b90add829b..bbcc9b3779 100644 --- a/authentik/core/sources/flow_manager.py +++ b/authentik/core/sources/flow_manager.py @@ -1,7 +1,7 @@ """Source decision helper""" from enum import Enum -from typing import Any, Optional +from typing import Any from django.contrib import messages from django.db import IntegrityError @@ -90,15 +90,14 @@ class SourceFlowManager: self._logger = get_logger().bind(source=source, identifier=identifier) self.policy_context = {} - # pylint: disable=too-many-return-statements - def get_action(self, **kwargs) -> tuple[Action, Optional[UserSourceConnection]]: + def get_action(self, **kwargs) -> tuple[Action, UserSourceConnection | None]: # noqa: PLR0911 """decide which action should be taken""" new_connection = self.connection_type(source=self.source, identifier=self.identifier) # When request is authenticated, always link if self.request.user.is_authenticated: new_connection.user = self.request.user new_connection = self.update_connection(new_connection, **kwargs) - # pylint: disable=no-member + new_connection.save() return Action.LINK, new_connection @@ -188,8 +187,10 @@ class SourceFlowManager: # Default case, assume deny error = Exception( _( - "Request to authenticate with %(source)s has been denied. Please authenticate " - "with the source you've previously signed up with." % {"source": self.source.name} + "Request to authenticate with {source} has been denied. Please authenticate " + "with the source you've previously signed up with.".format_map( + {"source": self.source.name} + ) ), ) return self.error_handler(error) @@ -217,7 +218,7 @@ class SourceFlowManager: self, flow: Flow, connection: UserSourceConnection, - stages: Optional[list[StageView]] = None, + stages: list[StageView] | None = None, **kwargs, ) -> HttpResponse: """Prepare Authentication Plan, redirect user FlowExecutor""" @@ -270,7 +271,9 @@ class SourceFlowManager: in_memory_stage( MessageStage, message=_( - "Successfully authenticated with %(source)s!" % {"source": self.source.name} + "Successfully authenticated with {source}!".format_map( + {"source": self.source.name} + ) ), ) ], @@ -294,7 +297,7 @@ class SourceFlowManager: ).from_http(self.request) messages.success( self.request, - _("Successfully linked %(source)s!" % {"source": self.source.name}), + _("Successfully linked {source}!".format_map({"source": self.source.name})), ) return redirect( reverse( @@ -322,7 +325,9 @@ class SourceFlowManager: in_memory_stage( MessageStage, message=_( - "Successfully authenticated with %(source)s!" % {"source": self.source.name} + "Successfully authenticated with {source}!".format_map( + {"source": self.source.name} + ) ), ) ], diff --git a/authentik/core/tasks.py b/authentik/core/tasks.py index 61ac773d6d..406e80c60d 100644 --- a/authentik/core/tasks.py +++ b/authentik/core/tasks.py @@ -37,20 +37,20 @@ def clean_expired_models(self: SystemTask): messages.append(f"Expired {amount} {cls._meta.verbose_name_plural}") # Special case amount = 0 - # pylint: disable=no-member + for session in AuthenticatedSession.objects.all(): cache_key = f"{KEY_PREFIX}{session.session_key}" value = None try: value = cache.get(cache_key) - # pylint: disable=broad-except + except Exception as exc: LOGGER.debug("Failed to get session from cache", exc=exc) if not value: session.delete() amount += 1 LOGGER.debug("Expired sessions", model=AuthenticatedSession, amount=amount) - # pylint: disable=no-member + messages.append(f"Expired {amount} {AuthenticatedSession._meta.verbose_name_plural}") self.set_status(TaskStatus.SUCCESSFUL, *messages) diff --git a/authentik/core/tests/test_models.py b/authentik/core/tests/test_models.py index 1c187daea5..4fd246bbb8 100644 --- a/authentik/core/tests/test_models.py +++ b/authentik/core/tests/test_models.py @@ -1,7 +1,7 @@ """authentik core models tests""" +from collections.abc import Callable from time import sleep -from typing import Callable from django.test import RequestFactory, TestCase from django.utils.timezone import now diff --git a/authentik/core/tests/test_source_flow_manager.py b/authentik/core/tests/test_source_flow_manager.py index 264cdc9d70..f4fa71a4be 100644 --- a/authentik/core/tests/test_source_flow_manager.py +++ b/authentik/core/tests/test_source_flow_manager.py @@ -173,5 +173,5 @@ class TestSourceFlowManager(TestCase): self.assertEqual(action, Action.ENROLL) response = flow_manager.get_flow() self.assertIsInstance(response, AccessDeniedResponse) - # pylint: disable=no-member + self.assertEqual(response.error_message, "foo") diff --git a/authentik/core/tests/utils.py b/authentik/core/tests/utils.py index c46c6ad792..c54492a6dc 100644 --- a/authentik/core/tests/utils.py +++ b/authentik/core/tests/utils.py @@ -1,7 +1,5 @@ """Test Utils""" -from typing import Optional - from django.utils.text import slugify from authentik.brands.models import Brand @@ -22,7 +20,7 @@ def create_test_flow( ) -def create_test_user(name: Optional[str] = None, **kwargs) -> User: +def create_test_user(name: str | None = None, **kwargs) -> User: """Generate a test user""" uid = generate_id(20) if not name else name kwargs.setdefault("email", f"{uid}@goauthentik.io") @@ -36,7 +34,7 @@ def create_test_user(name: Optional[str] = None, **kwargs) -> User: return user -def create_test_admin_user(name: Optional[str] = None, **kwargs) -> User: +def create_test_admin_user(name: str | None = None, **kwargs) -> User: """Generate a test-admin user""" user = create_test_user(name, **kwargs) group = Group.objects.create(name=user.name or name, is_superuser=True) diff --git a/authentik/core/types.py b/authentik/core/types.py index af8d32eb56..32468674c0 100644 --- a/authentik/core/types.py +++ b/authentik/core/types.py @@ -1,7 +1,6 @@ """authentik core dataclasses""" from dataclasses import dataclass -from typing import Optional from rest_framework.fields import CharField @@ -20,7 +19,7 @@ class UILoginButton: challenge: Challenge # Icon URL, used as-is - icon_url: Optional[str] = None + icon_url: str | None = None class UserSettingSerializer(PassiveSerializer): diff --git a/authentik/core/views/apps.py b/authentik/core/views/apps.py index 351507256b..9843cd1257 100644 --- a/authentik/core/views/apps.py +++ b/authentik/core/views/apps.py @@ -57,7 +57,7 @@ class RedirectToAppLaunch(View): }, ) except FlowNonApplicableException: - raise Http404 + raise Http404 from None plan.insert_stage(in_memory_stage(RedirectToAppStage)) request.session[SESSION_KEY_PLAN] = plan return redirect_with_qs("authentik_core:if-flow", request.GET, flow_slug=flow.slug) diff --git a/authentik/core/views/error.py b/authentik/core/views/error.py index 51868d9971..68c40ed4a5 100644 --- a/authentik/core/views/error.py +++ b/authentik/core/views/error.py @@ -61,7 +61,6 @@ class ServerErrorView(TemplateView): response_class = ServerErrorTemplateResponse template_name = "if/error.html" - # pylint: disable=useless-super-delegation def dispatch(self, *args, **kwargs): # pragma: no cover """Little wrapper so django accepts this function""" return super().dispatch(*args, **kwargs) diff --git a/authentik/crypto/api.py b/authentik/crypto/api.py index c68235c849..e028c12318 100644 --- a/authentik/crypto/api.py +++ b/authentik/crypto/api.py @@ -1,7 +1,6 @@ """Crypto API Views""" from datetime import datetime -from typing import Optional from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.serialization import load_pem_private_key @@ -56,25 +55,25 @@ class CertificateKeyPairSerializer(ModelSerializer): return True return str(request.query_params.get("include_details", "true")).lower() == "true" - def get_fingerprint_sha256(self, instance: CertificateKeyPair) -> Optional[str]: + def get_fingerprint_sha256(self, instance: CertificateKeyPair) -> str | None: "Get certificate Hash (SHA256)" if not self._should_include_details: return None return instance.fingerprint_sha256 - def get_fingerprint_sha1(self, instance: CertificateKeyPair) -> Optional[str]: + def get_fingerprint_sha1(self, instance: CertificateKeyPair) -> str | None: "Get certificate Hash (SHA1)" if not self._should_include_details: return None return instance.fingerprint_sha1 - def get_cert_expiry(self, instance: CertificateKeyPair) -> Optional[datetime]: + def get_cert_expiry(self, instance: CertificateKeyPair) -> datetime | None: "Get certificate expiry" if not self._should_include_details: return None return DateTimeField().to_representation(instance.certificate.not_valid_after) - def get_cert_subject(self, instance: CertificateKeyPair) -> Optional[str]: + def get_cert_subject(self, instance: CertificateKeyPair) -> str | None: """Get certificate subject as full rfc4514""" if not self._should_include_details: return None @@ -84,7 +83,7 @@ class CertificateKeyPairSerializer(ModelSerializer): """Show if this keypair has a private key configured or not""" return instance.key_data != "" and instance.key_data is not None - def get_private_key_type(self, instance: CertificateKeyPair) -> Optional[str]: + def get_private_key_type(self, instance: CertificateKeyPair) -> str | None: """Get the private key's type, if set""" if not self._should_include_details: return None @@ -121,7 +120,7 @@ class CertificateKeyPairSerializer(ModelSerializer): str(load_pem_x509_certificate(value.encode("utf-8"), default_backend())) except ValueError as exc: LOGGER.warning("Failed to load certificate", exc=exc) - raise ValidationError("Unable to load certificate.") + raise ValidationError("Unable to load certificate.") from None return value def validate_key_data(self, value: str) -> str: @@ -140,7 +139,7 @@ class CertificateKeyPairSerializer(ModelSerializer): ) except (ValueError, TypeError) as exc: LOGGER.warning("Failed to load private key", exc=exc) - raise ValidationError("Unable to load private key (possibly encrypted?).") + raise ValidationError("Unable to load private key (possibly encrypted?).") from None return value class Meta: diff --git a/authentik/crypto/apps.py b/authentik/crypto/apps.py index 2336930cca..cdb01b3a1b 100644 --- a/authentik/crypto/apps.py +++ b/authentik/crypto/apps.py @@ -1,7 +1,6 @@ """authentik crypto app config""" -from datetime import datetime, timezone -from typing import Optional +from datetime import UTC, datetime from authentik.blueprints.apps import ManagedAppConfig from authentik.lib.generators import generate_id @@ -41,10 +40,10 @@ class AuthentikCryptoConfig(ManagedAppConfig): """Ensure managed JWT certificate""" from authentik.crypto.models import CertificateKeyPair - cert: Optional[CertificateKeyPair] = CertificateKeyPair.objects.filter( + cert: CertificateKeyPair | None = CertificateKeyPair.objects.filter( managed=MANAGED_KEY ).first() - now = datetime.now(tz=timezone.utc) + now = datetime.now(tz=UTC) if not cert or ( now < cert.certificate.not_valid_after_utc or now > cert.certificate.not_valid_after_utc ): diff --git a/authentik/crypto/builder.py b/authentik/crypto/builder.py index 2ad0563916..70ad419a90 100644 --- a/authentik/crypto/builder.py +++ b/authentik/crypto/builder.py @@ -2,7 +2,6 @@ import datetime import uuid -from typing import Optional from cryptography import x509 from cryptography.hazmat.backends import default_backend @@ -52,7 +51,7 @@ class CertificateBuilder: def build( self, validity_days: int = 365, - subject_alt_names: Optional[list[str]] = None, + subject_alt_names: list[str] | None = None, ): """Build self-signed certificate""" one_day = datetime.timedelta(1, 0, 0) diff --git a/authentik/crypto/management/commands/import_certificate.py b/authentik/crypto/management/commands/import_certificate.py index b8b04e6a10..708e908207 100644 --- a/authentik/crypto/management/commands/import_certificate.py +++ b/authentik/crypto/management/commands/import_certificate.py @@ -24,13 +24,13 @@ class Command(TenantCommand): if not keypair: keypair = CertificateKeyPair(name=options["name"]) dirty = True - with open(options["certificate"], mode="r", encoding="utf-8") as _cert: + with open(options["certificate"], encoding="utf-8") as _cert: cert_data = _cert.read() if keypair.certificate_data != cert_data: dirty = True keypair.certificate_data = cert_data if options["private_key"]: - with open(options["private_key"], mode="r", encoding="utf-8") as _key: + with open(options["private_key"], encoding="utf-8") as _key: key_data = _key.read() if keypair.key_data != key_data: dirty = True diff --git a/authentik/crypto/models.py b/authentik/crypto/models.py index 3c8120ea9c..444b31ce62 100644 --- a/authentik/crypto/models.py +++ b/authentik/crypto/models.py @@ -2,7 +2,6 @@ from binascii import hexlify from hashlib import md5 -from typing import Optional from uuid import uuid4 from cryptography.hazmat.backends import default_backend @@ -37,9 +36,9 @@ class CertificateKeyPair(SerializerModel, ManagedModel, CreatedUpdatedModel): default="", ) - _cert: Optional[Certificate] = None - _private_key: Optional[PrivateKeyTypes] = None - _public_key: Optional[PublicKeyTypes] = None + _cert: Certificate | None = None + _private_key: PrivateKeyTypes | None = None + _public_key: PublicKeyTypes | None = None @property def serializer(self) -> Serializer: @@ -57,7 +56,7 @@ class CertificateKeyPair(SerializerModel, ManagedModel, CreatedUpdatedModel): return self._cert @property - def public_key(self) -> Optional[PublicKeyTypes]: + def public_key(self) -> PublicKeyTypes | None: """Get public key of the private key""" if not self._public_key: self._public_key = self.private_key.public_key() @@ -66,7 +65,7 @@ class CertificateKeyPair(SerializerModel, ManagedModel, CreatedUpdatedModel): @property def private_key( self, - ) -> Optional[PrivateKeyTypes]: + ) -> PrivateKeyTypes | None: """Get python cryptography PrivateKey instance""" if not self._private_key and self.key_data != "": try: diff --git a/authentik/crypto/tasks.py b/authentik/crypto/tasks.py index cef8f61c61..43704e7682 100644 --- a/authentik/crypto/tasks.py +++ b/authentik/crypto/tasks.py @@ -58,7 +58,7 @@ def certificate_discovery(self: SystemTask): else: cert_name = path.name.replace(path.suffix, "") try: - with open(path, "r", encoding="utf-8") as _file: + with open(path, encoding="utf-8") as _file: body = _file.read() if "PRIVATE KEY" in body: private_keys[cert_name] = ensure_private_key_valid(body) diff --git a/authentik/crypto/tests.py b/authentik/crypto/tests.py index 056425e06d..52fd8cba0b 100644 --- a/authentik/crypto/tests.py +++ b/authentik/crypto/tests.py @@ -267,7 +267,7 @@ class TestCrypto(APITestCase): with open(f"{temp_dir}/foo.bar/privkey.pem", "w+", encoding="utf-8") as _key: _key.write(builder.private_key) with CONFIG.patch("cert_discovery_dir", temp_dir): - certificate_discovery() # pylint: disable=no-value-for-parameter + certificate_discovery() keypair: CertificateKeyPair = CertificateKeyPair.objects.filter( managed=MANAGED_DISCOVERED % "foo" ).first() diff --git a/authentik/enterprise/audit/middleware.py b/authentik/enterprise/audit/middleware.py index ad649d1a04..9a621721eb 100644 --- a/authentik/enterprise/audit/middleware.py +++ b/authentik/enterprise/audit/middleware.py @@ -62,7 +62,7 @@ class EnterpriseAuditMiddleware(AuditMiddleware): field_value = value.name # If current field value is an expression, we are not evaluating it - if isinstance(field_value, (BaseExpression, Combinable)): + if isinstance(field_value, BaseExpression | Combinable): continue field_value = field.to_python(field_value) data[field.name] = deepcopy(field_value) @@ -83,12 +83,11 @@ class EnterpriseAuditMiddleware(AuditMiddleware): if hasattr(instance, "_previous_state"): return before = len(connection.queries) - setattr(instance, "_previous_state", self.serialize_simple(instance)) + instance._previous_state = self.serialize_simple(instance) after = len(connection.queries) if after > before: raise AssertionError("More queries generated by serialize_simple") - # pylint: disable=too-many-arguments def post_save_handler( self, user: User, diff --git a/authentik/enterprise/license.py b/authentik/enterprise/license.py index 7baa1b378a..7978f62a00 100644 --- a/authentik/enterprise/license.py +++ b/authentik/enterprise/license.py @@ -27,7 +27,7 @@ CACHE_KEY_ENTERPRISE_LICENSE = "goauthentik.io/enterprise/license" CACHE_EXPIRY_ENTERPRISE_LICENSE = 3 * 60 * 60 # 2 Hours -@lru_cache() +@lru_cache def get_licensing_key() -> Certificate: """Get Root CA PEM""" with open("authentik/enterprise/public.pem", "rb") as _key: @@ -88,7 +88,7 @@ class LicenseKey: try: headers = get_unverified_header(jwt) except PyJWTError: - raise ValidationError("Unable to verify license") + raise ValidationError("Unable to verify license") from None x5c: list[str] = headers.get("x5c", []) if len(x5c) < 1: raise ValidationError("Unable to verify license") @@ -98,7 +98,7 @@ class LicenseKey: our_cert.verify_directly_issued_by(intermediate) intermediate.verify_directly_issued_by(get_licensing_key()) except (InvalidSignature, TypeError, ValueError, Error): - raise ValidationError("Unable to verify license") + raise ValidationError("Unable to verify license") from None try: body = from_dict( LicenseKey, @@ -110,7 +110,7 @@ class LicenseKey: ), ) except PyJWTError: - raise ValidationError("Unable to verify license") + raise ValidationError("Unable to verify license") from None return body @staticmethod diff --git a/authentik/enterprise/policy.py b/authentik/enterprise/policy.py index 2e2535de0c..904c3f73ee 100644 --- a/authentik/enterprise/policy.py +++ b/authentik/enterprise/policy.py @@ -1,7 +1,5 @@ """Enterprise license policies""" -from typing import Optional - from django.utils.translation import gettext_lazy as _ from authentik.core.models import User, UserTypes @@ -21,7 +19,7 @@ class EnterprisePolicyAccessView(PolicyAccessView): return PolicyResult(False, _("Feature only accessible for internal users.")) return PolicyResult(True) - def user_has_access(self, user: Optional[User] = None) -> PolicyResult: + def user_has_access(self, user: User | None = None) -> PolicyResult: user = user or self.request.user request = PolicyRequest(user) request.http_request = self.request diff --git a/authentik/enterprise/providers/rac/api/endpoints.py b/authentik/enterprise/providers/rac/api/endpoints.py index c681350e8e..0dab4ca5f2 100644 --- a/authentik/enterprise/providers/rac/api/endpoints.py +++ b/authentik/enterprise/providers/rac/api/endpoints.py @@ -1,7 +1,5 @@ """RAC Provider API Views""" -from typing import Optional - from django.core.cache import cache from django.db.models import QuerySet from django.urls import reverse @@ -36,11 +34,11 @@ class EndpointSerializer(EnterpriseRequiredMixin, ModelSerializer): provider_obj = RACProviderSerializer(source="provider", read_only=True) launch_url = SerializerMethodField() - def get_launch_url(self, endpoint: Endpoint) -> Optional[str]: + def get_launch_url(self, endpoint: Endpoint) -> str | None: """Build actual launch URL (the provider itself does not have one, just individual endpoints)""" try: - # pylint: disable=no-member + return reverse( "authentik_providers_rac:start", kwargs={"app": endpoint.provider.application.slug, "endpoint": endpoint.pk}, diff --git a/authentik/enterprise/providers/rac/models.py b/authentik/enterprise/providers/rac/models.py index c5f866bfd6..5c6e835c61 100644 --- a/authentik/enterprise/providers/rac/models.py +++ b/authentik/enterprise/providers/rac/models.py @@ -1,6 +1,6 @@ """RAC Models""" -from typing import Any, Optional +from typing import Any from uuid import uuid4 from deepmerge import always_merger @@ -58,7 +58,7 @@ class RACProvider(Provider): ) @property - def launch_url(self) -> Optional[str]: + def launch_url(self) -> str | None: """URL to this provider and initiate authorization for the user. Can return None for providers that are not URL-based""" return "goauthentik.io://providers/rac/launch" @@ -112,7 +112,7 @@ class RACPropertyMapping(PropertyMapping): static_settings = models.JSONField(default=dict) - def evaluate(self, user: Optional[User], request: Optional[HttpRequest], **kwargs) -> Any: + def evaluate(self, user: User | None, request: HttpRequest | None, **kwargs) -> Any: """Evaluate `self.expression` using `**kwargs` as Context.""" if len(self.static_settings) > 0: return self.static_settings diff --git a/authentik/enterprise/providers/rac/views.py b/authentik/enterprise/providers/rac/views.py index 1028c1cf70..3cdcce2e0a 100644 --- a/authentik/enterprise/providers/rac/views.py +++ b/authentik/enterprise/providers/rac/views.py @@ -47,7 +47,7 @@ class RACStartView(EnterprisePolicyAccessView): }, ) except FlowNonApplicableException: - raise Http404 + raise Http404 from None plan.insert_stage( in_memory_stage( RACFinalStage, @@ -132,16 +132,7 @@ class RACFinalStage(RedirectStage): flow=self.executor.plan.flow_pk, endpoint=self.endpoint.name, ).from_http(self.request) - setattr( - self.executor.current_stage, - "destination", - self.request.build_absolute_uri( - reverse( - "authentik_providers_rac:if-rac", - kwargs={ - "token": str(token.token), - }, - ) - ), + self.executor.current_stage.destination = self.request.build_absolute_uri( + reverse("authentik_providers_rac:if-rac", kwargs={"token": str(token.token)}) ) return super().get_challenge(*args, **kwargs) diff --git a/authentik/events/api/tasks.py b/authentik/events/api/tasks.py index 581b0fcc60..25b6c38273 100644 --- a/authentik/events/api/tasks.py +++ b/authentik/events/api/tasks.py @@ -92,7 +92,7 @@ class SystemTaskViewSet(ReadOnlyModelViewSet): task_func.delay(*task.task_call_args, **task.task_call_kwargs) messages.success( self.request, - _("Successfully started task %(name)s." % {"name": task.name}), + _("Successfully started task {name}.".format_map({"name": task.name})), ) return Response(status=204) except (ImportError, AttributeError) as exc: # pragma: no cover diff --git a/authentik/events/context_processors/asn.py b/authentik/events/context_processors/asn.py index a125b07446..23288d55f7 100644 --- a/authentik/events/context_processors/asn.py +++ b/authentik/events/context_processors/asn.py @@ -46,7 +46,7 @@ class ASNContextProcessor(MMDBContextProcessor): "asn": self.asn_dict(ClientIPMiddleware.get_client_ip(request)), } - def asn(self, ip_address: str) -> Optional[ASN]: + def asn(self, ip_address: str) -> ASN | None: """Wrapper for Reader.asn""" with Hub.current.start_span( op="authentik.events.asn.asn", @@ -71,7 +71,7 @@ class ASNContextProcessor(MMDBContextProcessor): } return asn_dict - def asn_dict(self, ip_address: str) -> Optional[ASNDict]: + def asn_dict(self, ip_address: str) -> ASNDict | None: """Wrapper for self.asn that returns a dict""" asn = self.asn(ip_address) if not asn: diff --git a/authentik/events/context_processors/geoip.py b/authentik/events/context_processors/geoip.py index 0d32ec991c..76de1a3ae4 100644 --- a/authentik/events/context_processors/geoip.py +++ b/authentik/events/context_processors/geoip.py @@ -47,7 +47,7 @@ class GeoIPContextProcessor(MMDBContextProcessor): # Different key `geoip` vs `geo` for legacy reasons return {"geoip": self.city(ClientIPMiddleware.get_client_ip(request))} - def city(self, ip_address: str) -> Optional[City]: + def city(self, ip_address: str) -> City | None: """Wrapper for Reader.city""" with Hub.current.start_span( op="authentik.events.geo.city", @@ -76,7 +76,7 @@ class GeoIPContextProcessor(MMDBContextProcessor): city_dict["city"] = city.city.name return city_dict - def city_dict(self, ip_address: str) -> Optional[GeoIPDict]: + def city_dict(self, ip_address: str) -> GeoIPDict | None: """Wrapper for self.city that returns a dict""" city = self.city(ip_address) if not city: diff --git a/authentik/events/context_processors/mmdb.py b/authentik/events/context_processors/mmdb.py index 4453e78215..4ba762fe84 100644 --- a/authentik/events/context_processors/mmdb.py +++ b/authentik/events/context_processors/mmdb.py @@ -1,7 +1,6 @@ """Common logic for reading MMDB files""" from pathlib import Path -from typing import Optional from geoip2.database import Reader from structlog.stdlib import get_logger @@ -13,7 +12,7 @@ class MMDBContextProcessor(EventContextProcessor): """Common logic for reading MaxMind DB files, including re-loading if the file has changed""" def __init__(self): - self.reader: Optional[Reader] = None + self.reader: Reader | None = None self._last_mtime: float = 0.0 self.logger = get_logger() self.open() diff --git a/authentik/events/middleware.py b/authentik/events/middleware.py index b39ecb7337..2b4705f3a3 100644 --- a/authentik/events/middleware.py +++ b/authentik/events/middleware.py @@ -1,8 +1,9 @@ """Events middleware""" +from collections.abc import Callable from functools import partial from threading import Thread -from typing import Any, Callable, Optional +from typing import Any from django.conf import settings from django.contrib.sessions.models import Session @@ -49,9 +50,9 @@ class EventNewThread(Thread): action: str request: HttpRequest kwargs: dict[str, Any] - user: Optional[User] = None + user: User | None = None - def __init__(self, action: str, request: HttpRequest, user: Optional[User] = None, **kwargs): + def __init__(self, action: str, request: HttpRequest, user: User | None = None, **kwargs): super().__init__() self.action = action self.request = request @@ -144,7 +145,6 @@ class AuditMiddleware: ) thread.run() - # pylint: disable=too-many-arguments def post_save_handler( self, user: User, @@ -152,7 +152,7 @@ class AuditMiddleware: sender, instance: Model, created: bool, - thread_kwargs: Optional[dict] = None, + thread_kwargs: dict | None = None, **_, ): """Signal handler for all object's post_save""" diff --git a/authentik/events/models.py b/authentik/events/models.py index 3bb2ff1458..017e3e8fe2 100644 --- a/authentik/events/models.py +++ b/authentik/events/models.py @@ -7,7 +7,6 @@ from difflib import get_close_matches from functools import lru_cache from inspect import currentframe from smtplib import SMTPException -from typing import Optional from uuid import uuid4 from django.apps import apps @@ -52,6 +51,8 @@ from authentik.stages.email.utils import TemplateEmailMessage from authentik.tenants.models import Tenant LOGGER = get_logger() +DISCORD_FIELD_LIMIT = 25 +NOTIFICATION_SUMMARY_LENGTH = 75 def default_event_duration(): @@ -65,7 +66,7 @@ def default_brand(): return sanitize_dict(model_to_dict(DEFAULT_BRAND)) -@lru_cache() +@lru_cache def django_app_names() -> list[str]: """Get a cached list of all django apps' names (not labels)""" return [x.name for x in apps.app_configs.values()] @@ -198,7 +199,7 @@ class Event(SerializerModel, ExpiringModel): @staticmethod def new( action: str | EventAction, - app: Optional[str] = None, + app: str | None = None, **kwargs, ) -> "Event": """Create new Event instance from arguments. Instance is NOT saved.""" @@ -224,7 +225,7 @@ class Event(SerializerModel, ExpiringModel): self.user = get_user(user) return self - def from_http(self, request: HttpRequest, user: Optional[User] = None) -> "Event": + def from_http(self, request: HttpRequest, user: User | None = None) -> "Event": """Add data from a Django-HttpRequest, allowing the creation of Events independently from requests. `user` arguments optionally overrides user from requests.""" @@ -418,7 +419,7 @@ class NotificationTransport(SerializerModel): if not isinstance(value, str): continue # https://birdie0.github.io/discord-webhooks-guide/other/field_limits.html - if len(fields) >= 25: + if len(fields) >= DISCORD_FIELD_LIMIT: continue fields.append({"title": key[:256], "value": value[:1024]}) body = { @@ -472,7 +473,7 @@ class NotificationTransport(SerializerModel): continue context["key_value"][key] = value else: - context["title"] += notification.body[:75] + context["title"] += notification.body[:NOTIFICATION_SUMMARY_LENGTH] # TODO: improve permission check if notification.user.is_superuser: context["source"] = { @@ -489,7 +490,7 @@ class NotificationTransport(SerializerModel): try: from authentik.stages.email.tasks import send_mail - return send_mail(mail.__dict__) # pylint: disable=no-value-for-parameter + return send_mail(mail.__dict__) except (SMTPException, ConnectionError, OSError) as exc: raise NotificationTransportError(exc) from exc @@ -533,7 +534,11 @@ class Notification(SerializerModel): return NotificationSerializer def __str__(self) -> str: - body_trunc = (self.body[:75] + "..") if len(self.body) > 75 else self.body + body_trunc = ( + (self.body[:NOTIFICATION_SUMMARY_LENGTH] + "..") + if len(self.body) > NOTIFICATION_SUMMARY_LENGTH + else self.body + ) return f"Notification for user {self.user}: {body_trunc}" class Meta: diff --git a/authentik/events/signals.py b/authentik/events/signals.py index 2f29491708..160cb219f3 100644 --- a/authentik/events/signals.py +++ b/authentik/events/signals.py @@ -1,6 +1,6 @@ """authentik events signal listener""" -from typing import Any, Optional +from typing import Any from django.contrib.auth.signals import user_logged_in, user_logged_out from django.db.models.signals import post_save, pre_delete @@ -42,7 +42,7 @@ def on_user_logged_in(sender, request: HttpRequest, user: User, **_): request.session[SESSION_LOGIN_EVENT] = event -def get_login_event(request: HttpRequest) -> Optional[Event]: +def get_login_event(request: HttpRequest) -> Event | None: """Wrapper to get login event that can be mocked in tests""" return request.session.get(SESSION_LOGIN_EVENT, None) @@ -71,7 +71,7 @@ def on_login_failed( sender, credentials: dict[str, str], request: HttpRequest, - stage: Optional[Stage] = None, + stage: Stage | None = None, **kwargs, ): """Failed Login, authentik custom event""" diff --git a/authentik/events/system_tasks.py b/authentik/events/system_tasks.py index 625203f42a..f8de8530eb 100644 --- a/authentik/events/system_tasks.py +++ b/authentik/events/system_tasks.py @@ -2,16 +2,15 @@ from datetime import datetime, timedelta from time import perf_counter -from typing import Any, Optional +from typing import Any from django.utils.timezone import now from django.utils.translation import gettext_lazy as _ from structlog.stdlib import get_logger from tenant_schemas_celery.task import TenantTask -from authentik.events.models import Event, EventAction +from authentik.events.models import Event, EventAction, TaskStatus from authentik.events.models import SystemTask as DBSystemTask -from authentik.events.models import TaskStatus from authentik.events.utils import sanitize_item from authentik.lib.utils.errors import exception_to_string @@ -27,10 +26,10 @@ class SystemTask(TenantTask): _status: TaskStatus _messages: list[str] - _uid: Optional[str] + _uid: str | None # Precise start time from perf_counter - _start_precise: Optional[float] = None - _start: Optional[datetime] = None + _start_precise: float | None = None + _start: datetime | None = None def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -60,14 +59,13 @@ class SystemTask(TenantTask): self._start = now() return super().before_start(task_id, args, kwargs) - def db(self) -> Optional[DBSystemTask]: + def db(self) -> DBSystemTask | None: """Get DB object for latest task""" return DBSystemTask.objects.filter( name=self.__name__, uid=self._uid, ).first() - # pylint: disable=too-many-arguments def after_return(self, status, retval, task_id, args: list[Any], kwargs: dict[str, Any], einfo): super().after_return(status, retval, task_id, args, kwargs, einfo=einfo) if not self._status: @@ -97,7 +95,6 @@ class SystemTask(TenantTask): }, ) - # pylint: disable=too-many-arguments def on_failure(self, exc, task_id, args, kwargs, einfo): super().on_failure(exc, task_id, args, kwargs, einfo=einfo) if not self._status: diff --git a/authentik/events/tasks.py b/authentik/events/tasks.py index e196712cb9..db08715f53 100644 --- a/authentik/events/tasks.py +++ b/authentik/events/tasks.py @@ -1,7 +1,5 @@ """Event notification tasks""" -from typing import Optional - from django.db.models.query_utils import Q from guardian.shortcuts import get_anonymous_user from structlog.stdlib import get_logger @@ -38,7 +36,7 @@ def event_trigger_handler(event_uuid: str, trigger_name: str): if not event: LOGGER.warning("event doesn't exist yet or anymore", event_uuid=event_uuid) return - trigger: Optional[NotificationRule] = NotificationRule.objects.filter(name=trigger_name).first() + trigger: NotificationRule | None = NotificationRule.objects.filter(name=trigger_name).first() if not trigger: return diff --git a/authentik/events/tests/test_event.py b/authentik/events/tests/test_event.py index 78771d472d..b40aad7be7 100644 --- a/authentik/events/tests/test_event.py +++ b/authentik/events/tests/test_event.py @@ -105,7 +105,7 @@ class TestEvents(TestCase): # Test brand request = self.factory.get("/") brand = Brand(domain="test-brand") - setattr(request, "brand", brand) + request.brand = brand event = Event.new("unittest").from_http(request) self.assertEqual( event.brand, diff --git a/authentik/events/utils.py b/authentik/events/utils.py index b9071435c4..ab1778f446 100644 --- a/authentik/events/utils.py +++ b/authentik/events/utils.py @@ -7,7 +7,7 @@ from datetime import date, datetime, time, timedelta from enum import Enum from pathlib import Path from types import GeneratorType, NoneType -from typing import Any, Optional +from typing import Any from uuid import UUID from django.contrib.auth.models import AnonymousUser @@ -37,7 +37,7 @@ def cleanse_item(key: str, value: Any) -> Any: """Cleanse a single item""" if isinstance(value, dict): return cleanse_dict(value) - if isinstance(value, (list, tuple, set)): + if isinstance(value, list | tuple | set): for idx, item in enumerate(value): value[idx] = cleanse_item(key, item) return value @@ -74,7 +74,7 @@ def model_to_dict(model: Model) -> dict[str, Any]: } -def get_user(user: User | AnonymousUser, original_user: Optional[User] = None) -> dict[str, Any]: +def get_user(user: User | AnonymousUser, original_user: User | None = None) -> dict[str, Any]: """Convert user object to dictionary, optionally including the original user""" if isinstance(user, AnonymousUser): try: @@ -95,8 +95,7 @@ def get_user(user: User | AnonymousUser, original_user: Optional[User] = None) - return user_data -# pylint: disable=too-many-return-statements,too-many-branches -def sanitize_item(value: Any) -> Any: +def sanitize_item(value: Any) -> Any: # noqa: PLR0911, PLR0912 """Sanitize a single item, ensure it is JSON parsable""" if is_dataclass(value): # Because asdict calls `copy.deepcopy(obj)` on everything that's not tuple/dict, @@ -115,20 +114,20 @@ def sanitize_item(value: Any) -> Any: return sanitize_dict(value) if isinstance(value, GeneratorType): return sanitize_item(list(value)) - if isinstance(value, (list, tuple, set)): + if isinstance(value, list | tuple | set): new_values = [] for item in value: new_value = sanitize_item(item) if new_value: new_values.append(new_value) return new_values - if isinstance(value, (User, AnonymousUser)): + if isinstance(value, User | AnonymousUser): return sanitize_dict(get_user(value)) if isinstance(value, models.Model): return sanitize_dict(model_to_dict(value)) if isinstance(value, UUID): return value.hex - if isinstance(value, (HttpRequest, WSGIRequest)): + if isinstance(value, HttpRequest | WSGIRequest): return ... if isinstance(value, City): return GEOIP_CONTEXT_PROCESSOR.city_to_dict(value) @@ -171,7 +170,7 @@ def sanitize_item(value: Any) -> Any: "module": value.__module__, } # List taken from the stdlib's JSON encoder (_make_iterencode, encoder.py:415) - if isinstance(value, (bool, int, float, NoneType, list, tuple, dict)): + if isinstance(value, bool | int | float | NoneType | list | tuple | dict): return value try: return DjangoJSONEncoder().default(value) diff --git a/authentik/flows/api/flows.py b/authentik/flows/api/flows.py index 3947a7057d..7d96778998 100644 --- a/authentik/flows/api/flows.py +++ b/authentik/flows/api/flows.py @@ -114,7 +114,6 @@ class FlowImportResultSerializer(PassiveSerializer): class FlowViewSet(UsedByMixin, ModelViewSet): """Flow Viewset""" - # pylint: disable=no-member queryset = Flow.objects.all().prefetch_related("stages", "policies") serializer_class = FlowSerializer lookup_field = "slug" @@ -279,7 +278,7 @@ class FlowViewSet(UsedByMixin, ModelViewSet): }, ) @action(detail=True, pagination_class=None, filter_backends=[]) - def execute(self, request: Request, slug: str): + def execute(self, request: Request, _slug: str): """Execute flow for current user""" # Because we pre-plan the flow here, and not in the planner, we need to manually clear # the history of the inspector @@ -294,8 +293,9 @@ class FlowViewSet(UsedByMixin, ModelViewSet): return bad_request_message( request, _( - "Flow not applicable to current user/request: %(messages)s" - % {"messages": exc.messages} + "Flow not applicable to current user/request: {messages}".format_map( + {"messages": exc.messages} + ) ), ) return Response( diff --git a/authentik/flows/api/flows_diagram.py b/authentik/flows/api/flows_diagram.py index ba0056c14a..784eeed8f3 100644 --- a/authentik/flows/api/flows_diagram.py +++ b/authentik/flows/api/flows_diagram.py @@ -1,7 +1,6 @@ """Flows Diagram API""" from dataclasses import dataclass, field -from typing import Optional from django.utils.translation import gettext as _ from guardian.shortcuts import get_objects_for_user @@ -18,8 +17,8 @@ class DiagramElement: identifier: str description: str - action: Optional[str] = None - source: Optional[list["DiagramElement"]] = None + action: str | None = None + source: list["DiagramElement"] | None = None style: list[str] = field(default_factory=lambda: ["[", "]"]) @@ -66,10 +65,10 @@ class FlowDiagram: ): element = DiagramElement( f"flow_policy_{p_index}", - _("Policy (%(type)s)" % {"type": policy_binding.policy._meta.verbose_name}) + _("Policy ({type})".format_map({"type": policy_binding.policy._meta.verbose_name})) + "\n" + policy_binding.policy.name, - _("Binding %(order)d" % {"order": policy_binding.order}), + _("Binding {order}".format_map({"order": policy_binding.order})), parent_elements, style=["{{", "}}"], ) @@ -92,7 +91,7 @@ class FlowDiagram: ): element = DiagramElement( f"stage_{stage_index}_policy_{p_index}", - _("Policy (%(type)s)" % {"type": policy_binding.policy._meta.verbose_name}) + _("Policy ({type})".format_map({"type": policy_binding.policy._meta.verbose_name})) + "\n" + policy_binding.policy.name, "", @@ -120,7 +119,7 @@ class FlowDiagram: element = DiagramElement( f"stage_{s_index}", - _("Stage (%(type)s)" % {"type": stage_binding.stage._meta.verbose_name}) + _("Stage ({type})".format_map({"type": stage_binding.stage._meta.verbose_name})) + "\n" + stage_binding.stage.name, action, diff --git a/authentik/flows/apps.py b/authentik/flows/apps.py index da36733f45..09f85db1c9 100644 --- a/authentik/flows/apps.py +++ b/authentik/flows/apps.py @@ -37,4 +37,4 @@ class AuthentikFlowsConfig(ManagedAppConfig): from authentik.flows.models import Stage for stage in all_subclasses(Stage): - _ = stage().type + _ = stage().view diff --git a/authentik/flows/challenge.py b/authentik/flows/challenge.py index b04ca18199..03d6b5b819 100644 --- a/authentik/flows/challenge.py +++ b/authentik/flows/challenge.py @@ -104,7 +104,7 @@ class FlowErrorChallenge(Challenge): error = CharField(required=False) traceback = CharField(required=False) - def __init__(self, request: Optional[Request] = None, error: Optional[Exception] = None): + def __init__(self, request: Request | None = None, error: Exception | None = None): super().__init__(data={}) if not request or not error: return diff --git a/authentik/flows/exceptions.py b/authentik/flows/exceptions.py index 04f8855577..81b9793507 100644 --- a/authentik/flows/exceptions.py +++ b/authentik/flows/exceptions.py @@ -1,7 +1,5 @@ """flow exceptions""" -from typing import Optional - from django.utils.translation import gettext_lazy as _ from authentik.lib.sentry import SentryIgnoredException @@ -11,7 +9,7 @@ from authentik.policies.types import PolicyResult class FlowNonApplicableException(SentryIgnoredException): """Flow does not apply to current user (denied by policy, or otherwise).""" - policy_result: Optional[PolicyResult] = None + policy_result: PolicyResult | None = None @property def messages(self) -> str: diff --git a/authentik/flows/markers.py b/authentik/flows/markers.py index 0aeaa66c60..ac5509c958 100644 --- a/authentik/flows/markers.py +++ b/authentik/flows/markers.py @@ -1,7 +1,7 @@ """Stage Markers""" from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from django.http.request import HttpRequest from structlog.stdlib import get_logger @@ -25,7 +25,7 @@ class StageMarker: plan: "FlowPlan", binding: FlowStageBinding, http_request: HttpRequest, - ) -> Optional[FlowStageBinding]: + ) -> FlowStageBinding | None: """Process callback for this marker. This should be overridden by sub-classes. If a stage should be removed, return None.""" return binding @@ -42,7 +42,7 @@ class ReevaluateMarker(StageMarker): plan: "FlowPlan", binding: FlowStageBinding, http_request: HttpRequest, - ) -> Optional[FlowStageBinding]: + ) -> FlowStageBinding | None: """Re-evaluate policies bound to stage, and if they fail, remove from plan""" from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER diff --git a/authentik/flows/migrations/0027_auto_20231028_1424.py b/authentik/flows/migrations/0027_auto_20231028_1424.py index 70784e6e37..49894c669d 100644 --- a/authentik/flows/migrations/0027_auto_20231028_1424.py +++ b/authentik/flows/migrations/0027_auto_20231028_1424.py @@ -16,7 +16,7 @@ def set_oobe_flow_authentication(apps: Apps, schema_editor: BaseDatabaseSchemaEd users = User.objects.using(db_alias).exclude(username="akadmin") try: users = users.exclude(pk=get_anonymous_user().pk) - # pylint: disable=broad-except + except Exception: # nosec pass diff --git a/authentik/flows/models.py b/authentik/flows/models.py index e7efaca816..f34b0e3472 100644 --- a/authentik/flows/models.py +++ b/authentik/flows/models.py @@ -2,7 +2,7 @@ from base64 import b64decode, b64encode from pickle import dumps, loads # nosec -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from uuid import uuid4 from django.db import models @@ -83,7 +83,7 @@ class Stage(SerializerModel): objects = InheritanceManager() @property - def type(self) -> type["StageView"]: + def view(self) -> type["StageView"]: """Return StageView class that implements logic for this stage""" # This is a bit of a workaround, since we can't set class methods with setattr if hasattr(self, "__in_memory_type"): @@ -95,7 +95,7 @@ class Stage(SerializerModel): """Return component used to edit this object""" raise NotImplementedError - def ui_user_settings(self) -> Optional[UserSettingSerializer]: + def ui_user_settings(self) -> UserSettingSerializer | None: """Entrypoint to integrate with User settings. Can either return None if no user settings are available, or a challenge.""" return None @@ -113,8 +113,8 @@ def in_memory_stage(view: type["StageView"], **kwargs) -> Stage: # we set the view as a separate property and reference a generic function # that returns that member setattr(stage, "__in_memory_type", view) - setattr(stage, "name", _("Dynamic In-memory stage: %(doc)s" % {"doc": view.__doc__})) - setattr(stage._meta, "verbose_name", class_to_path(view)) + stage.name = _("Dynamic In-memory stage: {doc}".format_map({"doc": view.__doc__})) + stage._meta.verbose_name = class_to_path(view) for key, value in kwargs.items(): setattr(stage, key, value) return stage diff --git a/authentik/flows/planner.py b/authentik/flows/planner.py index 28f34a2655..a4c8c0c1ac 100644 --- a/authentik/flows/planner.py +++ b/authentik/flows/planner.py @@ -1,7 +1,7 @@ """Flows Planner""" from dataclasses import dataclass, field -from typing import Any, Optional +from typing import Any from django.core.cache import cache from django.http import HttpRequest @@ -39,7 +39,7 @@ CACHE_TIMEOUT = CONFIG.get_int("cache.timeout_flows") CACHE_PREFIX = "goauthentik.io/flows/planner/" -def cache_key(flow: Flow, user: Optional[User] = None) -> str: +def cache_key(flow: Flow, user: User | None = None) -> str: """Generate Cache key for flow""" prefix = CACHE_PREFIX + str(flow.pk) if user: @@ -58,16 +58,16 @@ class FlowPlan: context: dict[str, Any] = field(default_factory=dict) markers: list[StageMarker] = field(default_factory=list) - def append_stage(self, stage: Stage, marker: Optional[StageMarker] = None): + def append_stage(self, stage: Stage, marker: StageMarker | None = None): """Append `stage` to all stages, optionally with stage marker""" return self.append(FlowStageBinding(stage=stage), marker) - def append(self, binding: FlowStageBinding, marker: Optional[StageMarker] = None): + def append(self, binding: FlowStageBinding, marker: StageMarker | None = None): """Append `stage` to all stages, optionally with stage marker""" self.bindings.append(binding) self.markers.append(marker or StageMarker()) - def insert_stage(self, stage: Stage, marker: Optional[StageMarker] = None): + def insert_stage(self, stage: Stage, marker: StageMarker | None = None): """Insert stage into plan, as immediate next stage""" self.bindings.insert(1, FlowStageBinding(stage=stage, order=0)) self.markers.insert(1, marker or StageMarker()) @@ -78,7 +78,7 @@ class FlowPlan: self.insert_stage(in_memory_stage(RedirectStage, destination=destination)) - def next(self, http_request: Optional[HttpRequest]) -> Optional[FlowStageBinding]: + def next(self, http_request: HttpRequest | None) -> FlowStageBinding | None: """Return next pending stage from the bottom of the list""" if not self.has_stages: return None @@ -94,7 +94,7 @@ class FlowPlan: self.markers.remove(marker) if not self.has_stages: return None - # pylint: disable=not-callable + return self.next(http_request) return marked_stage @@ -148,9 +148,7 @@ class FlowPlanner: if not outpost_user: raise FlowNonApplicableException() - def plan( - self, request: HttpRequest, default_context: Optional[dict[str, Any]] = None - ) -> FlowPlan: + def plan(self, request: HttpRequest, default_context: dict[str, Any] | None = None) -> FlowPlan: """Check each of the flows' policies, check policies for each stage with PolicyBinding and return ordered list""" with Hub.current.start_span( @@ -214,7 +212,7 @@ class FlowPlanner: self, user: User, request: HttpRequest, - default_context: Optional[dict[str, Any]], + default_context: dict[str, Any] | None, ) -> FlowPlan: """Build flow plan by checking each stage in their respective order and checking the applied policies""" diff --git a/authentik/flows/stage.py b/authentik/flows/stage.py index e96cb33322..b12838e96d 100644 --- a/authentik/flows/stage.py +++ b/authentik/flows/stage.py @@ -1,6 +1,6 @@ """authentik stage Base view""" -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from django.contrib.auth.models import AnonymousUser from django.http import HttpRequest @@ -153,7 +153,7 @@ class ChallengeStageView(StageView): "app": self.executor.plan.context.get(PLAN_CONTEXT_APPLICATION, ""), "user": self.get_pending_user(for_display=True), } - # pylint: disable=broad-except + except Exception as exc: self.logger.warning("failed to template title", exc=exc) return self.executor.flow.title @@ -234,9 +234,9 @@ class ChallengeStageView(StageView): class AccessDeniedChallengeView(ChallengeStageView): """Used internally by FlowExecutor's stage_invalid()""" - error_message: Optional[str] + error_message: str | None - def __init__(self, executor: "FlowExecutorView", error_message: Optional[str] = None, **kwargs): + def __init__(self, executor: "FlowExecutorView", error_message: str | None = None, **kwargs): super().__init__(executor, **kwargs) self.error_message = error_message diff --git a/authentik/flows/tests/__init__.py b/authentik/flows/tests/__init__.py index 2123611845..f846563b04 100644 --- a/authentik/flows/tests/__init__.py +++ b/authentik/flows/tests/__init__.py @@ -1,7 +1,7 @@ """Test helpers""" from json import loads -from typing import Any, Optional +from typing import Any from django.http.response import HttpResponse from django.urls.base import reverse @@ -15,12 +15,11 @@ from authentik.flows.models import Flow class FlowTestCase(APITestCase): """Helpers for testing flows and stages.""" - # pylint: disable=invalid-name def assertStageResponse( self, response: HttpResponse, - flow: Optional[Flow] = None, - user: Optional[User] = None, + flow: Flow | None = None, + user: User | None = None, **kwargs, ) -> dict[str, Any]: """Assert various attributes of a stage response""" @@ -45,7 +44,6 @@ class FlowTestCase(APITestCase): self.assertEqual(raw_response[key], expected) return raw_response - # pylint: disable=invalid-name def assertStageRedirects(self, response: HttpResponse, to: str) -> dict[str, Any]: """Wrapper around assertStageResponse that checks for a redirect""" return self.assertStageResponse( diff --git a/authentik/flows/tests/test_stage_model.py b/authentik/flows/tests/test_stage_model.py index 0987bf2f44..a6bac2a81b 100644 --- a/authentik/flows/tests/test_stage_model.py +++ b/authentik/flows/tests/test_stage_model.py @@ -1,6 +1,6 @@ """base model tests""" -from typing import Callable +from collections.abc import Callable from django.test import TestCase @@ -22,7 +22,7 @@ def model_tester_factory(test_model: type[Stage]) -> Callable: model_class = test_model.__bases__[0]() else: model_class = test_model() - self.assertTrue(issubclass(model_class.type, StageView)) + self.assertTrue(issubclass(model_class.view, StageView)) self.assertIsNotNone(test_model.component) _ = model_class.ui_user_settings() diff --git a/authentik/flows/tests/test_stage_views.py b/authentik/flows/tests/test_stage_views.py index 46364ef298..a9ec1e8a4a 100644 --- a/authentik/flows/tests/test_stage_views.py +++ b/authentik/flows/tests/test_stage_views.py @@ -1,6 +1,6 @@ """stage view tests""" -from typing import Callable +from collections.abc import Callable from django.test import RequestFactory, TestCase diff --git a/authentik/flows/views/executor.py b/authentik/flows/views/executor.py index 71ee200402..c68c8e4d43 100644 --- a/authentik/flows/views/executor.py +++ b/authentik/flows/views/executor.py @@ -1,7 +1,6 @@ """authentik multi-stage authentication engine""" from copy import deepcopy -from typing import Optional from django.conf import settings from django.contrib.auth.mixins import LoginRequiredMixin @@ -107,8 +106,8 @@ class FlowExecutorView(APIView): flow: Flow - plan: Optional[FlowPlan] = None - current_binding: Optional[FlowStageBinding] = None + plan: FlowPlan | None = None + current_binding: FlowStageBinding | None = None current_stage: Stage current_stage_view: View @@ -136,9 +135,9 @@ class FlowExecutorView(APIView): ) return to_stage_response(self.request, self.stage_invalid(error_message=exc.messages)) - def _check_flow_token(self, key: str) -> Optional[FlowPlan]: + def _check_flow_token(self, key: str) -> FlowPlan | None: """Check if the user is using a flow token to restore a plan""" - token: Optional[FlowToken] = FlowToken.filter_not_expired(key=key).first() + token: FlowToken | None = FlowToken.filter_not_expired(key=key).first() if not token: return None plan = None @@ -154,7 +153,6 @@ class FlowExecutorView(APIView): self._logger.debug("f(exec): restored flow plan from token", plan=plan) return plan - # pylint: disable=too-many-return-statements def dispatch(self, request: HttpRequest, flow_slug: str) -> HttpResponse: with Hub.current.start_span( op="authentik.flow.executor.dispatch", description=self.flow.slug @@ -201,7 +199,7 @@ class FlowExecutorView(APIView): # if the cached plan is from an older version, it might have different attributes # in which case we just delete the plan and invalidate everything next_binding = self.plan.next(self.request) - except Exception as exc: # pylint: disable=broad-except + except Exception as exc: self._logger.warning( "f(exec): found incompatible flow plan, invalidating run", exc=exc ) @@ -219,7 +217,7 @@ class FlowExecutorView(APIView): flow_slug=self.flow.slug, ) try: - stage_cls = self.current_stage.type + stage_cls = self.current_stage.view except NotImplementedError as exc: self._logger.debug("Error getting stage type", exc=exc) return self.stage_invalid() @@ -290,7 +288,7 @@ class FlowExecutorView(APIView): span.set_data("authentik Flow", self.flow.slug) stage_response = self.current_stage_view.dispatch(request) return to_stage_response(request, stage_response) - except Exception as exc: # pylint: disable=broad-except + except Exception as exc: return self.handle_exception(exc) @extend_schema( @@ -341,7 +339,7 @@ class FlowExecutorView(APIView): span.set_data("authentik Flow", self.flow.slug) stage_response = self.current_stage_view.dispatch(request) return to_stage_response(request, stage_response) - except Exception as exc: # pylint: disable=broad-except + except Exception as exc: return self.handle_exception(exc) def _initiate_plan(self) -> FlowPlan: @@ -353,7 +351,7 @@ class FlowExecutorView(APIView): # there are no issues with the class we might've gotten # from the cache. If there are errors, just delete all cached flows _ = plan.has_stages - except Exception: # pylint: disable=broad-except + except Exception: keys = cache.keys(f"{CACHE_PREFIX}*") cache.delete_many(keys) return self._initiate_plan() @@ -421,7 +419,7 @@ class FlowExecutorView(APIView): ) return self._flow_done() - def stage_invalid(self, error_message: Optional[str] = None) -> HttpResponse: + def stage_invalid(self, error_message: str | None = None) -> HttpResponse: """Callback used stage when data is correct but a policy denies access or the user account is disabled. @@ -479,9 +477,9 @@ class CancelView(View): class ToDefaultFlow(View): """Redirect to default flow matching by designation""" - designation: Optional[FlowDesignation] = None + designation: FlowDesignation | None = None - def flow_by_policy(self, request: HttpRequest, **flow_filter) -> Optional[Flow]: + def flow_by_policy(self, request: HttpRequest, **flow_filter) -> Flow | None: """Get a Flow by `**flow_filter` and check if the request from `request` can access it.""" flows = Flow.objects.filter(**flow_filter).order_by("slug") for flow in flows: @@ -503,9 +501,7 @@ class ToDefaultFlow(View): if self.designation == FlowDesignation.AUTHENTICATION: flow = brand.flow_authentication # Check if we have a default flow from application - application: Optional[Application] = self.request.session.get( - SESSION_KEY_APPLICATION_PRE - ) + application: Application | None = self.request.session.get(SESSION_KEY_APPLICATION_PRE) if application and application.provider and application.provider.authentication_flow: flow = application.provider.authentication_flow elif self.designation == FlowDesignation.INVALIDATION: @@ -535,7 +531,10 @@ class ToDefaultFlow(View): def to_stage_response(request: HttpRequest, source: HttpResponse) -> HttpResponse: """Convert normal HttpResponse into JSON Response""" - if isinstance(source, HttpResponseRedirect) or source.status_code == 302: + if ( + isinstance(source, HttpResponseRedirect) + or source.status_code == HttpResponseRedirect.status_code + ): redirect_url = source["Location"] # Redirects to the same URL usually indicate an Error within a form if request.get_full_path() == redirect_url: @@ -599,7 +598,7 @@ class ConfigureFlowInitView(LoginRequiredMixin, View): ) except FlowNonApplicableException: LOGGER.warning("Flow not applicable to user") - raise Http404 + raise Http404 from None request.session[SESSION_KEY_PLAN] = plan return redirect_with_qs( "authentik_core:if-flow", diff --git a/authentik/flows/views/inspector.py b/authentik/flows/views/inspector.py index 8af581f756..51d008137f 100644 --- a/authentik/flows/views/inspector.py +++ b/authentik/flows/views/inspector.py @@ -26,6 +26,8 @@ from authentik.flows.planner import FlowPlan from authentik.flows.views.executor import SESSION_KEY_HISTORY, SESSION_KEY_PLAN from authentik.root.install_id import get_install_id +MIN_FLOW_LENGTH = 2 + class FlowInspectorPlanSerializer(PassiveSerializer): """Serializer for an active FlowPlan""" @@ -41,7 +43,7 @@ class FlowInspectorPlanSerializer(PassiveSerializer): def get_next_planned_stage(self, plan: FlowPlan) -> FlowStageBindingSerializer: """Get the next planned stage""" - if len(plan.bindings) < 2: + if len(plan.bindings) < MIN_FLOW_LENGTH: return FlowStageBindingSerializer().data return FlowStageBindingSerializer(instance=plan.bindings[1]).data @@ -49,7 +51,7 @@ class FlowInspectorPlanSerializer(PassiveSerializer): """Get the plan's context, sanitized""" return sanitize_dict(plan.context) - def get_session_id(self, plan: FlowPlan) -> str: + def get_session_id(self, _plan: FlowPlan) -> str: """Get a unique session ID""" request: Request = self.context["request"] return sha256( diff --git a/authentik/lib/avatars.py b/authentik/lib/avatars.py index 593ef6b828..f25a71f1bc 100644 --- a/authentik/lib/avatars.py +++ b/authentik/lib/avatars.py @@ -3,11 +3,11 @@ from base64 import b64encode from functools import cache as funccache from hashlib import md5 -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from urllib.parse import urlencode from django.core.cache import cache -from django.http import HttpRequest +from django.http import HttpRequest, HttpResponseNotFound from django.templatetags.static import static from lxml import etree # nosec from lxml.etree import Element, SubElement # nosec @@ -37,18 +37,18 @@ SVG_FONTS = [ ] -def avatar_mode_none(user: "User", mode: str) -> Optional[str]: +def avatar_mode_none(user: "User", mode: str) -> str | None: """No avatar""" return DEFAULT_AVATAR -def avatar_mode_attribute(user: "User", mode: str) -> Optional[str]: +def avatar_mode_attribute(user: "User", mode: str) -> str | None: """Avatars based on a user attribute""" avatar = get_path_from_dict(user.attributes, mode[11:], default=None) return avatar -def avatar_mode_gravatar(user: "User", mode: str) -> Optional[str]: +def avatar_mode_gravatar(user: "User", mode: str) -> str | None: """Gravatar avatars""" # gravatar uses md5 for their URLs, so md5 can't be avoided mail_hash = md5(user.email.lower().encode("utf-8")).hexdigest() # nosec @@ -65,7 +65,7 @@ def avatar_mode_gravatar(user: "User", mode: str) -> Optional[str]: # (HEAD since we don't need the body) # so if that returns a 404, move onto the next mode res = get_http_session().head(gravatar_url, timeout=5) - if res.status_code == 404: + if res.status_code == HttpResponseNotFound.status_code: cache.set(full_key, None) return None res.raise_for_status() @@ -86,12 +86,13 @@ def generate_colors(text: str) -> tuple[str, str]: red = min(max((color >> 16) & 0xFF, 55), 200) bg_hex = f"{red:02x}{green:02x}{blue:02x}" # Contrasting text color (https://stackoverflow.com/a/3943023) - text_hex = "000" if (red * 0.299 + green * 0.587 + blue * 0.114) > 186 else "fff" + text_hex = ( + "000" if (red * 0.299 + green * 0.587 + blue * 0.114) > 186 else "fff" # noqa: PLR2004 + ) return bg_hex, text_hex @funccache -# pylint: disable=too-many-arguments,too-many-locals def generate_avatar_from_name( name: str, length: int = 2, @@ -107,7 +108,7 @@ def generate_avatar_from_name( """ name_parts = name.split() # Only abbreviate first and last name - if len(name_parts) > 2: + if len(name_parts) > 2: # noqa: PLR2004 name_parts = [name_parts[0], name_parts[-1]] if len(name_parts) == 1: @@ -155,7 +156,7 @@ def generate_avatar_from_name( return etree.tostring(root_element).decode() -def avatar_mode_generated(user: "User", mode: str) -> Optional[str]: +def avatar_mode_generated(user: "User", mode: str) -> str | None: """Wrapper that converts generated avatar to base64 svg""" # By default generate based off of user's display name name = user.name.strip() @@ -169,7 +170,7 @@ def avatar_mode_generated(user: "User", mode: str) -> Optional[str]: return f"data:image/svg+xml;base64,{b64encode(svg.encode('utf-8')).decode('utf-8')}" -def avatar_mode_url(user: "User", mode: str) -> Optional[str]: +def avatar_mode_url(user: "User", mode: str) -> str | None: """Format url""" mail_hash = md5(user.email.lower().encode("utf-8")).hexdigest() # nosec return mode % { @@ -179,7 +180,7 @@ def avatar_mode_url(user: "User", mode: str) -> Optional[str]: } -def get_avatar(user: "User", request: Optional[HttpRequest] = None) -> str: +def get_avatar(user: "User", request: HttpRequest | None = None) -> str: """Get avatar with configured mode""" mode_map = { "none": avatar_mode_none, diff --git a/authentik/lib/config.py b/authentik/lib/config.py index 8d499e9b94..6b21cc5da2 100644 --- a/authentik/lib/config.py +++ b/authentik/lib/config.py @@ -13,7 +13,7 @@ from json.decoder import JSONDecodeError from pathlib import Path from sys import argv, stderr from time import time -from typing import Any, Optional +from typing import Any from urllib.parse import urlparse import yaml @@ -89,7 +89,7 @@ class Attr: # depending on source_type, might contain the environment variable or the path # to the config file containing this change or the file containing this value - source: Optional[str] = field(default=None) + source: str | None = field(default=None) def __post_init__(self): if isinstance(self.value, Attr): @@ -190,16 +190,16 @@ class ConfigLoader: def update(self, root: dict[str, Any], updatee: dict[str, Any]) -> dict[str, Any]: """Recursively update dictionary""" - for key, value in updatee.items(): - if isinstance(value, Mapping): - root[key] = self.update(root.get(key, {}), value) + for key, raw_value in updatee.items(): + if isinstance(raw_value, Mapping): + root[key] = self.update(root.get(key, {}), raw_value) else: - if isinstance(value, str): - value = self.parse_uri(value) - elif isinstance(value, Attr) and isinstance(value.value, str): - value = self.parse_uri(value.value) - elif not isinstance(value, Attr): - value = Attr(value) + if isinstance(raw_value, str): + value = self.parse_uri(raw_value) + elif isinstance(raw_value, Attr) and isinstance(raw_value.value, str): + value = self.parse_uri(raw_value.value) + elif not isinstance(raw_value, Attr): + value = Attr(raw_value) root[key] = value return root @@ -219,7 +219,7 @@ class ConfigLoader: parsed_value = os.getenv(url.netloc, url.query) if url.scheme == "file": try: - with open(url.path, "r", encoding="utf8") as _file: + with open(url.path, encoding="utf8") as _file: parsed_value = _file.read().strip() except OSError as exc: self.log("error", f"Failed to read config value from {url.path}: {exc}") @@ -257,7 +257,7 @@ class ConfigLoader: relative_key = key.replace(f"{ENV_PREFIX}_", "", 1).replace("__", ".").lower() # Check if the value is json, and try to load it try: - value = loads(value) + value = loads(value) # noqa: PLW2901 except JSONDecodeError: pass attr_value = Attr(value, Attr.Source.ENV, relative_key) @@ -330,7 +330,7 @@ CONFIG = ConfigLoader() if __name__ == "__main__": - if len(argv) < 2: + if len(argv) < 2: # noqa: PLR2004 print(dumps(CONFIG.raw, indent=4, cls=AttrEncoder)) else: print(CONFIG.get(argv[1])) diff --git a/authentik/lib/expression/evaluator.py b/authentik/lib/expression/evaluator.py index e4742e096e..24bfe2f6f8 100644 --- a/authentik/lib/expression/evaluator.py +++ b/authentik/lib/expression/evaluator.py @@ -2,9 +2,10 @@ import re import socket +from collections.abc import Iterable from ipaddress import ip_address, ip_network from textwrap import indent -from typing import Any, Iterable, Optional +from typing import Any from cachetools import TLRUCache, cached from django.core.exceptions import FieldError @@ -36,7 +37,7 @@ class BaseEvaluator: # Filename used for exec _filename: str - def __init__(self, filename: Optional[str] = None): + def __init__(self, filename: str | None = None): self._filename = filename if filename else "BaseEvaluator" # update website/docs/expressions/_objects.md # update website/docs/expressions/_functions.md @@ -60,7 +61,7 @@ class BaseEvaluator: @cached(cache=TLRUCache(maxsize=32, ttu=lambda key, value, now: now + 180)) @staticmethod - def expr_resolve_dns(host: str, ip_version: Optional[int] = None) -> list[str]: + def expr_resolve_dns(host: str, ip_version: int | None = None) -> list[str]: """Resolve host to a list of IPv4 and/or IPv6 addresses.""" # Although it seems to be fine (raising OSError), docs warn # against passing `None` for both the host and the port @@ -70,9 +71,9 @@ class BaseEvaluator: ip_list = [] family = 0 - if ip_version == 4: + if ip_version == 4: # noqa: PLR2004 family = socket.AF_INET - if ip_version == 6: + if ip_version == 6: # noqa: PLR2004 family = socket.AF_INET6 try: @@ -92,7 +93,7 @@ class BaseEvaluator: return ip_addr @staticmethod - def expr_flatten(value: list[Any] | Any) -> Optional[Any]: + def expr_flatten(value: list[Any] | Any) -> Any | None: """Flatten `value` if its a list""" if isinstance(value, list): if len(value) < 1: @@ -116,7 +117,7 @@ class BaseEvaluator: return user.all_groups().filter(**group_filters).exists() @staticmethod - def expr_user_by(**filters) -> Optional[User]: + def expr_user_by(**filters) -> User | None: """Get user by filters""" try: users = User.objects.filter(**filters) @@ -127,7 +128,7 @@ class BaseEvaluator: return None @staticmethod - def expr_func_user_has_authenticator(user: User, device_type: Optional[str] = None) -> bool: + def expr_func_user_has_authenticator(user: User, device_type: str | None = None) -> bool: """Check if a user has any authenticator devices, optionally matching *device_type*""" user_devices = devices_for_user(user) if device_type: @@ -204,7 +205,7 @@ class BaseEvaluator: # Yes this is an exec, yes it is potentially bad. Since we limit what variables are # available here, and these policies can only be edited by admins, this is a risk # we're willing to take. - # pylint: disable=exec-used + exec(ast_obj, self._globals, _locals) # nosec # noqa result = _locals["result"] except Exception as exc: diff --git a/authentik/lib/migrations.py b/authentik/lib/migrations.py index 65dc3a84bb..8f86be4af1 100644 --- a/authentik/lib/migrations.py +++ b/authentik/lib/migrations.py @@ -1,6 +1,6 @@ """Migration helpers""" -from typing import Iterable +from collections.abc import Iterable from django.apps.registry import Apps from django.db.backends.base.schema import BaseDatabaseSchemaEditor diff --git a/authentik/lib/models.py b/authentik/lib/models.py index 36a1e173b4..0b2ec9ac68 100644 --- a/authentik/lib/models.py +++ b/authentik/lib/models.py @@ -12,14 +12,14 @@ from rest_framework.serializers import BaseSerializer class SerializerModel(models.Model): """Base Abstract Model which has a serializer""" + class Meta: + abstract = True + @property def serializer(self) -> type[BaseSerializer]: """Get serializer for this model""" raise NotImplementedError - class Meta: - abstract = True - class CreatedUpdatedModel(models.Model): """Base Abstract Model to save created and update""" diff --git a/authentik/lib/sentry.py b/authentik/lib/sentry.py index 75198764bd..b42a299660 100644 --- a/authentik/lib/sentry.py +++ b/authentik/lib/sentry.py @@ -1,7 +1,7 @@ """authentik sentry integration""" from asyncio.exceptions import CancelledError -from typing import Any, Optional +from typing import Any from billiard.exceptions import SoftTimeLimitExceeded, WorkerLostError from celery.exceptions import CeleryError @@ -61,7 +61,7 @@ def sentry_init(**sentry_init_kwargs): }, } kwargs.update(**sentry_init_kwargs) - # pylint: disable=abstract-class-instantiated + sentry_sdk_init( dsn=CONFIG.get("error_reporting.sentry_dsn"), integrations=[ @@ -96,9 +96,9 @@ def traces_sampler(sampling_context: dict) -> float: return float(CONFIG.get("error_reporting.sample_rate", 0.1)) -def before_send(event: dict, hint: dict) -> Optional[dict]: +def before_send(event: dict, hint: dict) -> dict | None: """Check if error is database error, and ignore if so""" - # pylint: disable=no-name-in-module + from psycopg.errors import Error ignored_classes = ( diff --git a/authentik/lib/tests/test_config.py b/authentik/lib/tests/test_config.py index d2ef0a897b..d436787375 100644 --- a/authentik/lib/tests/test_config.py +++ b/authentik/lib/tests/test_config.py @@ -59,7 +59,7 @@ class TestConfig(TestCase): """Test URI parsing (file load)""" config = ConfigLoader() file, file_name = mkstemp() - write(file, "foo".encode()) + write(file, b"foo") _, file2_name = mkstemp() chmod(file2_name, 0o000) # Remove all permissions so we can't read the file self.assertEqual(config.parse_uri(f"file://{file_name}").value, "foo") @@ -70,12 +70,12 @@ class TestConfig(TestCase): def test_uri_file_update(self): """Test URI parsing (file load and update)""" file, file_name = mkstemp() - write(file, "foo".encode()) + write(file, b"foo") config = ConfigLoader(file_test=f"file://{file_name}") self.assertEqual(config.get("file_test"), "foo") # Update config file - write(file, "bar".encode()) + write(file, b"bar") config.refresh("file_test") self.assertEqual(config.get("file_test"), "foobar") @@ -91,9 +91,9 @@ class TestConfig(TestCase): """Test update_from_file""" config = ConfigLoader() file, file_name = mkstemp() - write(file, "{".encode()) + write(file, b"{") file2, file2_name = mkstemp() - write(file2, "{".encode()) + write(file2, b"{") chmod(file2_name, 0o000) # Remove all permissions so we can't read the file with self.assertRaises(ImproperlyConfigured): config.update_from_file(file_name) @@ -116,7 +116,7 @@ class TestConfig(TestCase): def test_get_dict_from_b64_json(self): """Test get_dict_from_b64_json""" config = ConfigLoader() - test_value = ' { "foo": "bar" } '.encode("utf-8") + test_value = b' { "foo": "bar" } ' b64_value = base64.b64encode(test_value) config.set("foo", b64_value) self.assertEqual(config.get_dict_from_b64_json("foo"), {"foo": "bar"}) @@ -124,7 +124,7 @@ class TestConfig(TestCase): def test_get_dict_from_b64_json_missing_brackets(self): """Test get_dict_from_b64_json with missing brackets""" config = ConfigLoader() - test_value = ' "foo": "bar" '.encode("utf-8") + test_value = b' "foo": "bar" ' b64_value = base64.b64encode(test_value) config.set("foo", b64_value) self.assertEqual(config.get_dict_from_b64_json("foo"), {"foo": "bar"}) diff --git a/authentik/lib/tests/test_serializer_model.py b/authentik/lib/tests/test_serializer_model.py index aea9d73fca..0d91884e31 100644 --- a/authentik/lib/tests/test_serializer_model.py +++ b/authentik/lib/tests/test_serializer_model.py @@ -1,6 +1,6 @@ """base model tests""" -from typing import Callable +from collections.abc import Callable from django.test import TestCase from rest_framework.serializers import BaseSerializer diff --git a/authentik/lib/tests/utils.py b/authentik/lib/tests/utils.py index 1345a747fd..c2f16c8dcc 100644 --- a/authentik/lib/tests/utils.py +++ b/authentik/lib/tests/utils.py @@ -20,9 +20,7 @@ def load_fixture(path: str, **kwargs) -> str: current = currentframe() parent = current.f_back calling_file_path = parent.f_globals["__file__"] - with open( - Path(calling_file_path).resolve().parent / Path(path), "r", encoding="utf-8" - ) as _fixture: + with open(Path(calling_file_path).resolve().parent / Path(path), encoding="utf-8") as _fixture: fixture = _fixture.read() try: return fixture % kwargs diff --git a/authentik/lib/utils/urls.py b/authentik/lib/utils/urls.py index b0e1b5493a..a374642404 100644 --- a/authentik/lib/utils/urls.py +++ b/authentik/lib/utils/urls.py @@ -1,6 +1,5 @@ """URL-related utils""" -from typing import Optional from urllib.parse import urlparse from django.http import HttpResponse, QueryDict @@ -17,9 +16,7 @@ def is_url_absolute(url): return bool(urlparse(url).netloc) -def redirect_with_qs( - view: str, get_query_set: Optional[QueryDict] = None, **kwargs -) -> HttpResponse: +def redirect_with_qs(view: str, get_query_set: QueryDict | None = None, **kwargs) -> HttpResponse: """Wrapper to redirect whilst keeping GET Parameters""" try: target = reverse(view, kwargs=kwargs) @@ -33,7 +30,7 @@ def redirect_with_qs( return redirect(target) -def reverse_with_qs(view: str, query: Optional[QueryDict] = None, **kwargs) -> str: +def reverse_with_qs(view: str, query: QueryDict | None = None, **kwargs) -> str: """Reverse a view to it's url but include get params""" url = reverse(view, **kwargs) if query: diff --git a/authentik/lib/validators.py b/authentik/lib/validators.py index 668f84c7b7..31fb3f6f03 100644 --- a/authentik/lib/validators.py +++ b/authentik/lib/validators.py @@ -1,7 +1,5 @@ """Serializer validators""" -from typing import Optional - from django.utils.translation import gettext_lazy as _ from rest_framework.exceptions import ValidationError from rest_framework.serializers import Serializer @@ -16,7 +14,7 @@ class RequiredTogetherValidator: requires_context = True message = _("The fields {field_names} must be used together.") - def __init__(self, fields: list[str], message: Optional[str] = None) -> None: + def __init__(self, fields: list[str], message: str | None = None) -> None: self.fields = fields self.message = message or self.message @@ -30,4 +28,4 @@ class RequiredTogetherValidator: raise ValidationError(message, code="required") def __repr__(self): - return "<%s(fields=%s)>" % (self.__class__.__name__, smart_repr(self.fields)) + return f"<{self.__class__.__name__}(fields={smart_repr(self.fields)})>" diff --git a/authentik/outposts/api/service_connections.py b/authentik/outposts/api/service_connections.py index 21eac69e49..6fe8300d22 100644 --- a/authentik/outposts/api/service_connections.py +++ b/authentik/outposts/api/service_connections.py @@ -133,7 +133,7 @@ class KubernetesServiceConnectionSerializer(ServiceConnectionSerializer): try: load_kube_config_from_dict(kubeconfig, client_configuration=config) except ConfigException: - raise serializers.ValidationError(_("Invalid kubeconfig")) + raise serializers.ValidationError(_("Invalid kubeconfig")) from None return kubeconfig class Meta: diff --git a/authentik/outposts/consumer.py b/authentik/outposts/consumer.py index 126526a6b8..f635de60db 100644 --- a/authentik/outposts/consumer.py +++ b/authentik/outposts/consumer.py @@ -3,7 +3,7 @@ from dataclasses import asdict, dataclass, field from datetime import datetime from enum import IntEnum -from typing import Any, Optional +from typing import Any from asgiref.sync import async_to_sync from channels.exceptions import DenyConnection @@ -49,10 +49,10 @@ class WebsocketMessage: class OutpostConsumer(JsonWebsocketConsumer): """Handler for Outposts that connect over websockets for health checks and live updates""" - outpost: Optional[Outpost] = None + outpost: Outpost | None = None logger: BoundLogger - instance_uid: Optional[str] = None + instance_uid: str | None = None def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -71,7 +71,7 @@ class OutpostConsumer(JsonWebsocketConsumer): self.accept() except RuntimeError as exc: self.logger.warning("runtime error during accept", exc=exc) - raise DenyConnection() + raise DenyConnection() from None self.outpost = outpost query = QueryDict(self.scope["query_string"].decode()) self.instance_uid = query.get("instance_uuid", self.channel_name) diff --git a/authentik/outposts/controllers/base.py b/authentik/outposts/controllers/base.py index d0098df8ca..5f7cdafc93 100644 --- a/authentik/outposts/controllers/base.py +++ b/authentik/outposts/controllers/base.py @@ -1,7 +1,6 @@ """Base Controller""" from dataclasses import dataclass -from typing import Optional from structlog.stdlib import get_logger from structlog.testing import capture_logs @@ -29,7 +28,7 @@ class DeploymentPort: port: int name: str protocol: str - inner_port: Optional[int] = None + inner_port: int | None = None class BaseClient: @@ -60,7 +59,6 @@ class BaseController: self.logger = get_logger() self.deployment_ports = [] - # pylint: disable=invalid-name def up(self): """Called by scheduled task to reconcile deployment/service/etc""" raise NotImplementedError diff --git a/authentik/outposts/controllers/docker.py b/authentik/outposts/controllers/docker.py index 6280ee57f2..69ea01477f 100644 --- a/authentik/outposts/controllers/docker.py +++ b/authentik/outposts/controllers/docker.py @@ -1,7 +1,6 @@ """Docker controller""" from time import sleep -from typing import Optional from urllib.parse import urlparse from django.conf import settings @@ -25,12 +24,14 @@ from authentik.outposts.models import ( ServiceConnectionInvalid, ) +DOCKER_MAX_ATTEMPTS = 10 + class DockerClient(UpstreamDockerClient, BaseClient): """Custom docker client, which can handle TLS and SSH from a database.""" - tls: Optional[DockerInlineTLS] - ssh: Optional[DockerInlineSSH] + tls: DockerInlineTLS | None + ssh: DockerInlineSSH | None def __init__(self, connection: DockerServiceConnection): self.tls = None @@ -226,11 +227,10 @@ class DockerController(BaseController): except NotFound: return - # pylint: disable=too-many-return-statements def up(self, depth=1): if self.outpost.managed == MANAGED_OUTPOST: return None - if depth >= 10: + if depth >= DOCKER_MAX_ATTEMPTS: raise ControllerException("Giving up since we exceeded recursion limit.") self._migrate_container_name() try: diff --git a/authentik/outposts/controllers/k8s/base.py b/authentik/outposts/controllers/k8s/base.py index 7eac71234f..2a254ade09 100644 --- a/authentik/outposts/controllers/k8s/base.py +++ b/authentik/outposts/controllers/k8s/base.py @@ -2,9 +2,10 @@ from dataclasses import asdict from json import dumps -from typing import TYPE_CHECKING, Generic, Optional, TypeVar +from typing import TYPE_CHECKING, Generic, TypeVar from dacite.core import from_dict +from django.http import HttpResponseNotFound from django.utils.text import slugify from jsonpatch import JsonPatchConflict, JsonPatchException, JsonPatchTestFailed, apply_patch from kubernetes.client import ApiClient, V1ObjectMeta @@ -100,7 +101,6 @@ class KubernetesObjectReconciler(Generic[T]): return result - # pylint: disable=invalid-name def up(self): """Create object if it doesn't exist, update if needed or recreate if needed.""" current = None @@ -112,8 +112,8 @@ class KubernetesObjectReconciler(Generic[T]): try: current = self.retrieve() except (OpenApiException, HTTPError) as exc: - # pylint: disable=no-member - if isinstance(exc, ApiException) and exc.status == 404: + + if isinstance(exc, ApiException) and exc.status == HttpResponseNotFound.status_code: self.logger.debug("Failed to get current, triggering recreate") raise NeedsRecreate from exc self.logger.debug("Other unhandled error", exc=exc) @@ -124,8 +124,8 @@ class KubernetesObjectReconciler(Generic[T]): self.update(current, reference) self.logger.debug("Updating") except (OpenApiException, HTTPError) as exc: - # pylint: disable=no-member - if isinstance(exc, ApiException) and exc.status == 422: + + if isinstance(exc, ApiException) and exc.status == 422: # noqa: PLR2004 self.logger.debug("Failed to update current, triggering re-create") self._recreate(current=current, reference=reference) return @@ -136,7 +136,7 @@ class KubernetesObjectReconciler(Generic[T]): else: self.logger.debug("Object is up-to-date.") - def _recreate(self, reference: T, current: Optional[T] = None): + def _recreate(self, reference: T, current: T | None = None): """Recreate object""" self.logger.debug("Recreate requested") if current: @@ -157,8 +157,8 @@ class KubernetesObjectReconciler(Generic[T]): self.delete(current) self.logger.debug("Removing") except (OpenApiException, HTTPError) as exc: - # pylint: disable=no-member - if isinstance(exc, ApiException) and exc.status == 404: + + if isinstance(exc, ApiException) and exc.status == HttpResponseNotFound.status_code: self.logger.debug("Failed to get current, assuming non-existent") return self.logger.debug("Other unhandled error", exc=exc) diff --git a/authentik/outposts/controllers/k8s/service_monitor.py b/authentik/outposts/controllers/k8s/service_monitor.py index 8001ae8553..4e80e23ca3 100644 --- a/authentik/outposts/controllers/k8s/service_monitor.py +++ b/authentik/outposts/controllers/k8s/service_monitor.py @@ -25,7 +25,6 @@ class PrometheusServiceMonitorSpecEndpoint: class PrometheusServiceMonitorSpecSelector: """Prometheus ServiceMonitor selector spec""" - # pylint: disable=invalid-name matchLabels: dict @@ -34,7 +33,7 @@ class PrometheusServiceMonitorSpec: """Prometheus ServiceMonitor spec""" endpoints: list[PrometheusServiceMonitorSpecEndpoint] - # pylint: disable=invalid-name + selector: PrometheusServiceMonitorSpecSelector @@ -51,7 +50,6 @@ class PrometheusServiceMonitorMetadata: class PrometheusServiceMonitor: """Prometheus ServiceMonitor""" - # pylint: disable=invalid-name apiVersion: str kind: str metadata: PrometheusServiceMonitorMetadata diff --git a/authentik/outposts/controllers/k8s/utils.py b/authentik/outposts/controllers/k8s/utils.py index a0395ffd2c..c2f22f0c8e 100644 --- a/authentik/outposts/controllers/k8s/utils.py +++ b/authentik/outposts/controllers/k8s/utils.py @@ -1,7 +1,6 @@ """k8s utils""" from pathlib import Path -from typing import Optional from kubernetes.client.models.v1_container_port import V1ContainerPort from kubernetes.client.models.v1_service_port import V1ServicePort @@ -14,7 +13,7 @@ def get_namespace() -> str: """Get the namespace if we're running in a pod, otherwise default to default""" path = Path(SERVICE_TOKEN_FILENAME.replace("token", "namespace")) if path.exists(): - with open(path, "r", encoding="utf8") as _namespace_file: + with open(path, encoding="utf8") as _namespace_file: return _namespace_file.read() return "default" @@ -39,8 +38,8 @@ def compare_port( def compare_ports( - current: Optional[list[V1ServicePort | V1ContainerPort]], - reference: Optional[list[V1ServicePort | V1ContainerPort]], + current: list[V1ServicePort | V1ContainerPort] | None, + reference: list[V1ServicePort | V1ContainerPort] | None, ): """Compare ports of a list""" if not current or not reference: diff --git a/authentik/outposts/docker_ssh.py b/authentik/outposts/docker_ssh.py index 2736fe258c..ac2de3e125 100644 --- a/authentik/outposts/docker_ssh.py +++ b/authentik/outposts/docker_ssh.py @@ -81,7 +81,7 @@ class DockerInlineSSH: """Cleanup when we're done""" try: os.unlink(self.key_path) - with open(self.config_path, "r", encoding="utf-8") as ssh_config: + with open(self.config_path, encoding="utf-8") as ssh_config: start = 0 end = 0 lines = ssh_config.readlines() diff --git a/authentik/outposts/docker_tls.py b/authentik/outposts/docker_tls.py index 378295541a..a16550e35d 100644 --- a/authentik/outposts/docker_tls.py +++ b/authentik/outposts/docker_tls.py @@ -3,7 +3,6 @@ from os import unlink from pathlib import Path from tempfile import gettempdir -from typing import Optional from docker.tls import TLSConfig @@ -13,15 +12,15 @@ from authentik.crypto.models import CertificateKeyPair class DockerInlineTLS: """Create Docker TLSConfig from CertificateKeyPair""" - verification_kp: Optional[CertificateKeyPair] - authentication_kp: Optional[CertificateKeyPair] + verification_kp: CertificateKeyPair | None + authentication_kp: CertificateKeyPair | None _paths: list[str] def __init__( self, - verification_kp: Optional[CertificateKeyPair], - authentication_kp: Optional[CertificateKeyPair], + verification_kp: CertificateKeyPair | None, + authentication_kp: CertificateKeyPair | None, ) -> None: self.verification_kp = verification_kp self.authentication_kp = authentication_kp diff --git a/authentik/outposts/models.py b/authentik/outposts/models.py index d99751c9f0..29984ebd62 100644 --- a/authentik/outposts/models.py +++ b/authentik/outposts/models.py @@ -1,8 +1,9 @@ """Outpost models""" +from collections.abc import Iterable from dataclasses import asdict, dataclass, field from datetime import datetime -from typing import Any, Iterable, Optional +from typing import Any from uuid import uuid4 from dacite.core import from_dict @@ -49,7 +50,6 @@ class ServiceConnectionInvalid(SentryIgnoredException): @dataclass -# pylint: disable=too-many-instance-attributes class OutpostConfig: """Configuration an outpost uses to configure it self""" @@ -62,21 +62,21 @@ class OutpostConfig: log_level: str = CONFIG.get("log_level") object_naming_template: str = field(default="ak-outpost-%(name)s") - container_image: Optional[str] = field(default=None) + container_image: str | None = field(default=None) - docker_network: Optional[str] = field(default=None) + docker_network: str | None = field(default=None) docker_map_ports: bool = field(default=True) - docker_labels: Optional[dict[str, str]] = field(default=None) + docker_labels: dict[str, str] | None = field(default=None) kubernetes_replicas: int = field(default=1) kubernetes_namespace: str = field(default_factory=get_namespace) kubernetes_ingress_annotations: dict[str, str] = field(default_factory=dict) kubernetes_ingress_secret_name: str = field(default="authentik-outpost-tls") - kubernetes_ingress_class_name: Optional[str] = field(default=None) + kubernetes_ingress_class_name: str | None = field(default=None) kubernetes_service_type: str = field(default="ClusterIP") kubernetes_disabled_components: list[str] = field(default_factory=list) kubernetes_image_pull_secrets: list[str] = field(default_factory=list) - kubernetes_json_patches: Optional[dict[str, list[dict[str, Any]]]] = field(default=None) + kubernetes_json_patches: dict[str, list[dict[str, Any]]] | None = field(default=None) class OutpostModel(Model): @@ -99,7 +99,7 @@ class OutpostType(models.TextChoices): RAC = "rac" -def default_outpost_config(host: Optional[str] = None): +def default_outpost_config(host: str | None = None): """Get default outpost config""" return asdict(OutpostConfig(authentik_host=host or "")) @@ -127,6 +127,13 @@ class OutpostServiceConnection(models.Model): objects = InheritanceManager() + class Meta: + verbose_name = _("Outpost Service-Connection") + verbose_name_plural = _("Outpost Service-Connections") + + def __str__(self) -> __version__: + return f"Outpost service connection {self.name}" + @property def state_key(self) -> str: """Key used to save connection state in cache""" @@ -150,10 +157,6 @@ class OutpostServiceConnection(models.Model): # since the response doesn't use the correct inheritance return "" - class Meta: - verbose_name = _("Outpost Service-Connection") - verbose_name_plural = _("Outpost Service-Connections") - class DockerServiceConnection(SerializerModel, OutpostServiceConnection): """Service Connection to a Docker endpoint""" @@ -188,6 +191,13 @@ class DockerServiceConnection(SerializerModel, OutpostServiceConnection): ), ) + class Meta: + verbose_name = _("Docker Service-Connection") + verbose_name_plural = _("Docker Service-Connections") + + def __str__(self) -> str: + return f"Docker Service-Connection {self.name}" + @property def serializer(self) -> Serializer: from authentik.outposts.api.service_connections import DockerServiceConnectionSerializer @@ -198,13 +208,6 @@ class DockerServiceConnection(SerializerModel, OutpostServiceConnection): def component(self) -> str: return "ak-service-connection-docker-form" - def __str__(self) -> str: - return f"Docker Service-Connection {self.name}" - - class Meta: - verbose_name = _("Docker Service-Connection") - verbose_name_plural = _("Docker Service-Connections") - class KubernetesServiceConnection(SerializerModel, OutpostServiceConnection): """Service Connection to a Kubernetes cluster""" @@ -220,6 +223,13 @@ class KubernetesServiceConnection(SerializerModel, OutpostServiceConnection): default=True, help_text=_("Verify SSL Certificates of the Kubernetes API endpoint") ) + class Meta: + verbose_name = _("Kubernetes Service-Connection") + verbose_name_plural = _("Kubernetes Service-Connections") + + def __str__(self) -> str: + return f"Kubernetes Service-Connection {self.name}" + @property def serializer(self) -> Serializer: from authentik.outposts.api.service_connections import KubernetesServiceConnectionSerializer @@ -230,13 +240,6 @@ class KubernetesServiceConnection(SerializerModel, OutpostServiceConnection): def component(self) -> str: return "ak-service-connection-kubernetes-form" - def __str__(self) -> str: - return f"Kubernetes Service-Connection {self.name}" - - class Meta: - verbose_name = _("Kubernetes Service-Connection") - verbose_name_plural = _("Kubernetes Service-Connections") - class Outpost(SerializerModel, ManagedModel): """Outpost instance which manages a service user and token""" @@ -427,14 +430,14 @@ class OutpostState: """Outpost instance state, last_seen and version""" uid: str - last_seen: Optional[datetime] = field(default=None) - version: Optional[str] = field(default=None) + last_seen: datetime | None = field(default=None) + version: str | None = field(default=None) version_should: Version = field(default=OUR_VERSION) build_hash: str = field(default="") hostname: str = field(default="") args: dict = field(default_factory=dict) - _outpost: Optional[Outpost] = field(default=None) + _outpost: Outpost | None = field(default=None) @property def version_outdated(self) -> bool: @@ -467,7 +470,7 @@ class OutpostState: cache.delete(key) data = default_data state = from_dict(OutpostState, data) - # pylint: disable=protected-access + state._outpost = outpost return state diff --git a/authentik/outposts/tasks.py b/authentik/outposts/tasks.py index 73405f6cff..d2db180a49 100644 --- a/authentik/outposts/tasks.py +++ b/authentik/outposts/tasks.py @@ -3,7 +3,7 @@ from os import R_OK, access from pathlib import Path from socket import gethostname -from typing import Any, Optional +from typing import Any from urllib.parse import urlparse from asgiref.sync import async_to_sync @@ -49,8 +49,7 @@ LOGGER = get_logger() CACHE_KEY_OUTPOST_DOWN = "goauthentik.io/outposts/teardown/%s" -# pylint: disable=too-many-return-statements -def controller_for_outpost(outpost: Outpost) -> Optional[type[BaseController]]: +def controller_for_outpost(outpost: Outpost) -> type[BaseController] | None: """Get a controller for the outpost, when a service connection is defined""" if not outpost.service_connection: return None @@ -195,7 +194,7 @@ def outpost_post_save(model_class: str, model_pk: Any): LOGGER.debug("Trigger reconcile for outpost", instance=instance) outpost_controller.delay(str(instance.pk)) - if isinstance(instance, (OutpostModel, Outpost)): + if isinstance(instance, OutpostModel | Outpost): LOGGER.debug("triggering outpost update from outpostmodel/outpost", instance=instance) outpost_send_update(instance) diff --git a/authentik/policies/api/bindings.py b/authentik/policies/api/bindings.py index 8ca6d17d29..f9a5ab1c7e 100644 --- a/authentik/policies/api/bindings.py +++ b/authentik/policies/api/bindings.py @@ -1,6 +1,6 @@ """policy binding API Views""" -from typing import OrderedDict +from collections import OrderedDict from django.core.exceptions import ObjectDoesNotExist from django_filters.filters import BooleanFilter, ModelMultipleChoiceFilter @@ -25,7 +25,6 @@ class PolicyBindingModelForeignKey(PrimaryKeyRelatedField): def use_pk_only_optimization(self): return False - # pylint: disable=inconsistent-return-statements def to_internal_value(self, data): if self.pk_field is not None: data = self.pk_field.to_internal_value(data) diff --git a/authentik/policies/denied.py b/authentik/policies/denied.py index faf0736d64..b93348372d 100644 --- a/authentik/policies/denied.py +++ b/authentik/policies/denied.py @@ -1,6 +1,6 @@ """policy http response""" -from typing import Any, Optional +from typing import Any from django.http.request import HttpRequest from django.template.response import TemplateResponse @@ -17,14 +17,14 @@ class AccessDeniedResponse(TemplateResponse): title: str - error_message: Optional[str] = None - policy_result: Optional[PolicyResult] = None + error_message: str | None = None + policy_result: PolicyResult | None = None def __init__(self, request: HttpRequest, template="policies/denied.html") -> None: super().__init__(request, template) self.title = _("Access denied") - def resolve_context(self, context: Optional[dict[str, Any]]) -> Optional[dict[str, Any]]: + def resolve_context(self, context: dict[str, Any] | None) -> dict[str, Any] | None: if not context: context = {} context["title"] = self.title diff --git a/authentik/policies/engine.py b/authentik/policies/engine.py index 7b786ea4cb..8b24334bf6 100644 --- a/authentik/policies/engine.py +++ b/authentik/policies/engine.py @@ -1,9 +1,9 @@ """authentik policy engine""" +from collections.abc import Iterator from multiprocessing import Pipe, current_process from multiprocessing.connection import Connection from time import perf_counter -from typing import Iterator, Optional from django.core.cache import cache from django.http import HttpRequest @@ -27,7 +27,7 @@ class PolicyProcessInfo: process: PolicyProcess connection: Connection - result: Optional[PolicyResult] + result: PolicyResult | None binding: PolicyBinding def __init__(self, process: PolicyProcess, connection: Connection, binding: PolicyBinding): diff --git a/authentik/policies/exceptions.py b/authentik/policies/exceptions.py index 928510b5a1..83e97bd318 100644 --- a/authentik/policies/exceptions.py +++ b/authentik/policies/exceptions.py @@ -1,7 +1,5 @@ """policy exceptions""" -from typing import Optional - from authentik.lib.sentry import SentryIgnoredException @@ -12,8 +10,8 @@ class PolicyEngineException(SentryIgnoredException): class PolicyException(SentryIgnoredException): """Exception that should be raised during Policy Evaluation, and can be recovered from.""" - src_exc: Optional[Exception] = None + src_exc: Exception | None = None - def __init__(self, src_exc: Optional[Exception] = None) -> None: + def __init__(self, src_exc: Exception | None = None) -> None: super().__init__() self.src_exc = src_exc diff --git a/authentik/policies/expression/evaluator.py b/authentik/policies/expression/evaluator.py index bba01553cf..536b2634d8 100644 --- a/authentik/policies/expression/evaluator.py +++ b/authentik/policies/expression/evaluator.py @@ -24,7 +24,7 @@ class PolicyEvaluator(BaseEvaluator): policy: Optional["ExpressionPolicy"] = None - def __init__(self, policy_name: Optional[str] = None): + def __init__(self, policy_name: str | None = None): super().__init__(policy_name or "PolicyEvaluator") self._messages = [] # update website/docs/expressions/_objects.md @@ -66,7 +66,7 @@ class PolicyEvaluator(BaseEvaluator): # PolicyExceptions should be propagated back to the process, # which handles recording and returning a correct result raise exc - except Exception as exc: # pylint: disable=broad-except + except Exception as exc: LOGGER.warning("Expression error", exc=exc) return PolicyResult(False, str(exc)) else: diff --git a/authentik/policies/models.py b/authentik/policies/models.py index ba4c72f720..2364dfeb28 100644 --- a/authentik/policies/models.py +++ b/authentik/policies/models.py @@ -40,13 +40,13 @@ class PolicyBindingModel(models.Model): objects = InheritanceManager() - def __str__(self) -> str: - return f"PolicyBindingModel {self.pbm_uuid}" - class Meta: verbose_name = _("Policy Binding Model") verbose_name_plural = _("Policy Binding Models") + def __str__(self) -> str: + return f"PolicyBindingModel {self.pbm_uuid}" + class PolicyBinding(SerializerModel): """Relationship between a Policy and a PolicyBindingModel.""" @@ -138,7 +138,7 @@ class PolicyBinding(SerializerModel): suffix = f"{self.target_type.title()} {self.target_name}" try: return f"Binding from {self.target} #{self.order} to {suffix}" - except PolicyBinding.target.RelatedObjectDoesNotExist: # pylint: disable=no-member + except PolicyBinding.target.RelatedObjectDoesNotExist: return f"Binding - #{self.order} to {suffix}" return "" diff --git a/authentik/policies/password/models.py b/authentik/policies/password/models.py index 8dde8ae24e..d8584617bc 100644 --- a/authentik/policies/password/models.py +++ b/authentik/policies/password/models.py @@ -87,7 +87,6 @@ class PasswordPolicy(Policy): return zxcvbn_result return PolicyResult(True) - # pylint: disable=too-many-return-statements def passes_static(self, password: str, request: PolicyRequest) -> PolicyResult: """Check static rules""" if len(password) < self.length_min: diff --git a/authentik/policies/process.py b/authentik/policies/process.py index 3e568fd32f..e2d9139b0c 100644 --- a/authentik/policies/process.py +++ b/authentik/policies/process.py @@ -2,7 +2,6 @@ from multiprocessing import get_context from multiprocessing.connection import Connection -from typing import Optional from django.core.cache import cache from sentry_sdk.hub import Hub @@ -46,7 +45,7 @@ class PolicyProcess(PROCESS_CLASS): self, binding: PolicyBinding, request: PolicyRequest, - connection: Optional[Connection], + connection: Connection | None, ): super().__init__() self.binding = binding @@ -143,6 +142,6 @@ class PolicyProcess(PROCESS_CLASS): """Task wrapper to run policy checking""" try: self.connection.send(self.profiling_wrapper()) - except Exception as exc: # pylint: disable=broad-except + except Exception as exc: LOGGER.warning("Policy failed to run", exc=exception_to_string(exc)) self.connection.send(PolicyResult(False, str(exc))) diff --git a/authentik/policies/types.py b/authentik/policies/types.py index be14dcea8c..d1c06677ae 100644 --- a/authentik/policies/types.py +++ b/authentik/policies/types.py @@ -3,7 +3,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from django.db.models import Model from django.http import HttpRequest @@ -24,8 +24,8 @@ class PolicyRequest: """Data-class to hold policy request data""" user: User - http_request: Optional[HttpRequest] - obj: Optional[Model] + http_request: HttpRequest | None + obj: Model | None context: dict[str, Any] debug: bool @@ -71,10 +71,10 @@ class PolicyResult: messages: tuple[str, ...] raw_result: Any - source_binding: Optional["PolicyBinding"] - source_results: Optional[list["PolicyResult"]] + source_binding: PolicyBinding | None + source_results: list[PolicyResult] | None - log_messages: Optional[list[dict]] + log_messages: list[dict] | None def __init__(self, passing: bool, *messages: str): self.passing = passing diff --git a/authentik/policies/views.py b/authentik/policies/views.py index 99f28dd224..23cf1cd51b 100644 --- a/authentik/policies/views.py +++ b/authentik/policies/views.py @@ -1,6 +1,6 @@ """authentik access helper classes""" -from typing import Any, Optional +from typing import Any from django.contrib import messages from django.contrib.auth.mixins import AccessMixin @@ -23,9 +23,9 @@ LOGGER = get_logger() class RequestValidationError(SentryIgnoredException): """Error raised in pre_permission_check, when a request is invalid.""" - response: Optional[HttpResponse] + response: HttpResponse | None - def __init__(self, response: Optional[HttpResponse] = None): + def __init__(self, response: HttpResponse | None = None): super().__init__() if response: self.response = response @@ -95,7 +95,7 @@ class PolicyAccessView(AccessMixin, View): ) def handle_no_permission_authenticated( - self, result: Optional[PolicyResult] = None + self, result: PolicyResult | None = None ) -> HttpResponse: """Function called when user has no permissions but is authenticated""" response = AccessDeniedResponse(self.request) @@ -107,7 +107,7 @@ class PolicyAccessView(AccessMixin, View): """optionally modify the policy request""" return request - def user_has_access(self, user: Optional[User] = None) -> PolicyResult: + def user_has_access(self, user: User | None = None) -> PolicyResult: """Check if user has access to application.""" user = user or self.request.user policy_engine = PolicyEngine(self.application, user or self.request.user, self.request) diff --git a/authentik/providers/ldap/models.py b/authentik/providers/ldap/models.py index 42dbbdb13f..5032ef3e47 100644 --- a/authentik/providers/ldap/models.py +++ b/authentik/providers/ldap/models.py @@ -1,6 +1,6 @@ """LDAP Provider""" -from typing import Iterable, Optional +from collections.abc import Iterable from django.db import models from django.utils.translation import gettext_lazy as _ @@ -82,7 +82,7 @@ class LDAPProvider(OutpostModel, BackchannelProvider): ) @property - def launch_url(self) -> Optional[str]: + def launch_url(self) -> str | None: """LDAP never has a launch URL""" return None diff --git a/authentik/providers/oauth2/api/providers.py b/authentik/providers/oauth2/api/providers.py index 6702dbf428..632fabca5b 100644 --- a/authentik/providers/oauth2/api/providers.py +++ b/authentik/providers/oauth2/api/providers.py @@ -135,7 +135,7 @@ class OAuth2ProviderViewSet(UsedByMixin, ModelViewSet): kwargs={"application_slug": provider.application.slug}, ) ) - except Provider.application.RelatedObjectDoesNotExist: # pylint: disable=no-member + except Provider.application.RelatedObjectDoesNotExist: pass return Response(data) @@ -170,7 +170,7 @@ class OAuth2ProviderViewSet(UsedByMixin, ModelViewSet): if not for_user: raise ValidationError({"for_user": "User not found"}) except ValueError: - raise ValidationError({"for_user": "input must be numerical"}) + raise ValidationError({"for_user": "input must be numerical"}) from None scope_names = ScopeMapping.objects.filter(provider=provider).values_list( "scope_name", flat=True diff --git a/authentik/providers/oauth2/errors.py b/authentik/providers/oauth2/errors.py index 35dd85912b..e8c5fd9ed8 100644 --- a/authentik/providers/oauth2/errors.py +++ b/authentik/providers/oauth2/errors.py @@ -1,6 +1,5 @@ """OAuth errors""" -from typing import Optional from urllib.parse import quote, urlparse from django.http import HttpRequest, HttpResponse, HttpResponseRedirect @@ -27,7 +26,7 @@ class OAuth2Error(SentryIgnoredException): def __repr__(self) -> str: return self.error - def to_event(self, message: Optional[str] = None, **kwargs) -> Event: + def to_event(self, message: str | None = None, **kwargs) -> Event: """Create configuration_error Event.""" return Event.new( EventAction.CONFIGURATION_ERROR, @@ -142,14 +141,13 @@ class AuthorizeError(OAuth2Error): ), } - # pylint: disable=too-many-arguments def __init__( self, redirect_uri: str, error: str, grant_type: str, state: str, - description: Optional[str] = None, + description: str | None = None, ): super().__init__() self.error = error diff --git a/authentik/providers/oauth2/id_token.py b/authentik/providers/oauth2/id_token.py index 8970b4612b..d21979a636 100644 --- a/authentik/providers/oauth2/id_token.py +++ b/authentik/providers/oauth2/id_token.py @@ -1,7 +1,7 @@ """id_token utils""" from dataclasses import asdict, dataclass, field -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any from django.db import models from django.http import HttpRequest @@ -43,7 +43,6 @@ class SubModes(models.TextChoices): @dataclass(slots=True) -# pylint: disable=too-many-instance-attributes class IDToken: """The primary extension that OpenID Connect makes to OAuth 2.0 to enable End-Users to be Authenticated is the ID Token data structure. The ID Token is a security token that contains @@ -54,36 +53,35 @@ class IDToken: https://openid.net/specs/openid-connect-core-1_0.html#IDToken""" # Issuer, https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.1 - iss: Optional[str] = None + iss: str | None = None # Subject, https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.2 - sub: Optional[str] = None + sub: str | None = None # Audience, https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.3 - aud: Optional[Union[str, list[str]]] = None + aud: str | list[str] | None = None # Expiration time, https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.4 - exp: Optional[int] = None + exp: int | None = None # Issued at, https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.6 - iat: Optional[int] = None + iat: int | None = None # Time when the authentication occurred, # https://openid.net/specs/openid-connect-core-1_0.html#IDToken - auth_time: Optional[int] = None + auth_time: int | None = None # Authentication Context Class Reference, # https://openid.net/specs/openid-connect-core-1_0.html#IDToken - acr: Optional[str] = ACR_AUTHENTIK_DEFAULT + acr: str | None = ACR_AUTHENTIK_DEFAULT # Authentication Methods References, # https://openid.net/specs/openid-connect-core-1_0.html#IDToken - amr: Optional[list[str]] = None + amr: list[str] | None = None # Code hash value, http://openid.net/specs/openid-connect-core-1_0.html - c_hash: Optional[str] = None + c_hash: str | None = None # Value used to associate a Client session with an ID Token, # http://openid.net/specs/openid-connect-core-1_0.html - nonce: Optional[str] = None + nonce: str | None = None # Access Token hash value, http://openid.net/specs/openid-connect-core-1_0.html - at_hash: Optional[str] = None + at_hash: str | None = None claims: dict[str, Any] = field(default_factory=dict) @staticmethod - # pylint: disable=too-many-locals def new( provider: "OAuth2Provider", token: "BaseGrantModel", request: HttpRequest, **kwargs ) -> "IDToken": diff --git a/authentik/providers/oauth2/models.py b/authentik/providers/oauth2/models.py index f5225b1534..3e7f00bcf7 100644 --- a/authentik/providers/oauth2/models.py +++ b/authentik/providers/oauth2/models.py @@ -6,7 +6,7 @@ import json from dataclasses import asdict from functools import cached_property from hashlib import sha256 -from typing import Any, Optional +from typing import Any from urllib.parse import urlparse, urlunparse from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey @@ -233,7 +233,7 @@ class OAuth2Provider(Provider): return private_key, JWTAlgorithms.ES256 raise ValueError(f"Invalid private key type: {type(private_key)}") - def get_issuer(self, request: HttpRequest) -> Optional[str]: + def get_issuer(self, request: HttpRequest) -> str | None: """Get issuer, based on request""" if self.issuer_mode == IssuerMode.GLOBAL: return request.build_absolute_uri(reverse("authentik_core:root-redirect")) @@ -241,17 +241,16 @@ class OAuth2Provider(Provider): url = reverse( "authentik_providers_oauth2:provider-root", kwargs={ - # pylint: disable=no-member "application_slug": self.application.slug, }, ) return request.build_absolute_uri(url) - # pylint: disable=no-member + except Provider.application.RelatedObjectDoesNotExist: return None @property - def launch_url(self) -> Optional[str]: + def launch_url(self) -> str | None: """Guess launch_url based on first redirect_uri""" if self.redirect_uris == "": return None @@ -299,6 +298,9 @@ class BaseGrantModel(models.Model): auth_time = models.DateTimeField(verbose_name="Authentication time") session_id = models.CharField(default="", blank=True) + class Meta: + abstract = True + @property def scope(self) -> list[str]: """Return scopes as list of strings""" @@ -308,9 +310,6 @@ class BaseGrantModel(models.Model): def scope(self, value): self._scope = " ".join(value) - class Meta: - abstract = True - class AuthorizationCode(SerializerModel, ExpiringModel, BaseGrantModel): """OAuth2 Authorization Code""" @@ -322,6 +321,13 @@ class AuthorizationCode(SerializerModel, ExpiringModel, BaseGrantModel): max_length=255, null=True, verbose_name=_("Code Challenge Method") ) + class Meta: + verbose_name = _("Authorization Code") + verbose_name_plural = _("Authorization Codes") + + def __str__(self): + return f"Authorization code for {self.provider} for user {self.user}" + @property def serializer(self) -> Serializer: from authentik.providers.oauth2.api.tokens import ExpiringBaseGrantModelSerializer @@ -338,13 +344,6 @@ class AuthorizationCode(SerializerModel, ExpiringModel, BaseGrantModel): .decode("ascii") ) - class Meta: - verbose_name = _("Authorization Code") - verbose_name_plural = _("Authorization Codes") - - def __str__(self): - return f"Authorization code for {self.provider} for user {self.user}" - class AccessToken(SerializerModel, ExpiringModel, BaseGrantModel): """OAuth2 access token, non-opaque using a JWT as identifier""" @@ -352,6 +351,13 @@ class AccessToken(SerializerModel, ExpiringModel, BaseGrantModel): token = models.TextField() _id_token = models.TextField() + class Meta: + verbose_name = _("OAuth2 Access Token") + verbose_name_plural = _("OAuth2 Access Tokens") + + def __str__(self): + return f"Access Token for {self.provider} for user {self.user}" + @property def id_token(self) -> IDToken: """Load ID Token from json""" @@ -381,13 +387,6 @@ class AccessToken(SerializerModel, ExpiringModel, BaseGrantModel): return TokenModelSerializer - class Meta: - verbose_name = _("OAuth2 Access Token") - verbose_name_plural = _("OAuth2 Access Tokens") - - def __str__(self): - return f"Access Token for {self.provider} for user {self.user}" - class RefreshToken(SerializerModel, ExpiringModel, BaseGrantModel): """OAuth2 Refresh Token, opaque""" @@ -395,6 +394,13 @@ class RefreshToken(SerializerModel, ExpiringModel, BaseGrantModel): token = models.TextField(default=generate_client_secret) _id_token = models.TextField(verbose_name=_("ID Token")) + class Meta: + verbose_name = _("OAuth2 Refresh Token") + verbose_name_plural = _("OAuth2 Refresh Tokens") + + def __str__(self): + return f"Refresh Token for {self.provider} for user {self.user}" + @property def id_token(self) -> IDToken: """Load ID Token from json""" @@ -411,13 +417,6 @@ class RefreshToken(SerializerModel, ExpiringModel, BaseGrantModel): return TokenModelSerializer - class Meta: - verbose_name = _("OAuth2 Refresh Token") - verbose_name_plural = _("OAuth2 Refresh Tokens") - - def __str__(self): - return f"Refresh Token for {self.provider} for user {self.user}" - class DeviceToken(ExpiringModel): """Temporary device token for OAuth device flow""" diff --git a/authentik/providers/oauth2/utils.py b/authentik/providers/oauth2/utils.py index 2dda028e73..5f36a3f891 100644 --- a/authentik/providers/oauth2/utils.py +++ b/authentik/providers/oauth2/utils.py @@ -3,7 +3,7 @@ import re from base64 import b64decode from binascii import Error -from typing import Any, Optional +from typing import Any from urllib.parse import urlparse from django.http import HttpRequest, HttpResponse, JsonResponse @@ -73,7 +73,7 @@ def cors_allow(request: HttpRequest, response: HttpResponse, *allowed_origins: s return response -def extract_access_token(request: HttpRequest) -> Optional[str]: +def extract_access_token(request: HttpRequest) -> str | None: """ Get the access token using Authorization Request Header Field method. Or try getting via GET. @@ -169,7 +169,7 @@ def protected_resource_view(scopes: list[str]): kwargs["token"] = token CTX_AUTH_VIA.set("oauth_token") response = view(request, *args, **kwargs) - setattr(response, "ak_context", {}) + response.ak_context = {} response.ak_context[KEY_USER] = token.user.username return response @@ -178,12 +178,12 @@ def protected_resource_view(scopes: list[str]): return wrapper -def authenticate_provider(request: HttpRequest) -> Optional[OAuth2Provider]: +def authenticate_provider(request: HttpRequest) -> OAuth2Provider | None: """Attempt to authenticate via Basic auth of client_id:client_secret""" client_id, client_secret = extract_client_auth(request) if client_id == client_secret == "": return None - provider: Optional[OAuth2Provider] = OAuth2Provider.objects.filter(client_id=client_id).first() + provider: OAuth2Provider | None = OAuth2Provider.objects.filter(client_id=client_id).first() if not provider: return None if client_id != provider.client_id or client_secret != provider.client_secret: @@ -200,7 +200,7 @@ class HttpResponseRedirectScheme(HttpResponseRedirect): self, redirect_to: str, *args: Any, - allowed_schemes: Optional[list[str]] = None, + allowed_schemes: list[str] | None = None, **kwargs: Any, ) -> None: self.allowed_schemes = allowed_schemes or ["http", "https", "ftp"] diff --git a/authentik/providers/oauth2/views/authorize.py b/authentik/providers/oauth2/views/authorize.py index bc99b04390..e7989e4191 100644 --- a/authentik/providers/oauth2/views/authorize.py +++ b/authentik/providers/oauth2/views/authorize.py @@ -6,7 +6,6 @@ from hashlib import sha256 from json import dumps from re import error as RegexError from re import fullmatch -from typing import Optional from urllib.parse import parse_qs, urlencode, urlparse, urlsplit, urlunsplit from uuid import uuid4 @@ -81,28 +80,27 @@ FORBIDDEN_URI_SCHEMES = {"javascript", "data", "vbscript"} @dataclass(slots=True) -# pylint: disable=too-many-instance-attributes class OAuthAuthorizationParams: """Parameters required to authorize an OAuth Client""" client_id: str redirect_uri: str response_type: str - response_mode: Optional[str] + response_mode: str | None scope: set[str] state: str - nonce: Optional[str] + nonce: str | None prompt: set[str] grant_type: str provider: OAuth2Provider = field(default_factory=OAuth2Provider) - request: Optional[str] = None + request: str | None = None - max_age: Optional[int] = None + max_age: int | None = None - code_challenge: Optional[str] = None - code_challenge_method: Optional[str] = None + code_challenge: str | None = None + code_challenge_method: str | None = None github_compat: InitVar[bool] = False @@ -221,7 +219,7 @@ class OAuthAuthorizationParams: redirect_uri_given=self.redirect_uri, redirect_uri_expected=allowed_redirect_urls, ) - raise RedirectUriError(self.redirect_uri, allowed_redirect_urls) + raise RedirectUriError(self.redirect_uri, allowed_redirect_urls) from None # Check against forbidden schemes if urlparse(self.redirect_uri).scheme in FORBIDDEN_URI_SCHEMES: raise RedirectUriError(self.redirect_uri, allowed_redirect_urls) @@ -348,14 +346,14 @@ class AuthorizationFlowInitView(PolicyAccessView): ) except AuthorizeError as error: LOGGER.warning(error.description, redirect_uri=error.redirect_uri) - raise RequestValidationError(error.get_response(self.request)) + raise RequestValidationError(error.get_response(self.request)) from None except OAuth2Error as error: LOGGER.warning(error.description) raise RequestValidationError( bad_request_message(self.request, error.description, title=error.error) - ) + ) from None except OAuth2Provider.DoesNotExist: - raise Http404 + raise Http404 from None if PROMPT_NONE in self.params.prompt and not self.request.user.is_authenticated: # When "prompt" is set to "none" but the user is not logged in, show an error message error = AuthorizeError( @@ -487,7 +485,7 @@ class OAuthFulfillmentStage(StageView): "component": "ak-stage-autosubmit", "title": self.executor.plan.context.get( PLAN_CONTEXT_TITLE, - _("Redirecting to %(app)s..." % {"app": self.application.name}), + _("Redirecting to {app}...".format_map({"app": self.application.name})), ), "url": self.params.redirect_uri, "attrs": query_params, @@ -533,7 +531,7 @@ class OAuthFulfillmentStage(StageView): except (ClientIdError, RedirectUriError) as error: error.to_event(application=self.application).from_http(request) self.executor.stage_invalid() - # pylint: disable=no-member + return bad_request_message(request, error.description, title=error.error) except AuthorizeError as error: error.to_event(application=self.application).from_http(request) @@ -596,9 +594,9 @@ class OAuthFulfillmentStage(StageView): "server_error", self.params.grant_type, self.params.state, - ) + ) from None - def create_implicit_response(self, code: Optional[AuthorizationCode]) -> dict: + def create_implicit_response(self, code: AuthorizationCode | None) -> dict: """Create implicit response's URL Fragment dictionary""" query_fragment = {} auth_event = get_login_event(self.request) diff --git a/authentik/providers/oauth2/views/device_backchannel.py b/authentik/providers/oauth2/views/device_backchannel.py index 32b674a719..453ec17fd6 100644 --- a/authentik/providers/oauth2/views/device_backchannel.py +++ b/authentik/providers/oauth2/views/device_backchannel.py @@ -1,6 +1,5 @@ """Device flow views""" -from typing import Optional from urllib.parse import urlencode from django.http import HttpRequest, HttpResponse, HttpResponseBadRequest, JsonResponse @@ -28,7 +27,7 @@ class DeviceView(View): provider: OAuth2Provider scopes: list[str] = [] - def parse_request(self) -> Optional[HttpResponse]: + def parse_request(self) -> HttpResponse | None: """Parse incoming request""" client_id = self.request.POST.get("client_id", None) if not client_id: diff --git a/authentik/providers/oauth2/views/device_init.py b/authentik/providers/oauth2/views/device_init.py index c038e8577b..a758cb1c0e 100644 --- a/authentik/providers/oauth2/views/device_init.py +++ b/authentik/providers/oauth2/views/device_init.py @@ -1,7 +1,5 @@ """Device flow views""" -from typing import Optional - from django.http import HttpRequest, HttpResponse from django.utils.translation import gettext as _ from django.views import View @@ -33,7 +31,7 @@ LOGGER = get_logger() QS_KEY_CODE = "code" # nosec -def get_application(provider: OAuth2Provider) -> Optional[Application]: +def get_application(provider: OAuth2Provider) -> Application | None: """Get application from provider""" try: app = provider.application @@ -44,7 +42,7 @@ def get_application(provider: OAuth2Provider) -> Optional[Application]: return None -def validate_code(code: int, request: HttpRequest) -> Optional[HttpResponse]: +def validate_code(code: int, request: HttpRequest) -> HttpResponse | None: """Validate user token""" token = DeviceToken.objects.filter( user_code=code, diff --git a/authentik/providers/oauth2/views/jwks.py b/authentik/providers/oauth2/views/jwks.py index 88e4328f8a..ea95c4c28c 100644 --- a/authentik/providers/oauth2/views/jwks.py +++ b/authentik/providers/oauth2/views/jwks.py @@ -1,7 +1,6 @@ """authentik OAuth2 JWKS Views""" from base64 import b64encode, urlsafe_b64encode -from typing import Optional from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric.ec import ( @@ -65,7 +64,7 @@ def to_base64url_uint(val: int, min_length: int = 0) -> bytes: class JWKSView(View): """Show RSA Key data for Provider""" - def get_jwk_for_key(self, key: CertificateKeyPair) -> Optional[dict]: + def get_jwk_for_key(self, key: CertificateKeyPair) -> dict | None: """Convert a certificate-key pair into JWK""" private_key = key.private_key key_data = None diff --git a/authentik/providers/oauth2/views/token.py b/authentik/providers/oauth2/views/token.py index 6856e4ab10..464df72636 100644 --- a/authentik/providers/oauth2/views/token.py +++ b/authentik/providers/oauth2/views/token.py @@ -7,7 +7,7 @@ from datetime import datetime from hashlib import sha256 from re import error as RegexError from re import fullmatch -from typing import Any, Optional +from typing import Any from urllib.parse import urlparse from django.http import HttpRequest, HttpResponse @@ -68,7 +68,6 @@ LOGGER = get_logger() @dataclass(slots=True) -# pylint: disable=too-many-instance-attributes class TokenParams: """Token params""" @@ -81,16 +80,16 @@ class TokenParams: provider: OAuth2Provider - authorization_code: Optional[AuthorizationCode] = None - refresh_token: Optional[RefreshToken] = None - device_code: Optional[DeviceToken] = None - user: Optional[User] = None + authorization_code: AuthorizationCode | None = None + refresh_token: RefreshToken | None = None + device_code: DeviceToken | None = None + user: User | None = None - code_verifier: Optional[str] = None + code_verifier: str | None = None raw_code: InitVar[str] = "" raw_token: InitVar[str] = "" - request: InitVar[Optional[HttpRequest]] = None + request: InitVar[HttpRequest | None] = None @staticmethod def parse( @@ -210,7 +209,7 @@ class TokenParams: message="Invalid redirect_uri configured", provider=self.provider, ).from_http(request) - raise TokenError("invalid_client") + raise TokenError("invalid_client") from None # Check against forbidden schemes if urlparse(self.redirect_uri).scheme in FORBIDDEN_URI_SCHEMES: @@ -306,7 +305,7 @@ class TokenParams: user, _, password = b64decode(self.client_secret).decode("utf-8").partition(":") return self.__post_init_client_credentials_creds(request, user, password) except (ValueError, Error): - raise TokenError("invalid_grant") + raise TokenError("invalid_grant") from None def __post_init_client_credentials_creds( self, request: HttpRequest, username: str, password: str @@ -338,7 +337,6 @@ class TokenParams: }, ).from_http(request, user=user) - # pylint: disable=too-many-locals def __post_init_client_credentials_jwt(self, request: HttpRequest): assertion_type = request.POST.get(CLIENT_ASSERTION_TYPE, "") if assertion_type != CLIENT_ASSERTION_TYPE_JWT: @@ -353,8 +351,8 @@ class TokenParams: token = None - source: Optional[OAuthSource] = None - parsed_key: Optional[PyJWK] = None + source: OAuthSource | None = None + parsed_key: PyJWK | None = None # Fully decode the JWT without verifying the signature, so we can get access to # the header. @@ -368,7 +366,7 @@ class TokenParams: ) except (PyJWTError, ValueError, TypeError, AttributeError) as exc: LOGGER.warning("failed to parse JWT for kid lookup", exc=exc) - raise TokenError("invalid_grant") + raise TokenError("invalid_grant") from None expected_kid = decode_unvalidated["header"]["kid"] for source in self.provider.jwks_sources.filter( oidc_jwks__keys__contains=[{"kid": expected_kid}] @@ -489,8 +487,8 @@ class TokenParams: class TokenView(View): """Generate tokens for clients""" - provider: Optional[OAuth2Provider] = None - params: Optional[TokenParams] = None + provider: OAuth2Provider | None = None + params: TokenParams | None = None def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse: response = super().dispatch(request, *args, **kwargs) diff --git a/authentik/providers/oauth2/views/userinfo.py b/authentik/providers/oauth2/views/userinfo.py index 19a3fb5bc9..fa4a2fe7a2 100644 --- a/authentik/providers/oauth2/views/userinfo.py +++ b/authentik/providers/oauth2/views/userinfo.py @@ -1,6 +1,6 @@ """authentik OAuth2 OpenID Userinfo views""" -from typing import Any, Optional +from typing import Any from deepmerge import always_merger from django.http import HttpRequest, HttpResponse @@ -39,7 +39,7 @@ class UserInfoView(View): """Create a dictionary with all the requested claims about the End-User. See: http://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse""" - token: Optional[RefreshToken] + token: RefreshToken | None def get_scope_descriptions( self, scopes: list[str], provider: OAuth2Provider diff --git a/authentik/providers/proxy/api.py b/authentik/providers/proxy/api.py index 3da9a98450..70c81d6d4c 100644 --- a/authentik/providers/proxy/api.py +++ b/authentik/providers/proxy/api.py @@ -1,6 +1,6 @@ """ProxyProvider API Views""" -from typing import Any, Optional +from typing import Any from django.utils.translation import gettext_lazy as _ from drf_spectacular.utils import extend_schema_field @@ -143,7 +143,7 @@ class ProxyOutpostConfigSerializer(ModelSerializer): """Embed OpenID Connect provider information""" return ProviderInfoView(request=self.context["request"]._request).get_info(obj) - def get_access_token_validity(self, obj: ProxyProvider) -> Optional[float]: + def get_access_token_validity(self, obj: ProxyProvider) -> float | None: """Get token validity as second count""" return timedelta_from_string(obj.access_token_validity).total_seconds() diff --git a/authentik/providers/proxy/controllers/k8s/traefik_3.py b/authentik/providers/proxy/controllers/k8s/traefik_3.py index 2e101ce2df..c807ecb630 100644 --- a/authentik/providers/proxy/controllers/k8s/traefik_3.py +++ b/authentik/providers/proxy/controllers/k8s/traefik_3.py @@ -20,11 +20,11 @@ class TraefikMiddlewareSpecForwardAuth: """traefik middleware forwardAuth spec""" address: str - # pylint: disable=invalid-name + authResponseHeadersRegex: str = field(default="") - # pylint: disable=invalid-name + authResponseHeaders: list[str] = field(default_factory=list) - # pylint: disable=invalid-name + trustForwardHeader: bool = field(default=True) @@ -32,7 +32,6 @@ class TraefikMiddlewareSpecForwardAuth: class TraefikMiddlewareSpec: """Traefik middleware spec""" - # pylint: disable=invalid-name forwardAuth: TraefikMiddlewareSpecForwardAuth @@ -49,7 +48,6 @@ class TraefikMiddlewareMetadata: class TraefikMiddleware: """Traefik Middleware""" - # pylint: disable=invalid-name apiVersion: str kind: str metadata: TraefikMiddlewareMetadata diff --git a/authentik/providers/proxy/models.py b/authentik/providers/proxy/models.py index 9ec1e4d71e..4445be86c6 100644 --- a/authentik/providers/proxy/models.py +++ b/authentik/providers/proxy/models.py @@ -1,8 +1,8 @@ """authentik proxy models""" import string +from collections.abc import Iterable from random import SystemRandom -from typing import Iterable, Optional from urllib.parse import urljoin from django.db import models @@ -122,7 +122,7 @@ class ProxyProvider(OutpostModel, OAuth2Provider): return ProxyProviderSerializer @property - def launch_url(self) -> Optional[str]: + def launch_url(self) -> str | None: """Use external_host as launch URL""" return self.external_host diff --git a/authentik/providers/radius/models.py b/authentik/providers/radius/models.py index 3fb32b576a..83acd4f564 100644 --- a/authentik/providers/radius/models.py +++ b/authentik/providers/radius/models.py @@ -1,7 +1,5 @@ """Radius Provider""" -from typing import Optional, Type - from django.db import models from django.utils.translation import gettext_lazy as _ from rest_framework.serializers import Serializer @@ -40,7 +38,7 @@ class RadiusProvider(OutpostModel, Provider): ) @property - def launch_url(self) -> Optional[str]: + def launch_url(self) -> str | None: """Radius never has a launch URL""" return None @@ -49,7 +47,7 @@ class RadiusProvider(OutpostModel, Provider): return "ak-provider-radius-form" @property - def serializer(self) -> Type[Serializer]: + def serializer(self) -> type[Serializer]: from authentik.providers.radius.api import RadiusProviderSerializer return RadiusProviderSerializer diff --git a/authentik/providers/saml/api/providers.py b/authentik/providers/saml/api/providers.py index eb3d5cbb43..ecaf673e66 100644 --- a/authentik/providers/saml/api/providers.py +++ b/authentik/providers/saml/api/providers.py @@ -70,7 +70,7 @@ class SAMLProviderSerializer(ProviderSerializer): kwargs={"application_slug": instance.application.slug}, ) ) - except Provider.application.RelatedObjectDoesNotExist: # pylint: disable=no-member + except Provider.application.RelatedObjectDoesNotExist: return "-" def get_url_sso_redirect(self, instance: SAMLProvider) -> str: @@ -85,7 +85,7 @@ class SAMLProviderSerializer(ProviderSerializer): kwargs={"application_slug": instance.application.slug}, ) ) - except Provider.application.RelatedObjectDoesNotExist: # pylint: disable=no-member + except Provider.application.RelatedObjectDoesNotExist: return "-" def get_url_sso_init(self, instance: SAMLProvider) -> str: @@ -100,7 +100,7 @@ class SAMLProviderSerializer(ProviderSerializer): kwargs={"application_slug": instance.application.slug}, ) ) - except Provider.application.RelatedObjectDoesNotExist: # pylint: disable=no-member + except Provider.application.RelatedObjectDoesNotExist: return "-" def get_url_slo_post(self, instance: SAMLProvider) -> str: @@ -115,7 +115,7 @@ class SAMLProviderSerializer(ProviderSerializer): kwargs={"application_slug": instance.application.slug}, ) ) - except Provider.application.RelatedObjectDoesNotExist: # pylint: disable=no-member + except Provider.application.RelatedObjectDoesNotExist: return "-" def get_url_slo_redirect(self, instance: SAMLProvider) -> str: @@ -130,7 +130,7 @@ class SAMLProviderSerializer(ProviderSerializer): kwargs={"application_slug": instance.application.slug}, ) ) - except Provider.application.RelatedObjectDoesNotExist: # pylint: disable=no-member + except Provider.application.RelatedObjectDoesNotExist: return "-" class Meta: @@ -216,7 +216,7 @@ class SAMLProviderViewSet(UsedByMixin, ModelViewSet): try: provider = get_object_or_404(SAMLProvider, pk=pk) except ValueError: - raise Http404 + raise Http404 from None try: proc = MetadataProcessor(provider, request) proc.force_binding = request.query_params.get("force_binding", None) @@ -228,7 +228,7 @@ class SAMLProviderViewSet(UsedByMixin, ModelViewSet): ) return response return Response({"metadata": metadata}) - except Provider.application.RelatedObjectDoesNotExist: # pylint: disable=no-member + except Provider.application.RelatedObjectDoesNotExist: return Response({"metadata": ""}) @permission_required( @@ -258,7 +258,7 @@ class SAMLProviderViewSet(UsedByMixin, ModelViewSet): try: fromstring(file.read()) except ParseError: - raise ValidationError(_("Invalid XML Syntax")) + raise ValidationError(_("Invalid XML Syntax")) from None file.seek(0) try: metadata = ServiceProviderMetadataParser().parse(file.read().decode()) @@ -268,8 +268,8 @@ class SAMLProviderViewSet(UsedByMixin, ModelViewSet): except ValueError as exc: # pragma: no cover LOGGER.warning(str(exc)) raise ValidationError( - _("Failed to import Metadata: %(message)s" % {"message": str(exc)}), - ) + _("Failed to import Metadata: {messages}".format_map({"message": str(exc)})), + ) from None return Response(status=204) @permission_required( @@ -303,7 +303,7 @@ class SAMLProviderViewSet(UsedByMixin, ModelViewSet): if not for_user: raise ValidationError({"for_user": "User not found"}) except ValueError: - raise ValidationError({"for_user": "input must be numerical"}) + raise ValidationError({"for_user": "input must be numerical"}) from None new_request = copy(request._request) new_request.user = for_user diff --git a/authentik/providers/saml/models.py b/authentik/providers/saml/models.py index e7955cae8b..9c8afa591b 100644 --- a/authentik/providers/saml/models.py +++ b/authentik/providers/saml/models.py @@ -1,7 +1,5 @@ """authentik saml_idp Models""" -from typing import Optional - from django.db import models from django.urls import reverse from django.utils.translation import gettext_lazy as _ @@ -144,10 +142,10 @@ class SAMLProvider(Provider): ) @property - def launch_url(self) -> Optional[str]: + def launch_url(self) -> str | None: """Use IDP-Initiated SAML flow as launch URL""" try: - # pylint: disable=no-member + return reverse( "authentik_providers_saml:sso-init", kwargs={"application_slug": self.application.slug}, diff --git a/authentik/providers/saml/processors/assertion.py b/authentik/providers/saml/processors/assertion.py index fff0b2a551..d305d672ad 100644 --- a/authentik/providers/saml/processors/assertion.py +++ b/authentik/providers/saml/processors/assertion.py @@ -92,16 +92,15 @@ class AssertionProcessor: attribute.attrib["FriendlyName"] = mapping.friendly_name attribute.attrib["Name"] = mapping.saml_name - if not isinstance(value, (list, GeneratorType)): + if not isinstance(value, list | GeneratorType): value = [value] for value_item in value: attribute_value = SubElement( attribute, f"{{{NS_SAML_ASSERTION}}}AttributeValue" ) - if not isinstance(value_item, str): - value_item = str(value_item) - attribute_value.text = value_item + str_value = str(value_item) if not isinstance(value_item, str) else value_item + attribute_value.text = str_value attribute_statement.append(attribute) @@ -166,7 +165,6 @@ class AssertionProcessor: audience.text = self.provider.audience return conditions - # pylint: disable=too-many-return-statements def get_name_id(self) -> Element: """Get NameID Element""" name_id = Element(f"{{{NS_SAML_ASSERTION}}}NameID") diff --git a/authentik/providers/saml/processors/authn_request_parser.py b/authentik/providers/saml/processors/authn_request_parser.py index 2e86d8ba49..16585dd12b 100644 --- a/authentik/providers/saml/processors/authn_request_parser.py +++ b/authentik/providers/saml/processors/authn_request_parser.py @@ -2,7 +2,6 @@ from base64 import b64decode from dataclasses import dataclass -from typing import Optional from urllib.parse import quote_plus from xml.etree.ElementTree import ParseError # nosec @@ -36,9 +35,9 @@ ERROR_FAILED_TO_VERIFY = "Failed to verify signature" class AuthNRequest: """AuthNRequest Dataclass""" - id: Optional[str] = None + id: str | None = None - relay_state: Optional[str] = None + relay_state: str | None = None name_id_policy: str = SAML_NAME_ID_FORMAT_UNSPECIFIED @@ -52,7 +51,7 @@ class AuthNRequestParser: self.provider = provider self.logger = get_logger().bind(provider=self.provider) - def _parse_xml(self, decoded_xml: str | bytes, relay_state: Optional[str]) -> AuthNRequest: + def _parse_xml(self, decoded_xml: str | bytes, relay_state: str | None) -> AuthNRequest: root = ElementTree.fromstring(decoded_xml) # http://docs.oasis-open.org/security/saml/v2.0/saml-core-2.0-os.pdf @@ -83,12 +82,12 @@ class AuthNRequestParser: return auth_n_request - def parse(self, saml_request: str, relay_state: Optional[str] = None) -> AuthNRequest: + def parse(self, saml_request: str, relay_state: str | None = None) -> AuthNRequest: """Validate and parse raw request with enveloped signautre.""" try: decoded_xml = b64decode(saml_request.encode()) except UnicodeDecodeError: - raise CannotHandleAssertion(ERROR_CANNOT_DECODE_REQUEST) + raise CannotHandleAssertion(ERROR_CANNOT_DECODE_REQUEST) from None verifier = self.provider.verification_kp if not verifier: @@ -121,15 +120,15 @@ class AuthNRequestParser: def parse_detached( self, saml_request: str, - relay_state: Optional[str], - signature: Optional[str] = None, - sig_alg: Optional[str] = None, + relay_state: str | None, + signature: str | None = None, + sig_alg: str | None = None, ) -> AuthNRequest: """Validate and parse raw request with detached signature""" try: decoded_xml = decode_base64_and_inflate(saml_request) except UnicodeDecodeError: - raise CannotHandleAssertion(ERROR_CANNOT_DECODE_REQUEST) + raise CannotHandleAssertion(ERROR_CANNOT_DECODE_REQUEST) from None verifier = self.provider.verification_kp if not verifier: diff --git a/authentik/providers/saml/processors/logout_request_parser.py b/authentik/providers/saml/processors/logout_request_parser.py index 6bb6ec6288..f34d8031d6 100644 --- a/authentik/providers/saml/processors/logout_request_parser.py +++ b/authentik/providers/saml/processors/logout_request_parser.py @@ -2,7 +2,6 @@ from base64 import b64decode from dataclasses import dataclass -from typing import Optional from defusedxml import ElementTree @@ -17,11 +16,11 @@ from authentik.sources.saml.processors.constants import NS_SAML_PROTOCOL class LogoutRequest: """Logout Request""" - id: Optional[str] = None + id: str | None = None - issuer: Optional[str] = None + issuer: str | None = None - relay_state: Optional[str] = None + relay_state: str | None = None class LogoutRequestParser: @@ -32,9 +31,7 @@ class LogoutRequestParser: def __init__(self, provider: SAMLProvider): self.provider = provider - def _parse_xml( - self, decoded_xml: str | bytes, relay_state: Optional[str] = None - ) -> LogoutRequest: + def _parse_xml(self, decoded_xml: str | bytes, relay_state: str | None = None) -> LogoutRequest: root = ElementTree.fromstring(decoded_xml) request = LogoutRequest( id=root.attrib["ID"], @@ -45,23 +42,23 @@ class LogoutRequestParser: request.relay_state = relay_state return request - def parse(self, saml_request: str, relay_state: Optional[str] = None) -> LogoutRequest: + def parse(self, saml_request: str, relay_state: str | None = None) -> LogoutRequest: """Validate and parse raw request with enveloped signautre.""" try: decoded_xml = b64decode(saml_request.encode()) except UnicodeDecodeError: - raise CannotHandleAssertion(ERROR_CANNOT_DECODE_REQUEST) + raise CannotHandleAssertion(ERROR_CANNOT_DECODE_REQUEST) from None return self._parse_xml(decoded_xml, relay_state) def parse_detached( self, saml_request: str, - relay_state: Optional[str] = None, + relay_state: str | None = None, ) -> LogoutRequest: """Validate and parse raw request with detached signature""" try: decoded_xml = decode_base64_and_inflate(saml_request) except UnicodeDecodeError: - raise CannotHandleAssertion(ERROR_CANNOT_DECODE_REQUEST) + raise CannotHandleAssertion(ERROR_CANNOT_DECODE_REQUEST) from None return self._parse_xml(decoded_xml, relay_state) diff --git a/authentik/providers/saml/processors/metadata.py b/authentik/providers/saml/processors/metadata.py index ec2641cde2..e6a2dd0139 100644 --- a/authentik/providers/saml/processors/metadata.py +++ b/authentik/providers/saml/processors/metadata.py @@ -1,7 +1,8 @@ """SAML Identity Provider Metadata Processor""" +from collections.abc import Iterator from hashlib import sha256 -from typing import Iterator, Optional +from typing import Optional import xmlsec # nosec from django.http import HttpRequest @@ -31,7 +32,7 @@ class MetadataProcessor: provider: SAMLProvider http_request: HttpRequest - force_binding: Optional[str] + force_binding: str | None def __init__(self, provider: SAMLProvider, request: HttpRequest): self.provider = provider @@ -39,7 +40,8 @@ class MetadataProcessor: self.force_binding = None self.xml_id = "_" + sha256(f"{provider.name}-{provider.pk}".encode("ascii")).hexdigest() - def get_signing_key_descriptor(self) -> Optional[Element]: + # Using type unions doesn't work with cython types (which is what lxml is) + def get_signing_key_descriptor(self) -> Optional[Element]: # noqa: UP007 """Get Signing KeyDescriptor, if enabled for the provider""" if not self.provider.signing_kp: return None diff --git a/authentik/providers/saml/processors/metadata_parser.py b/authentik/providers/saml/processors/metadata_parser.py index c9e8fb27d5..dd42b52af5 100644 --- a/authentik/providers/saml/processors/metadata_parser.py +++ b/authentik/providers/saml/processors/metadata_parser.py @@ -1,7 +1,6 @@ """SAML ServiceProvider Metadata Parser and dataclass""" from dataclasses import dataclass -from typing import Optional import xmlsec from cryptography.hazmat.backends import default_backend @@ -48,7 +47,7 @@ class ServiceProviderMetadata: auth_n_request_signed: bool assertion_signed: bool - signing_keypair: Optional[CertificateKeyPair] = None + signing_keypair: CertificateKeyPair | None = None def to_provider(self, name: str, authorization_flow: Flow) -> SAMLProvider: """Create a SAMLProvider instance from the details. `name` is required, @@ -76,7 +75,7 @@ class ServiceProviderMetadata: class ServiceProviderMetadataParser: """Service-Provider Metadata Parser""" - def get_signing_cert(self, root: etree.Element) -> Optional[CertificateKeyPair]: + def get_signing_cert(self, root: etree.Element) -> CertificateKeyPair | None: """Extract X509Certificate from metadata, when given.""" signing_certs = root.xpath( '//md:SPSSODescriptor/md:KeyDescriptor[@use="signing"]//ds:X509Certificate/text()', diff --git a/authentik/providers/saml/utils/time.py b/authentik/providers/saml/utils/time.py index 678fa299d2..dda87d5ca8 100644 --- a/authentik/providers/saml/utils/time.py +++ b/authentik/providers/saml/utils/time.py @@ -1,10 +1,9 @@ """Time utilities""" import datetime -from typing import Optional -def get_time_string(delta: Optional[datetime.timedelta] = None) -> str: +def get_time_string(delta: datetime.timedelta | None = None) -> str: """Get Data formatted in SAML format""" if delta is None: delta = datetime.timedelta() diff --git a/authentik/providers/saml/views/flows.py b/authentik/providers/saml/views/flows.py index af0c142e93..afe18f8b20 100644 --- a/authentik/providers/saml/views/flows.py +++ b/authentik/providers/saml/views/flows.py @@ -85,7 +85,7 @@ class SAMLFlowFinalView(ChallengeStageView): "component": "ak-stage-autosubmit", "title": self.executor.plan.context.get( PLAN_CONTEXT_TITLE, - _("Redirecting to %(app)s..." % {"app": application.name}), + _("Redirecting to {app}...".format_map({"app": application.name})), ), "url": provider.acs_url, "attrs": form_attrs, diff --git a/authentik/providers/saml/views/slo.py b/authentik/providers/saml/views/slo.py index ef7b7edfd5..7f38b1f31a 100644 --- a/authentik/providers/saml/views/slo.py +++ b/authentik/providers/saml/views/slo.py @@ -1,7 +1,5 @@ """SLO Views""" -from typing import Optional - from django.http import HttpRequest from django.http.response import HttpResponse from django.shortcuts import get_object_or_404, redirect @@ -36,7 +34,7 @@ class SAMLSLOView(PolicyAccessView): SAMLProvider, pk=self.application.provider_id ) - def check_saml_request(self) -> Optional[HttpRequest]: + def check_saml_request(self) -> HttpRequest | None: """Handler to verify the SAML Request. Must be implemented by a subclass""" raise NotImplementedError @@ -61,7 +59,7 @@ class SAMLSLOView(PolicyAccessView): class SAMLSLOBindingRedirectView(SAMLSLOView): """SAML Handler for SLO/Redirect bindings, which are sent via GET""" - def check_saml_request(self) -> Optional[HttpRequest]: + def check_saml_request(self) -> HttpRequest | None: if REQUEST_KEY_SAML_REQUEST not in self.request.GET: LOGGER.info("check_saml_request: SAML payload missing") return bad_request_message(self.request, "The SAML request payload is missing.") @@ -88,7 +86,7 @@ class SAMLSLOBindingRedirectView(SAMLSLOView): class SAMLSLOBindingPOSTView(SAMLSLOView): """SAML Handler for SLO/POST bindings""" - def check_saml_request(self) -> Optional[HttpRequest]: + def check_saml_request(self) -> HttpRequest | None: payload = self.request.POST if REQUEST_KEY_SAML_REQUEST not in payload: LOGGER.info("check_saml_request: SAML payload missing") diff --git a/authentik/providers/saml/views/sso.py b/authentik/providers/saml/views/sso.py index 0cf281bb62..fcc9f7bbf7 100644 --- a/authentik/providers/saml/views/sso.py +++ b/authentik/providers/saml/views/sso.py @@ -1,7 +1,5 @@ """authentik SAML IDP Views""" -from typing import Optional - from django.http import Http404, HttpRequest, HttpResponse from django.shortcuts import get_object_or_404 from django.utils.decorators import method_decorator @@ -48,7 +46,7 @@ class SAMLSSOView(PolicyAccessView): SAMLProvider, pk=self.application.provider_id ) - def check_saml_request(self) -> Optional[HttpRequest]: + def check_saml_request(self) -> HttpRequest | None: """Handler to verify the SAML Request. Must be implemented by a subclass""" raise NotImplementedError @@ -74,7 +72,7 @@ class SAMLSSOView(PolicyAccessView): }, ) except FlowNonApplicableException: - raise Http404 + raise Http404 from None plan.append_stage(in_memory_stage(SAMLFlowFinalView)) request.session[SESSION_KEY_PLAN] = plan return redirect_with_qs( @@ -92,7 +90,7 @@ class SAMLSSOView(PolicyAccessView): class SAMLSSOBindingRedirectView(SAMLSSOView): """SAML Handler for SSO/Redirect bindings, which are sent via GET""" - def check_saml_request(self) -> Optional[HttpRequest]: + def check_saml_request(self) -> HttpRequest | None: """Handle REDIRECT bindings""" if REQUEST_KEY_SAML_REQUEST not in self.request.GET: LOGGER.info("SAML payload missing") @@ -122,7 +120,7 @@ class SAMLSSOBindingRedirectView(SAMLSSOView): class SAMLSSOBindingPOSTView(SAMLSSOView): """SAML Handler for SSO/POST bindings""" - def check_saml_request(self) -> Optional[HttpRequest]: + def check_saml_request(self) -> HttpRequest | None: """Handle POST bindings""" payload = self.request.POST # Restore the post body from the session @@ -149,7 +147,7 @@ class SAMLSSOBindingPOSTView(SAMLSSOView): class SAMLSSOBindingInitView(SAMLSSOView): """SAML Handler for for IdP Initiated login flows""" - def check_saml_request(self) -> Optional[HttpRequest]: + def check_saml_request(self) -> HttpRequest | None: """Create SAML Response from scratch""" LOGGER.debug("No SAML Request, using IdP-initiated flow.") auth_n_request = AuthNRequestParser(self.provider).idp_initiated() diff --git a/authentik/providers/scim/clients/base.py b/authentik/providers/scim/clients/base.py index 4038a14942..1d27228c7f 100644 --- a/authentik/providers/scim/clients/base.py +++ b/authentik/providers/scim/clients/base.py @@ -2,6 +2,7 @@ from typing import Generic, TypeVar +from django.http import HttpResponseBadRequest, HttpResponseNotFound from pydantic import ValidationError from requests import RequestException, Session from structlog.stdlib import get_logger @@ -12,7 +13,7 @@ from authentik.providers.scim.clients.schema import ServiceProviderConfiguration from authentik.providers.scim.models import SCIMProvider T = TypeVar("T") -# pylint: disable=invalid-name + SchemaType = TypeVar("SchemaType") @@ -54,14 +55,14 @@ class SCIMClient(Generic[T, SchemaType]): except RequestException as exc: raise SCIMRequestException(message="Failed to send request") from exc self.logger.debug("scim request", path=path, method=method, **kwargs) - if response.status_code >= 400: - if response.status_code == 404: + if response.status_code >= HttpResponseBadRequest.status_code: + if response.status_code == HttpResponseNotFound.status_code: raise ResourceMissing(response) self.logger.warning( "Failed to send SCIM request", path=path, method=method, response=response.text ) raise SCIMRequestException(response) - if response.status_code == 204: + if response.status_code == 204: # noqa: PLR2004 return {} return response.json() diff --git a/authentik/providers/scim/clients/exceptions.py b/authentik/providers/scim/clients/exceptions.py index dc6d7a7b2b..76cd5a3fa4 100644 --- a/authentik/providers/scim/clients/exceptions.py +++ b/authentik/providers/scim/clients/exceptions.py @@ -1,7 +1,5 @@ """SCIM Client exceptions""" -from typing import Optional - from pydantic import ValidationError from requests import Response @@ -12,7 +10,7 @@ from authentik.providers.scim.clients.schema import SCIMError class StopSync(SentryIgnoredException): """Exception raised when a configuration error should stop the sync process""" - def __init__(self, exc: Exception, obj: object, mapping: Optional[object] = None) -> None: + def __init__(self, exc: Exception, obj: object, mapping: object | None = None) -> None: self.exc = exc self.obj = obj self.mapping = mapping @@ -29,10 +27,10 @@ class StopSync(SentryIgnoredException): class SCIMRequestException(SentryIgnoredException): """Exception raised when an SCIM request fails""" - _response: Optional[Response] - _message: Optional[str] + _response: Response | None + _message: str | None - def __init__(self, response: Optional[Response] = None, message: Optional[str] = None) -> None: + def __init__(self, response: Response | None = None, message: str | None = None) -> None: self._response = response self._message = message diff --git a/authentik/providers/scim/clients/schema.py b/authentik/providers/scim/clients/schema.py index 9e96710059..21f187f45e 100644 --- a/authentik/providers/scim/clients/schema.py +++ b/authentik/providers/scim/clients/schema.py @@ -1,15 +1,12 @@ """Custom SCIM schemas""" -from typing import Optional - from pydanticscim.group import Group as BaseGroup from pydanticscim.responses import PatchRequest as BasePatchRequest from pydanticscim.responses import SCIMError as BaseSCIMError -from pydanticscim.service_provider import Bulk, ChangePassword, Filter, Patch +from pydanticscim.service_provider import Bulk, ChangePassword, Filter, Patch, Sort from pydanticscim.service_provider import ( ServiceProviderConfiguration as BaseServiceProviderConfiguration, ) -from pydanticscim.service_provider import Sort from pydanticscim.user import User as BaseUser @@ -17,20 +14,20 @@ class User(BaseUser): """Modified User schema with added externalId field""" schemas: tuple[str] = ("urn:ietf:params:scim:schemas:core:2.0:User",) - externalId: Optional[str] = None + externalId: str | None = None class Group(BaseGroup): """Modified Group schema with added externalId field""" schemas: tuple[str] = ("urn:ietf:params:scim:schemas:core:2.0:Group",) - externalId: Optional[str] = None + externalId: str | None = None class ServiceProviderConfiguration(BaseServiceProviderConfiguration): """ServiceProviderConfig with fallback""" - _is_fallback: Optional[bool] = False + _is_fallback: bool | None = False @property def is_fallback(self) -> bool: @@ -61,4 +58,4 @@ class PatchRequest(BasePatchRequest): class SCIMError(BaseSCIMError): """SCIM error with optional status code""" - status: Optional[int] + status: int | None diff --git a/authentik/providers/scim/models.py b/authentik/providers/scim/models.py index 7f58288e2e..e078eead4c 100644 --- a/authentik/providers/scim/models.py +++ b/authentik/providers/scim/models.py @@ -104,6 +104,9 @@ class SCIMUser(models.Model): class Meta: unique_together = (("id", "user", "provider"),) + def __str__(self) -> str: + return f"SCIM User {self.user.username} to {self.provider.name}" + class SCIMGroup(models.Model): """Mapping of a group and provider to a SCIM user ID""" @@ -114,3 +117,6 @@ class SCIMGroup(models.Model): class Meta: unique_together = (("id", "group", "provider"),) + + def __str__(self) -> str: + return f"SCIM Group {self.group.name} to {self.provider.name}" diff --git a/authentik/providers/scim/tasks.py b/authentik/providers/scim/tasks.py index a98392a2ec..15d0caea5d 100644 --- a/authentik/providers/scim/tasks.py +++ b/authentik/providers/scim/tasks.py @@ -1,6 +1,6 @@ """SCIM Provider tasks""" -from typing import Any, Optional +from typing import Any from celery.result import allow_join_result from django.core.paginator import Paginator @@ -101,21 +101,23 @@ def scim_sync_users(page: int, provider_pk: int): LOGGER.warning("failed to sync user", exc=exc, user=user) messages.append( _( - "Failed to sync user %(user_name)s due to remote error: %(error)s" - % { - "user_name": user.username, - "error": exc.detail(), - } + "Failed to sync user {user_name} due to remote error: {error}".format_map( + { + "user_name": user.username, + "error": exc.detail(), + } + ) ) ) except StopSync as exc: LOGGER.warning("Stopping sync", exc=exc) messages.append( _( - "Stopping sync due to error: %(error)s" - % { - "error": exc.detail(), - } + "Stopping sync due to error: {error}".format_map( + { + "error": exc.detail(), + } + ) ) ) break @@ -142,21 +144,23 @@ def scim_sync_group(page: int, provider_pk: int): LOGGER.warning("failed to sync group", exc=exc, group=group) messages.append( _( - "Failed to sync group %(group_name)s due to remote error: %(error)s" - % { - "group_name": group.name, - "error": exc.detail(), - } + "Failed to sync group {group_name} due to remote error: {error}".format_map( + { + "group_name": group.name, + "error": exc.detail(), + } + ) ) ) except StopSync as exc: LOGGER.warning("Stopping sync", exc=exc) messages.append( _( - "Stopping sync due to error: %(error)s" - % { - "error": exc.detail(), - } + "Stopping sync due to error: {error}".format_map( + { + "error": exc.detail(), + } + ) ) ) break @@ -174,7 +178,7 @@ def scim_signal_direct(model: str, pk: Any, raw_op: str): for provider in SCIMProvider.objects.filter(backchannel_application__isnull=False): client = client_for_model(provider, instance) # Check if the object is allowed within the provider's restrictions - queryset: Optional[QuerySet] = None + queryset: QuerySet | None = None if isinstance(instance, User): queryset = provider.get_user_qs() if isinstance(instance, Group): diff --git a/authentik/providers/scim/tests/test_membership.py b/authentik/providers/scim/tests/test_membership.py index b8e31ec660..54d69b4561 100644 --- a/authentik/providers/scim/tests/test_membership.py +++ b/authentik/providers/scim/tests/test_membership.py @@ -49,7 +49,7 @@ class SCIMMembershipTests(TestCase): def test_member_add(self): """Test member add""" config = ServiceProviderConfiguration.default() - # pylint: disable=assigning-non-slot + config.patch.supported = True user_scim_id = generate_id() group_scim_id = generate_id() @@ -139,7 +139,7 @@ class SCIMMembershipTests(TestCase): def test_member_remove(self): """Test member remove""" config = ServiceProviderConfiguration.default() - # pylint: disable=assigning-non-slot + config.patch.supported = True user_scim_id = generate_id() group_scim_id = generate_id() diff --git a/authentik/rbac/api/rbac_roles.py b/authentik/rbac/api/rbac_roles.py index 37124b2393..60542dfc4f 100644 --- a/authentik/rbac/api/rbac_roles.py +++ b/authentik/rbac/api/rbac_roles.py @@ -1,7 +1,5 @@ """common RBAC serializers""" -from typing import Optional - from django.apps import apps from django_filters.filters import UUIDFilter from django_filters.filterset import FilterSet @@ -39,7 +37,7 @@ class ExtraRoleObjectPermissionSerializer(RoleObjectPermissionSerializer): except LookupError: return f"{instance.content_type.app_label}.{instance.content_type.model}" - def get_object_description(self, instance: GroupObjectPermission) -> Optional[str]: + def get_object_description(self, instance: GroupObjectPermission) -> str | None: """Get model description from attached model. This operation takes at least one additional query, and the description is only shown if the user/role has the view_ permission on the object""" diff --git a/authentik/rbac/api/rbac_users.py b/authentik/rbac/api/rbac_users.py index 0909c6e537..95a31de768 100644 --- a/authentik/rbac/api/rbac_users.py +++ b/authentik/rbac/api/rbac_users.py @@ -1,7 +1,5 @@ """common RBAC serializers""" -from typing import Optional - from django.apps import apps from django_filters.filters import NumberFilter from django_filters.filterset import FilterSet @@ -39,7 +37,7 @@ class ExtraUserObjectPermissionSerializer(UserObjectPermissionSerializer): except LookupError: return f"{instance.content_type.app_label}.{instance.content_type.model}" - def get_object_description(self, instance: UserObjectPermission) -> Optional[str]: + def get_object_description(self, instance: UserObjectPermission) -> str | None: """Get model description from attached model. This operation takes at least one additional query, and the description is only shown if the user/role has the view_ permission on the object""" diff --git a/authentik/rbac/decorators.py b/authentik/rbac/decorators.py index c3d17a8475..0438819309 100644 --- a/authentik/rbac/decorators.py +++ b/authentik/rbac/decorators.py @@ -1,7 +1,7 @@ """API Decorators""" +from collections.abc import Callable from functools import wraps -from typing import Callable, Optional from rest_framework.request import Request from rest_framework.response import Response @@ -11,7 +11,7 @@ from structlog.stdlib import get_logger LOGGER = get_logger() -def permission_required(obj_perm: Optional[str] = None, global_perms: Optional[list[str]] = None): +def permission_required(obj_perm: str | None = None, global_perms: list[str] | None = None): """Check permissions for a single custom action""" def _check_obj_perm(self: ModelViewSet, request: Request): diff --git a/authentik/rbac/models.py b/authentik/rbac/models.py index 4d6ae48f60..76c1dd0a91 100644 --- a/authentik/rbac/models.py +++ b/authentik/rbac/models.py @@ -1,6 +1,5 @@ """RBAC models""" -from typing import Optional from uuid import uuid4 from django.db import models @@ -31,7 +30,7 @@ class Role(SerializerModel): # name field has the same constraints as the group model name = models.TextField(max_length=150, unique=True) - def assign_permission(self, *perms: str, obj: Optional[models.Model] = None): + def assign_permission(self, *perms: str, obj: models.Model | None = None): """Assign permission to role, can handle multiple permissions, but when assigning multiple permissions to an object the permissions must all belong to the object given""" @@ -74,3 +73,6 @@ class SystemPermission(models.Model): ("view_system_settings", _("Can view system settings")), ("edit_system_settings", _("Can edit system settings")), ] + + def __str__(self) -> str: + return "System Permission" diff --git a/authentik/rbac/permissions.py b/authentik/rbac/permissions.py index 882be0e0f7..89ce1ca317 100644 --- a/authentik/rbac/permissions.py +++ b/authentik/rbac/permissions.py @@ -18,12 +18,10 @@ class ObjectPermissions(DjangoObjectPermissions): return super().has_object_permission(request, view, obj) -# pylint: disable=invalid-name def HasPermission(*perm: str) -> type[BasePermission]: """Permission checker for any non-object permissions, returns a BasePermission class that can be used with rest_framework""" - # pylint: disable=missing-class-docstring, invalid-name class checker(BasePermission): def has_permission(self, request: Request, view): return bool(request.user and request.user.has_perms(perm)) diff --git a/authentik/rbac/signals.py b/authentik/rbac/signals.py index b9650ac3e4..1f62f9725b 100644 --- a/authentik/rbac/signals.py +++ b/authentik/rbac/signals.py @@ -40,7 +40,6 @@ def rbac_group_role_m2m(sender: type[Group], action: str, instance: Group, rever LOGGER.debug("Updated users in group", group=instance) -# pylint: disable=no-member @receiver(m2m_changed, sender=Group.users.through) def rbac_group_users_m2m( sender: type[Group], action: str, instance: Group, pk_set: set, reverse: bool, **_ diff --git a/authentik/root/asgi.py b/authentik/root/asgi.py index c6cbab73ae..4c14429103 100644 --- a/authentik/root/asgi.py +++ b/authentik/root/asgi.py @@ -18,8 +18,8 @@ from sentry_sdk.integrations.asgi import SentryAsgiMiddleware defuse_stdlib() django.setup() -# pylint: disable=wrong-import-position -from authentik.root import websocket # noqa # isort:skip + +from authentik.root import websocket # noqa class LifespanApp: diff --git a/authentik/root/celery.py b/authentik/root/celery.py index 33c4b99a70..a7120d9901 100644 --- a/authentik/root/celery.py +++ b/authentik/root/celery.py @@ -1,11 +1,11 @@ """authentik core celery""" import os +from collections.abc import Callable from contextvars import ContextVar from logging.config import dictConfig from pathlib import Path from tempfile import gettempdir -from typing import Callable from celery import bootsteps from celery.apps.worker import Worker diff --git a/authentik/root/middleware.py b/authentik/root/middleware.py index 0c91a2d08a..88edae6143 100644 --- a/authentik/root/middleware.py +++ b/authentik/root/middleware.py @@ -1,15 +1,16 @@ """Dynamically set SameSite depending if the upstream connection is TLS or not""" +from collections.abc import Callable from hashlib import sha512 from time import perf_counter, time -from typing import Any, Callable, Optional +from typing import Any from django.conf import settings from django.contrib.sessions.backends.base import UpdateError from django.contrib.sessions.exceptions import SessionInterrupted from django.contrib.sessions.middleware import SessionMiddleware as UpstreamSessionMiddleware from django.http.request import HttpRequest -from django.http.response import HttpResponse +from django.http.response import HttpResponse, HttpResponseServerError from django.middleware.csrf import CSRF_SESSION_KEY from django.middleware.csrf import CsrfViewMiddleware as UpstreamCsrfViewMiddleware from django.utils.cache import patch_vary_headers @@ -99,7 +100,7 @@ class SessionMiddleware(UpstreamSessionMiddleware): expires = http_date(expires_time) # Save the session data and refresh the client cookie. # Skip session save for 500 responses, refs #3881. - if response.status_code != 500: + if response.status_code != HttpResponseServerError.status_code: try: request.session.save() except UpdateError: @@ -107,7 +108,7 @@ class SessionMiddleware(UpstreamSessionMiddleware): "The request's session was deleted before the " "request completed. The user may have logged " "out in a concurrent request, for example." - ) + ) from None payload = { "sid": request.session.session_key, "iss": "authentik", @@ -191,7 +192,7 @@ class ClientIPMiddleware: # FIXME: this should probably not be in `root` but rather in a middleware in `outposts` # but for now it's fine - def _get_outpost_override_ip(self, request: HttpRequest) -> Optional[str]: + def _get_outpost_override_ip(self, request: HttpRequest) -> str | None: """Get the actual remote IP when set by an outpost. Only allowed when the request is authenticated, by an outpost internal service account""" if ( @@ -228,7 +229,7 @@ class ClientIPMiddleware: setattr(request, self.request_attr_outpost_user, user) return delegated_ip - def _get_client_ip(self, request: Optional[HttpRequest]) -> str: + def _get_client_ip(self, request: HttpRequest | None) -> str: """Attempt to get the client's IP by checking common HTTP Headers. Returns none if no IP Could be found""" if not request: @@ -239,7 +240,7 @@ class ClientIPMiddleware: return self._get_client_ip_from_meta(request.META) @staticmethod - def get_outpost_user(request: HttpRequest) -> Optional[User]: + def get_outpost_user(request: HttpRequest) -> User | None: """Get outpost user that authenticated this request""" return getattr(request, ClientIPMiddleware.request_attr_outpost_user, None) diff --git a/authentik/root/storages.py b/authentik/root/storages.py index 0fbc4260b4..a0ccf27378 100644 --- a/authentik/root/storages.py +++ b/authentik/root/storages.py @@ -15,19 +15,16 @@ from authentik.lib.config import CONFIG class FileStorage(FileSystemStorage): """File storage backend""" - # pylint: disable=invalid-overridden-method @property def base_location(self): return os.path.join( self._value_or_setting(self._location, settings.MEDIA_ROOT), connection.schema_name ) - # pylint: disable=invalid-overridden-method @property def location(self): return os.path.abspath(self.base_location) - # pylint: disable=invalid-overridden-method @property def base_url(self): if self._base_url is not None and not self._base_url.endswith("/"): @@ -35,7 +32,6 @@ class FileStorage(FileSystemStorage): return f"{self._base_url}/{connection.schema_name}/" -# pylint: disable=abstract-method class S3Storage(BaseS3Storage): """S3 storage backend""" @@ -77,13 +73,12 @@ class S3Storage(BaseS3Storage): def _normalize_name(self, name): try: - # pylint: disable=no-member + return safe_join(self.location, connection.schema_name, name) except ValueError: - raise SuspiciousOperation("Attempted access to '%s' denied." % name) + raise SuspiciousOperation("Attempted access to '%s' denied." % name) from None # This is a fix for https://github.com/jschneier/django-storages/pull/839 - # pylint: disable=arguments-differ,no-member def url(self, name, parameters=None, expire=None, http_method=None): # Preserve the trailing slash after normalizing the path. name = self._normalize_name(clean_name(name)) @@ -109,7 +104,7 @@ class S3Storage(BaseS3Storage): # Remove signing parameter and previously added key "/". root_url = self._strip_signing_parameters(root_url_signed)[:-1] # Replace bucket domain with custom domain. - custom_url = "{}//{}/".format(self.url_protocol, self.custom_domain) + custom_url = f"{self.url_protocol}//{self.custom_domain}/" url = url.replace(root_url, custom_url) if self.querystring_auth: diff --git a/authentik/root/test_runner.py b/authentik/root/test_runner.py index e9d90270e7..eddb884b73 100644 --- a/authentik/root/test_runner.py +++ b/authentik/root/test_runner.py @@ -77,23 +77,21 @@ class PytestTestRunner(DiscoverRunner): # pragma: no cover if os.path.exists(label_as_path): self.args.append(label_as_path) valid_label_found = True + elif "::" in label: + self.args.append(label) + valid_label_found = True + # Convert dotted module path to file_path::class::method else: - # Already correctly formatted test found (file_path::class::method) - if "::" in label: - self.args.append(label) - valid_label_found = True - # Convert dotted module path to file_path::class::method - else: - path_pieces = label.split(".") - # Check whether only class or class and method are specified - for i in range(-1, -3, -1): - path = os.path.join(*path_pieces[:i]) + ".py" - label_as_path = os.path.abspath(path) - if os.path.exists(label_as_path): - path_method = label_as_path + "::" + "::".join(path_pieces[i:]) - self.args.append(path_method) - valid_label_found = True - break + path_pieces = label.split(".") + # Check whether only class or class and method are specified + for i in range(-1, -3, -1): + path = os.path.join(*path_pieces[:i]) + ".py" + label_as_path = os.path.abspath(path) + if os.path.exists(label_as_path): + path_method = label_as_path + "::" + "::".join(path_pieces[i:]) + self.args.append(path_method) + valid_label_found = True + break if not valid_label_found: raise RuntimeError( diff --git a/authentik/root/urls.py b/authentik/root/urls.py index 0da8ee814c..1b03051fdd 100644 --- a/authentik/root/urls.py +++ b/authentik/root/urls.py @@ -20,10 +20,10 @@ for _authentik_app in get_apps(): mountpoints = None base_url_module = _authentik_app.name + ".urls" if hasattr(_authentik_app, "mountpoint"): - mountpoint = getattr(_authentik_app, "mountpoint") + mountpoint = _authentik_app.mountpoint mountpoints = {base_url_module: mountpoint} if hasattr(_authentik_app, "mountpoints"): - mountpoints = getattr(_authentik_app, "mountpoints") + mountpoints = _authentik_app.mountpoints if not mountpoints: continue for module, mountpoint in mountpoints.items(): diff --git a/authentik/root/websocket.py b/authentik/root/websocket.py index 5d2a0be250..560a39c07d 100644 --- a/authentik/root/websocket.py +++ b/authentik/root/websocket.py @@ -16,7 +16,7 @@ for _authentik_app in get_apps(): continue if not hasattr(api_urls, "websocket_urlpatterns"): continue - urls: list = getattr(api_urls, "websocket_urlpatterns") + urls: list = api_urls.websocket_urlpatterns websocket_urlpatterns.extend(urls) LOGGER.debug( "Mounted Websocket URLs", diff --git a/authentik/sources/ldap/api.py b/authentik/sources/ldap/api.py index 34def44b03..dc45523380 100644 --- a/authentik/sources/ldap/api.py +++ b/authentik/sources/ldap/api.py @@ -1,6 +1,6 @@ """Source API Views""" -from typing import Any, Optional +from typing import Any from django.core.cache import cache from django_filters.filters import AllValuesMultipleFilter @@ -39,7 +39,7 @@ class LDAPSourceSerializer(SourceSerializer): required=False, ) - def get_connectivity(self, source: LDAPSource) -> Optional[dict[str, dict[str, str]]]: + def get_connectivity(self, source: LDAPSource) -> dict[str, dict[str, str]] | None: """Get cached source connectivity""" return cache.get(CACHE_KEY_STATUS + source.slug, None) diff --git a/authentik/sources/ldap/auth.py b/authentik/sources/ldap/auth.py index 7d351c8be5..a271dac5a5 100644 --- a/authentik/sources/ldap/auth.py +++ b/authentik/sources/ldap/auth.py @@ -1,7 +1,5 @@ """authentik LDAP Authentication Backend""" -from typing import Optional - from django.http import HttpRequest from ldap3.core.exceptions import LDAPException, LDAPInvalidCredentialsResult from structlog.stdlib import get_logger @@ -29,7 +27,7 @@ class LDAPBackend(InbuiltBackend): return user return None - def auth_user(self, source: LDAPSource, password: str, **filters: str) -> Optional[User]: + def auth_user(self, source: LDAPSource, password: str, **filters: str) -> User | None: """Try to bind as either user_dn or mail with password. Returns True on success, otherwise False""" users = User.objects.filter(**filters) @@ -52,7 +50,7 @@ class LDAPBackend(InbuiltBackend): LOGGER.debug("Failed to bind, password invalid") return None - def auth_user_by_bind(self, source: LDAPSource, user: User, password: str) -> Optional[User]: + def auth_user_by_bind(self, source: LDAPSource, user: User, password: str) -> User | None: """Attempt authentication by binding to the LDAP server as `user`. This method should be avoided as its slow to do the bind.""" # Try to bind as new user diff --git a/authentik/sources/ldap/models.py b/authentik/sources/ldap/models.py index 1c3552879b..074e49b5f4 100644 --- a/authentik/sources/ldap/models.py +++ b/authentik/sources/ldap/models.py @@ -5,7 +5,6 @@ from os.path import dirname, exists from shutil import rmtree from ssl import CERT_REQUIRED from tempfile import NamedTemporaryFile, mkdtemp -from typing import Optional from django.core.cache import cache from django.db import connection, models @@ -160,9 +159,9 @@ class LDAPSource(Source): def connection( self, - server: Optional[Server] = None, - server_kwargs: Optional[dict] = None, - connection_kwargs: Optional[dict] = None, + server: Server | None = None, + server_kwargs: dict | None = None, + connection_kwargs: dict | None = None, ) -> Connection: """Get a fully connected and bound LDAP Connection""" server_kwargs = server_kwargs or {} diff --git a/authentik/sources/ldap/password.py b/authentik/sources/ldap/password.py index 2b8d70fb49..662778aa8d 100644 --- a/authentik/sources/ldap/password.py +++ b/authentik/sources/ldap/password.py @@ -2,7 +2,6 @@ from enum import IntFlag from re import split -from typing import Optional from ldap3 import BASE from ldap3.core.exceptions import ( @@ -20,6 +19,7 @@ LOGGER = get_logger() NON_ALPHA = r"~!@#$%^&*_-+=`|\(){}[]:;\"'<>,.?/" RE_DISPLAYNAME_SEPARATORS = r",\.–—_\s#\t" +MIN_TOKEN_SIZE = 3 class PwdProperties(IntFlag): @@ -119,7 +119,7 @@ class LDAPPasswordChanger: raise AssertionError() user_attributes = users[0]["attributes"] # If sAMAccountName is longer than 3 chars, check if its contained in password - if len(user_attributes["sAMAccountName"]) >= 3: + if len(user_attributes["sAMAccountName"]) >= MIN_TOKEN_SIZE: if password.lower() in user_attributes["sAMAccountName"].lower(): return False # No display name set, can't check any further @@ -129,13 +129,13 @@ class LDAPPasswordChanger: display_name_tokens = split(RE_DISPLAYNAME_SEPARATORS, display_name) for token in display_name_tokens: # Ignore tokens under 3 chars - if len(token) < 3: + if len(token) < MIN_TOKEN_SIZE: continue if token.lower() in password.lower(): return False return True - def ad_password_complexity(self, password: str, user: Optional[User] = None) -> bool: + def ad_password_complexity(self, password: str, user: User | None = None) -> bool: """Check if password matches Active directory password policies https://docs.microsoft.com/en-us/windows/security/threat-protection/ diff --git a/authentik/sources/ldap/sync/base.py b/authentik/sources/ldap/sync/base.py index 1ebacfc643..0a2e84cb0e 100644 --- a/authentik/sources/ldap/sync/base.py +++ b/authentik/sources/ldap/sync/base.py @@ -1,6 +1,7 @@ """Sync LDAP Users and groups into authentik""" -from typing import Any, Generator +from collections.abc import Generator +from typing import Any from django.conf import settings from django.db.models.base import Model @@ -90,8 +91,7 @@ class BaseLDAPSynchronizer: """Get objects from LDAP, implemented in subclass""" raise NotImplementedError() - # pylint: disable=too-many-arguments - def search_paginator( + def search_paginator( # noqa: PLR0913 self, search_base, search_filter, @@ -103,11 +103,13 @@ class BaseLDAPSynchronizer: types_only=False, get_operational_attributes=False, controls=None, - paged_size=CONFIG.get_int("ldap.page_size", 50), + paged_size=None, paged_criticality=False, ): """Search in pages, returns each page""" cookie = True + if not paged_size: + paged_size = CONFIG.get_int("ldap.page_size", 50) while cookie: self._connection.search( search_base, diff --git a/authentik/sources/ldap/sync/groups.py b/authentik/sources/ldap/sync/groups.py index 2f2f67307c..cf6d26999c 100644 --- a/authentik/sources/ldap/sync/groups.py +++ b/authentik/sources/ldap/sync/groups.py @@ -1,6 +1,6 @@ """Sync LDAP Users and groups into authentik""" -from typing import Generator +from collections.abc import Generator from django.core.exceptions import FieldError from django.db.utils import IntegrityError diff --git a/authentik/sources/ldap/sync/membership.py b/authentik/sources/ldap/sync/membership.py index 987ac59b9c..7d2a60f5d4 100644 --- a/authentik/sources/ldap/sync/membership.py +++ b/authentik/sources/ldap/sync/membership.py @@ -1,6 +1,7 @@ """Sync LDAP Users and groups into authentik""" -from typing import Any, Generator, Optional +from collections.abc import Generator +from typing import Any from django.db.models import Q from ldap3 import SUBTREE @@ -76,7 +77,7 @@ class MembershipLDAPSynchronizer(BaseLDAPSynchronizer): self._logger.debug("Successfully updated group membership") return membership_count - def get_group(self, group_dict: dict[str, Any]) -> Optional[Group]: + def get_group(self, group_dict: dict[str, Any]) -> Group | None: """Check if we fetched the group already, and if not cache it for later""" group_dn = group_dict.get("attributes", {}).get(LDAP_DISTINGUISHED_NAME, []) group_uniq = group_dict.get("attributes", {}).get(self._source.object_uniqueness_field, []) diff --git a/authentik/sources/ldap/sync/users.py b/authentik/sources/ldap/sync/users.py index 7e3afef54a..9d143579dd 100644 --- a/authentik/sources/ldap/sync/users.py +++ b/authentik/sources/ldap/sync/users.py @@ -1,6 +1,6 @@ """Sync LDAP Users into authentik""" -from typing import Generator +from collections.abc import Generator from django.core.exceptions import FieldError from django.db.utils import IntegrityError diff --git a/authentik/sources/ldap/sync/vendor/freeipa.py b/authentik/sources/ldap/sync/vendor/freeipa.py index e8fc046834..83405ddd19 100644 --- a/authentik/sources/ldap/sync/vendor/freeipa.py +++ b/authentik/sources/ldap/sync/vendor/freeipa.py @@ -1,7 +1,8 @@ """FreeIPA specific""" -from datetime import datetime, timezone -from typing import Any, Generator +from collections.abc import Generator +from datetime import UTC, datetime +from typing import Any from authentik.core.models import User from authentik.sources.ldap.sync.base import BaseLDAPSynchronizer, flatten @@ -26,7 +27,7 @@ class FreeIPA(BaseLDAPSynchronizer): if "krbLastPwdChange" not in attributes: return pwd_last_set: datetime = attributes.get("krbLastPwdChange", datetime.now()) - pwd_last_set = pwd_last_set.replace(tzinfo=timezone.utc) + pwd_last_set = pwd_last_set.replace(tzinfo=UTC) if created or pwd_last_set >= user.password_change_date: self.message(f"'{user.username}': Reset user's password") self._logger.debug( diff --git a/authentik/sources/ldap/sync/vendor/ms_ad.py b/authentik/sources/ldap/sync/vendor/ms_ad.py index fef2653622..e8fdf831c8 100644 --- a/authentik/sources/ldap/sync/vendor/ms_ad.py +++ b/authentik/sources/ldap/sync/vendor/ms_ad.py @@ -1,8 +1,9 @@ """Active Directory specific""" -from datetime import datetime, timezone +from collections.abc import Generator +from datetime import UTC, datetime from enum import IntFlag -from typing import Any, Generator +from typing import Any from authentik.core.models import User from authentik.sources.ldap.sync.base import BaseLDAPSynchronizer @@ -57,7 +58,7 @@ class MicrosoftActiveDirectory(BaseLDAPSynchronizer): if "pwdLastSet" not in attributes: return pwd_last_set: datetime = attributes.get("pwdLastSet", datetime.now()) - pwd_last_set = pwd_last_set.replace(tzinfo=timezone.utc) + pwd_last_set = pwd_last_set.replace(tzinfo=UTC) if created or pwd_last_set >= user.password_change_date: self.message(f"'{user.username}': Reset user's password") self._logger.debug( diff --git a/authentik/sources/ldap/tasks.py b/authentik/sources/ldap/tasks.py index 0f58349028..9184089b97 100644 --- a/authentik/sources/ldap/tasks.py +++ b/authentik/sources/ldap/tasks.py @@ -1,6 +1,5 @@ """LDAP Sync tasks""" -from typing import Optional from uuid import uuid4 from celery import chain, group @@ -40,7 +39,7 @@ def ldap_sync_all(): @CELERY_APP.task() -def ldap_connectivity_check(pk: Optional[str] = None): +def ldap_connectivity_check(pk: str | None = None): """Check connectivity for LDAP Sources""" # 2 hour timeout, this task should run every hour timeout = 60 * 60 * 2 diff --git a/authentik/sources/oauth/api/source.py b/authentik/sources/oauth/api/source.py index 16a862d1be..72c8b96108 100644 --- a/authentik/sources/oauth/api/source.py +++ b/authentik/sources/oauth/api/source.py @@ -57,7 +57,6 @@ class OAuthSourceSerializer(SourceSerializer): """Get source's type configuration""" return SourceTypeSerializer(instance.source_type).data - # pylint: disable=too-many-locals def validate(self, attrs: dict) -> dict: session = get_http_session() source_type = registry.find_type(attrs["provider_type"]) @@ -71,7 +70,7 @@ class OAuthSourceSerializer(SourceSerializer): well_known_config.raise_for_status() except RequestException as exc: text = exc.response.text if exc.response else str(exc) - raise ValidationError({"oidc_well_known_url": text}) + raise ValidationError({"oidc_well_known_url": text}) from None config = well_known_config.json() if "issuer" not in config: raise ValidationError({"oidc_well_known_url": "Invalid well-known configuration"}) @@ -97,7 +96,7 @@ class OAuthSourceSerializer(SourceSerializer): jwks_config.raise_for_status() except RequestException as exc: text = exc.response.text if exc.response else str(exc) - raise ValidationError({"oidc_jwks_url": text}) + raise ValidationError({"oidc_jwks_url": text}) from None config = jwks_config.json() attrs["oidc_jwks"] = config diff --git a/authentik/sources/oauth/clients/base.py b/authentik/sources/oauth/clients/base.py index f86011b17f..a580786e32 100644 --- a/authentik/sources/oauth/clients/base.py +++ b/authentik/sources/oauth/clients/base.py @@ -1,6 +1,6 @@ """OAuth Clients""" -from typing import Any, Optional +from typing import Any from urllib.parse import parse_qs, quote, urlencode, urlparse, urlunparse from django.http import HttpRequest @@ -22,20 +22,20 @@ class BaseOAuthClient: source: OAuthSource request: HttpRequest - callback: Optional[str] + callback: str | None - def __init__(self, source: OAuthSource, request: HttpRequest, callback: Optional[str] = None): + def __init__(self, source: OAuthSource, request: HttpRequest, callback: str | None = None): self.source = source self.session = get_http_session() self.request = request self.callback = callback self.logger = get_logger().bind(source=source.slug) - def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]: + def get_access_token(self, **request_kwargs) -> dict[str, Any] | None: """Fetch access token from callback request.""" raise NotImplementedError("Defined in a sub-class") # pragma: no cover - def get_profile_info(self, token: dict[str, str]) -> Optional[dict[str, Any]]: + def get_profile_info(self, token: dict[str, str]) -> dict[str, Any] | None: """Fetch user profile information.""" profile_url = self.source.source_type.profile_url or "" if self.source.source_type.urls_customizable and self.source.profile_url: diff --git a/authentik/sources/oauth/clients/oauth1.py b/authentik/sources/oauth/clients/oauth1.py index 739ebfb6aa..2617556dd8 100644 --- a/authentik/sources/oauth/clients/oauth1.py +++ b/authentik/sources/oauth/clients/oauth1.py @@ -1,6 +1,6 @@ """OAuth 1 Clients""" -from typing import Any, Optional +from typing import Any from urllib.parse import parse_qsl from requests.exceptions import RequestException @@ -21,7 +21,7 @@ class OAuthClient(BaseOAuthClient): "Accept": "application/json", } - def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]: + def get_access_token(self, **request_kwargs) -> dict[str, Any] | None: """Fetch access token from callback request.""" raw_token = self.request.session.get(self.session_key, None) verifier = self.request.GET.get("oauth_verifier", None) diff --git a/authentik/sources/oauth/clients/oauth2.py b/authentik/sources/oauth/clients/oauth2.py index ec2fa2f5a5..9ee3c05229 100644 --- a/authentik/sources/oauth/clients/oauth2.py +++ b/authentik/sources/oauth/clients/oauth2.py @@ -1,7 +1,7 @@ """OAuth 2 Clients""" from json import loads -from typing import Any, Optional +from typing import Any from urllib.parse import parse_qsl from django.utils.crypto import constant_time_compare, get_random_string @@ -23,7 +23,7 @@ class OAuth2Client(BaseOAuthClient): "Accept": "application/json", } - def get_request_arg(self, key: str, default: Optional[Any] = None) -> Any: + def get_request_arg(self, key: str, default: Any | None = None) -> Any: """Depending on request type, get data from post or get""" if self.request.method == "POST": return self.request.POST.get(key, default) @@ -55,7 +55,7 @@ class OAuth2Client(BaseOAuthClient): """Get client secret""" return self.source.consumer_secret - def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]: + def get_access_token(self, **request_kwargs) -> dict[str, Any] | None: """Fetch access token from callback request.""" callback = self.request.build_absolute_uri(self.callback or self.request.path) if not self.check_application_state(): @@ -139,7 +139,7 @@ class OAuth2Client(BaseOAuthClient): class UserprofileHeaderAuthClient(OAuth2Client): """OAuth client which only sends authentication via header, not querystring""" - def get_profile_info(self, token: dict[str, str]) -> Optional[dict[str, Any]]: + def get_profile_info(self, token: dict[str, str]) -> dict[str, Any] | None: "Fetch user profile information." profile_url = self.source.source_type.profile_url or "" if self.source.source_type.urls_customizable and self.source.profile_url: diff --git a/authentik/sources/oauth/models.py b/authentik/sources/oauth/models.py index aaaf1f2ddf..6dae77750a 100644 --- a/authentik/sources/oauth/models.py +++ b/authentik/sources/oauth/models.py @@ -1,6 +1,6 @@ """OAuth Client models""" -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from django.db import models from django.http.request import HttpRequest @@ -84,7 +84,7 @@ class OAuthSource(Source): icon_url=icon, ) - def ui_user_settings(self) -> Optional[UserSettingSerializer]: + def ui_user_settings(self) -> UserSettingSerializer | None: provider_type = self.source_type icon = self.icon_url if not icon: diff --git a/authentik/sources/oauth/tests/test_tasks.py b/authentik/sources/oauth/tests/test_tasks.py index 3671fbc4ba..9ab3aab4bf 100644 --- a/authentik/sources/oauth/tests/test_tasks.py +++ b/authentik/sources/oauth/tests/test_tasks.py @@ -35,7 +35,7 @@ class TestOAuthSourceTasks(TestCase): }, ) mock.get("http://foo/jwks", json={"foo": "bar"}) - update_well_known_jwks() # pylint: disable=no-value-for-parameter + update_well_known_jwks() self.source.refresh_from_db() self.assertEqual(self.source.authorization_url, "foo") self.assertEqual(self.source.access_token_url, "foo") diff --git a/authentik/sources/oauth/types/apple.py b/authentik/sources/oauth/types/apple.py index a92ce95147..b8a9c0f912 100644 --- a/authentik/sources/oauth/types/apple.py +++ b/authentik/sources/oauth/types/apple.py @@ -1,7 +1,7 @@ """Apple OAuth Views""" from time import time -from typing import Any, Optional +from typing import Any from django.http.request import HttpRequest from django.urls.base import reverse @@ -17,6 +17,7 @@ from authentik.sources.oauth.views.callback import OAuthCallback from authentik.sources.oauth.views.redirect import OAuthRedirect LOGGER = get_logger() +APPLE_CLIENT_ID_PARTS = 3 class AppleLoginChallenge(Challenge): @@ -30,7 +31,7 @@ class AppleLoginChallenge(Challenge): class AppleChallengeResponse(ChallengeResponse): - """Pseudo class for plex response""" + """Pseudo class for apple response""" component = CharField(default="ak-source-oauth-apple") @@ -40,14 +41,14 @@ class AppleOAuthClient(OAuth2Client): def get_client_id(self) -> str: parts: list[str] = self.source.consumer_key.split(";") - if len(parts) < 3: + if len(parts) < APPLE_CLIENT_ID_PARTS: return self.source.consumer_key return parts[0].strip() def get_client_secret(self) -> str: now = time() parts: list[str] = self.source.consumer_key.split(";") - if len(parts) < 3: + if len(parts) < APPLE_CLIENT_ID_PARTS: raise ValueError( "Apple Source client_id should be formatted like " "services_id_identifier;apple_team_id;key_id" @@ -64,7 +65,7 @@ class AppleOAuthClient(OAuth2Client): LOGGER.debug("signing payload as secret key", payload=payload, jwt=jwt) return jwt - def get_profile_info(self, token: dict[str, str]) -> Optional[dict[str, Any]]: + def get_profile_info(self, token: dict[str, str]) -> dict[str, Any] | None: id_token = token.get("id_token") return decode(id_token, options={"verify_signature": False}) @@ -86,7 +87,7 @@ class AppleOAuth2Callback(OAuthCallback): client_class = AppleOAuthClient - def get_user_id(self, info: dict[str, Any]) -> Optional[str]: + def get_user_id(self, info: dict[str, Any]) -> str | None: return info["sub"] def get_user_enroll_context( diff --git a/authentik/sources/oauth/types/facebook.py b/authentik/sources/oauth/types/facebook.py index 7225109dcb..1d1ac9d42a 100644 --- a/authentik/sources/oauth/types/facebook.py +++ b/authentik/sources/oauth/types/facebook.py @@ -1,6 +1,6 @@ """Facebook OAuth Views""" -from typing import Any, Optional +from typing import Any from facebook import GraphAPI @@ -22,7 +22,7 @@ class FacebookOAuthRedirect(OAuthRedirect): class FacebookOAuth2Client(OAuth2Client): """Facebook OAuth2 Client""" - def get_profile_info(self, token: dict[str, str]) -> Optional[dict[str, Any]]: + def get_profile_info(self, token: dict[str, str]) -> dict[str, Any] | None: api = GraphAPI(access_token=token["access_token"]) return api.get_object("me", fields="id,name,email") diff --git a/authentik/sources/oauth/types/mailcow.py b/authentik/sources/oauth/types/mailcow.py index 0f64994fe4..37895e114a 100644 --- a/authentik/sources/oauth/types/mailcow.py +++ b/authentik/sources/oauth/types/mailcow.py @@ -1,6 +1,6 @@ """Mailcow OAuth Views""" -from typing import Any, Optional +from typing import Any from requests.exceptions import RequestException from structlog.stdlib import get_logger @@ -25,7 +25,7 @@ class MailcowOAuthRedirect(OAuthRedirect): class MailcowOAuth2Client(OAuth2Client): """MailcowOAuth2Client, for some reason, mailcow does not like the default headers""" - def get_profile_info(self, token: dict[str, str]) -> Optional[dict[str, Any]]: + def get_profile_info(self, token: dict[str, str]) -> dict[str, Any] | None: "Fetch user profile information." profile_url = self.source.source_type.profile_url or "" if self.source.source_type.urls_customizable and self.source.profile_url: diff --git a/authentik/sources/oauth/types/registry.py b/authentik/sources/oauth/types/registry.py index c86badb200..199451c6cc 100644 --- a/authentik/sources/oauth/types/registry.py +++ b/authentik/sources/oauth/types/registry.py @@ -1,7 +1,7 @@ """Source type manager""" +from collections.abc import Callable from enum import Enum -from typing import Callable, Optional, Type from django.http.request import HttpRequest from django.templatetags.static import static @@ -33,12 +33,12 @@ class SourceType: urls_customizable = False - request_token_url: Optional[str] = None - authorization_url: Optional[str] = None - access_token_url: Optional[str] = None - profile_url: Optional[str] = None - oidc_well_known_url: Optional[str] = None - oidc_jwks_url: Optional[str] = None + request_token_url: str | None = None + authorization_url: str | None = None + access_token_url: str | None = None + profile_url: str | None = None + oidc_well_known_url: str | None = None + oidc_jwks_url: str | None = None def icon_url(self) -> str: """Get Icon URL for login""" @@ -80,7 +80,7 @@ class SourceTypeRegistry: """Get list of tuples of all registered names""" return [(x.name, x.verbose_name) for x in self.__sources] - def find_type(self, type_name: str) -> Type[SourceType]: + def find_type(self, type_name: str) -> type[SourceType]: """Find type based on source""" found_type = None for src_type in self.__sources: diff --git a/authentik/sources/oauth/types/twitch.py b/authentik/sources/oauth/types/twitch.py index ada73ce499..777d457867 100644 --- a/authentik/sources/oauth/types/twitch.py +++ b/authentik/sources/oauth/types/twitch.py @@ -1,7 +1,7 @@ """Twitch OAuth Views""" from json import dumps -from typing import Any, Optional +from typing import Any from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient from authentik.sources.oauth.types.oidc import OpenIDConnectOAuth2Callback @@ -12,7 +12,7 @@ from authentik.sources.oauth.views.redirect import OAuthRedirect class TwitchClient(UserprofileHeaderAuthClient): """Twitch needs the token_type to be capitalized for the request header.""" - def get_profile_info(self, token: dict[str, str]) -> Optional[dict[str, Any]]: + def get_profile_info(self, token: dict[str, str]) -> dict[str, Any] | None: token["token_type"] = token["token_type"].capitalize() return super().get_profile_info(token) diff --git a/authentik/sources/oauth/types/twitter.py b/authentik/sources/oauth/types/twitter.py index 5753b4694c..8b1aa66124 100644 --- a/authentik/sources/oauth/types/twitter.py +++ b/authentik/sources/oauth/types/twitter.py @@ -1,6 +1,6 @@ """Twitter OAuth Views""" -from typing import Any, Optional +from typing import Any from authentik.lib.generators import generate_id from authentik.sources.oauth.clients.oauth2 import ( @@ -20,7 +20,7 @@ class TwitterClient(UserprofileHeaderAuthClient): # is set via query parameter, so we reuse the azure client # see https://github.com/goauthentik/authentik/issues/1910 - def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]: + def get_access_token(self, **request_kwargs) -> dict[str, Any] | None: return super().get_access_token( auth=( self.source.consumer_key, diff --git a/authentik/sources/oauth/views/base.py b/authentik/sources/oauth/views/base.py index 2948cdac6e..e308778b6d 100644 --- a/authentik/sources/oauth/views/base.py +++ b/authentik/sources/oauth/views/base.py @@ -1,7 +1,5 @@ """OAuth Base views""" -from typing import Optional - from django.http.request import HttpRequest from structlog.stdlib import get_logger @@ -13,18 +11,17 @@ from authentik.sources.oauth.models import OAuthSource LOGGER = get_logger() -# pylint: disable=too-few-public-methods class OAuthClientMixin: "Mixin for getting OAuth client for a source." request: HttpRequest # Set by View class - client_class: Optional[type[BaseOAuthClient]] = None + client_class: type[BaseOAuthClient] | None = None def get_client(self, source: OAuthSource, **kwargs) -> BaseOAuthClient: "Get instance of the OAuth client for this source." if self.client_class is not None: - # pylint: disable=not-callable + return self.client_class(source, self.request, **kwargs) if source.source_type.request_token_url or source.request_token_url: client = OAuthClient(source, self.request, **kwargs) diff --git a/authentik/sources/oauth/views/callback.py b/authentik/sources/oauth/views/callback.py index 37bb51aaaf..d04727c8d5 100644 --- a/authentik/sources/oauth/views/callback.py +++ b/authentik/sources/oauth/views/callback.py @@ -1,7 +1,7 @@ """OAuth Callback Views""" from json import JSONDecodeError -from typing import Any, Optional +from typing import Any from django.conf import settings from django.contrib import messages @@ -23,16 +23,15 @@ class OAuthCallback(OAuthClientMixin, View): "Base OAuth callback view." source: OAuthSource - token: Optional[dict] = None + token: dict | None = None - # pylint: disable=too-many-return-statements def dispatch(self, request: HttpRequest, *_, **kwargs) -> HttpResponse: """View Get handler""" slug = kwargs.get("source_slug", "") try: self.source = OAuthSource.objects.get(slug=slug) except OAuthSource.DoesNotExist: - raise Http404(f"Unknown OAuth source '{slug}'.") + raise Http404(f"Unknown OAuth source '{slug}'.") from None if not self.source.enabled: raise Http404(f"Source {slug} is not enabled.") @@ -86,7 +85,7 @@ class OAuthCallback(OAuthClientMixin, View): """Create a dict of User data""" raise NotImplementedError() - def get_user_id(self, info: dict[str, Any]) -> Optional[str]: + def get_user_id(self, info: dict[str, Any]) -> str | None: """Return unique identifier from the profile info.""" if "id" in info: return info["id"] @@ -98,10 +97,11 @@ class OAuthCallback(OAuthClientMixin, View): messages.error( self.request, _( - "Authentication failed: %(reason)s" - % { - "reason": reason, - } + "Authentication failed: {reason}".format_map( + { + "reason": reason, + } + ) ), ) return redirect(self.get_error_redirect(self.source, reason)) @@ -115,7 +115,7 @@ class OAuthSourceFlowManager(SourceFlowManager): def update_connection( self, connection: UserOAuthSourceConnection, - access_token: Optional[str] = None, + access_token: str | None = None, ) -> UserOAuthSourceConnection: """Set the access_token on the connection""" connection.access_token = access_token diff --git a/authentik/sources/oauth/views/redirect.py b/authentik/sources/oauth/views/redirect.py index 97d5356eb3..2a68060906 100644 --- a/authentik/sources/oauth/views/redirect.py +++ b/authentik/sources/oauth/views/redirect.py @@ -36,7 +36,7 @@ class OAuthRedirect(OAuthClientMixin, RedirectView): try: source: OAuthSource = OAuthSource.objects.get(slug=slug) except OAuthSource.DoesNotExist: - raise Http404(f"Unknown OAuth source '{slug}'.") + raise Http404(f"Unknown OAuth source '{slug}'.") from None if not source.enabled: raise Http404(f"source {slug} is not enabled.") client = self.get_client(source, callback=self.get_callback_url(source)) diff --git a/authentik/sources/plex/models.py b/authentik/sources/plex/models.py index 6621d4d269..6f4e7def99 100644 --- a/authentik/sources/plex/models.py +++ b/authentik/sources/plex/models.py @@ -1,7 +1,5 @@ """Plex source""" -from typing import Optional - from django.contrib.postgres.fields import ArrayField from django.db import models from django.http.request import HttpRequest @@ -79,7 +77,7 @@ class PlexSource(Source): name=self.name, ) - def ui_user_settings(self) -> Optional[UserSettingSerializer]: + def ui_user_settings(self) -> UserSettingSerializer | None: icon = self.icon_url if not icon: icon = static("authentik/sources/plex.svg") diff --git a/authentik/sources/plex/plex.py b/authentik/sources/plex/plex.py index 9e324f0ae8..caf245888f 100644 --- a/authentik/sources/plex/plex.py +++ b/authentik/sources/plex/plex.py @@ -85,7 +85,7 @@ class PlexAuth: resources = self.get_resources() except RequestException as exc: LOGGER.warning("Unable to fetch user resources", exc=exc) - raise Http404 + raise Http404 from None for resource in resources: if resource["provides"] != "server": continue diff --git a/authentik/sources/saml/models.py b/authentik/sources/saml/models.py index b05fd61007..593f49446b 100644 --- a/authentik/sources/saml/models.py +++ b/authentik/sources/saml/models.py @@ -1,7 +1,5 @@ """saml sp models""" -from typing import Optional - from django.db import models from django.http import HttpRequest from django.templatetags.static import static @@ -204,7 +202,7 @@ class SAMLSource(Source): icon_url=self.icon_url, ) - def ui_user_settings(self) -> Optional[UserSettingSerializer]: + def ui_user_settings(self) -> UserSettingSerializer | None: icon = self.icon_url if not icon: icon = static(f"authentik/sources/{self.slug}.svg") diff --git a/authentik/sources/saml/processors/metadata.py b/authentik/sources/saml/processors/metadata.py index 2aea4ba871..34bd4b5b66 100644 --- a/authentik/sources/saml/processors/metadata.py +++ b/authentik/sources/saml/processors/metadata.py @@ -1,6 +1,7 @@ """SAML Service Provider Metadata Processor""" -from typing import Iterator, Optional +from collections.abc import Iterator +from typing import Optional from django.http import HttpRequest from lxml.etree import Element, SubElement, tostring # nosec @@ -30,7 +31,8 @@ class MetadataProcessor: self.source = source self.http_request = request - def get_signing_key_descriptor(self) -> Optional[Element]: + # Using type unions doesn't work with cython types (which is what lxml is) + def get_signing_key_descriptor(self) -> Optional[Element]: # noqa: UP007 """Get Signing KeyDescriptor, if enabled for the source""" if self.source.signing_kp: key_descriptor = Element(f"{{{NS_SAML_METADATA}}}KeyDescriptor") diff --git a/authentik/sources/saml/views.py b/authentik/sources/saml/views.py index 502a96603b..37550d735e 100644 --- a/authentik/sources/saml/views.py +++ b/authentik/sources/saml/views.py @@ -88,7 +88,7 @@ class InitiateView(View): try: plan = planner.plan(self.request, kwargs) except FlowNonApplicableException: - raise Http404 + raise Http404 from None for stage in stages_to_append: plan.append_stage(stage) self.request.session[SESSION_KEY_PLAN] = plan diff --git a/authentik/stages/authenticator/__init__.py b/authentik/stages/authenticator/__init__.py index bf675f04e9..9601b13af3 100644 --- a/authentik/stages/authenticator/__init__.py +++ b/authentik/stages/authenticator/__init__.py @@ -121,7 +121,8 @@ def device_classes(): """ Returns an iterable of all loaded device models. """ - from django.apps import apps # isort: skip + from django.apps import apps + from authentik.stages.authenticator.models import Device for config in apps.get_app_configs(): diff --git a/authentik/stages/authenticator/models.py b/authentik/stages/authenticator/models.py index 3843d0fbf2..f7a6125c57 100644 --- a/authentik/stages/authenticator/models.py +++ b/authentik/stages/authenticator/models.py @@ -96,14 +96,14 @@ class Device(models.Model): except ObjectDoesNotExist: user = None - return "{0} ({1})".format(self.name, user) + return f"{self.name} ({user})" @property def persistent_id(self): """ A stable device identifier for forms and APIs. """ - return "{0}/{1}".format(self.model_label(), self.id) + return f"{self.model_label()}/{self.id}" @classmethod def model_label(cls): @@ -113,7 +113,7 @@ class Device(models.Model): This is just the standard "." form. """ - return "{0}.{1}".format(cls._meta.app_label, cls._meta.model_name) + return f"{cls._meta.app_label}.{cls._meta.model_name}" @classmethod def from_persistent_id(cls, persistent_id, for_verify=False): @@ -314,6 +314,9 @@ class ThrottlingMixin(models.Model): default=0, help_text="Number of successive failed attempts." ) + class Meta: + abstract = True + def verify_is_allowed(self): """ If verification is allowed, returns ``(True, None)``. @@ -397,6 +400,3 @@ class ThrottlingMixin(models.Model): """ raise NotImplementedError() - - class Meta: - abstract = True diff --git a/authentik/stages/authenticator/oath.py b/authentik/stages/authenticator/oath.py index 0a1a99b8e6..9ba4a5604b 100644 --- a/authentik/stages/authenticator/oath.py +++ b/authentik/stages/authenticator/oath.py @@ -6,7 +6,6 @@ from struct import pack from time import time -# pylint: disable=invalid-name def hotp(key: bytes, counter: int, digits=6) -> int: """ Implementation of the HOTP algorithm from `RFC 4226 @@ -129,7 +128,6 @@ class TOTP: 359152 """ - # pylint: disable=too-many-arguments def __init__(self, key: bytes, step=30, t0=0, digits=6, drift=0): self.key = key self.step = step diff --git a/authentik/stages/authenticator/util.py b/authentik/stages/authenticator/util.py index afef00cb52..e92f1a168e 100644 --- a/authentik/stages/authenticator/util.py +++ b/authentik/stages/authenticator/util.py @@ -43,10 +43,10 @@ def hex_validator(length=0): unhexlify(value) except Exception: - raise ValidationError("{0} is not valid hex-encoded data.".format(value)) + raise ValidationError(f"{value} is not valid hex-encoded data.") from None if (length > 0) and (len(value) != length * 2): - raise ValidationError("{0} does not represent exactly {1} bytes.".format(value, length)) + raise ValidationError(f"{value} does not represent exactly {length} bytes.") return _validator diff --git a/authentik/stages/authenticator_duo/models.py b/authentik/stages/authenticator_duo/models.py index f8b0f7ff48..37e2f6714c 100644 --- a/authentik/stages/authenticator_duo/models.py +++ b/authentik/stages/authenticator_duo/models.py @@ -1,7 +1,5 @@ """Duo stage""" -from typing import Optional - from django.contrib.auth import get_user_model from django.db import models from django.utils.translation import gettext_lazy as _ @@ -35,7 +33,7 @@ class AuthenticatorDuoStage(ConfigurableStage, FriendlyNamedStage, Stage): return AuthenticatorDuoStageSerializer @property - def type(self) -> type[View]: + def view(self) -> type[View]: from authentik.stages.authenticator_duo.stage import AuthenticatorDuoStageView return AuthenticatorDuoStageView @@ -65,7 +63,7 @@ class AuthenticatorDuoStage(ConfigurableStage, FriendlyNamedStage, Stage): def component(self) -> str: return "ak-stage-authenticator-duo-form" - def ui_user_settings(self) -> Optional[UserSettingSerializer]: + def ui_user_settings(self) -> UserSettingSerializer | None: return UserSettingSerializer( data={ "title": self.friendly_name or str(self._meta.verbose_name), diff --git a/authentik/stages/authenticator_sms/models.py b/authentik/stages/authenticator_sms/models.py index 27bd07faf9..36b3dd5370 100644 --- a/authentik/stages/authenticator_sms/models.py +++ b/authentik/stages/authenticator_sms/models.py @@ -1,10 +1,10 @@ """SMS Authenticator models""" from hashlib import sha256 -from typing import Optional from django.contrib.auth import get_user_model from django.db import models +from django.http import HttpResponseBadRequest from django.utils.translation import gettext_lazy as _ from django.views import View from requests.exceptions import RequestException @@ -79,7 +79,7 @@ class AuthenticatorSMSStage(ConfigurableStage, FriendlyNamedStage, Stage): def get_message(self, token: str) -> str: """Get SMS message""" - return _("Use this code to authenticate in authentik: %(token)s" % {"token": token}) + return _("Use this code to authenticate in authentik: {token}".format_map({"token": token})) def send_twilio(self, token: str, device: "SMSDevice"): """send sms via twilio provider""" @@ -92,7 +92,7 @@ class AuthenticatorSMSStage(ConfigurableStage, FriendlyNamedStage, Stage): LOGGER.debug("Sent SMS", to=device, message=message.sid) except TwilioRestException as exc: LOGGER.warning("Error sending token by Twilio SMS", exc=exc, msg=exc.msg) - raise ValidationError(exc.msg) + raise ValidationError(exc.msg) from None def send_generic(self, token: str, device: "SMSDevice"): """Send SMS via outside API""" @@ -146,8 +146,8 @@ class AuthenticatorSMSStage(ConfigurableStage, FriendlyNamedStage, Stage): status_code=response.status_code, body=response.text, ).set_user(device.user).save() - if response.status_code >= 400: - raise ValidationError(response.text) + if response.status_code >= HttpResponseBadRequest.status_code: + raise ValidationError(response.text) from None raise @property @@ -157,7 +157,7 @@ class AuthenticatorSMSStage(ConfigurableStage, FriendlyNamedStage, Stage): return AuthenticatorSMSStageSerializer @property - def type(self) -> type[View]: + def view(self) -> type[View]: from authentik.stages.authenticator_sms.stage import AuthenticatorSMSStageView return AuthenticatorSMSStageView @@ -166,7 +166,7 @@ class AuthenticatorSMSStage(ConfigurableStage, FriendlyNamedStage, Stage): def component(self) -> str: return "ak-stage-authenticator-sms-form" - def ui_user_settings(self) -> Optional[UserSettingSerializer]: + def ui_user_settings(self) -> UserSettingSerializer | None: return UserSettingSerializer( data={ "title": self.friendly_name or str(self._meta.verbose_name), diff --git a/authentik/stages/authenticator_sms/stage.py b/authentik/stages/authenticator_sms/stage.py index 3802adb709..159e5ede5a 100644 --- a/authentik/stages/authenticator_sms/stage.py +++ b/authentik/stages/authenticator_sms/stage.py @@ -1,7 +1,5 @@ """SMS Setup stage""" -from typing import Optional - from django.db.models import Q from django.http import HttpRequest, HttpResponse from django.http.request import QueryDict @@ -76,7 +74,7 @@ class AuthenticatorSMSStageView(ChallengeStageView): device: SMSDevice = self.request.session[SESSION_KEY_SMS_DEVICE] stage.send(device.token, device) - def _has_phone_number(self) -> Optional[str]: + def _has_phone_number(self) -> str | None: context = self.executor.plan.context if PLAN_CONTEXT_PHONE in context.get(PLAN_CONTEXT_PROMPT, {}): self.logger.debug("got phone number from plan context") diff --git a/authentik/stages/authenticator_static/models.py b/authentik/stages/authenticator_static/models.py index eb11c6ab9f..77af9b24a0 100644 --- a/authentik/stages/authenticator_static/models.py +++ b/authentik/stages/authenticator_static/models.py @@ -2,7 +2,6 @@ from base64 import b32encode from os import urandom -from typing import Optional from django.conf import settings from django.db import models @@ -29,7 +28,7 @@ class AuthenticatorStaticStage(ConfigurableStage, FriendlyNamedStage, Stage): return AuthenticatorStaticStageSerializer @property - def type(self) -> type[View]: + def view(self) -> type[View]: from authentik.stages.authenticator_static.stage import AuthenticatorStaticStageView return AuthenticatorStaticStageView @@ -38,7 +37,7 @@ class AuthenticatorStaticStage(ConfigurableStage, FriendlyNamedStage, Stage): def component(self) -> str: return "ak-stage-authenticator-static-form" - def ui_user_settings(self) -> Optional[UserSettingSerializer]: + def ui_user_settings(self) -> UserSettingSerializer | None: return UserSettingSerializer( data={ "title": self.friendly_name or str(self._meta.verbose_name), @@ -116,6 +115,13 @@ class StaticToken(models.Model): device = models.ForeignKey(StaticDevice, related_name="token_set", on_delete=models.CASCADE) token = models.CharField(max_length=16, db_index=True) + class Meta: + verbose_name = _("Static Token") + verbose_name_plural = _("Static Tokens") + + def __str__(self) -> str: + return "Static Token" + @staticmethod def random_token(): """ @@ -125,7 +131,3 @@ class StaticToken(models.Model): """ return b32encode(urandom(5)).decode("utf-8").lower() - - class Meta: - verbose_name = _("Static Token") - verbose_name_plural = _("Static Tokens") diff --git a/authentik/stages/authenticator_totp/models.py b/authentik/stages/authenticator_totp/models.py index 77a76d3fd3..6275af6a03 100644 --- a/authentik/stages/authenticator_totp/models.py +++ b/authentik/stages/authenticator_totp/models.py @@ -3,7 +3,6 @@ import time from base64 import b32encode from binascii import unhexlify -from typing import Optional from urllib.parse import quote, urlencode from django.conf import settings @@ -39,7 +38,7 @@ class AuthenticatorTOTPStage(ConfigurableStage, FriendlyNamedStage, Stage): return AuthenticatorTOTPStageSerializer @property - def type(self) -> type[View]: + def view(self) -> type[View]: from authentik.stages.authenticator_totp.stage import AuthenticatorTOTPStageView return AuthenticatorTOTPStageView @@ -48,7 +47,7 @@ class AuthenticatorTOTPStage(ConfigurableStage, FriendlyNamedStage, Stage): def component(self) -> str: return "ak-stage-authenticator-totp-form" - def ui_user_settings(self) -> Optional[UserSettingSerializer]: + def ui_user_settings(self) -> UserSettingSerializer | None: return UserSettingSerializer( data={ "title": self.friendly_name or str(self._meta.verbose_name), @@ -220,16 +219,16 @@ class TOTPDevice(SerializerModel, ThrottlingMixin, Device): issuer = self._read_str_from_settings("OTP_TOTP_ISSUER") if issuer: issuer = issuer.replace(":", "") - label = "{}:{}".format(issuer, label) - urlencoded_params += "&issuer={}".format( - quote(issuer) - ) # encode issuer as per RFC 3986, not quote_plus + label = f"{issuer}:{label}" + urlencoded_params += ( + f"&issuer={quote(issuer)}" # encode issuer as per RFC 3986, not quote_plus + ) image = self._read_str_from_settings("OTP_TOTP_IMAGE") if image: urlencoded_params += "&image={}".format(quote(image, safe=":/")) - url = "otpauth://totp/{}?{}".format(quote(label), urlencoded_params) + url = f"otpauth://totp/{quote(label)}?{urlencoded_params}" return url diff --git a/authentik/stages/authenticator_totp/tests.py b/authentik/stages/authenticator_totp/tests.py index 44c6dc9b6e..b5185d9396 100644 --- a/authentik/stages/authenticator_totp/tests.py +++ b/authentik/stages/authenticator_totp/tests.py @@ -46,7 +46,6 @@ class TOTPDeviceMixin: 784503, ] - # pylint: disable=invalid-name def setUp(self): """ Create a device at the fourth time step. The current token is 154567. diff --git a/authentik/stages/authenticator_validate/challenge.py b/authentik/stages/authenticator_validate/challenge.py index 85f57d1925..a762e44c17 100644 --- a/authentik/stages/authenticator_validate/challenge.py +++ b/authentik/stages/authenticator_validate/challenge.py @@ -1,7 +1,6 @@ """Validation stage challenge checking""" from json import loads -from typing import Optional from urllib.parse import urlencode from django.http import HttpRequest @@ -73,7 +72,7 @@ def get_webauthn_challenge_without_user( def get_webauthn_challenge( - request: HttpRequest, stage: AuthenticatorValidateStage, device: Optional[WebAuthnDevice] = None + request: HttpRequest, stage: AuthenticatorValidateStage, device: WebAuthnDevice | None = None ) -> dict: """Send the client a challenge that we'll check later""" request.session.pop(SESSION_KEY_WEBAUTHN_CHALLENGE, None) @@ -192,10 +191,11 @@ def validate_challenge_duo(device_pk: int, stage_view: StageView, user: User) -> user_id=device.duo_user_id, ipaddr=ClientIPMiddleware.get_client_ip(stage_view.request), type=__( - "%(brand_name)s Login request" - % { - "brand_name": stage_view.request.brand.branding_title, - } + "{brand_name} Login request".format_map( + { + "brand_name": stage_view.request.brand.branding_title, + } + ) ), display_username=user.username, device="auto", @@ -220,4 +220,4 @@ def validate_challenge_duo(device_pk: int, stage_view: StageView, user: User) -> message=f"Failed to DUO authenticate user: {str(exc)}", user=user, ).from_http(stage_view.request, user) - raise ValidationError("Duo denied access", code="denied") + raise ValidationError("Duo denied access", code="denied") from exc diff --git a/authentik/stages/authenticator_validate/models.py b/authentik/stages/authenticator_validate/models.py index 45bf071b52..3b2c6b2cfd 100644 --- a/authentik/stages/authenticator_validate/models.py +++ b/authentik/stages/authenticator_validate/models.py @@ -79,7 +79,7 @@ class AuthenticatorValidateStage(Stage): return AuthenticatorValidateStageSerializer @property - def type(self) -> type[View]: + def view(self) -> type[View]: from authentik.stages.authenticator_validate.stage import AuthenticatorValidateStageView return AuthenticatorValidateStageView diff --git a/authentik/stages/authenticator_validate/stage.py b/authentik/stages/authenticator_validate/stage.py index f7f62a6d1b..bf83187f07 100644 --- a/authentik/stages/authenticator_validate/stage.py +++ b/authentik/stages/authenticator_validate/stage.py @@ -2,7 +2,6 @@ from datetime import datetime from hashlib import sha256 -from typing import Optional from django.conf import settings from django.http import HttpRequest, HttpResponse @@ -63,7 +62,7 @@ class AuthenticatorValidationChallenge(WithUserInfoChallenge): class AuthenticatorValidationChallengeResponse(ChallengeResponse): """Challenge used for Code-based and WebAuthn authenticators""" - device: Optional[Device] + device: Device | None selected_challenge = DeviceChallenge(required=False) selected_stage = CharField(required=False) @@ -222,8 +221,7 @@ class AuthenticatorValidateStageView(ChallengeStageView): challenge.is_valid() return [challenge.data] - # pylint: disable=too-many-return-statements - def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: + def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: # noqa: PLR0911 """Check if a user is set, and check if the user has any devices if not, we can skip this entire stage""" user = self.get_pending_user() diff --git a/authentik/stages/authenticator_validate/tests/test_duo.py b/authentik/stages/authenticator_validate/tests/test_duo.py index 38cca7811a..bc76442c8b 100644 --- a/authentik/stages/authenticator_validate/tests/test_duo.py +++ b/authentik/stages/authenticator_validate/tests/test_duo.py @@ -37,7 +37,7 @@ class AuthenticatorValidateStageDuoTests(FlowTestCase): middleware = SessionMiddleware(dummy_get_response) middleware.process_request(request) request.session.save() - setattr(request, "brand", get_brand_for_request(request)) + request.brand = get_brand_for_request(request) stage = AuthenticatorDuoStage.objects.create( name=generate_id(), diff --git a/authentik/stages/authenticator_webauthn/models.py b/authentik/stages/authenticator_webauthn/models.py index 6d7551837e..72a5e846e8 100644 --- a/authentik/stages/authenticator_webauthn/models.py +++ b/authentik/stages/authenticator_webauthn/models.py @@ -1,7 +1,5 @@ """WebAuthn stage""" -from typing import Optional - from django.contrib.auth import get_user_model from django.db import models from django.utils.timezone import now @@ -78,7 +76,7 @@ class AuthenticateWebAuthnStage(ConfigurableStage, FriendlyNamedStage, Stage): choices=ResidentKeyRequirement.choices, default=ResidentKeyRequirement.PREFERRED, ) - authenticator_attachment = models.TextField( + authenticator_attachment = models.TextField( # noqa: DJ001 choices=AuthenticatorAttachment.choices, default=None, null=True ) @@ -89,7 +87,7 @@ class AuthenticateWebAuthnStage(ConfigurableStage, FriendlyNamedStage, Stage): return AuthenticateWebAuthnStageSerializer @property - def type(self) -> type[View]: + def view(self) -> type[View]: from authentik.stages.authenticator_webauthn.stage import AuthenticatorWebAuthnStageView return AuthenticatorWebAuthnStageView @@ -98,7 +96,7 @@ class AuthenticateWebAuthnStage(ConfigurableStage, FriendlyNamedStage, Stage): def component(self) -> str: return "ak-stage-authenticator-webauthn-form" - def ui_user_settings(self) -> Optional[UserSettingSerializer]: + def ui_user_settings(self) -> UserSettingSerializer | None: return UserSettingSerializer( data={ "title": self.friendly_name or str(self._meta.verbose_name), diff --git a/authentik/stages/authenticator_webauthn/stage.py b/authentik/stages/authenticator_webauthn/stage.py index 836e05df08..f78c637278 100644 --- a/authentik/stages/authenticator_webauthn/stage.py +++ b/authentik/stages/authenticator_webauthn/stage.py @@ -65,7 +65,7 @@ class AuthenticatorWebAuthnChallengeResponse(ChallengeResponse): ) except InvalidRegistrationResponse as exc: self.stage.logger.warning("registration failed", exc=exc) - raise ValidationError(f"Registration failed. Error: {exc}") + raise ValidationError(f"Registration failed. Error: {exc}") from None credential_id_exists = WebAuthnDevice.objects.filter( credential_id=bytes_to_base64url(registration.credential_id) diff --git a/authentik/stages/authenticator_webauthn/tests.py b/authentik/stages/authenticator_webauthn/tests.py index ecba7ee7ed..20c8d03e4b 100644 --- a/authentik/stages/authenticator_webauthn/tests.py +++ b/authentik/stages/authenticator_webauthn/tests.py @@ -47,10 +47,8 @@ class TestAuthenticatorWebAuthnStage(FlowTestCase): session = self.client.session session[SESSION_KEY_PLAN] = plan session[SESSION_KEY_WEBAUTHN_CHALLENGE] = b64decode( - ( - "o90Yh1osqW3mjGift+6WclWOya5lcdff/G0mqueN3hChacMUz" - "V4mxiDafuQ0x0e1d/fcPai0fx/jMBZ8/nG2qQ==" - ).encode() + b"o90Yh1osqW3mjGift+6WclWOya5lcdff/G0mqueN3hChacMUz" + b"V4mxiDafuQ0x0e1d/fcPai0fx/jMBZ8/nG2qQ==" ) session.save() diff --git a/authentik/stages/captcha/models.py b/authentik/stages/captcha/models.py index 7ef2b65472..e0e126b056 100644 --- a/authentik/stages/captcha/models.py +++ b/authentik/stages/captcha/models.py @@ -24,7 +24,7 @@ class CaptchaStage(Stage): return CaptchaStageSerializer @property - def type(self) -> type[View]: + def view(self) -> type[View]: from authentik.stages.captcha.stage import CaptchaStageView return CaptchaStageView diff --git a/authentik/stages/consent/models.py b/authentik/stages/consent/models.py index bf042cccfd..385aaf8921 100644 --- a/authentik/stages/consent/models.py +++ b/authentik/stages/consent/models.py @@ -37,7 +37,7 @@ class ConsentStage(Stage): return ConsentStageSerializer @property - def type(self) -> type[View]: + def view(self) -> type[View]: from authentik.stages.consent.stage import ConsentStageView return ConsentStageView diff --git a/authentik/stages/consent/stage.py b/authentik/stages/consent/stage.py index eda0d03f60..7fbc9283fa 100644 --- a/authentik/stages/consent/stage.py +++ b/authentik/stages/consent/stage.py @@ -1,6 +1,5 @@ """authentik consent stage""" -from typing import Optional from uuid import uuid4 from django.http import HttpRequest, HttpResponse @@ -99,7 +98,7 @@ class ConsentStageView(ChallengeStageView): if PLAN_CONTEXT_PENDING_USER in self.executor.plan.context: user = self.executor.plan.context[PLAN_CONTEXT_PENDING_USER] - consent: Optional[UserConsent] = UserConsent.filter_not_expired( + consent: UserConsent | None = UserConsent.filter_not_expired( user=user, application=application ).first() self.executor.plan.context[PLAN_CONTEXT_CONSENT] = consent diff --git a/authentik/stages/consent/tests.py b/authentik/stages/consent/tests.py index 09e1541b85..f67496482a 100644 --- a/authentik/stages/consent/tests.py +++ b/authentik/stages/consent/tests.py @@ -54,7 +54,7 @@ class TestConsentStage(FlowTestCase): "token": session[SESSION_KEY_CONSENT_TOKEN], }, ) - # pylint: disable=no-member + self.assertEqual(response.status_code, 200) self.assertStageRedirects(response, reverse("authentik_core:root-redirect")) self.assertFalse(UserConsent.objects.filter(user=self.user).exists()) diff --git a/authentik/stages/deny/models.py b/authentik/stages/deny/models.py index b57776966f..49fcb035d1 100644 --- a/authentik/stages/deny/models.py +++ b/authentik/stages/deny/models.py @@ -20,7 +20,7 @@ class DenyStage(Stage): return DenyStageSerializer @property - def type(self) -> type[View]: + def view(self) -> type[View]: from authentik.stages.deny.stage import DenyStageView return DenyStageView diff --git a/authentik/stages/dummy/models.py b/authentik/stages/dummy/models.py index cfd14b914f..e215ede196 100644 --- a/authentik/stages/dummy/models.py +++ b/authentik/stages/dummy/models.py @@ -22,7 +22,7 @@ class DummyStage(Stage): return DummyStageSerializer @property - def type(self) -> type[View]: + def view(self) -> type[View]: from authentik.stages.dummy.stage import DummyStageView return DummyStageView diff --git a/authentik/stages/email/models.py b/authentik/stages/email/models.py index 0a0f488db1..a695b861b7 100644 --- a/authentik/stages/email/models.py +++ b/authentik/stages/email/models.py @@ -2,7 +2,6 @@ from os import R_OK, access from pathlib import Path -from typing import Type from django.conf import settings from django.core.mail.backends.base import BaseEmailBackend @@ -88,7 +87,7 @@ class EmailStage(Stage): return EmailStageSerializer @property - def type(self) -> type[View]: + def view(self) -> type[View]: from authentik.stages.email.stage import EmailStageView return EmailStageView @@ -98,7 +97,7 @@ class EmailStage(Stage): return "ak-stage-email-form" @property - def backend_class(self) -> Type[BaseEmailBackend]: + def backend_class(self) -> type[BaseEmailBackend]: """Get the email backend class to use""" return EmailBackend diff --git a/authentik/stages/email/tasks.py b/authentik/stages/email/tasks.py index 2a8c0eb4c9..8a6089ea11 100644 --- a/authentik/stages/email/tasks.py +++ b/authentik/stages/email/tasks.py @@ -2,7 +2,7 @@ from email.utils import make_msgid from smtplib import SMTPException -from typing import Any, Optional +from typing import Any from celery import group from django.core.mail import EmailMultiAlternatives @@ -47,7 +47,7 @@ def get_email_body(email: EmailMultiAlternatives) -> str: retry_backoff=True, base=SystemTask, ) -def send_mail(self: SystemTask, message: dict[Any, Any], email_stage_pk: Optional[str] = None): +def send_mail(self: SystemTask, message: dict[Any, Any], email_stage_pk: str | None = None): """Send Email for Email Stage. Retries are scheduled automatically.""" self.save_on_success = False message_id = make_msgid(domain=DNS_NAME) diff --git a/authentik/stages/email/utils.py b/authentik/stages/email/utils.py index a086250596..40373cdbaa 100644 --- a/authentik/stages/email/utils.py +++ b/authentik/stages/email/utils.py @@ -10,7 +10,7 @@ from django.template.loader import render_to_string from django.utils import translation -@lru_cache() +@lru_cache def logo_data() -> MIMEImage: """Get logo as MIME Image for emails""" path = Path("web/icons/icon_left_brand.png") diff --git a/authentik/stages/identification/models.py b/authentik/stages/identification/models.py index 8b8b3d1fda..323b93ebf1 100644 --- a/authentik/stages/identification/models.py +++ b/authentik/stages/identification/models.py @@ -102,7 +102,7 @@ class IdentificationStage(Stage): return IdentificationStageSerializer @property - def type(self) -> type[View]: + def view(self) -> type[View]: from authentik.stages.identification.stage import IdentificationStageView return IdentificationStageView diff --git a/authentik/stages/identification/stage.py b/authentik/stages/identification/stage.py index 2e5de86f8b..130442d2c9 100644 --- a/authentik/stages/identification/stage.py +++ b/authentik/stages/identification/stage.py @@ -3,7 +3,7 @@ from dataclasses import asdict from random import SystemRandom from time import sleep -from typing import Any, Optional +from typing import Any from django.core.exceptions import PermissionDenied from django.db.models import Q @@ -84,7 +84,7 @@ class IdentificationChallengeResponse(ChallengeResponse): password = CharField(required=False, allow_blank=True, allow_null=True) component = CharField(default="ak-stage-identification") - pre_user: Optional[User] = None + pre_user: User | None = None def validate(self, attrs: dict[str, Any]) -> dict[str, Any]: """Validate that user exists, and optionally their password""" @@ -159,7 +159,7 @@ class IdentificationStageView(ChallengeStageView): response_class = IdentificationChallengeResponse - def get_user(self, uid_value: str) -> Optional[User]: + def get_user(self, uid_value: str) -> User | None: """Find user instance. Returns None if no user was found.""" current_stage: IdentificationStage = self.executor.current_stage query = Q() diff --git a/authentik/stages/invitation/models.py b/authentik/stages/invitation/models.py index 4296031a9a..831effcedc 100644 --- a/authentik/stages/invitation/models.py +++ b/authentik/stages/invitation/models.py @@ -32,7 +32,7 @@ class InvitationStage(Stage): return InvitationStageSerializer @property - def type(self) -> type[View]: + def view(self) -> type[View]: from authentik.stages.invitation.stage import InvitationStageView return InvitationStageView diff --git a/authentik/stages/invitation/stage.py b/authentik/stages/invitation/stage.py index 2527727ff5..3a9cc5746c 100644 --- a/authentik/stages/invitation/stage.py +++ b/authentik/stages/invitation/stage.py @@ -1,7 +1,5 @@ """invitation stage logic""" -from typing import Optional - from deepmerge import always_merger from django.core.exceptions import ValidationError from django.http import HttpRequest, HttpResponse @@ -22,7 +20,7 @@ INVITATION = "invitation" class InvitationStageView(StageView): """Finalise Authentication flow by logging the user in""" - def get_token(self) -> Optional[str]: + def get_token(self) -> str | None: """Get token from saved get-arguments or prompt_data""" # Check for ?token= and ?itoken= if INVITATION_TOKEN_KEY in self.request.session.get(SESSION_KEY_GET, {}): @@ -34,7 +32,7 @@ class InvitationStageView(StageView): return self.executor.plan.context[PLAN_CONTEXT_PROMPT][INVITATION_TOKEN_KEY_CONTEXT] return None - def get_invite(self) -> Optional[Invitation]: + def get_invite(self) -> Invitation | None: """Check the token, find the invite and check it's flow""" token = self.get_token() if not token: diff --git a/authentik/stages/password/models.py b/authentik/stages/password/models.py index 56542f3e49..9887b68543 100644 --- a/authentik/stages/password/models.py +++ b/authentik/stages/password/models.py @@ -1,7 +1,5 @@ """password stage models""" -from typing import Optional - from django.contrib.postgres.fields import ArrayField from django.db import models from django.utils.translation import gettext_lazy as _ @@ -53,7 +51,7 @@ class PasswordStage(ConfigurableStage, Stage): return PasswordStageSerializer @property - def type(self) -> type[View]: + def view(self) -> type[View]: from authentik.stages.password.stage import PasswordStageView return PasswordStageView @@ -62,7 +60,7 @@ class PasswordStage(ConfigurableStage, Stage): def component(self) -> str: return "ak-stage-password-form" - def ui_user_settings(self) -> Optional[UserSettingSerializer]: + def ui_user_settings(self) -> UserSettingSerializer | None: if not self.configure_flow: return None return UserSettingSerializer( diff --git a/authentik/stages/password/stage.py b/authentik/stages/password/stage.py index 964652a080..18d78383fb 100644 --- a/authentik/stages/password/stage.py +++ b/authentik/stages/password/stage.py @@ -1,6 +1,6 @@ """authentik password stage""" -from typing import Any, Optional +from typing import Any from django.contrib.auth import _clean_credentials from django.contrib.auth.backends import BaseBackend @@ -36,8 +36,8 @@ SESSION_KEY_INVALID_TRIES = "authentik/stages/password/user_invalid_tries" def authenticate( - request: HttpRequest, backends: list[str], stage: Optional[Stage] = None, **credentials: Any -) -> Optional[User]: + request: HttpRequest, backends: list[str], stage: Stage | None = None, **credentials: Any +) -> User | None: """If the given credentials are valid, return a User object. Customized version of django's authenticate, which accepts a list of backends""" diff --git a/authentik/stages/prompt/models.py b/authentik/stages/prompt/models.py index c30de6a3e1..21776b6b33 100644 --- a/authentik/stages/prompt/models.py +++ b/authentik/stages/prompt/models.py @@ -1,6 +1,6 @@ """prompt models""" -from typing import Any, Optional, Type +from typing import Any, Type # noqa: UP035 from urllib.parse import urlparse, urlunparse from uuid import uuid4 @@ -143,7 +143,7 @@ class Prompt(SerializerModel): initial_value_expression = models.BooleanField(default=False) @property - def serializer(self) -> Type[BaseSerializer]: + def serializer(self) -> Type[BaseSerializer]: # noqa: UP006 from authentik.stages.prompt.api import PromptSerializer return PromptSerializer @@ -153,8 +153,8 @@ class Prompt(SerializerModel): prompt_context: dict, user: User, request: HttpRequest, - dry_run: Optional[bool] = False, - ) -> Optional[tuple[dict[str, Any]]]: + dry_run: bool | None = False, + ) -> tuple[dict[str, Any]] | None: """Get fully interpolated list of choices""" if self.type not in CHOICE_FIELDS: return None @@ -178,7 +178,7 @@ class Prompt(SerializerModel): if dry_run: raise wrapped from exc - if isinstance(raw_choices, (list, tuple, set)): + if isinstance(raw_choices, list | tuple | set): choices = raw_choices else: choices = [raw_choices] @@ -193,7 +193,7 @@ class Prompt(SerializerModel): prompt_context: dict, user: User, request: HttpRequest, - dry_run: Optional[bool] = False, + dry_run: bool | None = False, ) -> str: """Get fully interpolated placeholder""" if self.type in CHOICE_FIELDS: @@ -222,7 +222,7 @@ class Prompt(SerializerModel): prompt_context: dict, user: User, request: HttpRequest, - dry_run: Optional[bool] = False, + dry_run: bool | None = False, ) -> str: """Get fully interpolated initial value""" @@ -258,50 +258,52 @@ class Prompt(SerializerModel): return value - def field(self, default: Optional[Any], choices: Optional[list[Any]] = None) -> CharField: + def field(self, default: Any | None, choices: list[Any] | None = None) -> CharField: """Get field type for Challenge and response. Choices are only valid for CHOICE_FIELDS.""" field_class = CharField kwargs = { "required": self.required, } - if self.type in (FieldTypes.TEXT, FieldTypes.TEXT_AREA): - kwargs["trim_whitespace"] = False - kwargs["allow_blank"] = not self.required - if self.type in (FieldTypes.TEXT_READ_ONLY, FieldTypes.TEXT_AREA_READ_ONLY): - field_class = ReadOnlyField - # required can't be set for ReadOnlyField - kwargs["required"] = False - if self.type == FieldTypes.EMAIL: - field_class = EmailField - kwargs["allow_blank"] = not self.required - if self.type == FieldTypes.NUMBER: - field_class = IntegerField - if self.type == FieldTypes.CHECKBOX: - field_class = BooleanField - kwargs["required"] = False + match self.type: + case FieldTypes.TEXT | FieldTypes.TEXT_AREA: + kwargs["trim_whitespace"] = False + kwargs["allow_blank"] = not self.required + case FieldTypes.TEXT_READ_ONLY, FieldTypes.TEXT_AREA_READ_ONLY: + field_class = ReadOnlyField + # required can't be set for ReadOnlyField + kwargs["required"] = False + case FieldTypes.EMAIL: + field_class = EmailField + kwargs["allow_blank"] = not self.required + case FieldTypes.NUMBER: + field_class = IntegerField + case FieldTypes.CHECKBOX: + field_class = BooleanField + kwargs["required"] = False + case FieldTypes.DATE: + field_class = DateField + case FieldTypes.DATE_TIME: + field_class = DateTimeField + case FieldTypes.FILE: + field_class = InlineFileField + case FieldTypes.SEPARATOR: + kwargs["required"] = False + kwargs["label"] = "" + case FieldTypes.HIDDEN: + field_class = HiddenField + kwargs["required"] = False + kwargs["default"] = self.placeholder + case FieldTypes.STATIC: + kwargs["default"] = self.placeholder + kwargs["required"] = False + kwargs["label"] = "" + + case FieldTypes.AK_LOCALE: + kwargs["allow_blank"] = True + if self.type in CHOICE_FIELDS: field_class = ChoiceField kwargs["choices"] = choices or [] - if self.type == FieldTypes.DATE: - field_class = DateField - if self.type == FieldTypes.DATE_TIME: - field_class = DateTimeField - if self.type == FieldTypes.FILE: - field_class = InlineFileField - if self.type == FieldTypes.SEPARATOR: - kwargs["required"] = False - kwargs["label"] = "" - if self.type == FieldTypes.HIDDEN: - field_class = HiddenField - kwargs["required"] = False - kwargs["default"] = self.placeholder - if self.type == FieldTypes.STATIC: - kwargs["default"] = self.placeholder - kwargs["required"] = False - kwargs["label"] = "" - - if self.type == FieldTypes.AK_LOCALE: - kwargs["allow_blank"] = True if default: kwargs["default"] = default @@ -337,7 +339,7 @@ class PromptStage(Stage): return PromptStageSerializer @property - def type(self) -> type[View]: + def view(self) -> type[View]: from authentik.stages.prompt.stage import PromptStageView return PromptStageView diff --git a/authentik/stages/prompt/stage.py b/authentik/stages/prompt/stage.py index 63733bf83d..244cbc1732 100644 --- a/authentik/stages/prompt/stage.py +++ b/authentik/stages/prompt/stage.py @@ -1,8 +1,9 @@ """Prompt Stage Logic""" +from collections.abc import Callable, Iterator from email.policy import Policy from types import MethodType -from typing import Any, Callable, Iterator +from typing import Any from django.db.models.query import QuerySet from django.http import HttpRequest, HttpResponse @@ -131,7 +132,7 @@ class PromptChallengeResponse(ChallengeResponse): password_fields: QuerySet[Prompt] = self.stage_instance.fields.filter( type=FieldTypes.PASSWORD ) - if password_fields.exists() and password_fields.count() == 2: + if password_fields.exists() and password_fields.count() == 2: # noqa: PLR2004 self._validate_password_fields(*[field.field_key for field in password_fields]) engine = ListPolicyEngine( @@ -152,7 +153,7 @@ class PromptChallengeResponse(ChallengeResponse): def username_field_validator_factory() -> Callable[[PromptChallenge, str], Any]: """Return a `clean_` method for `field`. Clean method checks if username is taken already.""" - def username_field_validator(self: PromptChallenge, value: str) -> Any: + def username_field_validator(_: PromptChallenge, value: str) -> Any: """Check for duplicate usernames""" if User.objects.filter(username=value).exists(): raise ValidationError("Username is already taken.") diff --git a/authentik/stages/prompt/tests.py b/authentik/stages/prompt/tests.py index a52d80a0ab..564e573e8d 100644 --- a/authentik/stages/prompt/tests.py +++ b/authentik/stages/prompt/tests.py @@ -23,7 +23,6 @@ from authentik.stages.prompt.stage import ( ) -# pylint: disable=too-many-public-methods class TestPromptStage(FlowTestCase): """Prompt tests""" diff --git a/authentik/stages/user_delete/models.py b/authentik/stages/user_delete/models.py index 154e4f9ae4..be7aec4d35 100644 --- a/authentik/stages/user_delete/models.py +++ b/authentik/stages/user_delete/models.py @@ -18,7 +18,7 @@ class UserDeleteStage(Stage): return UserDeleteStageSerializer @property - def type(self) -> type[View]: + def view(self) -> type[View]: from authentik.stages.user_delete.stage import UserDeleteStageView return UserDeleteStageView diff --git a/authentik/stages/user_login/middleware.py b/authentik/stages/user_login/middleware.py index 1a70b412b4..649c4a0c63 100644 --- a/authentik/stages/user_login/middleware.py +++ b/authentik/stages/user_login/middleware.py @@ -23,8 +23,7 @@ LOGGER = get_logger() class SessionBindingBroken(SentryIgnoredException): """Session binding was broken due to specified `reason`""" - # pylint: disable=too-many-arguments - def __init__( + def __init__( # noqa: PLR0913 self, reason: str, old_value: str, new_value: str, old_ip: str, new_ip: str ) -> None: self.reason = reason diff --git a/authentik/stages/user_login/models.py b/authentik/stages/user_login/models.py index 27bb6f3bf8..025e960763 100644 --- a/authentik/stages/user_login/models.py +++ b/authentik/stages/user_login/models.py @@ -71,7 +71,7 @@ class UserLoginStage(Stage): return UserLoginStageSerializer @property - def type(self) -> type[View]: + def view(self) -> type[View]: from authentik.stages.user_login.stage import UserLoginStageView return UserLoginStageView diff --git a/authentik/stages/user_logout/models.py b/authentik/stages/user_logout/models.py index 76e6f0db39..92fd36c99a 100644 --- a/authentik/stages/user_logout/models.py +++ b/authentik/stages/user_logout/models.py @@ -17,7 +17,7 @@ class UserLogoutStage(Stage): return UserLogoutStageSerializer @property - def type(self) -> type[View]: + def view(self) -> type[View]: from authentik.stages.user_logout.stage import UserLogoutStageView return UserLogoutStageView diff --git a/authentik/stages/user_logout/tests.py b/authentik/stages/user_logout/tests.py index 8f9035070b..c571354d62 100644 --- a/authentik/stages/user_logout/tests.py +++ b/authentik/stages/user_logout/tests.py @@ -37,7 +37,6 @@ class TestUserLogoutStage(FlowTestCase): reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}) ) - # pylint: disable=no-member self.assertEqual(response.status_code, 200) self.assertStageRedirects(response, reverse("authentik_core:root-redirect")) @@ -54,6 +53,5 @@ class TestUserLogoutStage(FlowTestCase): reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}) ) - # pylint: disable=no-member self.assertEqual(response.status_code, 200) self.assertStageRedirects(response, reverse("authentik_core:root-redirect")) diff --git a/authentik/stages/user_write/models.py b/authentik/stages/user_write/models.py index eb1dfd2c77..a4e94a71b9 100644 --- a/authentik/stages/user_write/models.py +++ b/authentik/stages/user_write/models.py @@ -55,7 +55,7 @@ class UserWriteStage(Stage): return UserWriteStageSerializer @property - def type(self) -> type[View]: + def view(self) -> type[View]: from authentik.stages.user_write.stage import UserWriteStageView return UserWriteStageView diff --git a/authentik/stages/user_write/stage.py b/authentik/stages/user_write/stage.py index 348509e768..3e8950c293 100644 --- a/authentik/stages/user_write/stage.py +++ b/authentik/stages/user_write/stage.py @@ -1,6 +1,6 @@ """Write stage logic""" -from typing import Any, Optional +from typing import Any from django.contrib.auth import update_session_auth_hash from django.db import transaction @@ -49,7 +49,7 @@ class UserWriteStageView(StageView): parts = parts[1:] set_path_in_dict(user.attributes, ".".join(parts), value) - def ensure_user(self) -> tuple[Optional[User], bool]: + def ensure_user(self) -> tuple[User | None, bool]: """Ensure a user exists""" user_created = False path = self.executor.plan.context.get( diff --git a/authentik/tenants/management/__init__.py b/authentik/tenants/management/__init__.py index a666aad85b..2f59bac740 100644 --- a/authentik/tenants/management/__init__.py +++ b/authentik/tenants/management/__init__.py @@ -25,7 +25,7 @@ class TenantCommand(BaseCommand): def handle(self, *args, **options): verbosity = int(options.get("verbosity")) - # pylint: disable=no-member + schema_name = options["schema_name"] or self.schema_name connection.set_schema_to_public() if verbosity >= 1: diff --git a/lifecycle/gunicorn.conf.py b/lifecycle/gunicorn.conf.py index eb32cacfe4..e608c94413 100644 --- a/lifecycle/gunicorn.conf.py +++ b/lifecycle/gunicorn.conf.py @@ -121,7 +121,7 @@ if not CONFIG.get_bool("disable_startup_analytics", False): }, timeout=5, ) - # pylint: disable=broad-exception-caught + except Exception: # nosec pass diff --git a/lifecycle/migrate.py b/lifecycle/migrate.py index 489ffa909a..6dceb6b007 100755 --- a/lifecycle/migrate.py +++ b/lifecycle/migrate.py @@ -55,8 +55,8 @@ def wait_for_lock(cursor: Cursor): """lock an advisory lock to prevent multiple instances from migrating at once""" LOGGER.info("waiting to acquire database lock") cursor.execute("SELECT pg_advisory_lock(%s)", (ADV_LOCK_UID,)) - # pylint: disable=global-statement - global LOCKED + + global LOCKED # noqa: PLW0603 LOCKED = True diff --git a/locale/en/LC_MESSAGES/django.po b/locale/en/LC_MESSAGES/django.po index effd6f223a..a9ffa90cb4 100644 --- a/locale/en/LC_MESSAGES/django.po +++ b/locale/en/LC_MESSAGES/django.po @@ -8,7 +8,7 @@ msgid "" msgstr "" "Project-Id-Version: PACKAGE VERSION\n" "Report-Msgid-Bugs-To: \n" -"POT-Creation-Date: 2024-02-14 18:51+0000\n" +"POT-Creation-Date: 2024-02-19 23:33+0000\n" "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" "Last-Translator: FULL NAME \n" "Language-Team: LANGUAGE \n" @@ -31,8 +31,8 @@ msgid "Blueprint file does not exist" msgstr "" #: authentik/blueprints/api.py -#, python-format -msgid "Failed to validate blueprint: %(logs)s" +#, python-brace-format +msgid "Failed to validate blueprint: {logs}" msgstr "" #: authentik/blueprints/api.py @@ -60,8 +60,8 @@ msgid "Blueprint Instances" msgstr "" #: authentik/blueprints/v1/exporter.py -#, python-format -msgid "authentik Export - %(date)s" +#, python-brace-format +msgid "authentik Export - {date}" msgstr "" #: authentik/blueprints/v1/tasks.py authentik/crypto/tasks.py @@ -271,9 +271,9 @@ msgid "Authenticated Sessions" msgstr "" #: authentik/core/sources/flow_manager.py -#, python-format +#, python-brace-format msgid "" -"Request to authenticate with %(source)s has been denied. Please authenticate " +"Request to authenticate with {source} has been denied. Please authenticate " "with the source you've previously signed up with." msgstr "" @@ -282,13 +282,13 @@ msgid "Configured flow does not exist." msgstr "" #: authentik/core/sources/flow_manager.py -#, python-format -msgid "Successfully authenticated with %(source)s!" +#, python-brace-format +msgid "Successfully authenticated with {source}!" msgstr "" #: authentik/core/sources/flow_manager.py -#, python-format -msgid "Successfully linked %(source)s!" +#, python-brace-format +msgid "Successfully linked {source}!" msgstr "" #: authentik/core/sources/flow_manager.py @@ -453,8 +453,8 @@ msgid "(You are already connected in another tab/window)" msgstr "" #: authentik/events/api/tasks.py -#, python-format -msgid "Successfully started task %(name)s." +#, python-brace-format +msgid "Successfully started task {name}." msgstr "" #: authentik/events/models.py @@ -576,18 +576,18 @@ msgid "Task has not been run yet." msgstr "" #: authentik/flows/api/flows.py -#, python-format -msgid "Flow not applicable to current user/request: %(messages)s" +#, python-brace-format +msgid "Flow not applicable to current user/request: {messages}" msgstr "" #: authentik/flows/api/flows_diagram.py -#, python-format -msgid "Policy (%(type)s)" +#, python-brace-format +msgid "Policy ({type})" msgstr "" #: authentik/flows/api/flows_diagram.py -#, python-format -msgid "Binding %(order)d" +#, python-brace-format +msgid "Binding {order}" msgstr "" #: authentik/flows/api/flows_diagram.py @@ -595,8 +595,8 @@ msgid "Policy passed" msgstr "" #: authentik/flows/api/flows_diagram.py -#, python-format -msgid "Stage (%(type)s)" +#, python-brace-format +msgid "Stage ({type})" msgstr "" #: authentik/flows/api/flows_diagram.py @@ -632,8 +632,8 @@ msgid "Flow does not apply to current user." msgstr "" #: authentik/flows/models.py -#, python-format -msgid "Dynamic In-memory stage: %(doc)s" +#, python-brace-format +msgid "Dynamic In-memory stage: {doc}" msgstr "" #: authentik/flows/models.py @@ -1317,8 +1317,8 @@ msgstr "" #: authentik/providers/oauth2/views/authorize.py #: authentik/providers/saml/views/flows.py -#, python-format -msgid "Redirecting to %(app)s..." +#, python-brace-format +msgid "Redirecting to {app}..." msgstr "" #: authentik/providers/oauth2/views/device_init.py @@ -1435,8 +1435,8 @@ msgid "Invalid XML Syntax" msgstr "" #: authentik/providers/saml/api/providers.py -#, python-format -msgid "Failed to import Metadata: %(message)s" +#, python-brace-format +msgid "Failed to import Metadata: {messages}" msgstr "" #: authentik/providers/saml/models.py @@ -1608,18 +1608,18 @@ msgid "Syncing page %(page)d of groups" msgstr "" #: authentik/providers/scim/tasks.py -#, python-format -msgid "Failed to sync user %(user_name)s due to remote error: %(error)s" +#, python-brace-format +msgid "Failed to sync user {user_name} due to remote error: {error}" msgstr "" #: authentik/providers/scim/tasks.py -#, python-format -msgid "Stopping sync due to error: %(error)s" +#, python-brace-format +msgid "Stopping sync due to error: {error}" msgstr "" #: authentik/providers/scim/tasks.py -#, python-format -msgid "Failed to sync group %(group_name)s due to remote error: %(error)s" +#, python-brace-format +msgid "Failed to sync group {group_name} due to remote error: {error}" msgstr "" #: authentik/rbac/models.py @@ -1925,8 +1925,8 @@ msgid "User OAuth Source Connections" msgstr "" #: authentik/sources/oauth/views/callback.py -#, python-format -msgid "Authentication failed: %(reason)s" +#, python-brace-format +msgid "Authentication failed: {reason}" msgstr "" #: authentik/sources/plex/models.py @@ -2073,8 +2073,8 @@ msgid "Optionally modify the payload being sent to custom providers." msgstr "" #: authentik/stages/authenticator_sms/models.py -#, python-format -msgid "Use this code to authenticate in authentik: %(token)s" +#, python-brace-format +msgid "Use this code to authenticate in authentik: {token}" msgstr "" #: authentik/stages/authenticator_sms/models.py diff --git a/poetry.lock b/poetry.lock index b90552d0a8..1a8a2ac86a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -250,17 +250,6 @@ files = [ {file = "asn1crypto-1.5.1.tar.gz", hash = "sha256:13ae38502be632115abf8a24cbe5f4da52e3b5231990aff31123c805306ccb9c"}, ] -[[package]] -name = "astroid" -version = "3.0.2" -description = "An abstract syntax tree for Python with inference support." -optional = false -python-versions = ">=3.8.0" -files = [ - {file = "astroid-3.0.2-py3-none-any.whl", hash = "sha256:d6e62862355f60e716164082d6b4b041d38e2a8cf1c7cd953ded5108bac8ff5c"}, - {file = "astroid-3.0.2.tar.gz", hash = "sha256:4a61cf0a59097c7bb52689b0fd63717cd2a8a14dc9f1eee97b82d814881c8c91"}, -] - [[package]] name = "attrs" version = "23.2.0" @@ -1130,20 +1119,6 @@ files = [ {file = "defusedxml-0.7.1.tar.gz", hash = "sha256:1bb3032db185915b62d7c6209c5a8792be6a32ab2fedacc84e01b52c51aa3e69"}, ] -[[package]] -name = "dill" -version = "0.3.7" -description = "serialize all of Python" -optional = false -python-versions = ">=3.7" -files = [ - {file = "dill-0.3.7-py3-none-any.whl", hash = "sha256:76b122c08ef4ce2eedcd4d1abd8e641114bfc6c2867f49f3c41facf65bf19f5e"}, - {file = "dill-0.3.7.tar.gz", hash = "sha256:cc1c8b182eb3013e24bd475ff2e9295af86c1a38eb1aff128dac8962a9ce3c03"}, -] - -[package.extras] -graph = ["objgraph (>=1.7.2)"] - [[package]] name = "django" version = "5.0.2" @@ -1790,20 +1765,6 @@ files = [ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] -[[package]] -name = "isort" -version = "5.13.2" -description = "A Python utility / library to sort Python imports." -optional = false -python-versions = ">=3.8.0" -files = [ - {file = "isort-5.13.2-py3-none-any.whl", hash = "sha256:8ca5e72a8d85860d5a3fa69b8745237f2939afe12dbf656afbcb47fe72d947a6"}, - {file = "isort-5.13.2.tar.gz", hash = "sha256:48fdfcb9face5d58a4f6dde2e72a1fb8dcaf8ab26f95ab49fab84c2ddefb0109"}, -] - -[package.extras] -colors = ["colorama (>=0.4.6)"] - [[package]] name = "jinja2" version = "3.1.3" @@ -2421,17 +2382,6 @@ files = [ [package.dependencies] setuptools = ">=68.2.2" -[[package]] -name = "mccabe" -version = "0.7.0" -description = "McCabe checker, plugin for flake8" -optional = false -python-versions = ">=3.6" -files = [ - {file = "mccabe-0.7.0-py2.py3-none-any.whl", hash = "sha256:6c2d30ab6be0e4a46919781807b4f0d834ebdd6c6e3dca0bda5a15f863427b6e"}, - {file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"}, -] - [[package]] name = "mdurl" version = "0.1.2" @@ -3055,62 +3005,6 @@ dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pyte docs = ["sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"] tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"] -[[package]] -name = "pylint" -version = "3.0.3" -description = "python code static checker" -optional = false -python-versions = ">=3.8.0" -files = [ - {file = "pylint-3.0.3-py3-none-any.whl", hash = "sha256:7a1585285aefc5165db81083c3e06363a27448f6b467b3b0f30dbd0ac1f73810"}, - {file = "pylint-3.0.3.tar.gz", hash = "sha256:58c2398b0301e049609a8429789ec6edf3aabe9b6c5fec916acd18639c16de8b"}, -] - -[package.dependencies] -astroid = ">=3.0.1,<=3.1.0-dev0" -colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""} -dill = {version = ">=0.3.7", markers = "python_version >= \"3.12\""} -isort = ">=4.2.5,<5.13.0 || >5.13.0,<6" -mccabe = ">=0.6,<0.8" -platformdirs = ">=2.2.0" -tomlkit = ">=0.10.1" - -[package.extras] -spelling = ["pyenchant (>=3.2,<4.0)"] -testutils = ["gitpython (>3)"] - -[[package]] -name = "pylint-django" -version = "2.5.5" -description = "A Pylint plugin to help Pylint understand the Django web framework" -optional = false -python-versions = ">=3.7,<4.0" -files = [ - {file = "pylint_django-2.5.5-py3-none-any.whl", hash = "sha256:5abd5c2228e0e5e2a4cb6d0b4fc1d1cef1e773d0be911412f4dd4fc1a1a440b7"}, - {file = "pylint_django-2.5.5.tar.gz", hash = "sha256:2f339e4bf55776958283395c5139c37700c91bd5ef1d8251ef6ac88b5abbba9b"}, -] - -[package.dependencies] -pylint = ">=2.0,<4" -pylint-plugin-utils = ">=0.8" - -[package.extras] -with-django = ["Django (>=2.2)"] - -[[package]] -name = "pylint-plugin-utils" -version = "0.8.2" -description = "Utilities and helpers for writing Pylint plugins" -optional = false -python-versions = ">=3.7,<4.0" -files = [ - {file = "pylint_plugin_utils-0.8.2-py3-none-any.whl", hash = "sha256:ae11664737aa2effbf26f973a9e0b6779ab7106ec0adc5fe104b0907ca04e507"}, - {file = "pylint_plugin_utils-0.8.2.tar.gz", hash = "sha256:d3cebf68a38ba3fba23a873809155562571386d4c1b03e5b4c4cc26c3eee93e4"}, -] - -[package.dependencies] -pylint = ">=1.7" - [[package]] name = "pynacl" version = "1.5.0" @@ -3867,17 +3761,6 @@ files = [ [package.dependencies] celery = "*" -[[package]] -name = "tomlkit" -version = "0.12.3" -description = "Style preserving TOML library" -optional = false -python-versions = ">=3.7" -files = [ - {file = "tomlkit-0.12.3-py3-none-any.whl", hash = "sha256:b0a645a9156dc7cb5d3a1f0d4bab66db287fcb8e0430bdd4664a095ea16414ba"}, - {file = "tomlkit-0.12.3.tar.gz", hash = "sha256:75baf5012d06501f07bee5bf8e801b9f343e7aac5a92581f20f80ce632e6b5a4"}, -] - [[package]] name = "tornado" version = "6.4" @@ -4667,4 +4550,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "~3.12" -content-hash = "872b759c19aa026742ef493c1b13a5f01dc5baab84efe7385163c4bf428b8f8f" +content-hash = "825f1d552ba34206f7bfd55b70bfb42bc5d769605f59410703828ae787cd0baf" diff --git a/pyproject.toml b/pyproject.toml index 075b4575f7..4ce9ece108 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,19 +1,8 @@ -[tool.pyright] -ignore = ["**/migrations/**", "**/node_modules/**"] -reportMissingTypeStubs = false -strictParameterNoneValue = true -strictDictionaryInference = true -strictListInference = true -reportOptionalMemberAccess = false -reportOptionalContextManager = false -# rest_framework's serializer's `validated_data` is typed as optional None -reportOptionalSubscript = false -# Sadly pyright still has issues with enums, and they fall under general type issues -# so we have to disable those for now -reportGeneralTypeIssues = false -verboseOutput = false -pythonVersion = "3.12" -pythonPlatform = "All" +[tool.poetry] +name = "authentik" +version = "2024.2.1" +description = "" +authors = ["authentik Team "] [tool.black] line-length = 100 @@ -25,14 +14,30 @@ line-length = 100 target-version = "py312" exclude = ["**/migrations/**", "**/node_modules/**"] -[tool.isort] -multi_line_output = 3 -include_trailing_comma = true -force_grid_wrap = 0 -use_parentheses = true -line_length = 100 -src_paths = ["authentik", "tests", "lifecycle"] -force_to_top = "*" +[tool.ruff.lint] +select = [ + # pycodestyle + "E", + # Pyflakes + "F", + # isort + "I", + # pyupgrade + "UP", + # flake8-bugbear + "B", + # django + "DJ", + # pylint + "PL", +] +ignore = [ + "DJ001" # Avoid using `null=True` on string-based fields, +] +[tool.ruff.lint.pylint] +max-args = 7 +max-branches = 18 +max-returns = 10 [tool.coverage.run] source = ["authentik"] @@ -67,40 +72,6 @@ exclude_lines = [ ] show_missing = true -[tool.pylint.basic] -good-names = ["pk", "id", "i", "j", "k", "_", "bar"] - -[tool.pylint.master] -disable = [ - "arguments-differ", - "locally-disabled", - "too-many-ancestors", - "too-few-public-methods", - "import-outside-toplevel", - "signature-differs", - "similarities", - "cyclic-import", - "protected-access", - "unused-argument", - "raise-missing-from", - "fixme", - # To preserve django's translation function we need to use %-formatting - "consider-using-f-string", -] - -load-plugins = ["pylint_django", "pylint.extensions.bad_builtin"] -django-settings-module = "authentik.root.settings" -extension-pkg-whitelist = ["lxml", "xmlsec"] - -# Allow constants to be shorter than normal (and lowercase, for settings.py) -const-rgx = "[a-zA-Z0-9_]{1,40}$" - -ignored-modules = ["binascii", "socket", "zlib"] -generated-members = ["xmlsec.constants.*", "xmlsec.tree.*", "xmlsec.template.*"] -ignore = ["migrations", "tests"] -max-attributes = 12 -max-branches = 20 - [tool.pytest.ini_options] DJANGO_SETTINGS_MODULE = "authentik.root.settings" python_files = ["tests.py", "test_*.py", "*_tests.py"] @@ -111,12 +82,6 @@ filterwarnings = [ "ignore:SelectableGroups dict interface is deprecated. Use select.:DeprecationWarning", ] -[tool.poetry] -name = "authentik" -version = "2024.2.1" -description = "" -authors = ["authentik Team "] - [tool.poetry.dependencies] argon2-cffi = "*" celery = "*" @@ -194,8 +159,6 @@ drf-jsonschema-serializer = "*" freezegun = "*" importlib-metadata = "*" pdoc = "*" -pylint = "*" -pylint-django = "*" pyrad = "*" pytest = "*" pytest-django = "*" diff --git a/tests/e2e/test_provider_ldap.py b/tests/e2e/test_provider_ldap.py index 5a3a8d7f3f..465c6e68fc 100644 --- a/tests/e2e/test_provider_ldap.py +++ b/tests/e2e/test_provider_ldap.py @@ -71,7 +71,7 @@ class TestProviderLDAP(SeleniumTestCase): # Wait until outpost healthcheck succeeds healthcheck_retries = 0 - while healthcheck_retries < 50: + while healthcheck_retries < 50: # noqa: PLR2004 if len(outpost.state) > 0: state = outpost.state[0] if state.last_seen: diff --git a/tests/e2e/test_provider_oauth2_github.py b/tests/e2e/test_provider_oauth2_github.py index 3fbdd7316b..64eecd032d 100644 --- a/tests/e2e/test_provider_oauth2_github.py +++ b/tests/e2e/test_provider_oauth2_github.py @@ -1,7 +1,7 @@ """test OAuth Provider flow""" from time import sleep -from typing import Any, Optional +from typing import Any from docker.types import Healthcheck from selenium.webdriver.common.by import By @@ -25,7 +25,7 @@ class TestProviderOAuth2Github(SeleniumTestCase): self.client_secret = generate_key() super().setUp() - def get_container_specs(self) -> Optional[dict[str, Any]]: + def get_container_specs(self) -> dict[str, Any] | None: """Setup client grafana container which we test OAuth against""" return { "image": "grafana/grafana:7.1.0", diff --git a/tests/e2e/test_provider_oauth2_grafana.py b/tests/e2e/test_provider_oauth2_grafana.py index 2a02f603a4..0bdd167560 100644 --- a/tests/e2e/test_provider_oauth2_grafana.py +++ b/tests/e2e/test_provider_oauth2_grafana.py @@ -1,7 +1,7 @@ """test OAuth2 OpenID Provider flow""" from time import sleep -from typing import Any, Optional +from typing import Any from docker.types import Healthcheck from selenium.webdriver.common.by import By @@ -33,7 +33,7 @@ class TestProviderOAuth2OAuth(SeleniumTestCase): self.app_slug = generate_id(20) super().setUp() - def get_container_specs(self) -> Optional[dict[str, Any]]: + def get_container_specs(self) -> dict[str, Any] | None: return { "image": "grafana/grafana:7.1.0", "detach": True, diff --git a/tests/e2e/test_provider_proxy.py b/tests/e2e/test_provider_proxy.py index 99d306dd90..bb4844dcae 100644 --- a/tests/e2e/test_provider_proxy.py +++ b/tests/e2e/test_provider_proxy.py @@ -4,7 +4,7 @@ from base64 import b64encode from dataclasses import asdict from sys import platform from time import sleep -from typing import Any, Optional +from typing import Any from unittest.case import skip, skipUnless from channels.testing import ChannelsLiveServerTestCase @@ -32,7 +32,7 @@ class TestProviderProxy(SeleniumTestCase): self.output_container_logs(self.proxy_container) self.proxy_container.kill() - def get_container_specs(self) -> Optional[dict[str, Any]]: + def get_container_specs(self) -> dict[str, Any] | None: return { "image": "traefik/whoami:latest", "detach": True, @@ -101,7 +101,7 @@ class TestProviderProxy(SeleniumTestCase): # Wait until outpost healthcheck succeeds healthcheck_retries = 0 - while healthcheck_retries < 50: + while healthcheck_retries < 50: # noqa: PLR2004 if len(outpost.state) > 0: state = outpost.state[0] if state.last_seen: @@ -171,7 +171,7 @@ class TestProviderProxy(SeleniumTestCase): # Wait until outpost healthcheck succeeds healthcheck_retries = 0 - while healthcheck_retries < 50: + while healthcheck_retries < 50: # noqa: PLR2004 if len(outpost.state) > 0: state = outpost.state[0] if state.last_seen: @@ -212,7 +212,7 @@ class TestProviderProxyConnect(ChannelsLiveServerTestCase): @reconcile_app("authentik_crypto") def test_proxy_connectivity(self): """Test proxy connectivity over websocket""" - outpost_connection_discovery() # pylint: disable=no-value-for-parameter + outpost_connection_discovery() proxy: ProxyProvider = ProxyProvider.objects.create( name=generate_id(), authorization_flow=Flow.objects.get( @@ -238,7 +238,7 @@ class TestProviderProxyConnect(ChannelsLiveServerTestCase): # Wait until outpost healthcheck succeeds healthcheck_retries = 0 - while healthcheck_retries < 50: + while healthcheck_retries < 50: # noqa: PLR2004 if len(outpost.state) > 0: state = outpost.state[0] if state.last_seen and state.version: diff --git a/tests/e2e/test_provider_radius.py b/tests/e2e/test_provider_radius.py index a0c7bf1b45..f67f6a1885 100644 --- a/tests/e2e/test_provider_radius.py +++ b/tests/e2e/test_provider_radius.py @@ -66,7 +66,7 @@ class TestProviderRadius(SeleniumTestCase): # Wait until outpost healthcheck succeeds healthcheck_retries = 0 - while healthcheck_retries < 50: + while healthcheck_retries < 50: # noqa: PLR2004 if len(outpost.state) > 0: state = outpost.state[0] if state.last_seen: diff --git a/tests/e2e/test_source_ldap_samba.py b/tests/e2e/test_source_ldap_samba.py index 522b7a03f5..d5210f6b51 100644 --- a/tests/e2e/test_source_ldap_samba.py +++ b/tests/e2e/test_source_ldap_samba.py @@ -1,6 +1,6 @@ """test LDAP Source""" -from typing import Any, Optional +from typing import Any from django.db.models import Q from ldap3.core.exceptions import LDAPSessionTerminatedByServerError @@ -23,7 +23,7 @@ class TestSourceLDAPSamba(SeleniumTestCase): self.admin_password = generate_key() super().setUp() - def get_container_specs(self) -> Optional[dict[str, Any]]: + def get_container_specs(self) -> dict[str, Any] | None: return { "image": "ghcr.io/beryju/test-samba-dc:latest", "detach": True, diff --git a/tests/e2e/test_source_oauth_oauth1.py b/tests/e2e/test_source_oauth_oauth1.py index f1672b04be..1da92a4c3e 100644 --- a/tests/e2e/test_source_oauth_oauth1.py +++ b/tests/e2e/test_source_oauth_oauth1.py @@ -1,7 +1,7 @@ """test OAuth Source""" from time import sleep -from typing import Any, Optional +from typing import Any from selenium.webdriver.common.by import By from selenium.webdriver.common.keys import Keys @@ -60,7 +60,7 @@ class TestSourceOAuth1(SeleniumTestCase): self.source_slug = generate_id() super().setUp() - def get_container_specs(self) -> Optional[dict[str, Any]]: + def get_container_specs(self) -> dict[str, Any] | None: return { "image": "ghcr.io/beryju/oauth1-test-server:v1.1", "detach": True, diff --git a/tests/e2e/test_source_oauth_oauth2.py b/tests/e2e/test_source_oauth_oauth2.py index 7a312ef6a1..5986de7c47 100644 --- a/tests/e2e/test_source_oauth_oauth2.py +++ b/tests/e2e/test_source_oauth_oauth2.py @@ -2,7 +2,7 @@ from pathlib import Path from time import sleep -from typing import Any, Optional +from typing import Any from docker.models.containers import Container from docker.types import Healthcheck @@ -69,7 +69,7 @@ class TestSourceOAuth2(SeleniumTestCase): with open(CONFIG_PATH, "w+", encoding="utf8") as _file: safe_dump(config, _file) - def get_container_specs(self) -> Optional[dict[str, Any]]: + def get_container_specs(self) -> dict[str, Any] | None: return { "image": "ghcr.io/dexidp/dex:v2.28.1", "detach": True, diff --git a/tests/e2e/test_source_saml.py b/tests/e2e/test_source_saml.py index 233f38b9a2..433ca42527 100644 --- a/tests/e2e/test_source_saml.py +++ b/tests/e2e/test_source_saml.py @@ -1,7 +1,7 @@ """test SAML Source""" from time import sleep -from typing import Any, Optional +from typing import Any from docker.types import Healthcheck from selenium.webdriver.common.by import By @@ -77,7 +77,7 @@ class TestSourceSAML(SeleniumTestCase): self.slug = generate_id() super().setUp() - def get_container_specs(self) -> Optional[dict[str, Any]]: + def get_container_specs(self) -> dict[str, Any] | None: return { "image": "kristophjunge/test-saml-idp:1.15", "detach": True, diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index 77decd7951..fa365cc6be 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -3,11 +3,12 @@ import json import os import socket +from collections.abc import Callable from functools import lru_cache, wraps from os import environ from sys import stderr from time import sleep -from typing import Any, Callable, Optional +from typing import Any from django.apps import apps from django.contrib.staticfiles.testing import StaticLiveServerTestCase @@ -66,7 +67,7 @@ class DockerTestCase: return container sleep(1) attempt += 1 - if attempt >= 30: + if attempt >= 30: # noqa: PLR2004 self.failureException("Container failed to start") @@ -74,7 +75,7 @@ class SeleniumTestCase(DockerTestCase, StaticLiveServerTestCase): """StaticLiveServerTestCase which automatically creates a Webdriver instance""" host = get_local_ip() - container: Optional[Container] = None + container: Container | None = None wait_timeout: int user: User @@ -114,7 +115,7 @@ class SeleniumTestCase(DockerTestCase, StaticLiveServerTestCase): self.wait_for_container(container) return container - def output_container_logs(self, container: Optional[Container] = None): + def output_container_logs(self, container: Container | None = None): """Output the container logs to our STDOUT""" _container = container or self.container if IS_CI: @@ -124,7 +125,7 @@ class SeleniumTestCase(DockerTestCase, StaticLiveServerTestCase): if IS_CI: print("::endgroup::") - def get_container_specs(self) -> Optional[dict[str, Any]]: + def get_container_specs(self) -> dict[str, Any] | None: """Optionally get container specs which will launched on setup, wait for the container to be healthy, and deleted again on tearDown""" return None @@ -178,7 +179,7 @@ class SeleniumTestCase(DockerTestCase, StaticLiveServerTestCase): return f"{self.live_server_url}/if/user/#{view}" def get_shadow_root( - self, selector: str, container: Optional[WebElement | WebDriver] = None + self, selector: str, container: WebElement | WebDriver | None = None ) -> WebElement: """Get shadow root element's inner shadowRoot""" if not container: @@ -245,12 +246,12 @@ def retry(max_retires=RETRIES, exceptions=None): nonlocal count try: return func(self, *args, **kwargs) - # pylint: disable=catching-non-exception + except tuple(exceptions) as exc: count += 1 if count > max_retires: logger.debug("Exceeded retry count", exc=exc, test=self) - # pylint: disable=raising-non-exception + raise exc logger.debug("Retrying on error", exc=exc, test=self) self.tearDown() diff --git a/tests/integration/test_outpost_docker.py b/tests/integration/test_outpost_docker.py index 1ea8bfc6d6..7fd6dfd3fd 100644 --- a/tests/integration/test_outpost_docker.py +++ b/tests/integration/test_outpost_docker.py @@ -54,7 +54,7 @@ class OutpostDockerTests(DockerTestCase, ChannelsLiveServerTestCase): self.ssl_folder = mkdtemp() self.container = self._start_container(self.ssl_folder) # Ensure that local connection have been created - outpost_connection_discovery() # pylint: disable=no-value-for-parameter + outpost_connection_discovery() self.provider: ProxyProvider = ProxyProvider.objects.create( name="test", internal_host="http://localhost", diff --git a/tests/integration/test_outpost_kubernetes.py b/tests/integration/test_outpost_kubernetes.py index 068bab14d3..d5454dbc20 100644 --- a/tests/integration/test_outpost_kubernetes.py +++ b/tests/integration/test_outpost_kubernetes.py @@ -23,7 +23,7 @@ class OutpostKubernetesTests(TestCase): def setUp(self): super().setUp() # Ensure that local connection have been created - outpost_connection_discovery() # pylint: disable=no-value-for-parameter + outpost_connection_discovery() self.provider: ProxyProvider = ProxyProvider.objects.create( name="test", internal_host="http://localhost", diff --git a/tests/integration/test_proxy_docker.py b/tests/integration/test_proxy_docker.py index 5db47da734..9aefccf39d 100644 --- a/tests/integration/test_proxy_docker.py +++ b/tests/integration/test_proxy_docker.py @@ -54,7 +54,7 @@ class TestProxyDocker(DockerTestCase, ChannelsLiveServerTestCase): self.ssl_folder = mkdtemp() self.container = self._start_container(self.ssl_folder) # Ensure that local connection have been created - outpost_connection_discovery() # pylint: disable=no-value-for-parameter + outpost_connection_discovery() self.provider: ProxyProvider = ProxyProvider.objects.create( name="test", internal_host="http://localhost", diff --git a/tests/integration/test_proxy_kubernetes.py b/tests/integration/test_proxy_kubernetes.py index 4840ff8326..19b37dcf64 100644 --- a/tests/integration/test_proxy_kubernetes.py +++ b/tests/integration/test_proxy_kubernetes.py @@ -1,7 +1,5 @@ """Test Controllers""" -from typing import Optional - import pytest import yaml from django.test import TestCase @@ -21,11 +19,11 @@ LOGGER = get_logger() class TestProxyKubernetes(TestCase): """Test Controllers""" - controller: Optional[KubernetesController] + controller: KubernetesController | None def setUp(self): # Ensure that local connection have been created - outpost_connection_discovery() # pylint: disable=no-value-for-parameter + outpost_connection_discovery() self.controller = None def tearDown(self) -> None: diff --git a/web/package-lock.json b/web/package-lock.json index a057f3c99f..554ab1d99f 100644 --- a/web/package-lock.json +++ b/web/package-lock.json @@ -90,7 +90,6 @@ "npm-run-all": "^4.1.5", "prettier": "^3.2.5", "pseudolocale": "^2.0.0", - "pyright": "=1.1.338", "react": "^18.2.0", "react-dom": "^18.2.0", "rollup": "^4.12.0", @@ -15720,22 +15719,6 @@ "async-limiter": "~1.0.0" } }, - "node_modules/pyright": { - "version": "1.1.338", - "resolved": "https://registry.npmjs.org/pyright/-/pyright-1.1.338.tgz", - "integrity": "sha512-cY4p/LZjC3E1m6If48n19vZgBOUASIOX6zMTavIo2o2JlJRd6/+gy+aYaMdmljVF2mVP8NG6OuKiGxERSdpQOw==", - "dev": true, - "bin": { - "pyright": "index.js", - "pyright-langserver": "langserver.index.js" - }, - "engines": { - "node": ">=12.0.0" - }, - "optionalDependencies": { - "fsevents": "~2.3.2" - } - }, "node_modules/qrjs": { "version": "0.2.0", "resolved": "https://registry.npmjs.org/qrjs/-/qrjs-0.2.0.tgz", diff --git a/web/package.json b/web/package.json index 70702535f7..bb389d6ef5 100644 --- a/web/package.json +++ b/web/package.json @@ -115,7 +115,6 @@ "npm-run-all": "^4.1.5", "prettier": "^3.2.5", "pseudolocale": "^2.0.0", - "pyright": "=1.1.338", "react": "^18.2.0", "react-dom": "^18.2.0", "rollup": "^4.12.0", diff --git a/website/developer-docs/index.md b/website/developer-docs/index.md index 9b312b8a0c..82c97977fc 100644 --- a/website/developer-docs/index.md +++ b/website/developer-docs/index.md @@ -34,8 +34,6 @@ authentik is at it's very core a Django project. It consists of many individual These are the current packages: - - ``` authentik ├── admin - Administrative tasks and APIs, no models (Version updates, Metrics, system tasks) @@ -145,20 +143,20 @@ While the prerequisites above must be satisfied prior to having your pull reques ### PR naming - Use the format of `: ` - - See [here](#authentik-packages) for `package` + - See [here](#authentiks-structure) for `package` - Example: `providers/saml2: fix parsing of requests` ### Git Commit Messages - Use the format of `: ` - - See [here](#authentik-packages) for `package` + - See [here](#authentiks-structure) for `package` - Example: `providers/saml2: fix parsing of requests` - Reference issues and pull requests liberally after the first line - Naming of commits within a PR does not need to adhere to the guidelines as we squash merge PRs ### Python Styleguide -All Python code is linted with [black](https://black.readthedocs.io/en/stable/), [PyLint](https://www.pylint.org/) and [isort](https://pycqa.github.io/isort/). +All Python code is linted with [black](https://black.readthedocs.io/en/stable/) and [Ruff](https://docs.astral.sh/ruff). authentik runs on Python 3.12 at the time of writing this. diff --git a/website/developer-docs/setup/full-dev-environment.md b/website/developer-docs/setup/full-dev-environment.md index 637d0419b8..df4468e4ec 100644 --- a/website/developer-docs/setup/full-dev-environment.md +++ b/website/developer-docs/setup/full-dev-environment.md @@ -54,10 +54,6 @@ make lint # Ensures your code is well-formatted make gen # Generates an updated OpenAPI Docs for any changes you make ``` -:::info -Linting also requires `pyright`, which is installed in the `web/` folder to make dependency management easier. -::: - ## Frontend Setup By default, no compiled bundle of the frontend is included so this step is required even if you're not developing for the UI.