brands: add OIDC webfinger support (#10400)

* brands: add OIDC webfinger support for default application

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add tests

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens L.
2024-08-06 14:00:01 +02:00
committed by GitHub
parent ae88ea3543
commit 3d63143c38
7 changed files with 118 additions and 1 deletions

View File

@ -9,3 +9,6 @@ class AuthentikBrandsConfig(AppConfig):
name = "authentik.brands" name = "authentik.brands"
label = "authentik_brands" label = "authentik_brands"
verbose_name = "authentik Brands" verbose_name = "authentik Brands"
mountpoints = {
"authentik.brands.urls_root": "",
}

View File

@ -3,6 +3,7 @@
from uuid import uuid4 from uuid import uuid4
from django.db import models from django.db import models
from django.http import HttpRequest
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from rest_framework.serializers import Serializer from rest_framework.serializers import Serializer
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
@ -98,3 +99,13 @@ class Brand(SerializerModel):
models.Index(fields=["domain"]), models.Index(fields=["domain"]),
models.Index(fields=["default"]), models.Index(fields=["default"]),
] ]
class WebfingerProvider(models.Model):
"""Provider which supports webfinger discovery"""
class Meta:
abstract = True
def webfinger(self, resource: str, request: HttpRequest) -> dict:
raise NotImplementedError()

View File

@ -5,7 +5,11 @@ from rest_framework.test import APITestCase
from authentik.brands.api import Themes from authentik.brands.api import Themes
from authentik.brands.models import Brand from authentik.brands.models import Brand
from authentik.core.models import Application
from authentik.core.tests.utils import create_test_admin_user, create_test_brand from authentik.core.tests.utils import create_test_admin_user, create_test_brand
from authentik.lib.generators import generate_id
from authentik.providers.oauth2.models import OAuth2Provider
from authentik.providers.saml.models import SAMLProvider
class TestBrands(APITestCase): class TestBrands(APITestCase):
@ -75,3 +79,45 @@ class TestBrands(APITestCase):
reverse("authentik_api:brand-list"), data={"domain": "bar", "default": True} reverse("authentik_api:brand-list"), data={"domain": "bar", "default": True}
) )
self.assertEqual(response.status_code, 400) self.assertEqual(response.status_code, 400)
def test_webfinger_no_app(self):
"""Test Webfinger"""
create_test_brand()
self.assertJSONEqual(
self.client.get(reverse("authentik_brands:webfinger")).content.decode(), {}
)
def test_webfinger_not_supported(self):
"""Test Webfinger"""
brand = create_test_brand()
provider = SAMLProvider.objects.create(
name=generate_id(),
)
app = Application.objects.create(name=generate_id(), slug=generate_id(), provider=provider)
brand.default_application = app
brand.save()
self.assertJSONEqual(
self.client.get(reverse("authentik_brands:webfinger")).content.decode(), {}
)
def test_webfinger_oidc(self):
"""Test Webfinger"""
brand = create_test_brand()
provider = OAuth2Provider.objects.create(
name=generate_id(),
)
app = Application.objects.create(name=generate_id(), slug=generate_id(), provider=provider)
brand.default_application = app
brand.save()
self.assertJSONEqual(
self.client.get(reverse("authentik_brands:webfinger")).content.decode(),
{
"links": [
{
"href": f"http://testserver/application/o/{app.slug}/",
"rel": "http://openid.net/specs/connect/1.0/issuer",
}
],
"subject": None,
},
)

View File

@ -0,0 +1,9 @@
"""authentik brand root URLs"""
from django.urls import path
from authentik.brands.views.webfinger import WebFingerView
urlpatterns = [
path(".well-known/webfinger", WebFingerView.as_view(), name="webfinger"),
]

View File

View File

@ -0,0 +1,29 @@
from typing import Any
from django.http import HttpRequest, HttpResponse, JsonResponse
from django.views import View
from authentik.brands.models import Brand, WebfingerProvider
from authentik.core.models import Application
class WebFingerView(View):
"""Webfinger endpoint"""
def get(self, request: HttpRequest) -> HttpResponse:
brand: Brand = request.brand
if not brand.default_application:
return JsonResponse({})
application: Application = brand.default_application
provider = application.get_provider()
if not provider or not isinstance(provider, WebfingerProvider):
return JsonResponse({})
webfinger_data = provider.webfinger(request.GET.get("resource"), request)
return JsonResponse(webfinger_data)
def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse:
response = super().dispatch(request, *args, **kwargs)
# RFC7033 spec
response["Access-Control-Allow-Origin"] = "*"
response["Content-Type"] = "application/jrd+json"
return response

View File

@ -22,6 +22,7 @@ from jwt import encode
from rest_framework.serializers import Serializer from rest_framework.serializers import Serializer
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.brands.models import WebfingerProvider
from authentik.core.models import ExpiringModel, PropertyMapping, Provider, User from authentik.core.models import ExpiringModel, PropertyMapping, Provider, User
from authentik.crypto.models import CertificateKeyPair from authentik.crypto.models import CertificateKeyPair
from authentik.lib.generators import generate_code_fixed_length, generate_id, generate_key from authentik.lib.generators import generate_code_fixed_length, generate_id, generate_key
@ -120,7 +121,7 @@ class ScopeMapping(PropertyMapping):
verbose_name_plural = _("Scope Mappings") verbose_name_plural = _("Scope Mappings")
class OAuth2Provider(Provider): class OAuth2Provider(WebfingerProvider, Provider):
"""OAuth2 Provider for generic OAuth and OpenID Connect Applications.""" """OAuth2 Provider for generic OAuth and OpenID Connect Applications."""
client_type = models.CharField( client_type = models.CharField(
@ -288,6 +289,24 @@ class OAuth2Provider(Provider):
key, alg = self.jwt_key key, alg = self.jwt_key
return encode(payload, key, algorithm=alg, headers=headers) return encode(payload, key, algorithm=alg, headers=headers)
def webfinger(self, resource: str, request: HttpRequest):
return {
"subject": resource,
"links": [
{
"rel": "http://openid.net/specs/connect/1.0/issuer",
"href": request.build_absolute_uri(
reverse(
"authentik_providers_oauth2:provider-root",
kwargs={
"application_slug": self.application.slug,
},
)
),
},
],
}
class Meta: class Meta:
verbose_name = _("OAuth2/OpenID Provider") verbose_name = _("OAuth2/OpenID Provider")
verbose_name_plural = _("OAuth2/OpenID Providers") verbose_name_plural = _("OAuth2/OpenID Providers")