enterprise: fix expired license's users being counted (#14451)
* enterprise: fix expired license's users being counted Signed-off-by: Jens Langhammer <jens@goauthentik.io> * tests to the rescue Signed-off-by: Jens Langhammer <jens@goauthentik.io> * hmm Signed-off-by: Jens Langhammer <jens@goauthentik.io> --------- Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
		@ -132,13 +132,14 @@ class LicenseKey:
 | 
			
		||||
        """Get a summarized version of all (not expired) licenses"""
 | 
			
		||||
        total = LicenseKey(get_license_aud(), 0, "Summarized license", 0, 0)
 | 
			
		||||
        for lic in License.objects.all():
 | 
			
		||||
            total.internal_users += lic.internal_users
 | 
			
		||||
            total.external_users += lic.external_users
 | 
			
		||||
            if lic.is_valid:
 | 
			
		||||
                total.internal_users += lic.internal_users
 | 
			
		||||
                total.external_users += lic.external_users
 | 
			
		||||
                total.license_flags.extend(lic.status.license_flags)
 | 
			
		||||
            exp_ts = int(mktime(lic.expiry.timetuple()))
 | 
			
		||||
            if total.exp == 0:
 | 
			
		||||
                total.exp = exp_ts
 | 
			
		||||
            total.exp = max(total.exp, exp_ts)
 | 
			
		||||
            total.license_flags.extend(lic.status.license_flags)
 | 
			
		||||
        return total
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
 | 
			
		||||
@ -39,6 +39,10 @@ class License(SerializerModel):
 | 
			
		||||
    internal_users = models.BigIntegerField()
 | 
			
		||||
    external_users = models.BigIntegerField()
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def is_valid(self) -> bool:
 | 
			
		||||
        return self.expiry >= now()
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def serializer(self) -> type[BaseSerializer]:
 | 
			
		||||
        from authentik.enterprise.api import LicenseSerializer
 | 
			
		||||
 | 
			
		||||
@ -8,6 +8,7 @@ from django.test import TestCase
 | 
			
		||||
from django.utils.timezone import now
 | 
			
		||||
from rest_framework.exceptions import ValidationError
 | 
			
		||||
 | 
			
		||||
from authentik.core.models import User
 | 
			
		||||
from authentik.enterprise.license import LicenseKey
 | 
			
		||||
from authentik.enterprise.models import (
 | 
			
		||||
    THRESHOLD_READ_ONLY_WEEKS,
 | 
			
		||||
@ -71,9 +72,9 @@ class TestEnterpriseLicense(TestCase):
 | 
			
		||||
    )
 | 
			
		||||
    def test_valid_multiple(self):
 | 
			
		||||
        """Check license verification"""
 | 
			
		||||
        lic = License.objects.create(key=generate_id())
 | 
			
		||||
        lic = License.objects.create(key=generate_id(), expiry=expiry_valid)
 | 
			
		||||
        self.assertTrue(lic.status.status().is_valid)
 | 
			
		||||
        lic2 = License.objects.create(key=generate_id())
 | 
			
		||||
        lic2 = License.objects.create(key=generate_id(), expiry=expiry_valid)
 | 
			
		||||
        self.assertTrue(lic2.status.status().is_valid)
 | 
			
		||||
        total = LicenseKey.get_total()
 | 
			
		||||
        self.assertEqual(total.internal_users, 200)
 | 
			
		||||
@ -232,7 +233,9 @@ class TestEnterpriseLicense(TestCase):
 | 
			
		||||
    )
 | 
			
		||||
    def test_expiry_expired(self):
 | 
			
		||||
        """Check license verification"""
 | 
			
		||||
        License.objects.create(key=generate_id())
 | 
			
		||||
        User.objects.all().delete()
 | 
			
		||||
        License.objects.all().delete()
 | 
			
		||||
        License.objects.create(key=generate_id(), expiry=expiry_expired)
 | 
			
		||||
        self.assertEqual(LicenseKey.get_total().summary().status, LicenseUsageStatus.EXPIRED)
 | 
			
		||||
 | 
			
		||||
    @patch(
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user