From 977757f561eaebe8c7940d7629307b98e55442ba Mon Sep 17 00:00:00 2001 From: Jens L Date: Thu, 6 Apr 2023 09:42:29 +0200 Subject: [PATCH] policies: provider raw result for better policy reusability (#5189) * policies: include raw_result in PolicyResult Signed-off-by: Jens Langhammer * move ak_call_policy to base evaluator Signed-off-by: Jens Langhammer --------- Signed-off-by: Jens Langhammer --- authentik/core/tests/test_property_mapping.py | 32 ++++++++++++++--- authentik/lib/expression/evaluator.py | 35 ++++++++++++++----- authentik/policies/expression/evaluator.py | 14 +------- authentik/policies/types.py | 4 ++- website/docs/expressions/_functions.md | 23 ++++++++++++ website/docs/policies/expression.mdx | 21 ----------- 6 files changed, 81 insertions(+), 48 deletions(-) diff --git a/authentik/core/tests/test_property_mapping.py b/authentik/core/tests/test_property_mapping.py index a7fce579c2..ff73b7d228 100644 --- a/authentik/core/tests/test_property_mapping.py +++ b/authentik/core/tests/test_property_mapping.py @@ -4,7 +4,10 @@ from guardian.shortcuts import get_anonymous_user from authentik.core.exceptions import PropertyMappingExpressionException from authentik.core.models import PropertyMapping +from authentik.core.tests.utils import create_test_admin_user from authentik.events.models import Event, EventAction +from authentik.lib.generators import generate_id +from authentik.policies.expression.models import ExpressionPolicy class TestPropertyMappings(TestCase): @@ -12,23 +15,24 @@ class TestPropertyMappings(TestCase): def setUp(self) -> None: super().setUp() + self.user = create_test_admin_user() self.factory = RequestFactory() def test_expression(self): """Test expression""" - mapping = PropertyMapping.objects.create(name="test", expression="return 'test'") + mapping = PropertyMapping.objects.create(name=generate_id(), expression="return 'test'") self.assertEqual(mapping.evaluate(None, None), "test") def test_expression_syntax(self): """Test expression syntax error""" - mapping = PropertyMapping.objects.create(name="test", expression="-") + mapping = PropertyMapping.objects.create(name=generate_id(), expression="-") with self.assertRaises(PropertyMappingExpressionException): mapping.evaluate(None, None) def test_expression_error_general(self): """Test expression error""" expr = "return aaa" - mapping = PropertyMapping.objects.create(name="test", expression=expr) + mapping = PropertyMapping.objects.create(name=generate_id(), expression=expr) with self.assertRaises(PropertyMappingExpressionException): mapping.evaluate(None, None) events = Event.objects.filter( @@ -41,7 +45,7 @@ class TestPropertyMappings(TestCase): """Test expression error (with user and http request""" expr = "return aaa" request = self.factory.get("/") - mapping = PropertyMapping.objects.create(name="test", expression=expr) + mapping = PropertyMapping.objects.create(name=generate_id(), expression=expr) with self.assertRaises(PropertyMappingExpressionException): mapping.evaluate(get_anonymous_user(), request) events = Event.objects.filter( @@ -52,3 +56,23 @@ class TestPropertyMappings(TestCase): event = events.first() self.assertEqual(event.user["username"], "AnonymousUser") self.assertEqual(event.client_ip, "127.0.0.1") + + def test_call_policy(self): + """test ak_call_policy""" + expr = ExpressionPolicy.objects.create( + name=generate_id(), + execution_logging=True, + expression="return request.http_request.path", + ) + http_request = self.factory.get("/") + tmpl = ( + """ + res = ak_call_policy('%s') + result = [request.http_request.path, res.raw_result] + return result + """ + % expr.name + ) + evaluator = PropertyMapping(expression=tmpl, name=generate_id()) + res = evaluator.evaluate(self.user, http_request) + self.assertEqual(res, ["/", "/"]) diff --git a/authentik/lib/expression/evaluator.py b/authentik/lib/expression/evaluator.py index 851e9da24e..2365e4494e 100644 --- a/authentik/lib/expression/evaluator.py +++ b/authentik/lib/expression/evaluator.py @@ -8,6 +8,7 @@ from typing import Any, Iterable, Optional from cachetools import TLRUCache, cached from django.core.exceptions import FieldError from django_otp import devices_for_user +from guardian.shortcuts import get_anonymous_user from rest_framework.serializers import ValidationError from sentry_sdk.hub import Hub from sentry_sdk.tracing import Span @@ -16,7 +17,9 @@ from structlog.stdlib import get_logger from authentik.core.models import User from authentik.events.models import Event from authentik.lib.utils.http import get_http_session -from authentik.policies.types import PolicyRequest +from authentik.policies.models import Policy, PolicyBinding +from authentik.policies.process import PolicyProcess +from authentik.policies.types import PolicyRequest, PolicyResult LOGGER = get_logger() @@ -37,19 +40,20 @@ class BaseEvaluator: # update website/docs/expressions/_objects.md # update website/docs/expressions/_functions.md self._globals = { - "regex_match": BaseEvaluator.expr_regex_match, - "regex_replace": BaseEvaluator.expr_regex_replace, - "list_flatten": BaseEvaluator.expr_flatten, + "ak_call_policy": self.expr_func_call_policy, + "ak_create_event": self.expr_event_create, "ak_is_group_member": BaseEvaluator.expr_is_group_member, + "ak_logger": get_logger(self._filename).bind(), "ak_user_by": BaseEvaluator.expr_user_by, "ak_user_has_authenticator": BaseEvaluator.expr_func_user_has_authenticator, - "resolve_dns": BaseEvaluator.expr_resolve_dns, - "reverse_dns": BaseEvaluator.expr_reverse_dns, - "ak_create_event": self.expr_event_create, - "ak_logger": get_logger(self._filename).bind(), - "requests": get_http_session(), "ip_address": ip_address, "ip_network": ip_network, + "list_flatten": BaseEvaluator.expr_flatten, + "regex_match": BaseEvaluator.expr_regex_match, + "regex_replace": BaseEvaluator.expr_regex_replace, + "requests": get_http_session(), + "resolve_dns": BaseEvaluator.expr_resolve_dns, + "reverse_dns": BaseEvaluator.expr_reverse_dns, } self._context = {} @@ -152,6 +156,19 @@ class BaseEvaluator: return event.save() + def expr_func_call_policy(self, name: str, **kwargs) -> PolicyResult: + """Call policy by name, with current request""" + policy = Policy.objects.filter(name=name).select_subclasses().first() + if not policy: + raise ValueError(f"Policy '{name}' not found.") + user = self._context.get("user", get_anonymous_user()) + req = PolicyRequest(user) + if "request" in self._context: + req = self._context["request"] + req.context.update(kwargs) + proc = PolicyProcess(PolicyBinding(policy=policy), request=req, connection=None) + return proc.profiling_wrapper() + def wrap_expression(self, expression: str, params: Iterable[str]) -> str: """Wrap expression in a function, call it, and save the result as `result`""" handler_signature = ",".join(params) diff --git a/authentik/policies/expression/evaluator.py b/authentik/policies/expression/evaluator.py index bf1b2b209d..7617efdb36 100644 --- a/authentik/policies/expression/evaluator.py +++ b/authentik/policies/expression/evaluator.py @@ -9,8 +9,6 @@ from authentik.flows.planner import PLAN_CONTEXT_SSO from authentik.lib.expression.evaluator import BaseEvaluator from authentik.lib.utils.http import get_client_ip from authentik.policies.exceptions import PolicyException -from authentik.policies.models import Policy, PolicyBinding -from authentik.policies.process import PolicyProcess from authentik.policies.types import PolicyRequest, PolicyResult LOGGER = get_logger() @@ -32,22 +30,11 @@ class PolicyEvaluator(BaseEvaluator): # update website/docs/expressions/_functions.md self._context["ak_message"] = self.expr_func_message self._context["ak_user_has_authenticator"] = self.expr_func_user_has_authenticator - self._context["ak_call_policy"] = self.expr_func_call_policy def expr_func_message(self, message: str): """Wrapper to append to messages list, which is returned with PolicyResult""" self._messages.append(message) - def expr_func_call_policy(self, name: str, **kwargs) -> PolicyResult: - """Call policy by name, with current request""" - policy = Policy.objects.filter(name=name).select_subclasses().first() - if not policy: - raise ValueError(f"Policy '{name}' not found.") - req: PolicyRequest = self._context["request"] - req.context.update(kwargs) - proc = PolicyProcess(PolicyBinding(policy=policy), request=req, connection=None) - return proc.profiling_wrapper() - def set_policy_request(self, request: PolicyRequest): """Update context based on policy request (if http request is given, update that too)""" # update website/docs/expressions/_objects.md @@ -83,6 +70,7 @@ class PolicyEvaluator(BaseEvaluator): return PolicyResult(False, str(exc)) else: policy_result = PolicyResult(False, *self._messages) + policy_result.raw_result = result if result is None: LOGGER.warning( "Expression policy returned None", diff --git a/authentik/policies/types.py b/authentik/policies/types.py index 76653d6c2c..aeacf73169 100644 --- a/authentik/policies/types.py +++ b/authentik/policies/types.py @@ -69,10 +69,11 @@ class PolicyRequest: @dataclass class PolicyResult: - """Small data-class to hold policy results""" + """Result from evaluating a policy.""" passing: bool messages: tuple[str, ...] + raw_result: Any source_binding: Optional["PolicyBinding"] source_results: Optional[list["PolicyResult"]] @@ -83,6 +84,7 @@ class PolicyResult: super().__init__() self.passing = passing self.messages = messages + self.raw_result = None self.source_binding = None self.source_results = [] self.log_messages = [] diff --git a/website/docs/expressions/_functions.md b/website/docs/expressions/_functions.md index 04efb2edf5..57624b1196 100644 --- a/website/docs/expressions/_functions.md +++ b/website/docs/expressions/_functions.md @@ -29,6 +29,29 @@ user = list_flatten(["foo"]) # user = "foo" ``` +### `ak_call_policy(name: str, **kwargs) -> PolicyResult` + +:::info +Requires authentik 2021.12 +::: + +Call another policy with the name _name_. Current request is passed to policy. Key-word arguments +can be used to modify the request's context. + +Example: + +```python +result = ak_call_policy("test-policy") +# result is a PolicyResult object, so you can access `.passing` and `.messages`. +# Starting with authentik 2023.4 you can also access `.raw_result`, which is the raw value returned from the called policy +# `result.passing` will always be a boolean if the policy is passing or not. +return result.passing + +result = ak_call_policy("test-policy-2", foo="bar") +# Inside the `test-policy-2` you can then use `request.context["foo"]` +return result.passing +``` + ### `ak_is_group_member(user: User, **group_filters) -> bool` Check if `user` is member of a group matching `**group_filters`. diff --git a/website/docs/policies/expression.mdx b/website/docs/policies/expression.mdx index b6d9545bf9..c98803f657 100644 --- a/website/docs/policies/expression.mdx +++ b/website/docs/policies/expression.mdx @@ -29,27 +29,6 @@ ak_message("Access denied") return False ``` -### `ak_call_policy(name: str, **kwargs) -> PolicyResult` - -:::info -Requires authentik 2021.12 -::: - -Call another policy with the name _name_. Current request is passed to policy. Key-word arguments -can be used to modify the request's context. - -Example: - -```python -result = ak_call_policy("test-policy") -# result is a PolicyResult object, so you can access `.passing` and `.messages`. -return result.passing - -result = ak_call_policy("test-policy-2", foo="bar") -# Inside the `test-policy-2` you can then use `request.context["foo"]` -return result.passing -``` - import Functions from "../expressions/_functions.md";