enterprise: rework license summary caching (#8501)

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens L
2024-02-14 19:00:08 +01:00
committed by GitHub
parent 4733778460
commit 7d527beea8
12 changed files with 313 additions and 202 deletions

View File

@ -39,7 +39,8 @@ from authentik.core.models import (
Source,
UserSourceConnection,
)
from authentik.enterprise.models import LicenseKey, LicenseUsage
from authentik.enterprise.license import LicenseKey
from authentik.enterprise.models import LicenseUsage
from authentik.enterprise.providers.rac.models import ConnectionToken
from authentik.events.models import SystemTask
from authentik.events.utils import cleanse_dict

View File

@ -1,6 +1,7 @@
"""Enterprise API Views"""
from datetime import datetime, timedelta
from dataclasses import asdict
from datetime import timedelta
from django.utils.timezone import now
from django.utils.translation import gettext as _
@ -8,7 +9,7 @@ from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import extend_schema, inline_serializer
from rest_framework.decorators import action
from rest_framework.exceptions import ValidationError
from rest_framework.fields import BooleanField, CharField, DateTimeField, IntegerField
from rest_framework.fields import CharField, IntegerField
from rest_framework.permissions import IsAuthenticated
from rest_framework.request import Request
from rest_framework.response import Response
@ -19,18 +20,18 @@ from authentik.api.decorators import permission_required
from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import PassiveSerializer
from authentik.core.models import User, UserTypes
from authentik.enterprise.models import License, LicenseKey
from authentik.enterprise.license import LicenseKey, LicenseSummarySerializer
from authentik.enterprise.models import License
from authentik.root.install_id import get_install_id
class EnterpriseRequiredMixin:
"""Mixin to validate that a valid enterprise license
exists before allowing to safe the object"""
exists before allowing to save the object"""
def validate(self, attrs: dict) -> dict:
"""Check that a valid license exists"""
total = LicenseKey.get_total()
if not total.is_valid():
if not LicenseKey.cached_summary().valid:
raise ValidationError(_("Enterprise is required to create/update this object."))
return super().validate(attrs)
@ -61,19 +62,6 @@ class LicenseSerializer(ModelSerializer):
}
class LicenseSummary(PassiveSerializer):
"""Serializer for license status"""
internal_users = IntegerField(required=True)
external_users = IntegerField(required=True)
valid = BooleanField()
show_admin_warning = BooleanField()
show_user_warning = BooleanField()
read_only = BooleanField()
latest_valid = DateTimeField()
has_license = BooleanField()
class LicenseForecastSerializer(PassiveSerializer):
"""Serializer for license forecast"""
@ -111,31 +99,13 @@ class LicenseViewSet(UsedByMixin, ModelViewSet):
@extend_schema(
request=OpenApiTypes.NONE,
responses={
200: LicenseSummary(),
200: LicenseSummarySerializer(),
},
)
@action(detail=False, methods=["GET"], permission_classes=[IsAuthenticated])
def summary(self, request: Request) -> Response:
"""Get the total license status"""
total = LicenseKey.get_total()
last_valid = LicenseKey.last_valid_date()
# TODO: move this to a different place?
show_admin_warning = last_valid < now() - timedelta(weeks=2)
show_user_warning = last_valid < now() - timedelta(weeks=4)
read_only = last_valid < now() - timedelta(weeks=6)
latest_valid = datetime.fromtimestamp(total.exp)
response = LicenseSummary(
data={
"internal_users": total.internal_users,
"external_users": total.external_users,
"valid": total.is_valid(),
"show_admin_warning": show_admin_warning,
"show_user_warning": show_user_warning,
"read_only": read_only,
"latest_valid": latest_valid,
"has_license": License.objects.all().count() > 0,
}
)
response = LicenseSummarySerializer(data=asdict(LicenseKey.cached_summary()))
response.is_valid(raise_exception=True)
return Response(response.data)

View File

@ -23,6 +23,6 @@ class AuthentikEnterpriseConfig(EnterpriseConfig):
def check_enabled(self):
"""Actual enterprise check, cached"""
from authentik.enterprise.models import LicenseKey
from authentik.enterprise.license import LicenseKey
return LicenseKey.get_total().is_valid()
return LicenseKey.cached_summary().valid

View File

@ -19,14 +19,10 @@ from authentik.events.utils import cleanse_dict, sanitize_item
class EnterpriseAuditMiddleware(AuditMiddleware):
"""Enterprise audit middleware"""
_enabled = None
@property
def enabled(self):
"""Lazy check if audit logging is enabled"""
if self._enabled is None:
self._enabled = apps.get_app_config("authentik_enterprise").enabled()
return self._enabled
"""Check if audit logging is enabled"""
return apps.get_app_config("authentik_enterprise").enabled()
def connect(self, request: HttpRequest):
super().connect(request)

View File

@ -0,0 +1,213 @@
"""Enterprise license"""
from base64 import b64decode
from binascii import Error
from dataclasses import asdict, dataclass, field
from datetime import datetime, timedelta
from enum import Enum
from functools import lru_cache
from time import mktime
from cryptography.exceptions import InvalidSignature
from cryptography.x509 import Certificate, load_der_x509_certificate, load_pem_x509_certificate
from dacite import from_dict
from django.core.cache import cache
from django.db.models.query import QuerySet
from django.utils.timezone import now
from jwt import PyJWTError, decode, get_unverified_header
from rest_framework.exceptions import ValidationError
from rest_framework.fields import BooleanField, DateTimeField, IntegerField
from authentik.core.api.utils import PassiveSerializer
from authentik.core.models import User, UserTypes
from authentik.enterprise.models import License, LicenseUsage
from authentik.root.install_id import get_install_id
CACHE_KEY_ENTERPRISE_LICENSE = "goauthentik.io/enterprise/license"
CACHE_EXPIRY_ENTERPRISE_LICENSE = 3 * 60 * 60 # 2 Hours
@lru_cache()
def get_licensing_key() -> Certificate:
"""Get Root CA PEM"""
with open("authentik/enterprise/public.pem", "rb") as _key:
return load_pem_x509_certificate(_key.read())
def get_license_aud() -> str:
"""Get the JWT audience field"""
return f"enterprise.goauthentik.io/license/{get_install_id()}"
class LicenseFlags(Enum):
"""License flags"""
@dataclass
class LicenseSummary:
"""Internal representation of a license summary"""
internal_users: int
external_users: int
valid: bool
show_admin_warning: bool
show_user_warning: bool
read_only: bool
latest_valid: datetime
has_license: bool
class LicenseSummarySerializer(PassiveSerializer):
"""Serializer for license status"""
internal_users = IntegerField(required=True)
external_users = IntegerField(required=True)
valid = BooleanField()
show_admin_warning = BooleanField()
show_user_warning = BooleanField()
read_only = BooleanField()
latest_valid = DateTimeField()
has_license = BooleanField()
@dataclass
class LicenseKey:
"""License JWT claims"""
aud: str
exp: int
name: str
internal_users: int = 0
external_users: int = 0
flags: list[LicenseFlags] = field(default_factory=list)
@staticmethod
def validate(jwt: str) -> "LicenseKey":
"""Validate the license from a given JWT"""
try:
headers = get_unverified_header(jwt)
except PyJWTError:
raise ValidationError("Unable to verify license")
x5c: list[str] = headers.get("x5c", [])
if len(x5c) < 1:
raise ValidationError("Unable to verify license")
try:
our_cert = load_der_x509_certificate(b64decode(x5c[0]))
intermediate = load_der_x509_certificate(b64decode(x5c[1]))
our_cert.verify_directly_issued_by(intermediate)
intermediate.verify_directly_issued_by(get_licensing_key())
except (InvalidSignature, TypeError, ValueError, Error):
raise ValidationError("Unable to verify license")
try:
body = from_dict(
LicenseKey,
decode(
jwt,
our_cert.public_key(),
algorithms=["ES512"],
audience=get_license_aud(),
),
)
except PyJWTError:
raise ValidationError("Unable to verify license")
return body
@staticmethod
def get_total() -> "LicenseKey":
"""Get a summarized version of all (not expired) licenses"""
active_licenses = License.objects.filter(expiry__gte=now())
total = LicenseKey(get_license_aud(), 0, "Summarized license", 0, 0)
for lic in active_licenses:
total.internal_users += lic.internal_users
total.external_users += lic.external_users
exp_ts = int(mktime(lic.expiry.timetuple()))
if total.exp == 0:
total.exp = exp_ts
if exp_ts <= total.exp:
total.exp = exp_ts
total.flags.extend(lic.status.flags)
return total
@staticmethod
def base_user_qs() -> QuerySet:
"""Base query set for all users"""
return User.objects.all().exclude_anonymous().exclude(is_active=False)
@staticmethod
def get_default_user_count():
"""Get current default user count"""
return LicenseKey.base_user_qs().filter(type=UserTypes.INTERNAL).count()
@staticmethod
def get_external_user_count():
"""Get current external user count"""
# Count since start of the month
last_month = now().replace(day=1)
return (
LicenseKey.base_user_qs()
.filter(type=UserTypes.EXTERNAL, last_login__gte=last_month)
.count()
)
def is_valid(self) -> bool:
"""Check if the given license body covers all users
Only checks the current count, no historical data is checked"""
default_users = self.get_default_user_count()
if default_users > self.internal_users:
return False
active_users = self.get_external_user_count()
if active_users > self.external_users:
return False
return True
def record_usage(self):
"""Capture the current validity status and metrics and save them"""
threshold = now() - timedelta(hours=8)
if not LicenseUsage.objects.filter(record_date__gte=threshold).exists():
LicenseUsage.objects.create(
user_count=self.get_default_user_count(),
external_user_count=self.get_external_user_count(),
within_limits=self.is_valid(),
)
summary = asdict(self.summary())
# Also cache the latest summary for the middleware
cache.set(CACHE_KEY_ENTERPRISE_LICENSE, summary, timeout=CACHE_EXPIRY_ENTERPRISE_LICENSE)
return summary
@staticmethod
def last_valid_date() -> datetime:
"""Get the last date the license was valid"""
usage: LicenseUsage = (
LicenseUsage.filter_not_expired(within_limits=True).order_by("-record_date").first()
)
if not usage:
return now()
return usage.record_date
def summary(self) -> LicenseSummary:
"""Summary of license status"""
last_valid = LicenseKey.last_valid_date()
show_admin_warning = last_valid < now() - timedelta(weeks=2)
show_user_warning = last_valid < now() - timedelta(weeks=4)
read_only = last_valid < now() - timedelta(weeks=6)
latest_valid = datetime.fromtimestamp(self.exp)
return LicenseSummary(
show_admin_warning=show_admin_warning,
show_user_warning=show_user_warning,
read_only=read_only,
latest_valid=latest_valid,
internal_users=self.internal_users,
external_users=self.external_users,
valid=self.is_valid(),
has_license=License.objects.all().count() > 0,
)
@staticmethod
def cached_summary() -> LicenseSummary:
"""Helper method which looks up the last summary"""
summary = cache.get(CACHE_KEY_ENTERPRISE_LICENSE)
if not summary:
return LicenseKey.get_total().summary()
return from_dict(LicenseSummary, summary)

View File

@ -0,0 +1,64 @@
"""Enterprise middleware"""
from collections.abc import Callable
from django.http import HttpRequest, HttpResponse, JsonResponse
from django.urls import resolve
from structlog.stdlib import BoundLogger, get_logger
from authentik.enterprise.api import LicenseViewSet
from authentik.enterprise.license import LicenseKey
from authentik.flows.views.executor import FlowExecutorView
from authentik.lib.utils.reflection import class_to_path
class EnterpriseMiddleware:
"""Enterprise middleware"""
get_response: Callable[[HttpRequest], HttpResponse]
logger: BoundLogger
def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]):
self.get_response = get_response
self.logger = get_logger().bind()
def __call__(self, request: HttpRequest) -> HttpResponse:
resolver_match = resolve(request.path_info)
request.resolver_match = resolver_match
if not self.is_request_allowed(request):
self.logger.warning("Refusing request due to expired/invalid license")
return JsonResponse(
{
"detail": "Request denied due to expired/invalid license.",
"code": "denied_license",
},
status=400,
)
return self.get_response(request)
def is_request_allowed(self, request: HttpRequest) -> bool:
"""Check if a specific request is allowed"""
if self.is_request_always_allowed(request):
return True
cached_status = LicenseKey.cached_summary()
if not cached_status:
return True
if cached_status.read_only:
return False
return True
def is_request_always_allowed(self, request: HttpRequest):
"""Check if a request is always allowed"""
# Always allow "safe" methods
if request.method.lower() in ["get", "head", "options", "trace"]:
return True
# Always allow requests to manage licenses
if class_to_path(request.resolver_match.func) == class_to_path(LicenseViewSet):
return True
# Flow executor is mounted as an API path but explicitly allowed
if class_to_path(request.resolver_match.func) == class_to_path(FlowExecutorView):
return True
# Only apply these restrictions to the API
if "authentik_api" not in request.resolver_match.app_names:
return True
return False

View File

@ -1,159 +1,20 @@
"""Enterprise models"""
from base64 import b64decode
from binascii import Error
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import Enum
from functools import lru_cache
from time import mktime
from datetime import timedelta
from typing import TYPE_CHECKING
from uuid import uuid4
from cryptography.exceptions import InvalidSignature
from cryptography.x509 import Certificate, load_der_x509_certificate, load_pem_x509_certificate
from dacite import from_dict
from django.contrib.postgres.indexes import HashIndex
from django.db import models
from django.db.models.query import QuerySet
from django.utils.timezone import now
from django.utils.translation import gettext as _
from jwt import PyJWTError, decode, get_unverified_header
from rest_framework.exceptions import ValidationError
from rest_framework.serializers import BaseSerializer
from authentik.core.models import ExpiringModel, User, UserTypes
from authentik.core.models import ExpiringModel
from authentik.lib.models import SerializerModel
from authentik.root.install_id import get_install_id
@lru_cache()
def get_licensing_key() -> Certificate:
"""Get Root CA PEM"""
with open("authentik/enterprise/public.pem", "rb") as _key:
return load_pem_x509_certificate(_key.read())
def get_license_aud() -> str:
"""Get the JWT audience field"""
return f"enterprise.goauthentik.io/license/{get_install_id()}"
class LicenseFlags(Enum):
"""License flags"""
@dataclass
class LicenseKey:
"""License JWT claims"""
aud: str
exp: int
name: str
internal_users: int = 0
external_users: int = 0
flags: list[LicenseFlags] = field(default_factory=list)
@staticmethod
def validate(jwt: str) -> "LicenseKey":
"""Validate the license from a given JWT"""
try:
headers = get_unverified_header(jwt)
except PyJWTError:
raise ValidationError("Unable to verify license")
x5c: list[str] = headers.get("x5c", [])
if len(x5c) < 1:
raise ValidationError("Unable to verify license")
try:
our_cert = load_der_x509_certificate(b64decode(x5c[0]))
intermediate = load_der_x509_certificate(b64decode(x5c[1]))
our_cert.verify_directly_issued_by(intermediate)
intermediate.verify_directly_issued_by(get_licensing_key())
except (InvalidSignature, TypeError, ValueError, Error):
raise ValidationError("Unable to verify license")
try:
body = from_dict(
LicenseKey,
decode(
jwt,
our_cert.public_key(),
algorithms=["ES512"],
audience=get_license_aud(),
),
)
except PyJWTError:
raise ValidationError("Unable to verify license")
return body
@staticmethod
def get_total() -> "LicenseKey":
"""Get a summarized version of all (not expired) licenses"""
active_licenses = License.objects.filter(expiry__gte=now())
total = LicenseKey(get_license_aud(), 0, "Summarized license", 0, 0)
for lic in active_licenses:
total.internal_users += lic.internal_users
total.external_users += lic.external_users
exp_ts = int(mktime(lic.expiry.timetuple()))
if total.exp == 0:
total.exp = exp_ts
if exp_ts <= total.exp:
total.exp = exp_ts
total.flags.extend(lic.status.flags)
return total
@staticmethod
def base_user_qs() -> QuerySet:
"""Base query set for all users"""
return User.objects.all().exclude_anonymous().exclude(is_active=False)
@staticmethod
def get_default_user_count():
"""Get current default user count"""
return LicenseKey.base_user_qs().filter(type=UserTypes.INTERNAL).count()
@staticmethod
def get_external_user_count():
"""Get current external user count"""
# Count since start of the month
last_month = now().replace(day=1)
return (
LicenseKey.base_user_qs()
.filter(type=UserTypes.EXTERNAL, last_login__gte=last_month)
.count()
)
def is_valid(self) -> bool:
"""Check if the given license body covers all users
Only checks the current count, no historical data is checked"""
default_users = self.get_default_user_count()
if default_users > self.internal_users:
return False
active_users = self.get_external_user_count()
if active_users > self.external_users:
return False
return True
def record_usage(self):
"""Capture the current validity status and metrics and save them"""
threshold = now() - timedelta(hours=8)
if LicenseUsage.objects.filter(record_date__gte=threshold).exists():
return
LicenseUsage.objects.create(
user_count=self.get_default_user_count(),
external_user_count=self.get_external_user_count(),
within_limits=self.is_valid(),
)
@staticmethod
def last_valid_date() -> datetime:
"""Get the last date the license was valid"""
usage: LicenseUsage = (
LicenseUsage.filter_not_expired(within_limits=True).order_by("-record_date").first()
)
if not usage:
return now()
return usage.record_date
if TYPE_CHECKING:
from authentik.enterprise.license import LicenseKey
class License(SerializerModel):
@ -174,8 +35,10 @@ class License(SerializerModel):
return LicenseSerializer
@property
def status(self) -> LicenseKey:
def status(self) -> "LicenseKey":
"""Get parsed license status"""
from authentik.enterprise.license import LicenseKey
return LicenseKey.validate(self.key)
class Meta:

View File

@ -5,7 +5,7 @@ from typing import Optional
from django.utils.translation import gettext_lazy as _
from authentik.core.models import User, UserTypes
from authentik.enterprise.models import LicenseKey
from authentik.enterprise.license import LicenseKey
from authentik.policies.types import PolicyRequest, PolicyResult
from authentik.policies.views import PolicyAccessView

View File

@ -11,7 +11,8 @@ from rest_framework.test import APITestCase
from authentik.core.models import Application
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
from authentik.enterprise.models import License, LicenseKey
from authentik.enterprise.license import LicenseKey
from authentik.enterprise.models import License
from authentik.enterprise.providers.rac.models import Endpoint, Protocols, RACProvider
from authentik.lib.generators import generate_id
from authentik.policies.denied import AccessDeniedResponse
@ -39,7 +40,7 @@ class TestRACViews(APITestCase):
)
@patch(
"authentik.enterprise.models.LicenseKey.validate",
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
@ -70,7 +71,7 @@ class TestRACViews(APITestCase):
self.assertEqual(final_response.status_code, 200)
@patch(
"authentik.enterprise.models.LicenseKey.validate",
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
@ -99,7 +100,7 @@ class TestRACViews(APITestCase):
self.assertIsInstance(response, AccessDeniedResponse)
@patch(
"authentik.enterprise.models.LicenseKey.validate",
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",

View File

@ -16,3 +16,5 @@ TENANT_APPS = [
"authentik.enterprise.audit",
"authentik.enterprise.providers.rac",
]
MIDDLEWARE = ["authentik.enterprise.middleware.EnterpriseMiddleware"]

View File

@ -1,6 +1,6 @@
"""Enterprise tasks"""
from authentik.enterprise.models import LicenseKey
from authentik.enterprise.license import LicenseKey
from authentik.events.models import TaskStatus
from authentik.events.system_tasks import SystemTask, prefill_task
from authentik.root.celery import CELERY_APP

View File

@ -8,7 +8,8 @@ from django.test import TestCase
from django.utils.timezone import now
from rest_framework.exceptions import ValidationError
from authentik.enterprise.models import License, LicenseKey
from authentik.enterprise.license import LicenseKey
from authentik.enterprise.models import License
from authentik.lib.generators import generate_id
_exp = int(mktime((now() + timedelta(days=3000)).timetuple()))
@ -18,7 +19,7 @@ class TestEnterpriseLicense(TestCase):
"""Enterprise license tests"""
@patch(
"authentik.enterprise.models.LicenseKey.validate",
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
@ -41,7 +42,7 @@ class TestEnterpriseLicense(TestCase):
License.objects.create(key=generate_id())
@patch(
"authentik.enterprise.models.LicenseKey.validate",
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",