core: optimise user list endpoint (#8353)
* unrelated changes Signed-off-by: Jens Langhammer <jens@goauthentik.io> * optimization pass 1: reduce N tenant lookups by taking tenant from request, reduce get_anonymous calls Signed-off-by: Jens Langhammer <jens@goauthentik.io> * fix lint Signed-off-by: Jens Langhammer <jens@goauthentik.io> * fix Signed-off-by: Jens Langhammer <jens@goauthentik.io> * make it easier to exclude anonymous user Signed-off-by: Jens Langhammer <jens@goauthentik.io> * fix? Signed-off-by: Jens Langhammer <jens@goauthentik.io> --------- Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
		| @ -7,7 +7,6 @@ from django.contrib.auth import get_user_model | |||||||
| from django.db.models import Model, Q, QuerySet | from django.db.models import Model, Q, QuerySet | ||||||
| from django.utils.timezone import now | from django.utils.timezone import now | ||||||
| from django.utils.translation import gettext as _ | from django.utils.translation import gettext as _ | ||||||
| from guardian.shortcuts import get_anonymous_user |  | ||||||
| from yaml import dump | from yaml import dump | ||||||
|  |  | ||||||
| from authentik.blueprints.v1.common import ( | from authentik.blueprints.v1.common import ( | ||||||
| @ -48,7 +47,7 @@ class Exporter: | |||||||
|         """Return a queryset for `model`. Can be used to filter some |         """Return a queryset for `model`. Can be used to filter some | ||||||
|         objects on some models""" |         objects on some models""" | ||||||
|         if model == get_user_model(): |         if model == get_user_model(): | ||||||
|             return model.objects.exclude(pk=get_anonymous_user().pk) |             return model.objects.exclude_anonymous() | ||||||
|         return model.objects.all() |         return model.objects.all() | ||||||
|  |  | ||||||
|     def _pre_export(self, blueprint: Blueprint): |     def _pre_export(self, blueprint: Blueprint): | ||||||
|  | |||||||
| @ -8,7 +8,6 @@ from sentry_sdk.hub import Hub | |||||||
|  |  | ||||||
| from authentik import get_full_version | from authentik import get_full_version | ||||||
| from authentik.brands.models import Brand | from authentik.brands.models import Brand | ||||||
| from authentik.tenants.utils import get_current_tenant |  | ||||||
|  |  | ||||||
| _q_default = Q(default=True) | _q_default = Q(default=True) | ||||||
| DEFAULT_BRAND = Brand(domain="fallback") | DEFAULT_BRAND = Brand(domain="fallback") | ||||||
| @ -36,7 +35,7 @@ def context_processor(request: HttpRequest) -> dict[str, Any]: | |||||||
|         trace = span.to_traceparent() |         trace = span.to_traceparent() | ||||||
|     return { |     return { | ||||||
|         "brand": brand, |         "brand": brand, | ||||||
|         "footer_links": get_current_tenant().footer_links, |         "footer_links": request.tenant.footer_links, | ||||||
|         "sentry_trace": trace, |         "sentry_trace": trace, | ||||||
|         "version": get_full_version(), |         "version": get_full_version(), | ||||||
|     } |     } | ||||||
|  | |||||||
| @ -5,12 +5,12 @@ from typing import Optional | |||||||
| from django.core.cache import cache | from django.core.cache import cache | ||||||
| from django.db.models import QuerySet | from django.db.models import QuerySet | ||||||
| from django.db.models.functions import ExtractHour | from django.db.models.functions import ExtractHour | ||||||
| from django.http.response import HttpResponseBadRequest |  | ||||||
| from django.shortcuts import get_object_or_404 | from django.shortcuts import get_object_or_404 | ||||||
| from drf_spectacular.types import OpenApiTypes | from drf_spectacular.types import OpenApiTypes | ||||||
| from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema | from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema | ||||||
| from guardian.shortcuts import get_objects_for_user | from guardian.shortcuts import get_objects_for_user | ||||||
| from rest_framework.decorators import action | from rest_framework.decorators import action | ||||||
|  | from rest_framework.exceptions import ValidationError | ||||||
| from rest_framework.fields import CharField, ReadOnlyField, SerializerMethodField | from rest_framework.fields import CharField, ReadOnlyField, SerializerMethodField | ||||||
| from rest_framework.parsers import MultiPartParser | from rest_framework.parsers import MultiPartParser | ||||||
| from rest_framework.request import Request | from rest_framework.request import Request | ||||||
| @ -147,7 +147,6 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet): | |||||||
|         ], |         ], | ||||||
|         responses={ |         responses={ | ||||||
|             200: PolicyTestResultSerializer(), |             200: PolicyTestResultSerializer(), | ||||||
|             404: OpenApiResponse(description="for_user user not found"), |  | ||||||
|         }, |         }, | ||||||
|     ) |     ) | ||||||
|     @action(detail=True, methods=["GET"]) |     @action(detail=True, methods=["GET"]) | ||||||
| @ -160,9 +159,11 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet): | |||||||
|         for_user = request.user |         for_user = request.user | ||||||
|         if request.user.is_superuser and "for_user" in request.query_params: |         if request.user.is_superuser and "for_user" in request.query_params: | ||||||
|             try: |             try: | ||||||
|                 for_user = get_object_or_404(User, pk=request.query_params.get("for_user")) |                 for_user = User.objects.filter(pk=request.query_params.get("for_user")).first() | ||||||
|             except ValueError: |             except ValueError: | ||||||
|                 return HttpResponseBadRequest("for_user must be numerical") |                 raise ValidationError({"for_user": "for_user must be numerical"}) | ||||||
|  |             if not for_user: | ||||||
|  |                 raise ValidationError({"for_user": "User not found"}) | ||||||
|         engine = PolicyEngine(application, for_user, request) |         engine = PolicyEngine(application, for_user, request) | ||||||
|         engine.use_cache = False |         engine.use_cache = False | ||||||
|         with capture_logs() as logs: |         with capture_logs() as logs: | ||||||
|  | |||||||
| @ -30,7 +30,7 @@ from drf_spectacular.utils import ( | |||||||
|     extend_schema_field, |     extend_schema_field, | ||||||
|     inline_serializer, |     inline_serializer, | ||||||
| ) | ) | ||||||
| from guardian.shortcuts import get_anonymous_user, get_objects_for_user | from guardian.shortcuts import get_objects_for_user | ||||||
| from rest_framework.decorators import action | from rest_framework.decorators import action | ||||||
| from rest_framework.fields import CharField, IntegerField, ListField, SerializerMethodField | from rest_framework.fields import CharField, IntegerField, ListField, SerializerMethodField | ||||||
| from rest_framework.request import Request | from rest_framework.request import Request | ||||||
| @ -72,6 +72,7 @@ from authentik.flows.exceptions import FlowNonApplicableException | |||||||
| from authentik.flows.models import FlowToken | from authentik.flows.models import FlowToken | ||||||
| from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlanner | from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlanner | ||||||
| from authentik.flows.views.executor import QS_KEY_TOKEN | from authentik.flows.views.executor import QS_KEY_TOKEN | ||||||
|  | from authentik.lib.avatars import get_avatar | ||||||
| from authentik.stages.email.models import EmailStage | from authentik.stages.email.models import EmailStage | ||||||
| from authentik.stages.email.tasks import send_mails | from authentik.stages.email.tasks import send_mails | ||||||
| from authentik.stages.email.utils import TemplateEmailMessage | from authentik.stages.email.utils import TemplateEmailMessage | ||||||
| @ -102,14 +103,21 @@ class UserSerializer(ModelSerializer): | |||||||
|     """User Serializer""" |     """User Serializer""" | ||||||
|  |  | ||||||
|     is_superuser = BooleanField(read_only=True) |     is_superuser = BooleanField(read_only=True) | ||||||
|     avatar = CharField(read_only=True) |     avatar = SerializerMethodField() | ||||||
|     attributes = JSONDictField(required=False) |     attributes = JSONDictField(required=False) | ||||||
|     groups = PrimaryKeyRelatedField( |     groups = PrimaryKeyRelatedField( | ||||||
|         allow_empty=True, many=True, source="ak_groups", queryset=Group.objects.all(), default=list |         allow_empty=True, | ||||||
|  |         many=True, | ||||||
|  |         source="ak_groups", | ||||||
|  |         queryset=Group.objects.all().order_by("name"), | ||||||
|  |         default=list, | ||||||
|     ) |     ) | ||||||
|     groups_obj = ListSerializer(child=UserGroupSerializer(), read_only=True, source="ak_groups") |     groups_obj = ListSerializer(child=UserGroupSerializer(), read_only=True, source="ak_groups") | ||||||
|     uid = CharField(read_only=True) |     uid = CharField(read_only=True) | ||||||
|     username = CharField(max_length=150, validators=[UniqueValidator(queryset=User.objects.all())]) |     username = CharField( | ||||||
|  |         max_length=150, | ||||||
|  |         validators=[UniqueValidator(queryset=User.objects.all().order_by("username"))], | ||||||
|  |     ) | ||||||
|  |  | ||||||
|     def __init__(self, *args, **kwargs): |     def __init__(self, *args, **kwargs): | ||||||
|         super().__init__(*args, **kwargs) |         super().__init__(*args, **kwargs) | ||||||
| @ -143,6 +151,10 @@ class UserSerializer(ModelSerializer): | |||||||
|             instance.set_unusable_password() |             instance.set_unusable_password() | ||||||
|             instance.save() |             instance.save() | ||||||
|  |  | ||||||
|  |     def get_avatar(self, user: User) -> str: | ||||||
|  |         """User's avatar, either a http/https URL or a data URI""" | ||||||
|  |         return get_avatar(user, self.context["request"]) | ||||||
|  |  | ||||||
|     def validate_path(self, path: str) -> str: |     def validate_path(self, path: str) -> str: | ||||||
|         """Validate path""" |         """Validate path""" | ||||||
|         if path[:1] == "/" or path[-1] == "/": |         if path[:1] == "/" or path[-1] == "/": | ||||||
| @ -197,12 +209,16 @@ class UserSelfSerializer(ModelSerializer): | |||||||
|     """User Serializer for information a user can retrieve about themselves""" |     """User Serializer for information a user can retrieve about themselves""" | ||||||
|  |  | ||||||
|     is_superuser = BooleanField(read_only=True) |     is_superuser = BooleanField(read_only=True) | ||||||
|     avatar = CharField(read_only=True) |     avatar = SerializerMethodField() | ||||||
|     groups = SerializerMethodField() |     groups = SerializerMethodField() | ||||||
|     uid = CharField(read_only=True) |     uid = CharField(read_only=True) | ||||||
|     settings = SerializerMethodField() |     settings = SerializerMethodField() | ||||||
|     system_permissions = SerializerMethodField() |     system_permissions = SerializerMethodField() | ||||||
|  |  | ||||||
|  |     def get_avatar(self, user: User) -> str: | ||||||
|  |         """User's avatar, either a http/https URL or a data URI""" | ||||||
|  |         return get_avatar(user, self.context["request"]) | ||||||
|  |  | ||||||
|     @extend_schema_field( |     @extend_schema_field( | ||||||
|         ListSerializer( |         ListSerializer( | ||||||
|             child=inline_serializer( |             child=inline_serializer( | ||||||
| @ -329,11 +345,11 @@ class UsersFilter(FilterSet): | |||||||
|     groups_by_name = ModelMultipleChoiceFilter( |     groups_by_name = ModelMultipleChoiceFilter( | ||||||
|         field_name="ak_groups__name", |         field_name="ak_groups__name", | ||||||
|         to_field_name="name", |         to_field_name="name", | ||||||
|         queryset=Group.objects.all(), |         queryset=Group.objects.all().order_by("name"), | ||||||
|     ) |     ) | ||||||
|     groups_by_pk = ModelMultipleChoiceFilter( |     groups_by_pk = ModelMultipleChoiceFilter( | ||||||
|         field_name="ak_groups", |         field_name="ak_groups", | ||||||
|         queryset=Group.objects.all(), |         queryset=Group.objects.all().order_by("name"), | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     def filter_attributes(self, queryset, name, value): |     def filter_attributes(self, queryset, name, value): | ||||||
| @ -378,7 +394,7 @@ class UserViewSet(UsedByMixin, ModelViewSet): | |||||||
|     filterset_class = UsersFilter |     filterset_class = UsersFilter | ||||||
|  |  | ||||||
|     def get_queryset(self):  # pragma: no cover |     def get_queryset(self):  # pragma: no cover | ||||||
|         return User.objects.all().exclude(pk=get_anonymous_user().pk) |         return User.objects.all().exclude_anonymous().prefetch_related("ak_groups") | ||||||
|  |  | ||||||
|     def _create_recovery_link(self) -> tuple[Optional[str], Optional[Token]]: |     def _create_recovery_link(self) -> tuple[Optional[str], Optional[Token]]: | ||||||
|         """Create a recovery link (when the current brand has a recovery flow set), |         """Create a recovery link (when the current brand has a recovery flow set), | ||||||
|  | |||||||
| @ -14,6 +14,7 @@ from django.http import HttpRequest | |||||||
| from django.utils.functional import SimpleLazyObject, cached_property | from django.utils.functional import SimpleLazyObject, cached_property | ||||||
| from django.utils.timezone import now | from django.utils.timezone import now | ||||||
| from django.utils.translation import gettext_lazy as _ | from django.utils.translation import gettext_lazy as _ | ||||||
|  | from guardian.conf import settings | ||||||
| from guardian.mixins import GuardianUserMixin | from guardian.mixins import GuardianUserMixin | ||||||
| from model_utils.managers import InheritanceManager | from model_utils.managers import InheritanceManager | ||||||
| from rest_framework.serializers import Serializer | from rest_framework.serializers import Serializer | ||||||
| @ -169,13 +170,29 @@ class Group(SerializerModel): | |||||||
|         verbose_name_plural = _("Groups") |         verbose_name_plural = _("Groups") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class UserQuerySet(models.QuerySet): | ||||||
|  |     """User queryset""" | ||||||
|  |  | ||||||
|  |     def exclude_anonymous(self): | ||||||
|  |         """Exclude anonymous user""" | ||||||
|  |         return self.exclude(**{User.USERNAME_FIELD: settings.ANONYMOUS_USER_NAME}) | ||||||
|  |  | ||||||
|  |  | ||||||
| class UserManager(DjangoUserManager): | class UserManager(DjangoUserManager): | ||||||
|     """User manager that doesn't assign is_superuser and is_staff""" |     """User manager that doesn't assign is_superuser and is_staff""" | ||||||
|  |  | ||||||
|  |     def get_queryset(self): | ||||||
|  |         """Create special user queryset""" | ||||||
|  |         return UserQuerySet(self.model, using=self._db) | ||||||
|  |  | ||||||
|     def create_user(self, username, email=None, password=None, **extra_fields): |     def create_user(self, username, email=None, password=None, **extra_fields): | ||||||
|         """User manager that doesn't assign is_superuser and is_staff""" |         """User manager that doesn't assign is_superuser and is_staff""" | ||||||
|         return self._create_user(username, email, password, **extra_fields) |         return self._create_user(username, email, password, **extra_fields) | ||||||
|  |  | ||||||
|  |     def exclude_anonymous(self) -> QuerySet: | ||||||
|  |         """Exclude anonymous user""" | ||||||
|  |         return self.get_queryset().exclude_anonymous() | ||||||
|  |  | ||||||
|  |  | ||||||
| class User(SerializerModel, GuardianUserMixin, AbstractUser): | class User(SerializerModel, GuardianUserMixin, AbstractUser): | ||||||
|     """authentik User model, based on django's contrib auth user model.""" |     """authentik User model, based on django's contrib auth user model.""" | ||||||
|  | |||||||
| @ -16,7 +16,6 @@ from django.db import models | |||||||
| from django.db.models.query import QuerySet | from django.db.models.query import QuerySet | ||||||
| from django.utils.timezone import now | from django.utils.timezone import now | ||||||
| from django.utils.translation import gettext as _ | from django.utils.translation import gettext as _ | ||||||
| from guardian.shortcuts import get_anonymous_user |  | ||||||
| from jwt import PyJWTError, decode, get_unverified_header | from jwt import PyJWTError, decode, get_unverified_header | ||||||
| from rest_framework.exceptions import ValidationError | from rest_framework.exceptions import ValidationError | ||||||
| from rest_framework.serializers import BaseSerializer | from rest_framework.serializers import BaseSerializer | ||||||
| @ -104,7 +103,7 @@ class LicenseKey: | |||||||
|     @staticmethod |     @staticmethod | ||||||
|     def base_user_qs() -> QuerySet: |     def base_user_qs() -> QuerySet: | ||||||
|         """Base query set for all users""" |         """Base query set for all users""" | ||||||
|         return User.objects.all().exclude(is_active=False).exclude(pk=get_anonymous_user().pk) |         return User.objects.all().exclude_anonymous().exclude(is_active=False) | ||||||
|  |  | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def get_default_user_count(): |     def get_default_user_count(): | ||||||
|  | |||||||
| @ -26,7 +26,7 @@ from authentik.flows.challenge import ( | |||||||
| from authentik.flows.exceptions import StageInvalidException | from authentik.flows.exceptions import StageInvalidException | ||||||
| from authentik.flows.models import InvalidResponseAction | from authentik.flows.models import InvalidResponseAction | ||||||
| from authentik.flows.planner import PLAN_CONTEXT_APPLICATION, PLAN_CONTEXT_PENDING_USER | from authentik.flows.planner import PLAN_CONTEXT_APPLICATION, PLAN_CONTEXT_PENDING_USER | ||||||
| from authentik.lib.avatars import DEFAULT_AVATAR | from authentik.lib.avatars import DEFAULT_AVATAR, get_avatar | ||||||
| from authentik.lib.utils.reflection import class_to_path | from authentik.lib.utils.reflection import class_to_path | ||||||
|  |  | ||||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||||
| @ -197,7 +197,7 @@ class ChallengeStageView(StageView): | |||||||
|                     challenge.initial_data["pending_user"] = user.username |                     challenge.initial_data["pending_user"] = user.username | ||||||
|                 challenge.initial_data["pending_user_avatar"] = DEFAULT_AVATAR |                 challenge.initial_data["pending_user_avatar"] = DEFAULT_AVATAR | ||||||
|                 if not isinstance(user, AnonymousUser): |                 if not isinstance(user, AnonymousUser): | ||||||
|                     challenge.initial_data["pending_user_avatar"] = user.avatar |                     challenge.initial_data["pending_user_avatar"] = get_avatar(user, self.request) | ||||||
|         return challenge |         return challenge | ||||||
|  |  | ||||||
|     def get_challenge(self, *args, **kwargs) -> Challenge: |     def get_challenge(self, *args, **kwargs) -> Challenge: | ||||||
|  | |||||||
| @ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Optional | |||||||
| from urllib.parse import urlencode | from urllib.parse import urlencode | ||||||
|  |  | ||||||
| from django.core.cache import cache | from django.core.cache import cache | ||||||
|  | from django.http import HttpRequest | ||||||
| from django.templatetags.static import static | from django.templatetags.static import static | ||||||
| from lxml import etree  # nosec | from lxml import etree  # nosec | ||||||
| from lxml.etree import Element, SubElement  # nosec | from lxml.etree import Element, SubElement  # nosec | ||||||
| @ -15,13 +16,13 @@ from authentik.lib.config import get_path_from_dict | |||||||
| from authentik.lib.utils.http import get_http_session | from authentik.lib.utils.http import get_http_session | ||||||
| from authentik.tenants.utils import get_current_tenant | from authentik.tenants.utils import get_current_tenant | ||||||
|  |  | ||||||
|  | if TYPE_CHECKING: | ||||||
|  |     from authentik.core.models import User | ||||||
|  |  | ||||||
| GRAVATAR_URL = "https://secure.gravatar.com" | GRAVATAR_URL = "https://secure.gravatar.com" | ||||||
| DEFAULT_AVATAR = static("dist/assets/images/user_default.png") | DEFAULT_AVATAR = static("dist/assets/images/user_default.png") | ||||||
| CACHE_KEY_GRAVATAR = "goauthentik.io/lib/avatars/" | CACHE_KEY_GRAVATAR = "goauthentik.io/lib/avatars/" | ||||||
|  |  | ||||||
| if TYPE_CHECKING: |  | ||||||
|     from authentik.core.models import User |  | ||||||
|  |  | ||||||
| SVG_XML_NS = "http://www.w3.org/2000/svg" | SVG_XML_NS = "http://www.w3.org/2000/svg" | ||||||
| SVG_NS_MAP = {None: SVG_XML_NS} | SVG_NS_MAP = {None: SVG_XML_NS} | ||||||
| # Match fonts used in web UI | # Match fonts used in web UI | ||||||
| @ -177,14 +178,19 @@ def avatar_mode_url(user: "User", mode: str) -> Optional[str]: | |||||||
|     } |     } | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_avatar(user: "User") -> str: | def get_avatar(user: "User", request: Optional[HttpRequest] = None) -> str: | ||||||
|     """Get avatar with configured mode""" |     """Get avatar with configured mode""" | ||||||
|     mode_map = { |     mode_map = { | ||||||
|         "none": avatar_mode_none, |         "none": avatar_mode_none, | ||||||
|         "initials": avatar_mode_generated, |         "initials": avatar_mode_generated, | ||||||
|         "gravatar": avatar_mode_gravatar, |         "gravatar": avatar_mode_gravatar, | ||||||
|     } |     } | ||||||
|     modes: str = get_current_tenant().avatars |     tenant = None | ||||||
|  |     if request: | ||||||
|  |         tenant = request.tenant | ||||||
|  |     else: | ||||||
|  |         tenant = get_current_tenant() | ||||||
|  |     modes: str = tenant.avatars | ||||||
|     for mode in modes.split(","): |     for mode in modes.split(","): | ||||||
|         avatar = None |         avatar = None | ||||||
|         if mode in mode_map: |         if mode in mode_map: | ||||||
|  | |||||||
| @ -3,7 +3,6 @@ from django.core.cache import cache | |||||||
| from django.db import models | from django.db import models | ||||||
| from django.db.models import QuerySet | from django.db.models import QuerySet | ||||||
| from django.utils.translation import gettext_lazy as _ | from django.utils.translation import gettext_lazy as _ | ||||||
| from guardian.shortcuts import get_anonymous_user |  | ||||||
| from redis.lock import Lock | from redis.lock import Lock | ||||||
| from rest_framework.serializers import Serializer | from rest_framework.serializers import Serializer | ||||||
|  |  | ||||||
| @ -42,7 +41,7 @@ class SCIMProvider(BackchannelProvider): | |||||||
|     def get_user_qs(self) -> QuerySet[User]: |     def get_user_qs(self) -> QuerySet[User]: | ||||||
|         """Get queryset of all users with consistent ordering |         """Get queryset of all users with consistent ordering | ||||||
|         according to the provider's settings""" |         according to the provider's settings""" | ||||||
|         base = User.objects.all().exclude(pk=get_anonymous_user().pk) |         base = User.objects.all().exclude_anonymous() | ||||||
|         if self.exclude_users_service_account: |         if self.exclude_users_service_account: | ||||||
|             base = base.exclude(type=UserTypes.SERVICE_ACCOUNT).exclude( |             base = base.exclude(type=UserTypes.SERVICE_ACCOUNT).exclude( | ||||||
|                 type=UserTypes.INTERNAL_SERVICE_ACCOUNT |                 type=UserTypes.INTERNAL_SERVICE_ACCOUNT | ||||||
|  | |||||||
| @ -2,7 +2,6 @@ | |||||||
| from json import loads | from json import loads | ||||||
|  |  | ||||||
| from django.test import TestCase | from django.test import TestCase | ||||||
| from guardian.shortcuts import get_anonymous_user |  | ||||||
| from jsonschema import validate | from jsonschema import validate | ||||||
| from requests_mock import Mocker | from requests_mock import Mocker | ||||||
|  |  | ||||||
| @ -19,7 +18,7 @@ class SCIMGroupTests(TestCase): | |||||||
|     def setUp(self) -> None: |     def setUp(self) -> None: | ||||||
|         # Delete all users and groups as the mocked HTTP responses only return one ID |         # Delete all users and groups as the mocked HTTP responses only return one ID | ||||||
|         # which will cause errors with multiple users |         # which will cause errors with multiple users | ||||||
|         User.objects.all().exclude(pk=get_anonymous_user().pk).delete() |         User.objects.all().exclude_anonymous().delete() | ||||||
|         Group.objects.all().delete() |         Group.objects.all().delete() | ||||||
|         self.provider: SCIMProvider = SCIMProvider.objects.create( |         self.provider: SCIMProvider = SCIMProvider.objects.create( | ||||||
|             name=generate_id(), |             name=generate_id(), | ||||||
|  | |||||||
| @ -1,6 +1,5 @@ | |||||||
| """SCIM Membership tests""" | """SCIM Membership tests""" | ||||||
| from django.test import TestCase | from django.test import TestCase | ||||||
| from guardian.shortcuts import get_anonymous_user |  | ||||||
| from requests_mock import Mocker | from requests_mock import Mocker | ||||||
|  |  | ||||||
| from authentik.blueprints.tests import apply_blueprint | from authentik.blueprints.tests import apply_blueprint | ||||||
| @ -21,7 +20,7 @@ class SCIMMembershipTests(TestCase): | |||||||
|     def setUp(self) -> None: |     def setUp(self) -> None: | ||||||
|         # Delete all users and groups as the mocked HTTP responses only return one ID |         # Delete all users and groups as the mocked HTTP responses only return one ID | ||||||
|         # which will cause errors with multiple users |         # which will cause errors with multiple users | ||||||
|         User.objects.all().exclude(pk=get_anonymous_user().pk).delete() |         User.objects.all().exclude_anonymous().delete() | ||||||
|         Group.objects.all().delete() |         Group.objects.all().delete() | ||||||
|         Tenant.objects.update(avatars="none") |         Tenant.objects.update(avatars="none") | ||||||
|  |  | ||||||
|  | |||||||
| @ -2,7 +2,6 @@ | |||||||
| from json import loads | from json import loads | ||||||
|  |  | ||||||
| from django.test import TestCase | from django.test import TestCase | ||||||
| from guardian.shortcuts import get_anonymous_user |  | ||||||
| from jsonschema import validate | from jsonschema import validate | ||||||
| from requests_mock import Mocker | from requests_mock import Mocker | ||||||
|  |  | ||||||
| @ -22,7 +21,7 @@ class SCIMUserTests(TestCase): | |||||||
|         # Delete all users and groups as the mocked HTTP responses only return one ID |         # Delete all users and groups as the mocked HTTP responses only return one ID | ||||||
|         # which will cause errors with multiple users |         # which will cause errors with multiple users | ||||||
|         Tenant.objects.update(avatars="none") |         Tenant.objects.update(avatars="none") | ||||||
|         User.objects.all().exclude(pk=get_anonymous_user().pk).delete() |         User.objects.all().exclude_anonymous().delete() | ||||||
|         Group.objects.all().delete() |         Group.objects.all().delete() | ||||||
|         self.provider: SCIMProvider = SCIMProvider.objects.create( |         self.provider: SCIMProvider = SCIMProvider.objects.create( | ||||||
|             name=generate_id(), |             name=generate_id(), | ||||||
|  | |||||||
| @ -54,7 +54,7 @@ func (cs *CryptoStore) getFingerprint(uuid string) string { | |||||||
| func (cs *CryptoStore) Fetch(uuid string) error { | func (cs *CryptoStore) Fetch(uuid string) error { | ||||||
| 	cfp := cs.getFingerprint(uuid) | 	cfp := cs.getFingerprint(uuid) | ||||||
| 	if cfp == cs.fingerprints[uuid] { | 	if cfp == cs.fingerprints[uuid] { | ||||||
| 		cs.log.WithField("uuid", uuid).Info("Fingerprint hasn't changed, not fetching cert") | 		cs.log.WithField("uuid", uuid).Debug("Fingerprint hasn't changed, not fetching cert") | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| 	cs.log.WithField("uuid", uuid).Info("Fetching certificate and private key") | 	cs.log.WithField("uuid", uuid).Info("Fetching certificate and private key") | ||||||
|  | |||||||
| @ -78,7 +78,7 @@ func (w *Watcher) GetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, err | |||||||
| 	if bestSelection == nil { | 	if bestSelection == nil { | ||||||
| 		return w.fallback, nil | 		return w.fallback, nil | ||||||
| 	} | 	} | ||||||
| 	cert := w.cs.Get(*bestSelection.WebCertificate.Get()) | 	cert := w.cs.Get(bestSelection.GetWebCertificate()) | ||||||
| 	if cert == nil { | 	if cert == nil { | ||||||
| 		return w.fallback, nil | 		return w.fallback, nil | ||||||
| 	} | 	} | ||||||
|  | |||||||
| @ -3,7 +3,6 @@ from time import sleep | |||||||
| from typing import Any, Optional | from typing import Any, Optional | ||||||
|  |  | ||||||
| from docker.types import Healthcheck | from docker.types import Healthcheck | ||||||
| from guardian.utils import get_anonymous_user |  | ||||||
| from selenium.webdriver.common.by import By | from selenium.webdriver.common.by import By | ||||||
| from selenium.webdriver.common.keys import Keys | from selenium.webdriver.common.keys import Keys | ||||||
| from selenium.webdriver.support import expected_conditions as ec | from selenium.webdriver.support import expected_conditions as ec | ||||||
| @ -161,7 +160,7 @@ class TestSourceSAML(SeleniumTestCase): | |||||||
|         self.assert_user( |         self.assert_user( | ||||||
|             User.objects.exclude(username="akadmin") |             User.objects.exclude(username="akadmin") | ||||||
|             .exclude(username__startswith="ak-outpost") |             .exclude(username__startswith="ak-outpost") | ||||||
|             .exclude(pk=get_anonymous_user().pk) |             .exclude_anonymous() | ||||||
|             .exclude(pk=self.user.pk) |             .exclude(pk=self.user.pk) | ||||||
|             .first() |             .first() | ||||||
|         ) |         ) | ||||||
| @ -244,7 +243,7 @@ class TestSourceSAML(SeleniumTestCase): | |||||||
|         self.assert_user( |         self.assert_user( | ||||||
|             User.objects.exclude(username="akadmin") |             User.objects.exclude(username="akadmin") | ||||||
|             .exclude(username__startswith="ak-outpost") |             .exclude(username__startswith="ak-outpost") | ||||||
|             .exclude(pk=get_anonymous_user().pk) |             .exclude_anonymous() | ||||||
|             .exclude(pk=self.user.pk) |             .exclude(pk=self.user.pk) | ||||||
|             .first() |             .first() | ||||||
|         ) |         ) | ||||||
| @ -314,7 +313,7 @@ class TestSourceSAML(SeleniumTestCase): | |||||||
|         self.assert_user( |         self.assert_user( | ||||||
|             User.objects.exclude(username="akadmin") |             User.objects.exclude(username="akadmin") | ||||||
|             .exclude(username__startswith="ak-outpost") |             .exclude(username__startswith="ak-outpost") | ||||||
|             .exclude(pk=get_anonymous_user().pk) |             .exclude_anonymous() | ||||||
|             .exclude(pk=self.user.pk) |             .exclude(pk=self.user.pk) | ||||||
|             .first() |             .first() | ||||||
|         ) |         ) | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user
	 Jens L
					Jens L