Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens Langhammer
2025-06-08 02:07:23 +02:00
parent dd8e71df7d
commit b7d2c5188b
2 changed files with 64 additions and 25 deletions

View File

@ -5,7 +5,7 @@ from multiprocessing.connection import Connection
from time import perf_counter from time import perf_counter
from django.core.cache import cache from django.core.cache import cache
from django.db.models import Q, QuerySet from django.db.models import Count, Q, QuerySet
from django.http import HttpRequest from django.http import HttpRequest
from sentry_sdk import start_span from sentry_sdk import start_span
from sentry_sdk.tracing import Span from sentry_sdk.tracing import Span
@ -25,17 +25,12 @@ CURRENT_PROCESS = current_process()
class PolicyProcessInfo: class PolicyProcessInfo:
"""Dataclass to hold all information and communication channels to a process""" """Dataclass to hold all information and communication channels to a process"""
process: PolicyProcess | None process: PolicyProcess
connection: Connection | None connection: Connection
result: PolicyResult | None result: PolicyResult | None
binding: PolicyBinding binding: PolicyBinding
def __init__( def __init__(self, process: PolicyProcess, connection: Connection, binding: PolicyBinding):
self,
process: PolicyProcess | None,
connection: Connection | None,
binding: PolicyBinding,
):
self.process = process self.process = process
self.connection = connection self.connection = connection
self.binding = binding self.binding = binding
@ -72,6 +67,7 @@ class PolicyEngine:
self.__processes: list[PolicyProcessInfo] = [] self.__processes: list[PolicyProcessInfo] = []
self.use_cache = True self.use_cache = True
self.__expected_result_count = 0 self.__expected_result_count = 0
self.__static_result = None
def bindings(self) -> QuerySet[PolicyBinding]: def bindings(self) -> QuerySet[PolicyBinding]:
"""Make sure all Policies are their respective classes""" """Make sure all Policies are their respective classes"""
@ -111,23 +107,32 @@ class PolicyEngine:
def compute_static_bindings(self, bindings: QuerySet[PolicyBinding]): def compute_static_bindings(self, bindings: QuerySet[PolicyBinding]):
all_groups = self.request.user.all_groups() all_groups = self.request.user.all_groups()
matched_bindings = bindings.filter( matched_bindings = bindings.aggregate(
Q( total=Count(
Q(user=self.request.user) | Q(group__in=all_groups), "pk", filter=Q(Q(group__isnull=False) | Q(user__isnull=False), policy=None)
negate=False,
)
| Q(
Q(~Q(user=self.request.user), user__isnull=False)
| Q(~Q(group__in=all_groups), group__isnull=False),
negate=True,
), ),
enabled=True, matching=Count(
).order_by("order") "pk",
for binding in matched_bindings: filter=Q(
self.__expected_result_count += 1 Q(
pi = PolicyProcessInfo(process=None, connection=None, binding=binding) Q(user=self.request.user) | Q(group__in=all_groups),
pi.result = PolicyResult(True) negate=False,
self.__processes.append(pi) )
| Q(
Q(~Q(user=self.request.user), user__isnull=False)
| Q(~Q(group__in=all_groups), group__isnull=False),
negate=True,
),
enabled=True,
),
),
)
passing = False
if matched_bindings["matching"] > 0:
passing = True
elif matched_bindings["total"] > 0 and matched_bindings["matching"] < 1:
passing = False
self.__static_result = PolicyResult(passing)
def build(self) -> "PolicyEngine": def build(self) -> "PolicyEngine":
"""Build wrapper which monitors performance""" """Build wrapper which monitors performance"""
@ -181,6 +186,8 @@ class PolicyEngine:
all_results = list(process_results + self.__cached_policies) all_results = list(process_results + self.__cached_policies)
if len(all_results) < self.__expected_result_count: # pragma: no cover if len(all_results) < self.__expected_result_count: # pragma: no cover
raise AssertionError("Got less results than polices") raise AssertionError("Got less results than polices")
if self.__static_result:
all_results.append(self.__static_result)
# No results, no policies attached -> passing # No results, no policies attached -> passing
if len(all_results) == 0: if len(all_results) == 0:
return PolicyResult(self.empty_result) return PolicyResult(self.empty_result)

View File

@ -1,8 +1,11 @@
"""policy engine tests""" """policy engine tests"""
from django.core.cache import cache from django.core.cache import cache
from django.db import connections
from django.test import TestCase from django.test import TestCase
from django.test.utils import CaptureQueriesContext
from authentik.core.models import Group
from authentik.core.tests.utils import create_test_admin_user from authentik.core.tests.utils import create_test_admin_user
from authentik.lib.generators import generate_id from authentik.lib.generators import generate_id
from authentik.policies.dummy.models import DummyPolicy from authentik.policies.dummy.models import DummyPolicy
@ -127,3 +130,32 @@ class TestPolicyEngine(TestCase):
self.assertEqual(len(cache.keys(f"{CACHE_PREFIX}{binding.policy_binding_uuid.hex}*")), 1) self.assertEqual(len(cache.keys(f"{CACHE_PREFIX}{binding.policy_binding_uuid.hex}*")), 1)
self.assertEqual(engine.build().passing, False) self.assertEqual(engine.build().passing, False)
self.assertEqual(len(cache.keys(f"{CACHE_PREFIX}{binding.policy_binding_uuid.hex}*")), 1) self.assertEqual(len(cache.keys(f"{CACHE_PREFIX}{binding.policy_binding_uuid.hex}*")), 1)
def test_engine_static_bindings(self):
"""Test static bindings"""
group = Group.objects.create(name=generate_id())
pbm = PolicyBindingModel.objects.create()
for x in range(1000):
PolicyBinding.objects.create(target=pbm, group=group, order=x)
engine = PolicyEngine(pbm, self.user)
engine.use_cache = False
with CaptureQueriesContext(connections["default"]) as ctx:
engine.build()
self.assertLess(ctx.final_queries, 1000)
self.assertEqual(engine.result.passing, False)
def test_engine_static_bindings_matching(self):
"""Test static bindings"""
group = Group.objects.create(name=generate_id())
group.users.add(self.user)
pbm = PolicyBindingModel.objects.create()
for x in range(1000):
PolicyBinding.objects.create(target=pbm, group=group, order=x)
engine = PolicyEngine(pbm, self.user)
engine.use_cache = False
with CaptureQueriesContext(connections["default"]) as ctx:
engine.build()
self.assertLess(ctx.final_queries, 1000)
self.assertEqual(engine.result.passing, False)