Compare commits

...

9 Commits

Author SHA1 Message Date
6e1cf8a23c slight refactors
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2025-06-09 00:48:28 +02:00
fde6120e67 fix
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2025-06-08 15:52:24 +02:00
dabd812071 fix more
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2025-06-08 03:03:48 +02:00
db92da4cb8 fix em actually
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2025-06-08 02:51:27 +02:00
81c23fff98 found the first bug
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2025-06-08 02:25:47 +02:00
54b5774a15 better
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2025-06-08 02:16:03 +02:00
ba4650a088 less hardcoded names
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2025-06-08 02:07:29 +02:00
b7d2c5188b improve
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2025-06-08 02:07:23 +02:00
dd8e71df7d initial optimisation
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2025-06-08 01:24:12 +02:00
5 changed files with 131 additions and 43 deletions

View File

@ -153,10 +153,10 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet):
return applications
def _filter_applications_with_launch_url(
self, pagined_apps: Iterator[Application]
self, paginated_apps: Iterator[Application]
) -> list[Application]:
applications = []
for app in pagined_apps:
for app in paginated_apps:
if app.get_launch_url():
applications.append(app)
return applications

View File

@ -1,11 +1,11 @@
"""authentik policy engine"""
from collections.abc import Iterator
from collections.abc import Iterable
from multiprocessing import Pipe, current_process
from multiprocessing.connection import Connection
from time import perf_counter
from django.core.cache import cache
from django.db.models import Count, Q, QuerySet
from django.http import HttpRequest
from sentry_sdk import start_span
from sentry_sdk.tracing import Span
@ -67,14 +67,11 @@ class PolicyEngine:
self.__processes: list[PolicyProcessInfo] = []
self.use_cache = True
self.__expected_result_count = 0
self.__static_result: PolicyResult | None = None
def iterate_bindings(self) -> Iterator[PolicyBinding]:
def bindings(self) -> QuerySet[PolicyBinding] | Iterable[PolicyBinding]:
"""Make sure all Policies are their respective classes"""
return (
PolicyBinding.objects.filter(target=self.__pbm, enabled=True)
.order_by("order")
.iterator()
)
return PolicyBinding.objects.filter(target=self.__pbm, enabled=True).order_by("order")
def _check_policy_type(self, binding: PolicyBinding):
"""Check policy type, make sure it's not the root class as that has no logic implemented"""
@ -84,30 +81,66 @@ class PolicyEngine:
def _check_cache(self, binding: PolicyBinding):
if not self.use_cache:
return False
before = perf_counter()
key = cache_key(binding, self.request)
cached_policy = cache.get(key, None)
duration = max(perf_counter() - before, 0)
if not cached_policy:
return False
self.logger.debug(
"P_ENG: Taking result from cache",
binding=binding,
cache_key=key,
request=self.request,
)
HIST_POLICIES_EXECUTION_TIME.labels(
# It's a bit silly to time this, but
with HIST_POLICIES_EXECUTION_TIME.labels(
binding_order=binding.order,
binding_target_type=binding.target_type,
binding_target_name=binding.target_name,
object_pk=str(self.request.obj.pk),
object_type=class_to_path(self.request.obj.__class__),
mode="cache_retrieve",
).observe(duration)
# It's a bit silly to time this, but
).time():
key = cache_key(binding, self.request)
cached_policy = cache.get(key, None)
if not cached_policy:
return False
self.logger.debug(
"P_ENG: Taking result from cache",
binding=binding,
cache_key=key,
request=self.request,
)
self.__cached_policies.append(cached_policy)
return True
def compute_static_bindings(self, bindings: QuerySet[PolicyBinding]):
"""Check static bindings if possible"""
aggrs = {
"total": Count(
"pk", filter=Q(Q(group__isnull=False) | Q(user__isnull=False), policy=None)
),
}
if self.request.user.pk:
all_groups = self.request.user.all_groups()
aggrs["passing"] = Count(
"pk",
filter=Q(
Q(
Q(user=self.request.user) | Q(group__in=all_groups),
negate=False,
)
| Q(
Q(~Q(user=self.request.user), user__isnull=False)
| Q(~Q(group__in=all_groups), group__isnull=False),
negate=True,
),
enabled=True,
),
)
matched_bindings = bindings.aggregate(**aggrs)
passing = False
if matched_bindings["total"] == 0 and matched_bindings.get("passing", 0) == 0:
# If we didn't find any static bindings, do nothing
return
self.logger.debug("P_ENG: Found static bindings", **matched_bindings)
if matched_bindings.get("passing", 0) > 0:
# Any passing static binding -> passing
passing = True
elif matched_bindings["total"] > 0 and matched_bindings.get("passing", 0) < 1:
# No matching static bindings but at least one is configured -> not passing
passing = False
self.__static_result = PolicyResult(passing)
def build(self) -> "PolicyEngine":
"""Build wrapper which monitors performance"""
with (
@ -123,7 +156,12 @@ class PolicyEngine:
span: Span
span.set_data("pbm", self.__pbm)
span.set_data("request", self.request)
for binding in self.iterate_bindings():
bindings = self.bindings()
policy_bindings = bindings
if isinstance(bindings, QuerySet):
self.compute_static_bindings(bindings)
policy_bindings = [x for x in bindings if x.policy]
for binding in policy_bindings:
self.__expected_result_count += 1
self._check_policy_type(binding)
@ -153,10 +191,13 @@ class PolicyEngine:
@property
def result(self) -> PolicyResult:
"""Get policy-checking result"""
self.__processes.sort(key=lambda x: x.binding.order)
process_results: list[PolicyResult] = [x.result for x in self.__processes if x.result]
all_results = list(process_results + self.__cached_policies)
if len(all_results) < self.__expected_result_count: # pragma: no cover
raise AssertionError("Got less results than polices")
if self.__static_result:
all_results.append(self.__static_result)
# No results, no policies attached -> passing
if len(all_results) == 0:
return PolicyResult(self.empty_result)

View File

@ -1,9 +1,12 @@
"""policy engine tests"""
from django.core.cache import cache
from django.db import connections
from django.test import TestCase
from django.test.utils import CaptureQueriesContext
from authentik.core.tests.utils import create_test_admin_user
from authentik.core.models import Group
from authentik.core.tests.utils import create_test_user
from authentik.lib.generators import generate_id
from authentik.policies.dummy.models import DummyPolicy
from authentik.policies.engine import PolicyEngine
@ -19,7 +22,7 @@ class TestPolicyEngine(TestCase):
def setUp(self):
clear_policy_cache()
self.user = create_test_admin_user()
self.user = create_test_user()
self.policy_false = DummyPolicy.objects.create(
name=generate_id(), result=False, wait_min=0, wait_max=1
)
@ -127,3 +130,43 @@ class TestPolicyEngine(TestCase):
self.assertEqual(len(cache.keys(f"{CACHE_PREFIX}{binding.policy_binding_uuid.hex}*")), 1)
self.assertEqual(engine.build().passing, False)
self.assertEqual(len(cache.keys(f"{CACHE_PREFIX}{binding.policy_binding_uuid.hex}*")), 1)
def test_engine_static_bindings(self):
"""Test static bindings"""
group_a = Group.objects.create(name=generate_id())
group_b = Group.objects.create(name=generate_id())
group_b.users.add(self.user)
user = create_test_user()
for case in [
{
"message": "Group, not member",
"binding_args": {"group": group_a},
"passing": False,
},
{
"message": "Group, member",
"binding_args": {"group": group_b},
"passing": True,
},
{
"message": "User, other",
"binding_args": {"user": user},
"passing": False,
},
{
"message": "User, same",
"binding_args": {"user": self.user},
"passing": True,
},
]:
with self.subTest():
pbm = PolicyBindingModel.objects.create()
for x in range(1000):
PolicyBinding.objects.create(target=pbm, order=x, **case["binding_args"])
engine = PolicyEngine(pbm, self.user)
engine.use_cache = False
with CaptureQueriesContext(connections["default"]) as ctx:
engine.build()
self.assertLess(ctx.final_queries, 1000)
self.assertEqual(engine.result.passing, case["passing"])

View File

@ -29,13 +29,12 @@ class TestPolicyProcess(TestCase):
def setUp(self):
clear_policy_cache()
self.factory = RequestFactory()
self.user = User.objects.create_user(username="policyuser")
self.user = User.objects.create_user(username=generate_id())
def test_group_passing(self):
"""Test binding to group"""
group = Group.objects.create(name="test-group")
group = Group.objects.create(name=generate_id())
group.users.add(self.user)
group.save()
binding = PolicyBinding(group=group)
request = PolicyRequest(self.user)
@ -44,8 +43,7 @@ class TestPolicyProcess(TestCase):
def test_group_negative(self):
"""Test binding to group"""
group = Group.objects.create(name="test-group")
group.save()
group = Group.objects.create(name=generate_id())
binding = PolicyBinding(group=group)
request = PolicyRequest(self.user)
@ -115,8 +113,10 @@ class TestPolicyProcess(TestCase):
def test_exception(self):
"""Test policy execution"""
policy = Policy.objects.create(name="test-execution")
binding = PolicyBinding(policy=policy, target=Application.objects.create(name="test"))
policy = Policy.objects.create(name=generate_id())
binding = PolicyBinding(
policy=policy, target=Application.objects.create(name=generate_id())
)
request = PolicyRequest(self.user)
response = PolicyProcess(binding, request, None).execute()
@ -125,13 +125,15 @@ class TestPolicyProcess(TestCase):
def test_execution_logging(self):
"""Test policy execution creates event"""
policy = DummyPolicy.objects.create(
name="test-execution-logging",
name=generate_id(),
result=False,
wait_min=0,
wait_max=1,
execution_logging=True,
)
binding = PolicyBinding(policy=policy, target=Application.objects.create(name="test"))
binding = PolicyBinding(
policy=policy, target=Application.objects.create(name=generate_id())
)
http_request = self.factory.get(reverse("authentik_api:user-impersonate-end"))
http_request.user = self.user
@ -186,13 +188,15 @@ class TestPolicyProcess(TestCase):
def test_execution_logging_anonymous(self):
"""Test policy execution creates event with anonymous user"""
policy = DummyPolicy.objects.create(
name="test-execution-logging-anon",
name=generate_id(),
result=False,
wait_min=0,
wait_max=1,
execution_logging=True,
)
binding = PolicyBinding(policy=policy, target=Application.objects.create(name="test"))
binding = PolicyBinding(
policy=policy, target=Application.objects.create(name=generate_id())
)
user = AnonymousUser()
@ -219,9 +223,9 @@ class TestPolicyProcess(TestCase):
def test_raises(self):
"""Test policy that raises error"""
policy_raises = ExpressionPolicy.objects.create(name="raises", expression="{{ 0/0 }}")
policy_raises = ExpressionPolicy.objects.create(name=generate_id(), expression="{{ 0/0 }}")
binding = PolicyBinding(
policy=policy_raises, target=Application.objects.create(name="test")
policy=policy_raises, target=Application.objects.create(name=generate_id())
)
request = PolicyRequest(self.user)

View File

@ -1,6 +1,6 @@
"""Prompt Stage Logic"""
from collections.abc import Callable, Iterator
from collections.abc import Callable
from email.policy import Policy
from types import MethodType
from typing import Any
@ -190,7 +190,7 @@ class ListPolicyEngine(PolicyEngine):
self.__list = policies
self.use_cache = False
def iterate_bindings(self) -> Iterator[PolicyBinding]:
def bindings(self):
for policy in self.__list:
yield PolicyBinding(
policy=policy,