Compare commits
9 Commits
main
...
policies/o
Author | SHA1 | Date | |
---|---|---|---|
6e1cf8a23c | |||
fde6120e67 | |||
dabd812071 | |||
db92da4cb8 | |||
81c23fff98 | |||
54b5774a15 | |||
ba4650a088 | |||
b7d2c5188b | |||
dd8e71df7d |
@ -153,10 +153,10 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet):
|
||||
return applications
|
||||
|
||||
def _filter_applications_with_launch_url(
|
||||
self, pagined_apps: Iterator[Application]
|
||||
self, paginated_apps: Iterator[Application]
|
||||
) -> list[Application]:
|
||||
applications = []
|
||||
for app in pagined_apps:
|
||||
for app in paginated_apps:
|
||||
if app.get_launch_url():
|
||||
applications.append(app)
|
||||
return applications
|
||||
|
@ -1,11 +1,11 @@
|
||||
"""authentik policy engine"""
|
||||
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Iterable
|
||||
from multiprocessing import Pipe, current_process
|
||||
from multiprocessing.connection import Connection
|
||||
from time import perf_counter
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.db.models import Count, Q, QuerySet
|
||||
from django.http import HttpRequest
|
||||
from sentry_sdk import start_span
|
||||
from sentry_sdk.tracing import Span
|
||||
@ -67,14 +67,11 @@ class PolicyEngine:
|
||||
self.__processes: list[PolicyProcessInfo] = []
|
||||
self.use_cache = True
|
||||
self.__expected_result_count = 0
|
||||
self.__static_result: PolicyResult | None = None
|
||||
|
||||
def iterate_bindings(self) -> Iterator[PolicyBinding]:
|
||||
def bindings(self) -> QuerySet[PolicyBinding] | Iterable[PolicyBinding]:
|
||||
"""Make sure all Policies are their respective classes"""
|
||||
return (
|
||||
PolicyBinding.objects.filter(target=self.__pbm, enabled=True)
|
||||
.order_by("order")
|
||||
.iterator()
|
||||
)
|
||||
return PolicyBinding.objects.filter(target=self.__pbm, enabled=True).order_by("order")
|
||||
|
||||
def _check_policy_type(self, binding: PolicyBinding):
|
||||
"""Check policy type, make sure it's not the root class as that has no logic implemented"""
|
||||
@ -84,30 +81,66 @@ class PolicyEngine:
|
||||
def _check_cache(self, binding: PolicyBinding):
|
||||
if not self.use_cache:
|
||||
return False
|
||||
before = perf_counter()
|
||||
key = cache_key(binding, self.request)
|
||||
cached_policy = cache.get(key, None)
|
||||
duration = max(perf_counter() - before, 0)
|
||||
if not cached_policy:
|
||||
return False
|
||||
self.logger.debug(
|
||||
"P_ENG: Taking result from cache",
|
||||
binding=binding,
|
||||
cache_key=key,
|
||||
request=self.request,
|
||||
)
|
||||
HIST_POLICIES_EXECUTION_TIME.labels(
|
||||
# It's a bit silly to time this, but
|
||||
with HIST_POLICIES_EXECUTION_TIME.labels(
|
||||
binding_order=binding.order,
|
||||
binding_target_type=binding.target_type,
|
||||
binding_target_name=binding.target_name,
|
||||
object_pk=str(self.request.obj.pk),
|
||||
object_type=class_to_path(self.request.obj.__class__),
|
||||
mode="cache_retrieve",
|
||||
).observe(duration)
|
||||
# It's a bit silly to time this, but
|
||||
).time():
|
||||
key = cache_key(binding, self.request)
|
||||
cached_policy = cache.get(key, None)
|
||||
if not cached_policy:
|
||||
return False
|
||||
self.logger.debug(
|
||||
"P_ENG: Taking result from cache",
|
||||
binding=binding,
|
||||
cache_key=key,
|
||||
request=self.request,
|
||||
)
|
||||
self.__cached_policies.append(cached_policy)
|
||||
return True
|
||||
|
||||
def compute_static_bindings(self, bindings: QuerySet[PolicyBinding]):
|
||||
"""Check static bindings if possible"""
|
||||
aggrs = {
|
||||
"total": Count(
|
||||
"pk", filter=Q(Q(group__isnull=False) | Q(user__isnull=False), policy=None)
|
||||
),
|
||||
}
|
||||
if self.request.user.pk:
|
||||
all_groups = self.request.user.all_groups()
|
||||
aggrs["passing"] = Count(
|
||||
"pk",
|
||||
filter=Q(
|
||||
Q(
|
||||
Q(user=self.request.user) | Q(group__in=all_groups),
|
||||
negate=False,
|
||||
)
|
||||
| Q(
|
||||
Q(~Q(user=self.request.user), user__isnull=False)
|
||||
| Q(~Q(group__in=all_groups), group__isnull=False),
|
||||
negate=True,
|
||||
),
|
||||
enabled=True,
|
||||
),
|
||||
)
|
||||
matched_bindings = bindings.aggregate(**aggrs)
|
||||
passing = False
|
||||
if matched_bindings["total"] == 0 and matched_bindings.get("passing", 0) == 0:
|
||||
# If we didn't find any static bindings, do nothing
|
||||
return
|
||||
self.logger.debug("P_ENG: Found static bindings", **matched_bindings)
|
||||
if matched_bindings.get("passing", 0) > 0:
|
||||
# Any passing static binding -> passing
|
||||
passing = True
|
||||
elif matched_bindings["total"] > 0 and matched_bindings.get("passing", 0) < 1:
|
||||
# No matching static bindings but at least one is configured -> not passing
|
||||
passing = False
|
||||
self.__static_result = PolicyResult(passing)
|
||||
|
||||
def build(self) -> "PolicyEngine":
|
||||
"""Build wrapper which monitors performance"""
|
||||
with (
|
||||
@ -123,7 +156,12 @@ class PolicyEngine:
|
||||
span: Span
|
||||
span.set_data("pbm", self.__pbm)
|
||||
span.set_data("request", self.request)
|
||||
for binding in self.iterate_bindings():
|
||||
bindings = self.bindings()
|
||||
policy_bindings = bindings
|
||||
if isinstance(bindings, QuerySet):
|
||||
self.compute_static_bindings(bindings)
|
||||
policy_bindings = [x for x in bindings if x.policy]
|
||||
for binding in policy_bindings:
|
||||
self.__expected_result_count += 1
|
||||
|
||||
self._check_policy_type(binding)
|
||||
@ -153,10 +191,13 @@ class PolicyEngine:
|
||||
@property
|
||||
def result(self) -> PolicyResult:
|
||||
"""Get policy-checking result"""
|
||||
self.__processes.sort(key=lambda x: x.binding.order)
|
||||
process_results: list[PolicyResult] = [x.result for x in self.__processes if x.result]
|
||||
all_results = list(process_results + self.__cached_policies)
|
||||
if len(all_results) < self.__expected_result_count: # pragma: no cover
|
||||
raise AssertionError("Got less results than polices")
|
||||
if self.__static_result:
|
||||
all_results.append(self.__static_result)
|
||||
# No results, no policies attached -> passing
|
||||
if len(all_results) == 0:
|
||||
return PolicyResult(self.empty_result)
|
||||
|
@ -1,9 +1,12 @@
|
||||
"""policy engine tests"""
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.db import connections
|
||||
from django.test import TestCase
|
||||
from django.test.utils import CaptureQueriesContext
|
||||
|
||||
from authentik.core.tests.utils import create_test_admin_user
|
||||
from authentik.core.models import Group
|
||||
from authentik.core.tests.utils import create_test_user
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.policies.dummy.models import DummyPolicy
|
||||
from authentik.policies.engine import PolicyEngine
|
||||
@ -19,7 +22,7 @@ class TestPolicyEngine(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
clear_policy_cache()
|
||||
self.user = create_test_admin_user()
|
||||
self.user = create_test_user()
|
||||
self.policy_false = DummyPolicy.objects.create(
|
||||
name=generate_id(), result=False, wait_min=0, wait_max=1
|
||||
)
|
||||
@ -127,3 +130,43 @@ class TestPolicyEngine(TestCase):
|
||||
self.assertEqual(len(cache.keys(f"{CACHE_PREFIX}{binding.policy_binding_uuid.hex}*")), 1)
|
||||
self.assertEqual(engine.build().passing, False)
|
||||
self.assertEqual(len(cache.keys(f"{CACHE_PREFIX}{binding.policy_binding_uuid.hex}*")), 1)
|
||||
|
||||
def test_engine_static_bindings(self):
|
||||
"""Test static bindings"""
|
||||
group_a = Group.objects.create(name=generate_id())
|
||||
group_b = Group.objects.create(name=generate_id())
|
||||
group_b.users.add(self.user)
|
||||
user = create_test_user()
|
||||
|
||||
for case in [
|
||||
{
|
||||
"message": "Group, not member",
|
||||
"binding_args": {"group": group_a},
|
||||
"passing": False,
|
||||
},
|
||||
{
|
||||
"message": "Group, member",
|
||||
"binding_args": {"group": group_b},
|
||||
"passing": True,
|
||||
},
|
||||
{
|
||||
"message": "User, other",
|
||||
"binding_args": {"user": user},
|
||||
"passing": False,
|
||||
},
|
||||
{
|
||||
"message": "User, same",
|
||||
"binding_args": {"user": self.user},
|
||||
"passing": True,
|
||||
},
|
||||
]:
|
||||
with self.subTest():
|
||||
pbm = PolicyBindingModel.objects.create()
|
||||
for x in range(1000):
|
||||
PolicyBinding.objects.create(target=pbm, order=x, **case["binding_args"])
|
||||
engine = PolicyEngine(pbm, self.user)
|
||||
engine.use_cache = False
|
||||
with CaptureQueriesContext(connections["default"]) as ctx:
|
||||
engine.build()
|
||||
self.assertLess(ctx.final_queries, 1000)
|
||||
self.assertEqual(engine.result.passing, case["passing"])
|
||||
|
@ -29,13 +29,12 @@ class TestPolicyProcess(TestCase):
|
||||
def setUp(self):
|
||||
clear_policy_cache()
|
||||
self.factory = RequestFactory()
|
||||
self.user = User.objects.create_user(username="policyuser")
|
||||
self.user = User.objects.create_user(username=generate_id())
|
||||
|
||||
def test_group_passing(self):
|
||||
"""Test binding to group"""
|
||||
group = Group.objects.create(name="test-group")
|
||||
group = Group.objects.create(name=generate_id())
|
||||
group.users.add(self.user)
|
||||
group.save()
|
||||
binding = PolicyBinding(group=group)
|
||||
|
||||
request = PolicyRequest(self.user)
|
||||
@ -44,8 +43,7 @@ class TestPolicyProcess(TestCase):
|
||||
|
||||
def test_group_negative(self):
|
||||
"""Test binding to group"""
|
||||
group = Group.objects.create(name="test-group")
|
||||
group.save()
|
||||
group = Group.objects.create(name=generate_id())
|
||||
binding = PolicyBinding(group=group)
|
||||
|
||||
request = PolicyRequest(self.user)
|
||||
@ -115,8 +113,10 @@ class TestPolicyProcess(TestCase):
|
||||
|
||||
def test_exception(self):
|
||||
"""Test policy execution"""
|
||||
policy = Policy.objects.create(name="test-execution")
|
||||
binding = PolicyBinding(policy=policy, target=Application.objects.create(name="test"))
|
||||
policy = Policy.objects.create(name=generate_id())
|
||||
binding = PolicyBinding(
|
||||
policy=policy, target=Application.objects.create(name=generate_id())
|
||||
)
|
||||
|
||||
request = PolicyRequest(self.user)
|
||||
response = PolicyProcess(binding, request, None).execute()
|
||||
@ -125,13 +125,15 @@ class TestPolicyProcess(TestCase):
|
||||
def test_execution_logging(self):
|
||||
"""Test policy execution creates event"""
|
||||
policy = DummyPolicy.objects.create(
|
||||
name="test-execution-logging",
|
||||
name=generate_id(),
|
||||
result=False,
|
||||
wait_min=0,
|
||||
wait_max=1,
|
||||
execution_logging=True,
|
||||
)
|
||||
binding = PolicyBinding(policy=policy, target=Application.objects.create(name="test"))
|
||||
binding = PolicyBinding(
|
||||
policy=policy, target=Application.objects.create(name=generate_id())
|
||||
)
|
||||
|
||||
http_request = self.factory.get(reverse("authentik_api:user-impersonate-end"))
|
||||
http_request.user = self.user
|
||||
@ -186,13 +188,15 @@ class TestPolicyProcess(TestCase):
|
||||
def test_execution_logging_anonymous(self):
|
||||
"""Test policy execution creates event with anonymous user"""
|
||||
policy = DummyPolicy.objects.create(
|
||||
name="test-execution-logging-anon",
|
||||
name=generate_id(),
|
||||
result=False,
|
||||
wait_min=0,
|
||||
wait_max=1,
|
||||
execution_logging=True,
|
||||
)
|
||||
binding = PolicyBinding(policy=policy, target=Application.objects.create(name="test"))
|
||||
binding = PolicyBinding(
|
||||
policy=policy, target=Application.objects.create(name=generate_id())
|
||||
)
|
||||
|
||||
user = AnonymousUser()
|
||||
|
||||
@ -219,9 +223,9 @@ class TestPolicyProcess(TestCase):
|
||||
|
||||
def test_raises(self):
|
||||
"""Test policy that raises error"""
|
||||
policy_raises = ExpressionPolicy.objects.create(name="raises", expression="{{ 0/0 }}")
|
||||
policy_raises = ExpressionPolicy.objects.create(name=generate_id(), expression="{{ 0/0 }}")
|
||||
binding = PolicyBinding(
|
||||
policy=policy_raises, target=Application.objects.create(name="test")
|
||||
policy=policy_raises, target=Application.objects.create(name=generate_id())
|
||||
)
|
||||
|
||||
request = PolicyRequest(self.user)
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Prompt Stage Logic"""
|
||||
|
||||
from collections.abc import Callable, Iterator
|
||||
from collections.abc import Callable
|
||||
from email.policy import Policy
|
||||
from types import MethodType
|
||||
from typing import Any
|
||||
@ -190,7 +190,7 @@ class ListPolicyEngine(PolicyEngine):
|
||||
self.__list = policies
|
||||
self.use_cache = False
|
||||
|
||||
def iterate_bindings(self) -> Iterator[PolicyBinding]:
|
||||
def bindings(self):
|
||||
for policy in self.__list:
|
||||
yield PolicyBinding(
|
||||
policy=policy,
|
||||
|
Reference in New Issue
Block a user