diff --git a/authentik/policies/engine.py b/authentik/policies/engine.py index 107c0933e4..85ad84c884 100644 --- a/authentik/policies/engine.py +++ b/authentik/policies/engine.py @@ -5,7 +5,7 @@ from multiprocessing.connection import Connection from time import perf_counter from django.core.cache import cache -from django.db.models import Q, QuerySet +from django.db.models import Count, Q, QuerySet from django.http import HttpRequest from sentry_sdk import start_span from sentry_sdk.tracing import Span @@ -25,17 +25,12 @@ CURRENT_PROCESS = current_process() class PolicyProcessInfo: """Dataclass to hold all information and communication channels to a process""" - process: PolicyProcess | None - connection: Connection | None + process: PolicyProcess + connection: Connection result: PolicyResult | None binding: PolicyBinding - def __init__( - self, - process: PolicyProcess | None, - connection: Connection | None, - binding: PolicyBinding, - ): + def __init__(self, process: PolicyProcess, connection: Connection, binding: PolicyBinding): self.process = process self.connection = connection self.binding = binding @@ -72,6 +67,7 @@ class PolicyEngine: self.__processes: list[PolicyProcessInfo] = [] self.use_cache = True self.__expected_result_count = 0 + self.__static_result = None def bindings(self) -> QuerySet[PolicyBinding]: """Make sure all Policies are their respective classes""" @@ -111,23 +107,32 @@ class PolicyEngine: def compute_static_bindings(self, bindings: QuerySet[PolicyBinding]): all_groups = self.request.user.all_groups() - matched_bindings = bindings.filter( - Q( - Q(user=self.request.user) | Q(group__in=all_groups), - negate=False, - ) - | Q( - Q(~Q(user=self.request.user), user__isnull=False) - | Q(~Q(group__in=all_groups), group__isnull=False), - negate=True, + matched_bindings = bindings.aggregate( + total=Count( + "pk", filter=Q(Q(group__isnull=False) | Q(user__isnull=False), policy=None) ), - enabled=True, - ).order_by("order") - for binding in matched_bindings: - self.__expected_result_count += 1 - pi = PolicyProcessInfo(process=None, connection=None, binding=binding) - pi.result = PolicyResult(True) - self.__processes.append(pi) + matching=Count( + "pk", + filter=Q( + Q( + Q(user=self.request.user) | Q(group__in=all_groups), + negate=False, + ) + | Q( + Q(~Q(user=self.request.user), user__isnull=False) + | Q(~Q(group__in=all_groups), group__isnull=False), + negate=True, + ), + enabled=True, + ), + ), + ) + passing = False + if matched_bindings["matching"] > 0: + passing = True + elif matched_bindings["total"] > 0 and matched_bindings["matching"] < 1: + passing = False + self.__static_result = PolicyResult(passing) def build(self) -> "PolicyEngine": """Build wrapper which monitors performance""" @@ -181,6 +186,8 @@ class PolicyEngine: all_results = list(process_results + self.__cached_policies) if len(all_results) < self.__expected_result_count: # pragma: no cover raise AssertionError("Got less results than polices") + if self.__static_result: + all_results.append(self.__static_result) # No results, no policies attached -> passing if len(all_results) == 0: return PolicyResult(self.empty_result) diff --git a/authentik/policies/tests/test_engine.py b/authentik/policies/tests/test_engine.py index 89b49ef1cb..56f35f1a44 100644 --- a/authentik/policies/tests/test_engine.py +++ b/authentik/policies/tests/test_engine.py @@ -1,8 +1,11 @@ """policy engine tests""" from django.core.cache import cache +from django.db import connections from django.test import TestCase +from django.test.utils import CaptureQueriesContext +from authentik.core.models import Group from authentik.core.tests.utils import create_test_admin_user from authentik.lib.generators import generate_id from authentik.policies.dummy.models import DummyPolicy @@ -127,3 +130,32 @@ class TestPolicyEngine(TestCase): self.assertEqual(len(cache.keys(f"{CACHE_PREFIX}{binding.policy_binding_uuid.hex}*")), 1) self.assertEqual(engine.build().passing, False) self.assertEqual(len(cache.keys(f"{CACHE_PREFIX}{binding.policy_binding_uuid.hex}*")), 1) + + def test_engine_static_bindings(self): + """Test static bindings""" + group = Group.objects.create(name=generate_id()) + + pbm = PolicyBindingModel.objects.create() + for x in range(1000): + PolicyBinding.objects.create(target=pbm, group=group, order=x) + engine = PolicyEngine(pbm, self.user) + engine.use_cache = False + with CaptureQueriesContext(connections["default"]) as ctx: + engine.build() + self.assertLess(ctx.final_queries, 1000) + self.assertEqual(engine.result.passing, False) + + def test_engine_static_bindings_matching(self): + """Test static bindings""" + group = Group.objects.create(name=generate_id()) + group.users.add(self.user) + + pbm = PolicyBindingModel.objects.create() + for x in range(1000): + PolicyBinding.objects.create(target=pbm, group=group, order=x) + engine = PolicyEngine(pbm, self.user) + engine.use_cache = False + with CaptureQueriesContext(connections["default"]) as ctx: + engine.build() + self.assertLess(ctx.final_queries, 1000) + self.assertEqual(engine.result.passing, False)