core: optimise user list endpoint (#8353)

* unrelated changes

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* optimization pass 1: reduce N tenant lookups by taking tenant from request, reduce get_anonymous calls

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix lint

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* make it easier to exclude anonymous user

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix?

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens L
2024-01-30 01:55:26 +01:00
committed by GitHub
parent 0413afc2a8
commit 25e72558eb
15 changed files with 71 additions and 39 deletions

View File

@ -7,7 +7,6 @@ from django.contrib.auth import get_user_model
from django.db.models import Model, Q, QuerySet from django.db.models import Model, Q, QuerySet
from django.utils.timezone import now from django.utils.timezone import now
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from guardian.shortcuts import get_anonymous_user
from yaml import dump from yaml import dump
from authentik.blueprints.v1.common import ( from authentik.blueprints.v1.common import (
@ -48,7 +47,7 @@ class Exporter:
"""Return a queryset for `model`. Can be used to filter some """Return a queryset for `model`. Can be used to filter some
objects on some models""" objects on some models"""
if model == get_user_model(): if model == get_user_model():
return model.objects.exclude(pk=get_anonymous_user().pk) return model.objects.exclude_anonymous()
return model.objects.all() return model.objects.all()
def _pre_export(self, blueprint: Blueprint): def _pre_export(self, blueprint: Blueprint):

View File

@ -8,7 +8,6 @@ from sentry_sdk.hub import Hub
from authentik import get_full_version from authentik import get_full_version
from authentik.brands.models import Brand from authentik.brands.models import Brand
from authentik.tenants.utils import get_current_tenant
_q_default = Q(default=True) _q_default = Q(default=True)
DEFAULT_BRAND = Brand(domain="fallback") DEFAULT_BRAND = Brand(domain="fallback")
@ -36,7 +35,7 @@ def context_processor(request: HttpRequest) -> dict[str, Any]:
trace = span.to_traceparent() trace = span.to_traceparent()
return { return {
"brand": brand, "brand": brand,
"footer_links": get_current_tenant().footer_links, "footer_links": request.tenant.footer_links,
"sentry_trace": trace, "sentry_trace": trace,
"version": get_full_version(), "version": get_full_version(),
} }

View File

@ -5,12 +5,12 @@ from typing import Optional
from django.core.cache import cache from django.core.cache import cache
from django.db.models import QuerySet from django.db.models import QuerySet
from django.db.models.functions import ExtractHour from django.db.models.functions import ExtractHour
from django.http.response import HttpResponseBadRequest
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
from drf_spectacular.types import OpenApiTypes from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema
from guardian.shortcuts import get_objects_for_user from guardian.shortcuts import get_objects_for_user
from rest_framework.decorators import action from rest_framework.decorators import action
from rest_framework.exceptions import ValidationError
from rest_framework.fields import CharField, ReadOnlyField, SerializerMethodField from rest_framework.fields import CharField, ReadOnlyField, SerializerMethodField
from rest_framework.parsers import MultiPartParser from rest_framework.parsers import MultiPartParser
from rest_framework.request import Request from rest_framework.request import Request
@ -147,7 +147,6 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet):
], ],
responses={ responses={
200: PolicyTestResultSerializer(), 200: PolicyTestResultSerializer(),
404: OpenApiResponse(description="for_user user not found"),
}, },
) )
@action(detail=True, methods=["GET"]) @action(detail=True, methods=["GET"])
@ -160,9 +159,11 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet):
for_user = request.user for_user = request.user
if request.user.is_superuser and "for_user" in request.query_params: if request.user.is_superuser and "for_user" in request.query_params:
try: try:
for_user = get_object_or_404(User, pk=request.query_params.get("for_user")) for_user = User.objects.filter(pk=request.query_params.get("for_user")).first()
except ValueError: except ValueError:
return HttpResponseBadRequest("for_user must be numerical") raise ValidationError({"for_user": "for_user must be numerical"})
if not for_user:
raise ValidationError({"for_user": "User not found"})
engine = PolicyEngine(application, for_user, request) engine = PolicyEngine(application, for_user, request)
engine.use_cache = False engine.use_cache = False
with capture_logs() as logs: with capture_logs() as logs:

View File

@ -30,7 +30,7 @@ from drf_spectacular.utils import (
extend_schema_field, extend_schema_field,
inline_serializer, inline_serializer,
) )
from guardian.shortcuts import get_anonymous_user, get_objects_for_user from guardian.shortcuts import get_objects_for_user
from rest_framework.decorators import action from rest_framework.decorators import action
from rest_framework.fields import CharField, IntegerField, ListField, SerializerMethodField from rest_framework.fields import CharField, IntegerField, ListField, SerializerMethodField
from rest_framework.request import Request from rest_framework.request import Request
@ -72,6 +72,7 @@ from authentik.flows.exceptions import FlowNonApplicableException
from authentik.flows.models import FlowToken from authentik.flows.models import FlowToken
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlanner from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlanner
from authentik.flows.views.executor import QS_KEY_TOKEN from authentik.flows.views.executor import QS_KEY_TOKEN
from authentik.lib.avatars import get_avatar
from authentik.stages.email.models import EmailStage from authentik.stages.email.models import EmailStage
from authentik.stages.email.tasks import send_mails from authentik.stages.email.tasks import send_mails
from authentik.stages.email.utils import TemplateEmailMessage from authentik.stages.email.utils import TemplateEmailMessage
@ -102,14 +103,21 @@ class UserSerializer(ModelSerializer):
"""User Serializer""" """User Serializer"""
is_superuser = BooleanField(read_only=True) is_superuser = BooleanField(read_only=True)
avatar = CharField(read_only=True) avatar = SerializerMethodField()
attributes = JSONDictField(required=False) attributes = JSONDictField(required=False)
groups = PrimaryKeyRelatedField( groups = PrimaryKeyRelatedField(
allow_empty=True, many=True, source="ak_groups", queryset=Group.objects.all(), default=list allow_empty=True,
many=True,
source="ak_groups",
queryset=Group.objects.all().order_by("name"),
default=list,
) )
groups_obj = ListSerializer(child=UserGroupSerializer(), read_only=True, source="ak_groups") groups_obj = ListSerializer(child=UserGroupSerializer(), read_only=True, source="ak_groups")
uid = CharField(read_only=True) uid = CharField(read_only=True)
username = CharField(max_length=150, validators=[UniqueValidator(queryset=User.objects.all())]) username = CharField(
max_length=150,
validators=[UniqueValidator(queryset=User.objects.all().order_by("username"))],
)
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -143,6 +151,10 @@ class UserSerializer(ModelSerializer):
instance.set_unusable_password() instance.set_unusable_password()
instance.save() instance.save()
def get_avatar(self, user: User) -> str:
"""User's avatar, either a http/https URL or a data URI"""
return get_avatar(user, self.context["request"])
def validate_path(self, path: str) -> str: def validate_path(self, path: str) -> str:
"""Validate path""" """Validate path"""
if path[:1] == "/" or path[-1] == "/": if path[:1] == "/" or path[-1] == "/":
@ -197,12 +209,16 @@ class UserSelfSerializer(ModelSerializer):
"""User Serializer for information a user can retrieve about themselves""" """User Serializer for information a user can retrieve about themselves"""
is_superuser = BooleanField(read_only=True) is_superuser = BooleanField(read_only=True)
avatar = CharField(read_only=True) avatar = SerializerMethodField()
groups = SerializerMethodField() groups = SerializerMethodField()
uid = CharField(read_only=True) uid = CharField(read_only=True)
settings = SerializerMethodField() settings = SerializerMethodField()
system_permissions = SerializerMethodField() system_permissions = SerializerMethodField()
def get_avatar(self, user: User) -> str:
"""User's avatar, either a http/https URL or a data URI"""
return get_avatar(user, self.context["request"])
@extend_schema_field( @extend_schema_field(
ListSerializer( ListSerializer(
child=inline_serializer( child=inline_serializer(
@ -329,11 +345,11 @@ class UsersFilter(FilterSet):
groups_by_name = ModelMultipleChoiceFilter( groups_by_name = ModelMultipleChoiceFilter(
field_name="ak_groups__name", field_name="ak_groups__name",
to_field_name="name", to_field_name="name",
queryset=Group.objects.all(), queryset=Group.objects.all().order_by("name"),
) )
groups_by_pk = ModelMultipleChoiceFilter( groups_by_pk = ModelMultipleChoiceFilter(
field_name="ak_groups", field_name="ak_groups",
queryset=Group.objects.all(), queryset=Group.objects.all().order_by("name"),
) )
def filter_attributes(self, queryset, name, value): def filter_attributes(self, queryset, name, value):
@ -378,7 +394,7 @@ class UserViewSet(UsedByMixin, ModelViewSet):
filterset_class = UsersFilter filterset_class = UsersFilter
def get_queryset(self): # pragma: no cover def get_queryset(self): # pragma: no cover
return User.objects.all().exclude(pk=get_anonymous_user().pk) return User.objects.all().exclude_anonymous().prefetch_related("ak_groups")
def _create_recovery_link(self) -> tuple[Optional[str], Optional[Token]]: def _create_recovery_link(self) -> tuple[Optional[str], Optional[Token]]:
"""Create a recovery link (when the current brand has a recovery flow set), """Create a recovery link (when the current brand has a recovery flow set),

View File

@ -14,6 +14,7 @@ from django.http import HttpRequest
from django.utils.functional import SimpleLazyObject, cached_property from django.utils.functional import SimpleLazyObject, cached_property
from django.utils.timezone import now from django.utils.timezone import now
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from guardian.conf import settings
from guardian.mixins import GuardianUserMixin from guardian.mixins import GuardianUserMixin
from model_utils.managers import InheritanceManager from model_utils.managers import InheritanceManager
from rest_framework.serializers import Serializer from rest_framework.serializers import Serializer
@ -169,13 +170,29 @@ class Group(SerializerModel):
verbose_name_plural = _("Groups") verbose_name_plural = _("Groups")
class UserQuerySet(models.QuerySet):
"""User queryset"""
def exclude_anonymous(self):
"""Exclude anonymous user"""
return self.exclude(**{User.USERNAME_FIELD: settings.ANONYMOUS_USER_NAME})
class UserManager(DjangoUserManager): class UserManager(DjangoUserManager):
"""User manager that doesn't assign is_superuser and is_staff""" """User manager that doesn't assign is_superuser and is_staff"""
def get_queryset(self):
"""Create special user queryset"""
return UserQuerySet(self.model, using=self._db)
def create_user(self, username, email=None, password=None, **extra_fields): def create_user(self, username, email=None, password=None, **extra_fields):
"""User manager that doesn't assign is_superuser and is_staff""" """User manager that doesn't assign is_superuser and is_staff"""
return self._create_user(username, email, password, **extra_fields) return self._create_user(username, email, password, **extra_fields)
def exclude_anonymous(self) -> QuerySet:
"""Exclude anonymous user"""
return self.get_queryset().exclude_anonymous()
class User(SerializerModel, GuardianUserMixin, AbstractUser): class User(SerializerModel, GuardianUserMixin, AbstractUser):
"""authentik User model, based on django's contrib auth user model.""" """authentik User model, based on django's contrib auth user model."""

View File

@ -16,7 +16,6 @@ from django.db import models
from django.db.models.query import QuerySet from django.db.models.query import QuerySet
from django.utils.timezone import now from django.utils.timezone import now
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from guardian.shortcuts import get_anonymous_user
from jwt import PyJWTError, decode, get_unverified_header from jwt import PyJWTError, decode, get_unverified_header
from rest_framework.exceptions import ValidationError from rest_framework.exceptions import ValidationError
from rest_framework.serializers import BaseSerializer from rest_framework.serializers import BaseSerializer
@ -104,7 +103,7 @@ class LicenseKey:
@staticmethod @staticmethod
def base_user_qs() -> QuerySet: def base_user_qs() -> QuerySet:
"""Base query set for all users""" """Base query set for all users"""
return User.objects.all().exclude(is_active=False).exclude(pk=get_anonymous_user().pk) return User.objects.all().exclude_anonymous().exclude(is_active=False)
@staticmethod @staticmethod
def get_default_user_count(): def get_default_user_count():

View File

@ -26,7 +26,7 @@ from authentik.flows.challenge import (
from authentik.flows.exceptions import StageInvalidException from authentik.flows.exceptions import StageInvalidException
from authentik.flows.models import InvalidResponseAction from authentik.flows.models import InvalidResponseAction
from authentik.flows.planner import PLAN_CONTEXT_APPLICATION, PLAN_CONTEXT_PENDING_USER from authentik.flows.planner import PLAN_CONTEXT_APPLICATION, PLAN_CONTEXT_PENDING_USER
from authentik.lib.avatars import DEFAULT_AVATAR from authentik.lib.avatars import DEFAULT_AVATAR, get_avatar
from authentik.lib.utils.reflection import class_to_path from authentik.lib.utils.reflection import class_to_path
if TYPE_CHECKING: if TYPE_CHECKING:
@ -197,7 +197,7 @@ class ChallengeStageView(StageView):
challenge.initial_data["pending_user"] = user.username challenge.initial_data["pending_user"] = user.username
challenge.initial_data["pending_user_avatar"] = DEFAULT_AVATAR challenge.initial_data["pending_user_avatar"] = DEFAULT_AVATAR
if not isinstance(user, AnonymousUser): if not isinstance(user, AnonymousUser):
challenge.initial_data["pending_user_avatar"] = user.avatar challenge.initial_data["pending_user_avatar"] = get_avatar(user, self.request)
return challenge return challenge
def get_challenge(self, *args, **kwargs) -> Challenge: def get_challenge(self, *args, **kwargs) -> Challenge:

View File

@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Optional
from urllib.parse import urlencode from urllib.parse import urlencode
from django.core.cache import cache from django.core.cache import cache
from django.http import HttpRequest
from django.templatetags.static import static from django.templatetags.static import static
from lxml import etree # nosec from lxml import etree # nosec
from lxml.etree import Element, SubElement # nosec from lxml.etree import Element, SubElement # nosec
@ -15,13 +16,13 @@ from authentik.lib.config import get_path_from_dict
from authentik.lib.utils.http import get_http_session from authentik.lib.utils.http import get_http_session
from authentik.tenants.utils import get_current_tenant from authentik.tenants.utils import get_current_tenant
if TYPE_CHECKING:
from authentik.core.models import User
GRAVATAR_URL = "https://secure.gravatar.com" GRAVATAR_URL = "https://secure.gravatar.com"
DEFAULT_AVATAR = static("dist/assets/images/user_default.png") DEFAULT_AVATAR = static("dist/assets/images/user_default.png")
CACHE_KEY_GRAVATAR = "goauthentik.io/lib/avatars/" CACHE_KEY_GRAVATAR = "goauthentik.io/lib/avatars/"
if TYPE_CHECKING:
from authentik.core.models import User
SVG_XML_NS = "http://www.w3.org/2000/svg" SVG_XML_NS = "http://www.w3.org/2000/svg"
SVG_NS_MAP = {None: SVG_XML_NS} SVG_NS_MAP = {None: SVG_XML_NS}
# Match fonts used in web UI # Match fonts used in web UI
@ -177,14 +178,19 @@ def avatar_mode_url(user: "User", mode: str) -> Optional[str]:
} }
def get_avatar(user: "User") -> str: def get_avatar(user: "User", request: Optional[HttpRequest] = None) -> str:
"""Get avatar with configured mode""" """Get avatar with configured mode"""
mode_map = { mode_map = {
"none": avatar_mode_none, "none": avatar_mode_none,
"initials": avatar_mode_generated, "initials": avatar_mode_generated,
"gravatar": avatar_mode_gravatar, "gravatar": avatar_mode_gravatar,
} }
modes: str = get_current_tenant().avatars tenant = None
if request:
tenant = request.tenant
else:
tenant = get_current_tenant()
modes: str = tenant.avatars
for mode in modes.split(","): for mode in modes.split(","):
avatar = None avatar = None
if mode in mode_map: if mode in mode_map:

View File

@ -3,7 +3,6 @@ from django.core.cache import cache
from django.db import models from django.db import models
from django.db.models import QuerySet from django.db.models import QuerySet
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from guardian.shortcuts import get_anonymous_user
from redis.lock import Lock from redis.lock import Lock
from rest_framework.serializers import Serializer from rest_framework.serializers import Serializer
@ -42,7 +41,7 @@ class SCIMProvider(BackchannelProvider):
def get_user_qs(self) -> QuerySet[User]: def get_user_qs(self) -> QuerySet[User]:
"""Get queryset of all users with consistent ordering """Get queryset of all users with consistent ordering
according to the provider's settings""" according to the provider's settings"""
base = User.objects.all().exclude(pk=get_anonymous_user().pk) base = User.objects.all().exclude_anonymous()
if self.exclude_users_service_account: if self.exclude_users_service_account:
base = base.exclude(type=UserTypes.SERVICE_ACCOUNT).exclude( base = base.exclude(type=UserTypes.SERVICE_ACCOUNT).exclude(
type=UserTypes.INTERNAL_SERVICE_ACCOUNT type=UserTypes.INTERNAL_SERVICE_ACCOUNT

View File

@ -2,7 +2,6 @@
from json import loads from json import loads
from django.test import TestCase from django.test import TestCase
from guardian.shortcuts import get_anonymous_user
from jsonschema import validate from jsonschema import validate
from requests_mock import Mocker from requests_mock import Mocker
@ -19,7 +18,7 @@ class SCIMGroupTests(TestCase):
def setUp(self) -> None: def setUp(self) -> None:
# Delete all users and groups as the mocked HTTP responses only return one ID # Delete all users and groups as the mocked HTTP responses only return one ID
# which will cause errors with multiple users # which will cause errors with multiple users
User.objects.all().exclude(pk=get_anonymous_user().pk).delete() User.objects.all().exclude_anonymous().delete()
Group.objects.all().delete() Group.objects.all().delete()
self.provider: SCIMProvider = SCIMProvider.objects.create( self.provider: SCIMProvider = SCIMProvider.objects.create(
name=generate_id(), name=generate_id(),

View File

@ -1,6 +1,5 @@
"""SCIM Membership tests""" """SCIM Membership tests"""
from django.test import TestCase from django.test import TestCase
from guardian.shortcuts import get_anonymous_user
from requests_mock import Mocker from requests_mock import Mocker
from authentik.blueprints.tests import apply_blueprint from authentik.blueprints.tests import apply_blueprint
@ -21,7 +20,7 @@ class SCIMMembershipTests(TestCase):
def setUp(self) -> None: def setUp(self) -> None:
# Delete all users and groups as the mocked HTTP responses only return one ID # Delete all users and groups as the mocked HTTP responses only return one ID
# which will cause errors with multiple users # which will cause errors with multiple users
User.objects.all().exclude(pk=get_anonymous_user().pk).delete() User.objects.all().exclude_anonymous().delete()
Group.objects.all().delete() Group.objects.all().delete()
Tenant.objects.update(avatars="none") Tenant.objects.update(avatars="none")

View File

@ -2,7 +2,6 @@
from json import loads from json import loads
from django.test import TestCase from django.test import TestCase
from guardian.shortcuts import get_anonymous_user
from jsonschema import validate from jsonschema import validate
from requests_mock import Mocker from requests_mock import Mocker
@ -22,7 +21,7 @@ class SCIMUserTests(TestCase):
# Delete all users and groups as the mocked HTTP responses only return one ID # Delete all users and groups as the mocked HTTP responses only return one ID
# which will cause errors with multiple users # which will cause errors with multiple users
Tenant.objects.update(avatars="none") Tenant.objects.update(avatars="none")
User.objects.all().exclude(pk=get_anonymous_user().pk).delete() User.objects.all().exclude_anonymous().delete()
Group.objects.all().delete() Group.objects.all().delete()
self.provider: SCIMProvider = SCIMProvider.objects.create( self.provider: SCIMProvider = SCIMProvider.objects.create(
name=generate_id(), name=generate_id(),

View File

@ -54,7 +54,7 @@ func (cs *CryptoStore) getFingerprint(uuid string) string {
func (cs *CryptoStore) Fetch(uuid string) error { func (cs *CryptoStore) Fetch(uuid string) error {
cfp := cs.getFingerprint(uuid) cfp := cs.getFingerprint(uuid)
if cfp == cs.fingerprints[uuid] { if cfp == cs.fingerprints[uuid] {
cs.log.WithField("uuid", uuid).Info("Fingerprint hasn't changed, not fetching cert") cs.log.WithField("uuid", uuid).Debug("Fingerprint hasn't changed, not fetching cert")
return nil return nil
} }
cs.log.WithField("uuid", uuid).Info("Fetching certificate and private key") cs.log.WithField("uuid", uuid).Info("Fetching certificate and private key")

View File

@ -78,7 +78,7 @@ func (w *Watcher) GetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, err
if bestSelection == nil { if bestSelection == nil {
return w.fallback, nil return w.fallback, nil
} }
cert := w.cs.Get(*bestSelection.WebCertificate.Get()) cert := w.cs.Get(bestSelection.GetWebCertificate())
if cert == nil { if cert == nil {
return w.fallback, nil return w.fallback, nil
} }

View File

@ -3,7 +3,6 @@ from time import sleep
from typing import Any, Optional from typing import Any, Optional
from docker.types import Healthcheck from docker.types import Healthcheck
from guardian.utils import get_anonymous_user
from selenium.webdriver.common.by import By from selenium.webdriver.common.by import By
from selenium.webdriver.common.keys import Keys from selenium.webdriver.common.keys import Keys
from selenium.webdriver.support import expected_conditions as ec from selenium.webdriver.support import expected_conditions as ec
@ -161,7 +160,7 @@ class TestSourceSAML(SeleniumTestCase):
self.assert_user( self.assert_user(
User.objects.exclude(username="akadmin") User.objects.exclude(username="akadmin")
.exclude(username__startswith="ak-outpost") .exclude(username__startswith="ak-outpost")
.exclude(pk=get_anonymous_user().pk) .exclude_anonymous()
.exclude(pk=self.user.pk) .exclude(pk=self.user.pk)
.first() .first()
) )
@ -244,7 +243,7 @@ class TestSourceSAML(SeleniumTestCase):
self.assert_user( self.assert_user(
User.objects.exclude(username="akadmin") User.objects.exclude(username="akadmin")
.exclude(username__startswith="ak-outpost") .exclude(username__startswith="ak-outpost")
.exclude(pk=get_anonymous_user().pk) .exclude_anonymous()
.exclude(pk=self.user.pk) .exclude(pk=self.user.pk)
.first() .first()
) )
@ -314,7 +313,7 @@ class TestSourceSAML(SeleniumTestCase):
self.assert_user( self.assert_user(
User.objects.exclude(username="akadmin") User.objects.exclude(username="akadmin")
.exclude(username__startswith="ak-outpost") .exclude(username__startswith="ak-outpost")
.exclude(pk=get_anonymous_user().pk) .exclude_anonymous()
.exclude(pk=self.user.pk) .exclude(pk=self.user.pk)
.first() .first()
) )