@ -5,7 +5,7 @@ from multiprocessing.connection import Connection
|
|||||||
from time import perf_counter
|
from time import perf_counter
|
||||||
|
|
||||||
from django.core.cache import cache
|
from django.core.cache import cache
|
||||||
from django.db.models import Q, QuerySet
|
from django.db.models import Count, Q, QuerySet
|
||||||
from django.http import HttpRequest
|
from django.http import HttpRequest
|
||||||
from sentry_sdk import start_span
|
from sentry_sdk import start_span
|
||||||
from sentry_sdk.tracing import Span
|
from sentry_sdk.tracing import Span
|
||||||
@ -25,17 +25,12 @@ CURRENT_PROCESS = current_process()
|
|||||||
class PolicyProcessInfo:
|
class PolicyProcessInfo:
|
||||||
"""Dataclass to hold all information and communication channels to a process"""
|
"""Dataclass to hold all information and communication channels to a process"""
|
||||||
|
|
||||||
process: PolicyProcess | None
|
process: PolicyProcess
|
||||||
connection: Connection | None
|
connection: Connection
|
||||||
result: PolicyResult | None
|
result: PolicyResult | None
|
||||||
binding: PolicyBinding
|
binding: PolicyBinding
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, process: PolicyProcess, connection: Connection, binding: PolicyBinding):
|
||||||
self,
|
|
||||||
process: PolicyProcess | None,
|
|
||||||
connection: Connection | None,
|
|
||||||
binding: PolicyBinding,
|
|
||||||
):
|
|
||||||
self.process = process
|
self.process = process
|
||||||
self.connection = connection
|
self.connection = connection
|
||||||
self.binding = binding
|
self.binding = binding
|
||||||
@ -72,6 +67,7 @@ class PolicyEngine:
|
|||||||
self.__processes: list[PolicyProcessInfo] = []
|
self.__processes: list[PolicyProcessInfo] = []
|
||||||
self.use_cache = True
|
self.use_cache = True
|
||||||
self.__expected_result_count = 0
|
self.__expected_result_count = 0
|
||||||
|
self.__static_result = None
|
||||||
|
|
||||||
def bindings(self) -> QuerySet[PolicyBinding]:
|
def bindings(self) -> QuerySet[PolicyBinding]:
|
||||||
"""Make sure all Policies are their respective classes"""
|
"""Make sure all Policies are their respective classes"""
|
||||||
@ -111,23 +107,32 @@ class PolicyEngine:
|
|||||||
|
|
||||||
def compute_static_bindings(self, bindings: QuerySet[PolicyBinding]):
|
def compute_static_bindings(self, bindings: QuerySet[PolicyBinding]):
|
||||||
all_groups = self.request.user.all_groups()
|
all_groups = self.request.user.all_groups()
|
||||||
matched_bindings = bindings.filter(
|
matched_bindings = bindings.aggregate(
|
||||||
Q(
|
total=Count(
|
||||||
Q(user=self.request.user) | Q(group__in=all_groups),
|
"pk", filter=Q(Q(group__isnull=False) | Q(user__isnull=False), policy=None)
|
||||||
negate=False,
|
|
||||||
)
|
|
||||||
| Q(
|
|
||||||
Q(~Q(user=self.request.user), user__isnull=False)
|
|
||||||
| Q(~Q(group__in=all_groups), group__isnull=False),
|
|
||||||
negate=True,
|
|
||||||
),
|
),
|
||||||
enabled=True,
|
matching=Count(
|
||||||
).order_by("order")
|
"pk",
|
||||||
for binding in matched_bindings:
|
filter=Q(
|
||||||
self.__expected_result_count += 1
|
Q(
|
||||||
pi = PolicyProcessInfo(process=None, connection=None, binding=binding)
|
Q(user=self.request.user) | Q(group__in=all_groups),
|
||||||
pi.result = PolicyResult(True)
|
negate=False,
|
||||||
self.__processes.append(pi)
|
)
|
||||||
|
| Q(
|
||||||
|
Q(~Q(user=self.request.user), user__isnull=False)
|
||||||
|
| Q(~Q(group__in=all_groups), group__isnull=False),
|
||||||
|
negate=True,
|
||||||
|
),
|
||||||
|
enabled=True,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
passing = False
|
||||||
|
if matched_bindings["matching"] > 0:
|
||||||
|
passing = True
|
||||||
|
elif matched_bindings["total"] > 0 and matched_bindings["matching"] < 1:
|
||||||
|
passing = False
|
||||||
|
self.__static_result = PolicyResult(passing)
|
||||||
|
|
||||||
def build(self) -> "PolicyEngine":
|
def build(self) -> "PolicyEngine":
|
||||||
"""Build wrapper which monitors performance"""
|
"""Build wrapper which monitors performance"""
|
||||||
@ -181,6 +186,8 @@ class PolicyEngine:
|
|||||||
all_results = list(process_results + self.__cached_policies)
|
all_results = list(process_results + self.__cached_policies)
|
||||||
if len(all_results) < self.__expected_result_count: # pragma: no cover
|
if len(all_results) < self.__expected_result_count: # pragma: no cover
|
||||||
raise AssertionError("Got less results than polices")
|
raise AssertionError("Got less results than polices")
|
||||||
|
if self.__static_result:
|
||||||
|
all_results.append(self.__static_result)
|
||||||
# No results, no policies attached -> passing
|
# No results, no policies attached -> passing
|
||||||
if len(all_results) == 0:
|
if len(all_results) == 0:
|
||||||
return PolicyResult(self.empty_result)
|
return PolicyResult(self.empty_result)
|
||||||
|
|||||||
@ -1,8 +1,11 @@
|
|||||||
"""policy engine tests"""
|
"""policy engine tests"""
|
||||||
|
|
||||||
from django.core.cache import cache
|
from django.core.cache import cache
|
||||||
|
from django.db import connections
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
|
from django.test.utils import CaptureQueriesContext
|
||||||
|
|
||||||
|
from authentik.core.models import Group
|
||||||
from authentik.core.tests.utils import create_test_admin_user
|
from authentik.core.tests.utils import create_test_admin_user
|
||||||
from authentik.lib.generators import generate_id
|
from authentik.lib.generators import generate_id
|
||||||
from authentik.policies.dummy.models import DummyPolicy
|
from authentik.policies.dummy.models import DummyPolicy
|
||||||
@ -127,3 +130,32 @@ class TestPolicyEngine(TestCase):
|
|||||||
self.assertEqual(len(cache.keys(f"{CACHE_PREFIX}{binding.policy_binding_uuid.hex}*")), 1)
|
self.assertEqual(len(cache.keys(f"{CACHE_PREFIX}{binding.policy_binding_uuid.hex}*")), 1)
|
||||||
self.assertEqual(engine.build().passing, False)
|
self.assertEqual(engine.build().passing, False)
|
||||||
self.assertEqual(len(cache.keys(f"{CACHE_PREFIX}{binding.policy_binding_uuid.hex}*")), 1)
|
self.assertEqual(len(cache.keys(f"{CACHE_PREFIX}{binding.policy_binding_uuid.hex}*")), 1)
|
||||||
|
|
||||||
|
def test_engine_static_bindings(self):
|
||||||
|
"""Test static bindings"""
|
||||||
|
group = Group.objects.create(name=generate_id())
|
||||||
|
|
||||||
|
pbm = PolicyBindingModel.objects.create()
|
||||||
|
for x in range(1000):
|
||||||
|
PolicyBinding.objects.create(target=pbm, group=group, order=x)
|
||||||
|
engine = PolicyEngine(pbm, self.user)
|
||||||
|
engine.use_cache = False
|
||||||
|
with CaptureQueriesContext(connections["default"]) as ctx:
|
||||||
|
engine.build()
|
||||||
|
self.assertLess(ctx.final_queries, 1000)
|
||||||
|
self.assertEqual(engine.result.passing, False)
|
||||||
|
|
||||||
|
def test_engine_static_bindings_matching(self):
|
||||||
|
"""Test static bindings"""
|
||||||
|
group = Group.objects.create(name=generate_id())
|
||||||
|
group.users.add(self.user)
|
||||||
|
|
||||||
|
pbm = PolicyBindingModel.objects.create()
|
||||||
|
for x in range(1000):
|
||||||
|
PolicyBinding.objects.create(target=pbm, group=group, order=x)
|
||||||
|
engine = PolicyEngine(pbm, self.user)
|
||||||
|
engine.use_cache = False
|
||||||
|
with CaptureQueriesContext(connections["default"]) as ctx:
|
||||||
|
engine.build()
|
||||||
|
self.assertLess(ctx.final_queries, 1000)
|
||||||
|
self.assertEqual(engine.result.passing, False)
|
||||||
|
|||||||
Reference in New Issue
Block a user