core: ensure proxy provider is correctly looked up (cherry-pick #11267) (#11269)

core: ensure proxy provider is correctly looked up (#11267)

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
Co-authored-by: Jens L. <jens@goauthentik.io>
This commit is contained in:
gcp-cherry-pick-bot[bot]
2024-09-07 21:53:30 +02:00
committed by GitHub
parent f5580d311d
commit 61778053b4
2 changed files with 26 additions and 9 deletions

View File

@ -466,8 +466,6 @@ class ApplicationQuerySet(QuerySet):
def with_provider(self) -> "QuerySet[Application]": def with_provider(self) -> "QuerySet[Application]":
qs = self.select_related("provider") qs = self.select_related("provider")
for subclass in Provider.objects.get_queryset()._get_subclasses_recurse(Provider): for subclass in Provider.objects.get_queryset()._get_subclasses_recurse(Provider):
if LOOKUP_SEP in subclass:
continue
qs = qs.select_related(f"provider__{subclass}") qs = qs.select_related(f"provider__{subclass}")
return qs return qs
@ -545,15 +543,20 @@ class Application(SerializerModel, PolicyBindingModel):
if not self.provider: if not self.provider:
return None return None
candidates = []
for subclass in Provider.objects.get_queryset()._get_subclasses_recurse(Provider): for subclass in Provider.objects.get_queryset()._get_subclasses_recurse(Provider):
# We don't care about recursion, skip nested models parent = self.provider
if LOOKUP_SEP in subclass: for level in subclass.split(LOOKUP_SEP):
continue
try: try:
return getattr(self.provider, subclass) parent = getattr(parent, level)
except AttributeError: except AttributeError:
pass break
if parent in candidates:
continue
candidates.insert(subclass.count(LOOKUP_SEP), parent)
if not candidates:
return None return None
return candidates[-1]
def __str__(self): def __str__(self):
return str(self.name) return str(self.name)

View File

@ -9,9 +9,11 @@ from rest_framework.test import APITestCase
from authentik.core.models import Application from authentik.core.models import Application
from authentik.core.tests.utils import create_test_admin_user, create_test_flow from authentik.core.tests.utils import create_test_admin_user, create_test_flow
from authentik.lib.generators import generate_id
from authentik.policies.dummy.models import DummyPolicy from authentik.policies.dummy.models import DummyPolicy
from authentik.policies.models import PolicyBinding from authentik.policies.models import PolicyBinding
from authentik.providers.oauth2.models import OAuth2Provider from authentik.providers.oauth2.models import OAuth2Provider
from authentik.providers.proxy.models import ProxyProvider
class TestApplicationsAPI(APITestCase): class TestApplicationsAPI(APITestCase):
@ -222,3 +224,15 @@ class TestApplicationsAPI(APITestCase):
], ],
}, },
) )
def test_get_provider(self):
"""Ensure that proxy providers (at the time of writing that is the only provider
that inherits from another proxy type (OAuth) instead of inheriting from the root
provider class) is correctly looked up and selected from the database"""
provider = ProxyProvider.objects.create(name=generate_id())
app = Application.objects.create(
name=generate_id(),
slug=generate_id(),
provider=provider,
)
self.assertEqual(app.get_provider(), provider)