From cf6c3c6d3f3c1128ced8fa4c03cbcdd67ce1e043 Mon Sep 17 00:00:00 2001 From: "Jens L." Date: Wed, 13 Nov 2024 21:45:16 +0100 Subject: [PATCH] providers/oauth2: fix manual device code entry (#12017) * providers/oauth2: fix manual device code entry Signed-off-by: Jens Langhammer * make code input a char field to prevent leading 0s from being cut off Signed-off-by: Jens Langhammer --------- Signed-off-by: Jens Langhammer --- .../oauth2/tests/test_device_init.py | 76 ++++++++++++++++++- .../providers/oauth2/views/device_init.py | 7 +- schema.yml | 3 +- 3 files changed, 82 insertions(+), 4 deletions(-) diff --git a/authentik/providers/oauth2/tests/test_device_init.py b/authentik/providers/oauth2/tests/test_device_init.py index e691e17d49..e503bc0152 100644 --- a/authentik/providers/oauth2/tests/test_device_init.py +++ b/authentik/providers/oauth2/tests/test_device_init.py @@ -3,6 +3,7 @@ from urllib.parse import urlencode from django.urls import reverse +from rest_framework.test import APIClient from authentik.core.models import Application, Group from authentik.core.tests.utils import create_test_admin_user, create_test_brand, create_test_flow @@ -34,7 +35,10 @@ class TesOAuth2DeviceInit(OAuthTestCase): self.brand.flow_device_code = self.device_flow self.brand.save() - def test_device_init(self): + self.api_client = APIClient() + self.api_client.force_login(self.user) + + def test_device_init_get(self): """Test device init""" res = self.client.get(reverse("authentik_providers_oauth2_root:device-login")) self.assertEqual(res.status_code, 302) @@ -48,6 +52,76 @@ class TesOAuth2DeviceInit(OAuthTestCase): ), ) + def test_device_init_post(self): + """Test device init""" + res = self.api_client.get(reverse("authentik_providers_oauth2_root:device-login")) + self.assertEqual(res.status_code, 302) + self.assertEqual( + res.url, + reverse( + "authentik_core:if-flow", + kwargs={ + "flow_slug": self.device_flow.slug, + }, + ), + ) + res = self.api_client.get( + reverse( + "authentik_api:flow-executor", + kwargs={ + "flow_slug": self.device_flow.slug, + }, + ), + ) + self.assertEqual(res.status_code, 200) + self.assertJSONEqual( + res.content, + { + "component": "ak-provider-oauth2-device-code", + "flow_info": { + "background": "/static/dist/assets/images/flow_background.jpg", + "cancel_url": "/flows/-/cancel/", + "layout": "stacked", + "title": self.device_flow.title, + }, + }, + ) + + provider = OAuth2Provider.objects.create( + name=generate_id(), + authorization_flow=create_test_flow(), + ) + Application.objects.create(name=generate_id(), slug=generate_id(), provider=provider) + token = DeviceToken.objects.create( + provider=provider, + ) + + res = self.api_client.post( + reverse( + "authentik_api:flow-executor", + kwargs={ + "flow_slug": self.device_flow.slug, + }, + ), + data={ + "component": "ak-provider-oauth2-device-code", + "code": token.user_code, + }, + ) + self.assertEqual(res.status_code, 200) + self.assertJSONEqual( + res.content, + { + "component": "xak-flow-redirect", + "to": reverse( + "authentik_core:if-flow", + kwargs={ + "flow_slug": provider.authorization_flow.slug, + }, + ), + }, + ) + def test_no_flow(self): """Test no flow""" self.brand.flow_device_code = None diff --git a/authentik/providers/oauth2/views/device_init.py b/authentik/providers/oauth2/views/device_init.py index ffbce26b5b..85b32d8051 100644 --- a/authentik/providers/oauth2/views/device_init.py +++ b/authentik/providers/oauth2/views/device_init.py @@ -5,7 +5,7 @@ from typing import Any from django.http import HttpRequest, HttpResponse from django.utils.translation import gettext as _ from rest_framework.exceptions import ValidationError -from rest_framework.fields import CharField, IntegerField +from rest_framework.fields import CharField from structlog.stdlib import get_logger from authentik.brands.models import Brand @@ -47,6 +47,9 @@ class CodeValidatorView(PolicyAccessView): self.provider = self.token.provider self.application = self.token.provider.application + def post(self, request: HttpRequest, *args, **kwargs): + return self.get(request, *args, **kwargs) + def get(self, request: HttpRequest, *args, **kwargs): scope_descriptions = UserInfoView().get_scope_descriptions(self.token.scope, self.provider) planner = FlowPlanner(self.provider.authorization_flow) @@ -122,7 +125,7 @@ class OAuthDeviceCodeChallenge(Challenge): class OAuthDeviceCodeChallengeResponse(ChallengeResponse): """Response that includes the user-entered device code""" - code = IntegerField() + code = CharField() component = CharField(default="ak-provider-oauth2-device-code") def validate_code(self, code: int) -> HttpResponse | None: diff --git a/schema.yml b/schema.yml index a46794370b..2ffda1d4e3 100644 --- a/schema.yml +++ b/schema.yml @@ -44957,7 +44957,8 @@ components: minLength: 1 default: ak-provider-oauth2-device-code code: - type: integer + type: string + minLength: 1 required: - code OAuthDeviceCodeFinishChallenge: