enterprise: rework license summary caching (#8501)

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens L
2024-02-14 19:00:08 +01:00
committed by GitHub
parent 4733778460
commit 7d527beea8
12 changed files with 313 additions and 202 deletions

View File

@ -39,7 +39,8 @@ from authentik.core.models import (
Source, Source,
UserSourceConnection, UserSourceConnection,
) )
from authentik.enterprise.models import LicenseKey, LicenseUsage from authentik.enterprise.license import LicenseKey
from authentik.enterprise.models import LicenseUsage
from authentik.enterprise.providers.rac.models import ConnectionToken from authentik.enterprise.providers.rac.models import ConnectionToken
from authentik.events.models import SystemTask from authentik.events.models import SystemTask
from authentik.events.utils import cleanse_dict from authentik.events.utils import cleanse_dict

View File

@ -1,6 +1,7 @@
"""Enterprise API Views""" """Enterprise API Views"""
from datetime import datetime, timedelta from dataclasses import asdict
from datetime import timedelta
from django.utils.timezone import now from django.utils.timezone import now
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
@ -8,7 +9,7 @@ from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import extend_schema, inline_serializer from drf_spectacular.utils import extend_schema, inline_serializer
from rest_framework.decorators import action from rest_framework.decorators import action
from rest_framework.exceptions import ValidationError from rest_framework.exceptions import ValidationError
from rest_framework.fields import BooleanField, CharField, DateTimeField, IntegerField from rest_framework.fields import CharField, IntegerField
from rest_framework.permissions import IsAuthenticated from rest_framework.permissions import IsAuthenticated
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.response import Response from rest_framework.response import Response
@ -19,18 +20,18 @@ from authentik.api.decorators import permission_required
from authentik.core.api.used_by import UsedByMixin from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import PassiveSerializer from authentik.core.api.utils import PassiveSerializer
from authentik.core.models import User, UserTypes from authentik.core.models import User, UserTypes
from authentik.enterprise.models import License, LicenseKey from authentik.enterprise.license import LicenseKey, LicenseSummarySerializer
from authentik.enterprise.models import License
from authentik.root.install_id import get_install_id from authentik.root.install_id import get_install_id
class EnterpriseRequiredMixin: class EnterpriseRequiredMixin:
"""Mixin to validate that a valid enterprise license """Mixin to validate that a valid enterprise license
exists before allowing to safe the object""" exists before allowing to save the object"""
def validate(self, attrs: dict) -> dict: def validate(self, attrs: dict) -> dict:
"""Check that a valid license exists""" """Check that a valid license exists"""
total = LicenseKey.get_total() if not LicenseKey.cached_summary().valid:
if not total.is_valid():
raise ValidationError(_("Enterprise is required to create/update this object.")) raise ValidationError(_("Enterprise is required to create/update this object."))
return super().validate(attrs) return super().validate(attrs)
@ -61,19 +62,6 @@ class LicenseSerializer(ModelSerializer):
} }
class LicenseSummary(PassiveSerializer):
"""Serializer for license status"""
internal_users = IntegerField(required=True)
external_users = IntegerField(required=True)
valid = BooleanField()
show_admin_warning = BooleanField()
show_user_warning = BooleanField()
read_only = BooleanField()
latest_valid = DateTimeField()
has_license = BooleanField()
class LicenseForecastSerializer(PassiveSerializer): class LicenseForecastSerializer(PassiveSerializer):
"""Serializer for license forecast""" """Serializer for license forecast"""
@ -111,31 +99,13 @@ class LicenseViewSet(UsedByMixin, ModelViewSet):
@extend_schema( @extend_schema(
request=OpenApiTypes.NONE, request=OpenApiTypes.NONE,
responses={ responses={
200: LicenseSummary(), 200: LicenseSummarySerializer(),
}, },
) )
@action(detail=False, methods=["GET"], permission_classes=[IsAuthenticated]) @action(detail=False, methods=["GET"], permission_classes=[IsAuthenticated])
def summary(self, request: Request) -> Response: def summary(self, request: Request) -> Response:
"""Get the total license status""" """Get the total license status"""
total = LicenseKey.get_total() response = LicenseSummarySerializer(data=asdict(LicenseKey.cached_summary()))
last_valid = LicenseKey.last_valid_date()
# TODO: move this to a different place?
show_admin_warning = last_valid < now() - timedelta(weeks=2)
show_user_warning = last_valid < now() - timedelta(weeks=4)
read_only = last_valid < now() - timedelta(weeks=6)
latest_valid = datetime.fromtimestamp(total.exp)
response = LicenseSummary(
data={
"internal_users": total.internal_users,
"external_users": total.external_users,
"valid": total.is_valid(),
"show_admin_warning": show_admin_warning,
"show_user_warning": show_user_warning,
"read_only": read_only,
"latest_valid": latest_valid,
"has_license": License.objects.all().count() > 0,
}
)
response.is_valid(raise_exception=True) response.is_valid(raise_exception=True)
return Response(response.data) return Response(response.data)

View File

@ -23,6 +23,6 @@ class AuthentikEnterpriseConfig(EnterpriseConfig):
def check_enabled(self): def check_enabled(self):
"""Actual enterprise check, cached""" """Actual enterprise check, cached"""
from authentik.enterprise.models import LicenseKey from authentik.enterprise.license import LicenseKey
return LicenseKey.get_total().is_valid() return LicenseKey.cached_summary().valid

View File

@ -19,14 +19,10 @@ from authentik.events.utils import cleanse_dict, sanitize_item
class EnterpriseAuditMiddleware(AuditMiddleware): class EnterpriseAuditMiddleware(AuditMiddleware):
"""Enterprise audit middleware""" """Enterprise audit middleware"""
_enabled = None
@property @property
def enabled(self): def enabled(self):
"""Lazy check if audit logging is enabled""" """Check if audit logging is enabled"""
if self._enabled is None: return apps.get_app_config("authentik_enterprise").enabled()
self._enabled = apps.get_app_config("authentik_enterprise").enabled()
return self._enabled
def connect(self, request: HttpRequest): def connect(self, request: HttpRequest):
super().connect(request) super().connect(request)

View File

@ -0,0 +1,213 @@
"""Enterprise license"""
from base64 import b64decode
from binascii import Error
from dataclasses import asdict, dataclass, field
from datetime import datetime, timedelta
from enum import Enum
from functools import lru_cache
from time import mktime
from cryptography.exceptions import InvalidSignature
from cryptography.x509 import Certificate, load_der_x509_certificate, load_pem_x509_certificate
from dacite import from_dict
from django.core.cache import cache
from django.db.models.query import QuerySet
from django.utils.timezone import now
from jwt import PyJWTError, decode, get_unverified_header
from rest_framework.exceptions import ValidationError
from rest_framework.fields import BooleanField, DateTimeField, IntegerField
from authentik.core.api.utils import PassiveSerializer
from authentik.core.models import User, UserTypes
from authentik.enterprise.models import License, LicenseUsage
from authentik.root.install_id import get_install_id
CACHE_KEY_ENTERPRISE_LICENSE = "goauthentik.io/enterprise/license"
CACHE_EXPIRY_ENTERPRISE_LICENSE = 3 * 60 * 60 # 2 Hours
@lru_cache()
def get_licensing_key() -> Certificate:
"""Get Root CA PEM"""
with open("authentik/enterprise/public.pem", "rb") as _key:
return load_pem_x509_certificate(_key.read())
def get_license_aud() -> str:
"""Get the JWT audience field"""
return f"enterprise.goauthentik.io/license/{get_install_id()}"
class LicenseFlags(Enum):
"""License flags"""
@dataclass
class LicenseSummary:
"""Internal representation of a license summary"""
internal_users: int
external_users: int
valid: bool
show_admin_warning: bool
show_user_warning: bool
read_only: bool
latest_valid: datetime
has_license: bool
class LicenseSummarySerializer(PassiveSerializer):
"""Serializer for license status"""
internal_users = IntegerField(required=True)
external_users = IntegerField(required=True)
valid = BooleanField()
show_admin_warning = BooleanField()
show_user_warning = BooleanField()
read_only = BooleanField()
latest_valid = DateTimeField()
has_license = BooleanField()
@dataclass
class LicenseKey:
"""License JWT claims"""
aud: str
exp: int
name: str
internal_users: int = 0
external_users: int = 0
flags: list[LicenseFlags] = field(default_factory=list)
@staticmethod
def validate(jwt: str) -> "LicenseKey":
"""Validate the license from a given JWT"""
try:
headers = get_unverified_header(jwt)
except PyJWTError:
raise ValidationError("Unable to verify license")
x5c: list[str] = headers.get("x5c", [])
if len(x5c) < 1:
raise ValidationError("Unable to verify license")
try:
our_cert = load_der_x509_certificate(b64decode(x5c[0]))
intermediate = load_der_x509_certificate(b64decode(x5c[1]))
our_cert.verify_directly_issued_by(intermediate)
intermediate.verify_directly_issued_by(get_licensing_key())
except (InvalidSignature, TypeError, ValueError, Error):
raise ValidationError("Unable to verify license")
try:
body = from_dict(
LicenseKey,
decode(
jwt,
our_cert.public_key(),
algorithms=["ES512"],
audience=get_license_aud(),
),
)
except PyJWTError:
raise ValidationError("Unable to verify license")
return body
@staticmethod
def get_total() -> "LicenseKey":
"""Get a summarized version of all (not expired) licenses"""
active_licenses = License.objects.filter(expiry__gte=now())
total = LicenseKey(get_license_aud(), 0, "Summarized license", 0, 0)
for lic in active_licenses:
total.internal_users += lic.internal_users
total.external_users += lic.external_users
exp_ts = int(mktime(lic.expiry.timetuple()))
if total.exp == 0:
total.exp = exp_ts
if exp_ts <= total.exp:
total.exp = exp_ts
total.flags.extend(lic.status.flags)
return total
@staticmethod
def base_user_qs() -> QuerySet:
"""Base query set for all users"""
return User.objects.all().exclude_anonymous().exclude(is_active=False)
@staticmethod
def get_default_user_count():
"""Get current default user count"""
return LicenseKey.base_user_qs().filter(type=UserTypes.INTERNAL).count()
@staticmethod
def get_external_user_count():
"""Get current external user count"""
# Count since start of the month
last_month = now().replace(day=1)
return (
LicenseKey.base_user_qs()
.filter(type=UserTypes.EXTERNAL, last_login__gte=last_month)
.count()
)
def is_valid(self) -> bool:
"""Check if the given license body covers all users
Only checks the current count, no historical data is checked"""
default_users = self.get_default_user_count()
if default_users > self.internal_users:
return False
active_users = self.get_external_user_count()
if active_users > self.external_users:
return False
return True
def record_usage(self):
"""Capture the current validity status and metrics and save them"""
threshold = now() - timedelta(hours=8)
if not LicenseUsage.objects.filter(record_date__gte=threshold).exists():
LicenseUsage.objects.create(
user_count=self.get_default_user_count(),
external_user_count=self.get_external_user_count(),
within_limits=self.is_valid(),
)
summary = asdict(self.summary())
# Also cache the latest summary for the middleware
cache.set(CACHE_KEY_ENTERPRISE_LICENSE, summary, timeout=CACHE_EXPIRY_ENTERPRISE_LICENSE)
return summary
@staticmethod
def last_valid_date() -> datetime:
"""Get the last date the license was valid"""
usage: LicenseUsage = (
LicenseUsage.filter_not_expired(within_limits=True).order_by("-record_date").first()
)
if not usage:
return now()
return usage.record_date
def summary(self) -> LicenseSummary:
"""Summary of license status"""
last_valid = LicenseKey.last_valid_date()
show_admin_warning = last_valid < now() - timedelta(weeks=2)
show_user_warning = last_valid < now() - timedelta(weeks=4)
read_only = last_valid < now() - timedelta(weeks=6)
latest_valid = datetime.fromtimestamp(self.exp)
return LicenseSummary(
show_admin_warning=show_admin_warning,
show_user_warning=show_user_warning,
read_only=read_only,
latest_valid=latest_valid,
internal_users=self.internal_users,
external_users=self.external_users,
valid=self.is_valid(),
has_license=License.objects.all().count() > 0,
)
@staticmethod
def cached_summary() -> LicenseSummary:
"""Helper method which looks up the last summary"""
summary = cache.get(CACHE_KEY_ENTERPRISE_LICENSE)
if not summary:
return LicenseKey.get_total().summary()
return from_dict(LicenseSummary, summary)

View File

@ -0,0 +1,64 @@
"""Enterprise middleware"""
from collections.abc import Callable
from django.http import HttpRequest, HttpResponse, JsonResponse
from django.urls import resolve
from structlog.stdlib import BoundLogger, get_logger
from authentik.enterprise.api import LicenseViewSet
from authentik.enterprise.license import LicenseKey
from authentik.flows.views.executor import FlowExecutorView
from authentik.lib.utils.reflection import class_to_path
class EnterpriseMiddleware:
"""Enterprise middleware"""
get_response: Callable[[HttpRequest], HttpResponse]
logger: BoundLogger
def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]):
self.get_response = get_response
self.logger = get_logger().bind()
def __call__(self, request: HttpRequest) -> HttpResponse:
resolver_match = resolve(request.path_info)
request.resolver_match = resolver_match
if not self.is_request_allowed(request):
self.logger.warning("Refusing request due to expired/invalid license")
return JsonResponse(
{
"detail": "Request denied due to expired/invalid license.",
"code": "denied_license",
},
status=400,
)
return self.get_response(request)
def is_request_allowed(self, request: HttpRequest) -> bool:
"""Check if a specific request is allowed"""
if self.is_request_always_allowed(request):
return True
cached_status = LicenseKey.cached_summary()
if not cached_status:
return True
if cached_status.read_only:
return False
return True
def is_request_always_allowed(self, request: HttpRequest):
"""Check if a request is always allowed"""
# Always allow "safe" methods
if request.method.lower() in ["get", "head", "options", "trace"]:
return True
# Always allow requests to manage licenses
if class_to_path(request.resolver_match.func) == class_to_path(LicenseViewSet):
return True
# Flow executor is mounted as an API path but explicitly allowed
if class_to_path(request.resolver_match.func) == class_to_path(FlowExecutorView):
return True
# Only apply these restrictions to the API
if "authentik_api" not in request.resolver_match.app_names:
return True
return False

View File

@ -1,159 +1,20 @@
"""Enterprise models""" """Enterprise models"""
from base64 import b64decode from datetime import timedelta
from binascii import Error from typing import TYPE_CHECKING
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import Enum
from functools import lru_cache
from time import mktime
from uuid import uuid4 from uuid import uuid4
from cryptography.exceptions import InvalidSignature
from cryptography.x509 import Certificate, load_der_x509_certificate, load_pem_x509_certificate
from dacite import from_dict
from django.contrib.postgres.indexes import HashIndex from django.contrib.postgres.indexes import HashIndex
from django.db import models from django.db import models
from django.db.models.query import QuerySet
from django.utils.timezone import now from django.utils.timezone import now
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from jwt import PyJWTError, decode, get_unverified_header
from rest_framework.exceptions import ValidationError
from rest_framework.serializers import BaseSerializer from rest_framework.serializers import BaseSerializer
from authentik.core.models import ExpiringModel, User, UserTypes from authentik.core.models import ExpiringModel
from authentik.lib.models import SerializerModel from authentik.lib.models import SerializerModel
from authentik.root.install_id import get_install_id
if TYPE_CHECKING:
@lru_cache() from authentik.enterprise.license import LicenseKey
def get_licensing_key() -> Certificate:
"""Get Root CA PEM"""
with open("authentik/enterprise/public.pem", "rb") as _key:
return load_pem_x509_certificate(_key.read())
def get_license_aud() -> str:
"""Get the JWT audience field"""
return f"enterprise.goauthentik.io/license/{get_install_id()}"
class LicenseFlags(Enum):
"""License flags"""
@dataclass
class LicenseKey:
"""License JWT claims"""
aud: str
exp: int
name: str
internal_users: int = 0
external_users: int = 0
flags: list[LicenseFlags] = field(default_factory=list)
@staticmethod
def validate(jwt: str) -> "LicenseKey":
"""Validate the license from a given JWT"""
try:
headers = get_unverified_header(jwt)
except PyJWTError:
raise ValidationError("Unable to verify license")
x5c: list[str] = headers.get("x5c", [])
if len(x5c) < 1:
raise ValidationError("Unable to verify license")
try:
our_cert = load_der_x509_certificate(b64decode(x5c[0]))
intermediate = load_der_x509_certificate(b64decode(x5c[1]))
our_cert.verify_directly_issued_by(intermediate)
intermediate.verify_directly_issued_by(get_licensing_key())
except (InvalidSignature, TypeError, ValueError, Error):
raise ValidationError("Unable to verify license")
try:
body = from_dict(
LicenseKey,
decode(
jwt,
our_cert.public_key(),
algorithms=["ES512"],
audience=get_license_aud(),
),
)
except PyJWTError:
raise ValidationError("Unable to verify license")
return body
@staticmethod
def get_total() -> "LicenseKey":
"""Get a summarized version of all (not expired) licenses"""
active_licenses = License.objects.filter(expiry__gte=now())
total = LicenseKey(get_license_aud(), 0, "Summarized license", 0, 0)
for lic in active_licenses:
total.internal_users += lic.internal_users
total.external_users += lic.external_users
exp_ts = int(mktime(lic.expiry.timetuple()))
if total.exp == 0:
total.exp = exp_ts
if exp_ts <= total.exp:
total.exp = exp_ts
total.flags.extend(lic.status.flags)
return total
@staticmethod
def base_user_qs() -> QuerySet:
"""Base query set for all users"""
return User.objects.all().exclude_anonymous().exclude(is_active=False)
@staticmethod
def get_default_user_count():
"""Get current default user count"""
return LicenseKey.base_user_qs().filter(type=UserTypes.INTERNAL).count()
@staticmethod
def get_external_user_count():
"""Get current external user count"""
# Count since start of the month
last_month = now().replace(day=1)
return (
LicenseKey.base_user_qs()
.filter(type=UserTypes.EXTERNAL, last_login__gte=last_month)
.count()
)
def is_valid(self) -> bool:
"""Check if the given license body covers all users
Only checks the current count, no historical data is checked"""
default_users = self.get_default_user_count()
if default_users > self.internal_users:
return False
active_users = self.get_external_user_count()
if active_users > self.external_users:
return False
return True
def record_usage(self):
"""Capture the current validity status and metrics and save them"""
threshold = now() - timedelta(hours=8)
if LicenseUsage.objects.filter(record_date__gte=threshold).exists():
return
LicenseUsage.objects.create(
user_count=self.get_default_user_count(),
external_user_count=self.get_external_user_count(),
within_limits=self.is_valid(),
)
@staticmethod
def last_valid_date() -> datetime:
"""Get the last date the license was valid"""
usage: LicenseUsage = (
LicenseUsage.filter_not_expired(within_limits=True).order_by("-record_date").first()
)
if not usage:
return now()
return usage.record_date
class License(SerializerModel): class License(SerializerModel):
@ -174,8 +35,10 @@ class License(SerializerModel):
return LicenseSerializer return LicenseSerializer
@property @property
def status(self) -> LicenseKey: def status(self) -> "LicenseKey":
"""Get parsed license status""" """Get parsed license status"""
from authentik.enterprise.license import LicenseKey
return LicenseKey.validate(self.key) return LicenseKey.validate(self.key)
class Meta: class Meta:

View File

@ -5,7 +5,7 @@ from typing import Optional
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from authentik.core.models import User, UserTypes from authentik.core.models import User, UserTypes
from authentik.enterprise.models import LicenseKey from authentik.enterprise.license import LicenseKey
from authentik.policies.types import PolicyRequest, PolicyResult from authentik.policies.types import PolicyRequest, PolicyResult
from authentik.policies.views import PolicyAccessView from authentik.policies.views import PolicyAccessView

View File

@ -11,7 +11,8 @@ from rest_framework.test import APITestCase
from authentik.core.models import Application from authentik.core.models import Application
from authentik.core.tests.utils import create_test_admin_user, create_test_flow from authentik.core.tests.utils import create_test_admin_user, create_test_flow
from authentik.enterprise.models import License, LicenseKey from authentik.enterprise.license import LicenseKey
from authentik.enterprise.models import License
from authentik.enterprise.providers.rac.models import Endpoint, Protocols, RACProvider from authentik.enterprise.providers.rac.models import Endpoint, Protocols, RACProvider
from authentik.lib.generators import generate_id from authentik.lib.generators import generate_id
from authentik.policies.denied import AccessDeniedResponse from authentik.policies.denied import AccessDeniedResponse
@ -39,7 +40,7 @@ class TestRACViews(APITestCase):
) )
@patch( @patch(
"authentik.enterprise.models.LicenseKey.validate", "authentik.enterprise.license.LicenseKey.validate",
MagicMock( MagicMock(
return_value=LicenseKey( return_value=LicenseKey(
aud="", aud="",
@ -70,7 +71,7 @@ class TestRACViews(APITestCase):
self.assertEqual(final_response.status_code, 200) self.assertEqual(final_response.status_code, 200)
@patch( @patch(
"authentik.enterprise.models.LicenseKey.validate", "authentik.enterprise.license.LicenseKey.validate",
MagicMock( MagicMock(
return_value=LicenseKey( return_value=LicenseKey(
aud="", aud="",
@ -99,7 +100,7 @@ class TestRACViews(APITestCase):
self.assertIsInstance(response, AccessDeniedResponse) self.assertIsInstance(response, AccessDeniedResponse)
@patch( @patch(
"authentik.enterprise.models.LicenseKey.validate", "authentik.enterprise.license.LicenseKey.validate",
MagicMock( MagicMock(
return_value=LicenseKey( return_value=LicenseKey(
aud="", aud="",

View File

@ -16,3 +16,5 @@ TENANT_APPS = [
"authentik.enterprise.audit", "authentik.enterprise.audit",
"authentik.enterprise.providers.rac", "authentik.enterprise.providers.rac",
] ]
MIDDLEWARE = ["authentik.enterprise.middleware.EnterpriseMiddleware"]

View File

@ -1,6 +1,6 @@
"""Enterprise tasks""" """Enterprise tasks"""
from authentik.enterprise.models import LicenseKey from authentik.enterprise.license import LicenseKey
from authentik.events.models import TaskStatus from authentik.events.models import TaskStatus
from authentik.events.system_tasks import SystemTask, prefill_task from authentik.events.system_tasks import SystemTask, prefill_task
from authentik.root.celery import CELERY_APP from authentik.root.celery import CELERY_APP

View File

@ -8,7 +8,8 @@ from django.test import TestCase
from django.utils.timezone import now from django.utils.timezone import now
from rest_framework.exceptions import ValidationError from rest_framework.exceptions import ValidationError
from authentik.enterprise.models import License, LicenseKey from authentik.enterprise.license import LicenseKey
from authentik.enterprise.models import License
from authentik.lib.generators import generate_id from authentik.lib.generators import generate_id
_exp = int(mktime((now() + timedelta(days=3000)).timetuple())) _exp = int(mktime((now() + timedelta(days=3000)).timetuple()))
@ -18,7 +19,7 @@ class TestEnterpriseLicense(TestCase):
"""Enterprise license tests""" """Enterprise license tests"""
@patch( @patch(
"authentik.enterprise.models.LicenseKey.validate", "authentik.enterprise.license.LicenseKey.validate",
MagicMock( MagicMock(
return_value=LicenseKey( return_value=LicenseKey(
aud="", aud="",
@ -41,7 +42,7 @@ class TestEnterpriseLicense(TestCase):
License.objects.create(key=generate_id()) License.objects.create(key=generate_id())
@patch( @patch(
"authentik.enterprise.models.LicenseKey.validate", "authentik.enterprise.license.LicenseKey.validate",
MagicMock( MagicMock(
return_value=LicenseKey( return_value=LicenseKey(
aud="", aud="",