diff --git a/authentik/core/models.py b/authentik/core/models.py index c81beffade..8b6fbcbf56 100644 --- a/authentik/core/models.py +++ b/authentik/core/models.py @@ -544,7 +544,8 @@ class Application(SerializerModel, PolicyBindingModel): return None candidates = [] - for subclass in Provider.objects.get_queryset()._get_subclasses_recurse(Provider): + base_class = Provider + for subclass in base_class.objects.get_queryset()._get_subclasses_recurse(base_class): parent = self.provider for level in subclass.split(LOOKUP_SEP): try: @@ -553,7 +554,10 @@ class Application(SerializerModel, PolicyBindingModel): break if parent in candidates: continue - candidates.insert(subclass.count(LOOKUP_SEP), parent) + idx = subclass.count(LOOKUP_SEP) + if type(parent) is not base_class: + idx += 1 + candidates.insert(idx, parent) if not candidates: return None return candidates[-1] diff --git a/authentik/core/tests/test_applications_api.py b/authentik/core/tests/test_applications_api.py index df16059c15..51adf4b878 100644 --- a/authentik/core/tests/test_applications_api.py +++ b/authentik/core/tests/test_applications_api.py @@ -14,6 +14,7 @@ from authentik.policies.dummy.models import DummyPolicy from authentik.policies.models import PolicyBinding from authentik.providers.oauth2.models import OAuth2Provider from authentik.providers.proxy.models import ProxyProvider +from authentik.providers.saml.models import SAMLProvider class TestApplicationsAPI(APITestCase): @@ -229,10 +230,26 @@ class TestApplicationsAPI(APITestCase): """Ensure that proxy providers (at the time of writing that is the only provider that inherits from another proxy type (OAuth) instead of inheriting from the root provider class) is correctly looked up and selected from the database""" + slug = generate_id() provider = ProxyProvider.objects.create(name=generate_id()) - app = Application.objects.create( + Application.objects.create( name=generate_id(), - slug=generate_id(), + slug=slug, provider=provider, ) - self.assertEqual(app.get_provider(), provider) + self.assertEqual(Application.objects.get(slug=slug).get_provider(), provider) + self.assertEqual( + Application.objects.with_provider().get(slug=slug).get_provider(), provider + ) + + slug = generate_id() + provider = SAMLProvider.objects.create(name=generate_id()) + Application.objects.create( + name=generate_id(), + slug=slug, + provider=provider, + ) + self.assertEqual(Application.objects.get(slug=slug).get_provider(), provider) + self.assertEqual( + Application.objects.with_provider().get(slug=slug).get_provider(), provider + )