diff --git a/authentik/core/models.py b/authentik/core/models.py index 85ee1fd925..c81beffade 100644 --- a/authentik/core/models.py +++ b/authentik/core/models.py @@ -466,8 +466,6 @@ class ApplicationQuerySet(QuerySet): def with_provider(self) -> "QuerySet[Application]": qs = self.select_related("provider") for subclass in Provider.objects.get_queryset()._get_subclasses_recurse(Provider): - if LOOKUP_SEP in subclass: - continue qs = qs.select_related(f"provider__{subclass}") return qs @@ -545,15 +543,20 @@ class Application(SerializerModel, PolicyBindingModel): if not self.provider: return None + candidates = [] for subclass in Provider.objects.get_queryset()._get_subclasses_recurse(Provider): - # We don't care about recursion, skip nested models - if LOOKUP_SEP in subclass: + parent = self.provider + for level in subclass.split(LOOKUP_SEP): + try: + parent = getattr(parent, level) + except AttributeError: + break + if parent in candidates: continue - try: - return getattr(self.provider, subclass) - except AttributeError: - pass - return None + candidates.insert(subclass.count(LOOKUP_SEP), parent) + if not candidates: + return None + return candidates[-1] def __str__(self): return str(self.name) diff --git a/authentik/core/tests/test_applications_api.py b/authentik/core/tests/test_applications_api.py index 6e970b2079..df16059c15 100644 --- a/authentik/core/tests/test_applications_api.py +++ b/authentik/core/tests/test_applications_api.py @@ -9,9 +9,11 @@ from rest_framework.test import APITestCase from authentik.core.models import Application from authentik.core.tests.utils import create_test_admin_user, create_test_flow +from authentik.lib.generators import generate_id from authentik.policies.dummy.models import DummyPolicy from authentik.policies.models import PolicyBinding from authentik.providers.oauth2.models import OAuth2Provider +from authentik.providers.proxy.models import ProxyProvider class TestApplicationsAPI(APITestCase): @@ -222,3 +224,15 @@ class TestApplicationsAPI(APITestCase): ], }, ) + + def test_get_provider(self): + """Ensure that proxy providers (at the time of writing that is the only provider + that inherits from another proxy type (OAuth) instead of inheriting from the root + provider class) is correctly looked up and selected from the database""" + provider = ProxyProvider.objects.create(name=generate_id()) + app = Application.objects.create( + name=generate_id(), + slug=generate_id(), + provider=provider, + ) + self.assertEqual(app.get_provider(), provider)