Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens Langhammer
2025-06-08 02:07:23 +02:00
parent dd8e71df7d
commit b7d2c5188b
2 changed files with 64 additions and 25 deletions

View File

@ -5,7 +5,7 @@ from multiprocessing.connection import Connection
from time import perf_counter
from django.core.cache import cache
from django.db.models import Q, QuerySet
from django.db.models import Count, Q, QuerySet
from django.http import HttpRequest
from sentry_sdk import start_span
from sentry_sdk.tracing import Span
@ -25,17 +25,12 @@ CURRENT_PROCESS = current_process()
class PolicyProcessInfo:
"""Dataclass to hold all information and communication channels to a process"""
process: PolicyProcess | None
connection: Connection | None
process: PolicyProcess
connection: Connection
result: PolicyResult | None
binding: PolicyBinding
def __init__(
self,
process: PolicyProcess | None,
connection: Connection | None,
binding: PolicyBinding,
):
def __init__(self, process: PolicyProcess, connection: Connection, binding: PolicyBinding):
self.process = process
self.connection = connection
self.binding = binding
@ -72,6 +67,7 @@ class PolicyEngine:
self.__processes: list[PolicyProcessInfo] = []
self.use_cache = True
self.__expected_result_count = 0
self.__static_result = None
def bindings(self) -> QuerySet[PolicyBinding]:
"""Make sure all Policies are their respective classes"""
@ -111,7 +107,13 @@ class PolicyEngine:
def compute_static_bindings(self, bindings: QuerySet[PolicyBinding]):
all_groups = self.request.user.all_groups()
matched_bindings = bindings.filter(
matched_bindings = bindings.aggregate(
total=Count(
"pk", filter=Q(Q(group__isnull=False) | Q(user__isnull=False), policy=None)
),
matching=Count(
"pk",
filter=Q(
Q(
Q(user=self.request.user) | Q(group__in=all_groups),
negate=False,
@ -122,12 +124,15 @@ class PolicyEngine:
negate=True,
),
enabled=True,
).order_by("order")
for binding in matched_bindings:
self.__expected_result_count += 1
pi = PolicyProcessInfo(process=None, connection=None, binding=binding)
pi.result = PolicyResult(True)
self.__processes.append(pi)
),
),
)
passing = False
if matched_bindings["matching"] > 0:
passing = True
elif matched_bindings["total"] > 0 and matched_bindings["matching"] < 1:
passing = False
self.__static_result = PolicyResult(passing)
def build(self) -> "PolicyEngine":
"""Build wrapper which monitors performance"""
@ -181,6 +186,8 @@ class PolicyEngine:
all_results = list(process_results + self.__cached_policies)
if len(all_results) < self.__expected_result_count: # pragma: no cover
raise AssertionError("Got less results than polices")
if self.__static_result:
all_results.append(self.__static_result)
# No results, no policies attached -> passing
if len(all_results) == 0:
return PolicyResult(self.empty_result)

View File

@ -1,8 +1,11 @@
"""policy engine tests"""
from django.core.cache import cache
from django.db import connections
from django.test import TestCase
from django.test.utils import CaptureQueriesContext
from authentik.core.models import Group
from authentik.core.tests.utils import create_test_admin_user
from authentik.lib.generators import generate_id
from authentik.policies.dummy.models import DummyPolicy
@ -127,3 +130,32 @@ class TestPolicyEngine(TestCase):
self.assertEqual(len(cache.keys(f"{CACHE_PREFIX}{binding.policy_binding_uuid.hex}*")), 1)
self.assertEqual(engine.build().passing, False)
self.assertEqual(len(cache.keys(f"{CACHE_PREFIX}{binding.policy_binding_uuid.hex}*")), 1)
def test_engine_static_bindings(self):
"""Test static bindings"""
group = Group.objects.create(name=generate_id())
pbm = PolicyBindingModel.objects.create()
for x in range(1000):
PolicyBinding.objects.create(target=pbm, group=group, order=x)
engine = PolicyEngine(pbm, self.user)
engine.use_cache = False
with CaptureQueriesContext(connections["default"]) as ctx:
engine.build()
self.assertLess(ctx.final_queries, 1000)
self.assertEqual(engine.result.passing, False)
def test_engine_static_bindings_matching(self):
"""Test static bindings"""
group = Group.objects.create(name=generate_id())
group.users.add(self.user)
pbm = PolicyBindingModel.objects.create()
for x in range(1000):
PolicyBinding.objects.create(target=pbm, group=group, order=x)
engine = PolicyEngine(pbm, self.user)
engine.use_cache = False
with CaptureQueriesContext(connections["default"]) as ctx:
engine.build()
self.assertLess(ctx.final_queries, 1000)
self.assertEqual(engine.result.passing, False)