policy(minor): Move policy-related code to separate package
This commit is contained in:
		
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @ -191,3 +191,4 @@ pip-selfcheck.json | ||||
| # End of https://www.gitignore.io/api/python,django | ||||
| /static/ | ||||
| local.env.yml | ||||
| .vscode/ | ||||
|  | ||||
| @ -11,8 +11,8 @@ from django.views.generic.detail import DetailView | ||||
| from passbook.admin.forms.policies import PolicyTestForm | ||||
| from passbook.admin.mixins import AdminRequiredMixin | ||||
| from passbook.core.models import Policy | ||||
| from passbook.core.policies import PolicyEngine | ||||
| from passbook.lib.utils.reflection import path_to_class | ||||
| from passbook.policy.engine import PolicyEngine | ||||
|  | ||||
|  | ||||
| class PolicyListView(AdminRequiredMixin, ListView): | ||||
|  | ||||
| @ -27,7 +27,7 @@ class ApplicationGatewayMiddleware: | ||||
|         handler = RequestHandler(app_gw, request) | ||||
|  | ||||
|         if not handler.check_permission(): | ||||
|             to_url = 'https://%s/?next=%s' % (CONFIG.get('domains')[0], request.get_full_path()) | ||||
|             to_url = 'https://%s/?next=%s' % (CONFIG.y('domains')[0], request.get_full_path()) | ||||
|             return RedirectView.as_view(url=to_url)(request) | ||||
|  | ||||
|         return handler.get_response() | ||||
|  | ||||
| @ -15,7 +15,7 @@ from passbook.app_gw.proxy.response import get_django_response | ||||
| from passbook.app_gw.proxy.rewrite import Rewriter | ||||
| from passbook.app_gw.proxy.utils import encode_items, normalize_request_headers | ||||
| from passbook.core.models import Application | ||||
| from passbook.core.policies import PolicyEngine | ||||
| from passbook.policy.engine import PolicyEngine | ||||
|  | ||||
| SESSION_UPSTREAM_KEY = 'passbook_app_gw_upstream' | ||||
| IGNORED_HOSTNAMES_KEY = 'passbook_app_gw_ignored' | ||||
|  | ||||
| @ -17,7 +17,7 @@ class PassbookCoreConfig(AppConfig): | ||||
|     mountpoint = '' | ||||
|  | ||||
|     def ready(self): | ||||
|         import_module('passbook.core.policies') | ||||
|         import_module('passbook.policy.engine') | ||||
|         factors_to_load = CONFIG.y('passbook.factors', []) | ||||
|         for factors_to_load in factors_to_load: | ||||
|             try: | ||||
|  | ||||
| @ -1,5 +1,6 @@ | ||||
| """passbook multi-factor authentication engine""" | ||||
| from logging import getLogger | ||||
| from typing import List, Tuple | ||||
|  | ||||
| from django.contrib.auth import login | ||||
| from django.contrib.auth.mixins import UserPassesTestMixin | ||||
| @ -8,10 +9,10 @@ from django.utils.http import urlencode | ||||
| from django.views.generic import View | ||||
|  | ||||
| from passbook.core.models import Factor, User | ||||
| from passbook.core.policies import PolicyEngine | ||||
| from passbook.core.views.utils import PermissionDeniedView | ||||
| from passbook.lib.utils.reflection import class_to_path, path_to_class | ||||
| from passbook.lib.utils.urls import is_url_absolute | ||||
| from passbook.policy.engine import PolicyEngine | ||||
|  | ||||
| LOGGER = getLogger(__name__) | ||||
|  | ||||
| @ -31,12 +32,12 @@ class AuthenticationView(UserPassesTestMixin, View): | ||||
|     SESSION_USER_BACKEND = 'passbook_user_backend' | ||||
|     SESSION_IS_SSO_LOGIN = 'passbook_sso_login' | ||||
|  | ||||
|     pending_user = None | ||||
|     pending_factors = [] | ||||
|     pending_user: User | ||||
|     pending_factors: List[Tuple[str, str]] = [] | ||||
|  | ||||
|     _current_factor_class = None | ||||
|     _current_factor_class: Factor | ||||
|  | ||||
|     current_factor = None | ||||
|     current_factor: Factor | ||||
|  | ||||
|     # Allow only not authenticated users to login | ||||
|     def test_func(self): | ||||
|  | ||||
| @ -4,7 +4,7 @@ from datetime import timedelta | ||||
| from logging import getLogger | ||||
| from random import SystemRandom | ||||
| from time import sleep | ||||
| from typing import Tuple, Union | ||||
| from typing import List | ||||
| from uuid import uuid4 | ||||
|  | ||||
| from django.contrib.auth.models import AbstractUser | ||||
| @ -25,6 +25,20 @@ def default_nonce_duration(): | ||||
|     """Default duration a Nonce is valid""" | ||||
|     return now() + timedelta(hours=4) | ||||
|  | ||||
|  | ||||
| class PolicyResult: | ||||
|     """Small data-class to hold policy results""" | ||||
|  | ||||
|     passing: bool = False | ||||
|     messages: List[str] = [] | ||||
|  | ||||
|     def __init__(self, passing: bool, *messages: str): | ||||
|         self.passing = passing | ||||
|         self.messages = messages | ||||
|  | ||||
|     def __str__(self): | ||||
|         return f"<PolicyResult passing={self.passing}>" | ||||
|  | ||||
| class Group(UUIDModel): | ||||
|     """Custom Group model which supports a basic hierarchy""" | ||||
|  | ||||
| @ -229,7 +243,7 @@ class Policy(UUIDModel, CreatedUpdatedModel): | ||||
|             return self.name | ||||
|         return "%s action %s" % (self.name, self.action) | ||||
|  | ||||
|     def passes(self, user: User) -> Union[bool, Tuple[bool, str]]: | ||||
|     def passes(self, user: User) -> PolicyResult: | ||||
|         """Check if user instance passes this policy""" | ||||
|         raise NotImplementedError() | ||||
|  | ||||
| @ -273,7 +287,7 @@ class FieldMatcherPolicy(Policy): | ||||
|             description = "%s: %s" % (self.name, description) | ||||
|         return description | ||||
|  | ||||
|     def passes(self, user: User) -> Union[bool, Tuple[bool, str]]: | ||||
|     def passes(self, user: User) -> PolicyResult: | ||||
|         """Check if user instance passes this role""" | ||||
|         if not hasattr(user, self.user_field): | ||||
|             raise ValueError("Field does not exist") | ||||
| @ -294,7 +308,7 @@ class FieldMatcherPolicy(Policy): | ||||
|             passes = user_field_value == self.value | ||||
|  | ||||
|         LOGGER.debug("User got '%r'", passes) | ||||
|         return passes | ||||
|         return PolicyResult(passes) | ||||
|  | ||||
|     class Meta: | ||||
|  | ||||
| @ -313,10 +327,10 @@ class PasswordPolicy(Policy): | ||||
|  | ||||
|     form = 'passbook.core.forms.policies.PasswordPolicyForm' | ||||
|  | ||||
|     def passes(self, user: User) -> Union[bool, Tuple[bool, str]]: | ||||
|     def passes(self, user: User) -> PolicyResult: | ||||
|         # Only check if password is being set | ||||
|         if not hasattr(user, '__password__'): | ||||
|             return True | ||||
|             return PolicyResult(True) | ||||
|         password = getattr(user, '__password__') | ||||
|  | ||||
|         filter_regex = r'' | ||||
| @ -329,8 +343,8 @@ class PasswordPolicy(Policy): | ||||
|         result = bool(re.compile(filter_regex).match(password)) | ||||
|         LOGGER.debug("User got %r", result) | ||||
|         if not result: | ||||
|             return result, self.error_message | ||||
|         return result | ||||
|             return PolicyResult(result, self.error_message) | ||||
|         return PolicyResult(result) | ||||
|  | ||||
|     class Meta: | ||||
|  | ||||
| @ -364,7 +378,7 @@ class WebhookPolicy(Policy): | ||||
|  | ||||
|     form = 'passbook.core.forms.policies.WebhookPolicyForm' | ||||
|  | ||||
|     def passes(self, user: User): | ||||
|     def passes(self, user: User) -> PolicyResult: | ||||
|         """Call webhook asynchronously and report back""" | ||||
|         raise NotImplementedError() | ||||
|  | ||||
| @ -383,12 +397,12 @@ class DebugPolicy(Policy): | ||||
|  | ||||
|     form = 'passbook.core.forms.policies.DebugPolicyForm' | ||||
|  | ||||
|     def passes(self, user: User): | ||||
|     def passes(self, user: User) -> PolicyResult: | ||||
|         """Wait random time then return result""" | ||||
|         wait = SystemRandom().randrange(self.wait_min, self.wait_max) | ||||
|         LOGGER.debug("Policy '%s' waiting for %ds", self.name, wait) | ||||
|         sleep(wait) | ||||
|         return self.result, 'Debugging' | ||||
|         return PolicyResult(self.result, 'Debugging') | ||||
|  | ||||
|     class Meta: | ||||
|  | ||||
| @ -402,8 +416,8 @@ class GroupMembershipPolicy(Policy): | ||||
|  | ||||
|     form = 'passbook.core.forms.policies.GroupMembershipPolicyForm' | ||||
|  | ||||
|     def passes(self, user: User) -> Union[bool, Tuple[bool, str]]: | ||||
|         return self.group.user_set.filter(pk=user.pk).exists() | ||||
|     def passes(self, user: User) -> PolicyResult: | ||||
|         return PolicyResult(self.group.user_set.filter(pk=user.pk).exists()) | ||||
|  | ||||
|     class Meta: | ||||
|  | ||||
| @ -415,10 +429,10 @@ class SSOLoginPolicy(Policy): | ||||
|  | ||||
|     form = 'passbook.core.forms.policies.SSOLoginPolicyForm' | ||||
|  | ||||
|     def passes(self, user): | ||||
|     def passes(self, user) -> PolicyResult: | ||||
|         """Check if user instance passes this policy""" | ||||
|         from passbook.core.auth.view import AuthenticationView | ||||
|         return user.session.get(AuthenticationView.SESSION_IS_SSO_LOGIN, False), "" | ||||
|         return PolicyResult(user.session.get(AuthenticationView.SESSION_IS_SSO_LOGIN, False)) | ||||
|  | ||||
|     class Meta: | ||||
|  | ||||
|  | ||||
| @ -1,134 +0,0 @@ | ||||
| """passbook core policy engine""" | ||||
| from logging import getLogger | ||||
|  | ||||
| from amqp.exceptions import UnexpectedFrame | ||||
| from celery import group | ||||
| from celery.exceptions import TimeoutError as CeleryTimeoutError | ||||
| from django.core.cache import cache | ||||
| from ipware import get_client_ip | ||||
|  | ||||
| from passbook.core.models import Policy, User | ||||
| from passbook.root.celery import CELERY_APP | ||||
|  | ||||
| LOGGER = getLogger(__name__) | ||||
|  | ||||
| def _cache_key(policy, user): | ||||
|     return "policy_%s#%s" % (policy.uuid, user.pk) | ||||
|  | ||||
| @CELERY_APP.task() | ||||
| def _policy_engine_task(user_pk, policy_pk, **kwargs): | ||||
|     """Task wrapper to run policy checking""" | ||||
|     if not user_pk: | ||||
|         raise ValueError() | ||||
|     policy_obj = Policy.objects.filter(pk=policy_pk).select_subclasses().first() | ||||
|     user_obj = User.objects.get(pk=user_pk) | ||||
|     for key, value in kwargs.items(): | ||||
|         setattr(user_obj, key, value) | ||||
|     LOGGER.debug("Running policy `%s`#%s for user %s...", policy_obj.name, | ||||
|                  policy_obj.pk.hex, user_obj) | ||||
|     policy_result = policy_obj.passes(user_obj) | ||||
|     # Handle policy result correctly if result, message or just result | ||||
|     message = None | ||||
|     if isinstance(policy_result, (tuple, list)): | ||||
|         policy_result, message = policy_result | ||||
|     # Invert result if policy.negate is set | ||||
|     if policy_obj.negate: | ||||
|         policy_result = not policy_result | ||||
|     LOGGER.debug("Policy %r#%s got %s", policy_obj.name, policy_obj.pk.hex, policy_result) | ||||
|     cache_key = _cache_key(policy_obj, user_obj) | ||||
|     cache.set(cache_key, (policy_obj.action, policy_result, message)) | ||||
|     LOGGER.debug("Cached entry as %s", cache_key) | ||||
|     return policy_obj.action, policy_result, message | ||||
|  | ||||
| class PolicyEngine: | ||||
|     """Orchestrate policy checking, launch tasks and return result""" | ||||
|  | ||||
|     __group = None | ||||
|     __cached = None | ||||
|  | ||||
|     policies = None | ||||
|     __get_timeout = 0 | ||||
|     __request = None | ||||
|     __user = None | ||||
|  | ||||
|     def __init__(self, policies): | ||||
|         self.policies = policies | ||||
|         self.__request = None | ||||
|         self.__user = None | ||||
|  | ||||
|     def for_user(self, user): | ||||
|         """Check policies for user""" | ||||
|         self.__user = user | ||||
|         return self | ||||
|  | ||||
|     def with_request(self, request): | ||||
|         """Set request""" | ||||
|         self.__request = request | ||||
|         return self | ||||
|  | ||||
|     def build(self): | ||||
|         """Build task group""" | ||||
|         if not self.__user: | ||||
|             raise ValueError("User not set.") | ||||
|         signatures = [] | ||||
|         cached_policies = [] | ||||
|         kwargs = { | ||||
|             '__password__': getattr(self.__user, '__password__', None), | ||||
|             'session': dict(getattr(self.__request, 'session', {}).items()), | ||||
|         } | ||||
|         if self.__request: | ||||
|             kwargs['remote_ip'], _ = get_client_ip(self.__request) | ||||
|             if not kwargs['remote_ip']: | ||||
|                 kwargs['remote_ip'] = '255.255.255.255' | ||||
|         for policy in self.policies: | ||||
|             cached_policy = cache.get(_cache_key(policy, self.__user), None) | ||||
|             if cached_policy: | ||||
|                 LOGGER.debug("Taking result from cache for %s", policy.pk.hex) | ||||
|                 cached_policies.append(cached_policy) | ||||
|             else: | ||||
|                 LOGGER.debug("Evaluating policy %s", policy.pk.hex) | ||||
|                 signatures.append(_policy_engine_task.signature( | ||||
|                     args=(self.__user.pk, policy.pk.hex), | ||||
|                     kwargs=kwargs, | ||||
|                     time_limit=policy.timeout)) | ||||
|                 self.__get_timeout += policy.timeout | ||||
|         LOGGER.debug("Set total policy timeout to %r", self.__get_timeout) | ||||
|         # If all policies are cached, we have an empty list here. | ||||
|         if signatures: | ||||
|             self.__group = group(signatures)() | ||||
|             self.__get_timeout += 3 | ||||
|             self.__get_timeout = (self.__get_timeout / len(self.policies)) * 1.5 | ||||
|         self.__cached = cached_policies | ||||
|         return self | ||||
|  | ||||
|     @property | ||||
|     def result(self): | ||||
|         """Get policy-checking result""" | ||||
|         messages = [] | ||||
|         result = [] | ||||
|         try: | ||||
|             if self.__group: | ||||
|                 # ValueError can be thrown from _policy_engine_task when user is None | ||||
|                 result += self.__group.get(timeout=self.__get_timeout) | ||||
|             result += self.__cached | ||||
|         except ValueError as exc: | ||||
|             # ValueError can be thrown from _policy_engine_task when user is None | ||||
|             return False, [str(exc)] | ||||
|         except UnexpectedFrame as exc: | ||||
|             return False, [str(exc)] | ||||
|         except CeleryTimeoutError as exc: | ||||
|             return False, [str(exc)] | ||||
|         for policy_action, policy_result, policy_message in result: | ||||
|             passing = (policy_action == Policy.ACTION_ALLOW and policy_result) or \ | ||||
|                       (policy_action == Policy.ACTION_DENY and not policy_result) | ||||
|             LOGGER.debug('Action=%s, Result=%r => %r', policy_action, policy_result, passing) | ||||
|             if policy_message: | ||||
|                 messages.append(policy_message) | ||||
|             if not passing: | ||||
|                 return False, messages | ||||
|         return True, messages | ||||
|  | ||||
|     @property | ||||
|     def passing(self): | ||||
|         """Only get true/false if user passes""" | ||||
|         return self.result[0] | ||||
| @ -20,7 +20,7 @@ password_changed = Signal(providing_args=['user', 'password']) | ||||
| def password_policy_checker(sender, password, **kwargs): | ||||
|     """Run password through all password policies which are applied to the user""" | ||||
|     from passbook.core.models import PasswordFactor | ||||
|     from passbook.core.policies import PolicyEngine | ||||
|     from passbook.policy.engine import PolicyEngine | ||||
|     setattr(sender, '__password__', password) | ||||
|     _all_factors = PasswordFactor.objects.filter(enabled=True).order_by('order') | ||||
|     for factor in _all_factors: | ||||
|  | ||||
| @ -3,7 +3,7 @@ | ||||
| from django import template | ||||
|  | ||||
| from passbook.core.models import Factor, Source | ||||
| from passbook.core.policies import PolicyEngine | ||||
| from passbook.policy.engine import PolicyEngine | ||||
|  | ||||
| register = template.Library() | ||||
|  | ||||
|  | ||||
| @ -5,7 +5,7 @@ from django.contrib import messages | ||||
| from django.utils.translation import gettext as _ | ||||
|  | ||||
| from passbook.core.models import Application | ||||
| from passbook.core.policies import PolicyEngine | ||||
| from passbook.policy.engine import PolicyEngine | ||||
|  | ||||
| LOGGER = getLogger(__name__) | ||||
|  | ||||
|  | ||||
| @ -4,7 +4,7 @@ from django.contrib.auth.mixins import LoginRequiredMixin | ||||
| from django.views.generic import TemplateView | ||||
|  | ||||
| from passbook.core.models import Application | ||||
| from passbook.core.policies import PolicyEngine | ||||
| from passbook.policy.engine import PolicyEngine | ||||
|  | ||||
|  | ||||
| class OverviewView(LoginRequiredMixin, TemplateView): | ||||
|  | ||||
| @ -6,7 +6,7 @@ from django.db import models | ||||
| from django.utils.translation import gettext as _ | ||||
| from requests import get | ||||
|  | ||||
| from passbook.core.models import Policy, User | ||||
| from passbook.core.models import Policy, PolicyResult, User | ||||
|  | ||||
| LOGGER = getLogger(__name__) | ||||
|  | ||||
| @ -18,13 +18,13 @@ class HaveIBeenPwendPolicy(Policy): | ||||
|  | ||||
|     form = 'passbook.hibp_policy.forms.HaveIBeenPwnedPolicyForm' | ||||
|  | ||||
|     def passes(self, user: User) -> bool: | ||||
|     def passes(self, user: User) -> PolicyResult: | ||||
|         """Check if password is in HIBP DB. Hashes given Password with SHA1, uses the first 5 | ||||
|         characters of Password in request and checks if full hash is in response. Returns 0 | ||||
|         if Password is not in result otherwise the count of how many times it was used.""" | ||||
|         # Only check if password is being set | ||||
|         if not hasattr(user, '__password__'): | ||||
|             return True | ||||
|             return PolicyResult(True) | ||||
|         password = getattr(user, '__password__') | ||||
|         pw_hash = sha1(password.encode('utf-8')).hexdigest() # nosec | ||||
|         url = 'https://api.pwnedpasswords.com/range/%s' % pw_hash[:5] | ||||
| @ -36,8 +36,9 @@ class HaveIBeenPwendPolicy(Policy): | ||||
|                 final_count = int(count) | ||||
|         LOGGER.debug("Got count %d for hash %s", final_count, pw_hash[:5]) | ||||
|         if final_count > self.allowed_count: | ||||
|             return False, _("Password exists on %(count)d online lists." % {'count': final_count}) | ||||
|         return True | ||||
|             message = _("Password exists on %(count)d online lists." % {'count': final_count}) | ||||
|             return PolicyResult(False, message) | ||||
|         return PolicyResult(True) | ||||
|  | ||||
|     class Meta: | ||||
|  | ||||
|  | ||||
| @ -34,8 +34,7 @@ class ConfigLoader: | ||||
|  | ||||
|     def __init__(self): | ||||
|         super().__init__() | ||||
|         base_dir = os.path.realpath(os.path.join( | ||||
|             os.path.dirname(__file__), '../..')) | ||||
|         base_dir = os.path.realpath(os.path.join(os.path.dirname(__file__), '../..')) | ||||
|         for path in SEARCH_PATHS: | ||||
|             # Check if path is relative, and if so join with base_dir | ||||
|             if not os.path.isabs(path): | ||||
|  | ||||
| @ -2,9 +2,9 @@ | ||||
| from django.template import Context, Template, loader | ||||
|  | ||||
|  | ||||
| def render_from_string(template: str, ctx: Context) -> str: | ||||
| def render_from_string(tmpl: str, ctx: Context) -> str: | ||||
|     """Render template from string to string""" | ||||
|     template = Template(template) | ||||
|     template = Template(tmpl) | ||||
|     return template.render(ctx) | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -5,7 +5,7 @@ from django.contrib import messages | ||||
| from django.shortcuts import redirect | ||||
|  | ||||
| from passbook.core.models import Application | ||||
| from passbook.core.policies import PolicyEngine | ||||
| from passbook.policy.engine import PolicyEngine | ||||
|  | ||||
| LOGGER = getLogger(__name__) | ||||
|  | ||||
|  | ||||
| @ -6,7 +6,7 @@ from django.db import models | ||||
| from django.utils.timezone import now | ||||
| from django.utils.translation import gettext as _ | ||||
|  | ||||
| from passbook.core.models import Policy, User | ||||
| from passbook.core.models import Policy, PolicyResult, User | ||||
|  | ||||
| LOGGER = getLogger(__name__) | ||||
|  | ||||
| @ -20,7 +20,7 @@ class PasswordExpiryPolicy(Policy): | ||||
|  | ||||
|     form = 'passbook.password_expiry_policy.forms.PasswordExpiryPolicyForm' | ||||
|  | ||||
|     def passes(self, user: User) -> bool: | ||||
|     def passes(self, user: User) -> PolicyResult: | ||||
|         """If password change date is more than x days in the past, call set_unusable_password | ||||
|         and show a notice""" | ||||
|         actual_days = (now() - user.password_change_date).days | ||||
| @ -29,12 +29,13 @@ class PasswordExpiryPolicy(Policy): | ||||
|             if not self.deny_only: | ||||
|                 user.set_unusable_password() | ||||
|                 user.save() | ||||
|                 return False, _(('Password expired %(days)d days ago. ' | ||||
|                 message = _(('Password expired %(days)d days ago. ' | ||||
|                              'Please update your password.') % { | ||||
|                                  'days': days_since_expiry | ||||
|                              }) | ||||
|             return False, _('Password has expired.') | ||||
|         return True | ||||
|                 return PolicyResult(False, message) | ||||
|             return PolicyResult(False, _('Password has expired.')) | ||||
|         return PolicyResult(True) | ||||
|  | ||||
|     class Meta: | ||||
|  | ||||
|  | ||||
							
								
								
									
										0
									
								
								passbook/policy/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								passbook/policy/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										97
									
								
								passbook/policy/engine.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										97
									
								
								passbook/policy/engine.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,97 @@ | ||||
| """passbook policy engine""" | ||||
| from multiprocessing import Pipe | ||||
| from multiprocessing.connection import Connection | ||||
| from typing import List, Tuple | ||||
|  | ||||
| from django.core.cache import cache | ||||
| from django.http import HttpRequest | ||||
| from structlog import get_logger | ||||
|  | ||||
| from passbook.core.models import Policy, PolicyResult, User | ||||
| from passbook.policy.task import PolicyTask | ||||
|  | ||||
| LOGGER = get_logger() | ||||
|  | ||||
| def _cache_key(policy, user): | ||||
|     return "policy_%s#%s" % (policy.uuid, user.pk) | ||||
|  | ||||
| class PolicyEngine: | ||||
|     """Orchestrate policy checking, launch tasks and return result""" | ||||
|  | ||||
|     # __group = None | ||||
|     # __cached = None | ||||
|  | ||||
|     policies: List[Policy] = [] | ||||
|     __request: HttpRequest | ||||
|     __user: User | ||||
|  | ||||
|     __proc_list: List[Tuple[Connection, PolicyTask]] = [] | ||||
|  | ||||
|     def __init__(self, policies, user: User = None, request: HttpRequest = None): | ||||
|         self.policies = policies | ||||
|         self.__request = request | ||||
|         self.__user = user | ||||
|  | ||||
|     def for_user(self, user: User) -> 'PolicyEngine': | ||||
|         """Check policies for user""" | ||||
|         self.__user = user | ||||
|         return self | ||||
|  | ||||
|     def with_request(self, request: HttpRequest) -> 'PolicyEngine': | ||||
|         """Set request""" | ||||
|         self.__request = request | ||||
|         return self | ||||
|  | ||||
|     def build(self) -> 'PolicyEngine': | ||||
|         """Build task group""" | ||||
|         if not self.__user: | ||||
|             raise ValueError("User not set.") | ||||
|         cached_policies = [] | ||||
|         kwargs = { | ||||
|             '__password__': getattr(self.__user, '__password__', None), | ||||
|             'session': dict(getattr(self.__request, 'session', {}).items()), | ||||
|             'request': self.__request, | ||||
|         } | ||||
|         for policy in self.policies: | ||||
|             cached_policy = cache.get(_cache_key(policy, self.__user), None) | ||||
|             if cached_policy: | ||||
|                 LOGGER.debug("Taking result from cache for %s", policy.pk.hex) | ||||
|                 cached_policies.append(cached_policy) | ||||
|             else: | ||||
|                 LOGGER.debug("Evaluating policy %s", policy.pk.hex) | ||||
|                 our_end, task_end = Pipe(False) | ||||
|                 task = PolicyTask() | ||||
|                 task.ret = task_end | ||||
|                 task.user = self.__user | ||||
|                 task.policy = policy | ||||
|                 task.params = kwargs | ||||
|                 LOGGER.debug("Starting Process %s", task.__class__.__name__) | ||||
|                 task.start() | ||||
|                 self.__proc_list.append((our_end, task)) | ||||
|         # If all policies are cached, we have an empty list here. | ||||
|         if self.__proc_list: | ||||
|             for _, running_proc in self.__proc_list: | ||||
|                 running_proc.join() | ||||
|         return self | ||||
|  | ||||
|     @property | ||||
|     def result(self): | ||||
|         """Get policy-checking result""" | ||||
|         results: List[PolicyResult] = [] | ||||
|         messages: List[str] = [] | ||||
|         for our_end, _ in self.__proc_list: | ||||
|             results.append(our_end.recv()) | ||||
|         for policy_result in results: | ||||
|             # passing = (policy_action == Policy.ACTION_ALLOW and policy_result) or \ | ||||
|             #           (policy_action == Policy.ACTION_DENY and not policy_result) | ||||
|             LOGGER.debug('Result=%r => %r', policy_result, policy_result.passing) | ||||
|             if policy_result.messages: | ||||
|                 messages += policy_result.messages | ||||
|             if not policy_result.passing: | ||||
|                 return False, messages | ||||
|         return True, messages | ||||
|  | ||||
|     @property | ||||
|     def passing(self): | ||||
|         """Only get true/false if user passes""" | ||||
|         return self.result[0] | ||||
							
								
								
									
										38
									
								
								passbook/policy/task.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										38
									
								
								passbook/policy/task.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,38 @@ | ||||
| """passbook policy task""" | ||||
| from logging import getLogger | ||||
| from multiprocessing import Process | ||||
| from multiprocessing.connection import Connection | ||||
| from typing import Any, Dict | ||||
|  | ||||
| from passbook.core.models import Policy, User | ||||
|  | ||||
| LOGGER = getLogger(__name__) | ||||
|  | ||||
|  | ||||
| def _cache_key(policy, user): | ||||
|     return "policy_%s#%s" % (policy.uuid, user.pk) | ||||
|  | ||||
| class PolicyTask(Process): | ||||
|     """Evaluate a single policy within a seprate process""" | ||||
|  | ||||
|     ret: Connection | ||||
|     user: User | ||||
|     policy: Policy | ||||
|     params: Dict[str, Any] | ||||
|  | ||||
|     def run(self): | ||||
|         """Task wrapper to run policy checking""" | ||||
|         for key, value in self.params.items(): | ||||
|             setattr(self.user, key, value) | ||||
|         LOGGER.debug("Running policy `%s`#%s for user %s...", self.policy.name, | ||||
|                      self.policy.pk.hex, self.user) | ||||
|         policy_result = self.policy.passes(self.user) | ||||
|         # Invert result if policy.negate is set | ||||
|         if self.policy.negate: | ||||
|             policy_result = not policy_result | ||||
|         LOGGER.debug("Policy %r#%s got %s", self.policy.name, self.policy.pk.hex, policy_result) | ||||
|         # cache_key = _cache_key(self.policy, self.user) | ||||
|         # cache.set(cache_key, (self.policy.action, policy_result, message)) | ||||
|         # LOGGER.debug("Cached entry as %s", cache_key) | ||||
|         self.ret.send(policy_result) | ||||
|         self.ret.close() | ||||
| @ -11,7 +11,6 @@ https://docs.djangoproject.com/en/2.1/ref/settings/ | ||||
| """ | ||||
|  | ||||
| import importlib | ||||
| import logging | ||||
| import os | ||||
| import sys | ||||
|  | ||||
|  | ||||
| @ -260,7 +260,6 @@ class Processor: | ||||
|     def _validate_user(self): | ||||
|         """Validates the User. Sub-classes should override this and | ||||
|         throw an CannotHandleAssertion Exception if the validation does not succeed.""" | ||||
|         pass | ||||
|  | ||||
|     def can_handle(self, request): | ||||
|         """Returns true if this processor can handle this request.""" | ||||
|  | ||||
| @ -3,9 +3,7 @@ | ||||
|  | ||||
| class CannotHandleAssertion(Exception): | ||||
|     """This processor does not handle this assertion.""" | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class UserNotAuthorized(Exception): | ||||
|     """User not authorized for SAML 2.0 authentication.""" | ||||
|     pass | ||||
|  | ||||
| @ -16,9 +16,9 @@ from signxml.util import strip_pem_header | ||||
|  | ||||
| from passbook.audit.models import AuditEntry | ||||
| from passbook.core.models import Application | ||||
| from passbook.core.policies import PolicyEngine | ||||
| from passbook.lib.mixins import CSRFExemptMixin | ||||
| from passbook.lib.utils.template import render_to_string | ||||
| from passbook.policy.engine import PolicyEngine | ||||
| from passbook.saml_idp import exceptions | ||||
| from passbook.saml_idp.models import SAMLProvider | ||||
|  | ||||
|  | ||||
| @ -2,7 +2,7 @@ | ||||
| from django.db import models | ||||
| from django.utils.translation import gettext as _ | ||||
|  | ||||
| from passbook.core.models import Policy, User | ||||
| from passbook.core.models import Policy, PolicyResult, User | ||||
|  | ||||
|  | ||||
| class SuspiciousRequestPolicy(Policy): | ||||
| @ -14,7 +14,7 @@ class SuspiciousRequestPolicy(Policy): | ||||
|  | ||||
|     form = 'passbook.suspicious_policy.forms.SuspiciousRequestPolicyForm' | ||||
|  | ||||
|     def passes(self, user: User): | ||||
|     def passes(self, user: User) -> PolicyResult: | ||||
|         remote_ip = user.remote_ip | ||||
|         passing = True | ||||
|         if self.check_ip: | ||||
| @ -23,7 +23,7 @@ class SuspiciousRequestPolicy(Policy): | ||||
|         if self.check_username: | ||||
|             user_scores = UserScore.objects.filter(user=user, score__lte=self.threshold) | ||||
|             passing = passing and user_scores.exists() | ||||
|         return passing | ||||
|         return PolicyResult(passing) | ||||
|  | ||||
|     class Meta: | ||||
|  | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	 Langhammer, Jens
					Langhammer, Jens