core: ensure proxy provider is correctly looked up (#11267) Signed-off-by: Jens Langhammer <jens@goauthentik.io> Co-authored-by: Jens L. <jens@goauthentik.io>
This commit is contained in:
committed by
GitHub
parent
f5580d311d
commit
61778053b4
@ -466,8 +466,6 @@ class ApplicationQuerySet(QuerySet):
|
||||
def with_provider(self) -> "QuerySet[Application]":
|
||||
qs = self.select_related("provider")
|
||||
for subclass in Provider.objects.get_queryset()._get_subclasses_recurse(Provider):
|
||||
if LOOKUP_SEP in subclass:
|
||||
continue
|
||||
qs = qs.select_related(f"provider__{subclass}")
|
||||
return qs
|
||||
|
||||
@ -545,15 +543,20 @@ class Application(SerializerModel, PolicyBindingModel):
|
||||
if not self.provider:
|
||||
return None
|
||||
|
||||
candidates = []
|
||||
for subclass in Provider.objects.get_queryset()._get_subclasses_recurse(Provider):
|
||||
# We don't care about recursion, skip nested models
|
||||
if LOOKUP_SEP in subclass:
|
||||
continue
|
||||
parent = self.provider
|
||||
for level in subclass.split(LOOKUP_SEP):
|
||||
try:
|
||||
return getattr(self.provider, subclass)
|
||||
parent = getattr(parent, level)
|
||||
except AttributeError:
|
||||
pass
|
||||
break
|
||||
if parent in candidates:
|
||||
continue
|
||||
candidates.insert(subclass.count(LOOKUP_SEP), parent)
|
||||
if not candidates:
|
||||
return None
|
||||
return candidates[-1]
|
||||
|
||||
def __str__(self):
|
||||
return str(self.name)
|
||||
|
||||
@ -9,9 +9,11 @@ from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.core.models import Application
|
||||
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.policies.dummy.models import DummyPolicy
|
||||
from authentik.policies.models import PolicyBinding
|
||||
from authentik.providers.oauth2.models import OAuth2Provider
|
||||
from authentik.providers.proxy.models import ProxyProvider
|
||||
|
||||
|
||||
class TestApplicationsAPI(APITestCase):
|
||||
@ -222,3 +224,15 @@ class TestApplicationsAPI(APITestCase):
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
def test_get_provider(self):
|
||||
"""Ensure that proxy providers (at the time of writing that is the only provider
|
||||
that inherits from another proxy type (OAuth) instead of inheriting from the root
|
||||
provider class) is correctly looked up and selected from the database"""
|
||||
provider = ProxyProvider.objects.create(name=generate_id())
|
||||
app = Application.objects.create(
|
||||
name=generate_id(),
|
||||
slug=generate_id(),
|
||||
provider=provider,
|
||||
)
|
||||
self.assertEqual(app.get_provider(), provider)
|
||||
|
||||
Reference in New Issue
Block a user