brands: add OIDC webfinger support (#10400)
* brands: add OIDC webfinger support for default application Signed-off-by: Jens Langhammer <jens@goauthentik.io> * add tests Signed-off-by: Jens Langhammer <jens@goauthentik.io> --------- Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
@ -9,3 +9,6 @@ class AuthentikBrandsConfig(AppConfig):
|
|||||||
name = "authentik.brands"
|
name = "authentik.brands"
|
||||||
label = "authentik_brands"
|
label = "authentik_brands"
|
||||||
verbose_name = "authentik Brands"
|
verbose_name = "authentik Brands"
|
||||||
|
mountpoints = {
|
||||||
|
"authentik.brands.urls_root": "",
|
||||||
|
}
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from django.db import models
|
from django.db import models
|
||||||
|
from django.http import HttpRequest
|
||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
from rest_framework.serializers import Serializer
|
from rest_framework.serializers import Serializer
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
@ -98,3 +99,13 @@ class Brand(SerializerModel):
|
|||||||
models.Index(fields=["domain"]),
|
models.Index(fields=["domain"]),
|
||||||
models.Index(fields=["default"]),
|
models.Index(fields=["default"]),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class WebfingerProvider(models.Model):
|
||||||
|
"""Provider which supports webfinger discovery"""
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
abstract = True
|
||||||
|
|
||||||
|
def webfinger(self, resource: str, request: HttpRequest) -> dict:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|||||||
@ -5,7 +5,11 @@ from rest_framework.test import APITestCase
|
|||||||
|
|
||||||
from authentik.brands.api import Themes
|
from authentik.brands.api import Themes
|
||||||
from authentik.brands.models import Brand
|
from authentik.brands.models import Brand
|
||||||
|
from authentik.core.models import Application
|
||||||
from authentik.core.tests.utils import create_test_admin_user, create_test_brand
|
from authentik.core.tests.utils import create_test_admin_user, create_test_brand
|
||||||
|
from authentik.lib.generators import generate_id
|
||||||
|
from authentik.providers.oauth2.models import OAuth2Provider
|
||||||
|
from authentik.providers.saml.models import SAMLProvider
|
||||||
|
|
||||||
|
|
||||||
class TestBrands(APITestCase):
|
class TestBrands(APITestCase):
|
||||||
@ -75,3 +79,45 @@ class TestBrands(APITestCase):
|
|||||||
reverse("authentik_api:brand-list"), data={"domain": "bar", "default": True}
|
reverse("authentik_api:brand-list"), data={"domain": "bar", "default": True}
|
||||||
)
|
)
|
||||||
self.assertEqual(response.status_code, 400)
|
self.assertEqual(response.status_code, 400)
|
||||||
|
|
||||||
|
def test_webfinger_no_app(self):
|
||||||
|
"""Test Webfinger"""
|
||||||
|
create_test_brand()
|
||||||
|
self.assertJSONEqual(
|
||||||
|
self.client.get(reverse("authentik_brands:webfinger")).content.decode(), {}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_webfinger_not_supported(self):
|
||||||
|
"""Test Webfinger"""
|
||||||
|
brand = create_test_brand()
|
||||||
|
provider = SAMLProvider.objects.create(
|
||||||
|
name=generate_id(),
|
||||||
|
)
|
||||||
|
app = Application.objects.create(name=generate_id(), slug=generate_id(), provider=provider)
|
||||||
|
brand.default_application = app
|
||||||
|
brand.save()
|
||||||
|
self.assertJSONEqual(
|
||||||
|
self.client.get(reverse("authentik_brands:webfinger")).content.decode(), {}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_webfinger_oidc(self):
|
||||||
|
"""Test Webfinger"""
|
||||||
|
brand = create_test_brand()
|
||||||
|
provider = OAuth2Provider.objects.create(
|
||||||
|
name=generate_id(),
|
||||||
|
)
|
||||||
|
app = Application.objects.create(name=generate_id(), slug=generate_id(), provider=provider)
|
||||||
|
brand.default_application = app
|
||||||
|
brand.save()
|
||||||
|
self.assertJSONEqual(
|
||||||
|
self.client.get(reverse("authentik_brands:webfinger")).content.decode(),
|
||||||
|
{
|
||||||
|
"links": [
|
||||||
|
{
|
||||||
|
"href": f"http://testserver/application/o/{app.slug}/",
|
||||||
|
"rel": "http://openid.net/specs/connect/1.0/issuer",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"subject": None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|||||||
9
authentik/brands/urls_root.py
Normal file
9
authentik/brands/urls_root.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
"""authentik brand root URLs"""
|
||||||
|
|
||||||
|
from django.urls import path
|
||||||
|
|
||||||
|
from authentik.brands.views.webfinger import WebFingerView
|
||||||
|
|
||||||
|
urlpatterns = [
|
||||||
|
path(".well-known/webfinger", WebFingerView.as_view(), name="webfinger"),
|
||||||
|
]
|
||||||
0
authentik/brands/views/__init__.py
Normal file
0
authentik/brands/views/__init__.py
Normal file
29
authentik/brands/views/webfinger.py
Normal file
29
authentik/brands/views/webfinger.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from django.http import HttpRequest, HttpResponse, JsonResponse
|
||||||
|
from django.views import View
|
||||||
|
|
||||||
|
from authentik.brands.models import Brand, WebfingerProvider
|
||||||
|
from authentik.core.models import Application
|
||||||
|
|
||||||
|
|
||||||
|
class WebFingerView(View):
|
||||||
|
"""Webfinger endpoint"""
|
||||||
|
|
||||||
|
def get(self, request: HttpRequest) -> HttpResponse:
|
||||||
|
brand: Brand = request.brand
|
||||||
|
if not brand.default_application:
|
||||||
|
return JsonResponse({})
|
||||||
|
application: Application = brand.default_application
|
||||||
|
provider = application.get_provider()
|
||||||
|
if not provider or not isinstance(provider, WebfingerProvider):
|
||||||
|
return JsonResponse({})
|
||||||
|
webfinger_data = provider.webfinger(request.GET.get("resource"), request)
|
||||||
|
return JsonResponse(webfinger_data)
|
||||||
|
|
||||||
|
def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse:
|
||||||
|
response = super().dispatch(request, *args, **kwargs)
|
||||||
|
# RFC7033 spec
|
||||||
|
response["Access-Control-Allow-Origin"] = "*"
|
||||||
|
response["Content-Type"] = "application/jrd+json"
|
||||||
|
return response
|
||||||
@ -22,6 +22,7 @@ from jwt import encode
|
|||||||
from rest_framework.serializers import Serializer
|
from rest_framework.serializers import Serializer
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
|
|
||||||
|
from authentik.brands.models import WebfingerProvider
|
||||||
from authentik.core.models import ExpiringModel, PropertyMapping, Provider, User
|
from authentik.core.models import ExpiringModel, PropertyMapping, Provider, User
|
||||||
from authentik.crypto.models import CertificateKeyPair
|
from authentik.crypto.models import CertificateKeyPair
|
||||||
from authentik.lib.generators import generate_code_fixed_length, generate_id, generate_key
|
from authentik.lib.generators import generate_code_fixed_length, generate_id, generate_key
|
||||||
@ -120,7 +121,7 @@ class ScopeMapping(PropertyMapping):
|
|||||||
verbose_name_plural = _("Scope Mappings")
|
verbose_name_plural = _("Scope Mappings")
|
||||||
|
|
||||||
|
|
||||||
class OAuth2Provider(Provider):
|
class OAuth2Provider(WebfingerProvider, Provider):
|
||||||
"""OAuth2 Provider for generic OAuth and OpenID Connect Applications."""
|
"""OAuth2 Provider for generic OAuth and OpenID Connect Applications."""
|
||||||
|
|
||||||
client_type = models.CharField(
|
client_type = models.CharField(
|
||||||
@ -288,6 +289,24 @@ class OAuth2Provider(Provider):
|
|||||||
key, alg = self.jwt_key
|
key, alg = self.jwt_key
|
||||||
return encode(payload, key, algorithm=alg, headers=headers)
|
return encode(payload, key, algorithm=alg, headers=headers)
|
||||||
|
|
||||||
|
def webfinger(self, resource: str, request: HttpRequest):
|
||||||
|
return {
|
||||||
|
"subject": resource,
|
||||||
|
"links": [
|
||||||
|
{
|
||||||
|
"rel": "http://openid.net/specs/connect/1.0/issuer",
|
||||||
|
"href": request.build_absolute_uri(
|
||||||
|
reverse(
|
||||||
|
"authentik_providers_oauth2:provider-root",
|
||||||
|
kwargs={
|
||||||
|
"application_slug": self.application.slug,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
verbose_name = _("OAuth2/OpenID Provider")
|
verbose_name = _("OAuth2/OpenID Provider")
|
||||||
verbose_name_plural = _("OAuth2/OpenID Providers")
|
verbose_name_plural = _("OAuth2/OpenID Providers")
|
||||||
|
|||||||
Reference in New Issue
Block a user