diff --git a/authentik/blueprints/v1/importer.py b/authentik/blueprints/v1/importer.py index 1e935c4d6b..21127ec00e 100644 --- a/authentik/blueprints/v1/importer.py +++ b/authentik/blueprints/v1/importer.py @@ -39,7 +39,8 @@ from authentik.core.models import ( Source, UserSourceConnection, ) -from authentik.enterprise.models import LicenseKey, LicenseUsage +from authentik.enterprise.license import LicenseKey +from authentik.enterprise.models import LicenseUsage from authentik.enterprise.providers.rac.models import ConnectionToken from authentik.events.models import SystemTask from authentik.events.utils import cleanse_dict diff --git a/authentik/enterprise/api.py b/authentik/enterprise/api.py index b7a91d7649..3835c85610 100644 --- a/authentik/enterprise/api.py +++ b/authentik/enterprise/api.py @@ -1,6 +1,7 @@ """Enterprise API Views""" -from datetime import datetime, timedelta +from dataclasses import asdict +from datetime import timedelta from django.utils.timezone import now from django.utils.translation import gettext as _ @@ -8,7 +9,7 @@ from drf_spectacular.types import OpenApiTypes from drf_spectacular.utils import extend_schema, inline_serializer from rest_framework.decorators import action from rest_framework.exceptions import ValidationError -from rest_framework.fields import BooleanField, CharField, DateTimeField, IntegerField +from rest_framework.fields import CharField, IntegerField from rest_framework.permissions import IsAuthenticated from rest_framework.request import Request from rest_framework.response import Response @@ -19,18 +20,18 @@ from authentik.api.decorators import permission_required from authentik.core.api.used_by import UsedByMixin from authentik.core.api.utils import PassiveSerializer from authentik.core.models import User, UserTypes -from authentik.enterprise.models import License, LicenseKey +from authentik.enterprise.license import LicenseKey, LicenseSummarySerializer +from authentik.enterprise.models import License from authentik.root.install_id import get_install_id class EnterpriseRequiredMixin: """Mixin to validate that a valid enterprise license - exists before allowing to safe the object""" + exists before allowing to save the object""" def validate(self, attrs: dict) -> dict: """Check that a valid license exists""" - total = LicenseKey.get_total() - if not total.is_valid(): + if not LicenseKey.cached_summary().valid: raise ValidationError(_("Enterprise is required to create/update this object.")) return super().validate(attrs) @@ -61,19 +62,6 @@ class LicenseSerializer(ModelSerializer): } -class LicenseSummary(PassiveSerializer): - """Serializer for license status""" - - internal_users = IntegerField(required=True) - external_users = IntegerField(required=True) - valid = BooleanField() - show_admin_warning = BooleanField() - show_user_warning = BooleanField() - read_only = BooleanField() - latest_valid = DateTimeField() - has_license = BooleanField() - - class LicenseForecastSerializer(PassiveSerializer): """Serializer for license forecast""" @@ -111,31 +99,13 @@ class LicenseViewSet(UsedByMixin, ModelViewSet): @extend_schema( request=OpenApiTypes.NONE, responses={ - 200: LicenseSummary(), + 200: LicenseSummarySerializer(), }, ) @action(detail=False, methods=["GET"], permission_classes=[IsAuthenticated]) def summary(self, request: Request) -> Response: """Get the total license status""" - total = LicenseKey.get_total() - last_valid = LicenseKey.last_valid_date() - # TODO: move this to a different place? - show_admin_warning = last_valid < now() - timedelta(weeks=2) - show_user_warning = last_valid < now() - timedelta(weeks=4) - read_only = last_valid < now() - timedelta(weeks=6) - latest_valid = datetime.fromtimestamp(total.exp) - response = LicenseSummary( - data={ - "internal_users": total.internal_users, - "external_users": total.external_users, - "valid": total.is_valid(), - "show_admin_warning": show_admin_warning, - "show_user_warning": show_user_warning, - "read_only": read_only, - "latest_valid": latest_valid, - "has_license": License.objects.all().count() > 0, - } - ) + response = LicenseSummarySerializer(data=asdict(LicenseKey.cached_summary())) response.is_valid(raise_exception=True) return Response(response.data) diff --git a/authentik/enterprise/apps.py b/authentik/enterprise/apps.py index 0b1057ac9a..83dbefa06a 100644 --- a/authentik/enterprise/apps.py +++ b/authentik/enterprise/apps.py @@ -23,6 +23,6 @@ class AuthentikEnterpriseConfig(EnterpriseConfig): def check_enabled(self): """Actual enterprise check, cached""" - from authentik.enterprise.models import LicenseKey + from authentik.enterprise.license import LicenseKey - return LicenseKey.get_total().is_valid() + return LicenseKey.cached_summary().valid diff --git a/authentik/enterprise/audit/middleware.py b/authentik/enterprise/audit/middleware.py index cca240b96f..ad649d1a04 100644 --- a/authentik/enterprise/audit/middleware.py +++ b/authentik/enterprise/audit/middleware.py @@ -19,14 +19,10 @@ from authentik.events.utils import cleanse_dict, sanitize_item class EnterpriseAuditMiddleware(AuditMiddleware): """Enterprise audit middleware""" - _enabled = None - @property def enabled(self): - """Lazy check if audit logging is enabled""" - if self._enabled is None: - self._enabled = apps.get_app_config("authentik_enterprise").enabled() - return self._enabled + """Check if audit logging is enabled""" + return apps.get_app_config("authentik_enterprise").enabled() def connect(self, request: HttpRequest): super().connect(request) diff --git a/authentik/enterprise/license.py b/authentik/enterprise/license.py new file mode 100644 index 0000000000..7baa1b378a --- /dev/null +++ b/authentik/enterprise/license.py @@ -0,0 +1,213 @@ +"""Enterprise license""" + +from base64 import b64decode +from binascii import Error +from dataclasses import asdict, dataclass, field +from datetime import datetime, timedelta +from enum import Enum +from functools import lru_cache +from time import mktime + +from cryptography.exceptions import InvalidSignature +from cryptography.x509 import Certificate, load_der_x509_certificate, load_pem_x509_certificate +from dacite import from_dict +from django.core.cache import cache +from django.db.models.query import QuerySet +from django.utils.timezone import now +from jwt import PyJWTError, decode, get_unverified_header +from rest_framework.exceptions import ValidationError +from rest_framework.fields import BooleanField, DateTimeField, IntegerField + +from authentik.core.api.utils import PassiveSerializer +from authentik.core.models import User, UserTypes +from authentik.enterprise.models import License, LicenseUsage +from authentik.root.install_id import get_install_id + +CACHE_KEY_ENTERPRISE_LICENSE = "goauthentik.io/enterprise/license" +CACHE_EXPIRY_ENTERPRISE_LICENSE = 3 * 60 * 60 # 2 Hours + + +@lru_cache() +def get_licensing_key() -> Certificate: + """Get Root CA PEM""" + with open("authentik/enterprise/public.pem", "rb") as _key: + return load_pem_x509_certificate(_key.read()) + + +def get_license_aud() -> str: + """Get the JWT audience field""" + return f"enterprise.goauthentik.io/license/{get_install_id()}" + + +class LicenseFlags(Enum): + """License flags""" + + +@dataclass +class LicenseSummary: + """Internal representation of a license summary""" + + internal_users: int + external_users: int + valid: bool + show_admin_warning: bool + show_user_warning: bool + read_only: bool + latest_valid: datetime + has_license: bool + + +class LicenseSummarySerializer(PassiveSerializer): + """Serializer for license status""" + + internal_users = IntegerField(required=True) + external_users = IntegerField(required=True) + valid = BooleanField() + show_admin_warning = BooleanField() + show_user_warning = BooleanField() + read_only = BooleanField() + latest_valid = DateTimeField() + has_license = BooleanField() + + +@dataclass +class LicenseKey: + """License JWT claims""" + + aud: str + exp: int + + name: str + internal_users: int = 0 + external_users: int = 0 + flags: list[LicenseFlags] = field(default_factory=list) + + @staticmethod + def validate(jwt: str) -> "LicenseKey": + """Validate the license from a given JWT""" + try: + headers = get_unverified_header(jwt) + except PyJWTError: + raise ValidationError("Unable to verify license") + x5c: list[str] = headers.get("x5c", []) + if len(x5c) < 1: + raise ValidationError("Unable to verify license") + try: + our_cert = load_der_x509_certificate(b64decode(x5c[0])) + intermediate = load_der_x509_certificate(b64decode(x5c[1])) + our_cert.verify_directly_issued_by(intermediate) + intermediate.verify_directly_issued_by(get_licensing_key()) + except (InvalidSignature, TypeError, ValueError, Error): + raise ValidationError("Unable to verify license") + try: + body = from_dict( + LicenseKey, + decode( + jwt, + our_cert.public_key(), + algorithms=["ES512"], + audience=get_license_aud(), + ), + ) + except PyJWTError: + raise ValidationError("Unable to verify license") + return body + + @staticmethod + def get_total() -> "LicenseKey": + """Get a summarized version of all (not expired) licenses""" + active_licenses = License.objects.filter(expiry__gte=now()) + total = LicenseKey(get_license_aud(), 0, "Summarized license", 0, 0) + for lic in active_licenses: + total.internal_users += lic.internal_users + total.external_users += lic.external_users + exp_ts = int(mktime(lic.expiry.timetuple())) + if total.exp == 0: + total.exp = exp_ts + if exp_ts <= total.exp: + total.exp = exp_ts + total.flags.extend(lic.status.flags) + return total + + @staticmethod + def base_user_qs() -> QuerySet: + """Base query set for all users""" + return User.objects.all().exclude_anonymous().exclude(is_active=False) + + @staticmethod + def get_default_user_count(): + """Get current default user count""" + return LicenseKey.base_user_qs().filter(type=UserTypes.INTERNAL).count() + + @staticmethod + def get_external_user_count(): + """Get current external user count""" + # Count since start of the month + last_month = now().replace(day=1) + return ( + LicenseKey.base_user_qs() + .filter(type=UserTypes.EXTERNAL, last_login__gte=last_month) + .count() + ) + + def is_valid(self) -> bool: + """Check if the given license body covers all users + + Only checks the current count, no historical data is checked""" + default_users = self.get_default_user_count() + if default_users > self.internal_users: + return False + active_users = self.get_external_user_count() + if active_users > self.external_users: + return False + return True + + def record_usage(self): + """Capture the current validity status and metrics and save them""" + threshold = now() - timedelta(hours=8) + if not LicenseUsage.objects.filter(record_date__gte=threshold).exists(): + LicenseUsage.objects.create( + user_count=self.get_default_user_count(), + external_user_count=self.get_external_user_count(), + within_limits=self.is_valid(), + ) + summary = asdict(self.summary()) + # Also cache the latest summary for the middleware + cache.set(CACHE_KEY_ENTERPRISE_LICENSE, summary, timeout=CACHE_EXPIRY_ENTERPRISE_LICENSE) + return summary + + @staticmethod + def last_valid_date() -> datetime: + """Get the last date the license was valid""" + usage: LicenseUsage = ( + LicenseUsage.filter_not_expired(within_limits=True).order_by("-record_date").first() + ) + if not usage: + return now() + return usage.record_date + + def summary(self) -> LicenseSummary: + """Summary of license status""" + last_valid = LicenseKey.last_valid_date() + show_admin_warning = last_valid < now() - timedelta(weeks=2) + show_user_warning = last_valid < now() - timedelta(weeks=4) + read_only = last_valid < now() - timedelta(weeks=6) + latest_valid = datetime.fromtimestamp(self.exp) + return LicenseSummary( + show_admin_warning=show_admin_warning, + show_user_warning=show_user_warning, + read_only=read_only, + latest_valid=latest_valid, + internal_users=self.internal_users, + external_users=self.external_users, + valid=self.is_valid(), + has_license=License.objects.all().count() > 0, + ) + + @staticmethod + def cached_summary() -> LicenseSummary: + """Helper method which looks up the last summary""" + summary = cache.get(CACHE_KEY_ENTERPRISE_LICENSE) + if not summary: + return LicenseKey.get_total().summary() + return from_dict(LicenseSummary, summary) diff --git a/authentik/enterprise/middleware.py b/authentik/enterprise/middleware.py new file mode 100644 index 0000000000..83ff8af05f --- /dev/null +++ b/authentik/enterprise/middleware.py @@ -0,0 +1,64 @@ +"""Enterprise middleware""" + +from collections.abc import Callable + +from django.http import HttpRequest, HttpResponse, JsonResponse +from django.urls import resolve +from structlog.stdlib import BoundLogger, get_logger + +from authentik.enterprise.api import LicenseViewSet +from authentik.enterprise.license import LicenseKey +from authentik.flows.views.executor import FlowExecutorView +from authentik.lib.utils.reflection import class_to_path + + +class EnterpriseMiddleware: + """Enterprise middleware""" + + get_response: Callable[[HttpRequest], HttpResponse] + logger: BoundLogger + + def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]): + self.get_response = get_response + self.logger = get_logger().bind() + + def __call__(self, request: HttpRequest) -> HttpResponse: + resolver_match = resolve(request.path_info) + request.resolver_match = resolver_match + if not self.is_request_allowed(request): + self.logger.warning("Refusing request due to expired/invalid license") + return JsonResponse( + { + "detail": "Request denied due to expired/invalid license.", + "code": "denied_license", + }, + status=400, + ) + return self.get_response(request) + + def is_request_allowed(self, request: HttpRequest) -> bool: + """Check if a specific request is allowed""" + if self.is_request_always_allowed(request): + return True + cached_status = LicenseKey.cached_summary() + if not cached_status: + return True + if cached_status.read_only: + return False + return True + + def is_request_always_allowed(self, request: HttpRequest): + """Check if a request is always allowed""" + # Always allow "safe" methods + if request.method.lower() in ["get", "head", "options", "trace"]: + return True + # Always allow requests to manage licenses + if class_to_path(request.resolver_match.func) == class_to_path(LicenseViewSet): + return True + # Flow executor is mounted as an API path but explicitly allowed + if class_to_path(request.resolver_match.func) == class_to_path(FlowExecutorView): + return True + # Only apply these restrictions to the API + if "authentik_api" not in request.resolver_match.app_names: + return True + return False diff --git a/authentik/enterprise/models.py b/authentik/enterprise/models.py index 2c20169fbd..6600e5c07a 100644 --- a/authentik/enterprise/models.py +++ b/authentik/enterprise/models.py @@ -1,159 +1,20 @@ """Enterprise models""" -from base64 import b64decode -from binascii import Error -from dataclasses import dataclass, field -from datetime import datetime, timedelta -from enum import Enum -from functools import lru_cache -from time import mktime +from datetime import timedelta +from typing import TYPE_CHECKING from uuid import uuid4 -from cryptography.exceptions import InvalidSignature -from cryptography.x509 import Certificate, load_der_x509_certificate, load_pem_x509_certificate -from dacite import from_dict from django.contrib.postgres.indexes import HashIndex from django.db import models -from django.db.models.query import QuerySet from django.utils.timezone import now from django.utils.translation import gettext as _ -from jwt import PyJWTError, decode, get_unverified_header -from rest_framework.exceptions import ValidationError from rest_framework.serializers import BaseSerializer -from authentik.core.models import ExpiringModel, User, UserTypes +from authentik.core.models import ExpiringModel from authentik.lib.models import SerializerModel -from authentik.root.install_id import get_install_id - -@lru_cache() -def get_licensing_key() -> Certificate: - """Get Root CA PEM""" - with open("authentik/enterprise/public.pem", "rb") as _key: - return load_pem_x509_certificate(_key.read()) - - -def get_license_aud() -> str: - """Get the JWT audience field""" - return f"enterprise.goauthentik.io/license/{get_install_id()}" - - -class LicenseFlags(Enum): - """License flags""" - - -@dataclass -class LicenseKey: - """License JWT claims""" - - aud: str - exp: int - - name: str - internal_users: int = 0 - external_users: int = 0 - flags: list[LicenseFlags] = field(default_factory=list) - - @staticmethod - def validate(jwt: str) -> "LicenseKey": - """Validate the license from a given JWT""" - try: - headers = get_unverified_header(jwt) - except PyJWTError: - raise ValidationError("Unable to verify license") - x5c: list[str] = headers.get("x5c", []) - if len(x5c) < 1: - raise ValidationError("Unable to verify license") - try: - our_cert = load_der_x509_certificate(b64decode(x5c[0])) - intermediate = load_der_x509_certificate(b64decode(x5c[1])) - our_cert.verify_directly_issued_by(intermediate) - intermediate.verify_directly_issued_by(get_licensing_key()) - except (InvalidSignature, TypeError, ValueError, Error): - raise ValidationError("Unable to verify license") - try: - body = from_dict( - LicenseKey, - decode( - jwt, - our_cert.public_key(), - algorithms=["ES512"], - audience=get_license_aud(), - ), - ) - except PyJWTError: - raise ValidationError("Unable to verify license") - return body - - @staticmethod - def get_total() -> "LicenseKey": - """Get a summarized version of all (not expired) licenses""" - active_licenses = License.objects.filter(expiry__gte=now()) - total = LicenseKey(get_license_aud(), 0, "Summarized license", 0, 0) - for lic in active_licenses: - total.internal_users += lic.internal_users - total.external_users += lic.external_users - exp_ts = int(mktime(lic.expiry.timetuple())) - if total.exp == 0: - total.exp = exp_ts - if exp_ts <= total.exp: - total.exp = exp_ts - total.flags.extend(lic.status.flags) - return total - - @staticmethod - def base_user_qs() -> QuerySet: - """Base query set for all users""" - return User.objects.all().exclude_anonymous().exclude(is_active=False) - - @staticmethod - def get_default_user_count(): - """Get current default user count""" - return LicenseKey.base_user_qs().filter(type=UserTypes.INTERNAL).count() - - @staticmethod - def get_external_user_count(): - """Get current external user count""" - # Count since start of the month - last_month = now().replace(day=1) - return ( - LicenseKey.base_user_qs() - .filter(type=UserTypes.EXTERNAL, last_login__gte=last_month) - .count() - ) - - def is_valid(self) -> bool: - """Check if the given license body covers all users - - Only checks the current count, no historical data is checked""" - default_users = self.get_default_user_count() - if default_users > self.internal_users: - return False - active_users = self.get_external_user_count() - if active_users > self.external_users: - return False - return True - - def record_usage(self): - """Capture the current validity status and metrics and save them""" - threshold = now() - timedelta(hours=8) - if LicenseUsage.objects.filter(record_date__gte=threshold).exists(): - return - LicenseUsage.objects.create( - user_count=self.get_default_user_count(), - external_user_count=self.get_external_user_count(), - within_limits=self.is_valid(), - ) - - @staticmethod - def last_valid_date() -> datetime: - """Get the last date the license was valid""" - usage: LicenseUsage = ( - LicenseUsage.filter_not_expired(within_limits=True).order_by("-record_date").first() - ) - if not usage: - return now() - return usage.record_date +if TYPE_CHECKING: + from authentik.enterprise.license import LicenseKey class License(SerializerModel): @@ -174,8 +35,10 @@ class License(SerializerModel): return LicenseSerializer @property - def status(self) -> LicenseKey: + def status(self) -> "LicenseKey": """Get parsed license status""" + from authentik.enterprise.license import LicenseKey + return LicenseKey.validate(self.key) class Meta: diff --git a/authentik/enterprise/policy.py b/authentik/enterprise/policy.py index a448c087c5..2e2535de0c 100644 --- a/authentik/enterprise/policy.py +++ b/authentik/enterprise/policy.py @@ -5,7 +5,7 @@ from typing import Optional from django.utils.translation import gettext_lazy as _ from authentik.core.models import User, UserTypes -from authentik.enterprise.models import LicenseKey +from authentik.enterprise.license import LicenseKey from authentik.policies.types import PolicyRequest, PolicyResult from authentik.policies.views import PolicyAccessView diff --git a/authentik/enterprise/providers/rac/tests/test_views.py b/authentik/enterprise/providers/rac/tests/test_views.py index 380b925a76..a63f27fba0 100644 --- a/authentik/enterprise/providers/rac/tests/test_views.py +++ b/authentik/enterprise/providers/rac/tests/test_views.py @@ -11,7 +11,8 @@ from rest_framework.test import APITestCase from authentik.core.models import Application from authentik.core.tests.utils import create_test_admin_user, create_test_flow -from authentik.enterprise.models import License, LicenseKey +from authentik.enterprise.license import LicenseKey +from authentik.enterprise.models import License from authentik.enterprise.providers.rac.models import Endpoint, Protocols, RACProvider from authentik.lib.generators import generate_id from authentik.policies.denied import AccessDeniedResponse @@ -39,7 +40,7 @@ class TestRACViews(APITestCase): ) @patch( - "authentik.enterprise.models.LicenseKey.validate", + "authentik.enterprise.license.LicenseKey.validate", MagicMock( return_value=LicenseKey( aud="", @@ -70,7 +71,7 @@ class TestRACViews(APITestCase): self.assertEqual(final_response.status_code, 200) @patch( - "authentik.enterprise.models.LicenseKey.validate", + "authentik.enterprise.license.LicenseKey.validate", MagicMock( return_value=LicenseKey( aud="", @@ -99,7 +100,7 @@ class TestRACViews(APITestCase): self.assertIsInstance(response, AccessDeniedResponse) @patch( - "authentik.enterprise.models.LicenseKey.validate", + "authentik.enterprise.license.LicenseKey.validate", MagicMock( return_value=LicenseKey( aud="", diff --git a/authentik/enterprise/settings.py b/authentik/enterprise/settings.py index f026c70c22..7eb238a831 100644 --- a/authentik/enterprise/settings.py +++ b/authentik/enterprise/settings.py @@ -16,3 +16,5 @@ TENANT_APPS = [ "authentik.enterprise.audit", "authentik.enterprise.providers.rac", ] + +MIDDLEWARE = ["authentik.enterprise.middleware.EnterpriseMiddleware"] diff --git a/authentik/enterprise/tasks.py b/authentik/enterprise/tasks.py index 0d5a537a8d..a55ab5e13d 100644 --- a/authentik/enterprise/tasks.py +++ b/authentik/enterprise/tasks.py @@ -1,6 +1,6 @@ """Enterprise tasks""" -from authentik.enterprise.models import LicenseKey +from authentik.enterprise.license import LicenseKey from authentik.events.models import TaskStatus from authentik.events.system_tasks import SystemTask, prefill_task from authentik.root.celery import CELERY_APP diff --git a/authentik/enterprise/tests/test_license.py b/authentik/enterprise/tests/test_license.py index a972d961b7..efa45e0eb6 100644 --- a/authentik/enterprise/tests/test_license.py +++ b/authentik/enterprise/tests/test_license.py @@ -8,7 +8,8 @@ from django.test import TestCase from django.utils.timezone import now from rest_framework.exceptions import ValidationError -from authentik.enterprise.models import License, LicenseKey +from authentik.enterprise.license import LicenseKey +from authentik.enterprise.models import License from authentik.lib.generators import generate_id _exp = int(mktime((now() + timedelta(days=3000)).timetuple())) @@ -18,7 +19,7 @@ class TestEnterpriseLicense(TestCase): """Enterprise license tests""" @patch( - "authentik.enterprise.models.LicenseKey.validate", + "authentik.enterprise.license.LicenseKey.validate", MagicMock( return_value=LicenseKey( aud="", @@ -41,7 +42,7 @@ class TestEnterpriseLicense(TestCase): License.objects.create(key=generate_id()) @patch( - "authentik.enterprise.models.LicenseKey.validate", + "authentik.enterprise.license.LicenseKey.validate", MagicMock( return_value=LicenseKey( aud="",