root: early spring clean for linting (#8498)
* remove pyright Signed-off-by: Jens Langhammer <jens@goauthentik.io> * remove pylint Signed-off-by: Jens Langhammer <jens@goauthentik.io> * replace pylint with ruff Signed-off-by: Jens Langhammer <jens@goauthentik.io> * ruff fix Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space> * fix UP038 Signed-off-by: Jens Langhammer <jens@goauthentik.io> * fix DJ012 Signed-off-by: Jens Langhammer <jens@goauthentik.io> * fix default arg Signed-off-by: Jens Langhammer <jens@goauthentik.io> * fix UP031 Signed-off-by: Jens Langhammer <jens@goauthentik.io> * rename stage type to view Signed-off-by: Jens Langhammer <jens@goauthentik.io> * fix DJ008 Signed-off-by: Jens Langhammer <jens@goauthentik.io> * fix remaining upgrade Signed-off-by: Jens Langhammer <jens@goauthentik.io> * fix PLR2004 Signed-off-by: Jens Langhammer <jens@goauthentik.io> * fix B904 Signed-off-by: Jens Langhammer <jens@goauthentik.io> * fix PLW2901 Signed-off-by: Jens Langhammer <jens@goauthentik.io> * fix remaining issues Signed-off-by: Jens Langhammer <jens@goauthentik.io> * prevent ruff from breaking the code Signed-off-by: Jens Langhammer <jens@goauthentik.io> * stages/prompt: refactor field building Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space> * fix tests Signed-off-by: Jens Langhammer <jens@goauthentik.io> * fix lint Signed-off-by: Jens Langhammer <jens@goauthentik.io> * fully remove isort Signed-off-by: Jens Langhammer <jens@goauthentik.io> --------- Signed-off-by: Jens Langhammer <jens@goauthentik.io> Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space> Co-authored-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
		@ -1,6 +1,6 @@
 | 
			
		||||
"""Source API Views"""
 | 
			
		||||
 | 
			
		||||
from typing import Any, Optional
 | 
			
		||||
from typing import Any
 | 
			
		||||
 | 
			
		||||
from django.core.cache import cache
 | 
			
		||||
from django_filters.filters import AllValuesMultipleFilter
 | 
			
		||||
@ -39,7 +39,7 @@ class LDAPSourceSerializer(SourceSerializer):
 | 
			
		||||
        required=False,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    def get_connectivity(self, source: LDAPSource) -> Optional[dict[str, dict[str, str]]]:
 | 
			
		||||
    def get_connectivity(self, source: LDAPSource) -> dict[str, dict[str, str]] | None:
 | 
			
		||||
        """Get cached source connectivity"""
 | 
			
		||||
        return cache.get(CACHE_KEY_STATUS + source.slug, None)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,7 +1,5 @@
 | 
			
		||||
"""authentik LDAP Authentication Backend"""
 | 
			
		||||
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from django.http import HttpRequest
 | 
			
		||||
from ldap3.core.exceptions import LDAPException, LDAPInvalidCredentialsResult
 | 
			
		||||
from structlog.stdlib import get_logger
 | 
			
		||||
@ -29,7 +27,7 @@ class LDAPBackend(InbuiltBackend):
 | 
			
		||||
                return user
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
    def auth_user(self, source: LDAPSource, password: str, **filters: str) -> Optional[User]:
 | 
			
		||||
    def auth_user(self, source: LDAPSource, password: str, **filters: str) -> User | None:
 | 
			
		||||
        """Try to bind as either user_dn or mail with password.
 | 
			
		||||
        Returns True on success, otherwise False"""
 | 
			
		||||
        users = User.objects.filter(**filters)
 | 
			
		||||
@ -52,7 +50,7 @@ class LDAPBackend(InbuiltBackend):
 | 
			
		||||
        LOGGER.debug("Failed to bind, password invalid")
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
    def auth_user_by_bind(self, source: LDAPSource, user: User, password: str) -> Optional[User]:
 | 
			
		||||
    def auth_user_by_bind(self, source: LDAPSource, user: User, password: str) -> User | None:
 | 
			
		||||
        """Attempt authentication by binding to the LDAP server as `user`. This
 | 
			
		||||
        method should be avoided as its slow to do the bind."""
 | 
			
		||||
        # Try to bind as new user
 | 
			
		||||
 | 
			
		||||
@ -5,7 +5,6 @@ from os.path import dirname, exists
 | 
			
		||||
from shutil import rmtree
 | 
			
		||||
from ssl import CERT_REQUIRED
 | 
			
		||||
from tempfile import NamedTemporaryFile, mkdtemp
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from django.core.cache import cache
 | 
			
		||||
from django.db import connection, models
 | 
			
		||||
@ -160,9 +159,9 @@ class LDAPSource(Source):
 | 
			
		||||
 | 
			
		||||
    def connection(
 | 
			
		||||
        self,
 | 
			
		||||
        server: Optional[Server] = None,
 | 
			
		||||
        server_kwargs: Optional[dict] = None,
 | 
			
		||||
        connection_kwargs: Optional[dict] = None,
 | 
			
		||||
        server: Server | None = None,
 | 
			
		||||
        server_kwargs: dict | None = None,
 | 
			
		||||
        connection_kwargs: dict | None = None,
 | 
			
		||||
    ) -> Connection:
 | 
			
		||||
        """Get a fully connected and bound LDAP Connection"""
 | 
			
		||||
        server_kwargs = server_kwargs or {}
 | 
			
		||||
 | 
			
		||||
@ -2,7 +2,6 @@
 | 
			
		||||
 | 
			
		||||
from enum import IntFlag
 | 
			
		||||
from re import split
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from ldap3 import BASE
 | 
			
		||||
from ldap3.core.exceptions import (
 | 
			
		||||
@ -20,6 +19,7 @@ LOGGER = get_logger()
 | 
			
		||||
 | 
			
		||||
NON_ALPHA = r"~!@#$%^&*_-+=`|\(){}[]:;\"'<>,.?/"
 | 
			
		||||
RE_DISPLAYNAME_SEPARATORS = r",\.–—_\s#\t"
 | 
			
		||||
MIN_TOKEN_SIZE = 3
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PwdProperties(IntFlag):
 | 
			
		||||
@ -119,7 +119,7 @@ class LDAPPasswordChanger:
 | 
			
		||||
            raise AssertionError()
 | 
			
		||||
        user_attributes = users[0]["attributes"]
 | 
			
		||||
        # If sAMAccountName is longer than 3 chars, check if its contained in password
 | 
			
		||||
        if len(user_attributes["sAMAccountName"]) >= 3:
 | 
			
		||||
        if len(user_attributes["sAMAccountName"]) >= MIN_TOKEN_SIZE:
 | 
			
		||||
            if password.lower() in user_attributes["sAMAccountName"].lower():
 | 
			
		||||
                return False
 | 
			
		||||
        # No display name set, can't check any further
 | 
			
		||||
@ -129,13 +129,13 @@ class LDAPPasswordChanger:
 | 
			
		||||
            display_name_tokens = split(RE_DISPLAYNAME_SEPARATORS, display_name)
 | 
			
		||||
            for token in display_name_tokens:
 | 
			
		||||
                # Ignore tokens under 3 chars
 | 
			
		||||
                if len(token) < 3:
 | 
			
		||||
                if len(token) < MIN_TOKEN_SIZE:
 | 
			
		||||
                    continue
 | 
			
		||||
                if token.lower() in password.lower():
 | 
			
		||||
                    return False
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
    def ad_password_complexity(self, password: str, user: Optional[User] = None) -> bool:
 | 
			
		||||
    def ad_password_complexity(self, password: str, user: User | None = None) -> bool:
 | 
			
		||||
        """Check if password matches Active directory password policies
 | 
			
		||||
 | 
			
		||||
        https://docs.microsoft.com/en-us/windows/security/threat-protection/
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,7 @@
 | 
			
		||||
"""Sync LDAP Users and groups into authentik"""
 | 
			
		||||
 | 
			
		||||
from typing import Any, Generator
 | 
			
		||||
from collections.abc import Generator
 | 
			
		||||
from typing import Any
 | 
			
		||||
 | 
			
		||||
from django.conf import settings
 | 
			
		||||
from django.db.models.base import Model
 | 
			
		||||
@ -90,8 +91,7 @@ class BaseLDAPSynchronizer:
 | 
			
		||||
        """Get objects from LDAP, implemented in subclass"""
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    # pylint: disable=too-many-arguments
 | 
			
		||||
    def search_paginator(
 | 
			
		||||
    def search_paginator(  # noqa: PLR0913
 | 
			
		||||
        self,
 | 
			
		||||
        search_base,
 | 
			
		||||
        search_filter,
 | 
			
		||||
@ -103,11 +103,13 @@ class BaseLDAPSynchronizer:
 | 
			
		||||
        types_only=False,
 | 
			
		||||
        get_operational_attributes=False,
 | 
			
		||||
        controls=None,
 | 
			
		||||
        paged_size=CONFIG.get_int("ldap.page_size", 50),
 | 
			
		||||
        paged_size=None,
 | 
			
		||||
        paged_criticality=False,
 | 
			
		||||
    ):
 | 
			
		||||
        """Search in pages, returns each page"""
 | 
			
		||||
        cookie = True
 | 
			
		||||
        if not paged_size:
 | 
			
		||||
            paged_size = CONFIG.get_int("ldap.page_size", 50)
 | 
			
		||||
        while cookie:
 | 
			
		||||
            self._connection.search(
 | 
			
		||||
                search_base,
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,6 @@
 | 
			
		||||
"""Sync LDAP Users and groups into authentik"""
 | 
			
		||||
 | 
			
		||||
from typing import Generator
 | 
			
		||||
from collections.abc import Generator
 | 
			
		||||
 | 
			
		||||
from django.core.exceptions import FieldError
 | 
			
		||||
from django.db.utils import IntegrityError
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,7 @@
 | 
			
		||||
"""Sync LDAP Users and groups into authentik"""
 | 
			
		||||
 | 
			
		||||
from typing import Any, Generator, Optional
 | 
			
		||||
from collections.abc import Generator
 | 
			
		||||
from typing import Any
 | 
			
		||||
 | 
			
		||||
from django.db.models import Q
 | 
			
		||||
from ldap3 import SUBTREE
 | 
			
		||||
@ -76,7 +77,7 @@ class MembershipLDAPSynchronizer(BaseLDAPSynchronizer):
 | 
			
		||||
        self._logger.debug("Successfully updated group membership")
 | 
			
		||||
        return membership_count
 | 
			
		||||
 | 
			
		||||
    def get_group(self, group_dict: dict[str, Any]) -> Optional[Group]:
 | 
			
		||||
    def get_group(self, group_dict: dict[str, Any]) -> Group | None:
 | 
			
		||||
        """Check if we fetched the group already, and if not cache it for later"""
 | 
			
		||||
        group_dn = group_dict.get("attributes", {}).get(LDAP_DISTINGUISHED_NAME, [])
 | 
			
		||||
        group_uniq = group_dict.get("attributes", {}).get(self._source.object_uniqueness_field, [])
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,6 @@
 | 
			
		||||
"""Sync LDAP Users into authentik"""
 | 
			
		||||
 | 
			
		||||
from typing import Generator
 | 
			
		||||
from collections.abc import Generator
 | 
			
		||||
 | 
			
		||||
from django.core.exceptions import FieldError
 | 
			
		||||
from django.db.utils import IntegrityError
 | 
			
		||||
 | 
			
		||||
@ -1,7 +1,8 @@
 | 
			
		||||
"""FreeIPA specific"""
 | 
			
		||||
 | 
			
		||||
from datetime import datetime, timezone
 | 
			
		||||
from typing import Any, Generator
 | 
			
		||||
from collections.abc import Generator
 | 
			
		||||
from datetime import UTC, datetime
 | 
			
		||||
from typing import Any
 | 
			
		||||
 | 
			
		||||
from authentik.core.models import User
 | 
			
		||||
from authentik.sources.ldap.sync.base import BaseLDAPSynchronizer, flatten
 | 
			
		||||
@ -26,7 +27,7 @@ class FreeIPA(BaseLDAPSynchronizer):
 | 
			
		||||
        if "krbLastPwdChange" not in attributes:
 | 
			
		||||
            return
 | 
			
		||||
        pwd_last_set: datetime = attributes.get("krbLastPwdChange", datetime.now())
 | 
			
		||||
        pwd_last_set = pwd_last_set.replace(tzinfo=timezone.utc)
 | 
			
		||||
        pwd_last_set = pwd_last_set.replace(tzinfo=UTC)
 | 
			
		||||
        if created or pwd_last_set >= user.password_change_date:
 | 
			
		||||
            self.message(f"'{user.username}': Reset user's password")
 | 
			
		||||
            self._logger.debug(
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										7
									
								
								authentik/sources/ldap/sync/vendor/ms_ad.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										7
									
								
								authentik/sources/ldap/sync/vendor/ms_ad.py
									
									
									
									
										vendored
									
									
								
							@ -1,8 +1,9 @@
 | 
			
		||||
"""Active Directory specific"""
 | 
			
		||||
 | 
			
		||||
from datetime import datetime, timezone
 | 
			
		||||
from collections.abc import Generator
 | 
			
		||||
from datetime import UTC, datetime
 | 
			
		||||
from enum import IntFlag
 | 
			
		||||
from typing import Any, Generator
 | 
			
		||||
from typing import Any
 | 
			
		||||
 | 
			
		||||
from authentik.core.models import User
 | 
			
		||||
from authentik.sources.ldap.sync.base import BaseLDAPSynchronizer
 | 
			
		||||
@ -57,7 +58,7 @@ class MicrosoftActiveDirectory(BaseLDAPSynchronizer):
 | 
			
		||||
        if "pwdLastSet" not in attributes:
 | 
			
		||||
            return
 | 
			
		||||
        pwd_last_set: datetime = attributes.get("pwdLastSet", datetime.now())
 | 
			
		||||
        pwd_last_set = pwd_last_set.replace(tzinfo=timezone.utc)
 | 
			
		||||
        pwd_last_set = pwd_last_set.replace(tzinfo=UTC)
 | 
			
		||||
        if created or pwd_last_set >= user.password_change_date:
 | 
			
		||||
            self.message(f"'{user.username}': Reset user's password")
 | 
			
		||||
            self._logger.debug(
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,5 @@
 | 
			
		||||
"""LDAP Sync tasks"""
 | 
			
		||||
 | 
			
		||||
from typing import Optional
 | 
			
		||||
from uuid import uuid4
 | 
			
		||||
 | 
			
		||||
from celery import chain, group
 | 
			
		||||
@ -40,7 +39,7 @@ def ldap_sync_all():
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@CELERY_APP.task()
 | 
			
		||||
def ldap_connectivity_check(pk: Optional[str] = None):
 | 
			
		||||
def ldap_connectivity_check(pk: str | None = None):
 | 
			
		||||
    """Check connectivity for LDAP Sources"""
 | 
			
		||||
    # 2 hour timeout, this task should run every hour
 | 
			
		||||
    timeout = 60 * 60 * 2
 | 
			
		||||
 | 
			
		||||
@ -57,7 +57,6 @@ class OAuthSourceSerializer(SourceSerializer):
 | 
			
		||||
        """Get source's type configuration"""
 | 
			
		||||
        return SourceTypeSerializer(instance.source_type).data
 | 
			
		||||
 | 
			
		||||
    # pylint: disable=too-many-locals
 | 
			
		||||
    def validate(self, attrs: dict) -> dict:
 | 
			
		||||
        session = get_http_session()
 | 
			
		||||
        source_type = registry.find_type(attrs["provider_type"])
 | 
			
		||||
@ -71,7 +70,7 @@ class OAuthSourceSerializer(SourceSerializer):
 | 
			
		||||
                well_known_config.raise_for_status()
 | 
			
		||||
            except RequestException as exc:
 | 
			
		||||
                text = exc.response.text if exc.response else str(exc)
 | 
			
		||||
                raise ValidationError({"oidc_well_known_url": text})
 | 
			
		||||
                raise ValidationError({"oidc_well_known_url": text}) from None
 | 
			
		||||
            config = well_known_config.json()
 | 
			
		||||
            if "issuer" not in config:
 | 
			
		||||
                raise ValidationError({"oidc_well_known_url": "Invalid well-known configuration"})
 | 
			
		||||
@ -97,7 +96,7 @@ class OAuthSourceSerializer(SourceSerializer):
 | 
			
		||||
                jwks_config.raise_for_status()
 | 
			
		||||
            except RequestException as exc:
 | 
			
		||||
                text = exc.response.text if exc.response else str(exc)
 | 
			
		||||
                raise ValidationError({"oidc_jwks_url": text})
 | 
			
		||||
                raise ValidationError({"oidc_jwks_url": text}) from None
 | 
			
		||||
            config = jwks_config.json()
 | 
			
		||||
            attrs["oidc_jwks"] = config
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,6 @@
 | 
			
		||||
"""OAuth Clients"""
 | 
			
		||||
 | 
			
		||||
from typing import Any, Optional
 | 
			
		||||
from typing import Any
 | 
			
		||||
from urllib.parse import parse_qs, quote, urlencode, urlparse, urlunparse
 | 
			
		||||
 | 
			
		||||
from django.http import HttpRequest
 | 
			
		||||
@ -22,20 +22,20 @@ class BaseOAuthClient:
 | 
			
		||||
    source: OAuthSource
 | 
			
		||||
    request: HttpRequest
 | 
			
		||||
 | 
			
		||||
    callback: Optional[str]
 | 
			
		||||
    callback: str | None
 | 
			
		||||
 | 
			
		||||
    def __init__(self, source: OAuthSource, request: HttpRequest, callback: Optional[str] = None):
 | 
			
		||||
    def __init__(self, source: OAuthSource, request: HttpRequest, callback: str | None = None):
 | 
			
		||||
        self.source = source
 | 
			
		||||
        self.session = get_http_session()
 | 
			
		||||
        self.request = request
 | 
			
		||||
        self.callback = callback
 | 
			
		||||
        self.logger = get_logger().bind(source=source.slug)
 | 
			
		||||
 | 
			
		||||
    def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]:
 | 
			
		||||
    def get_access_token(self, **request_kwargs) -> dict[str, Any] | None:
 | 
			
		||||
        """Fetch access token from callback request."""
 | 
			
		||||
        raise NotImplementedError("Defined in a sub-class")  # pragma: no cover
 | 
			
		||||
 | 
			
		||||
    def get_profile_info(self, token: dict[str, str]) -> Optional[dict[str, Any]]:
 | 
			
		||||
    def get_profile_info(self, token: dict[str, str]) -> dict[str, Any] | None:
 | 
			
		||||
        """Fetch user profile information."""
 | 
			
		||||
        profile_url = self.source.source_type.profile_url or ""
 | 
			
		||||
        if self.source.source_type.urls_customizable and self.source.profile_url:
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,6 @@
 | 
			
		||||
"""OAuth 1 Clients"""
 | 
			
		||||
 | 
			
		||||
from typing import Any, Optional
 | 
			
		||||
from typing import Any
 | 
			
		||||
from urllib.parse import parse_qsl
 | 
			
		||||
 | 
			
		||||
from requests.exceptions import RequestException
 | 
			
		||||
@ -21,7 +21,7 @@ class OAuthClient(BaseOAuthClient):
 | 
			
		||||
        "Accept": "application/json",
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]:
 | 
			
		||||
    def get_access_token(self, **request_kwargs) -> dict[str, Any] | None:
 | 
			
		||||
        """Fetch access token from callback request."""
 | 
			
		||||
        raw_token = self.request.session.get(self.session_key, None)
 | 
			
		||||
        verifier = self.request.GET.get("oauth_verifier", None)
 | 
			
		||||
 | 
			
		||||
@ -1,7 +1,7 @@
 | 
			
		||||
"""OAuth 2 Clients"""
 | 
			
		||||
 | 
			
		||||
from json import loads
 | 
			
		||||
from typing import Any, Optional
 | 
			
		||||
from typing import Any
 | 
			
		||||
from urllib.parse import parse_qsl
 | 
			
		||||
 | 
			
		||||
from django.utils.crypto import constant_time_compare, get_random_string
 | 
			
		||||
@ -23,7 +23,7 @@ class OAuth2Client(BaseOAuthClient):
 | 
			
		||||
        "Accept": "application/json",
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    def get_request_arg(self, key: str, default: Optional[Any] = None) -> Any:
 | 
			
		||||
    def get_request_arg(self, key: str, default: Any | None = None) -> Any:
 | 
			
		||||
        """Depending on request type, get data from post or get"""
 | 
			
		||||
        if self.request.method == "POST":
 | 
			
		||||
            return self.request.POST.get(key, default)
 | 
			
		||||
@ -55,7 +55,7 @@ class OAuth2Client(BaseOAuthClient):
 | 
			
		||||
        """Get client secret"""
 | 
			
		||||
        return self.source.consumer_secret
 | 
			
		||||
 | 
			
		||||
    def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]:
 | 
			
		||||
    def get_access_token(self, **request_kwargs) -> dict[str, Any] | None:
 | 
			
		||||
        """Fetch access token from callback request."""
 | 
			
		||||
        callback = self.request.build_absolute_uri(self.callback or self.request.path)
 | 
			
		||||
        if not self.check_application_state():
 | 
			
		||||
@ -139,7 +139,7 @@ class OAuth2Client(BaseOAuthClient):
 | 
			
		||||
class UserprofileHeaderAuthClient(OAuth2Client):
 | 
			
		||||
    """OAuth client which only sends authentication via header, not querystring"""
 | 
			
		||||
 | 
			
		||||
    def get_profile_info(self, token: dict[str, str]) -> Optional[dict[str, Any]]:
 | 
			
		||||
    def get_profile_info(self, token: dict[str, str]) -> dict[str, Any] | None:
 | 
			
		||||
        "Fetch user profile information."
 | 
			
		||||
        profile_url = self.source.source_type.profile_url or ""
 | 
			
		||||
        if self.source.source_type.urls_customizable and self.source.profile_url:
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,6 @@
 | 
			
		||||
"""OAuth Client models"""
 | 
			
		||||
 | 
			
		||||
from typing import TYPE_CHECKING, Optional
 | 
			
		||||
from typing import TYPE_CHECKING
 | 
			
		||||
 | 
			
		||||
from django.db import models
 | 
			
		||||
from django.http.request import HttpRequest
 | 
			
		||||
@ -84,7 +84,7 @@ class OAuthSource(Source):
 | 
			
		||||
            icon_url=icon,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def ui_user_settings(self) -> Optional[UserSettingSerializer]:
 | 
			
		||||
    def ui_user_settings(self) -> UserSettingSerializer | None:
 | 
			
		||||
        provider_type = self.source_type
 | 
			
		||||
        icon = self.icon_url
 | 
			
		||||
        if not icon:
 | 
			
		||||
 | 
			
		||||
@ -35,7 +35,7 @@ class TestOAuthSourceTasks(TestCase):
 | 
			
		||||
            },
 | 
			
		||||
        )
 | 
			
		||||
        mock.get("http://foo/jwks", json={"foo": "bar"})
 | 
			
		||||
        update_well_known_jwks()  # pylint: disable=no-value-for-parameter
 | 
			
		||||
        update_well_known_jwks()
 | 
			
		||||
        self.source.refresh_from_db()
 | 
			
		||||
        self.assertEqual(self.source.authorization_url, "foo")
 | 
			
		||||
        self.assertEqual(self.source.access_token_url, "foo")
 | 
			
		||||
 | 
			
		||||
@ -1,7 +1,7 @@
 | 
			
		||||
"""Apple OAuth Views"""
 | 
			
		||||
 | 
			
		||||
from time import time
 | 
			
		||||
from typing import Any, Optional
 | 
			
		||||
from typing import Any
 | 
			
		||||
 | 
			
		||||
from django.http.request import HttpRequest
 | 
			
		||||
from django.urls.base import reverse
 | 
			
		||||
@ -17,6 +17,7 @@ from authentik.sources.oauth.views.callback import OAuthCallback
 | 
			
		||||
from authentik.sources.oauth.views.redirect import OAuthRedirect
 | 
			
		||||
 | 
			
		||||
LOGGER = get_logger()
 | 
			
		||||
APPLE_CLIENT_ID_PARTS = 3
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AppleLoginChallenge(Challenge):
 | 
			
		||||
@ -30,7 +31,7 @@ class AppleLoginChallenge(Challenge):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AppleChallengeResponse(ChallengeResponse):
 | 
			
		||||
    """Pseudo class for plex response"""
 | 
			
		||||
    """Pseudo class for apple response"""
 | 
			
		||||
 | 
			
		||||
    component = CharField(default="ak-source-oauth-apple")
 | 
			
		||||
 | 
			
		||||
@ -40,14 +41,14 @@ class AppleOAuthClient(OAuth2Client):
 | 
			
		||||
 | 
			
		||||
    def get_client_id(self) -> str:
 | 
			
		||||
        parts: list[str] = self.source.consumer_key.split(";")
 | 
			
		||||
        if len(parts) < 3:
 | 
			
		||||
        if len(parts) < APPLE_CLIENT_ID_PARTS:
 | 
			
		||||
            return self.source.consumer_key
 | 
			
		||||
        return parts[0].strip()
 | 
			
		||||
 | 
			
		||||
    def get_client_secret(self) -> str:
 | 
			
		||||
        now = time()
 | 
			
		||||
        parts: list[str] = self.source.consumer_key.split(";")
 | 
			
		||||
        if len(parts) < 3:
 | 
			
		||||
        if len(parts) < APPLE_CLIENT_ID_PARTS:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                "Apple Source client_id should be formatted like "
 | 
			
		||||
                "services_id_identifier;apple_team_id;key_id"
 | 
			
		||||
@ -64,7 +65,7 @@ class AppleOAuthClient(OAuth2Client):
 | 
			
		||||
        LOGGER.debug("signing payload as secret key", payload=payload, jwt=jwt)
 | 
			
		||||
        return jwt
 | 
			
		||||
 | 
			
		||||
    def get_profile_info(self, token: dict[str, str]) -> Optional[dict[str, Any]]:
 | 
			
		||||
    def get_profile_info(self, token: dict[str, str]) -> dict[str, Any] | None:
 | 
			
		||||
        id_token = token.get("id_token")
 | 
			
		||||
        return decode(id_token, options={"verify_signature": False})
 | 
			
		||||
 | 
			
		||||
@ -86,7 +87,7 @@ class AppleOAuth2Callback(OAuthCallback):
 | 
			
		||||
 | 
			
		||||
    client_class = AppleOAuthClient
 | 
			
		||||
 | 
			
		||||
    def get_user_id(self, info: dict[str, Any]) -> Optional[str]:
 | 
			
		||||
    def get_user_id(self, info: dict[str, Any]) -> str | None:
 | 
			
		||||
        return info["sub"]
 | 
			
		||||
 | 
			
		||||
    def get_user_enroll_context(
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,6 @@
 | 
			
		||||
"""Facebook OAuth Views"""
 | 
			
		||||
 | 
			
		||||
from typing import Any, Optional
 | 
			
		||||
from typing import Any
 | 
			
		||||
 | 
			
		||||
from facebook import GraphAPI
 | 
			
		||||
 | 
			
		||||
@ -22,7 +22,7 @@ class FacebookOAuthRedirect(OAuthRedirect):
 | 
			
		||||
class FacebookOAuth2Client(OAuth2Client):
 | 
			
		||||
    """Facebook OAuth2 Client"""
 | 
			
		||||
 | 
			
		||||
    def get_profile_info(self, token: dict[str, str]) -> Optional[dict[str, Any]]:
 | 
			
		||||
    def get_profile_info(self, token: dict[str, str]) -> dict[str, Any] | None:
 | 
			
		||||
        api = GraphAPI(access_token=token["access_token"])
 | 
			
		||||
        return api.get_object("me", fields="id,name,email")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,6 @@
 | 
			
		||||
"""Mailcow OAuth Views"""
 | 
			
		||||
 | 
			
		||||
from typing import Any, Optional
 | 
			
		||||
from typing import Any
 | 
			
		||||
 | 
			
		||||
from requests.exceptions import RequestException
 | 
			
		||||
from structlog.stdlib import get_logger
 | 
			
		||||
@ -25,7 +25,7 @@ class MailcowOAuthRedirect(OAuthRedirect):
 | 
			
		||||
class MailcowOAuth2Client(OAuth2Client):
 | 
			
		||||
    """MailcowOAuth2Client, for some reason, mailcow does not like the default headers"""
 | 
			
		||||
 | 
			
		||||
    def get_profile_info(self, token: dict[str, str]) -> Optional[dict[str, Any]]:
 | 
			
		||||
    def get_profile_info(self, token: dict[str, str]) -> dict[str, Any] | None:
 | 
			
		||||
        "Fetch user profile information."
 | 
			
		||||
        profile_url = self.source.source_type.profile_url or ""
 | 
			
		||||
        if self.source.source_type.urls_customizable and self.source.profile_url:
 | 
			
		||||
 | 
			
		||||
@ -1,7 +1,7 @@
 | 
			
		||||
"""Source type manager"""
 | 
			
		||||
 | 
			
		||||
from collections.abc import Callable
 | 
			
		||||
from enum import Enum
 | 
			
		||||
from typing import Callable, Optional, Type
 | 
			
		||||
 | 
			
		||||
from django.http.request import HttpRequest
 | 
			
		||||
from django.templatetags.static import static
 | 
			
		||||
@ -33,12 +33,12 @@ class SourceType:
 | 
			
		||||
 | 
			
		||||
    urls_customizable = False
 | 
			
		||||
 | 
			
		||||
    request_token_url: Optional[str] = None
 | 
			
		||||
    authorization_url: Optional[str] = None
 | 
			
		||||
    access_token_url: Optional[str] = None
 | 
			
		||||
    profile_url: Optional[str] = None
 | 
			
		||||
    oidc_well_known_url: Optional[str] = None
 | 
			
		||||
    oidc_jwks_url: Optional[str] = None
 | 
			
		||||
    request_token_url: str | None = None
 | 
			
		||||
    authorization_url: str | None = None
 | 
			
		||||
    access_token_url: str | None = None
 | 
			
		||||
    profile_url: str | None = None
 | 
			
		||||
    oidc_well_known_url: str | None = None
 | 
			
		||||
    oidc_jwks_url: str | None = None
 | 
			
		||||
 | 
			
		||||
    def icon_url(self) -> str:
 | 
			
		||||
        """Get Icon URL for login"""
 | 
			
		||||
@ -80,7 +80,7 @@ class SourceTypeRegistry:
 | 
			
		||||
        """Get list of tuples of all registered names"""
 | 
			
		||||
        return [(x.name, x.verbose_name) for x in self.__sources]
 | 
			
		||||
 | 
			
		||||
    def find_type(self, type_name: str) -> Type[SourceType]:
 | 
			
		||||
    def find_type(self, type_name: str) -> type[SourceType]:
 | 
			
		||||
        """Find type based on source"""
 | 
			
		||||
        found_type = None
 | 
			
		||||
        for src_type in self.__sources:
 | 
			
		||||
 | 
			
		||||
@ -1,7 +1,7 @@
 | 
			
		||||
"""Twitch OAuth Views"""
 | 
			
		||||
 | 
			
		||||
from json import dumps
 | 
			
		||||
from typing import Any, Optional
 | 
			
		||||
from typing import Any
 | 
			
		||||
 | 
			
		||||
from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient
 | 
			
		||||
from authentik.sources.oauth.types.oidc import OpenIDConnectOAuth2Callback
 | 
			
		||||
@ -12,7 +12,7 @@ from authentik.sources.oauth.views.redirect import OAuthRedirect
 | 
			
		||||
class TwitchClient(UserprofileHeaderAuthClient):
 | 
			
		||||
    """Twitch needs the token_type to be capitalized for the request header."""
 | 
			
		||||
 | 
			
		||||
    def get_profile_info(self, token: dict[str, str]) -> Optional[dict[str, Any]]:
 | 
			
		||||
    def get_profile_info(self, token: dict[str, str]) -> dict[str, Any] | None:
 | 
			
		||||
        token["token_type"] = token["token_type"].capitalize()
 | 
			
		||||
        return super().get_profile_info(token)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,6 @@
 | 
			
		||||
"""Twitter OAuth Views"""
 | 
			
		||||
 | 
			
		||||
from typing import Any, Optional
 | 
			
		||||
from typing import Any
 | 
			
		||||
 | 
			
		||||
from authentik.lib.generators import generate_id
 | 
			
		||||
from authentik.sources.oauth.clients.oauth2 import (
 | 
			
		||||
@ -20,7 +20,7 @@ class TwitterClient(UserprofileHeaderAuthClient):
 | 
			
		||||
    # is set via query parameter, so we reuse the azure client
 | 
			
		||||
    # see https://github.com/goauthentik/authentik/issues/1910
 | 
			
		||||
 | 
			
		||||
    def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]:
 | 
			
		||||
    def get_access_token(self, **request_kwargs) -> dict[str, Any] | None:
 | 
			
		||||
        return super().get_access_token(
 | 
			
		||||
            auth=(
 | 
			
		||||
                self.source.consumer_key,
 | 
			
		||||
 | 
			
		||||
@ -1,7 +1,5 @@
 | 
			
		||||
"""OAuth Base views"""
 | 
			
		||||
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from django.http.request import HttpRequest
 | 
			
		||||
from structlog.stdlib import get_logger
 | 
			
		||||
 | 
			
		||||
@ -13,18 +11,17 @@ from authentik.sources.oauth.models import OAuthSource
 | 
			
		||||
LOGGER = get_logger()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# pylint: disable=too-few-public-methods
 | 
			
		||||
class OAuthClientMixin:
 | 
			
		||||
    "Mixin for getting OAuth client for a source."
 | 
			
		||||
 | 
			
		||||
    request: HttpRequest  # Set by View class
 | 
			
		||||
 | 
			
		||||
    client_class: Optional[type[BaseOAuthClient]] = None
 | 
			
		||||
    client_class: type[BaseOAuthClient] | None = None
 | 
			
		||||
 | 
			
		||||
    def get_client(self, source: OAuthSource, **kwargs) -> BaseOAuthClient:
 | 
			
		||||
        "Get instance of the OAuth client for this source."
 | 
			
		||||
        if self.client_class is not None:
 | 
			
		||||
            # pylint: disable=not-callable
 | 
			
		||||
 | 
			
		||||
            return self.client_class(source, self.request, **kwargs)
 | 
			
		||||
        if source.source_type.request_token_url or source.request_token_url:
 | 
			
		||||
            client = OAuthClient(source, self.request, **kwargs)
 | 
			
		||||
 | 
			
		||||
@ -1,7 +1,7 @@
 | 
			
		||||
"""OAuth Callback Views"""
 | 
			
		||||
 | 
			
		||||
from json import JSONDecodeError
 | 
			
		||||
from typing import Any, Optional
 | 
			
		||||
from typing import Any
 | 
			
		||||
 | 
			
		||||
from django.conf import settings
 | 
			
		||||
from django.contrib import messages
 | 
			
		||||
@ -23,16 +23,15 @@ class OAuthCallback(OAuthClientMixin, View):
 | 
			
		||||
    "Base OAuth callback view."
 | 
			
		||||
 | 
			
		||||
    source: OAuthSource
 | 
			
		||||
    token: Optional[dict] = None
 | 
			
		||||
    token: dict | None = None
 | 
			
		||||
 | 
			
		||||
    # pylint: disable=too-many-return-statements
 | 
			
		||||
    def dispatch(self, request: HttpRequest, *_, **kwargs) -> HttpResponse:
 | 
			
		||||
        """View Get handler"""
 | 
			
		||||
        slug = kwargs.get("source_slug", "")
 | 
			
		||||
        try:
 | 
			
		||||
            self.source = OAuthSource.objects.get(slug=slug)
 | 
			
		||||
        except OAuthSource.DoesNotExist:
 | 
			
		||||
            raise Http404(f"Unknown OAuth source '{slug}'.")
 | 
			
		||||
            raise Http404(f"Unknown OAuth source '{slug}'.") from None
 | 
			
		||||
 | 
			
		||||
        if not self.source.enabled:
 | 
			
		||||
            raise Http404(f"Source {slug} is not enabled.")
 | 
			
		||||
@ -86,7 +85,7 @@ class OAuthCallback(OAuthClientMixin, View):
 | 
			
		||||
        """Create a dict of User data"""
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    def get_user_id(self, info: dict[str, Any]) -> Optional[str]:
 | 
			
		||||
    def get_user_id(self, info: dict[str, Any]) -> str | None:
 | 
			
		||||
        """Return unique identifier from the profile info."""
 | 
			
		||||
        if "id" in info:
 | 
			
		||||
            return info["id"]
 | 
			
		||||
@ -98,10 +97,11 @@ class OAuthCallback(OAuthClientMixin, View):
 | 
			
		||||
        messages.error(
 | 
			
		||||
            self.request,
 | 
			
		||||
            _(
 | 
			
		||||
                "Authentication failed: %(reason)s"
 | 
			
		||||
                % {
 | 
			
		||||
                    "reason": reason,
 | 
			
		||||
                }
 | 
			
		||||
                "Authentication failed: {reason}".format_map(
 | 
			
		||||
                    {
 | 
			
		||||
                        "reason": reason,
 | 
			
		||||
                    }
 | 
			
		||||
                )
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
        return redirect(self.get_error_redirect(self.source, reason))
 | 
			
		||||
@ -115,7 +115,7 @@ class OAuthSourceFlowManager(SourceFlowManager):
 | 
			
		||||
    def update_connection(
 | 
			
		||||
        self,
 | 
			
		||||
        connection: UserOAuthSourceConnection,
 | 
			
		||||
        access_token: Optional[str] = None,
 | 
			
		||||
        access_token: str | None = None,
 | 
			
		||||
    ) -> UserOAuthSourceConnection:
 | 
			
		||||
        """Set the access_token on the connection"""
 | 
			
		||||
        connection.access_token = access_token
 | 
			
		||||
 | 
			
		||||
@ -36,7 +36,7 @@ class OAuthRedirect(OAuthClientMixin, RedirectView):
 | 
			
		||||
        try:
 | 
			
		||||
            source: OAuthSource = OAuthSource.objects.get(slug=slug)
 | 
			
		||||
        except OAuthSource.DoesNotExist:
 | 
			
		||||
            raise Http404(f"Unknown OAuth source '{slug}'.")
 | 
			
		||||
            raise Http404(f"Unknown OAuth source '{slug}'.") from None
 | 
			
		||||
        if not source.enabled:
 | 
			
		||||
            raise Http404(f"source {slug} is not enabled.")
 | 
			
		||||
        client = self.get_client(source, callback=self.get_callback_url(source))
 | 
			
		||||
 | 
			
		||||
@ -1,7 +1,5 @@
 | 
			
		||||
"""Plex source"""
 | 
			
		||||
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from django.contrib.postgres.fields import ArrayField
 | 
			
		||||
from django.db import models
 | 
			
		||||
from django.http.request import HttpRequest
 | 
			
		||||
@ -79,7 +77,7 @@ class PlexSource(Source):
 | 
			
		||||
            name=self.name,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def ui_user_settings(self) -> Optional[UserSettingSerializer]:
 | 
			
		||||
    def ui_user_settings(self) -> UserSettingSerializer | None:
 | 
			
		||||
        icon = self.icon_url
 | 
			
		||||
        if not icon:
 | 
			
		||||
            icon = static("authentik/sources/plex.svg")
 | 
			
		||||
 | 
			
		||||
@ -85,7 +85,7 @@ class PlexAuth:
 | 
			
		||||
            resources = self.get_resources()
 | 
			
		||||
        except RequestException as exc:
 | 
			
		||||
            LOGGER.warning("Unable to fetch user resources", exc=exc)
 | 
			
		||||
            raise Http404
 | 
			
		||||
            raise Http404 from None
 | 
			
		||||
        for resource in resources:
 | 
			
		||||
            if resource["provides"] != "server":
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
@ -1,7 +1,5 @@
 | 
			
		||||
"""saml sp models"""
 | 
			
		||||
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from django.db import models
 | 
			
		||||
from django.http import HttpRequest
 | 
			
		||||
from django.templatetags.static import static
 | 
			
		||||
@ -204,7 +202,7 @@ class SAMLSource(Source):
 | 
			
		||||
            icon_url=self.icon_url,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def ui_user_settings(self) -> Optional[UserSettingSerializer]:
 | 
			
		||||
    def ui_user_settings(self) -> UserSettingSerializer | None:
 | 
			
		||||
        icon = self.icon_url
 | 
			
		||||
        if not icon:
 | 
			
		||||
            icon = static(f"authentik/sources/{self.slug}.svg")
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,7 @@
 | 
			
		||||
"""SAML Service Provider Metadata Processor"""
 | 
			
		||||
 | 
			
		||||
from typing import Iterator, Optional
 | 
			
		||||
from collections.abc import Iterator
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from django.http import HttpRequest
 | 
			
		||||
from lxml.etree import Element, SubElement, tostring  # nosec
 | 
			
		||||
@ -30,7 +31,8 @@ class MetadataProcessor:
 | 
			
		||||
        self.source = source
 | 
			
		||||
        self.http_request = request
 | 
			
		||||
 | 
			
		||||
    def get_signing_key_descriptor(self) -> Optional[Element]:
 | 
			
		||||
    # Using type unions doesn't work with cython types (which is what lxml is)
 | 
			
		||||
    def get_signing_key_descriptor(self) -> Optional[Element]:  # noqa: UP007
 | 
			
		||||
        """Get Signing KeyDescriptor, if enabled for the source"""
 | 
			
		||||
        if self.source.signing_kp:
 | 
			
		||||
            key_descriptor = Element(f"{{{NS_SAML_METADATA}}}KeyDescriptor")
 | 
			
		||||
 | 
			
		||||
@ -88,7 +88,7 @@ class InitiateView(View):
 | 
			
		||||
        try:
 | 
			
		||||
            plan = planner.plan(self.request, kwargs)
 | 
			
		||||
        except FlowNonApplicableException:
 | 
			
		||||
            raise Http404
 | 
			
		||||
            raise Http404 from None
 | 
			
		||||
        for stage in stages_to_append:
 | 
			
		||||
            plan.append_stage(stage)
 | 
			
		||||
        self.request.session[SESSION_KEY_PLAN] = plan
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user