Compare commits

...

9 Commits

Author SHA1 Message Date
6e1cf8a23c slight refactors
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2025-06-09 00:48:28 +02:00
fde6120e67 fix
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2025-06-08 15:52:24 +02:00
dabd812071 fix more
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2025-06-08 03:03:48 +02:00
db92da4cb8 fix em actually
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2025-06-08 02:51:27 +02:00
81c23fff98 found the first bug
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2025-06-08 02:25:47 +02:00
54b5774a15 better
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2025-06-08 02:16:03 +02:00
ba4650a088 less hardcoded names
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2025-06-08 02:07:29 +02:00
b7d2c5188b improve
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2025-06-08 02:07:23 +02:00
dd8e71df7d initial optimisation
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2025-06-08 01:24:12 +02:00
5 changed files with 131 additions and 43 deletions

View File

@ -153,10 +153,10 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet):
return applications return applications
def _filter_applications_with_launch_url( def _filter_applications_with_launch_url(
self, pagined_apps: Iterator[Application] self, paginated_apps: Iterator[Application]
) -> list[Application]: ) -> list[Application]:
applications = [] applications = []
for app in pagined_apps: for app in paginated_apps:
if app.get_launch_url(): if app.get_launch_url():
applications.append(app) applications.append(app)
return applications return applications

View File

@ -1,11 +1,11 @@
"""authentik policy engine""" """authentik policy engine"""
from collections.abc import Iterator from collections.abc import Iterable
from multiprocessing import Pipe, current_process from multiprocessing import Pipe, current_process
from multiprocessing.connection import Connection from multiprocessing.connection import Connection
from time import perf_counter
from django.core.cache import cache from django.core.cache import cache
from django.db.models import Count, Q, QuerySet
from django.http import HttpRequest from django.http import HttpRequest
from sentry_sdk import start_span from sentry_sdk import start_span
from sentry_sdk.tracing import Span from sentry_sdk.tracing import Span
@ -67,14 +67,11 @@ class PolicyEngine:
self.__processes: list[PolicyProcessInfo] = [] self.__processes: list[PolicyProcessInfo] = []
self.use_cache = True self.use_cache = True
self.__expected_result_count = 0 self.__expected_result_count = 0
self.__static_result: PolicyResult | None = None
def iterate_bindings(self) -> Iterator[PolicyBinding]: def bindings(self) -> QuerySet[PolicyBinding] | Iterable[PolicyBinding]:
"""Make sure all Policies are their respective classes""" """Make sure all Policies are their respective classes"""
return ( return PolicyBinding.objects.filter(target=self.__pbm, enabled=True).order_by("order")
PolicyBinding.objects.filter(target=self.__pbm, enabled=True)
.order_by("order")
.iterator()
)
def _check_policy_type(self, binding: PolicyBinding): def _check_policy_type(self, binding: PolicyBinding):
"""Check policy type, make sure it's not the root class as that has no logic implemented""" """Check policy type, make sure it's not the root class as that has no logic implemented"""
@ -84,10 +81,17 @@ class PolicyEngine:
def _check_cache(self, binding: PolicyBinding): def _check_cache(self, binding: PolicyBinding):
if not self.use_cache: if not self.use_cache:
return False return False
before = perf_counter() # It's a bit silly to time this, but
with HIST_POLICIES_EXECUTION_TIME.labels(
binding_order=binding.order,
binding_target_type=binding.target_type,
binding_target_name=binding.target_name,
object_pk=str(self.request.obj.pk),
object_type=class_to_path(self.request.obj.__class__),
mode="cache_retrieve",
).time():
key = cache_key(binding, self.request) key = cache_key(binding, self.request)
cached_policy = cache.get(key, None) cached_policy = cache.get(key, None)
duration = max(perf_counter() - before, 0)
if not cached_policy: if not cached_policy:
return False return False
self.logger.debug( self.logger.debug(
@ -96,18 +100,47 @@ class PolicyEngine:
cache_key=key, cache_key=key,
request=self.request, request=self.request,
) )
HIST_POLICIES_EXECUTION_TIME.labels(
binding_order=binding.order,
binding_target_type=binding.target_type,
binding_target_name=binding.target_name,
object_pk=str(self.request.obj.pk),
object_type=class_to_path(self.request.obj.__class__),
mode="cache_retrieve",
).observe(duration)
# It's a bit silly to time this, but
self.__cached_policies.append(cached_policy) self.__cached_policies.append(cached_policy)
return True return True
def compute_static_bindings(self, bindings: QuerySet[PolicyBinding]):
"""Check static bindings if possible"""
aggrs = {
"total": Count(
"pk", filter=Q(Q(group__isnull=False) | Q(user__isnull=False), policy=None)
),
}
if self.request.user.pk:
all_groups = self.request.user.all_groups()
aggrs["passing"] = Count(
"pk",
filter=Q(
Q(
Q(user=self.request.user) | Q(group__in=all_groups),
negate=False,
)
| Q(
Q(~Q(user=self.request.user), user__isnull=False)
| Q(~Q(group__in=all_groups), group__isnull=False),
negate=True,
),
enabled=True,
),
)
matched_bindings = bindings.aggregate(**aggrs)
passing = False
if matched_bindings["total"] == 0 and matched_bindings.get("passing", 0) == 0:
# If we didn't find any static bindings, do nothing
return
self.logger.debug("P_ENG: Found static bindings", **matched_bindings)
if matched_bindings.get("passing", 0) > 0:
# Any passing static binding -> passing
passing = True
elif matched_bindings["total"] > 0 and matched_bindings.get("passing", 0) < 1:
# No matching static bindings but at least one is configured -> not passing
passing = False
self.__static_result = PolicyResult(passing)
def build(self) -> "PolicyEngine": def build(self) -> "PolicyEngine":
"""Build wrapper which monitors performance""" """Build wrapper which monitors performance"""
with ( with (
@ -123,7 +156,12 @@ class PolicyEngine:
span: Span span: Span
span.set_data("pbm", self.__pbm) span.set_data("pbm", self.__pbm)
span.set_data("request", self.request) span.set_data("request", self.request)
for binding in self.iterate_bindings(): bindings = self.bindings()
policy_bindings = bindings
if isinstance(bindings, QuerySet):
self.compute_static_bindings(bindings)
policy_bindings = [x for x in bindings if x.policy]
for binding in policy_bindings:
self.__expected_result_count += 1 self.__expected_result_count += 1
self._check_policy_type(binding) self._check_policy_type(binding)
@ -153,10 +191,13 @@ class PolicyEngine:
@property @property
def result(self) -> PolicyResult: def result(self) -> PolicyResult:
"""Get policy-checking result""" """Get policy-checking result"""
self.__processes.sort(key=lambda x: x.binding.order)
process_results: list[PolicyResult] = [x.result for x in self.__processes if x.result] process_results: list[PolicyResult] = [x.result for x in self.__processes if x.result]
all_results = list(process_results + self.__cached_policies) all_results = list(process_results + self.__cached_policies)
if len(all_results) < self.__expected_result_count: # pragma: no cover if len(all_results) < self.__expected_result_count: # pragma: no cover
raise AssertionError("Got less results than polices") raise AssertionError("Got less results than polices")
if self.__static_result:
all_results.append(self.__static_result)
# No results, no policies attached -> passing # No results, no policies attached -> passing
if len(all_results) == 0: if len(all_results) == 0:
return PolicyResult(self.empty_result) return PolicyResult(self.empty_result)

View File

@ -1,9 +1,12 @@
"""policy engine tests""" """policy engine tests"""
from django.core.cache import cache from django.core.cache import cache
from django.db import connections
from django.test import TestCase from django.test import TestCase
from django.test.utils import CaptureQueriesContext
from authentik.core.tests.utils import create_test_admin_user from authentik.core.models import Group
from authentik.core.tests.utils import create_test_user
from authentik.lib.generators import generate_id from authentik.lib.generators import generate_id
from authentik.policies.dummy.models import DummyPolicy from authentik.policies.dummy.models import DummyPolicy
from authentik.policies.engine import PolicyEngine from authentik.policies.engine import PolicyEngine
@ -19,7 +22,7 @@ class TestPolicyEngine(TestCase):
def setUp(self): def setUp(self):
clear_policy_cache() clear_policy_cache()
self.user = create_test_admin_user() self.user = create_test_user()
self.policy_false = DummyPolicy.objects.create( self.policy_false = DummyPolicy.objects.create(
name=generate_id(), result=False, wait_min=0, wait_max=1 name=generate_id(), result=False, wait_min=0, wait_max=1
) )
@ -127,3 +130,43 @@ class TestPolicyEngine(TestCase):
self.assertEqual(len(cache.keys(f"{CACHE_PREFIX}{binding.policy_binding_uuid.hex}*")), 1) self.assertEqual(len(cache.keys(f"{CACHE_PREFIX}{binding.policy_binding_uuid.hex}*")), 1)
self.assertEqual(engine.build().passing, False) self.assertEqual(engine.build().passing, False)
self.assertEqual(len(cache.keys(f"{CACHE_PREFIX}{binding.policy_binding_uuid.hex}*")), 1) self.assertEqual(len(cache.keys(f"{CACHE_PREFIX}{binding.policy_binding_uuid.hex}*")), 1)
def test_engine_static_bindings(self):
"""Test static bindings"""
group_a = Group.objects.create(name=generate_id())
group_b = Group.objects.create(name=generate_id())
group_b.users.add(self.user)
user = create_test_user()
for case in [
{
"message": "Group, not member",
"binding_args": {"group": group_a},
"passing": False,
},
{
"message": "Group, member",
"binding_args": {"group": group_b},
"passing": True,
},
{
"message": "User, other",
"binding_args": {"user": user},
"passing": False,
},
{
"message": "User, same",
"binding_args": {"user": self.user},
"passing": True,
},
]:
with self.subTest():
pbm = PolicyBindingModel.objects.create()
for x in range(1000):
PolicyBinding.objects.create(target=pbm, order=x, **case["binding_args"])
engine = PolicyEngine(pbm, self.user)
engine.use_cache = False
with CaptureQueriesContext(connections["default"]) as ctx:
engine.build()
self.assertLess(ctx.final_queries, 1000)
self.assertEqual(engine.result.passing, case["passing"])

View File

@ -29,13 +29,12 @@ class TestPolicyProcess(TestCase):
def setUp(self): def setUp(self):
clear_policy_cache() clear_policy_cache()
self.factory = RequestFactory() self.factory = RequestFactory()
self.user = User.objects.create_user(username="policyuser") self.user = User.objects.create_user(username=generate_id())
def test_group_passing(self): def test_group_passing(self):
"""Test binding to group""" """Test binding to group"""
group = Group.objects.create(name="test-group") group = Group.objects.create(name=generate_id())
group.users.add(self.user) group.users.add(self.user)
group.save()
binding = PolicyBinding(group=group) binding = PolicyBinding(group=group)
request = PolicyRequest(self.user) request = PolicyRequest(self.user)
@ -44,8 +43,7 @@ class TestPolicyProcess(TestCase):
def test_group_negative(self): def test_group_negative(self):
"""Test binding to group""" """Test binding to group"""
group = Group.objects.create(name="test-group") group = Group.objects.create(name=generate_id())
group.save()
binding = PolicyBinding(group=group) binding = PolicyBinding(group=group)
request = PolicyRequest(self.user) request = PolicyRequest(self.user)
@ -115,8 +113,10 @@ class TestPolicyProcess(TestCase):
def test_exception(self): def test_exception(self):
"""Test policy execution""" """Test policy execution"""
policy = Policy.objects.create(name="test-execution") policy = Policy.objects.create(name=generate_id())
binding = PolicyBinding(policy=policy, target=Application.objects.create(name="test")) binding = PolicyBinding(
policy=policy, target=Application.objects.create(name=generate_id())
)
request = PolicyRequest(self.user) request = PolicyRequest(self.user)
response = PolicyProcess(binding, request, None).execute() response = PolicyProcess(binding, request, None).execute()
@ -125,13 +125,15 @@ class TestPolicyProcess(TestCase):
def test_execution_logging(self): def test_execution_logging(self):
"""Test policy execution creates event""" """Test policy execution creates event"""
policy = DummyPolicy.objects.create( policy = DummyPolicy.objects.create(
name="test-execution-logging", name=generate_id(),
result=False, result=False,
wait_min=0, wait_min=0,
wait_max=1, wait_max=1,
execution_logging=True, execution_logging=True,
) )
binding = PolicyBinding(policy=policy, target=Application.objects.create(name="test")) binding = PolicyBinding(
policy=policy, target=Application.objects.create(name=generate_id())
)
http_request = self.factory.get(reverse("authentik_api:user-impersonate-end")) http_request = self.factory.get(reverse("authentik_api:user-impersonate-end"))
http_request.user = self.user http_request.user = self.user
@ -186,13 +188,15 @@ class TestPolicyProcess(TestCase):
def test_execution_logging_anonymous(self): def test_execution_logging_anonymous(self):
"""Test policy execution creates event with anonymous user""" """Test policy execution creates event with anonymous user"""
policy = DummyPolicy.objects.create( policy = DummyPolicy.objects.create(
name="test-execution-logging-anon", name=generate_id(),
result=False, result=False,
wait_min=0, wait_min=0,
wait_max=1, wait_max=1,
execution_logging=True, execution_logging=True,
) )
binding = PolicyBinding(policy=policy, target=Application.objects.create(name="test")) binding = PolicyBinding(
policy=policy, target=Application.objects.create(name=generate_id())
)
user = AnonymousUser() user = AnonymousUser()
@ -219,9 +223,9 @@ class TestPolicyProcess(TestCase):
def test_raises(self): def test_raises(self):
"""Test policy that raises error""" """Test policy that raises error"""
policy_raises = ExpressionPolicy.objects.create(name="raises", expression="{{ 0/0 }}") policy_raises = ExpressionPolicy.objects.create(name=generate_id(), expression="{{ 0/0 }}")
binding = PolicyBinding( binding = PolicyBinding(
policy=policy_raises, target=Application.objects.create(name="test") policy=policy_raises, target=Application.objects.create(name=generate_id())
) )
request = PolicyRequest(self.user) request = PolicyRequest(self.user)

View File

@ -1,6 +1,6 @@
"""Prompt Stage Logic""" """Prompt Stage Logic"""
from collections.abc import Callable, Iterator from collections.abc import Callable
from email.policy import Policy from email.policy import Policy
from types import MethodType from types import MethodType
from typing import Any from typing import Any
@ -190,7 +190,7 @@ class ListPolicyEngine(PolicyEngine):
self.__list = policies self.__list = policies
self.use_cache = False self.use_cache = False
def iterate_bindings(self) -> Iterator[PolicyBinding]: def bindings(self):
for policy in self.__list: for policy in self.__list:
yield PolicyBinding( yield PolicyBinding(
policy=policy, policy=policy,