Files
authentik/authentik/policies/api/policies.py
Jens L 2c781ae423 root: use custom model serializer that saves m2m without bulk (#10139)
* use custom model serializer that saves m2m without bulk

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* sigh

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2024-06-18 22:48:05 +09:00

162 lines
5.8 KiB
Python

"""policy API Views"""
from django.core.cache import cache
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiResponse, extend_schema
from guardian.shortcuts import get_objects_for_user
from rest_framework import mixins
from rest_framework.decorators import action
from rest_framework.fields import SerializerMethodField
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.viewsets import GenericViewSet
from structlog.stdlib import get_logger
from authentik.core.api.applications import user_app_cache_key
from authentik.core.api.object_types import TypesMixin
from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import (
CacheSerializer,
MetaNameSerializer,
ModelSerializer,
)
from authentik.events.logs import LogEventSerializer, capture_logs
from authentik.policies.api.exec import PolicyTestResultSerializer, PolicyTestSerializer
from authentik.policies.models import Policy, PolicyBinding
from authentik.policies.process import PolicyProcess
from authentik.policies.types import CACHE_PREFIX, PolicyRequest
from authentik.rbac.decorators import permission_required
LOGGER = get_logger()
class PolicySerializer(ModelSerializer, MetaNameSerializer):
"""Policy Serializer"""
_resolve_inheritance: bool
component = SerializerMethodField()
bound_to = SerializerMethodField()
def __init__(self, *args, resolve_inheritance: bool = True, **kwargs):
super().__init__(*args, **kwargs)
self._resolve_inheritance = resolve_inheritance
def get_component(self, obj: Policy) -> str: # pragma: no cover
"""Get object component so that we know how to edit the object"""
if obj.__class__ == Policy:
return ""
return obj.component
def get_bound_to(self, obj: Policy) -> int:
"""Return objects policy is bound to"""
return obj.bindings.count() + obj.promptstage_set.count()
def to_representation(self, instance: Policy):
if instance.__class__ == Policy or not self._resolve_inheritance:
return super().to_representation(instance)
return dict(instance.serializer(instance=instance, resolve_inheritance=False).data)
class Meta:
model = Policy
fields = [
"pk",
"name",
"execution_logging",
"component",
"verbose_name",
"verbose_name_plural",
"meta_model_name",
"bound_to",
]
depth = 3
class PolicyViewSet(
TypesMixin,
mixins.RetrieveModelMixin,
mixins.DestroyModelMixin,
UsedByMixin,
mixins.ListModelMixin,
GenericViewSet,
):
"""Policy Viewset"""
queryset = Policy.objects.all()
serializer_class = PolicySerializer
filterset_fields = {
"bindings": ["isnull"],
"promptstage": ["isnull"],
}
search_fields = ["name"]
ordering = ["name"]
def get_queryset(self): # pragma: no cover
return Policy.objects.select_subclasses().prefetch_related("bindings", "promptstage_set")
@permission_required(None, ["authentik_policies.view_policy_cache"])
@extend_schema(responses={200: CacheSerializer(many=False)})
@action(detail=False, pagination_class=None, filter_backends=[])
def cache_info(self, request: Request) -> Response:
"""Info about cached policies"""
return Response(data={"count": len(cache.keys(f"{CACHE_PREFIX}*"))})
@permission_required(None, ["authentik_policies.clear_policy_cache"])
@extend_schema(
request=OpenApiTypes.NONE,
responses={
204: OpenApiResponse(description="Successfully cleared cache"),
400: OpenApiResponse(description="Bad request"),
},
)
@action(detail=False, methods=["POST"])
def cache_clear(self, request: Request) -> Response:
"""Clear policy cache"""
keys = cache.keys(f"{CACHE_PREFIX}*")
cache.delete_many(keys)
LOGGER.debug("Cleared Policy cache", keys=len(keys))
# Also delete user application cache
keys = cache.keys(user_app_cache_key("*"))
cache.delete_many(keys)
return Response(status=204)
@permission_required("authentik_policies.view_policy")
@extend_schema(
request=PolicyTestSerializer(),
responses={
200: PolicyTestResultSerializer(),
400: OpenApiResponse(description="Invalid parameters"),
},
)
@action(detail=True, pagination_class=None, filter_backends=[], methods=["POST"])
def test(self, request: Request, pk: str) -> Response:
"""Test policy"""
policy = self.get_object()
test_params = PolicyTestSerializer(data=request.data)
if not test_params.is_valid():
return Response(test_params.errors, status=400)
# User permission check, only allow policy testing for users that are readable
users = get_objects_for_user(request.user, "authentik_core.view_user").filter(
pk=test_params.validated_data["user"].pk
)
if not users.exists():
return Response(status=400)
p_request = PolicyRequest(users.first())
p_request.debug = True
p_request.set_http_request(self.request)
p_request.context = test_params.validated_data.get("context", {})
proc = PolicyProcess(PolicyBinding(policy=policy), p_request, None)
with capture_logs() as logs:
result = proc.execute()
log_messages = []
for log in logs:
if log.attributes.get("process", "") == "PolicyProcess":
continue
log_messages.append(LogEventSerializer(log).data)
result.log_messages = log_messages
response = PolicyTestResultSerializer(result)
return Response(response.data)