From 4b5bb77d99c8bd025bb7c9dd6f9979218940c465 Mon Sep 17 00:00:00 2001 From: "Jens L." Date: Fri, 9 Aug 2024 14:26:38 +0200 Subject: [PATCH] enterprise: UI improvements, better handling of expiry (#10828) * web/admin: show enterprise banner on the very top Signed-off-by: Jens Langhammer * rework license Signed-off-by: Jens Langhammer * fix a bunch of things Signed-off-by: Jens Langhammer * add some more tests Signed-off-by: Jens Langhammer * add more tests Signed-off-by: Jens Langhammer * fix middleware Signed-off-by: Jens Langhammer * better api Signed-off-by: Jens Langhammer * format Signed-off-by: Jens Langhammer * add tests for and fix read only mode Signed-off-by: Jens Langhammer * field name consistency Signed-off-by: Jens Langhammer --------- Signed-off-by: Jens Langhammer --- authentik/admin/api/system.py | 2 +- authentik/blueprints/v1/importer.py | 2 +- authentik/enterprise/api.py | 6 +- authentik/enterprise/apps.py | 2 +- authentik/enterprise/license.py | 131 ++++++----- authentik/enterprise/middleware.py | 7 +- ...ove_licenseusage_within_limits_and_more.py | 68 ++++++ authentik/enterprise/models.py | 37 ++- authentik/enterprise/policy.py | 2 +- authentik/enterprise/tests/test_license.py | 209 ++++++++++++++++- authentik/enterprise/tests/test_read_only.py | 217 ++++++++++++++++++ authentik/outposts/api/outposts.py | 2 +- blueprints/schema.json | 56 ++--- schema.yml | 26 +-- .../admin/AdminInterface/AdminInterface.ts | 7 + .../enterprise/EnterpriseLicenseListPage.ts | 3 +- .../Interface/licenseSummaryProvider.ts | 4 +- web/src/elements/PageHeader.ts | 113 +++++---- .../enterprise/EnterpriseStatusBanner.ts | 48 +++- web/src/elements/sidebar/Sidebar.ts | 1 - 20 files changed, 749 insertions(+), 194 deletions(-) create mode 100644 authentik/enterprise/migrations/0003_remove_licenseusage_within_limits_and_more.py create mode 100644 authentik/enterprise/tests/test_read_only.py diff --git a/authentik/admin/api/system.py b/authentik/admin/api/system.py index ac9df17e71..1e119e5fbc 100644 --- a/authentik/admin/api/system.py +++ b/authentik/admin/api/system.py @@ -73,7 +73,7 @@ class SystemInfoSerializer(PassiveSerializer): "authentik_version": get_full_version(), "environment": get_env(), "openssl_fips_enabled": ( - backend._fips_enabled if LicenseKey.get_total().is_valid() else None + backend._fips_enabled if LicenseKey.get_total().status().is_valid else None ), "openssl_version": OPENSSL_VERSION, "platform": platform.platform(), diff --git a/authentik/blueprints/v1/importer.py b/authentik/blueprints/v1/importer.py index 2143fde053..08d7976932 100644 --- a/authentik/blueprints/v1/importer.py +++ b/authentik/blueprints/v1/importer.py @@ -171,7 +171,7 @@ class Importer: def default_context(self): """Default context""" return { - "goauthentik.io/enterprise/licensed": LicenseKey.get_total().is_valid(), + "goauthentik.io/enterprise/licensed": LicenseKey.get_total().status().is_valid, "goauthentik.io/rbac/models": rbac_models(), } diff --git a/authentik/enterprise/api.py b/authentik/enterprise/api.py index 9f66cd0653..510fc378f4 100644 --- a/authentik/enterprise/api.py +++ b/authentik/enterprise/api.py @@ -19,7 +19,7 @@ from authentik.core.api.used_by import UsedByMixin from authentik.core.api.utils import ModelSerializer, PassiveSerializer from authentik.core.models import User, UserTypes from authentik.enterprise.license import LicenseKey, LicenseSummarySerializer -from authentik.enterprise.models import License +from authentik.enterprise.models import License, LicenseUsageStatus from authentik.rbac.decorators import permission_required from authentik.tenants.utils import get_unique_identifier @@ -30,7 +30,7 @@ class EnterpriseRequiredMixin: def validate(self, attrs: dict) -> dict: """Check that a valid license exists""" - if not LicenseKey.cached_summary().has_license: + if LicenseKey.cached_summary().status != LicenseUsageStatus.UNLICENSED: raise ValidationError(_("Enterprise is required to create/update this object.")) return super().validate(attrs) @@ -128,7 +128,7 @@ class LicenseViewSet(UsedByMixin, ModelViewSet): forecast_for_months = 12 response = LicenseForecastSerializer( data={ - "internal_users": LicenseKey.get_default_user_count(), + "internal_users": LicenseKey.get_internal_user_count(), "external_users": LicenseKey.get_external_user_count(), "forecasted_internal_users": (internal_in_last_month * forecast_for_months), "forecasted_external_users": (external_in_last_month * forecast_for_months), diff --git a/authentik/enterprise/apps.py b/authentik/enterprise/apps.py index 83dbefa06a..e62f317a7c 100644 --- a/authentik/enterprise/apps.py +++ b/authentik/enterprise/apps.py @@ -25,4 +25,4 @@ class AuthentikEnterpriseConfig(EnterpriseConfig): """Actual enterprise check, cached""" from authentik.enterprise.license import LicenseKey - return LicenseKey.cached_summary().valid + return LicenseKey.cached_summary().status diff --git a/authentik/enterprise/license.py b/authentik/enterprise/license.py index 1d2062811e..3ce789e34c 100644 --- a/authentik/enterprise/license.py +++ b/authentik/enterprise/license.py @@ -3,24 +3,36 @@ from base64 import b64decode from binascii import Error from dataclasses import asdict, dataclass, field -from datetime import datetime, timedelta +from datetime import UTC, datetime, timedelta from enum import Enum from functools import lru_cache from time import mktime from cryptography.exceptions import InvalidSignature from cryptography.x509 import Certificate, load_der_x509_certificate, load_pem_x509_certificate -from dacite import from_dict +from dacite import DaciteError, from_dict from django.core.cache import cache from django.db.models.query import QuerySet from django.utils.timezone import now from jwt import PyJWTError, decode, get_unverified_header from rest_framework.exceptions import ValidationError -from rest_framework.fields import BooleanField, DateTimeField, IntegerField +from rest_framework.fields import ( + ChoiceField, + DateTimeField, + IntegerField, +) from authentik.core.api.utils import PassiveSerializer from authentik.core.models import User, UserTypes -from authentik.enterprise.models import License, LicenseUsage +from authentik.enterprise.models import ( + THRESHOLD_READ_ONLY_WEEKS, + THRESHOLD_WARNING_ADMIN_WEEKS, + THRESHOLD_WARNING_EXPIRY_WEEKS, + THRESHOLD_WARNING_USER_WEEKS, + License, + LicenseUsage, + LicenseUsageStatus, +) from authentik.tenants.utils import get_unique_identifier CACHE_KEY_ENTERPRISE_LICENSE = "goauthentik.io/enterprise/license" @@ -42,6 +54,8 @@ def get_license_aud() -> str: class LicenseFlags(Enum): """License flags""" + TRIAL = "trial" + @dataclass class LicenseSummary: @@ -49,12 +63,8 @@ class LicenseSummary: internal_users: int external_users: int - valid: bool - show_admin_warning: bool - show_user_warning: bool - read_only: bool + status: LicenseUsageStatus latest_valid: datetime - has_license: bool class LicenseSummarySerializer(PassiveSerializer): @@ -62,12 +72,8 @@ class LicenseSummarySerializer(PassiveSerializer): internal_users = IntegerField(required=True) external_users = IntegerField(required=True) - valid = BooleanField() - show_admin_warning = BooleanField() - show_user_warning = BooleanField() - read_only = BooleanField() + status = ChoiceField(choices=LicenseUsageStatus.choices) latest_valid = DateTimeField() - has_license = BooleanField() @dataclass @@ -83,7 +89,7 @@ class LicenseKey: flags: list[LicenseFlags] = field(default_factory=list) @staticmethod - def validate(jwt: str) -> "LicenseKey": + def validate(jwt: str, check_expiry=True) -> "LicenseKey": """Validate the license from a given JWT""" try: headers = get_unverified_header(jwt) @@ -107,6 +113,7 @@ class LicenseKey: our_cert.public_key(), algorithms=["ES512"], audience=get_license_aud(), + options={"verify_exp": check_expiry}, ), ) except PyJWTError: @@ -116,9 +123,8 @@ class LicenseKey: @staticmethod def get_total() -> "LicenseKey": """Get a summarized version of all (not expired) licenses""" - active_licenses = License.objects.filter(expiry__gte=now()) total = LicenseKey(get_license_aud(), 0, "Summarized license", 0, 0) - for lic in active_licenses: + for lic in License.objects.all(): total.internal_users += lic.internal_users total.external_users += lic.external_users exp_ts = int(mktime(lic.expiry.timetuple())) @@ -135,7 +141,7 @@ class LicenseKey: return User.objects.all().exclude_anonymous().exclude(is_active=False) @staticmethod - def get_default_user_count(): + def get_internal_user_count(): """Get current default user count""" return LicenseKey.base_user_qs().filter(type=UserTypes.INTERNAL).count() @@ -144,59 +150,72 @@ class LicenseKey: """Get current external user count""" return LicenseKey.base_user_qs().filter(type=UserTypes.EXTERNAL).count() - def is_valid(self) -> bool: - """Check if the given license body covers all users + def _last_valid_date(self): + last_valid_date = ( + LicenseUsage.objects.order_by("-record_date") + .filter(status=LicenseUsageStatus.VALID) + .first() + ) + if not last_valid_date: + return datetime.fromtimestamp(0, UTC) + return last_valid_date.record_date - Only checks the current count, no historical data is checked""" - default_users = self.get_default_user_count() - if default_users > self.internal_users: - return False - active_users = self.get_external_user_count() - if active_users > self.external_users: - return False - return True + def status(self) -> LicenseUsageStatus: + """Check if the given license body covers all users, and is valid.""" + last_valid = self._last_valid_date() + if self.exp == 0 and not License.objects.exists(): + return LicenseUsageStatus.UNLICENSED + _now = now() + # Check limit-exceeded based status + internal_users = self.get_internal_user_count() + external_users = self.get_external_user_count() + if internal_users > self.internal_users or external_users > self.external_users: + if last_valid < _now - timedelta(weeks=THRESHOLD_READ_ONLY_WEEKS): + return LicenseUsageStatus.READ_ONLY + if last_valid < _now - timedelta(weeks=THRESHOLD_WARNING_USER_WEEKS): + return LicenseUsageStatus.LIMIT_EXCEEDED_USER + if last_valid < _now - timedelta(weeks=THRESHOLD_WARNING_ADMIN_WEEKS): + return LicenseUsageStatus.LIMIT_EXCEEDED_ADMIN + # Check expiry based status + if datetime.fromtimestamp(self.exp, UTC) < _now: + if datetime.fromtimestamp(self.exp, UTC) < _now - timedelta( + weeks=THRESHOLD_READ_ONLY_WEEKS + ): + return LicenseUsageStatus.READ_ONLY + return LicenseUsageStatus.EXPIRED + # Expiry warning + if datetime.fromtimestamp(self.exp, UTC) <= _now + timedelta( + weeks=THRESHOLD_WARNING_EXPIRY_WEEKS + ): + return LicenseUsageStatus.EXPIRY_SOON + return LicenseUsageStatus.VALID def record_usage(self): """Capture the current validity status and metrics and save them""" threshold = now() - timedelta(hours=8) - if not LicenseUsage.objects.filter(record_date__gte=threshold).exists(): - LicenseUsage.objects.create( - user_count=self.get_default_user_count(), + usage = ( + LicenseUsage.objects.order_by("-record_date").filter(record_date__gte=threshold).first() + ) + if not usage: + usage = LicenseUsage.objects.create( + internal_user_count=self.get_internal_user_count(), external_user_count=self.get_external_user_count(), - within_limits=self.is_valid(), + status=self.status(), ) summary = asdict(self.summary()) # Also cache the latest summary for the middleware cache.set(CACHE_KEY_ENTERPRISE_LICENSE, summary, timeout=CACHE_EXPIRY_ENTERPRISE_LICENSE) - return summary - - @staticmethod - def last_valid_date() -> datetime: - """Get the last date the license was valid""" - usage: LicenseUsage = ( - LicenseUsage.filter_not_expired(within_limits=True).order_by("-record_date").first() - ) - if not usage: - return now() - return usage.record_date + return usage def summary(self) -> LicenseSummary: """Summary of license status""" - has_license = License.objects.all().count() > 0 - last_valid = LicenseKey.last_valid_date() - show_admin_warning = last_valid < now() - timedelta(weeks=2) - show_user_warning = last_valid < now() - timedelta(weeks=4) - read_only = last_valid < now() - timedelta(weeks=6) + status = self.status() latest_valid = datetime.fromtimestamp(self.exp) return LicenseSummary( - show_admin_warning=show_admin_warning and has_license, - show_user_warning=show_user_warning and has_license, - read_only=read_only and has_license, latest_valid=latest_valid, internal_users=self.internal_users, external_users=self.external_users, - valid=self.is_valid(), - has_license=has_license, + status=status, ) @staticmethod @@ -205,4 +224,8 @@ class LicenseKey: summary = cache.get(CACHE_KEY_ENTERPRISE_LICENSE) if not summary: return LicenseKey.get_total().summary() - return from_dict(LicenseSummary, summary) + try: + return from_dict(LicenseSummary, summary) + except DaciteError: + cache.delete(CACHE_KEY_ENTERPRISE_LICENSE) + return LicenseKey.get_total().summary() diff --git a/authentik/enterprise/middleware.py b/authentik/enterprise/middleware.py index 83ff8af05f..681194eece 100644 --- a/authentik/enterprise/middleware.py +++ b/authentik/enterprise/middleware.py @@ -8,6 +8,7 @@ from structlog.stdlib import BoundLogger, get_logger from authentik.enterprise.api import LicenseViewSet from authentik.enterprise.license import LicenseKey +from authentik.enterprise.models import LicenseUsageStatus from authentik.flows.views.executor import FlowExecutorView from authentik.lib.utils.reflection import class_to_path @@ -43,7 +44,7 @@ class EnterpriseMiddleware: cached_status = LicenseKey.cached_summary() if not cached_status: return True - if cached_status.read_only: + if cached_status.status == LicenseUsageStatus.READ_ONLY: return False return True @@ -53,10 +54,10 @@ class EnterpriseMiddleware: if request.method.lower() in ["get", "head", "options", "trace"]: return True # Always allow requests to manage licenses - if class_to_path(request.resolver_match.func) == class_to_path(LicenseViewSet): + if request.resolver_match._func_path == class_to_path(LicenseViewSet): return True # Flow executor is mounted as an API path but explicitly allowed - if class_to_path(request.resolver_match.func) == class_to_path(FlowExecutorView): + if request.resolver_match._func_path == class_to_path(FlowExecutorView): return True # Only apply these restrictions to the API if "authentik_api" not in request.resolver_match.app_names: diff --git a/authentik/enterprise/migrations/0003_remove_licenseusage_within_limits_and_more.py b/authentik/enterprise/migrations/0003_remove_licenseusage_within_limits_and_more.py new file mode 100644 index 0000000000..4943da8536 --- /dev/null +++ b/authentik/enterprise/migrations/0003_remove_licenseusage_within_limits_and_more.py @@ -0,0 +1,68 @@ +# Generated by Django 5.0.8 on 2024-08-08 14:15 + +from django.db import migrations, models +from django.apps.registry import Apps +from django.db.backends.base.schema import BaseDatabaseSchemaEditor + + +def migrate_license_usage(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): + LicenseUsage = apps.get_model("authentik_enterprise", "licenseusage") + db_alias = schema_editor.connection.alias + + for usage in LicenseUsage.objects.using(db_alias).all(): + usage.status = "valid" if usage.within_limits else "limit_exceeded_admin" + usage.save() + + +class Migration(migrations.Migration): + + dependencies = [ + ("authentik_enterprise", "0002_rename_users_license_internal_users_and_more"), + ] + + operations = [ + migrations.AddField( + model_name="licenseusage", + name="status", + field=models.TextField( + choices=[ + ("unlicensed", "Unlicensed"), + ("valid", "Valid"), + ("expired", "Expired"), + ("expiry_soon", "Expiry Soon"), + ("limit_exceeded_admin", "Limit Exceeded Admin"), + ("limit_exceeded_user", "Limit Exceeded User"), + ("read_only", "Read Only"), + ], + default=None, + null=True, + ), + preserve_default=False, + ), + migrations.RunPython(migrate_license_usage), + migrations.RemoveField( + model_name="licenseusage", + name="within_limits", + ), + migrations.AlterField( + model_name="licenseusage", + name="status", + field=models.TextField( + choices=[ + ("unlicensed", "Unlicensed"), + ("valid", "Valid"), + ("expired", "Expired"), + ("expiry_soon", "Expiry Soon"), + ("limit_exceeded_admin", "Limit Exceeded Admin"), + ("limit_exceeded_user", "Limit Exceeded User"), + ("read_only", "Read Only"), + ], + ), + preserve_default=False, + ), + migrations.RenameField( + model_name="licenseusage", + old_name="user_count", + new_name="internal_user_count", + ), + ] diff --git a/authentik/enterprise/models.py b/authentik/enterprise/models.py index 6600e5c07a..3130e29eb9 100644 --- a/authentik/enterprise/models.py +++ b/authentik/enterprise/models.py @@ -17,6 +17,17 @@ if TYPE_CHECKING: from authentik.enterprise.license import LicenseKey +def usage_expiry(): + """Keep license usage records for 3 months""" + return now() + timedelta(days=30 * 3) + + +THRESHOLD_WARNING_ADMIN_WEEKS = 2 +THRESHOLD_WARNING_USER_WEEKS = 4 +THRESHOLD_WARNING_EXPIRY_WEEKS = 2 +THRESHOLD_READ_ONLY_WEEKS = 6 + + class License(SerializerModel): """An authentik enterprise license""" @@ -39,7 +50,7 @@ class License(SerializerModel): """Get parsed license status""" from authentik.enterprise.license import LicenseKey - return LicenseKey.validate(self.key) + return LicenseKey.validate(self.key, check_expiry=False) class Meta: indexes = (HashIndex(fields=("key",)),) @@ -47,9 +58,23 @@ class License(SerializerModel): verbose_name_plural = _("Licenses") -def usage_expiry(): - """Keep license usage records for 3 months""" - return now() + timedelta(days=30 * 3) +class LicenseUsageStatus(models.TextChoices): + """License states an instance/tenant can be in""" + + UNLICENSED = "unlicensed" + VALID = "valid" + EXPIRED = "expired" + EXPIRY_SOON = "expiry_soon" + # User limit exceeded, 2 week threshold, show message in admin interface + LIMIT_EXCEEDED_ADMIN = "limit_exceeded_admin" + # User limit exceeded, 4 week threshold, show message in user interface + LIMIT_EXCEEDED_USER = "limit_exceeded_user" + READ_ONLY = "read_only" + + @property + def is_valid(self) -> bool: + """Quickly check if a license is valid""" + return self in [LicenseUsageStatus.VALID, LicenseUsageStatus.EXPIRY_SOON] class LicenseUsage(ExpiringModel): @@ -59,9 +84,9 @@ class LicenseUsage(ExpiringModel): usage_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4) - user_count = models.BigIntegerField() + internal_user_count = models.BigIntegerField() external_user_count = models.BigIntegerField() - within_limits = models.BooleanField() + status = models.TextField(choices=LicenseUsageStatus.choices) record_date = models.DateTimeField(auto_now_add=True) diff --git a/authentik/enterprise/policy.py b/authentik/enterprise/policy.py index 904c3f73ee..0c4cc91ae8 100644 --- a/authentik/enterprise/policy.py +++ b/authentik/enterprise/policy.py @@ -13,7 +13,7 @@ class EnterprisePolicyAccessView(PolicyAccessView): def check_license(self): """Check license""" - if not LicenseKey.get_total().is_valid(): + if not LicenseKey.get_total().status().is_valid: return PolicyResult(False, _("Enterprise required to access this feature.")) if self.request.user.type != UserTypes.INTERNAL: return PolicyResult(False, _("Feature only accessible for internal users.")) diff --git a/authentik/enterprise/tests/test_license.py b/authentik/enterprise/tests/test_license.py index efa45e0eb6..c76f141f10 100644 --- a/authentik/enterprise/tests/test_license.py +++ b/authentik/enterprise/tests/test_license.py @@ -9,10 +9,26 @@ from django.utils.timezone import now from rest_framework.exceptions import ValidationError from authentik.enterprise.license import LicenseKey -from authentik.enterprise.models import License +from authentik.enterprise.models import ( + THRESHOLD_READ_ONLY_WEEKS, + THRESHOLD_WARNING_ADMIN_WEEKS, + THRESHOLD_WARNING_USER_WEEKS, + License, + LicenseUsage, + LicenseUsageStatus, +) from authentik.lib.generators import generate_id -_exp = int(mktime((now() + timedelta(days=3000)).timetuple())) +# Valid license expiry +expiry_valid = int(mktime((now() + timedelta(days=3000)).timetuple())) +# Valid license expiry, expires soon +expiry_soon = int(mktime((now() + timedelta(hours=10)).timetuple())) +# Invalid license expiry, recently expired +expiry_expired = int(mktime((now() - timedelta(hours=10)).timetuple())) +# Invalid license expiry, expired longer ago +expiry_expired_read_only = int( + mktime((now() - timedelta(weeks=THRESHOLD_READ_ONLY_WEEKS + 1)).timetuple()) +) class TestEnterpriseLicense(TestCase): @@ -23,7 +39,7 @@ class TestEnterpriseLicense(TestCase): MagicMock( return_value=LicenseKey( aud="", - exp=_exp, + exp=expiry_valid, name=generate_id(), internal_users=100, external_users=100, @@ -33,7 +49,7 @@ class TestEnterpriseLicense(TestCase): def test_valid(self): """Check license verification""" lic = License.objects.create(key=generate_id()) - self.assertTrue(lic.status.is_valid()) + self.assertTrue(lic.status.status().is_valid) self.assertEqual(lic.internal_users, 100) def test_invalid(self): @@ -46,7 +62,7 @@ class TestEnterpriseLicense(TestCase): MagicMock( return_value=LicenseKey( aud="", - exp=_exp, + exp=expiry_valid, name=generate_id(), internal_users=100, external_users=100, @@ -56,11 +72,186 @@ class TestEnterpriseLicense(TestCase): def test_valid_multiple(self): """Check license verification""" lic = License.objects.create(key=generate_id()) - self.assertTrue(lic.status.is_valid()) + self.assertTrue(lic.status.status().is_valid) lic2 = License.objects.create(key=generate_id()) - self.assertTrue(lic2.status.is_valid()) + self.assertTrue(lic2.status.status().is_valid) total = LicenseKey.get_total() self.assertEqual(total.internal_users, 200) self.assertEqual(total.external_users, 200) - self.assertEqual(total.exp, _exp) - self.assertTrue(total.is_valid()) + self.assertEqual(total.exp, expiry_valid) + self.assertTrue(total.status().is_valid) + + @patch( + "authentik.enterprise.license.LicenseKey.validate", + MagicMock( + return_value=LicenseKey( + aud="", + exp=expiry_valid, + name=generate_id(), + internal_users=100, + external_users=100, + ) + ), + ) + @patch( + "authentik.enterprise.license.LicenseKey.get_internal_user_count", + MagicMock(return_value=1000), + ) + @patch( + "authentik.enterprise.license.LicenseKey.get_external_user_count", + MagicMock(return_value=1000), + ) + @patch( + "authentik.enterprise.license.LicenseKey.record_usage", + MagicMock(), + ) + def test_limit_exceeded_read_only(self): + """Check license verification""" + License.objects.create(key=generate_id()) + usage = LicenseUsage.objects.create( + internal_user_count=100, + external_user_count=100, + status=LicenseUsageStatus.VALID, + ) + usage.record_date = now() - timedelta(weeks=THRESHOLD_READ_ONLY_WEEKS + 1) + usage.save(update_fields=["record_date"]) + self.assertEqual(LicenseKey.get_total().summary().status, LicenseUsageStatus.READ_ONLY) + + @patch( + "authentik.enterprise.license.LicenseKey.validate", + MagicMock( + return_value=LicenseKey( + aud="", + exp=expiry_valid, + name=generate_id(), + internal_users=100, + external_users=100, + ) + ), + ) + @patch( + "authentik.enterprise.license.LicenseKey.get_internal_user_count", + MagicMock(return_value=1000), + ) + @patch( + "authentik.enterprise.license.LicenseKey.get_external_user_count", + MagicMock(return_value=1000), + ) + @patch( + "authentik.enterprise.license.LicenseKey.record_usage", + MagicMock(), + ) + def test_limit_exceeded_user_warning(self): + """Check license verification""" + License.objects.create(key=generate_id()) + usage = LicenseUsage.objects.create( + internal_user_count=100, + external_user_count=100, + status=LicenseUsageStatus.VALID, + ) + usage.record_date = now() - timedelta(weeks=THRESHOLD_WARNING_USER_WEEKS + 1) + usage.save(update_fields=["record_date"]) + self.assertEqual( + LicenseKey.get_total().summary().status, LicenseUsageStatus.LIMIT_EXCEEDED_USER + ) + + @patch( + "authentik.enterprise.license.LicenseKey.validate", + MagicMock( + return_value=LicenseKey( + aud="", + exp=expiry_valid, + name=generate_id(), + internal_users=100, + external_users=100, + ) + ), + ) + @patch( + "authentik.enterprise.license.LicenseKey.get_internal_user_count", + MagicMock(return_value=1000), + ) + @patch( + "authentik.enterprise.license.LicenseKey.get_external_user_count", + MagicMock(return_value=1000), + ) + @patch( + "authentik.enterprise.license.LicenseKey.record_usage", + MagicMock(), + ) + def test_limit_exceeded_admin_warning(self): + """Check license verification""" + License.objects.create(key=generate_id()) + usage = LicenseUsage.objects.create( + internal_user_count=100, + external_user_count=100, + status=LicenseUsageStatus.VALID, + ) + usage.record_date = now() - timedelta(weeks=THRESHOLD_WARNING_ADMIN_WEEKS + 1) + usage.save(update_fields=["record_date"]) + self.assertEqual( + LicenseKey.get_total().summary().status, LicenseUsageStatus.LIMIT_EXCEEDED_ADMIN + ) + + @patch( + "authentik.enterprise.license.LicenseKey.validate", + MagicMock( + return_value=LicenseKey( + aud="", + exp=expiry_expired_read_only, + name=generate_id(), + internal_users=100, + external_users=100, + ) + ), + ) + @patch( + "authentik.enterprise.license.LicenseKey.record_usage", + MagicMock(), + ) + def test_expiry_read_only(self): + """Check license verification""" + License.objects.create(key=generate_id()) + self.assertEqual(LicenseKey.get_total().summary().status, LicenseUsageStatus.READ_ONLY) + + @patch( + "authentik.enterprise.license.LicenseKey.validate", + MagicMock( + return_value=LicenseKey( + aud="", + exp=expiry_expired, + name=generate_id(), + internal_users=100, + external_users=100, + ) + ), + ) + @patch( + "authentik.enterprise.license.LicenseKey.record_usage", + MagicMock(), + ) + def test_expiry_expired(self): + """Check license verification""" + License.objects.create(key=generate_id()) + self.assertEqual(LicenseKey.get_total().summary().status, LicenseUsageStatus.EXPIRED) + + @patch( + "authentik.enterprise.license.LicenseKey.validate", + MagicMock( + return_value=LicenseKey( + aud="", + exp=expiry_soon, + name=generate_id(), + internal_users=100, + external_users=100, + ) + ), + ) + @patch( + "authentik.enterprise.license.LicenseKey.record_usage", + MagicMock(), + ) + def test_expiry_soon(self): + """Check license verification""" + License.objects.create(key=generate_id()) + self.assertEqual(LicenseKey.get_total().summary().status, LicenseUsageStatus.EXPIRY_SOON) diff --git a/authentik/enterprise/tests/test_read_only.py b/authentik/enterprise/tests/test_read_only.py new file mode 100644 index 0000000000..f6af83c989 --- /dev/null +++ b/authentik/enterprise/tests/test_read_only.py @@ -0,0 +1,217 @@ +"""read only tests""" + +from datetime import timedelta +from unittest.mock import MagicMock, patch + +from django.urls import reverse +from django.utils.timezone import now + +from authentik.core.tests.utils import create_test_admin_user, create_test_flow, create_test_user +from authentik.enterprise.license import LicenseKey +from authentik.enterprise.models import ( + THRESHOLD_READ_ONLY_WEEKS, + License, + LicenseUsage, + LicenseUsageStatus, +) +from authentik.enterprise.tests.test_license import expiry_valid +from authentik.flows.models import ( + FlowDesignation, + FlowStageBinding, +) +from authentik.flows.tests import FlowTestCase +from authentik.lib.generators import generate_id +from authentik.stages.identification.models import IdentificationStage, UserFields +from authentik.stages.password import BACKEND_INBUILT +from authentik.stages.password.models import PasswordStage +from authentik.stages.user_login.models import UserLoginStage + + +class TestReadOnly(FlowTestCase): + """Test read_only""" + + @patch( + "authentik.enterprise.license.LicenseKey.validate", + MagicMock( + return_value=LicenseKey( + aud="", + exp=expiry_valid, + name=generate_id(), + internal_users=100, + external_users=100, + ) + ), + ) + @patch( + "authentik.enterprise.license.LicenseKey.get_internal_user_count", + MagicMock(return_value=1000), + ) + @patch( + "authentik.enterprise.license.LicenseKey.get_external_user_count", + MagicMock(return_value=1000), + ) + @patch( + "authentik.enterprise.license.LicenseKey.record_usage", + MagicMock(), + ) + def test_login(self): + """Test flow, ensure login is still possible with read only mode""" + License.objects.create(key=generate_id()) + usage = LicenseUsage.objects.create( + internal_user_count=100, + external_user_count=100, + status=LicenseUsageStatus.VALID, + ) + usage.record_date = now() - timedelta(weeks=THRESHOLD_READ_ONLY_WEEKS + 1) + usage.save(update_fields=["record_date"]) + + flow = create_test_flow( + FlowDesignation.AUTHENTICATION, + ) + + ident_stage = IdentificationStage.objects.create( + name=generate_id(), + user_fields=[UserFields.E_MAIL], + pretend_user_exists=False, + ) + FlowStageBinding.objects.create( + target=flow, + stage=ident_stage, + order=0, + ) + password_stage = PasswordStage.objects.create( + name=generate_id(), backends=[BACKEND_INBUILT] + ) + FlowStageBinding.objects.create( + target=flow, + stage=password_stage, + order=1, + ) + login_stage = UserLoginStage.objects.create( + name=generate_id(), + ) + FlowStageBinding.objects.create( + target=flow, + stage=login_stage, + order=2, + ) + + user = create_test_user() + + exec_url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}) + response = self.client.get(exec_url) + self.assertStageResponse( + response, + flow, + component="ak-stage-identification", + password_fields=False, + primary_action="Log in", + sources=[], + show_source_labels=False, + user_fields=[UserFields.E_MAIL], + ) + response = self.client.post(exec_url, {"uid_field": user.email}, follow=True) + self.assertStageResponse(response, flow, component="ak-stage-password") + response = self.client.post(exec_url, {"password": user.username}, follow=True) + self.assertStageRedirects(response, reverse("authentik_core:root-redirect")) + + @patch( + "authentik.enterprise.license.LicenseKey.validate", + MagicMock( + return_value=LicenseKey( + aud="", + exp=expiry_valid, + name=generate_id(), + internal_users=100, + external_users=100, + ) + ), + ) + @patch( + "authentik.enterprise.license.LicenseKey.get_internal_user_count", + MagicMock(return_value=1000), + ) + @patch( + "authentik.enterprise.license.LicenseKey.get_external_user_count", + MagicMock(return_value=1000), + ) + @patch( + "authentik.enterprise.license.LicenseKey.record_usage", + MagicMock(), + ) + def test_manage_licenses(self): + """Test that managing licenses is still possible""" + license = License.objects.create(key=generate_id()) + usage = LicenseUsage.objects.create( + internal_user_count=100, + external_user_count=100, + status=LicenseUsageStatus.VALID, + ) + usage.record_date = now() - timedelta(weeks=THRESHOLD_READ_ONLY_WEEKS + 1) + usage.save(update_fields=["record_date"]) + + admin = create_test_admin_user() + self.client.force_login(admin) + + # Reading is always allowed + response = self.client.get(reverse("authentik_api:license-list")) + self.assertEqual(response.status_code, 200) + + # Writing should also be allowed + response = self.client.patch( + reverse("authentik_api:license-detail", kwargs={"pk": license.pk}) + ) + self.assertEqual(response.status_code, 200) + + @patch( + "authentik.enterprise.license.LicenseKey.validate", + MagicMock( + return_value=LicenseKey( + aud="", + exp=expiry_valid, + name=generate_id(), + internal_users=100, + external_users=100, + ) + ), + ) + @patch( + "authentik.enterprise.license.LicenseKey.get_internal_user_count", + MagicMock(return_value=1000), + ) + @patch( + "authentik.enterprise.license.LicenseKey.get_external_user_count", + MagicMock(return_value=1000), + ) + @patch( + "authentik.enterprise.license.LicenseKey.record_usage", + MagicMock(), + ) + def test_manage_flows(self): + """Test flow""" + License.objects.create(key=generate_id()) + usage = LicenseUsage.objects.create( + internal_user_count=100, + external_user_count=100, + status=LicenseUsageStatus.VALID, + ) + usage.record_date = now() - timedelta(weeks=THRESHOLD_READ_ONLY_WEEKS + 1) + usage.save(update_fields=["record_date"]) + + admin = create_test_admin_user() + self.client.force_login(admin) + + # Read only is still allowed + response = self.client.get(reverse("authentik_api:flow-list")) + self.assertEqual(response.status_code, 200) + + flow = create_test_flow() + # Writing is not + response = self.client.patch( + reverse("authentik_api:flow-detail", kwargs={"slug": flow.slug}) + ) + self.assertJSONEqual( + response.content, + {"detail": "Request denied due to expired/invalid license.", "code": "denied_license"}, + ) + self.assertEqual(response.status_code, 400) diff --git a/authentik/outposts/api/outposts.py b/authentik/outposts/api/outposts.py index a79f48f883..582dd1eac3 100644 --- a/authentik/outposts/api/outposts.py +++ b/authentik/outposts/api/outposts.py @@ -140,7 +140,7 @@ class OutpostHealthSerializer(PassiveSerializer): def get_fips_enabled(self, obj: dict) -> bool | None: """Get FIPS enabled""" - if not LicenseKey.get_total().is_valid(): + if not LicenseKey.get_total().status().is_valid: return None return obj["fips_enabled"] diff --git a/blueprints/schema.json b/blueprints/schema.json index ec03caedef..7278a76d01 100644 --- a/blueprints/schema.json +++ b/blueprints/schema.json @@ -6321,22 +6321,6 @@ "authentik_rbac.edit_system_settings", "authentik_rbac.view_system_info", "authentik_rbac.view_system_settings", - "authentik_sources_kerberos.add_groupkerberossourceconnection", - "authentik_sources_kerberos.change_groupkerberossourceconnection", - "authentik_sources_kerberos.delete_groupkerberossourceconnection", - "authentik_sources_kerberos.view_groupkerberossourceconnection", - "authentik_sources_kerberos.add_kerberospropertymapping", - "authentik_sources_kerberos.change_kerberospropertymapping", - "authentik_sources_kerberos.delete_kerberospropertymapping", - "authentik_sources_kerberos.view_kerberospropertymapping", - "authentik_sources_kerberos.add_kerberossource", - "authentik_sources_kerberos.change_kerberossource", - "authentik_sources_kerberos.delete_kerberossource", - "authentik_sources_kerberos.view_kerberossource", - "authentik_sources_kerberos.add_userkerberossourceconnection", - "authentik_sources_kerberos.change_userkerberossourceconnection", - "authentik_sources_kerberos.delete_userkerberossourceconnection", - "authentik_sources_kerberos.view_userkerberossourceconnection", "authentik_sources_ldap.add_ldapsource", "authentik_sources_ldap.change_ldapsource", "authentik_sources_ldap.delete_ldapsource", @@ -6361,14 +6345,26 @@ "authentik_sources_oauth.change_useroauthsourceconnection", "authentik_sources_oauth.delete_useroauthsourceconnection", "authentik_sources_oauth.view_useroauthsourceconnection", + "authentik_sources_plex.add_groupplexsourceconnection", + "authentik_sources_plex.change_groupplexsourceconnection", + "authentik_sources_plex.delete_groupplexsourceconnection", + "authentik_sources_plex.view_groupplexsourceconnection", "authentik_sources_plex.add_plexsource", "authentik_sources_plex.change_plexsource", "authentik_sources_plex.delete_plexsource", "authentik_sources_plex.view_plexsource", + "authentik_sources_plex.add_plexsourcepropertymapping", + "authentik_sources_plex.change_plexsourcepropertymapping", + "authentik_sources_plex.delete_plexsourcepropertymapping", + "authentik_sources_plex.view_plexsourcepropertymapping", "authentik_sources_plex.add_plexsourceconnection", + "authentik_sources_plex.add_userplexsourceconnection", "authentik_sources_plex.change_plexsourceconnection", + "authentik_sources_plex.change_userplexsourceconnection", "authentik_sources_plex.delete_plexsourceconnection", + "authentik_sources_plex.delete_userplexsourceconnection", "authentik_sources_plex.view_plexsourceconnection", + "authentik_sources_plex.view_userplexsourceconnection", "authentik_sources_saml.add_groupsamlsourceconnection", "authentik_sources_saml.change_groupsamlsourceconnection", "authentik_sources_saml.delete_groupsamlsourceconnection", @@ -11984,22 +11980,6 @@ "authentik_rbac.edit_system_settings", "authentik_rbac.view_system_info", "authentik_rbac.view_system_settings", - "authentik_sources_kerberos.add_groupkerberossourceconnection", - "authentik_sources_kerberos.change_groupkerberossourceconnection", - "authentik_sources_kerberos.delete_groupkerberossourceconnection", - "authentik_sources_kerberos.view_groupkerberossourceconnection", - "authentik_sources_kerberos.add_kerberospropertymapping", - "authentik_sources_kerberos.change_kerberospropertymapping", - "authentik_sources_kerberos.delete_kerberospropertymapping", - "authentik_sources_kerberos.view_kerberospropertymapping", - "authentik_sources_kerberos.add_kerberossource", - "authentik_sources_kerberos.change_kerberossource", - "authentik_sources_kerberos.delete_kerberossource", - "authentik_sources_kerberos.view_kerberossource", - "authentik_sources_kerberos.add_userkerberossourceconnection", - "authentik_sources_kerberos.change_userkerberossourceconnection", - "authentik_sources_kerberos.delete_userkerberossourceconnection", - "authentik_sources_kerberos.view_userkerberossourceconnection", "authentik_sources_ldap.add_ldapsource", "authentik_sources_ldap.change_ldapsource", "authentik_sources_ldap.delete_ldapsource", @@ -12024,14 +12004,26 @@ "authentik_sources_oauth.change_useroauthsourceconnection", "authentik_sources_oauth.delete_useroauthsourceconnection", "authentik_sources_oauth.view_useroauthsourceconnection", + "authentik_sources_plex.add_groupplexsourceconnection", + "authentik_sources_plex.change_groupplexsourceconnection", + "authentik_sources_plex.delete_groupplexsourceconnection", + "authentik_sources_plex.view_groupplexsourceconnection", "authentik_sources_plex.add_plexsource", "authentik_sources_plex.change_plexsource", "authentik_sources_plex.delete_plexsource", "authentik_sources_plex.view_plexsource", + "authentik_sources_plex.add_plexsourcepropertymapping", + "authentik_sources_plex.change_plexsourcepropertymapping", + "authentik_sources_plex.delete_plexsourcepropertymapping", + "authentik_sources_plex.view_plexsourcepropertymapping", "authentik_sources_plex.add_plexsourceconnection", + "authentik_sources_plex.add_userplexsourceconnection", "authentik_sources_plex.change_plexsourceconnection", + "authentik_sources_plex.change_userplexsourceconnection", "authentik_sources_plex.delete_plexsourceconnection", + "authentik_sources_plex.delete_userplexsourceconnection", "authentik_sources_plex.view_plexsourceconnection", + "authentik_sources_plex.view_userplexsourceconnection", "authentik_sources_saml.add_groupsamlsourceconnection", "authentik_sources_saml.change_groupsamlsourceconnection", "authentik_sources_saml.delete_groupsamlsourceconnection", diff --git a/schema.yml b/schema.yml index e5c8013ffc..efaaa462f9 100644 --- a/schema.yml +++ b/schema.yml @@ -41386,28 +41386,26 @@ components: type: integer external_users: type: integer - valid: - type: boolean - show_admin_warning: - type: boolean - show_user_warning: - type: boolean - read_only: - type: boolean + status: + $ref: '#/components/schemas/LicenseSummaryStatusEnum' latest_valid: type: string format: date-time - has_license: - type: boolean required: - external_users - - has_license - internal_users - latest_valid - - read_only - - show_admin_warning - - show_user_warning + - status + LicenseSummaryStatusEnum: + enum: + - unlicensed - valid + - expired + - expiry_soon + - limit_exceeded_admin + - limit_exceeded_user + - read_only + type: string Link: type: object description: Returns a single link diff --git a/web/src/admin/AdminInterface/AdminInterface.ts b/web/src/admin/AdminInterface/AdminInterface.ts index 473d97c829..4eeeaa9b30 100644 --- a/web/src/admin/AdminInterface/AdminInterface.ts +++ b/web/src/admin/AdminInterface/AdminInterface.ts @@ -71,6 +71,12 @@ export class AdminInterface extends EnterpriseAwareInterface { :host([theme="dark"]) .pf-c-page { --pf-c-page--BackgroundColor: var(--ak-dark-background); } + ak-enterprise-status { + grid-area: header; + } + ak-admin-sidebar { + grid-area: nav; + } `, ]; } @@ -118,6 +124,7 @@ export class AdminInterface extends EnterpriseAwareInterface { return html`
+ diff --git a/web/src/admin/enterprise/EnterpriseLicenseListPage.ts b/web/src/admin/enterprise/EnterpriseLicenseListPage.ts index 47d506a149..fec91f6eee 100644 --- a/web/src/admin/enterprise/EnterpriseLicenseListPage.ts +++ b/web/src/admin/enterprise/EnterpriseLicenseListPage.ts @@ -29,6 +29,7 @@ import { License, LicenseForecast, LicenseSummary, + LicenseSummaryStatusEnum, RbacPermissionsAssignedByUsersListModelEnum, } from "@goauthentik/api"; @@ -182,7 +183,7 @@ export class EnterpriseLicenseListPage extends TablePage { header=${msg("Expiry")} subtext=${msg("Cumulative license expiry")} > - ${this.summary?.hasLicense + ${this.summary?.status === LicenseSummaryStatusEnum.Unlicensed ? html`
${getRelativeTime(this.summary.latestValid)}
${this.summary.latestValid.toLocaleString()}` : "-"} diff --git a/web/src/elements/Interface/licenseSummaryProvider.ts b/web/src/elements/Interface/licenseSummaryProvider.ts index 4a73ffab6d..a7e6b9fa93 100644 --- a/web/src/elements/Interface/licenseSummaryProvider.ts +++ b/web/src/elements/Interface/licenseSummaryProvider.ts @@ -4,7 +4,7 @@ import { Constructor } from "@goauthentik/elements/types.js"; import { consume } from "@lit/context"; import type { LitElement } from "lit"; -import type { LicenseSummary } from "@goauthentik/api"; +import { type LicenseSummary, LicenseSummaryStatusEnum } from "@goauthentik/api"; export function WithLicenseSummary>( superclass: T, @@ -15,7 +15,7 @@ export function WithLicenseSummary>( public licenseSummary!: LicenseSummary; get hasEnterpriseLicense() { - return this.licenseSummary?.hasLicense; + return this.licenseSummary?.status !== LicenseSummaryStatusEnum.Unlicensed; } } diff --git a/web/src/elements/PageHeader.ts b/web/src/elements/PageHeader.ts index 8fa2c74be2..5d1f96249c 100644 --- a/web/src/elements/PageHeader.ts +++ b/web/src/elements/PageHeader.ts @@ -138,63 +138,62 @@ export class PageHeader extends WithBrandConfig(AKElement) { } render(): TemplateResult { - return html` -
- -
-
-

- ${this.renderIcon()}  - ${this.header} -

- ${this.description ? html`

${this.description}

` : html``} -
-
- - -
`; + return html`
+ +
+
+

+ ${this.renderIcon()}  + ${this.header} +

+ ${this.description ? html`

${this.description}

` : html``} +
+
+ + +
`; } } diff --git a/web/src/elements/enterprise/EnterpriseStatusBanner.ts b/web/src/elements/enterprise/EnterpriseStatusBanner.ts index a83e30e129..5a898b892e 100644 --- a/web/src/elements/enterprise/EnterpriseStatusBanner.ts +++ b/web/src/elements/enterprise/EnterpriseStatusBanner.ts @@ -7,6 +7,8 @@ import { customElement, property } from "lit/decorators.js"; import PFBanner from "@patternfly/patternfly/components/Banner/banner.css"; +import { LicenseSummaryStatusEnum } from "@goauthentik/api"; + @customElement("ak-enterprise-status") export class EnterpriseStatusBanner extends WithLicenseSummary(AKElement) { @property() @@ -17,26 +19,58 @@ export class EnterpriseStatusBanner extends WithLicenseSummary(AKElement) { } renderBanner(): TemplateResult { + let message = ""; + switch (this.licenseSummary.status) { + case LicenseSummaryStatusEnum.LimitExceededAdmin: + case LicenseSummaryStatusEnum.LimitExceededUser: + message = msg( + "Warning: The current user count has exceeded the configured licenses.", + ); + break; + case LicenseSummaryStatusEnum.Expired: + message = msg("Warning: One or more license(s) have expired."); + break; + case LicenseSummaryStatusEnum.ExpirySoon: + message = msg( + "Warning: One or more license(s) will expire within the next 2 weeks.", + ); + break; + case LicenseSummaryStatusEnum.ReadOnly: + message = msg( + "Caution: This authentik instance has entered read-only mode due to expired/exceeded licenses.", + ); + break; + default: + break; + } return html`
- ${msg("Warning: The current user count has exceeded the configured licenses.")} + ${message} ${msg("Click here for more info.")}
`; } render(): TemplateResult { - switch (this.interface.toLowerCase()) { - case "admin": - if (this.licenseSummary?.showAdminWarning || this.licenseSummary?.readOnly) { + switch (this.licenseSummary.status) { + case LicenseSummaryStatusEnum.LimitExceededUser: + if (this.interface.toLowerCase() === "user") { return this.renderBanner(); } break; - case "user": - if (this.licenseSummary?.showUserWarning || this.licenseSummary?.readOnly) { + case LicenseSummaryStatusEnum.ExpirySoon: + case LicenseSummaryStatusEnum.Expired: + case LicenseSummaryStatusEnum.LimitExceededAdmin: + if (this.interface.toLowerCase() === "admin") { return this.renderBanner(); } break; + case LicenseSummaryStatusEnum.ReadOnly: + return this.renderBanner(); + default: + break; } return html``; } diff --git a/web/src/elements/sidebar/Sidebar.ts b/web/src/elements/sidebar/Sidebar.ts index 0640c94184..76575acefb 100644 --- a/web/src/elements/sidebar/Sidebar.ts +++ b/web/src/elements/sidebar/Sidebar.ts @@ -42,7 +42,6 @@ export class Sidebar extends AKElement { nav { display: flex; flex-direction: column; - max-height: 100vh; height: 100%; overflow-y: hidden; }