Files
authentik/authentik/enterprise/license.py
gcp-cherry-pick-bot[bot] 53143e0c40 enterprise: fix expired license's users being counted (cherry-pick #14451) (#14496)
enterprise: fix expired license's users being counted (#14451)

* enterprise: fix expired license's users being counted



* tests to the rescue



* hmm



---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
Co-authored-by: Jens L. <jens@goauthentik.io>
2025-05-13 16:07:08 +02:00

240 lines
8.6 KiB
Python

"""Enterprise license"""
from base64 import b64decode
from binascii import Error
from dataclasses import asdict, dataclass, field
from datetime import UTC, datetime, timedelta
from enum import Enum
from functools import lru_cache
from time import mktime
from cryptography.exceptions import InvalidSignature
from cryptography.x509 import Certificate, load_der_x509_certificate, load_pem_x509_certificate
from dacite import DaciteError, from_dict
from django.core.cache import cache
from django.db.models.query import QuerySet
from django.utils.timezone import now
from jwt import PyJWTError, decode, get_unverified_header
from rest_framework.exceptions import ValidationError
from rest_framework.fields import (
ChoiceField,
DateTimeField,
IntegerField,
ListField,
)
from authentik.core.api.utils import PassiveSerializer
from authentik.core.models import User, UserTypes
from authentik.enterprise.models import (
THRESHOLD_READ_ONLY_WEEKS,
THRESHOLD_WARNING_ADMIN_WEEKS,
THRESHOLD_WARNING_EXPIRY_WEEKS,
THRESHOLD_WARNING_USER_WEEKS,
License,
LicenseUsage,
LicenseUsageStatus,
)
from authentik.tenants.utils import get_unique_identifier
CACHE_KEY_ENTERPRISE_LICENSE = "goauthentik.io/enterprise/license"
CACHE_EXPIRY_ENTERPRISE_LICENSE = 3 * 60 * 60 # 2 Hours
@lru_cache
def get_licensing_key() -> Certificate:
"""Get Root CA PEM"""
with open("authentik/enterprise/public.pem", "rb") as _key:
return load_pem_x509_certificate(_key.read())
def get_license_aud() -> str:
"""Get the JWT audience field"""
return f"enterprise.goauthentik.io/license/{get_unique_identifier()}"
class LicenseFlags(Enum):
"""License flags"""
TRIAL = "trial"
NON_PRODUCTION = "non_production"
@dataclass
class LicenseSummary:
"""Internal representation of a license summary"""
internal_users: int
external_users: int
status: LicenseUsageStatus
latest_valid: datetime
license_flags: list[LicenseFlags]
class LicenseSummarySerializer(PassiveSerializer):
"""Serializer for license status"""
internal_users = IntegerField(required=True)
external_users = IntegerField(required=True)
status = ChoiceField(choices=LicenseUsageStatus.choices)
latest_valid = DateTimeField()
license_flags = ListField(child=ChoiceField(choices=tuple(x.value for x in LicenseFlags)))
@dataclass
class LicenseKey:
"""License JWT claims"""
aud: str
exp: int
name: str
internal_users: int = 0
external_users: int = 0
license_flags: list[LicenseFlags] = field(default_factory=list)
@staticmethod
def validate(jwt: str, check_expiry=True) -> "LicenseKey":
"""Validate the license from a given JWT"""
try:
headers = get_unverified_header(jwt)
except PyJWTError:
raise ValidationError("Unable to verify license") from None
x5c: list[str] = headers.get("x5c", [])
if len(x5c) < 1:
raise ValidationError("Unable to verify license")
try:
our_cert = load_der_x509_certificate(b64decode(x5c[0]))
intermediate = load_der_x509_certificate(b64decode(x5c[1]))
our_cert.verify_directly_issued_by(intermediate)
intermediate.verify_directly_issued_by(get_licensing_key())
except (InvalidSignature, TypeError, ValueError, Error):
raise ValidationError("Unable to verify license") from None
try:
body = from_dict(
LicenseKey,
decode(
jwt,
our_cert.public_key(),
algorithms=["ES512"],
audience=get_license_aud(),
options={"verify_exp": check_expiry, "verify_signature": check_expiry},
),
)
except PyJWTError:
unverified = decode(jwt, options={"verify_signature": False})
if unverified["aud"] != get_license_aud():
raise ValidationError("Invalid Install ID in license") from None
raise ValidationError("Unable to verify license") from None
return body
@staticmethod
def get_total() -> "LicenseKey":
"""Get a summarized version of all (not expired) licenses"""
total = LicenseKey(get_license_aud(), 0, "Summarized license", 0, 0)
for lic in License.objects.all():
if lic.is_valid:
total.internal_users += lic.internal_users
total.external_users += lic.external_users
total.license_flags.extend(lic.status.license_flags)
exp_ts = int(mktime(lic.expiry.timetuple()))
if total.exp == 0:
total.exp = exp_ts
total.exp = max(total.exp, exp_ts)
return total
@staticmethod
def base_user_qs() -> QuerySet:
"""Base query set for all users"""
return User.objects.all().exclude_anonymous().exclude(is_active=False)
@staticmethod
def get_internal_user_count():
"""Get current default user count"""
return LicenseKey.base_user_qs().filter(type=UserTypes.INTERNAL).count()
@staticmethod
def get_external_user_count():
"""Get current external user count"""
return LicenseKey.base_user_qs().filter(type=UserTypes.EXTERNAL).count()
def _last_valid_date(self):
last_valid_date = (
LicenseUsage.objects.order_by("-record_date")
.filter(status=LicenseUsageStatus.VALID)
.first()
)
if not last_valid_date:
return datetime.fromtimestamp(0, UTC)
return last_valid_date.record_date
def status(self) -> LicenseUsageStatus:
"""Check if the given license body covers all users, and is valid."""
last_valid = self._last_valid_date()
if self.exp == 0 and not License.objects.exists():
return LicenseUsageStatus.UNLICENSED
_now = now()
# Check limit-exceeded based status
internal_users = self.get_internal_user_count()
external_users = self.get_external_user_count()
if internal_users > self.internal_users or external_users > self.external_users:
if last_valid < _now - timedelta(weeks=THRESHOLD_READ_ONLY_WEEKS):
return LicenseUsageStatus.READ_ONLY
if last_valid < _now - timedelta(weeks=THRESHOLD_WARNING_USER_WEEKS):
return LicenseUsageStatus.LIMIT_EXCEEDED_USER
if last_valid < _now - timedelta(weeks=THRESHOLD_WARNING_ADMIN_WEEKS):
return LicenseUsageStatus.LIMIT_EXCEEDED_ADMIN
# Check expiry based status
if datetime.fromtimestamp(self.exp, UTC) < _now:
if datetime.fromtimestamp(self.exp, UTC) < _now - timedelta(
weeks=THRESHOLD_READ_ONLY_WEEKS
):
return LicenseUsageStatus.READ_ONLY
return LicenseUsageStatus.EXPIRED
# Expiry warning
if datetime.fromtimestamp(self.exp, UTC) <= _now + timedelta(
weeks=THRESHOLD_WARNING_EXPIRY_WEEKS
):
return LicenseUsageStatus.EXPIRY_SOON
return LicenseUsageStatus.VALID
def record_usage(self):
"""Capture the current validity status and metrics and save them"""
threshold = now() - timedelta(hours=8)
usage = (
LicenseUsage.objects.order_by("-record_date").filter(record_date__gte=threshold).first()
)
if not usage:
usage = LicenseUsage.objects.create(
internal_user_count=self.get_internal_user_count(),
external_user_count=self.get_external_user_count(),
status=self.status(),
)
summary = asdict(self.summary())
# Also cache the latest summary for the middleware
cache.set(CACHE_KEY_ENTERPRISE_LICENSE, summary, timeout=CACHE_EXPIRY_ENTERPRISE_LICENSE)
return usage
def summary(self) -> LicenseSummary:
"""Summary of license status"""
status = self.status()
latest_valid = datetime.fromtimestamp(self.exp)
return LicenseSummary(
latest_valid=latest_valid,
internal_users=self.internal_users,
external_users=self.external_users,
status=status,
license_flags=self.license_flags,
)
@staticmethod
def cached_summary() -> LicenseSummary:
"""Helper method which looks up the last summary"""
summary = cache.get(CACHE_KEY_ENTERPRISE_LICENSE)
if not summary:
return LicenseKey.get_total().summary()
try:
return from_dict(LicenseSummary, summary)
except DaciteError:
cache.delete(CACHE_KEY_ENTERPRISE_LICENSE)
return LicenseKey.get_total().summary()