sources/oauth: introduce authorization code auth method (#14034)

Co-authored-by: Rsgm <rsgm123@gmail.com>
This commit is contained in:
Marc 'risson' Schmitt
2025-04-16 15:00:08 +02:00
committed by GitHub
parent c6f9d5df7b
commit 155a31fd70
22 changed files with 251 additions and 77 deletions

View File

@ -130,6 +130,7 @@ class OAuthSourceSerializer(SourceSerializer):
"oidc_well_known_url", "oidc_well_known_url",
"oidc_jwks_url", "oidc_jwks_url",
"oidc_jwks", "oidc_jwks",
"authorization_code_auth_method",
] ]
extra_kwargs = { extra_kwargs = {
"consumer_secret": {"write_only": True}, "consumer_secret": {"write_only": True},

View File

@ -6,11 +6,15 @@ from urllib.parse import parse_qsl
from django.utils.crypto import constant_time_compare, get_random_string from django.utils.crypto import constant_time_compare, get_random_string
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from requests.auth import AuthBase, HTTPBasicAuth
from requests.exceptions import RequestException from requests.exceptions import RequestException
from requests.models import Response from requests.models import Response
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.sources.oauth.clients.base import BaseOAuthClient from authentik.sources.oauth.clients.base import BaseOAuthClient
from authentik.sources.oauth.models import (
AuthorizationCodeAuthMethod,
)
LOGGER = get_logger() LOGGER = get_logger()
SESSION_KEY_OAUTH_PKCE = "authentik/sources/oauth/pkce" SESSION_KEY_OAUTH_PKCE = "authentik/sources/oauth/pkce"
@ -55,6 +59,30 @@ class OAuth2Client(BaseOAuthClient):
"""Get client secret""" """Get client secret"""
return self.source.consumer_secret return self.source.consumer_secret
def get_access_token_args(self, callback: str, code: str) -> dict[str, Any]:
args = {
"redirect_uri": callback,
"code": code,
"grant_type": "authorization_code",
}
if SESSION_KEY_OAUTH_PKCE in self.request.session:
args["code_verifier"] = self.request.session[SESSION_KEY_OAUTH_PKCE]
if (
self.source.source_type.authorization_code_auth_method
== AuthorizationCodeAuthMethod.POST_BODY
):
args["client_id"] = self.get_client_id()
args["client_secret"] = self.get_client_secret()
return args
def get_access_token_auth(self) -> AuthBase | None:
if (
self.source.source_type.authorization_code_auth_method
== AuthorizationCodeAuthMethod.BASIC_AUTH
):
return HTTPBasicAuth(self.get_client_id(), self.get_client_secret())
return None
def get_access_token(self, **request_kwargs) -> dict[str, Any] | None: def get_access_token(self, **request_kwargs) -> dict[str, Any] | None:
"""Fetch access token from callback request.""" """Fetch access token from callback request."""
callback = self.request.build_absolute_uri(self.callback or self.request.path) callback = self.request.build_absolute_uri(self.callback or self.request.path)
@ -67,13 +95,6 @@ class OAuth2Client(BaseOAuthClient):
error = self.get_request_arg("error", None) error = self.get_request_arg("error", None)
error_desc = self.get_request_arg("error_description", None) error_desc = self.get_request_arg("error_description", None)
return {"error": error_desc or error or _("No token received.")} return {"error": error_desc or error or _("No token received.")}
args = {
"redirect_uri": callback,
"code": code,
"grant_type": "authorization_code",
}
if SESSION_KEY_OAUTH_PKCE in self.request.session:
args["code_verifier"] = self.request.session[SESSION_KEY_OAUTH_PKCE]
try: try:
access_token_url = self.source.source_type.access_token_url or "" access_token_url = self.source.source_type.access_token_url or ""
if self.source.source_type.urls_customizable and self.source.access_token_url: if self.source.source_type.urls_customizable and self.source.access_token_url:
@ -81,8 +102,8 @@ class OAuth2Client(BaseOAuthClient):
response = self.do_request( response = self.do_request(
"post", "post",
access_token_url, access_token_url,
auth=(self.get_client_id(), self.get_client_secret()), auth=self.get_access_token_auth(),
data=args, data=self.get_access_token_args(callback, code),
headers=self._default_headers, headers=self._default_headers,
**request_kwargs, **request_kwargs,
) )

View File

@ -0,0 +1,25 @@
# Generated by Django 5.0.14 on 2025-04-11 18:09
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_sources_oauth", "0009_migrate_useroauthsourceconnection_identifier"),
]
operations = [
migrations.AddField(
model_name="oauthsource",
name="authorization_code_auth_method",
field=models.TextField(
choices=[
("basic_auth", "HTTP Basic Authentication"),
("post_body", "Include the client ID and secret as request parameters"),
],
default="basic_auth",
help_text="How to perform authentication during an authorization_code token request flow",
),
),
]

View File

@ -21,6 +21,11 @@ if TYPE_CHECKING:
from authentik.sources.oauth.types.registry import SourceType from authentik.sources.oauth.types.registry import SourceType
class AuthorizationCodeAuthMethod(models.TextChoices):
BASIC_AUTH = "basic_auth", _("HTTP Basic Authentication")
POST_BODY = "post_body", _("Include the client ID and secret as request parameters")
class OAuthSource(NonCreatableType, Source): class OAuthSource(NonCreatableType, Source):
"""Login using a Generic OAuth provider.""" """Login using a Generic OAuth provider."""
@ -61,6 +66,14 @@ class OAuthSource(NonCreatableType, Source):
oidc_jwks_url = models.TextField(default="", blank=True) oidc_jwks_url = models.TextField(default="", blank=True)
oidc_jwks = models.JSONField(default=dict, blank=True) oidc_jwks = models.JSONField(default=dict, blank=True)
authorization_code_auth_method = models.TextField(
choices=AuthorizationCodeAuthMethod.choices,
default=AuthorizationCodeAuthMethod.BASIC_AUTH,
help_text=_(
"How to perform authentication during an authorization_code token request flow"
),
)
@property @property
def source_type(self) -> type["SourceType"]: def source_type(self) -> type["SourceType"]:
"""Return the provider instance for this source""" """Return the provider instance for this source"""

View File

@ -0,0 +1,69 @@
from django.test import RequestFactory, TestCase
from guardian.shortcuts import get_anonymous_user
from authentik.lib.generators import generate_id
from authentik.sources.oauth.clients.oauth2 import OAuth2Client
from authentik.sources.oauth.models import AuthorizationCodeAuthMethod, OAuthSource
from authentik.sources.oauth.types.oidc import OpenIDConnectClient
class TestOAuthClient(TestCase):
"""OAuth Source tests"""
def setUp(self):
self.source = OAuthSource.objects.create(
name="test",
slug="test",
provider_type="openidconnect",
authorization_url="",
profile_url="",
consumer_key=generate_id(),
)
self.factory = RequestFactory()
def test_client_post_body_auth(self):
"""Test login_challenge"""
self.source.provider_type = "apple"
self.source.save()
request = self.factory.get("/")
request.session = {}
request.user = get_anonymous_user()
client = OAuth2Client(self.source, request)
self.assertIsNone(client.get_access_token_auth())
args = client.get_access_token_args("", "")
self.assertIn("client_id", args)
self.assertIn("client_secret", args)
def test_client_basic_auth(self):
"""Test login_challenge"""
self.source.provider_type = "reddit"
self.source.save()
request = self.factory.get("/")
request.session = {}
request.user = get_anonymous_user()
client = OAuth2Client(self.source, request)
self.assertIsNotNone(client.get_access_token_auth())
args = client.get_access_token_args("", "")
self.assertNotIn("client_id", args)
self.assertNotIn("client_secret", args)
def test_client_openid_auth(self):
"""Test login_challenge"""
request = self.factory.get("/")
request.session = {}
request.user = get_anonymous_user()
client = OpenIDConnectClient(self.source, request)
self.assertIsNotNone(client.get_access_token_auth())
args = client.get_access_token_args("", "")
self.assertNotIn("client_id", args)
self.assertNotIn("client_secret", args)
self.source.authorization_code_auth_method = AuthorizationCodeAuthMethod.POST_BODY
self.source.save()
client = OpenIDConnectClient(self.source, request)
self.assertIsNone(client.get_access_token_auth())
args = client.get_access_token_args("", "")
self.assertIn("client_id", args)
self.assertIn("client_secret", args)

View File

@ -11,7 +11,7 @@ from structlog.stdlib import get_logger
from authentik.flows.challenge import Challenge, ChallengeResponse from authentik.flows.challenge import Challenge, ChallengeResponse
from authentik.sources.oauth.clients.oauth2 import OAuth2Client from authentik.sources.oauth.clients.oauth2 import OAuth2Client
from authentik.sources.oauth.models import OAuthSource from authentik.sources.oauth.models import AuthorizationCodeAuthMethod, OAuthSource
from authentik.sources.oauth.types.registry import SourceType, registry from authentik.sources.oauth.types.registry import SourceType, registry
from authentik.sources.oauth.views.callback import OAuthCallback from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect from authentik.sources.oauth.views.redirect import OAuthRedirect
@ -105,6 +105,8 @@ class AppleType(SourceType):
access_token_url = "https://appleid.apple.com/auth/token" # nosec access_token_url = "https://appleid.apple.com/auth/token" # nosec
profile_url = "" profile_url = ""
authorization_code_auth_method = AuthorizationCodeAuthMethod.POST_BODY
def login_challenge(self, source: OAuthSource, request: HttpRequest) -> Challenge: def login_challenge(self, source: OAuthSource, request: HttpRequest) -> Challenge:
"""Pre-general all the things required for the JS SDK""" """Pre-general all the things required for the JS SDK"""
apple_client = AppleOAuthClient( apple_client = AppleOAuthClient(

View File

@ -6,6 +6,7 @@ from requests import RequestException
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient
from authentik.sources.oauth.models import AuthorizationCodeAuthMethod
from authentik.sources.oauth.types.oidc import OpenIDConnectOAuth2Callback from authentik.sources.oauth.types.oidc import OpenIDConnectOAuth2Callback
from authentik.sources.oauth.types.registry import SourceType, registry from authentik.sources.oauth.types.registry import SourceType, registry
from authentik.sources.oauth.views.redirect import OAuthRedirect from authentik.sources.oauth.views.redirect import OAuthRedirect
@ -77,6 +78,8 @@ class AzureADType(SourceType):
) )
oidc_jwks_url = "https://login.microsoftonline.com/common/discovery/keys" oidc_jwks_url = "https://login.microsoftonline.com/common/discovery/keys"
authorization_code_auth_method = AuthorizationCodeAuthMethod.POST_BODY
def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]: def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]:
mail = info.get("mail", None) or info.get("otherMails", [None])[0] mail = info.get("mail", None) or info.get("otherMails", [None])[0]
# Format group info # Format group info

View File

@ -2,8 +2,8 @@
from typing import Any from typing import Any
from authentik.sources.oauth.models import AuthorizationCodeAuthMethod
from authentik.sources.oauth.types.registry import SourceType, registry from authentik.sources.oauth.types.registry import SourceType, registry
from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect from authentik.sources.oauth.views.redirect import OAuthRedirect
@ -16,15 +16,10 @@ class FacebookOAuthRedirect(OAuthRedirect):
} }
class FacebookOAuth2Callback(OAuthCallback):
"""Facebook OAuth2 Callback"""
@registry.register() @registry.register()
class FacebookType(SourceType): class FacebookType(SourceType):
"""Facebook Type definition""" """Facebook Type definition"""
callback_view = FacebookOAuth2Callback
redirect_view = FacebookOAuthRedirect redirect_view = FacebookOAuthRedirect
verbose_name = "Facebook" verbose_name = "Facebook"
name = "facebook" name = "facebook"
@ -33,6 +28,8 @@ class FacebookType(SourceType):
access_token_url = "https://graph.facebook.com/v7.0/oauth/access_token" # nosec access_token_url = "https://graph.facebook.com/v7.0/oauth/access_token" # nosec
profile_url = "https://graph.facebook.com/v7.0/me?fields=id,name,email" profile_url = "https://graph.facebook.com/v7.0/me?fields=id,name,email"
authorization_code_auth_method = AuthorizationCodeAuthMethod.POST_BODY
def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]: def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]:
return { return {
"username": info.get("name"), "username": info.get("name"),

View File

@ -5,7 +5,7 @@ from typing import Any
from requests.exceptions import RequestException from requests.exceptions import RequestException
from authentik.sources.oauth.clients.oauth2 import OAuth2Client from authentik.sources.oauth.clients.oauth2 import OAuth2Client
from authentik.sources.oauth.models import OAuthSource from authentik.sources.oauth.models import AuthorizationCodeAuthMethod, OAuthSource
from authentik.sources.oauth.types.registry import SourceType, registry from authentik.sources.oauth.types.registry import SourceType, registry
from authentik.sources.oauth.views.callback import OAuthCallback from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect from authentik.sources.oauth.views.redirect import OAuthRedirect
@ -63,6 +63,8 @@ class GitHubType(SourceType):
) )
oidc_jwks_url = "https://token.actions.githubusercontent.com/.well-known/jwks" oidc_jwks_url = "https://token.actions.githubusercontent.com/.well-known/jwks"
authorization_code_auth_method = AuthorizationCodeAuthMethod.POST_BODY
def get_base_user_properties( def get_base_user_properties(
self, self,
source: OAuthSource, source: OAuthSource,

View File

@ -7,9 +7,8 @@ and https://docs.gitlab.com/ee/integration/openid_connect_provider.html
from typing import Any from typing import Any
from authentik.sources.oauth.models import OAuthSource from authentik.sources.oauth.models import AuthorizationCodeAuthMethod, OAuthSource
from authentik.sources.oauth.types.registry import SourceType, registry from authentik.sources.oauth.types.registry import SourceType, registry
from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect from authentik.sources.oauth.views.redirect import OAuthRedirect
@ -22,15 +21,10 @@ class GitLabOAuthRedirect(OAuthRedirect):
} }
class GitLabOAuthCallback(OAuthCallback):
"""GitLab OAuth2 Callback"""
@registry.register() @registry.register()
class GitLabType(SourceType): class GitLabType(SourceType):
"""GitLab Type definition""" """GitLab Type definition"""
callback_view = GitLabOAuthCallback
redirect_view = GitLabOAuthRedirect redirect_view = GitLabOAuthRedirect
verbose_name = "GitLab" verbose_name = "GitLab"
name = "gitlab" name = "gitlab"
@ -43,6 +37,8 @@ class GitLabType(SourceType):
oidc_well_known_url = "https://gitlab.com/.well-known/openid-configuration" oidc_well_known_url = "https://gitlab.com/.well-known/openid-configuration"
oidc_jwks_url = "https://gitlab.com/oauth/discovery/keys" oidc_jwks_url = "https://gitlab.com/oauth/discovery/keys"
authorization_code_auth_method = AuthorizationCodeAuthMethod.POST_BODY
def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]: def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]:
return { return {
"username": info.get("preferred_username"), "username": info.get("preferred_username"),

View File

@ -2,8 +2,8 @@
from typing import Any from typing import Any
from authentik.sources.oauth.models import AuthorizationCodeAuthMethod
from authentik.sources.oauth.types.registry import SourceType, registry from authentik.sources.oauth.types.registry import SourceType, registry
from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect from authentik.sources.oauth.views.redirect import OAuthRedirect
@ -16,15 +16,10 @@ class GoogleOAuthRedirect(OAuthRedirect):
} }
class GoogleOAuth2Callback(OAuthCallback):
"""Google OAuth2 Callback"""
@registry.register() @registry.register()
class GoogleType(SourceType): class GoogleType(SourceType):
"""Google Type definition""" """Google Type definition"""
callback_view = GoogleOAuth2Callback
redirect_view = GoogleOAuthRedirect redirect_view = GoogleOAuthRedirect
verbose_name = "Google" verbose_name = "Google"
name = "google" name = "google"
@ -35,6 +30,8 @@ class GoogleType(SourceType):
oidc_well_known_url = "https://accounts.google.com/.well-known/openid-configuration" oidc_well_known_url = "https://accounts.google.com/.well-known/openid-configuration"
oidc_jwks_url = "https://www.googleapis.com/oauth2/v3/certs" oidc_jwks_url = "https://www.googleapis.com/oauth2/v3/certs"
authorization_code_auth_method = AuthorizationCodeAuthMethod.POST_BODY
def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]: def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]:
return { return {
"email": info.get("email"), "email": info.get("email"),

View File

@ -6,6 +6,7 @@ from requests.exceptions import RequestException
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.sources.oauth.clients.oauth2 import OAuth2Client from authentik.sources.oauth.clients.oauth2 import OAuth2Client
from authentik.sources.oauth.models import AuthorizationCodeAuthMethod
from authentik.sources.oauth.types.registry import SourceType, registry from authentik.sources.oauth.types.registry import SourceType, registry
from authentik.sources.oauth.views.callback import OAuthCallback from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect from authentik.sources.oauth.views.redirect import OAuthRedirect
@ -59,6 +60,8 @@ class MailcowType(SourceType):
urls_customizable = True urls_customizable = True
authorization_code_auth_method = AuthorizationCodeAuthMethod.POST_BODY
def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]: def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]:
return { return {
"username": info.get("full_name"), "username": info.get("full_name"),

View File

@ -2,8 +2,10 @@
from typing import Any from typing import Any
from requests.auth import AuthBase, HTTPBasicAuth
from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient
from authentik.sources.oauth.models import OAuthSource from authentik.sources.oauth.models import AuthorizationCodeAuthMethod, OAuthSource
from authentik.sources.oauth.types.registry import SourceType, registry from authentik.sources.oauth.types.registry import SourceType, registry
from authentik.sources.oauth.views.callback import OAuthCallback from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect from authentik.sources.oauth.views.redirect import OAuthRedirect
@ -18,10 +20,27 @@ class OpenIDConnectOAuthRedirect(OAuthRedirect):
} }
class OpenIDConnectClient(UserprofileHeaderAuthClient):
def get_access_token_args(self, callback: str, code: str) -> dict[str, Any]:
args = super().get_access_token_args(callback, code)
if self.source.authorization_code_auth_method == AuthorizationCodeAuthMethod.POST_BODY:
args["client_id"] = self.get_client_id()
args["client_secret"] = self.get_client_secret()
else:
args.pop("client_id", None)
args.pop("client_secret", None)
return args
def get_access_token_auth(self) -> AuthBase | None:
if self.source.authorization_code_auth_method == AuthorizationCodeAuthMethod.BASIC_AUTH:
return HTTPBasicAuth(self.get_client_id(), self.get_client_secret())
return None
class OpenIDConnectOAuth2Callback(OAuthCallback): class OpenIDConnectOAuth2Callback(OAuthCallback):
"""OpenIDConnect OAuth2 Callback""" """OpenIDConnect OAuth2 Callback"""
client_class = UserprofileHeaderAuthClient client_class = OpenIDConnectClient
def get_user_id(self, info: dict[str, str]) -> str: def get_user_id(self, info: dict[str, str]) -> str:
return info.get("sub", None) return info.get("sub", None)

View File

@ -2,7 +2,6 @@
from typing import Any from typing import Any
from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient
from authentik.sources.oauth.models import OAuthSource from authentik.sources.oauth.models import OAuthSource
from authentik.sources.oauth.types.oidc import OpenIDConnectOAuth2Callback from authentik.sources.oauth.types.oidc import OpenIDConnectOAuth2Callback
from authentik.sources.oauth.types.registry import SourceType, registry from authentik.sources.oauth.types.registry import SourceType, registry
@ -18,20 +17,11 @@ class OktaOAuthRedirect(OAuthRedirect):
} }
class OktaOAuth2Callback(OpenIDConnectOAuth2Callback):
"""Okta OAuth2 Callback"""
# Okta has the same quirk as azure and throws an error if the access token
# is set via query parameter, so we reuse the azure client
# see https://github.com/goauthentik/authentik/issues/1910
client_class = UserprofileHeaderAuthClient
@registry.register() @registry.register()
class OktaType(SourceType): class OktaType(SourceType):
"""Okta Type definition""" """Okta Type definition"""
callback_view = OktaOAuth2Callback callback_view = OpenIDConnectOAuth2Callback
redirect_view = OktaOAuthRedirect redirect_view = OktaOAuthRedirect
verbose_name = "Okta" verbose_name = "Okta"
name = "okta" name = "okta"

View File

@ -3,7 +3,7 @@
from typing import Any from typing import Any
from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient
from authentik.sources.oauth.models import OAuthSource from authentik.sources.oauth.models import AuthorizationCodeAuthMethod, OAuthSource
from authentik.sources.oauth.types.registry import SourceType, registry from authentik.sources.oauth.types.registry import SourceType, registry
from authentik.sources.oauth.views.callback import OAuthCallback from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect from authentik.sources.oauth.views.redirect import OAuthRedirect
@ -41,6 +41,8 @@ class PatreonType(SourceType):
access_token_url = "https://www.patreon.com/api/oauth2/token" # nosec access_token_url = "https://www.patreon.com/api/oauth2/token" # nosec
profile_url = "https://www.patreon.com/api/oauth2/api/current_user" profile_url = "https://www.patreon.com/api/oauth2/api/current_user"
authorization_code_auth_method = AuthorizationCodeAuthMethod.POST_BODY
def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]: def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]:
return { return {
"username": info.get("data", {}).get("attributes", {}).get("vanity"), "username": info.get("data", {}).get("attributes", {}).get("vanity"),

View File

@ -2,8 +2,6 @@
from typing import Any from typing import Any
from requests.auth import HTTPBasicAuth
from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient
from authentik.sources.oauth.types.registry import SourceType, registry from authentik.sources.oauth.types.registry import SourceType, registry
from authentik.sources.oauth.views.callback import OAuthCallback from authentik.sources.oauth.views.callback import OAuthCallback
@ -20,21 +18,10 @@ class RedditOAuthRedirect(OAuthRedirect):
} }
class RedditOAuth2Client(UserprofileHeaderAuthClient):
"""Reddit OAuth2 Client"""
def get_access_token(self, **request_kwargs):
"Fetch access token from callback request."
request_kwargs["auth"] = HTTPBasicAuth(
self.source.consumer_key, self.source.consumer_secret
)
return super().get_access_token(**request_kwargs)
class RedditOAuth2Callback(OAuthCallback): class RedditOAuth2Callback(OAuthCallback):
"""Reddit OAuth2 Callback""" """Reddit OAuth2 Callback"""
client_class = RedditOAuth2Client client_class = UserprofileHeaderAuthClient
@registry.register() @registry.register()

View File

@ -10,7 +10,7 @@ from django.urls.base import reverse
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.flows.challenge import Challenge, RedirectChallenge from authentik.flows.challenge import Challenge, RedirectChallenge
from authentik.sources.oauth.models import OAuthSource from authentik.sources.oauth.models import AuthorizationCodeAuthMethod, OAuthSource
from authentik.sources.oauth.views.callback import OAuthCallback from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect from authentik.sources.oauth.views.redirect import OAuthRedirect
@ -41,6 +41,10 @@ class SourceType:
oidc_well_known_url: str | None = None oidc_well_known_url: str | None = None
oidc_jwks_url: str | None = None oidc_jwks_url: str | None = None
authorization_code_auth_method: AuthorizationCodeAuthMethod = (
AuthorizationCodeAuthMethod.BASIC_AUTH
)
def icon_url(self) -> str: def icon_url(self) -> str:
"""Get Icon URL for login""" """Get Icon URL for login"""
return static(f"authentik/sources/{self.name}.svg") return static(f"authentik/sources/{self.name}.svg")

View File

@ -4,6 +4,7 @@ from json import dumps
from typing import Any from typing import Any
from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient
from authentik.sources.oauth.models import AuthorizationCodeAuthMethod
from authentik.sources.oauth.types.oidc import OpenIDConnectOAuth2Callback from authentik.sources.oauth.types.oidc import OpenIDConnectOAuth2Callback
from authentik.sources.oauth.types.registry import SourceType, registry from authentik.sources.oauth.types.registry import SourceType, registry
from authentik.sources.oauth.views.redirect import OAuthRedirect from authentik.sources.oauth.views.redirect import OAuthRedirect
@ -47,6 +48,8 @@ class TwitchType(SourceType):
access_token_url = "https://id.twitch.tv/oauth2/token" # nosec access_token_url = "https://id.twitch.tv/oauth2/token" # nosec
profile_url = "https://id.twitch.tv/oauth2/userinfo" profile_url = "https://id.twitch.tv/oauth2/userinfo"
authorization_code_auth_method = AuthorizationCodeAuthMethod.POST_BODY
def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]: def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]:
return { return {
"username": info.get("preferred_username"), "username": info.get("preferred_username"),

View File

@ -12,23 +12,6 @@ from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect from authentik.sources.oauth.views.redirect import OAuthRedirect
class TwitterClient(UserprofileHeaderAuthClient):
"""Twitter has similar quirks to Azure AD, and additionally requires Basic auth on
the access token endpoint for some reason."""
# Twitter has the same quirk as azure and throws an error if the access token
# is set via query parameter, so we reuse the azure client
# see https://github.com/goauthentik/authentik/issues/1910
def get_access_token(self, **request_kwargs) -> dict[str, Any] | None:
return super().get_access_token(
auth=(
self.source.consumer_key,
self.source.consumer_secret,
)
)
class TwitterOAuthRedirect(OAuthRedirect): class TwitterOAuthRedirect(OAuthRedirect):
"""Twitter OAuth2 Redirect""" """Twitter OAuth2 Redirect"""
@ -44,7 +27,7 @@ class TwitterOAuthRedirect(OAuthRedirect):
class TwitterOAuthCallback(OAuthCallback): class TwitterOAuthCallback(OAuthCallback):
"""Twitter OAuth2 Callback""" """Twitter OAuth2 Callback"""
client_class = TwitterClient client_class = UserprofileHeaderAuthClient
def get_user_id(self, info: dict[str, str]) -> str: def get_user_id(self, info: dict[str, str]) -> str:
return info.get("data", {}).get("id", "") return info.get("data", {}).get("id", "")

View File

@ -8436,6 +8436,15 @@
"type": "object", "type": "object",
"additionalProperties": true, "additionalProperties": true,
"title": "Oidc jwks" "title": "Oidc jwks"
},
"authorization_code_auth_method": {
"type": "string",
"enum": [
"basic_auth",
"post_body"
],
"title": "Authorization code auth method",
"description": "How to perform authentication during an authorization_code token request flow"
} }
}, },
"required": [] "required": []

View File

@ -42110,6 +42110,11 @@ components:
format: uuid format: uuid
required: required:
- name - name
AuthorizationCodeAuthMethodEnum:
enum:
- basic_auth
- post_body
type: string
AutoSubmitChallengeResponseRequest: AutoSubmitChallengeResponseRequest:
type: object type: object
description: Pseudo class for autosubmit response description: Pseudo class for autosubmit response
@ -48742,6 +48747,11 @@ components:
oidc_jwks_url: oidc_jwks_url:
type: string type: string
oidc_jwks: {} oidc_jwks: {}
authorization_code_auth_method:
allOf:
- $ref: '#/components/schemas/AuthorizationCodeAuthMethodEnum'
description: How to perform authentication during an authorization_code
token request flow
required: required:
- callback_url - callback_url
- component - component
@ -48911,6 +48921,11 @@ components:
oidc_jwks_url: oidc_jwks_url:
type: string type: string
oidc_jwks: {} oidc_jwks: {}
authorization_code_auth_method:
allOf:
- $ref: '#/components/schemas/AuthorizationCodeAuthMethodEnum'
description: How to perform authentication during an authorization_code
token request flow
required: required:
- consumer_key - consumer_key
- consumer_secret - consumer_secret
@ -53009,6 +53024,11 @@ components:
oidc_jwks_url: oidc_jwks_url:
type: string type: string
oidc_jwks: {} oidc_jwks: {}
authorization_code_auth_method:
allOf:
- $ref: '#/components/schemas/AuthorizationCodeAuthMethodEnum'
description: How to perform authentication during an authorization_code
token request flow
PatchedOutpostRequest: PatchedOutpostRequest:
type: object type: object
description: Outpost Serializer description: Outpost Serializer

View File

@ -7,6 +7,7 @@ import {
} from "@goauthentik/admin/sources/oauth/utils"; } from "@goauthentik/admin/sources/oauth/utils";
import { DEFAULT_CONFIG, config } from "@goauthentik/common/api/config"; import { DEFAULT_CONFIG, config } from "@goauthentik/common/api/config";
import { first } from "@goauthentik/common/utils"; import { first } from "@goauthentik/common/utils";
import "@goauthentik/components/ak-radio-input";
import "@goauthentik/elements/CodeMirror"; import "@goauthentik/elements/CodeMirror";
import { CodeMirrorMode } from "@goauthentik/elements/CodeMirror"; import { CodeMirrorMode } from "@goauthentik/elements/CodeMirror";
import { import {
@ -16,6 +17,7 @@ import {
import "@goauthentik/elements/ak-dual-select/ak-dual-select-dynamic-selected-provider.js"; import "@goauthentik/elements/ak-dual-select/ak-dual-select-dynamic-selected-provider.js";
import "@goauthentik/elements/forms/FormGroup"; import "@goauthentik/elements/forms/FormGroup";
import "@goauthentik/elements/forms/HorizontalFormElement"; import "@goauthentik/elements/forms/HorizontalFormElement";
import "@goauthentik/elements/forms/Radio";
import "@goauthentik/elements/forms/SearchSelect"; import "@goauthentik/elements/forms/SearchSelect";
import { msg } from "@lit/localize"; import { msg } from "@lit/localize";
@ -24,6 +26,7 @@ import { customElement, property, state } from "lit/decorators.js";
import { ifDefined } from "lit/directives/if-defined.js"; import { ifDefined } from "lit/directives/if-defined.js";
import { import {
AuthorizationCodeAuthMethodEnum,
FlowsInstancesListDesignationEnum, FlowsInstancesListDesignationEnum,
GroupMatchingModeEnum, GroupMatchingModeEnum,
OAuthSource, OAuthSource,
@ -36,6 +39,18 @@ import {
import { propertyMappingsProvider, propertyMappingsSelector } from "./OAuthSourceFormHelpers.js"; import { propertyMappingsProvider, propertyMappingsSelector } from "./OAuthSourceFormHelpers.js";
const authorizationCodeAuthMethodOptions = [
{
label: msg("HTTP Basic Auth"),
value: AuthorizationCodeAuthMethodEnum.BasicAuth,
default: true,
},
{
label: msg("Include the client ID and secret as request parameters"),
value: AuthorizationCodeAuthMethodEnum.PostBody,
},
];
@customElement("ak-source-oauth-form") @customElement("ak-source-oauth-form")
export class OAuthSourceForm extends WithCapabilitiesConfig(BaseSourceForm<OAuthSource>) { export class OAuthSourceForm extends WithCapabilitiesConfig(BaseSourceForm<OAuthSource>) {
async loadInstance(pk: string): Promise<OAuthSource> { async loadInstance(pk: string): Promise<OAuthSource> {
@ -240,6 +255,19 @@ export class OAuthSourceForm extends WithCapabilitiesConfig(BaseSourceForm<OAuth
<p class="pf-c-form__helper-text">${msg("Raw JWKS data.")}</p> <p class="pf-c-form__helper-text">${msg("Raw JWKS data.")}</p>
</ak-form-element-horizontal>` </ak-form-element-horizontal>`
: html``} : html``}
${this.providerType.name === ProviderTypeEnum.Openidconnect
? html`<ak-radio-input
label=${msg("Authorization code authentication method")}
name="authorizationCodeAuthMethod"
required
.options=${authorizationCodeAuthMethodOptions}
.value=${this.instance?.authorizationCodeAuthMethod}
help=${msg(
"How to perform authentication during an authorization_code token request flow",
)}
>
</ak-radio-input>`
: html``}
</div> </div>
</ak-form-group>`; </ak-form-group>`;
} }