core: optimise user list endpoint (#8353)
* unrelated changes Signed-off-by: Jens Langhammer <jens@goauthentik.io> * optimization pass 1: reduce N tenant lookups by taking tenant from request, reduce get_anonymous calls Signed-off-by: Jens Langhammer <jens@goauthentik.io> * fix lint Signed-off-by: Jens Langhammer <jens@goauthentik.io> * fix Signed-off-by: Jens Langhammer <jens@goauthentik.io> * make it easier to exclude anonymous user Signed-off-by: Jens Langhammer <jens@goauthentik.io> * fix? Signed-off-by: Jens Langhammer <jens@goauthentik.io> --------- Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
@ -7,7 +7,6 @@ from django.contrib.auth import get_user_model
|
||||
from django.db.models import Model, Q, QuerySet
|
||||
from django.utils.timezone import now
|
||||
from django.utils.translation import gettext as _
|
||||
from guardian.shortcuts import get_anonymous_user
|
||||
from yaml import dump
|
||||
|
||||
from authentik.blueprints.v1.common import (
|
||||
@ -48,7 +47,7 @@ class Exporter:
|
||||
"""Return a queryset for `model`. Can be used to filter some
|
||||
objects on some models"""
|
||||
if model == get_user_model():
|
||||
return model.objects.exclude(pk=get_anonymous_user().pk)
|
||||
return model.objects.exclude_anonymous()
|
||||
return model.objects.all()
|
||||
|
||||
def _pre_export(self, blueprint: Blueprint):
|
||||
|
@ -8,7 +8,6 @@ from sentry_sdk.hub import Hub
|
||||
|
||||
from authentik import get_full_version
|
||||
from authentik.brands.models import Brand
|
||||
from authentik.tenants.utils import get_current_tenant
|
||||
|
||||
_q_default = Q(default=True)
|
||||
DEFAULT_BRAND = Brand(domain="fallback")
|
||||
@ -36,7 +35,7 @@ def context_processor(request: HttpRequest) -> dict[str, Any]:
|
||||
trace = span.to_traceparent()
|
||||
return {
|
||||
"brand": brand,
|
||||
"footer_links": get_current_tenant().footer_links,
|
||||
"footer_links": request.tenant.footer_links,
|
||||
"sentry_trace": trace,
|
||||
"version": get_full_version(),
|
||||
}
|
||||
|
@ -5,12 +5,12 @@ from typing import Optional
|
||||
from django.core.cache import cache
|
||||
from django.db.models import QuerySet
|
||||
from django.db.models.functions import ExtractHour
|
||||
from django.http.response import HttpResponseBadRequest
|
||||
from django.shortcuts import get_object_or_404
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema
|
||||
from guardian.shortcuts import get_objects_for_user
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.exceptions import ValidationError
|
||||
from rest_framework.fields import CharField, ReadOnlyField, SerializerMethodField
|
||||
from rest_framework.parsers import MultiPartParser
|
||||
from rest_framework.request import Request
|
||||
@ -147,7 +147,6 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet):
|
||||
],
|
||||
responses={
|
||||
200: PolicyTestResultSerializer(),
|
||||
404: OpenApiResponse(description="for_user user not found"),
|
||||
},
|
||||
)
|
||||
@action(detail=True, methods=["GET"])
|
||||
@ -160,9 +159,11 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet):
|
||||
for_user = request.user
|
||||
if request.user.is_superuser and "for_user" in request.query_params:
|
||||
try:
|
||||
for_user = get_object_or_404(User, pk=request.query_params.get("for_user"))
|
||||
for_user = User.objects.filter(pk=request.query_params.get("for_user")).first()
|
||||
except ValueError:
|
||||
return HttpResponseBadRequest("for_user must be numerical")
|
||||
raise ValidationError({"for_user": "for_user must be numerical"})
|
||||
if not for_user:
|
||||
raise ValidationError({"for_user": "User not found"})
|
||||
engine = PolicyEngine(application, for_user, request)
|
||||
engine.use_cache = False
|
||||
with capture_logs() as logs:
|
||||
|
@ -30,7 +30,7 @@ from drf_spectacular.utils import (
|
||||
extend_schema_field,
|
||||
inline_serializer,
|
||||
)
|
||||
from guardian.shortcuts import get_anonymous_user, get_objects_for_user
|
||||
from guardian.shortcuts import get_objects_for_user
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.fields import CharField, IntegerField, ListField, SerializerMethodField
|
||||
from rest_framework.request import Request
|
||||
@ -72,6 +72,7 @@ from authentik.flows.exceptions import FlowNonApplicableException
|
||||
from authentik.flows.models import FlowToken
|
||||
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlanner
|
||||
from authentik.flows.views.executor import QS_KEY_TOKEN
|
||||
from authentik.lib.avatars import get_avatar
|
||||
from authentik.stages.email.models import EmailStage
|
||||
from authentik.stages.email.tasks import send_mails
|
||||
from authentik.stages.email.utils import TemplateEmailMessage
|
||||
@ -102,14 +103,21 @@ class UserSerializer(ModelSerializer):
|
||||
"""User Serializer"""
|
||||
|
||||
is_superuser = BooleanField(read_only=True)
|
||||
avatar = CharField(read_only=True)
|
||||
avatar = SerializerMethodField()
|
||||
attributes = JSONDictField(required=False)
|
||||
groups = PrimaryKeyRelatedField(
|
||||
allow_empty=True, many=True, source="ak_groups", queryset=Group.objects.all(), default=list
|
||||
allow_empty=True,
|
||||
many=True,
|
||||
source="ak_groups",
|
||||
queryset=Group.objects.all().order_by("name"),
|
||||
default=list,
|
||||
)
|
||||
groups_obj = ListSerializer(child=UserGroupSerializer(), read_only=True, source="ak_groups")
|
||||
uid = CharField(read_only=True)
|
||||
username = CharField(max_length=150, validators=[UniqueValidator(queryset=User.objects.all())])
|
||||
username = CharField(
|
||||
max_length=150,
|
||||
validators=[UniqueValidator(queryset=User.objects.all().order_by("username"))],
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@ -143,6 +151,10 @@ class UserSerializer(ModelSerializer):
|
||||
instance.set_unusable_password()
|
||||
instance.save()
|
||||
|
||||
def get_avatar(self, user: User) -> str:
|
||||
"""User's avatar, either a http/https URL or a data URI"""
|
||||
return get_avatar(user, self.context["request"])
|
||||
|
||||
def validate_path(self, path: str) -> str:
|
||||
"""Validate path"""
|
||||
if path[:1] == "/" or path[-1] == "/":
|
||||
@ -197,12 +209,16 @@ class UserSelfSerializer(ModelSerializer):
|
||||
"""User Serializer for information a user can retrieve about themselves"""
|
||||
|
||||
is_superuser = BooleanField(read_only=True)
|
||||
avatar = CharField(read_only=True)
|
||||
avatar = SerializerMethodField()
|
||||
groups = SerializerMethodField()
|
||||
uid = CharField(read_only=True)
|
||||
settings = SerializerMethodField()
|
||||
system_permissions = SerializerMethodField()
|
||||
|
||||
def get_avatar(self, user: User) -> str:
|
||||
"""User's avatar, either a http/https URL or a data URI"""
|
||||
return get_avatar(user, self.context["request"])
|
||||
|
||||
@extend_schema_field(
|
||||
ListSerializer(
|
||||
child=inline_serializer(
|
||||
@ -329,11 +345,11 @@ class UsersFilter(FilterSet):
|
||||
groups_by_name = ModelMultipleChoiceFilter(
|
||||
field_name="ak_groups__name",
|
||||
to_field_name="name",
|
||||
queryset=Group.objects.all(),
|
||||
queryset=Group.objects.all().order_by("name"),
|
||||
)
|
||||
groups_by_pk = ModelMultipleChoiceFilter(
|
||||
field_name="ak_groups",
|
||||
queryset=Group.objects.all(),
|
||||
queryset=Group.objects.all().order_by("name"),
|
||||
)
|
||||
|
||||
def filter_attributes(self, queryset, name, value):
|
||||
@ -378,7 +394,7 @@ class UserViewSet(UsedByMixin, ModelViewSet):
|
||||
filterset_class = UsersFilter
|
||||
|
||||
def get_queryset(self): # pragma: no cover
|
||||
return User.objects.all().exclude(pk=get_anonymous_user().pk)
|
||||
return User.objects.all().exclude_anonymous().prefetch_related("ak_groups")
|
||||
|
||||
def _create_recovery_link(self) -> tuple[Optional[str], Optional[Token]]:
|
||||
"""Create a recovery link (when the current brand has a recovery flow set),
|
||||
|
@ -14,6 +14,7 @@ from django.http import HttpRequest
|
||||
from django.utils.functional import SimpleLazyObject, cached_property
|
||||
from django.utils.timezone import now
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from guardian.conf import settings
|
||||
from guardian.mixins import GuardianUserMixin
|
||||
from model_utils.managers import InheritanceManager
|
||||
from rest_framework.serializers import Serializer
|
||||
@ -169,13 +170,29 @@ class Group(SerializerModel):
|
||||
verbose_name_plural = _("Groups")
|
||||
|
||||
|
||||
class UserQuerySet(models.QuerySet):
|
||||
"""User queryset"""
|
||||
|
||||
def exclude_anonymous(self):
|
||||
"""Exclude anonymous user"""
|
||||
return self.exclude(**{User.USERNAME_FIELD: settings.ANONYMOUS_USER_NAME})
|
||||
|
||||
|
||||
class UserManager(DjangoUserManager):
|
||||
"""User manager that doesn't assign is_superuser and is_staff"""
|
||||
|
||||
def get_queryset(self):
|
||||
"""Create special user queryset"""
|
||||
return UserQuerySet(self.model, using=self._db)
|
||||
|
||||
def create_user(self, username, email=None, password=None, **extra_fields):
|
||||
"""User manager that doesn't assign is_superuser and is_staff"""
|
||||
return self._create_user(username, email, password, **extra_fields)
|
||||
|
||||
def exclude_anonymous(self) -> QuerySet:
|
||||
"""Exclude anonymous user"""
|
||||
return self.get_queryset().exclude_anonymous()
|
||||
|
||||
|
||||
class User(SerializerModel, GuardianUserMixin, AbstractUser):
|
||||
"""authentik User model, based on django's contrib auth user model."""
|
||||
|
@ -16,7 +16,6 @@ from django.db import models
|
||||
from django.db.models.query import QuerySet
|
||||
from django.utils.timezone import now
|
||||
from django.utils.translation import gettext as _
|
||||
from guardian.shortcuts import get_anonymous_user
|
||||
from jwt import PyJWTError, decode, get_unverified_header
|
||||
from rest_framework.exceptions import ValidationError
|
||||
from rest_framework.serializers import BaseSerializer
|
||||
@ -104,7 +103,7 @@ class LicenseKey:
|
||||
@staticmethod
|
||||
def base_user_qs() -> QuerySet:
|
||||
"""Base query set for all users"""
|
||||
return User.objects.all().exclude(is_active=False).exclude(pk=get_anonymous_user().pk)
|
||||
return User.objects.all().exclude_anonymous().exclude(is_active=False)
|
||||
|
||||
@staticmethod
|
||||
def get_default_user_count():
|
||||
|
@ -26,7 +26,7 @@ from authentik.flows.challenge import (
|
||||
from authentik.flows.exceptions import StageInvalidException
|
||||
from authentik.flows.models import InvalidResponseAction
|
||||
from authentik.flows.planner import PLAN_CONTEXT_APPLICATION, PLAN_CONTEXT_PENDING_USER
|
||||
from authentik.lib.avatars import DEFAULT_AVATAR
|
||||
from authentik.lib.avatars import DEFAULT_AVATAR, get_avatar
|
||||
from authentik.lib.utils.reflection import class_to_path
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -197,7 +197,7 @@ class ChallengeStageView(StageView):
|
||||
challenge.initial_data["pending_user"] = user.username
|
||||
challenge.initial_data["pending_user_avatar"] = DEFAULT_AVATAR
|
||||
if not isinstance(user, AnonymousUser):
|
||||
challenge.initial_data["pending_user_avatar"] = user.avatar
|
||||
challenge.initial_data["pending_user_avatar"] = get_avatar(user, self.request)
|
||||
return challenge
|
||||
|
||||
def get_challenge(self, *args, **kwargs) -> Challenge:
|
||||
|
@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.http import HttpRequest
|
||||
from django.templatetags.static import static
|
||||
from lxml import etree # nosec
|
||||
from lxml.etree import Element, SubElement # nosec
|
||||
@ -15,13 +16,13 @@ from authentik.lib.config import get_path_from_dict
|
||||
from authentik.lib.utils.http import get_http_session
|
||||
from authentik.tenants.utils import get_current_tenant
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from authentik.core.models import User
|
||||
|
||||
GRAVATAR_URL = "https://secure.gravatar.com"
|
||||
DEFAULT_AVATAR = static("dist/assets/images/user_default.png")
|
||||
CACHE_KEY_GRAVATAR = "goauthentik.io/lib/avatars/"
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from authentik.core.models import User
|
||||
|
||||
SVG_XML_NS = "http://www.w3.org/2000/svg"
|
||||
SVG_NS_MAP = {None: SVG_XML_NS}
|
||||
# Match fonts used in web UI
|
||||
@ -177,14 +178,19 @@ def avatar_mode_url(user: "User", mode: str) -> Optional[str]:
|
||||
}
|
||||
|
||||
|
||||
def get_avatar(user: "User") -> str:
|
||||
def get_avatar(user: "User", request: Optional[HttpRequest] = None) -> str:
|
||||
"""Get avatar with configured mode"""
|
||||
mode_map = {
|
||||
"none": avatar_mode_none,
|
||||
"initials": avatar_mode_generated,
|
||||
"gravatar": avatar_mode_gravatar,
|
||||
}
|
||||
modes: str = get_current_tenant().avatars
|
||||
tenant = None
|
||||
if request:
|
||||
tenant = request.tenant
|
||||
else:
|
||||
tenant = get_current_tenant()
|
||||
modes: str = tenant.avatars
|
||||
for mode in modes.split(","):
|
||||
avatar = None
|
||||
if mode in mode_map:
|
||||
|
@ -3,7 +3,6 @@ from django.core.cache import cache
|
||||
from django.db import models
|
||||
from django.db.models import QuerySet
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from guardian.shortcuts import get_anonymous_user
|
||||
from redis.lock import Lock
|
||||
from rest_framework.serializers import Serializer
|
||||
|
||||
@ -42,7 +41,7 @@ class SCIMProvider(BackchannelProvider):
|
||||
def get_user_qs(self) -> QuerySet[User]:
|
||||
"""Get queryset of all users with consistent ordering
|
||||
according to the provider's settings"""
|
||||
base = User.objects.all().exclude(pk=get_anonymous_user().pk)
|
||||
base = User.objects.all().exclude_anonymous()
|
||||
if self.exclude_users_service_account:
|
||||
base = base.exclude(type=UserTypes.SERVICE_ACCOUNT).exclude(
|
||||
type=UserTypes.INTERNAL_SERVICE_ACCOUNT
|
||||
|
@ -2,7 +2,6 @@
|
||||
from json import loads
|
||||
|
||||
from django.test import TestCase
|
||||
from guardian.shortcuts import get_anonymous_user
|
||||
from jsonschema import validate
|
||||
from requests_mock import Mocker
|
||||
|
||||
@ -19,7 +18,7 @@ class SCIMGroupTests(TestCase):
|
||||
def setUp(self) -> None:
|
||||
# Delete all users and groups as the mocked HTTP responses only return one ID
|
||||
# which will cause errors with multiple users
|
||||
User.objects.all().exclude(pk=get_anonymous_user().pk).delete()
|
||||
User.objects.all().exclude_anonymous().delete()
|
||||
Group.objects.all().delete()
|
||||
self.provider: SCIMProvider = SCIMProvider.objects.create(
|
||||
name=generate_id(),
|
||||
|
@ -1,6 +1,5 @@
|
||||
"""SCIM Membership tests"""
|
||||
from django.test import TestCase
|
||||
from guardian.shortcuts import get_anonymous_user
|
||||
from requests_mock import Mocker
|
||||
|
||||
from authentik.blueprints.tests import apply_blueprint
|
||||
@ -21,7 +20,7 @@ class SCIMMembershipTests(TestCase):
|
||||
def setUp(self) -> None:
|
||||
# Delete all users and groups as the mocked HTTP responses only return one ID
|
||||
# which will cause errors with multiple users
|
||||
User.objects.all().exclude(pk=get_anonymous_user().pk).delete()
|
||||
User.objects.all().exclude_anonymous().delete()
|
||||
Group.objects.all().delete()
|
||||
Tenant.objects.update(avatars="none")
|
||||
|
||||
|
@ -2,7 +2,6 @@
|
||||
from json import loads
|
||||
|
||||
from django.test import TestCase
|
||||
from guardian.shortcuts import get_anonymous_user
|
||||
from jsonschema import validate
|
||||
from requests_mock import Mocker
|
||||
|
||||
@ -22,7 +21,7 @@ class SCIMUserTests(TestCase):
|
||||
# Delete all users and groups as the mocked HTTP responses only return one ID
|
||||
# which will cause errors with multiple users
|
||||
Tenant.objects.update(avatars="none")
|
||||
User.objects.all().exclude(pk=get_anonymous_user().pk).delete()
|
||||
User.objects.all().exclude_anonymous().delete()
|
||||
Group.objects.all().delete()
|
||||
self.provider: SCIMProvider = SCIMProvider.objects.create(
|
||||
name=generate_id(),
|
||||
|
@ -54,7 +54,7 @@ func (cs *CryptoStore) getFingerprint(uuid string) string {
|
||||
func (cs *CryptoStore) Fetch(uuid string) error {
|
||||
cfp := cs.getFingerprint(uuid)
|
||||
if cfp == cs.fingerprints[uuid] {
|
||||
cs.log.WithField("uuid", uuid).Info("Fingerprint hasn't changed, not fetching cert")
|
||||
cs.log.WithField("uuid", uuid).Debug("Fingerprint hasn't changed, not fetching cert")
|
||||
return nil
|
||||
}
|
||||
cs.log.WithField("uuid", uuid).Info("Fetching certificate and private key")
|
||||
|
@ -78,7 +78,7 @@ func (w *Watcher) GetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, err
|
||||
if bestSelection == nil {
|
||||
return w.fallback, nil
|
||||
}
|
||||
cert := w.cs.Get(*bestSelection.WebCertificate.Get())
|
||||
cert := w.cs.Get(bestSelection.GetWebCertificate())
|
||||
if cert == nil {
|
||||
return w.fallback, nil
|
||||
}
|
||||
|
@ -3,7 +3,6 @@ from time import sleep
|
||||
from typing import Any, Optional
|
||||
|
||||
from docker.types import Healthcheck
|
||||
from guardian.utils import get_anonymous_user
|
||||
from selenium.webdriver.common.by import By
|
||||
from selenium.webdriver.common.keys import Keys
|
||||
from selenium.webdriver.support import expected_conditions as ec
|
||||
@ -161,7 +160,7 @@ class TestSourceSAML(SeleniumTestCase):
|
||||
self.assert_user(
|
||||
User.objects.exclude(username="akadmin")
|
||||
.exclude(username__startswith="ak-outpost")
|
||||
.exclude(pk=get_anonymous_user().pk)
|
||||
.exclude_anonymous()
|
||||
.exclude(pk=self.user.pk)
|
||||
.first()
|
||||
)
|
||||
@ -244,7 +243,7 @@ class TestSourceSAML(SeleniumTestCase):
|
||||
self.assert_user(
|
||||
User.objects.exclude(username="akadmin")
|
||||
.exclude(username__startswith="ak-outpost")
|
||||
.exclude(pk=get_anonymous_user().pk)
|
||||
.exclude_anonymous()
|
||||
.exclude(pk=self.user.pk)
|
||||
.first()
|
||||
)
|
||||
@ -314,7 +313,7 @@ class TestSourceSAML(SeleniumTestCase):
|
||||
self.assert_user(
|
||||
User.objects.exclude(username="akadmin")
|
||||
.exclude(username__startswith="ak-outpost")
|
||||
.exclude(pk=get_anonymous_user().pk)
|
||||
.exclude_anonymous()
|
||||
.exclude(pk=self.user.pk)
|
||||
.first()
|
||||
)
|
||||
|
Reference in New Issue
Block a user