brands: add OIDC webfinger support (#10400)

* brands: add OIDC webfinger support for default application

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add tests

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens L.
2024-08-06 14:00:01 +02:00
committed by GitHub
parent ae88ea3543
commit 3d63143c38
7 changed files with 118 additions and 1 deletions

View File

@ -9,3 +9,6 @@ class AuthentikBrandsConfig(AppConfig):
name = "authentik.brands"
label = "authentik_brands"
verbose_name = "authentik Brands"
mountpoints = {
"authentik.brands.urls_root": "",
}

View File

@ -3,6 +3,7 @@
from uuid import uuid4
from django.db import models
from django.http import HttpRequest
from django.utils.translation import gettext_lazy as _
from rest_framework.serializers import Serializer
from structlog.stdlib import get_logger
@ -98,3 +99,13 @@ class Brand(SerializerModel):
models.Index(fields=["domain"]),
models.Index(fields=["default"]),
]
class WebfingerProvider(models.Model):
"""Provider which supports webfinger discovery"""
class Meta:
abstract = True
def webfinger(self, resource: str, request: HttpRequest) -> dict:
raise NotImplementedError()

View File

@ -5,7 +5,11 @@ from rest_framework.test import APITestCase
from authentik.brands.api import Themes
from authentik.brands.models import Brand
from authentik.core.models import Application
from authentik.core.tests.utils import create_test_admin_user, create_test_brand
from authentik.lib.generators import generate_id
from authentik.providers.oauth2.models import OAuth2Provider
from authentik.providers.saml.models import SAMLProvider
class TestBrands(APITestCase):
@ -75,3 +79,45 @@ class TestBrands(APITestCase):
reverse("authentik_api:brand-list"), data={"domain": "bar", "default": True}
)
self.assertEqual(response.status_code, 400)
def test_webfinger_no_app(self):
"""Test Webfinger"""
create_test_brand()
self.assertJSONEqual(
self.client.get(reverse("authentik_brands:webfinger")).content.decode(), {}
)
def test_webfinger_not_supported(self):
"""Test Webfinger"""
brand = create_test_brand()
provider = SAMLProvider.objects.create(
name=generate_id(),
)
app = Application.objects.create(name=generate_id(), slug=generate_id(), provider=provider)
brand.default_application = app
brand.save()
self.assertJSONEqual(
self.client.get(reverse("authentik_brands:webfinger")).content.decode(), {}
)
def test_webfinger_oidc(self):
"""Test Webfinger"""
brand = create_test_brand()
provider = OAuth2Provider.objects.create(
name=generate_id(),
)
app = Application.objects.create(name=generate_id(), slug=generate_id(), provider=provider)
brand.default_application = app
brand.save()
self.assertJSONEqual(
self.client.get(reverse("authentik_brands:webfinger")).content.decode(),
{
"links": [
{
"href": f"http://testserver/application/o/{app.slug}/",
"rel": "http://openid.net/specs/connect/1.0/issuer",
}
],
"subject": None,
},
)

View File

@ -0,0 +1,9 @@
"""authentik brand root URLs"""
from django.urls import path
from authentik.brands.views.webfinger import WebFingerView
urlpatterns = [
path(".well-known/webfinger", WebFingerView.as_view(), name="webfinger"),
]

View File

View File

@ -0,0 +1,29 @@
from typing import Any
from django.http import HttpRequest, HttpResponse, JsonResponse
from django.views import View
from authentik.brands.models import Brand, WebfingerProvider
from authentik.core.models import Application
class WebFingerView(View):
"""Webfinger endpoint"""
def get(self, request: HttpRequest) -> HttpResponse:
brand: Brand = request.brand
if not brand.default_application:
return JsonResponse({})
application: Application = brand.default_application
provider = application.get_provider()
if not provider or not isinstance(provider, WebfingerProvider):
return JsonResponse({})
webfinger_data = provider.webfinger(request.GET.get("resource"), request)
return JsonResponse(webfinger_data)
def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse:
response = super().dispatch(request, *args, **kwargs)
# RFC7033 spec
response["Access-Control-Allow-Origin"] = "*"
response["Content-Type"] = "application/jrd+json"
return response

View File

@ -22,6 +22,7 @@ from jwt import encode
from rest_framework.serializers import Serializer
from structlog.stdlib import get_logger
from authentik.brands.models import WebfingerProvider
from authentik.core.models import ExpiringModel, PropertyMapping, Provider, User
from authentik.crypto.models import CertificateKeyPair
from authentik.lib.generators import generate_code_fixed_length, generate_id, generate_key
@ -120,7 +121,7 @@ class ScopeMapping(PropertyMapping):
verbose_name_plural = _("Scope Mappings")
class OAuth2Provider(Provider):
class OAuth2Provider(WebfingerProvider, Provider):
"""OAuth2 Provider for generic OAuth and OpenID Connect Applications."""
client_type = models.CharField(
@ -288,6 +289,24 @@ class OAuth2Provider(Provider):
key, alg = self.jwt_key
return encode(payload, key, algorithm=alg, headers=headers)
def webfinger(self, resource: str, request: HttpRequest):
return {
"subject": resource,
"links": [
{
"rel": "http://openid.net/specs/connect/1.0/issuer",
"href": request.build_absolute_uri(
reverse(
"authentik_providers_oauth2:provider-root",
kwargs={
"application_slug": self.application.slug,
},
)
),
},
],
}
class Meta:
verbose_name = _("OAuth2/OpenID Provider")
verbose_name_plural = _("OAuth2/OpenID Providers")