root: early spring clean for linting (#8498)

* remove pyright

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* remove pylint

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* replace pylint with ruff

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* ruff fix

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* fix UP038

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix DJ012

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix default arg

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix UP031

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* rename stage type to view

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix DJ008

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix remaining upgrade

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix PLR2004

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix B904

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix PLW2901

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix remaining issues

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* prevent ruff from breaking the code

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* stages/prompt: refactor field building

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* fix tests

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix lint

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fully remove isort

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
Co-authored-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
Jens L
2024-02-24 18:13:35 +01:00
committed by GitHub
parent 507f9b7ae2
commit b225b0200e
260 changed files with 1058 additions and 1352 deletions

View File

@ -1,6 +1,6 @@
"""Source API Views"""
from typing import Any, Optional
from typing import Any
from django.core.cache import cache
from django_filters.filters import AllValuesMultipleFilter
@ -39,7 +39,7 @@ class LDAPSourceSerializer(SourceSerializer):
required=False,
)
def get_connectivity(self, source: LDAPSource) -> Optional[dict[str, dict[str, str]]]:
def get_connectivity(self, source: LDAPSource) -> dict[str, dict[str, str]] | None:
"""Get cached source connectivity"""
return cache.get(CACHE_KEY_STATUS + source.slug, None)

View File

@ -1,7 +1,5 @@
"""authentik LDAP Authentication Backend"""
from typing import Optional
from django.http import HttpRequest
from ldap3.core.exceptions import LDAPException, LDAPInvalidCredentialsResult
from structlog.stdlib import get_logger
@ -29,7 +27,7 @@ class LDAPBackend(InbuiltBackend):
return user
return None
def auth_user(self, source: LDAPSource, password: str, **filters: str) -> Optional[User]:
def auth_user(self, source: LDAPSource, password: str, **filters: str) -> User | None:
"""Try to bind as either user_dn or mail with password.
Returns True on success, otherwise False"""
users = User.objects.filter(**filters)
@ -52,7 +50,7 @@ class LDAPBackend(InbuiltBackend):
LOGGER.debug("Failed to bind, password invalid")
return None
def auth_user_by_bind(self, source: LDAPSource, user: User, password: str) -> Optional[User]:
def auth_user_by_bind(self, source: LDAPSource, user: User, password: str) -> User | None:
"""Attempt authentication by binding to the LDAP server as `user`. This
method should be avoided as its slow to do the bind."""
# Try to bind as new user

View File

@ -5,7 +5,6 @@ from os.path import dirname, exists
from shutil import rmtree
from ssl import CERT_REQUIRED
from tempfile import NamedTemporaryFile, mkdtemp
from typing import Optional
from django.core.cache import cache
from django.db import connection, models
@ -160,9 +159,9 @@ class LDAPSource(Source):
def connection(
self,
server: Optional[Server] = None,
server_kwargs: Optional[dict] = None,
connection_kwargs: Optional[dict] = None,
server: Server | None = None,
server_kwargs: dict | None = None,
connection_kwargs: dict | None = None,
) -> Connection:
"""Get a fully connected and bound LDAP Connection"""
server_kwargs = server_kwargs or {}

View File

@ -2,7 +2,6 @@
from enum import IntFlag
from re import split
from typing import Optional
from ldap3 import BASE
from ldap3.core.exceptions import (
@ -20,6 +19,7 @@ LOGGER = get_logger()
NON_ALPHA = r"~!@#$%^&*_-+=`|\(){}[]:;\"'<>,.?/"
RE_DISPLAYNAME_SEPARATORS = r",\.—_\s#\t"
MIN_TOKEN_SIZE = 3
class PwdProperties(IntFlag):
@ -119,7 +119,7 @@ class LDAPPasswordChanger:
raise AssertionError()
user_attributes = users[0]["attributes"]
# If sAMAccountName is longer than 3 chars, check if its contained in password
if len(user_attributes["sAMAccountName"]) >= 3:
if len(user_attributes["sAMAccountName"]) >= MIN_TOKEN_SIZE:
if password.lower() in user_attributes["sAMAccountName"].lower():
return False
# No display name set, can't check any further
@ -129,13 +129,13 @@ class LDAPPasswordChanger:
display_name_tokens = split(RE_DISPLAYNAME_SEPARATORS, display_name)
for token in display_name_tokens:
# Ignore tokens under 3 chars
if len(token) < 3:
if len(token) < MIN_TOKEN_SIZE:
continue
if token.lower() in password.lower():
return False
return True
def ad_password_complexity(self, password: str, user: Optional[User] = None) -> bool:
def ad_password_complexity(self, password: str, user: User | None = None) -> bool:
"""Check if password matches Active directory password policies
https://docs.microsoft.com/en-us/windows/security/threat-protection/

View File

@ -1,6 +1,7 @@
"""Sync LDAP Users and groups into authentik"""
from typing import Any, Generator
from collections.abc import Generator
from typing import Any
from django.conf import settings
from django.db.models.base import Model
@ -90,8 +91,7 @@ class BaseLDAPSynchronizer:
"""Get objects from LDAP, implemented in subclass"""
raise NotImplementedError()
# pylint: disable=too-many-arguments
def search_paginator(
def search_paginator( # noqa: PLR0913
self,
search_base,
search_filter,
@ -103,11 +103,13 @@ class BaseLDAPSynchronizer:
types_only=False,
get_operational_attributes=False,
controls=None,
paged_size=CONFIG.get_int("ldap.page_size", 50),
paged_size=None,
paged_criticality=False,
):
"""Search in pages, returns each page"""
cookie = True
if not paged_size:
paged_size = CONFIG.get_int("ldap.page_size", 50)
while cookie:
self._connection.search(
search_base,

View File

@ -1,6 +1,6 @@
"""Sync LDAP Users and groups into authentik"""
from typing import Generator
from collections.abc import Generator
from django.core.exceptions import FieldError
from django.db.utils import IntegrityError

View File

@ -1,6 +1,7 @@
"""Sync LDAP Users and groups into authentik"""
from typing import Any, Generator, Optional
from collections.abc import Generator
from typing import Any
from django.db.models import Q
from ldap3 import SUBTREE
@ -76,7 +77,7 @@ class MembershipLDAPSynchronizer(BaseLDAPSynchronizer):
self._logger.debug("Successfully updated group membership")
return membership_count
def get_group(self, group_dict: dict[str, Any]) -> Optional[Group]:
def get_group(self, group_dict: dict[str, Any]) -> Group | None:
"""Check if we fetched the group already, and if not cache it for later"""
group_dn = group_dict.get("attributes", {}).get(LDAP_DISTINGUISHED_NAME, [])
group_uniq = group_dict.get("attributes", {}).get(self._source.object_uniqueness_field, [])

View File

@ -1,6 +1,6 @@
"""Sync LDAP Users into authentik"""
from typing import Generator
from collections.abc import Generator
from django.core.exceptions import FieldError
from django.db.utils import IntegrityError

View File

@ -1,7 +1,8 @@
"""FreeIPA specific"""
from datetime import datetime, timezone
from typing import Any, Generator
from collections.abc import Generator
from datetime import UTC, datetime
from typing import Any
from authentik.core.models import User
from authentik.sources.ldap.sync.base import BaseLDAPSynchronizer, flatten
@ -26,7 +27,7 @@ class FreeIPA(BaseLDAPSynchronizer):
if "krbLastPwdChange" not in attributes:
return
pwd_last_set: datetime = attributes.get("krbLastPwdChange", datetime.now())
pwd_last_set = pwd_last_set.replace(tzinfo=timezone.utc)
pwd_last_set = pwd_last_set.replace(tzinfo=UTC)
if created or pwd_last_set >= user.password_change_date:
self.message(f"'{user.username}': Reset user's password")
self._logger.debug(

View File

@ -1,8 +1,9 @@
"""Active Directory specific"""
from datetime import datetime, timezone
from collections.abc import Generator
from datetime import UTC, datetime
from enum import IntFlag
from typing import Any, Generator
from typing import Any
from authentik.core.models import User
from authentik.sources.ldap.sync.base import BaseLDAPSynchronizer
@ -57,7 +58,7 @@ class MicrosoftActiveDirectory(BaseLDAPSynchronizer):
if "pwdLastSet" not in attributes:
return
pwd_last_set: datetime = attributes.get("pwdLastSet", datetime.now())
pwd_last_set = pwd_last_set.replace(tzinfo=timezone.utc)
pwd_last_set = pwd_last_set.replace(tzinfo=UTC)
if created or pwd_last_set >= user.password_change_date:
self.message(f"'{user.username}': Reset user's password")
self._logger.debug(

View File

@ -1,6 +1,5 @@
"""LDAP Sync tasks"""
from typing import Optional
from uuid import uuid4
from celery import chain, group
@ -40,7 +39,7 @@ def ldap_sync_all():
@CELERY_APP.task()
def ldap_connectivity_check(pk: Optional[str] = None):
def ldap_connectivity_check(pk: str | None = None):
"""Check connectivity for LDAP Sources"""
# 2 hour timeout, this task should run every hour
timeout = 60 * 60 * 2

View File

@ -57,7 +57,6 @@ class OAuthSourceSerializer(SourceSerializer):
"""Get source's type configuration"""
return SourceTypeSerializer(instance.source_type).data
# pylint: disable=too-many-locals
def validate(self, attrs: dict) -> dict:
session = get_http_session()
source_type = registry.find_type(attrs["provider_type"])
@ -71,7 +70,7 @@ class OAuthSourceSerializer(SourceSerializer):
well_known_config.raise_for_status()
except RequestException as exc:
text = exc.response.text if exc.response else str(exc)
raise ValidationError({"oidc_well_known_url": text})
raise ValidationError({"oidc_well_known_url": text}) from None
config = well_known_config.json()
if "issuer" not in config:
raise ValidationError({"oidc_well_known_url": "Invalid well-known configuration"})
@ -97,7 +96,7 @@ class OAuthSourceSerializer(SourceSerializer):
jwks_config.raise_for_status()
except RequestException as exc:
text = exc.response.text if exc.response else str(exc)
raise ValidationError({"oidc_jwks_url": text})
raise ValidationError({"oidc_jwks_url": text}) from None
config = jwks_config.json()
attrs["oidc_jwks"] = config

View File

@ -1,6 +1,6 @@
"""OAuth Clients"""
from typing import Any, Optional
from typing import Any
from urllib.parse import parse_qs, quote, urlencode, urlparse, urlunparse
from django.http import HttpRequest
@ -22,20 +22,20 @@ class BaseOAuthClient:
source: OAuthSource
request: HttpRequest
callback: Optional[str]
callback: str | None
def __init__(self, source: OAuthSource, request: HttpRequest, callback: Optional[str] = None):
def __init__(self, source: OAuthSource, request: HttpRequest, callback: str | None = None):
self.source = source
self.session = get_http_session()
self.request = request
self.callback = callback
self.logger = get_logger().bind(source=source.slug)
def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]:
def get_access_token(self, **request_kwargs) -> dict[str, Any] | None:
"""Fetch access token from callback request."""
raise NotImplementedError("Defined in a sub-class") # pragma: no cover
def get_profile_info(self, token: dict[str, str]) -> Optional[dict[str, Any]]:
def get_profile_info(self, token: dict[str, str]) -> dict[str, Any] | None:
"""Fetch user profile information."""
profile_url = self.source.source_type.profile_url or ""
if self.source.source_type.urls_customizable and self.source.profile_url:

View File

@ -1,6 +1,6 @@
"""OAuth 1 Clients"""
from typing import Any, Optional
from typing import Any
from urllib.parse import parse_qsl
from requests.exceptions import RequestException
@ -21,7 +21,7 @@ class OAuthClient(BaseOAuthClient):
"Accept": "application/json",
}
def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]:
def get_access_token(self, **request_kwargs) -> dict[str, Any] | None:
"""Fetch access token from callback request."""
raw_token = self.request.session.get(self.session_key, None)
verifier = self.request.GET.get("oauth_verifier", None)

View File

@ -1,7 +1,7 @@
"""OAuth 2 Clients"""
from json import loads
from typing import Any, Optional
from typing import Any
from urllib.parse import parse_qsl
from django.utils.crypto import constant_time_compare, get_random_string
@ -23,7 +23,7 @@ class OAuth2Client(BaseOAuthClient):
"Accept": "application/json",
}
def get_request_arg(self, key: str, default: Optional[Any] = None) -> Any:
def get_request_arg(self, key: str, default: Any | None = None) -> Any:
"""Depending on request type, get data from post or get"""
if self.request.method == "POST":
return self.request.POST.get(key, default)
@ -55,7 +55,7 @@ class OAuth2Client(BaseOAuthClient):
"""Get client secret"""
return self.source.consumer_secret
def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]:
def get_access_token(self, **request_kwargs) -> dict[str, Any] | None:
"""Fetch access token from callback request."""
callback = self.request.build_absolute_uri(self.callback or self.request.path)
if not self.check_application_state():
@ -139,7 +139,7 @@ class OAuth2Client(BaseOAuthClient):
class UserprofileHeaderAuthClient(OAuth2Client):
"""OAuth client which only sends authentication via header, not querystring"""
def get_profile_info(self, token: dict[str, str]) -> Optional[dict[str, Any]]:
def get_profile_info(self, token: dict[str, str]) -> dict[str, Any] | None:
"Fetch user profile information."
profile_url = self.source.source_type.profile_url or ""
if self.source.source_type.urls_customizable and self.source.profile_url:

View File

@ -1,6 +1,6 @@
"""OAuth Client models"""
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING
from django.db import models
from django.http.request import HttpRequest
@ -84,7 +84,7 @@ class OAuthSource(Source):
icon_url=icon,
)
def ui_user_settings(self) -> Optional[UserSettingSerializer]:
def ui_user_settings(self) -> UserSettingSerializer | None:
provider_type = self.source_type
icon = self.icon_url
if not icon:

View File

@ -35,7 +35,7 @@ class TestOAuthSourceTasks(TestCase):
},
)
mock.get("http://foo/jwks", json={"foo": "bar"})
update_well_known_jwks() # pylint: disable=no-value-for-parameter
update_well_known_jwks()
self.source.refresh_from_db()
self.assertEqual(self.source.authorization_url, "foo")
self.assertEqual(self.source.access_token_url, "foo")

View File

@ -1,7 +1,7 @@
"""Apple OAuth Views"""
from time import time
from typing import Any, Optional
from typing import Any
from django.http.request import HttpRequest
from django.urls.base import reverse
@ -17,6 +17,7 @@ from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect
LOGGER = get_logger()
APPLE_CLIENT_ID_PARTS = 3
class AppleLoginChallenge(Challenge):
@ -30,7 +31,7 @@ class AppleLoginChallenge(Challenge):
class AppleChallengeResponse(ChallengeResponse):
"""Pseudo class for plex response"""
"""Pseudo class for apple response"""
component = CharField(default="ak-source-oauth-apple")
@ -40,14 +41,14 @@ class AppleOAuthClient(OAuth2Client):
def get_client_id(self) -> str:
parts: list[str] = self.source.consumer_key.split(";")
if len(parts) < 3:
if len(parts) < APPLE_CLIENT_ID_PARTS:
return self.source.consumer_key
return parts[0].strip()
def get_client_secret(self) -> str:
now = time()
parts: list[str] = self.source.consumer_key.split(";")
if len(parts) < 3:
if len(parts) < APPLE_CLIENT_ID_PARTS:
raise ValueError(
"Apple Source client_id should be formatted like "
"services_id_identifier;apple_team_id;key_id"
@ -64,7 +65,7 @@ class AppleOAuthClient(OAuth2Client):
LOGGER.debug("signing payload as secret key", payload=payload, jwt=jwt)
return jwt
def get_profile_info(self, token: dict[str, str]) -> Optional[dict[str, Any]]:
def get_profile_info(self, token: dict[str, str]) -> dict[str, Any] | None:
id_token = token.get("id_token")
return decode(id_token, options={"verify_signature": False})
@ -86,7 +87,7 @@ class AppleOAuth2Callback(OAuthCallback):
client_class = AppleOAuthClient
def get_user_id(self, info: dict[str, Any]) -> Optional[str]:
def get_user_id(self, info: dict[str, Any]) -> str | None:
return info["sub"]
def get_user_enroll_context(

View File

@ -1,6 +1,6 @@
"""Facebook OAuth Views"""
from typing import Any, Optional
from typing import Any
from facebook import GraphAPI
@ -22,7 +22,7 @@ class FacebookOAuthRedirect(OAuthRedirect):
class FacebookOAuth2Client(OAuth2Client):
"""Facebook OAuth2 Client"""
def get_profile_info(self, token: dict[str, str]) -> Optional[dict[str, Any]]:
def get_profile_info(self, token: dict[str, str]) -> dict[str, Any] | None:
api = GraphAPI(access_token=token["access_token"])
return api.get_object("me", fields="id,name,email")

View File

@ -1,6 +1,6 @@
"""Mailcow OAuth Views"""
from typing import Any, Optional
from typing import Any
from requests.exceptions import RequestException
from structlog.stdlib import get_logger
@ -25,7 +25,7 @@ class MailcowOAuthRedirect(OAuthRedirect):
class MailcowOAuth2Client(OAuth2Client):
"""MailcowOAuth2Client, for some reason, mailcow does not like the default headers"""
def get_profile_info(self, token: dict[str, str]) -> Optional[dict[str, Any]]:
def get_profile_info(self, token: dict[str, str]) -> dict[str, Any] | None:
"Fetch user profile information."
profile_url = self.source.source_type.profile_url or ""
if self.source.source_type.urls_customizable and self.source.profile_url:

View File

@ -1,7 +1,7 @@
"""Source type manager"""
from collections.abc import Callable
from enum import Enum
from typing import Callable, Optional, Type
from django.http.request import HttpRequest
from django.templatetags.static import static
@ -33,12 +33,12 @@ class SourceType:
urls_customizable = False
request_token_url: Optional[str] = None
authorization_url: Optional[str] = None
access_token_url: Optional[str] = None
profile_url: Optional[str] = None
oidc_well_known_url: Optional[str] = None
oidc_jwks_url: Optional[str] = None
request_token_url: str | None = None
authorization_url: str | None = None
access_token_url: str | None = None
profile_url: str | None = None
oidc_well_known_url: str | None = None
oidc_jwks_url: str | None = None
def icon_url(self) -> str:
"""Get Icon URL for login"""
@ -80,7 +80,7 @@ class SourceTypeRegistry:
"""Get list of tuples of all registered names"""
return [(x.name, x.verbose_name) for x in self.__sources]
def find_type(self, type_name: str) -> Type[SourceType]:
def find_type(self, type_name: str) -> type[SourceType]:
"""Find type based on source"""
found_type = None
for src_type in self.__sources:

View File

@ -1,7 +1,7 @@
"""Twitch OAuth Views"""
from json import dumps
from typing import Any, Optional
from typing import Any
from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient
from authentik.sources.oauth.types.oidc import OpenIDConnectOAuth2Callback
@ -12,7 +12,7 @@ from authentik.sources.oauth.views.redirect import OAuthRedirect
class TwitchClient(UserprofileHeaderAuthClient):
"""Twitch needs the token_type to be capitalized for the request header."""
def get_profile_info(self, token: dict[str, str]) -> Optional[dict[str, Any]]:
def get_profile_info(self, token: dict[str, str]) -> dict[str, Any] | None:
token["token_type"] = token["token_type"].capitalize()
return super().get_profile_info(token)

View File

@ -1,6 +1,6 @@
"""Twitter OAuth Views"""
from typing import Any, Optional
from typing import Any
from authentik.lib.generators import generate_id
from authentik.sources.oauth.clients.oauth2 import (
@ -20,7 +20,7 @@ class TwitterClient(UserprofileHeaderAuthClient):
# is set via query parameter, so we reuse the azure client
# see https://github.com/goauthentik/authentik/issues/1910
def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]:
def get_access_token(self, **request_kwargs) -> dict[str, Any] | None:
return super().get_access_token(
auth=(
self.source.consumer_key,

View File

@ -1,7 +1,5 @@
"""OAuth Base views"""
from typing import Optional
from django.http.request import HttpRequest
from structlog.stdlib import get_logger
@ -13,18 +11,17 @@ from authentik.sources.oauth.models import OAuthSource
LOGGER = get_logger()
# pylint: disable=too-few-public-methods
class OAuthClientMixin:
"Mixin for getting OAuth client for a source."
request: HttpRequest # Set by View class
client_class: Optional[type[BaseOAuthClient]] = None
client_class: type[BaseOAuthClient] | None = None
def get_client(self, source: OAuthSource, **kwargs) -> BaseOAuthClient:
"Get instance of the OAuth client for this source."
if self.client_class is not None:
# pylint: disable=not-callable
return self.client_class(source, self.request, **kwargs)
if source.source_type.request_token_url or source.request_token_url:
client = OAuthClient(source, self.request, **kwargs)

View File

@ -1,7 +1,7 @@
"""OAuth Callback Views"""
from json import JSONDecodeError
from typing import Any, Optional
from typing import Any
from django.conf import settings
from django.contrib import messages
@ -23,16 +23,15 @@ class OAuthCallback(OAuthClientMixin, View):
"Base OAuth callback view."
source: OAuthSource
token: Optional[dict] = None
token: dict | None = None
# pylint: disable=too-many-return-statements
def dispatch(self, request: HttpRequest, *_, **kwargs) -> HttpResponse:
"""View Get handler"""
slug = kwargs.get("source_slug", "")
try:
self.source = OAuthSource.objects.get(slug=slug)
except OAuthSource.DoesNotExist:
raise Http404(f"Unknown OAuth source '{slug}'.")
raise Http404(f"Unknown OAuth source '{slug}'.") from None
if not self.source.enabled:
raise Http404(f"Source {slug} is not enabled.")
@ -86,7 +85,7 @@ class OAuthCallback(OAuthClientMixin, View):
"""Create a dict of User data"""
raise NotImplementedError()
def get_user_id(self, info: dict[str, Any]) -> Optional[str]:
def get_user_id(self, info: dict[str, Any]) -> str | None:
"""Return unique identifier from the profile info."""
if "id" in info:
return info["id"]
@ -98,10 +97,11 @@ class OAuthCallback(OAuthClientMixin, View):
messages.error(
self.request,
_(
"Authentication failed: %(reason)s"
% {
"reason": reason,
}
"Authentication failed: {reason}".format_map(
{
"reason": reason,
}
)
),
)
return redirect(self.get_error_redirect(self.source, reason))
@ -115,7 +115,7 @@ class OAuthSourceFlowManager(SourceFlowManager):
def update_connection(
self,
connection: UserOAuthSourceConnection,
access_token: Optional[str] = None,
access_token: str | None = None,
) -> UserOAuthSourceConnection:
"""Set the access_token on the connection"""
connection.access_token = access_token

View File

@ -36,7 +36,7 @@ class OAuthRedirect(OAuthClientMixin, RedirectView):
try:
source: OAuthSource = OAuthSource.objects.get(slug=slug)
except OAuthSource.DoesNotExist:
raise Http404(f"Unknown OAuth source '{slug}'.")
raise Http404(f"Unknown OAuth source '{slug}'.") from None
if not source.enabled:
raise Http404(f"source {slug} is not enabled.")
client = self.get_client(source, callback=self.get_callback_url(source))

View File

@ -1,7 +1,5 @@
"""Plex source"""
from typing import Optional
from django.contrib.postgres.fields import ArrayField
from django.db import models
from django.http.request import HttpRequest
@ -79,7 +77,7 @@ class PlexSource(Source):
name=self.name,
)
def ui_user_settings(self) -> Optional[UserSettingSerializer]:
def ui_user_settings(self) -> UserSettingSerializer | None:
icon = self.icon_url
if not icon:
icon = static("authentik/sources/plex.svg")

View File

@ -85,7 +85,7 @@ class PlexAuth:
resources = self.get_resources()
except RequestException as exc:
LOGGER.warning("Unable to fetch user resources", exc=exc)
raise Http404
raise Http404 from None
for resource in resources:
if resource["provides"] != "server":
continue

View File

@ -1,7 +1,5 @@
"""saml sp models"""
from typing import Optional
from django.db import models
from django.http import HttpRequest
from django.templatetags.static import static
@ -204,7 +202,7 @@ class SAMLSource(Source):
icon_url=self.icon_url,
)
def ui_user_settings(self) -> Optional[UserSettingSerializer]:
def ui_user_settings(self) -> UserSettingSerializer | None:
icon = self.icon_url
if not icon:
icon = static(f"authentik/sources/{self.slug}.svg")

View File

@ -1,6 +1,7 @@
"""SAML Service Provider Metadata Processor"""
from typing import Iterator, Optional
from collections.abc import Iterator
from typing import Optional
from django.http import HttpRequest
from lxml.etree import Element, SubElement, tostring # nosec
@ -30,7 +31,8 @@ class MetadataProcessor:
self.source = source
self.http_request = request
def get_signing_key_descriptor(self) -> Optional[Element]:
# Using type unions doesn't work with cython types (which is what lxml is)
def get_signing_key_descriptor(self) -> Optional[Element]: # noqa: UP007
"""Get Signing KeyDescriptor, if enabled for the source"""
if self.source.signing_kp:
key_descriptor = Element(f"{{{NS_SAML_METADATA}}}KeyDescriptor")

View File

@ -88,7 +88,7 @@ class InitiateView(View):
try:
plan = planner.plan(self.request, kwargs)
except FlowNonApplicableException:
raise Http404
raise Http404 from None
for stage in stages_to_append:
plan.append_stage(stage)
self.request.session[SESSION_KEY_PLAN] = plan