diff --git a/authentik/enterprise/license.py b/authentik/enterprise/license.py index 1bc92bdf5c..346cfe6438 100644 --- a/authentik/enterprise/license.py +++ b/authentik/enterprise/license.py @@ -132,13 +132,14 @@ class LicenseKey: """Get a summarized version of all (not expired) licenses""" total = LicenseKey(get_license_aud(), 0, "Summarized license", 0, 0) for lic in License.objects.all(): - total.internal_users += lic.internal_users - total.external_users += lic.external_users + if lic.is_valid: + total.internal_users += lic.internal_users + total.external_users += lic.external_users + total.license_flags.extend(lic.status.license_flags) exp_ts = int(mktime(lic.expiry.timetuple())) if total.exp == 0: total.exp = exp_ts total.exp = max(total.exp, exp_ts) - total.license_flags.extend(lic.status.license_flags) return total @staticmethod diff --git a/authentik/enterprise/models.py b/authentik/enterprise/models.py index 2ef24311b8..0b96c9754b 100644 --- a/authentik/enterprise/models.py +++ b/authentik/enterprise/models.py @@ -39,6 +39,10 @@ class License(SerializerModel): internal_users = models.BigIntegerField() external_users = models.BigIntegerField() + @property + def is_valid(self) -> bool: + return self.expiry >= now() + @property def serializer(self) -> type[BaseSerializer]: from authentik.enterprise.api import LicenseSerializer diff --git a/authentik/enterprise/tests/test_license.py b/authentik/enterprise/tests/test_license.py index c76f141f10..6ab2ced0c7 100644 --- a/authentik/enterprise/tests/test_license.py +++ b/authentik/enterprise/tests/test_license.py @@ -8,6 +8,7 @@ from django.test import TestCase from django.utils.timezone import now from rest_framework.exceptions import ValidationError +from authentik.core.models import User from authentik.enterprise.license import LicenseKey from authentik.enterprise.models import ( THRESHOLD_READ_ONLY_WEEKS, @@ -71,9 +72,9 @@ class TestEnterpriseLicense(TestCase): ) def test_valid_multiple(self): """Check license verification""" - lic = License.objects.create(key=generate_id()) + lic = License.objects.create(key=generate_id(), expiry=expiry_valid) self.assertTrue(lic.status.status().is_valid) - lic2 = License.objects.create(key=generate_id()) + lic2 = License.objects.create(key=generate_id(), expiry=expiry_valid) self.assertTrue(lic2.status.status().is_valid) total = LicenseKey.get_total() self.assertEqual(total.internal_users, 200) @@ -232,7 +233,9 @@ class TestEnterpriseLicense(TestCase): ) def test_expiry_expired(self): """Check license verification""" - License.objects.create(key=generate_id()) + User.objects.all().delete() + License.objects.all().delete() + License.objects.create(key=generate_id(), expiry=expiry_expired) self.assertEqual(LicenseKey.get_total().summary().status, LicenseUsageStatus.EXPIRED) @patch(