Compare commits
	
		
			9 Commits
		
	
	
		
			enterprise
			...
			policies/o
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 6e1cf8a23c | |||
| fde6120e67 | |||
| dabd812071 | |||
| db92da4cb8 | |||
| 81c23fff98 | |||
| 54b5774a15 | |||
| ba4650a088 | |||
| b7d2c5188b | |||
| dd8e71df7d | 
| @ -153,10 +153,10 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet): | ||||
|         return applications | ||||
|  | ||||
|     def _filter_applications_with_launch_url( | ||||
|         self, pagined_apps: Iterator[Application] | ||||
|         self, paginated_apps: Iterator[Application] | ||||
|     ) -> list[Application]: | ||||
|         applications = [] | ||||
|         for app in pagined_apps: | ||||
|         for app in paginated_apps: | ||||
|             if app.get_launch_url(): | ||||
|                 applications.append(app) | ||||
|         return applications | ||||
|  | ||||
| @ -1,11 +1,11 @@ | ||||
| """authentik policy engine""" | ||||
|  | ||||
| from collections.abc import Iterator | ||||
| from collections.abc import Iterable | ||||
| from multiprocessing import Pipe, current_process | ||||
| from multiprocessing.connection import Connection | ||||
| from time import perf_counter | ||||
|  | ||||
| from django.core.cache import cache | ||||
| from django.db.models import Count, Q, QuerySet | ||||
| from django.http import HttpRequest | ||||
| from sentry_sdk import start_span | ||||
| from sentry_sdk.tracing import Span | ||||
| @ -67,14 +67,11 @@ class PolicyEngine: | ||||
|         self.__processes: list[PolicyProcessInfo] = [] | ||||
|         self.use_cache = True | ||||
|         self.__expected_result_count = 0 | ||||
|         self.__static_result: PolicyResult | None = None | ||||
|  | ||||
|     def iterate_bindings(self) -> Iterator[PolicyBinding]: | ||||
|     def bindings(self) -> QuerySet[PolicyBinding] | Iterable[PolicyBinding]: | ||||
|         """Make sure all Policies are their respective classes""" | ||||
|         return ( | ||||
|             PolicyBinding.objects.filter(target=self.__pbm, enabled=True) | ||||
|             .order_by("order") | ||||
|             .iterator() | ||||
|         ) | ||||
|         return PolicyBinding.objects.filter(target=self.__pbm, enabled=True).order_by("order") | ||||
|  | ||||
|     def _check_policy_type(self, binding: PolicyBinding): | ||||
|         """Check policy type, make sure it's not the root class as that has no logic implemented""" | ||||
| @ -84,30 +81,66 @@ class PolicyEngine: | ||||
|     def _check_cache(self, binding: PolicyBinding): | ||||
|         if not self.use_cache: | ||||
|             return False | ||||
|         before = perf_counter() | ||||
|         key = cache_key(binding, self.request) | ||||
|         cached_policy = cache.get(key, None) | ||||
|         duration = max(perf_counter() - before, 0) | ||||
|         if not cached_policy: | ||||
|             return False | ||||
|         self.logger.debug( | ||||
|             "P_ENG: Taking result from cache", | ||||
|             binding=binding, | ||||
|             cache_key=key, | ||||
|             request=self.request, | ||||
|         ) | ||||
|         HIST_POLICIES_EXECUTION_TIME.labels( | ||||
|         # It's a bit silly to time this, but | ||||
|         with HIST_POLICIES_EXECUTION_TIME.labels( | ||||
|             binding_order=binding.order, | ||||
|             binding_target_type=binding.target_type, | ||||
|             binding_target_name=binding.target_name, | ||||
|             object_pk=str(self.request.obj.pk), | ||||
|             object_type=class_to_path(self.request.obj.__class__), | ||||
|             mode="cache_retrieve", | ||||
|         ).observe(duration) | ||||
|         # It's a bit silly to time this, but | ||||
|         ).time(): | ||||
|             key = cache_key(binding, self.request) | ||||
|             cached_policy = cache.get(key, None) | ||||
|             if not cached_policy: | ||||
|                 return False | ||||
|         self.logger.debug( | ||||
|             "P_ENG: Taking result from cache", | ||||
|             binding=binding, | ||||
|             cache_key=key, | ||||
|             request=self.request, | ||||
|         ) | ||||
|         self.__cached_policies.append(cached_policy) | ||||
|         return True | ||||
|  | ||||
|     def compute_static_bindings(self, bindings: QuerySet[PolicyBinding]): | ||||
|         """Check static bindings if possible""" | ||||
|         aggrs = { | ||||
|             "total": Count( | ||||
|                 "pk", filter=Q(Q(group__isnull=False) | Q(user__isnull=False), policy=None) | ||||
|             ), | ||||
|         } | ||||
|         if self.request.user.pk: | ||||
|             all_groups = self.request.user.all_groups() | ||||
|             aggrs["passing"] = Count( | ||||
|                 "pk", | ||||
|                 filter=Q( | ||||
|                     Q( | ||||
|                         Q(user=self.request.user) | Q(group__in=all_groups), | ||||
|                         negate=False, | ||||
|                     ) | ||||
|                     | Q( | ||||
|                         Q(~Q(user=self.request.user), user__isnull=False) | ||||
|                         | Q(~Q(group__in=all_groups), group__isnull=False), | ||||
|                         negate=True, | ||||
|                     ), | ||||
|                     enabled=True, | ||||
|                 ), | ||||
|             ) | ||||
|         matched_bindings = bindings.aggregate(**aggrs) | ||||
|         passing = False | ||||
|         if matched_bindings["total"] == 0 and matched_bindings.get("passing", 0) == 0: | ||||
|             # If we didn't find any static bindings, do nothing | ||||
|             return | ||||
|         self.logger.debug("P_ENG: Found static bindings", **matched_bindings) | ||||
|         if matched_bindings.get("passing", 0) > 0: | ||||
|             # Any passing static binding -> passing | ||||
|             passing = True | ||||
|         elif matched_bindings["total"] > 0 and matched_bindings.get("passing", 0) < 1: | ||||
|             # No matching static bindings but at least one is configured -> not passing | ||||
|             passing = False | ||||
|         self.__static_result = PolicyResult(passing) | ||||
|  | ||||
|     def build(self) -> "PolicyEngine": | ||||
|         """Build wrapper which monitors performance""" | ||||
|         with ( | ||||
| @ -123,7 +156,12 @@ class PolicyEngine: | ||||
|             span: Span | ||||
|             span.set_data("pbm", self.__pbm) | ||||
|             span.set_data("request", self.request) | ||||
|             for binding in self.iterate_bindings(): | ||||
|             bindings = self.bindings() | ||||
|             policy_bindings = bindings | ||||
|             if isinstance(bindings, QuerySet): | ||||
|                 self.compute_static_bindings(bindings) | ||||
|                 policy_bindings = [x for x in bindings if x.policy] | ||||
|             for binding in policy_bindings: | ||||
|                 self.__expected_result_count += 1 | ||||
|  | ||||
|                 self._check_policy_type(binding) | ||||
| @ -153,10 +191,13 @@ class PolicyEngine: | ||||
|     @property | ||||
|     def result(self) -> PolicyResult: | ||||
|         """Get policy-checking result""" | ||||
|         self.__processes.sort(key=lambda x: x.binding.order) | ||||
|         process_results: list[PolicyResult] = [x.result for x in self.__processes if x.result] | ||||
|         all_results = list(process_results + self.__cached_policies) | ||||
|         if len(all_results) < self.__expected_result_count:  # pragma: no cover | ||||
|             raise AssertionError("Got less results than polices") | ||||
|         if self.__static_result: | ||||
|             all_results.append(self.__static_result) | ||||
|         # No results, no policies attached -> passing | ||||
|         if len(all_results) == 0: | ||||
|             return PolicyResult(self.empty_result) | ||||
|  | ||||
| @ -1,9 +1,12 @@ | ||||
| """policy engine tests""" | ||||
|  | ||||
| from django.core.cache import cache | ||||
| from django.db import connections | ||||
| from django.test import TestCase | ||||
| from django.test.utils import CaptureQueriesContext | ||||
|  | ||||
| from authentik.core.tests.utils import create_test_admin_user | ||||
| from authentik.core.models import Group | ||||
| from authentik.core.tests.utils import create_test_user | ||||
| from authentik.lib.generators import generate_id | ||||
| from authentik.policies.dummy.models import DummyPolicy | ||||
| from authentik.policies.engine import PolicyEngine | ||||
| @ -19,7 +22,7 @@ class TestPolicyEngine(TestCase): | ||||
|  | ||||
|     def setUp(self): | ||||
|         clear_policy_cache() | ||||
|         self.user = create_test_admin_user() | ||||
|         self.user = create_test_user() | ||||
|         self.policy_false = DummyPolicy.objects.create( | ||||
|             name=generate_id(), result=False, wait_min=0, wait_max=1 | ||||
|         ) | ||||
| @ -127,3 +130,43 @@ class TestPolicyEngine(TestCase): | ||||
|         self.assertEqual(len(cache.keys(f"{CACHE_PREFIX}{binding.policy_binding_uuid.hex}*")), 1) | ||||
|         self.assertEqual(engine.build().passing, False) | ||||
|         self.assertEqual(len(cache.keys(f"{CACHE_PREFIX}{binding.policy_binding_uuid.hex}*")), 1) | ||||
|  | ||||
|     def test_engine_static_bindings(self): | ||||
|         """Test static bindings""" | ||||
|         group_a = Group.objects.create(name=generate_id()) | ||||
|         group_b = Group.objects.create(name=generate_id()) | ||||
|         group_b.users.add(self.user) | ||||
|         user = create_test_user() | ||||
|  | ||||
|         for case in [ | ||||
|             { | ||||
|                 "message": "Group, not member", | ||||
|                 "binding_args": {"group": group_a}, | ||||
|                 "passing": False, | ||||
|             }, | ||||
|             { | ||||
|                 "message": "Group, member", | ||||
|                 "binding_args": {"group": group_b}, | ||||
|                 "passing": True, | ||||
|             }, | ||||
|             { | ||||
|                 "message": "User, other", | ||||
|                 "binding_args": {"user": user}, | ||||
|                 "passing": False, | ||||
|             }, | ||||
|             { | ||||
|                 "message": "User, same", | ||||
|                 "binding_args": {"user": self.user}, | ||||
|                 "passing": True, | ||||
|             }, | ||||
|         ]: | ||||
|             with self.subTest(): | ||||
|                 pbm = PolicyBindingModel.objects.create() | ||||
|                 for x in range(1000): | ||||
|                     PolicyBinding.objects.create(target=pbm, order=x, **case["binding_args"]) | ||||
|                 engine = PolicyEngine(pbm, self.user) | ||||
|                 engine.use_cache = False | ||||
|                 with CaptureQueriesContext(connections["default"]) as ctx: | ||||
|                     engine.build() | ||||
|                 self.assertLess(ctx.final_queries, 1000) | ||||
|                 self.assertEqual(engine.result.passing, case["passing"]) | ||||
|  | ||||
| @ -29,13 +29,12 @@ class TestPolicyProcess(TestCase): | ||||
|     def setUp(self): | ||||
|         clear_policy_cache() | ||||
|         self.factory = RequestFactory() | ||||
|         self.user = User.objects.create_user(username="policyuser") | ||||
|         self.user = User.objects.create_user(username=generate_id()) | ||||
|  | ||||
|     def test_group_passing(self): | ||||
|         """Test binding to group""" | ||||
|         group = Group.objects.create(name="test-group") | ||||
|         group = Group.objects.create(name=generate_id()) | ||||
|         group.users.add(self.user) | ||||
|         group.save() | ||||
|         binding = PolicyBinding(group=group) | ||||
|  | ||||
|         request = PolicyRequest(self.user) | ||||
| @ -44,8 +43,7 @@ class TestPolicyProcess(TestCase): | ||||
|  | ||||
|     def test_group_negative(self): | ||||
|         """Test binding to group""" | ||||
|         group = Group.objects.create(name="test-group") | ||||
|         group.save() | ||||
|         group = Group.objects.create(name=generate_id()) | ||||
|         binding = PolicyBinding(group=group) | ||||
|  | ||||
|         request = PolicyRequest(self.user) | ||||
| @ -115,8 +113,10 @@ class TestPolicyProcess(TestCase): | ||||
|  | ||||
|     def test_exception(self): | ||||
|         """Test policy execution""" | ||||
|         policy = Policy.objects.create(name="test-execution") | ||||
|         binding = PolicyBinding(policy=policy, target=Application.objects.create(name="test")) | ||||
|         policy = Policy.objects.create(name=generate_id()) | ||||
|         binding = PolicyBinding( | ||||
|             policy=policy, target=Application.objects.create(name=generate_id()) | ||||
|         ) | ||||
|  | ||||
|         request = PolicyRequest(self.user) | ||||
|         response = PolicyProcess(binding, request, None).execute() | ||||
| @ -125,13 +125,15 @@ class TestPolicyProcess(TestCase): | ||||
|     def test_execution_logging(self): | ||||
|         """Test policy execution creates event""" | ||||
|         policy = DummyPolicy.objects.create( | ||||
|             name="test-execution-logging", | ||||
|             name=generate_id(), | ||||
|             result=False, | ||||
|             wait_min=0, | ||||
|             wait_max=1, | ||||
|             execution_logging=True, | ||||
|         ) | ||||
|         binding = PolicyBinding(policy=policy, target=Application.objects.create(name="test")) | ||||
|         binding = PolicyBinding( | ||||
|             policy=policy, target=Application.objects.create(name=generate_id()) | ||||
|         ) | ||||
|  | ||||
|         http_request = self.factory.get(reverse("authentik_api:user-impersonate-end")) | ||||
|         http_request.user = self.user | ||||
| @ -186,13 +188,15 @@ class TestPolicyProcess(TestCase): | ||||
|     def test_execution_logging_anonymous(self): | ||||
|         """Test policy execution creates event with anonymous user""" | ||||
|         policy = DummyPolicy.objects.create( | ||||
|             name="test-execution-logging-anon", | ||||
|             name=generate_id(), | ||||
|             result=False, | ||||
|             wait_min=0, | ||||
|             wait_max=1, | ||||
|             execution_logging=True, | ||||
|         ) | ||||
|         binding = PolicyBinding(policy=policy, target=Application.objects.create(name="test")) | ||||
|         binding = PolicyBinding( | ||||
|             policy=policy, target=Application.objects.create(name=generate_id()) | ||||
|         ) | ||||
|  | ||||
|         user = AnonymousUser() | ||||
|  | ||||
| @ -219,9 +223,9 @@ class TestPolicyProcess(TestCase): | ||||
|  | ||||
|     def test_raises(self): | ||||
|         """Test policy that raises error""" | ||||
|         policy_raises = ExpressionPolicy.objects.create(name="raises", expression="{{ 0/0 }}") | ||||
|         policy_raises = ExpressionPolicy.objects.create(name=generate_id(), expression="{{ 0/0 }}") | ||||
|         binding = PolicyBinding( | ||||
|             policy=policy_raises, target=Application.objects.create(name="test") | ||||
|             policy=policy_raises, target=Application.objects.create(name=generate_id()) | ||||
|         ) | ||||
|  | ||||
|         request = PolicyRequest(self.user) | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| """Prompt Stage Logic""" | ||||
|  | ||||
| from collections.abc import Callable, Iterator | ||||
| from collections.abc import Callable | ||||
| from email.policy import Policy | ||||
| from types import MethodType | ||||
| from typing import Any | ||||
| @ -190,7 +190,7 @@ class ListPolicyEngine(PolicyEngine): | ||||
|         self.__list = policies | ||||
|         self.use_cache = False | ||||
|  | ||||
|     def iterate_bindings(self) -> Iterator[PolicyBinding]: | ||||
|     def bindings(self): | ||||
|         for policy in self.__list: | ||||
|             yield PolicyBinding( | ||||
|                 policy=policy, | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	