
* use custom model serializer that saves m2m without bulk Signed-off-by: Jens Langhammer <jens@goauthentik.io> * sigh Signed-off-by: Jens Langhammer <jens@goauthentik.io> --------- Signed-off-by: Jens Langhammer <jens@goauthentik.io>
135 lines
4.8 KiB
Python
135 lines
4.8 KiB
Python
"""RAC Provider API Views"""
|
|
|
|
from django.core.cache import cache
|
|
from django.db.models import QuerySet
|
|
from django.urls import reverse
|
|
from drf_spectacular.types import OpenApiTypes
|
|
from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema
|
|
from rest_framework.fields import SerializerMethodField
|
|
from rest_framework.request import Request
|
|
from rest_framework.response import Response
|
|
from rest_framework.viewsets import ModelViewSet
|
|
from structlog.stdlib import get_logger
|
|
|
|
from authentik.core.api.used_by import UsedByMixin
|
|
from authentik.core.api.utils import ModelSerializer
|
|
from authentik.core.models import Provider
|
|
from authentik.enterprise.api import EnterpriseRequiredMixin
|
|
from authentik.enterprise.providers.rac.api.providers import RACProviderSerializer
|
|
from authentik.enterprise.providers.rac.models import Endpoint
|
|
from authentik.policies.engine import PolicyEngine
|
|
from authentik.rbac.filters import ObjectFilter
|
|
|
|
LOGGER = get_logger()
|
|
|
|
|
|
def user_endpoint_cache_key(user_pk: str) -> str:
|
|
"""Cache key where endpoint list for user is saved"""
|
|
return f"goauthentik.io/providers/rac/endpoint_access/{user_pk}"
|
|
|
|
|
|
class EndpointSerializer(EnterpriseRequiredMixin, ModelSerializer):
|
|
"""Endpoint Serializer"""
|
|
|
|
provider_obj = RACProviderSerializer(source="provider", read_only=True)
|
|
launch_url = SerializerMethodField()
|
|
|
|
def get_launch_url(self, endpoint: Endpoint) -> str | None:
|
|
"""Build actual launch URL (the provider itself does not have one, just
|
|
individual endpoints)"""
|
|
try:
|
|
|
|
return reverse(
|
|
"authentik_providers_rac:start",
|
|
kwargs={"app": endpoint.provider.application.slug, "endpoint": endpoint.pk},
|
|
)
|
|
except Provider.application.RelatedObjectDoesNotExist:
|
|
return None
|
|
|
|
class Meta:
|
|
model = Endpoint
|
|
fields = [
|
|
"pk",
|
|
"name",
|
|
"provider",
|
|
"provider_obj",
|
|
"protocol",
|
|
"host",
|
|
"settings",
|
|
"property_mappings",
|
|
"auth_mode",
|
|
"launch_url",
|
|
"maximum_connections",
|
|
]
|
|
|
|
|
|
class EndpointViewSet(UsedByMixin, ModelViewSet):
|
|
"""Endpoint Viewset"""
|
|
|
|
queryset = Endpoint.objects.all()
|
|
serializer_class = EndpointSerializer
|
|
filterset_fields = ["name", "provider"]
|
|
search_fields = ["name", "protocol"]
|
|
ordering = ["name", "protocol"]
|
|
|
|
def _filter_queryset_for_list(self, queryset: QuerySet) -> QuerySet:
|
|
"""Custom filter_queryset method which ignores guardian, but still supports sorting"""
|
|
for backend in list(self.filter_backends):
|
|
if backend == ObjectFilter:
|
|
continue
|
|
queryset = backend().filter_queryset(self.request, queryset, self)
|
|
return queryset
|
|
|
|
def _get_allowed_endpoints(self, queryset: QuerySet) -> list[Endpoint]:
|
|
endpoints = []
|
|
for endpoint in queryset:
|
|
engine = PolicyEngine(endpoint, self.request.user, self.request)
|
|
engine.build()
|
|
if engine.passing:
|
|
endpoints.append(endpoint)
|
|
return endpoints
|
|
|
|
@extend_schema(
|
|
parameters=[
|
|
OpenApiParameter(
|
|
"search",
|
|
OpenApiTypes.STR,
|
|
),
|
|
OpenApiParameter(
|
|
name="superuser_full_list",
|
|
location=OpenApiParameter.QUERY,
|
|
type=OpenApiTypes.BOOL,
|
|
),
|
|
],
|
|
responses={
|
|
200: EndpointSerializer(many=True),
|
|
400: OpenApiResponse(description="Bad request"),
|
|
},
|
|
)
|
|
def list(self, request: Request, *args, **kwargs) -> Response:
|
|
"""List accessible endpoints"""
|
|
should_cache = request.GET.get("search", "") == ""
|
|
|
|
superuser_full_list = str(request.GET.get("superuser_full_list", "false")).lower() == "true"
|
|
if superuser_full_list and request.user.is_superuser:
|
|
return super().list(request)
|
|
|
|
queryset = self._filter_queryset_for_list(self.get_queryset())
|
|
self.paginate_queryset(queryset)
|
|
|
|
allowed_endpoints = []
|
|
if not should_cache:
|
|
allowed_endpoints = self._get_allowed_endpoints(queryset)
|
|
if should_cache:
|
|
allowed_endpoints = cache.get(user_endpoint_cache_key(self.request.user.pk))
|
|
if not allowed_endpoints:
|
|
LOGGER.debug("Caching allowed endpoint list")
|
|
allowed_endpoints = self._get_allowed_endpoints(queryset)
|
|
cache.set(
|
|
user_endpoint_cache_key(self.request.user.pk),
|
|
allowed_endpoints,
|
|
timeout=86400,
|
|
)
|
|
serializer = self.get_serializer(allowed_endpoints, many=True)
|
|
return self.get_paginated_response(serializer.data)
|