core: optimise user list endpoint (#8353)

* unrelated changes

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* optimization pass 1: reduce N tenant lookups by taking tenant from request, reduce get_anonymous calls

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix lint

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* make it easier to exclude anonymous user

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix?

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens L
2024-01-30 01:55:26 +01:00
committed by GitHub
parent 0413afc2a8
commit 25e72558eb
15 changed files with 71 additions and 39 deletions

View File

@ -7,7 +7,6 @@ from django.contrib.auth import get_user_model
from django.db.models import Model, Q, QuerySet
from django.utils.timezone import now
from django.utils.translation import gettext as _
from guardian.shortcuts import get_anonymous_user
from yaml import dump
from authentik.blueprints.v1.common import (
@ -48,7 +47,7 @@ class Exporter:
"""Return a queryset for `model`. Can be used to filter some
objects on some models"""
if model == get_user_model():
return model.objects.exclude(pk=get_anonymous_user().pk)
return model.objects.exclude_anonymous()
return model.objects.all()
def _pre_export(self, blueprint: Blueprint):

View File

@ -8,7 +8,6 @@ from sentry_sdk.hub import Hub
from authentik import get_full_version
from authentik.brands.models import Brand
from authentik.tenants.utils import get_current_tenant
_q_default = Q(default=True)
DEFAULT_BRAND = Brand(domain="fallback")
@ -36,7 +35,7 @@ def context_processor(request: HttpRequest) -> dict[str, Any]:
trace = span.to_traceparent()
return {
"brand": brand,
"footer_links": get_current_tenant().footer_links,
"footer_links": request.tenant.footer_links,
"sentry_trace": trace,
"version": get_full_version(),
}

View File

@ -5,12 +5,12 @@ from typing import Optional
from django.core.cache import cache
from django.db.models import QuerySet
from django.db.models.functions import ExtractHour
from django.http.response import HttpResponseBadRequest
from django.shortcuts import get_object_or_404
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema
from guardian.shortcuts import get_objects_for_user
from rest_framework.decorators import action
from rest_framework.exceptions import ValidationError
from rest_framework.fields import CharField, ReadOnlyField, SerializerMethodField
from rest_framework.parsers import MultiPartParser
from rest_framework.request import Request
@ -147,7 +147,6 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet):
],
responses={
200: PolicyTestResultSerializer(),
404: OpenApiResponse(description="for_user user not found"),
},
)
@action(detail=True, methods=["GET"])
@ -160,9 +159,11 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet):
for_user = request.user
if request.user.is_superuser and "for_user" in request.query_params:
try:
for_user = get_object_or_404(User, pk=request.query_params.get("for_user"))
for_user = User.objects.filter(pk=request.query_params.get("for_user")).first()
except ValueError:
return HttpResponseBadRequest("for_user must be numerical")
raise ValidationError({"for_user": "for_user must be numerical"})
if not for_user:
raise ValidationError({"for_user": "User not found"})
engine = PolicyEngine(application, for_user, request)
engine.use_cache = False
with capture_logs() as logs:

View File

@ -30,7 +30,7 @@ from drf_spectacular.utils import (
extend_schema_field,
inline_serializer,
)
from guardian.shortcuts import get_anonymous_user, get_objects_for_user
from guardian.shortcuts import get_objects_for_user
from rest_framework.decorators import action
from rest_framework.fields import CharField, IntegerField, ListField, SerializerMethodField
from rest_framework.request import Request
@ -72,6 +72,7 @@ from authentik.flows.exceptions import FlowNonApplicableException
from authentik.flows.models import FlowToken
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlanner
from authentik.flows.views.executor import QS_KEY_TOKEN
from authentik.lib.avatars import get_avatar
from authentik.stages.email.models import EmailStage
from authentik.stages.email.tasks import send_mails
from authentik.stages.email.utils import TemplateEmailMessage
@ -102,14 +103,21 @@ class UserSerializer(ModelSerializer):
"""User Serializer"""
is_superuser = BooleanField(read_only=True)
avatar = CharField(read_only=True)
avatar = SerializerMethodField()
attributes = JSONDictField(required=False)
groups = PrimaryKeyRelatedField(
allow_empty=True, many=True, source="ak_groups", queryset=Group.objects.all(), default=list
allow_empty=True,
many=True,
source="ak_groups",
queryset=Group.objects.all().order_by("name"),
default=list,
)
groups_obj = ListSerializer(child=UserGroupSerializer(), read_only=True, source="ak_groups")
uid = CharField(read_only=True)
username = CharField(max_length=150, validators=[UniqueValidator(queryset=User.objects.all())])
username = CharField(
max_length=150,
validators=[UniqueValidator(queryset=User.objects.all().order_by("username"))],
)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@ -143,6 +151,10 @@ class UserSerializer(ModelSerializer):
instance.set_unusable_password()
instance.save()
def get_avatar(self, user: User) -> str:
"""User's avatar, either a http/https URL or a data URI"""
return get_avatar(user, self.context["request"])
def validate_path(self, path: str) -> str:
"""Validate path"""
if path[:1] == "/" or path[-1] == "/":
@ -197,12 +209,16 @@ class UserSelfSerializer(ModelSerializer):
"""User Serializer for information a user can retrieve about themselves"""
is_superuser = BooleanField(read_only=True)
avatar = CharField(read_only=True)
avatar = SerializerMethodField()
groups = SerializerMethodField()
uid = CharField(read_only=True)
settings = SerializerMethodField()
system_permissions = SerializerMethodField()
def get_avatar(self, user: User) -> str:
"""User's avatar, either a http/https URL or a data URI"""
return get_avatar(user, self.context["request"])
@extend_schema_field(
ListSerializer(
child=inline_serializer(
@ -329,11 +345,11 @@ class UsersFilter(FilterSet):
groups_by_name = ModelMultipleChoiceFilter(
field_name="ak_groups__name",
to_field_name="name",
queryset=Group.objects.all(),
queryset=Group.objects.all().order_by("name"),
)
groups_by_pk = ModelMultipleChoiceFilter(
field_name="ak_groups",
queryset=Group.objects.all(),
queryset=Group.objects.all().order_by("name"),
)
def filter_attributes(self, queryset, name, value):
@ -378,7 +394,7 @@ class UserViewSet(UsedByMixin, ModelViewSet):
filterset_class = UsersFilter
def get_queryset(self): # pragma: no cover
return User.objects.all().exclude(pk=get_anonymous_user().pk)
return User.objects.all().exclude_anonymous().prefetch_related("ak_groups")
def _create_recovery_link(self) -> tuple[Optional[str], Optional[Token]]:
"""Create a recovery link (when the current brand has a recovery flow set),

View File

@ -14,6 +14,7 @@ from django.http import HttpRequest
from django.utils.functional import SimpleLazyObject, cached_property
from django.utils.timezone import now
from django.utils.translation import gettext_lazy as _
from guardian.conf import settings
from guardian.mixins import GuardianUserMixin
from model_utils.managers import InheritanceManager
from rest_framework.serializers import Serializer
@ -169,13 +170,29 @@ class Group(SerializerModel):
verbose_name_plural = _("Groups")
class UserQuerySet(models.QuerySet):
"""User queryset"""
def exclude_anonymous(self):
"""Exclude anonymous user"""
return self.exclude(**{User.USERNAME_FIELD: settings.ANONYMOUS_USER_NAME})
class UserManager(DjangoUserManager):
"""User manager that doesn't assign is_superuser and is_staff"""
def get_queryset(self):
"""Create special user queryset"""
return UserQuerySet(self.model, using=self._db)
def create_user(self, username, email=None, password=None, **extra_fields):
"""User manager that doesn't assign is_superuser and is_staff"""
return self._create_user(username, email, password, **extra_fields)
def exclude_anonymous(self) -> QuerySet:
"""Exclude anonymous user"""
return self.get_queryset().exclude_anonymous()
class User(SerializerModel, GuardianUserMixin, AbstractUser):
"""authentik User model, based on django's contrib auth user model."""

View File

@ -16,7 +16,6 @@ from django.db import models
from django.db.models.query import QuerySet
from django.utils.timezone import now
from django.utils.translation import gettext as _
from guardian.shortcuts import get_anonymous_user
from jwt import PyJWTError, decode, get_unverified_header
from rest_framework.exceptions import ValidationError
from rest_framework.serializers import BaseSerializer
@ -104,7 +103,7 @@ class LicenseKey:
@staticmethod
def base_user_qs() -> QuerySet:
"""Base query set for all users"""
return User.objects.all().exclude(is_active=False).exclude(pk=get_anonymous_user().pk)
return User.objects.all().exclude_anonymous().exclude(is_active=False)
@staticmethod
def get_default_user_count():

View File

@ -26,7 +26,7 @@ from authentik.flows.challenge import (
from authentik.flows.exceptions import StageInvalidException
from authentik.flows.models import InvalidResponseAction
from authentik.flows.planner import PLAN_CONTEXT_APPLICATION, PLAN_CONTEXT_PENDING_USER
from authentik.lib.avatars import DEFAULT_AVATAR
from authentik.lib.avatars import DEFAULT_AVATAR, get_avatar
from authentik.lib.utils.reflection import class_to_path
if TYPE_CHECKING:
@ -197,7 +197,7 @@ class ChallengeStageView(StageView):
challenge.initial_data["pending_user"] = user.username
challenge.initial_data["pending_user_avatar"] = DEFAULT_AVATAR
if not isinstance(user, AnonymousUser):
challenge.initial_data["pending_user_avatar"] = user.avatar
challenge.initial_data["pending_user_avatar"] = get_avatar(user, self.request)
return challenge
def get_challenge(self, *args, **kwargs) -> Challenge:

View File

@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Optional
from urllib.parse import urlencode
from django.core.cache import cache
from django.http import HttpRequest
from django.templatetags.static import static
from lxml import etree # nosec
from lxml.etree import Element, SubElement # nosec
@ -15,13 +16,13 @@ from authentik.lib.config import get_path_from_dict
from authentik.lib.utils.http import get_http_session
from authentik.tenants.utils import get_current_tenant
if TYPE_CHECKING:
from authentik.core.models import User
GRAVATAR_URL = "https://secure.gravatar.com"
DEFAULT_AVATAR = static("dist/assets/images/user_default.png")
CACHE_KEY_GRAVATAR = "goauthentik.io/lib/avatars/"
if TYPE_CHECKING:
from authentik.core.models import User
SVG_XML_NS = "http://www.w3.org/2000/svg"
SVG_NS_MAP = {None: SVG_XML_NS}
# Match fonts used in web UI
@ -177,14 +178,19 @@ def avatar_mode_url(user: "User", mode: str) -> Optional[str]:
}
def get_avatar(user: "User") -> str:
def get_avatar(user: "User", request: Optional[HttpRequest] = None) -> str:
"""Get avatar with configured mode"""
mode_map = {
"none": avatar_mode_none,
"initials": avatar_mode_generated,
"gravatar": avatar_mode_gravatar,
}
modes: str = get_current_tenant().avatars
tenant = None
if request:
tenant = request.tenant
else:
tenant = get_current_tenant()
modes: str = tenant.avatars
for mode in modes.split(","):
avatar = None
if mode in mode_map:

View File

@ -3,7 +3,6 @@ from django.core.cache import cache
from django.db import models
from django.db.models import QuerySet
from django.utils.translation import gettext_lazy as _
from guardian.shortcuts import get_anonymous_user
from redis.lock import Lock
from rest_framework.serializers import Serializer
@ -42,7 +41,7 @@ class SCIMProvider(BackchannelProvider):
def get_user_qs(self) -> QuerySet[User]:
"""Get queryset of all users with consistent ordering
according to the provider's settings"""
base = User.objects.all().exclude(pk=get_anonymous_user().pk)
base = User.objects.all().exclude_anonymous()
if self.exclude_users_service_account:
base = base.exclude(type=UserTypes.SERVICE_ACCOUNT).exclude(
type=UserTypes.INTERNAL_SERVICE_ACCOUNT

View File

@ -2,7 +2,6 @@
from json import loads
from django.test import TestCase
from guardian.shortcuts import get_anonymous_user
from jsonschema import validate
from requests_mock import Mocker
@ -19,7 +18,7 @@ class SCIMGroupTests(TestCase):
def setUp(self) -> None:
# Delete all users and groups as the mocked HTTP responses only return one ID
# which will cause errors with multiple users
User.objects.all().exclude(pk=get_anonymous_user().pk).delete()
User.objects.all().exclude_anonymous().delete()
Group.objects.all().delete()
self.provider: SCIMProvider = SCIMProvider.objects.create(
name=generate_id(),

View File

@ -1,6 +1,5 @@
"""SCIM Membership tests"""
from django.test import TestCase
from guardian.shortcuts import get_anonymous_user
from requests_mock import Mocker
from authentik.blueprints.tests import apply_blueprint
@ -21,7 +20,7 @@ class SCIMMembershipTests(TestCase):
def setUp(self) -> None:
# Delete all users and groups as the mocked HTTP responses only return one ID
# which will cause errors with multiple users
User.objects.all().exclude(pk=get_anonymous_user().pk).delete()
User.objects.all().exclude_anonymous().delete()
Group.objects.all().delete()
Tenant.objects.update(avatars="none")

View File

@ -2,7 +2,6 @@
from json import loads
from django.test import TestCase
from guardian.shortcuts import get_anonymous_user
from jsonschema import validate
from requests_mock import Mocker
@ -22,7 +21,7 @@ class SCIMUserTests(TestCase):
# Delete all users and groups as the mocked HTTP responses only return one ID
# which will cause errors with multiple users
Tenant.objects.update(avatars="none")
User.objects.all().exclude(pk=get_anonymous_user().pk).delete()
User.objects.all().exclude_anonymous().delete()
Group.objects.all().delete()
self.provider: SCIMProvider = SCIMProvider.objects.create(
name=generate_id(),

View File

@ -54,7 +54,7 @@ func (cs *CryptoStore) getFingerprint(uuid string) string {
func (cs *CryptoStore) Fetch(uuid string) error {
cfp := cs.getFingerprint(uuid)
if cfp == cs.fingerprints[uuid] {
cs.log.WithField("uuid", uuid).Info("Fingerprint hasn't changed, not fetching cert")
cs.log.WithField("uuid", uuid).Debug("Fingerprint hasn't changed, not fetching cert")
return nil
}
cs.log.WithField("uuid", uuid).Info("Fetching certificate and private key")

View File

@ -78,7 +78,7 @@ func (w *Watcher) GetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, err
if bestSelection == nil {
return w.fallback, nil
}
cert := w.cs.Get(*bestSelection.WebCertificate.Get())
cert := w.cs.Get(bestSelection.GetWebCertificate())
if cert == nil {
return w.fallback, nil
}

View File

@ -3,7 +3,6 @@ from time import sleep
from typing import Any, Optional
from docker.types import Healthcheck
from guardian.utils import get_anonymous_user
from selenium.webdriver.common.by import By
from selenium.webdriver.common.keys import Keys
from selenium.webdriver.support import expected_conditions as ec
@ -161,7 +160,7 @@ class TestSourceSAML(SeleniumTestCase):
self.assert_user(
User.objects.exclude(username="akadmin")
.exclude(username__startswith="ak-outpost")
.exclude(pk=get_anonymous_user().pk)
.exclude_anonymous()
.exclude(pk=self.user.pk)
.first()
)
@ -244,7 +243,7 @@ class TestSourceSAML(SeleniumTestCase):
self.assert_user(
User.objects.exclude(username="akadmin")
.exclude(username__startswith="ak-outpost")
.exclude(pk=get_anonymous_user().pk)
.exclude_anonymous()
.exclude(pk=self.user.pk)
.first()
)
@ -314,7 +313,7 @@ class TestSourceSAML(SeleniumTestCase):
self.assert_user(
User.objects.exclude(username="akadmin")
.exclude(username__startswith="ak-outpost")
.exclude(pk=get_anonymous_user().pk)
.exclude_anonymous()
.exclude(pk=self.user.pk)
.first()
)