@ -5,7 +5,7 @@ from multiprocessing.connection import Connection
|
||||
from time import perf_counter
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.db.models import Q, QuerySet
|
||||
from django.db.models import Count, Q, QuerySet
|
||||
from django.http import HttpRequest
|
||||
from sentry_sdk import start_span
|
||||
from sentry_sdk.tracing import Span
|
||||
@ -25,17 +25,12 @@ CURRENT_PROCESS = current_process()
|
||||
class PolicyProcessInfo:
|
||||
"""Dataclass to hold all information and communication channels to a process"""
|
||||
|
||||
process: PolicyProcess | None
|
||||
connection: Connection | None
|
||||
process: PolicyProcess
|
||||
connection: Connection
|
||||
result: PolicyResult | None
|
||||
binding: PolicyBinding
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
process: PolicyProcess | None,
|
||||
connection: Connection | None,
|
||||
binding: PolicyBinding,
|
||||
):
|
||||
def __init__(self, process: PolicyProcess, connection: Connection, binding: PolicyBinding):
|
||||
self.process = process
|
||||
self.connection = connection
|
||||
self.binding = binding
|
||||
@ -72,6 +67,7 @@ class PolicyEngine:
|
||||
self.__processes: list[PolicyProcessInfo] = []
|
||||
self.use_cache = True
|
||||
self.__expected_result_count = 0
|
||||
self.__static_result = None
|
||||
|
||||
def bindings(self) -> QuerySet[PolicyBinding]:
|
||||
"""Make sure all Policies are their respective classes"""
|
||||
@ -111,7 +107,13 @@ class PolicyEngine:
|
||||
|
||||
def compute_static_bindings(self, bindings: QuerySet[PolicyBinding]):
|
||||
all_groups = self.request.user.all_groups()
|
||||
matched_bindings = bindings.filter(
|
||||
matched_bindings = bindings.aggregate(
|
||||
total=Count(
|
||||
"pk", filter=Q(Q(group__isnull=False) | Q(user__isnull=False), policy=None)
|
||||
),
|
||||
matching=Count(
|
||||
"pk",
|
||||
filter=Q(
|
||||
Q(
|
||||
Q(user=self.request.user) | Q(group__in=all_groups),
|
||||
negate=False,
|
||||
@ -122,12 +124,15 @@ class PolicyEngine:
|
||||
negate=True,
|
||||
),
|
||||
enabled=True,
|
||||
).order_by("order")
|
||||
for binding in matched_bindings:
|
||||
self.__expected_result_count += 1
|
||||
pi = PolicyProcessInfo(process=None, connection=None, binding=binding)
|
||||
pi.result = PolicyResult(True)
|
||||
self.__processes.append(pi)
|
||||
),
|
||||
),
|
||||
)
|
||||
passing = False
|
||||
if matched_bindings["matching"] > 0:
|
||||
passing = True
|
||||
elif matched_bindings["total"] > 0 and matched_bindings["matching"] < 1:
|
||||
passing = False
|
||||
self.__static_result = PolicyResult(passing)
|
||||
|
||||
def build(self) -> "PolicyEngine":
|
||||
"""Build wrapper which monitors performance"""
|
||||
@ -181,6 +186,8 @@ class PolicyEngine:
|
||||
all_results = list(process_results + self.__cached_policies)
|
||||
if len(all_results) < self.__expected_result_count: # pragma: no cover
|
||||
raise AssertionError("Got less results than polices")
|
||||
if self.__static_result:
|
||||
all_results.append(self.__static_result)
|
||||
# No results, no policies attached -> passing
|
||||
if len(all_results) == 0:
|
||||
return PolicyResult(self.empty_result)
|
||||
|
||||
@ -1,8 +1,11 @@
|
||||
"""policy engine tests"""
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.db import connections
|
||||
from django.test import TestCase
|
||||
from django.test.utils import CaptureQueriesContext
|
||||
|
||||
from authentik.core.models import Group
|
||||
from authentik.core.tests.utils import create_test_admin_user
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.policies.dummy.models import DummyPolicy
|
||||
@ -127,3 +130,32 @@ class TestPolicyEngine(TestCase):
|
||||
self.assertEqual(len(cache.keys(f"{CACHE_PREFIX}{binding.policy_binding_uuid.hex}*")), 1)
|
||||
self.assertEqual(engine.build().passing, False)
|
||||
self.assertEqual(len(cache.keys(f"{CACHE_PREFIX}{binding.policy_binding_uuid.hex}*")), 1)
|
||||
|
||||
def test_engine_static_bindings(self):
|
||||
"""Test static bindings"""
|
||||
group = Group.objects.create(name=generate_id())
|
||||
|
||||
pbm = PolicyBindingModel.objects.create()
|
||||
for x in range(1000):
|
||||
PolicyBinding.objects.create(target=pbm, group=group, order=x)
|
||||
engine = PolicyEngine(pbm, self.user)
|
||||
engine.use_cache = False
|
||||
with CaptureQueriesContext(connections["default"]) as ctx:
|
||||
engine.build()
|
||||
self.assertLess(ctx.final_queries, 1000)
|
||||
self.assertEqual(engine.result.passing, False)
|
||||
|
||||
def test_engine_static_bindings_matching(self):
|
||||
"""Test static bindings"""
|
||||
group = Group.objects.create(name=generate_id())
|
||||
group.users.add(self.user)
|
||||
|
||||
pbm = PolicyBindingModel.objects.create()
|
||||
for x in range(1000):
|
||||
PolicyBinding.objects.create(target=pbm, group=group, order=x)
|
||||
engine = PolicyEngine(pbm, self.user)
|
||||
engine.use_cache = False
|
||||
with CaptureQueriesContext(connections["default"]) as ctx:
|
||||
engine.build()
|
||||
self.assertLess(ctx.final_queries, 1000)
|
||||
self.assertEqual(engine.result.passing, False)
|
||||
|
||||
Reference in New Issue
Block a user