policies/expression: use pb_message() for messages instead of returning a tuple
This commit is contained in:
		@ -1,8 +1,9 @@
 | 
			
		||||
"""passbook expression policy evaluator"""
 | 
			
		||||
import re
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Dict, Optional
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
 | 
			
		||||
 | 
			
		||||
from django.core.exceptions import ValidationError
 | 
			
		||||
from django.http import HttpRequest
 | 
			
		||||
from jinja2 import Undefined
 | 
			
		||||
from jinja2.exceptions import TemplateSyntaxError
 | 
			
		||||
from jinja2.nativetypes import NativeEnvironment
 | 
			
		||||
@ -25,12 +26,32 @@ class Evaluator:
 | 
			
		||||
 | 
			
		||||
    _env: NativeEnvironment
 | 
			
		||||
 | 
			
		||||
    _context: Dict[str, Any]
 | 
			
		||||
    _messages: List[str]
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self._env = NativeEnvironment()
 | 
			
		||||
        self._env = NativeEnvironment(
 | 
			
		||||
            extensions=["jinja2.ext.do",],
 | 
			
		||||
            trim_blocks=True,
 | 
			
		||||
            lstrip_blocks=True,
 | 
			
		||||
            line_statement_prefix=">",
 | 
			
		||||
        )
 | 
			
		||||
        # update passbook/policies/expression/templates/policy/expression/form.html
 | 
			
		||||
        # update docs/policies/expression/index.md
 | 
			
		||||
        self._env.filters["regex_match"] = Evaluator.jinja2_filter_regex_match
 | 
			
		||||
        self._env.filters["regex_replace"] = Evaluator.jinja2_filter_regex_replace
 | 
			
		||||
        self._env.globals["pb_message"] = self.jinja2_func_message
 | 
			
		||||
        self._context = {
 | 
			
		||||
            "pb_is_group_member": Evaluator.jinja2_func_is_group_member,
 | 
			
		||||
            "pb_logger": get_logger(),
 | 
			
		||||
            "requests": Session(),
 | 
			
		||||
        }
 | 
			
		||||
        self._messages = []
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def env(self) -> NativeEnvironment:
 | 
			
		||||
        """Access to our custom NativeEnvironment"""
 | 
			
		||||
        return self._env
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def jinja2_filter_regex_match(value: Any, regex: str) -> bool:
 | 
			
		||||
@ -47,52 +68,57 @@ class Evaluator:
 | 
			
		||||
        """Check if `user` is member of group with name `group_name`"""
 | 
			
		||||
        return user.groups.filter(name=group_name).exists()
 | 
			
		||||
 | 
			
		||||
    def _get_expression_context(
 | 
			
		||||
        self, request: PolicyRequest, **kwargs
 | 
			
		||||
    ) -> Dict[str, Any]:
 | 
			
		||||
        """Return dictionary with additional global variables passed to expression"""
 | 
			
		||||
    def jinja2_func_message(self, message: str):
 | 
			
		||||
        """Wrapper to append to messages list, which is returned with PolicyResult"""
 | 
			
		||||
        self._messages.append(message)
 | 
			
		||||
 | 
			
		||||
    def set_policy_request(self, request: PolicyRequest):
 | 
			
		||||
        """Update context based on policy request (if http request is given, update that too)"""
 | 
			
		||||
        # update passbook/policies/expression/templates/policy/expression/form.html
 | 
			
		||||
        # update docs/policies/expression/index.md
 | 
			
		||||
        kwargs["pb_is_group_member"] = Evaluator.jinja2_func_is_group_member
 | 
			
		||||
        kwargs["pb_logger"] = get_logger()
 | 
			
		||||
        kwargs["requests"] = Session()
 | 
			
		||||
        kwargs["pb_is_sso_flow"] = request.context.get(PLAN_CONTEXT_SSO, False)
 | 
			
		||||
        self._context["pb_is_sso_flow"] = request.context.get(PLAN_CONTEXT_SSO, False)
 | 
			
		||||
        self._context["request"] = request
 | 
			
		||||
        if request.http_request:
 | 
			
		||||
            kwargs["pb_client_ip"] = (
 | 
			
		||||
                get_client_ip(request.http_request) or "255.255.255.255"
 | 
			
		||||
            )
 | 
			
		||||
            if SESSION_KEY_PLAN in request.http_request.session:
 | 
			
		||||
                kwargs["pb_flow_plan"] = request.http_request.session[SESSION_KEY_PLAN]
 | 
			
		||||
        return kwargs
 | 
			
		||||
            self.set_http_request(request.http_request)
 | 
			
		||||
 | 
			
		||||
    def evaluate(self, expression_source: str, request: PolicyRequest) -> PolicyResult:
 | 
			
		||||
        """Parse and evaluate expression.
 | 
			
		||||
        If the Expression evaluates to a list with 2 items, the first is used as passing bool and
 | 
			
		||||
        the second as messages.
 | 
			
		||||
        If the Expression evaluates to a truthy-object, it is used as passing bool."""
 | 
			
		||||
    def set_http_request(self, request: HttpRequest):
 | 
			
		||||
        """Update context based on http request"""
 | 
			
		||||
        # update passbook/policies/expression/templates/policy/expression/form.html
 | 
			
		||||
        # update docs/policies/expression/index.md
 | 
			
		||||
        self._context["pb_client_ip"] = (
 | 
			
		||||
            get_client_ip(request.http_request) or "255.255.255.255"
 | 
			
		||||
        )
 | 
			
		||||
        self._context["request"] = request
 | 
			
		||||
        if SESSION_KEY_PLAN in request.http_request.session:
 | 
			
		||||
            self._context["pb_flow_plan"] = request.http_request.session[
 | 
			
		||||
                SESSION_KEY_PLAN
 | 
			
		||||
            ]
 | 
			
		||||
 | 
			
		||||
    def evaluate(self, expression_source: str) -> PolicyResult:
 | 
			
		||||
        """Parse and evaluate expression. Policy is expected to return a truthy object.
 | 
			
		||||
        Messages can be added using 'do pb_message()'."""
 | 
			
		||||
        try:
 | 
			
		||||
            expression = self._env.from_string(expression_source)
 | 
			
		||||
            expression = self._env.from_string(expression_source.lstrip().rstrip())
 | 
			
		||||
        except TemplateSyntaxError as exc:
 | 
			
		||||
            return PolicyResult(False, str(exc))
 | 
			
		||||
        try:
 | 
			
		||||
            result: Optional[Any] = expression.render(
 | 
			
		||||
                request=request, **self._get_expression_context(request)
 | 
			
		||||
            )
 | 
			
		||||
            result: Optional[Any] = expression.render(self._context)
 | 
			
		||||
        except Exception as exc:  # pylint: disable=broad-except
 | 
			
		||||
            LOGGER.warning("Expression error", exc=exc)
 | 
			
		||||
            return PolicyResult(False, str(exc))
 | 
			
		||||
        else:
 | 
			
		||||
            policy_result = PolicyResult(False)
 | 
			
		||||
            policy_result.messages = tuple(self._messages)
 | 
			
		||||
            if isinstance(result, Undefined):
 | 
			
		||||
                LOGGER.warning(
 | 
			
		||||
                    "Expression policy returned undefined",
 | 
			
		||||
                    src=expression_source,
 | 
			
		||||
                    req=request,
 | 
			
		||||
                    req=self._context,
 | 
			
		||||
                )
 | 
			
		||||
                return PolicyResult(False)
 | 
			
		||||
            if isinstance(result, (list, tuple)) and len(result) == 2:
 | 
			
		||||
                return PolicyResult(*result)
 | 
			
		||||
                policy_result.passing = False
 | 
			
		||||
            if result:
 | 
			
		||||
                return PolicyResult(bool(result))
 | 
			
		||||
            return PolicyResult(False)
 | 
			
		||||
        except Exception as exc:  # pylint: disable=broad-except
 | 
			
		||||
            LOGGER.warning("Expression error", exc=exc)
 | 
			
		||||
            return PolicyResult(False, str(exc))
 | 
			
		||||
                policy_result.passing = bool(result)
 | 
			
		||||
            return policy_result
 | 
			
		||||
 | 
			
		||||
    def validate(self, expression: str):
 | 
			
		||||
        """Validate expression's syntax, raise ValidationError if Syntax is invalid"""
 | 
			
		||||
 | 
			
		||||
@ -16,7 +16,9 @@ class ExpressionPolicy(Policy):
 | 
			
		||||
 | 
			
		||||
    def passes(self, request: PolicyRequest) -> PolicyResult:
 | 
			
		||||
        """Evaluate and render expression. Returns PolicyResult(false) on error."""
 | 
			
		||||
        return Evaluator().evaluate(self.expression, request)
 | 
			
		||||
        evaluator = Evaluator()
 | 
			
		||||
        evaluator.set_policy_request(request)
 | 
			
		||||
        return evaluator.evaluate(self.expression)
 | 
			
		||||
 | 
			
		||||
    def save(self, *args, **kwargs):
 | 
			
		||||
        Evaluator().validate(self.expression)
 | 
			
		||||
 | 
			
		||||
@ -17,13 +17,15 @@ class TestEvaluator(TestCase):
 | 
			
		||||
        """test simple value expression"""
 | 
			
		||||
        template = "True"
 | 
			
		||||
        evaluator = Evaluator()
 | 
			
		||||
        self.assertEqual(evaluator.evaluate(template, self.request).passing, True)
 | 
			
		||||
        evaluator.set_policy_request(self.request)
 | 
			
		||||
        self.assertEqual(evaluator.evaluate(template).passing, True)
 | 
			
		||||
 | 
			
		||||
    def test_messages(self):
 | 
			
		||||
        """test expression with message return"""
 | 
			
		||||
        template = "False, 'some message'"
 | 
			
		||||
        template = '{% do pb_message("some message") %}False'
 | 
			
		||||
        evaluator = Evaluator()
 | 
			
		||||
        result = evaluator.evaluate(template, self.request)
 | 
			
		||||
        evaluator.set_policy_request(self.request)
 | 
			
		||||
        result = evaluator.evaluate(template)
 | 
			
		||||
        self.assertEqual(result.passing, False)
 | 
			
		||||
        self.assertEqual(result.messages, ("some message",))
 | 
			
		||||
 | 
			
		||||
@ -31,7 +33,8 @@ class TestEvaluator(TestCase):
 | 
			
		||||
        """test invalid syntax"""
 | 
			
		||||
        template = "{%"
 | 
			
		||||
        evaluator = Evaluator()
 | 
			
		||||
        result = evaluator.evaluate(template, self.request)
 | 
			
		||||
        evaluator.set_policy_request(self.request)
 | 
			
		||||
        result = evaluator.evaluate(template)
 | 
			
		||||
        self.assertEqual(result.passing, False)
 | 
			
		||||
        self.assertEqual(result.messages, ("tag name expected",))
 | 
			
		||||
 | 
			
		||||
@ -39,7 +42,8 @@ class TestEvaluator(TestCase):
 | 
			
		||||
        """test undefined result"""
 | 
			
		||||
        template = "{{ foo.bar }}"
 | 
			
		||||
        evaluator = Evaluator()
 | 
			
		||||
        result = evaluator.evaluate(template, self.request)
 | 
			
		||||
        evaluator.set_policy_request(self.request)
 | 
			
		||||
        result = evaluator.evaluate(template)
 | 
			
		||||
        self.assertEqual(result.passing, False)
 | 
			
		||||
        self.assertEqual(result.messages, ("'foo' is undefined",))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user