core: optimise user list endpoint (#8353)
* unrelated changes Signed-off-by: Jens Langhammer <jens@goauthentik.io> * optimization pass 1: reduce N tenant lookups by taking tenant from request, reduce get_anonymous calls Signed-off-by: Jens Langhammer <jens@goauthentik.io> * fix lint Signed-off-by: Jens Langhammer <jens@goauthentik.io> * fix Signed-off-by: Jens Langhammer <jens@goauthentik.io> * make it easier to exclude anonymous user Signed-off-by: Jens Langhammer <jens@goauthentik.io> * fix? Signed-off-by: Jens Langhammer <jens@goauthentik.io> --------- Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
@ -7,7 +7,6 @@ from django.contrib.auth import get_user_model
|
|||||||
from django.db.models import Model, Q, QuerySet
|
from django.db.models import Model, Q, QuerySet
|
||||||
from django.utils.timezone import now
|
from django.utils.timezone import now
|
||||||
from django.utils.translation import gettext as _
|
from django.utils.translation import gettext as _
|
||||||
from guardian.shortcuts import get_anonymous_user
|
|
||||||
from yaml import dump
|
from yaml import dump
|
||||||
|
|
||||||
from authentik.blueprints.v1.common import (
|
from authentik.blueprints.v1.common import (
|
||||||
@ -48,7 +47,7 @@ class Exporter:
|
|||||||
"""Return a queryset for `model`. Can be used to filter some
|
"""Return a queryset for `model`. Can be used to filter some
|
||||||
objects on some models"""
|
objects on some models"""
|
||||||
if model == get_user_model():
|
if model == get_user_model():
|
||||||
return model.objects.exclude(pk=get_anonymous_user().pk)
|
return model.objects.exclude_anonymous()
|
||||||
return model.objects.all()
|
return model.objects.all()
|
||||||
|
|
||||||
def _pre_export(self, blueprint: Blueprint):
|
def _pre_export(self, blueprint: Blueprint):
|
||||||
|
|||||||
@ -8,7 +8,6 @@ from sentry_sdk.hub import Hub
|
|||||||
|
|
||||||
from authentik import get_full_version
|
from authentik import get_full_version
|
||||||
from authentik.brands.models import Brand
|
from authentik.brands.models import Brand
|
||||||
from authentik.tenants.utils import get_current_tenant
|
|
||||||
|
|
||||||
_q_default = Q(default=True)
|
_q_default = Q(default=True)
|
||||||
DEFAULT_BRAND = Brand(domain="fallback")
|
DEFAULT_BRAND = Brand(domain="fallback")
|
||||||
@ -36,7 +35,7 @@ def context_processor(request: HttpRequest) -> dict[str, Any]:
|
|||||||
trace = span.to_traceparent()
|
trace = span.to_traceparent()
|
||||||
return {
|
return {
|
||||||
"brand": brand,
|
"brand": brand,
|
||||||
"footer_links": get_current_tenant().footer_links,
|
"footer_links": request.tenant.footer_links,
|
||||||
"sentry_trace": trace,
|
"sentry_trace": trace,
|
||||||
"version": get_full_version(),
|
"version": get_full_version(),
|
||||||
}
|
}
|
||||||
|
|||||||
@ -5,12 +5,12 @@ from typing import Optional
|
|||||||
from django.core.cache import cache
|
from django.core.cache import cache
|
||||||
from django.db.models import QuerySet
|
from django.db.models import QuerySet
|
||||||
from django.db.models.functions import ExtractHour
|
from django.db.models.functions import ExtractHour
|
||||||
from django.http.response import HttpResponseBadRequest
|
|
||||||
from django.shortcuts import get_object_or_404
|
from django.shortcuts import get_object_or_404
|
||||||
from drf_spectacular.types import OpenApiTypes
|
from drf_spectacular.types import OpenApiTypes
|
||||||
from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema
|
from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema
|
||||||
from guardian.shortcuts import get_objects_for_user
|
from guardian.shortcuts import get_objects_for_user
|
||||||
from rest_framework.decorators import action
|
from rest_framework.decorators import action
|
||||||
|
from rest_framework.exceptions import ValidationError
|
||||||
from rest_framework.fields import CharField, ReadOnlyField, SerializerMethodField
|
from rest_framework.fields import CharField, ReadOnlyField, SerializerMethodField
|
||||||
from rest_framework.parsers import MultiPartParser
|
from rest_framework.parsers import MultiPartParser
|
||||||
from rest_framework.request import Request
|
from rest_framework.request import Request
|
||||||
@ -147,7 +147,6 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet):
|
|||||||
],
|
],
|
||||||
responses={
|
responses={
|
||||||
200: PolicyTestResultSerializer(),
|
200: PolicyTestResultSerializer(),
|
||||||
404: OpenApiResponse(description="for_user user not found"),
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@action(detail=True, methods=["GET"])
|
@action(detail=True, methods=["GET"])
|
||||||
@ -160,9 +159,11 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet):
|
|||||||
for_user = request.user
|
for_user = request.user
|
||||||
if request.user.is_superuser and "for_user" in request.query_params:
|
if request.user.is_superuser and "for_user" in request.query_params:
|
||||||
try:
|
try:
|
||||||
for_user = get_object_or_404(User, pk=request.query_params.get("for_user"))
|
for_user = User.objects.filter(pk=request.query_params.get("for_user")).first()
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return HttpResponseBadRequest("for_user must be numerical")
|
raise ValidationError({"for_user": "for_user must be numerical"})
|
||||||
|
if not for_user:
|
||||||
|
raise ValidationError({"for_user": "User not found"})
|
||||||
engine = PolicyEngine(application, for_user, request)
|
engine = PolicyEngine(application, for_user, request)
|
||||||
engine.use_cache = False
|
engine.use_cache = False
|
||||||
with capture_logs() as logs:
|
with capture_logs() as logs:
|
||||||
|
|||||||
@ -30,7 +30,7 @@ from drf_spectacular.utils import (
|
|||||||
extend_schema_field,
|
extend_schema_field,
|
||||||
inline_serializer,
|
inline_serializer,
|
||||||
)
|
)
|
||||||
from guardian.shortcuts import get_anonymous_user, get_objects_for_user
|
from guardian.shortcuts import get_objects_for_user
|
||||||
from rest_framework.decorators import action
|
from rest_framework.decorators import action
|
||||||
from rest_framework.fields import CharField, IntegerField, ListField, SerializerMethodField
|
from rest_framework.fields import CharField, IntegerField, ListField, SerializerMethodField
|
||||||
from rest_framework.request import Request
|
from rest_framework.request import Request
|
||||||
@ -72,6 +72,7 @@ from authentik.flows.exceptions import FlowNonApplicableException
|
|||||||
from authentik.flows.models import FlowToken
|
from authentik.flows.models import FlowToken
|
||||||
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlanner
|
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlanner
|
||||||
from authentik.flows.views.executor import QS_KEY_TOKEN
|
from authentik.flows.views.executor import QS_KEY_TOKEN
|
||||||
|
from authentik.lib.avatars import get_avatar
|
||||||
from authentik.stages.email.models import EmailStage
|
from authentik.stages.email.models import EmailStage
|
||||||
from authentik.stages.email.tasks import send_mails
|
from authentik.stages.email.tasks import send_mails
|
||||||
from authentik.stages.email.utils import TemplateEmailMessage
|
from authentik.stages.email.utils import TemplateEmailMessage
|
||||||
@ -102,14 +103,21 @@ class UserSerializer(ModelSerializer):
|
|||||||
"""User Serializer"""
|
"""User Serializer"""
|
||||||
|
|
||||||
is_superuser = BooleanField(read_only=True)
|
is_superuser = BooleanField(read_only=True)
|
||||||
avatar = CharField(read_only=True)
|
avatar = SerializerMethodField()
|
||||||
attributes = JSONDictField(required=False)
|
attributes = JSONDictField(required=False)
|
||||||
groups = PrimaryKeyRelatedField(
|
groups = PrimaryKeyRelatedField(
|
||||||
allow_empty=True, many=True, source="ak_groups", queryset=Group.objects.all(), default=list
|
allow_empty=True,
|
||||||
|
many=True,
|
||||||
|
source="ak_groups",
|
||||||
|
queryset=Group.objects.all().order_by("name"),
|
||||||
|
default=list,
|
||||||
)
|
)
|
||||||
groups_obj = ListSerializer(child=UserGroupSerializer(), read_only=True, source="ak_groups")
|
groups_obj = ListSerializer(child=UserGroupSerializer(), read_only=True, source="ak_groups")
|
||||||
uid = CharField(read_only=True)
|
uid = CharField(read_only=True)
|
||||||
username = CharField(max_length=150, validators=[UniqueValidator(queryset=User.objects.all())])
|
username = CharField(
|
||||||
|
max_length=150,
|
||||||
|
validators=[UniqueValidator(queryset=User.objects.all().order_by("username"))],
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
@ -143,6 +151,10 @@ class UserSerializer(ModelSerializer):
|
|||||||
instance.set_unusable_password()
|
instance.set_unusable_password()
|
||||||
instance.save()
|
instance.save()
|
||||||
|
|
||||||
|
def get_avatar(self, user: User) -> str:
|
||||||
|
"""User's avatar, either a http/https URL or a data URI"""
|
||||||
|
return get_avatar(user, self.context["request"])
|
||||||
|
|
||||||
def validate_path(self, path: str) -> str:
|
def validate_path(self, path: str) -> str:
|
||||||
"""Validate path"""
|
"""Validate path"""
|
||||||
if path[:1] == "/" or path[-1] == "/":
|
if path[:1] == "/" or path[-1] == "/":
|
||||||
@ -197,12 +209,16 @@ class UserSelfSerializer(ModelSerializer):
|
|||||||
"""User Serializer for information a user can retrieve about themselves"""
|
"""User Serializer for information a user can retrieve about themselves"""
|
||||||
|
|
||||||
is_superuser = BooleanField(read_only=True)
|
is_superuser = BooleanField(read_only=True)
|
||||||
avatar = CharField(read_only=True)
|
avatar = SerializerMethodField()
|
||||||
groups = SerializerMethodField()
|
groups = SerializerMethodField()
|
||||||
uid = CharField(read_only=True)
|
uid = CharField(read_only=True)
|
||||||
settings = SerializerMethodField()
|
settings = SerializerMethodField()
|
||||||
system_permissions = SerializerMethodField()
|
system_permissions = SerializerMethodField()
|
||||||
|
|
||||||
|
def get_avatar(self, user: User) -> str:
|
||||||
|
"""User's avatar, either a http/https URL or a data URI"""
|
||||||
|
return get_avatar(user, self.context["request"])
|
||||||
|
|
||||||
@extend_schema_field(
|
@extend_schema_field(
|
||||||
ListSerializer(
|
ListSerializer(
|
||||||
child=inline_serializer(
|
child=inline_serializer(
|
||||||
@ -329,11 +345,11 @@ class UsersFilter(FilterSet):
|
|||||||
groups_by_name = ModelMultipleChoiceFilter(
|
groups_by_name = ModelMultipleChoiceFilter(
|
||||||
field_name="ak_groups__name",
|
field_name="ak_groups__name",
|
||||||
to_field_name="name",
|
to_field_name="name",
|
||||||
queryset=Group.objects.all(),
|
queryset=Group.objects.all().order_by("name"),
|
||||||
)
|
)
|
||||||
groups_by_pk = ModelMultipleChoiceFilter(
|
groups_by_pk = ModelMultipleChoiceFilter(
|
||||||
field_name="ak_groups",
|
field_name="ak_groups",
|
||||||
queryset=Group.objects.all(),
|
queryset=Group.objects.all().order_by("name"),
|
||||||
)
|
)
|
||||||
|
|
||||||
def filter_attributes(self, queryset, name, value):
|
def filter_attributes(self, queryset, name, value):
|
||||||
@ -378,7 +394,7 @@ class UserViewSet(UsedByMixin, ModelViewSet):
|
|||||||
filterset_class = UsersFilter
|
filterset_class = UsersFilter
|
||||||
|
|
||||||
def get_queryset(self): # pragma: no cover
|
def get_queryset(self): # pragma: no cover
|
||||||
return User.objects.all().exclude(pk=get_anonymous_user().pk)
|
return User.objects.all().exclude_anonymous().prefetch_related("ak_groups")
|
||||||
|
|
||||||
def _create_recovery_link(self) -> tuple[Optional[str], Optional[Token]]:
|
def _create_recovery_link(self) -> tuple[Optional[str], Optional[Token]]:
|
||||||
"""Create a recovery link (when the current brand has a recovery flow set),
|
"""Create a recovery link (when the current brand has a recovery flow set),
|
||||||
|
|||||||
@ -14,6 +14,7 @@ from django.http import HttpRequest
|
|||||||
from django.utils.functional import SimpleLazyObject, cached_property
|
from django.utils.functional import SimpleLazyObject, cached_property
|
||||||
from django.utils.timezone import now
|
from django.utils.timezone import now
|
||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
|
from guardian.conf import settings
|
||||||
from guardian.mixins import GuardianUserMixin
|
from guardian.mixins import GuardianUserMixin
|
||||||
from model_utils.managers import InheritanceManager
|
from model_utils.managers import InheritanceManager
|
||||||
from rest_framework.serializers import Serializer
|
from rest_framework.serializers import Serializer
|
||||||
@ -169,13 +170,29 @@ class Group(SerializerModel):
|
|||||||
verbose_name_plural = _("Groups")
|
verbose_name_plural = _("Groups")
|
||||||
|
|
||||||
|
|
||||||
|
class UserQuerySet(models.QuerySet):
|
||||||
|
"""User queryset"""
|
||||||
|
|
||||||
|
def exclude_anonymous(self):
|
||||||
|
"""Exclude anonymous user"""
|
||||||
|
return self.exclude(**{User.USERNAME_FIELD: settings.ANONYMOUS_USER_NAME})
|
||||||
|
|
||||||
|
|
||||||
class UserManager(DjangoUserManager):
|
class UserManager(DjangoUserManager):
|
||||||
"""User manager that doesn't assign is_superuser and is_staff"""
|
"""User manager that doesn't assign is_superuser and is_staff"""
|
||||||
|
|
||||||
|
def get_queryset(self):
|
||||||
|
"""Create special user queryset"""
|
||||||
|
return UserQuerySet(self.model, using=self._db)
|
||||||
|
|
||||||
def create_user(self, username, email=None, password=None, **extra_fields):
|
def create_user(self, username, email=None, password=None, **extra_fields):
|
||||||
"""User manager that doesn't assign is_superuser and is_staff"""
|
"""User manager that doesn't assign is_superuser and is_staff"""
|
||||||
return self._create_user(username, email, password, **extra_fields)
|
return self._create_user(username, email, password, **extra_fields)
|
||||||
|
|
||||||
|
def exclude_anonymous(self) -> QuerySet:
|
||||||
|
"""Exclude anonymous user"""
|
||||||
|
return self.get_queryset().exclude_anonymous()
|
||||||
|
|
||||||
|
|
||||||
class User(SerializerModel, GuardianUserMixin, AbstractUser):
|
class User(SerializerModel, GuardianUserMixin, AbstractUser):
|
||||||
"""authentik User model, based on django's contrib auth user model."""
|
"""authentik User model, based on django's contrib auth user model."""
|
||||||
|
|||||||
@ -16,7 +16,6 @@ from django.db import models
|
|||||||
from django.db.models.query import QuerySet
|
from django.db.models.query import QuerySet
|
||||||
from django.utils.timezone import now
|
from django.utils.timezone import now
|
||||||
from django.utils.translation import gettext as _
|
from django.utils.translation import gettext as _
|
||||||
from guardian.shortcuts import get_anonymous_user
|
|
||||||
from jwt import PyJWTError, decode, get_unverified_header
|
from jwt import PyJWTError, decode, get_unverified_header
|
||||||
from rest_framework.exceptions import ValidationError
|
from rest_framework.exceptions import ValidationError
|
||||||
from rest_framework.serializers import BaseSerializer
|
from rest_framework.serializers import BaseSerializer
|
||||||
@ -104,7 +103,7 @@ class LicenseKey:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def base_user_qs() -> QuerySet:
|
def base_user_qs() -> QuerySet:
|
||||||
"""Base query set for all users"""
|
"""Base query set for all users"""
|
||||||
return User.objects.all().exclude(is_active=False).exclude(pk=get_anonymous_user().pk)
|
return User.objects.all().exclude_anonymous().exclude(is_active=False)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_default_user_count():
|
def get_default_user_count():
|
||||||
|
|||||||
@ -26,7 +26,7 @@ from authentik.flows.challenge import (
|
|||||||
from authentik.flows.exceptions import StageInvalidException
|
from authentik.flows.exceptions import StageInvalidException
|
||||||
from authentik.flows.models import InvalidResponseAction
|
from authentik.flows.models import InvalidResponseAction
|
||||||
from authentik.flows.planner import PLAN_CONTEXT_APPLICATION, PLAN_CONTEXT_PENDING_USER
|
from authentik.flows.planner import PLAN_CONTEXT_APPLICATION, PLAN_CONTEXT_PENDING_USER
|
||||||
from authentik.lib.avatars import DEFAULT_AVATAR
|
from authentik.lib.avatars import DEFAULT_AVATAR, get_avatar
|
||||||
from authentik.lib.utils.reflection import class_to_path
|
from authentik.lib.utils.reflection import class_to_path
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -197,7 +197,7 @@ class ChallengeStageView(StageView):
|
|||||||
challenge.initial_data["pending_user"] = user.username
|
challenge.initial_data["pending_user"] = user.username
|
||||||
challenge.initial_data["pending_user_avatar"] = DEFAULT_AVATAR
|
challenge.initial_data["pending_user_avatar"] = DEFAULT_AVATAR
|
||||||
if not isinstance(user, AnonymousUser):
|
if not isinstance(user, AnonymousUser):
|
||||||
challenge.initial_data["pending_user_avatar"] = user.avatar
|
challenge.initial_data["pending_user_avatar"] = get_avatar(user, self.request)
|
||||||
return challenge
|
return challenge
|
||||||
|
|
||||||
def get_challenge(self, *args, **kwargs) -> Challenge:
|
def get_challenge(self, *args, **kwargs) -> Challenge:
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Optional
|
|||||||
from urllib.parse import urlencode
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
from django.core.cache import cache
|
from django.core.cache import cache
|
||||||
|
from django.http import HttpRequest
|
||||||
from django.templatetags.static import static
|
from django.templatetags.static import static
|
||||||
from lxml import etree # nosec
|
from lxml import etree # nosec
|
||||||
from lxml.etree import Element, SubElement # nosec
|
from lxml.etree import Element, SubElement # nosec
|
||||||
@ -15,13 +16,13 @@ from authentik.lib.config import get_path_from_dict
|
|||||||
from authentik.lib.utils.http import get_http_session
|
from authentik.lib.utils.http import get_http_session
|
||||||
from authentik.tenants.utils import get_current_tenant
|
from authentik.tenants.utils import get_current_tenant
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from authentik.core.models import User
|
||||||
|
|
||||||
GRAVATAR_URL = "https://secure.gravatar.com"
|
GRAVATAR_URL = "https://secure.gravatar.com"
|
||||||
DEFAULT_AVATAR = static("dist/assets/images/user_default.png")
|
DEFAULT_AVATAR = static("dist/assets/images/user_default.png")
|
||||||
CACHE_KEY_GRAVATAR = "goauthentik.io/lib/avatars/"
|
CACHE_KEY_GRAVATAR = "goauthentik.io/lib/avatars/"
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from authentik.core.models import User
|
|
||||||
|
|
||||||
SVG_XML_NS = "http://www.w3.org/2000/svg"
|
SVG_XML_NS = "http://www.w3.org/2000/svg"
|
||||||
SVG_NS_MAP = {None: SVG_XML_NS}
|
SVG_NS_MAP = {None: SVG_XML_NS}
|
||||||
# Match fonts used in web UI
|
# Match fonts used in web UI
|
||||||
@ -177,14 +178,19 @@ def avatar_mode_url(user: "User", mode: str) -> Optional[str]:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_avatar(user: "User") -> str:
|
def get_avatar(user: "User", request: Optional[HttpRequest] = None) -> str:
|
||||||
"""Get avatar with configured mode"""
|
"""Get avatar with configured mode"""
|
||||||
mode_map = {
|
mode_map = {
|
||||||
"none": avatar_mode_none,
|
"none": avatar_mode_none,
|
||||||
"initials": avatar_mode_generated,
|
"initials": avatar_mode_generated,
|
||||||
"gravatar": avatar_mode_gravatar,
|
"gravatar": avatar_mode_gravatar,
|
||||||
}
|
}
|
||||||
modes: str = get_current_tenant().avatars
|
tenant = None
|
||||||
|
if request:
|
||||||
|
tenant = request.tenant
|
||||||
|
else:
|
||||||
|
tenant = get_current_tenant()
|
||||||
|
modes: str = tenant.avatars
|
||||||
for mode in modes.split(","):
|
for mode in modes.split(","):
|
||||||
avatar = None
|
avatar = None
|
||||||
if mode in mode_map:
|
if mode in mode_map:
|
||||||
|
|||||||
@ -3,7 +3,6 @@ from django.core.cache import cache
|
|||||||
from django.db import models
|
from django.db import models
|
||||||
from django.db.models import QuerySet
|
from django.db.models import QuerySet
|
||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
from guardian.shortcuts import get_anonymous_user
|
|
||||||
from redis.lock import Lock
|
from redis.lock import Lock
|
||||||
from rest_framework.serializers import Serializer
|
from rest_framework.serializers import Serializer
|
||||||
|
|
||||||
@ -42,7 +41,7 @@ class SCIMProvider(BackchannelProvider):
|
|||||||
def get_user_qs(self) -> QuerySet[User]:
|
def get_user_qs(self) -> QuerySet[User]:
|
||||||
"""Get queryset of all users with consistent ordering
|
"""Get queryset of all users with consistent ordering
|
||||||
according to the provider's settings"""
|
according to the provider's settings"""
|
||||||
base = User.objects.all().exclude(pk=get_anonymous_user().pk)
|
base = User.objects.all().exclude_anonymous()
|
||||||
if self.exclude_users_service_account:
|
if self.exclude_users_service_account:
|
||||||
base = base.exclude(type=UserTypes.SERVICE_ACCOUNT).exclude(
|
base = base.exclude(type=UserTypes.SERVICE_ACCOUNT).exclude(
|
||||||
type=UserTypes.INTERNAL_SERVICE_ACCOUNT
|
type=UserTypes.INTERNAL_SERVICE_ACCOUNT
|
||||||
|
|||||||
@ -2,7 +2,6 @@
|
|||||||
from json import loads
|
from json import loads
|
||||||
|
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
from guardian.shortcuts import get_anonymous_user
|
|
||||||
from jsonschema import validate
|
from jsonschema import validate
|
||||||
from requests_mock import Mocker
|
from requests_mock import Mocker
|
||||||
|
|
||||||
@ -19,7 +18,7 @@ class SCIMGroupTests(TestCase):
|
|||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
# Delete all users and groups as the mocked HTTP responses only return one ID
|
# Delete all users and groups as the mocked HTTP responses only return one ID
|
||||||
# which will cause errors with multiple users
|
# which will cause errors with multiple users
|
||||||
User.objects.all().exclude(pk=get_anonymous_user().pk).delete()
|
User.objects.all().exclude_anonymous().delete()
|
||||||
Group.objects.all().delete()
|
Group.objects.all().delete()
|
||||||
self.provider: SCIMProvider = SCIMProvider.objects.create(
|
self.provider: SCIMProvider = SCIMProvider.objects.create(
|
||||||
name=generate_id(),
|
name=generate_id(),
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
"""SCIM Membership tests"""
|
"""SCIM Membership tests"""
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
from guardian.shortcuts import get_anonymous_user
|
|
||||||
from requests_mock import Mocker
|
from requests_mock import Mocker
|
||||||
|
|
||||||
from authentik.blueprints.tests import apply_blueprint
|
from authentik.blueprints.tests import apply_blueprint
|
||||||
@ -21,7 +20,7 @@ class SCIMMembershipTests(TestCase):
|
|||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
# Delete all users and groups as the mocked HTTP responses only return one ID
|
# Delete all users and groups as the mocked HTTP responses only return one ID
|
||||||
# which will cause errors with multiple users
|
# which will cause errors with multiple users
|
||||||
User.objects.all().exclude(pk=get_anonymous_user().pk).delete()
|
User.objects.all().exclude_anonymous().delete()
|
||||||
Group.objects.all().delete()
|
Group.objects.all().delete()
|
||||||
Tenant.objects.update(avatars="none")
|
Tenant.objects.update(avatars="none")
|
||||||
|
|
||||||
|
|||||||
@ -2,7 +2,6 @@
|
|||||||
from json import loads
|
from json import loads
|
||||||
|
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
from guardian.shortcuts import get_anonymous_user
|
|
||||||
from jsonschema import validate
|
from jsonschema import validate
|
||||||
from requests_mock import Mocker
|
from requests_mock import Mocker
|
||||||
|
|
||||||
@ -22,7 +21,7 @@ class SCIMUserTests(TestCase):
|
|||||||
# Delete all users and groups as the mocked HTTP responses only return one ID
|
# Delete all users and groups as the mocked HTTP responses only return one ID
|
||||||
# which will cause errors with multiple users
|
# which will cause errors with multiple users
|
||||||
Tenant.objects.update(avatars="none")
|
Tenant.objects.update(avatars="none")
|
||||||
User.objects.all().exclude(pk=get_anonymous_user().pk).delete()
|
User.objects.all().exclude_anonymous().delete()
|
||||||
Group.objects.all().delete()
|
Group.objects.all().delete()
|
||||||
self.provider: SCIMProvider = SCIMProvider.objects.create(
|
self.provider: SCIMProvider = SCIMProvider.objects.create(
|
||||||
name=generate_id(),
|
name=generate_id(),
|
||||||
|
|||||||
@ -54,7 +54,7 @@ func (cs *CryptoStore) getFingerprint(uuid string) string {
|
|||||||
func (cs *CryptoStore) Fetch(uuid string) error {
|
func (cs *CryptoStore) Fetch(uuid string) error {
|
||||||
cfp := cs.getFingerprint(uuid)
|
cfp := cs.getFingerprint(uuid)
|
||||||
if cfp == cs.fingerprints[uuid] {
|
if cfp == cs.fingerprints[uuid] {
|
||||||
cs.log.WithField("uuid", uuid).Info("Fingerprint hasn't changed, not fetching cert")
|
cs.log.WithField("uuid", uuid).Debug("Fingerprint hasn't changed, not fetching cert")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
cs.log.WithField("uuid", uuid).Info("Fetching certificate and private key")
|
cs.log.WithField("uuid", uuid).Info("Fetching certificate and private key")
|
||||||
|
|||||||
@ -78,7 +78,7 @@ func (w *Watcher) GetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, err
|
|||||||
if bestSelection == nil {
|
if bestSelection == nil {
|
||||||
return w.fallback, nil
|
return w.fallback, nil
|
||||||
}
|
}
|
||||||
cert := w.cs.Get(*bestSelection.WebCertificate.Get())
|
cert := w.cs.Get(bestSelection.GetWebCertificate())
|
||||||
if cert == nil {
|
if cert == nil {
|
||||||
return w.fallback, nil
|
return w.fallback, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@ -3,7 +3,6 @@ from time import sleep
|
|||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from docker.types import Healthcheck
|
from docker.types import Healthcheck
|
||||||
from guardian.utils import get_anonymous_user
|
|
||||||
from selenium.webdriver.common.by import By
|
from selenium.webdriver.common.by import By
|
||||||
from selenium.webdriver.common.keys import Keys
|
from selenium.webdriver.common.keys import Keys
|
||||||
from selenium.webdriver.support import expected_conditions as ec
|
from selenium.webdriver.support import expected_conditions as ec
|
||||||
@ -161,7 +160,7 @@ class TestSourceSAML(SeleniumTestCase):
|
|||||||
self.assert_user(
|
self.assert_user(
|
||||||
User.objects.exclude(username="akadmin")
|
User.objects.exclude(username="akadmin")
|
||||||
.exclude(username__startswith="ak-outpost")
|
.exclude(username__startswith="ak-outpost")
|
||||||
.exclude(pk=get_anonymous_user().pk)
|
.exclude_anonymous()
|
||||||
.exclude(pk=self.user.pk)
|
.exclude(pk=self.user.pk)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
@ -244,7 +243,7 @@ class TestSourceSAML(SeleniumTestCase):
|
|||||||
self.assert_user(
|
self.assert_user(
|
||||||
User.objects.exclude(username="akadmin")
|
User.objects.exclude(username="akadmin")
|
||||||
.exclude(username__startswith="ak-outpost")
|
.exclude(username__startswith="ak-outpost")
|
||||||
.exclude(pk=get_anonymous_user().pk)
|
.exclude_anonymous()
|
||||||
.exclude(pk=self.user.pk)
|
.exclude(pk=self.user.pk)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
@ -314,7 +313,7 @@ class TestSourceSAML(SeleniumTestCase):
|
|||||||
self.assert_user(
|
self.assert_user(
|
||||||
User.objects.exclude(username="akadmin")
|
User.objects.exclude(username="akadmin")
|
||||||
.exclude(username__startswith="ak-outpost")
|
.exclude(username__startswith="ak-outpost")
|
||||||
.exclude(pk=get_anonymous_user().pk)
|
.exclude_anonymous()
|
||||||
.exclude(pk=self.user.pk)
|
.exclude(pk=self.user.pk)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user