sources/oauth: introduce authorization code auth method (#14034)
Co-authored-by: Rsgm <rsgm123@gmail.com>
This commit is contained in:
		 Marc 'risson' Schmitt
					Marc 'risson' Schmitt
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							c6f9d5df7b
						
					
				
				
					commit
					155a31fd70
				
			| @ -130,6 +130,7 @@ class OAuthSourceSerializer(SourceSerializer): | ||||
|             "oidc_well_known_url", | ||||
|             "oidc_jwks_url", | ||||
|             "oidc_jwks", | ||||
|             "authorization_code_auth_method", | ||||
|         ] | ||||
|         extra_kwargs = { | ||||
|             "consumer_secret": {"write_only": True}, | ||||
|  | ||||
| @ -6,11 +6,15 @@ from urllib.parse import parse_qsl | ||||
|  | ||||
| from django.utils.crypto import constant_time_compare, get_random_string | ||||
| from django.utils.translation import gettext as _ | ||||
| from requests.auth import AuthBase, HTTPBasicAuth | ||||
| from requests.exceptions import RequestException | ||||
| from requests.models import Response | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.sources.oauth.clients.base import BaseOAuthClient | ||||
| from authentik.sources.oauth.models import ( | ||||
|     AuthorizationCodeAuthMethod, | ||||
| ) | ||||
|  | ||||
| LOGGER = get_logger() | ||||
| SESSION_KEY_OAUTH_PKCE = "authentik/sources/oauth/pkce" | ||||
| @ -55,6 +59,30 @@ class OAuth2Client(BaseOAuthClient): | ||||
|         """Get client secret""" | ||||
|         return self.source.consumer_secret | ||||
|  | ||||
|     def get_access_token_args(self, callback: str, code: str) -> dict[str, Any]: | ||||
|         args = { | ||||
|             "redirect_uri": callback, | ||||
|             "code": code, | ||||
|             "grant_type": "authorization_code", | ||||
|         } | ||||
|         if SESSION_KEY_OAUTH_PKCE in self.request.session: | ||||
|             args["code_verifier"] = self.request.session[SESSION_KEY_OAUTH_PKCE] | ||||
|         if ( | ||||
|             self.source.source_type.authorization_code_auth_method | ||||
|             == AuthorizationCodeAuthMethod.POST_BODY | ||||
|         ): | ||||
|             args["client_id"] = self.get_client_id() | ||||
|             args["client_secret"] = self.get_client_secret() | ||||
|         return args | ||||
|  | ||||
|     def get_access_token_auth(self) -> AuthBase | None: | ||||
|         if ( | ||||
|             self.source.source_type.authorization_code_auth_method | ||||
|             == AuthorizationCodeAuthMethod.BASIC_AUTH | ||||
|         ): | ||||
|             return HTTPBasicAuth(self.get_client_id(), self.get_client_secret()) | ||||
|         return None | ||||
|  | ||||
|     def get_access_token(self, **request_kwargs) -> dict[str, Any] | None: | ||||
|         """Fetch access token from callback request.""" | ||||
|         callback = self.request.build_absolute_uri(self.callback or self.request.path) | ||||
| @ -67,13 +95,6 @@ class OAuth2Client(BaseOAuthClient): | ||||
|             error = self.get_request_arg("error", None) | ||||
|             error_desc = self.get_request_arg("error_description", None) | ||||
|             return {"error": error_desc or error or _("No token received.")} | ||||
|         args = { | ||||
|             "redirect_uri": callback, | ||||
|             "code": code, | ||||
|             "grant_type": "authorization_code", | ||||
|         } | ||||
|         if SESSION_KEY_OAUTH_PKCE in self.request.session: | ||||
|             args["code_verifier"] = self.request.session[SESSION_KEY_OAUTH_PKCE] | ||||
|         try: | ||||
|             access_token_url = self.source.source_type.access_token_url or "" | ||||
|             if self.source.source_type.urls_customizable and self.source.access_token_url: | ||||
| @ -81,8 +102,8 @@ class OAuth2Client(BaseOAuthClient): | ||||
|             response = self.do_request( | ||||
|                 "post", | ||||
|                 access_token_url, | ||||
|                 auth=(self.get_client_id(), self.get_client_secret()), | ||||
|                 data=args, | ||||
|                 auth=self.get_access_token_auth(), | ||||
|                 data=self.get_access_token_args(callback, code), | ||||
|                 headers=self._default_headers, | ||||
|                 **request_kwargs, | ||||
|             ) | ||||
|  | ||||
| @ -0,0 +1,25 @@ | ||||
| # Generated by Django 5.0.14 on 2025-04-11 18:09 | ||||
|  | ||||
| from django.db import migrations, models | ||||
|  | ||||
|  | ||||
| class Migration(migrations.Migration): | ||||
|  | ||||
|     dependencies = [ | ||||
|         ("authentik_sources_oauth", "0009_migrate_useroauthsourceconnection_identifier"), | ||||
|     ] | ||||
|  | ||||
|     operations = [ | ||||
|         migrations.AddField( | ||||
|             model_name="oauthsource", | ||||
|             name="authorization_code_auth_method", | ||||
|             field=models.TextField( | ||||
|                 choices=[ | ||||
|                     ("basic_auth", "HTTP Basic Authentication"), | ||||
|                     ("post_body", "Include the client ID and secret as request parameters"), | ||||
|                 ], | ||||
|                 default="basic_auth", | ||||
|                 help_text="How to perform authentication during an authorization_code token request flow", | ||||
|             ), | ||||
|         ), | ||||
|     ] | ||||
| @ -21,6 +21,11 @@ if TYPE_CHECKING: | ||||
|     from authentik.sources.oauth.types.registry import SourceType | ||||
|  | ||||
|  | ||||
| class AuthorizationCodeAuthMethod(models.TextChoices): | ||||
|     BASIC_AUTH = "basic_auth", _("HTTP Basic Authentication") | ||||
|     POST_BODY = "post_body", _("Include the client ID and secret as request parameters") | ||||
|  | ||||
|  | ||||
| class OAuthSource(NonCreatableType, Source): | ||||
|     """Login using a Generic OAuth provider.""" | ||||
|  | ||||
| @ -61,6 +66,14 @@ class OAuthSource(NonCreatableType, Source): | ||||
|     oidc_jwks_url = models.TextField(default="", blank=True) | ||||
|     oidc_jwks = models.JSONField(default=dict, blank=True) | ||||
|  | ||||
|     authorization_code_auth_method = models.TextField( | ||||
|         choices=AuthorizationCodeAuthMethod.choices, | ||||
|         default=AuthorizationCodeAuthMethod.BASIC_AUTH, | ||||
|         help_text=_( | ||||
|             "How to perform authentication during an authorization_code token request flow" | ||||
|         ), | ||||
|     ) | ||||
|  | ||||
|     @property | ||||
|     def source_type(self) -> type["SourceType"]: | ||||
|         """Return the provider instance for this source""" | ||||
|  | ||||
							
								
								
									
										69
									
								
								authentik/sources/oauth/tests/test_client.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								authentik/sources/oauth/tests/test_client.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,69 @@ | ||||
| from django.test import RequestFactory, TestCase | ||||
| from guardian.shortcuts import get_anonymous_user | ||||
|  | ||||
| from authentik.lib.generators import generate_id | ||||
| from authentik.sources.oauth.clients.oauth2 import OAuth2Client | ||||
| from authentik.sources.oauth.models import AuthorizationCodeAuthMethod, OAuthSource | ||||
| from authentik.sources.oauth.types.oidc import OpenIDConnectClient | ||||
|  | ||||
|  | ||||
| class TestOAuthClient(TestCase): | ||||
|     """OAuth Source tests""" | ||||
|  | ||||
|     def setUp(self): | ||||
|         self.source = OAuthSource.objects.create( | ||||
|             name="test", | ||||
|             slug="test", | ||||
|             provider_type="openidconnect", | ||||
|             authorization_url="", | ||||
|             profile_url="", | ||||
|             consumer_key=generate_id(), | ||||
|         ) | ||||
|         self.factory = RequestFactory() | ||||
|  | ||||
|     def test_client_post_body_auth(self): | ||||
|         """Test login_challenge""" | ||||
|         self.source.provider_type = "apple" | ||||
|         self.source.save() | ||||
|         request = self.factory.get("/") | ||||
|         request.session = {} | ||||
|         request.user = get_anonymous_user() | ||||
|         client = OAuth2Client(self.source, request) | ||||
|         self.assertIsNone(client.get_access_token_auth()) | ||||
|         args = client.get_access_token_args("", "") | ||||
|         self.assertIn("client_id", args) | ||||
|         self.assertIn("client_secret", args) | ||||
|  | ||||
|     def test_client_basic_auth(self): | ||||
|         """Test login_challenge""" | ||||
|         self.source.provider_type = "reddit" | ||||
|         self.source.save() | ||||
|         request = self.factory.get("/") | ||||
|         request.session = {} | ||||
|         request.user = get_anonymous_user() | ||||
|         client = OAuth2Client(self.source, request) | ||||
|         self.assertIsNotNone(client.get_access_token_auth()) | ||||
|         args = client.get_access_token_args("", "") | ||||
|         self.assertNotIn("client_id", args) | ||||
|         self.assertNotIn("client_secret", args) | ||||
|  | ||||
|     def test_client_openid_auth(self): | ||||
|         """Test login_challenge""" | ||||
|         request = self.factory.get("/") | ||||
|         request.session = {} | ||||
|         request.user = get_anonymous_user() | ||||
|         client = OpenIDConnectClient(self.source, request) | ||||
|  | ||||
|         self.assertIsNotNone(client.get_access_token_auth()) | ||||
|         args = client.get_access_token_args("", "") | ||||
|         self.assertNotIn("client_id", args) | ||||
|         self.assertNotIn("client_secret", args) | ||||
|  | ||||
|         self.source.authorization_code_auth_method = AuthorizationCodeAuthMethod.POST_BODY | ||||
|         self.source.save() | ||||
|         client = OpenIDConnectClient(self.source, request) | ||||
|  | ||||
|         self.assertIsNone(client.get_access_token_auth()) | ||||
|         args = client.get_access_token_args("", "") | ||||
|         self.assertIn("client_id", args) | ||||
|         self.assertIn("client_secret", args) | ||||
| @ -11,7 +11,7 @@ from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.flows.challenge import Challenge, ChallengeResponse | ||||
| from authentik.sources.oauth.clients.oauth2 import OAuth2Client | ||||
| from authentik.sources.oauth.models import OAuthSource | ||||
| from authentik.sources.oauth.models import AuthorizationCodeAuthMethod, OAuthSource | ||||
| from authentik.sources.oauth.types.registry import SourceType, registry | ||||
| from authentik.sources.oauth.views.callback import OAuthCallback | ||||
| from authentik.sources.oauth.views.redirect import OAuthRedirect | ||||
| @ -105,6 +105,8 @@ class AppleType(SourceType): | ||||
|     access_token_url = "https://appleid.apple.com/auth/token"  # nosec | ||||
|     profile_url = "" | ||||
|  | ||||
|     authorization_code_auth_method = AuthorizationCodeAuthMethod.POST_BODY | ||||
|  | ||||
|     def login_challenge(self, source: OAuthSource, request: HttpRequest) -> Challenge: | ||||
|         """Pre-general all the things required for the JS SDK""" | ||||
|         apple_client = AppleOAuthClient( | ||||
|  | ||||
| @ -6,6 +6,7 @@ from requests import RequestException | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient | ||||
| from authentik.sources.oauth.models import AuthorizationCodeAuthMethod | ||||
| from authentik.sources.oauth.types.oidc import OpenIDConnectOAuth2Callback | ||||
| from authentik.sources.oauth.types.registry import SourceType, registry | ||||
| from authentik.sources.oauth.views.redirect import OAuthRedirect | ||||
| @ -77,6 +78,8 @@ class AzureADType(SourceType): | ||||
|     ) | ||||
|     oidc_jwks_url = "https://login.microsoftonline.com/common/discovery/keys" | ||||
|  | ||||
|     authorization_code_auth_method = AuthorizationCodeAuthMethod.POST_BODY | ||||
|  | ||||
|     def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]: | ||||
|         mail = info.get("mail", None) or info.get("otherMails", [None])[0] | ||||
|         # Format group info | ||||
|  | ||||
| @ -2,8 +2,8 @@ | ||||
|  | ||||
| from typing import Any | ||||
|  | ||||
| from authentik.sources.oauth.models import AuthorizationCodeAuthMethod | ||||
| from authentik.sources.oauth.types.registry import SourceType, registry | ||||
| from authentik.sources.oauth.views.callback import OAuthCallback | ||||
| from authentik.sources.oauth.views.redirect import OAuthRedirect | ||||
|  | ||||
|  | ||||
| @ -16,15 +16,10 @@ class FacebookOAuthRedirect(OAuthRedirect): | ||||
|         } | ||||
|  | ||||
|  | ||||
| class FacebookOAuth2Callback(OAuthCallback): | ||||
|     """Facebook OAuth2 Callback""" | ||||
|  | ||||
|  | ||||
| @registry.register() | ||||
| class FacebookType(SourceType): | ||||
|     """Facebook Type definition""" | ||||
|  | ||||
|     callback_view = FacebookOAuth2Callback | ||||
|     redirect_view = FacebookOAuthRedirect | ||||
|     verbose_name = "Facebook" | ||||
|     name = "facebook" | ||||
| @ -33,6 +28,8 @@ class FacebookType(SourceType): | ||||
|     access_token_url = "https://graph.facebook.com/v7.0/oauth/access_token"  # nosec | ||||
|     profile_url = "https://graph.facebook.com/v7.0/me?fields=id,name,email" | ||||
|  | ||||
|     authorization_code_auth_method = AuthorizationCodeAuthMethod.POST_BODY | ||||
|  | ||||
|     def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]: | ||||
|         return { | ||||
|             "username": info.get("name"), | ||||
|  | ||||
| @ -5,7 +5,7 @@ from typing import Any | ||||
| from requests.exceptions import RequestException | ||||
|  | ||||
| from authentik.sources.oauth.clients.oauth2 import OAuth2Client | ||||
| from authentik.sources.oauth.models import OAuthSource | ||||
| from authentik.sources.oauth.models import AuthorizationCodeAuthMethod, OAuthSource | ||||
| from authentik.sources.oauth.types.registry import SourceType, registry | ||||
| from authentik.sources.oauth.views.callback import OAuthCallback | ||||
| from authentik.sources.oauth.views.redirect import OAuthRedirect | ||||
| @ -63,6 +63,8 @@ class GitHubType(SourceType): | ||||
|     ) | ||||
|     oidc_jwks_url = "https://token.actions.githubusercontent.com/.well-known/jwks" | ||||
|  | ||||
|     authorization_code_auth_method = AuthorizationCodeAuthMethod.POST_BODY | ||||
|  | ||||
|     def get_base_user_properties( | ||||
|         self, | ||||
|         source: OAuthSource, | ||||
|  | ||||
| @ -7,9 +7,8 @@ and https://docs.gitlab.com/ee/integration/openid_connect_provider.html | ||||
|  | ||||
| from typing import Any | ||||
|  | ||||
| from authentik.sources.oauth.models import OAuthSource | ||||
| from authentik.sources.oauth.models import AuthorizationCodeAuthMethod, OAuthSource | ||||
| from authentik.sources.oauth.types.registry import SourceType, registry | ||||
| from authentik.sources.oauth.views.callback import OAuthCallback | ||||
| from authentik.sources.oauth.views.redirect import OAuthRedirect | ||||
|  | ||||
|  | ||||
| @ -22,15 +21,10 @@ class GitLabOAuthRedirect(OAuthRedirect): | ||||
|         } | ||||
|  | ||||
|  | ||||
| class GitLabOAuthCallback(OAuthCallback): | ||||
|     """GitLab OAuth2 Callback""" | ||||
|  | ||||
|  | ||||
| @registry.register() | ||||
| class GitLabType(SourceType): | ||||
|     """GitLab Type definition""" | ||||
|  | ||||
|     callback_view = GitLabOAuthCallback | ||||
|     redirect_view = GitLabOAuthRedirect | ||||
|     verbose_name = "GitLab" | ||||
|     name = "gitlab" | ||||
| @ -43,6 +37,8 @@ class GitLabType(SourceType): | ||||
|     oidc_well_known_url = "https://gitlab.com/.well-known/openid-configuration" | ||||
|     oidc_jwks_url = "https://gitlab.com/oauth/discovery/keys" | ||||
|  | ||||
|     authorization_code_auth_method = AuthorizationCodeAuthMethod.POST_BODY | ||||
|  | ||||
|     def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]: | ||||
|         return { | ||||
|             "username": info.get("preferred_username"), | ||||
|  | ||||
| @ -2,8 +2,8 @@ | ||||
|  | ||||
| from typing import Any | ||||
|  | ||||
| from authentik.sources.oauth.models import AuthorizationCodeAuthMethod | ||||
| from authentik.sources.oauth.types.registry import SourceType, registry | ||||
| from authentik.sources.oauth.views.callback import OAuthCallback | ||||
| from authentik.sources.oauth.views.redirect import OAuthRedirect | ||||
|  | ||||
|  | ||||
| @ -16,15 +16,10 @@ class GoogleOAuthRedirect(OAuthRedirect): | ||||
|         } | ||||
|  | ||||
|  | ||||
| class GoogleOAuth2Callback(OAuthCallback): | ||||
|     """Google OAuth2 Callback""" | ||||
|  | ||||
|  | ||||
| @registry.register() | ||||
| class GoogleType(SourceType): | ||||
|     """Google Type definition""" | ||||
|  | ||||
|     callback_view = GoogleOAuth2Callback | ||||
|     redirect_view = GoogleOAuthRedirect | ||||
|     verbose_name = "Google" | ||||
|     name = "google" | ||||
| @ -35,6 +30,8 @@ class GoogleType(SourceType): | ||||
|     oidc_well_known_url = "https://accounts.google.com/.well-known/openid-configuration" | ||||
|     oidc_jwks_url = "https://www.googleapis.com/oauth2/v3/certs" | ||||
|  | ||||
|     authorization_code_auth_method = AuthorizationCodeAuthMethod.POST_BODY | ||||
|  | ||||
|     def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]: | ||||
|         return { | ||||
|             "email": info.get("email"), | ||||
|  | ||||
| @ -6,6 +6,7 @@ from requests.exceptions import RequestException | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.sources.oauth.clients.oauth2 import OAuth2Client | ||||
| from authentik.sources.oauth.models import AuthorizationCodeAuthMethod | ||||
| from authentik.sources.oauth.types.registry import SourceType, registry | ||||
| from authentik.sources.oauth.views.callback import OAuthCallback | ||||
| from authentik.sources.oauth.views.redirect import OAuthRedirect | ||||
| @ -59,6 +60,8 @@ class MailcowType(SourceType): | ||||
|  | ||||
|     urls_customizable = True | ||||
|  | ||||
|     authorization_code_auth_method = AuthorizationCodeAuthMethod.POST_BODY | ||||
|  | ||||
|     def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]: | ||||
|         return { | ||||
|             "username": info.get("full_name"), | ||||
|  | ||||
| @ -2,8 +2,10 @@ | ||||
|  | ||||
| from typing import Any | ||||
|  | ||||
| from requests.auth import AuthBase, HTTPBasicAuth | ||||
|  | ||||
| from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient | ||||
| from authentik.sources.oauth.models import OAuthSource | ||||
| from authentik.sources.oauth.models import AuthorizationCodeAuthMethod, OAuthSource | ||||
| from authentik.sources.oauth.types.registry import SourceType, registry | ||||
| from authentik.sources.oauth.views.callback import OAuthCallback | ||||
| from authentik.sources.oauth.views.redirect import OAuthRedirect | ||||
| @ -18,10 +20,27 @@ class OpenIDConnectOAuthRedirect(OAuthRedirect): | ||||
|         } | ||||
|  | ||||
|  | ||||
| class OpenIDConnectClient(UserprofileHeaderAuthClient): | ||||
|     def get_access_token_args(self, callback: str, code: str) -> dict[str, Any]: | ||||
|         args = super().get_access_token_args(callback, code) | ||||
|         if self.source.authorization_code_auth_method == AuthorizationCodeAuthMethod.POST_BODY: | ||||
|             args["client_id"] = self.get_client_id() | ||||
|             args["client_secret"] = self.get_client_secret() | ||||
|         else: | ||||
|             args.pop("client_id", None) | ||||
|             args.pop("client_secret", None) | ||||
|         return args | ||||
|  | ||||
|     def get_access_token_auth(self) -> AuthBase | None: | ||||
|         if self.source.authorization_code_auth_method == AuthorizationCodeAuthMethod.BASIC_AUTH: | ||||
|             return HTTPBasicAuth(self.get_client_id(), self.get_client_secret()) | ||||
|         return None | ||||
|  | ||||
|  | ||||
| class OpenIDConnectOAuth2Callback(OAuthCallback): | ||||
|     """OpenIDConnect OAuth2 Callback""" | ||||
|  | ||||
|     client_class = UserprofileHeaderAuthClient | ||||
|     client_class = OpenIDConnectClient | ||||
|  | ||||
|     def get_user_id(self, info: dict[str, str]) -> str: | ||||
|         return info.get("sub", None) | ||||
|  | ||||
| @ -2,7 +2,6 @@ | ||||
|  | ||||
| from typing import Any | ||||
|  | ||||
| from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient | ||||
| from authentik.sources.oauth.models import OAuthSource | ||||
| from authentik.sources.oauth.types.oidc import OpenIDConnectOAuth2Callback | ||||
| from authentik.sources.oauth.types.registry import SourceType, registry | ||||
| @ -18,20 +17,11 @@ class OktaOAuthRedirect(OAuthRedirect): | ||||
|         } | ||||
|  | ||||
|  | ||||
| class OktaOAuth2Callback(OpenIDConnectOAuth2Callback): | ||||
|     """Okta OAuth2 Callback""" | ||||
|  | ||||
|     # Okta has the same quirk as azure and throws an error if the access token | ||||
|     # is set via query parameter, so we reuse the azure client | ||||
|     # see https://github.com/goauthentik/authentik/issues/1910 | ||||
|     client_class = UserprofileHeaderAuthClient | ||||
|  | ||||
|  | ||||
| @registry.register() | ||||
| class OktaType(SourceType): | ||||
|     """Okta Type definition""" | ||||
|  | ||||
|     callback_view = OktaOAuth2Callback | ||||
|     callback_view = OpenIDConnectOAuth2Callback | ||||
|     redirect_view = OktaOAuthRedirect | ||||
|     verbose_name = "Okta" | ||||
|     name = "okta" | ||||
|  | ||||
| @ -3,7 +3,7 @@ | ||||
| from typing import Any | ||||
|  | ||||
| from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient | ||||
| from authentik.sources.oauth.models import OAuthSource | ||||
| from authentik.sources.oauth.models import AuthorizationCodeAuthMethod, OAuthSource | ||||
| from authentik.sources.oauth.types.registry import SourceType, registry | ||||
| from authentik.sources.oauth.views.callback import OAuthCallback | ||||
| from authentik.sources.oauth.views.redirect import OAuthRedirect | ||||
| @ -41,6 +41,8 @@ class PatreonType(SourceType): | ||||
|     access_token_url = "https://www.patreon.com/api/oauth2/token"  # nosec | ||||
|     profile_url = "https://www.patreon.com/api/oauth2/api/current_user" | ||||
|  | ||||
|     authorization_code_auth_method = AuthorizationCodeAuthMethod.POST_BODY | ||||
|  | ||||
|     def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]: | ||||
|         return { | ||||
|             "username": info.get("data", {}).get("attributes", {}).get("vanity"), | ||||
|  | ||||
| @ -2,8 +2,6 @@ | ||||
|  | ||||
| from typing import Any | ||||
|  | ||||
| from requests.auth import HTTPBasicAuth | ||||
|  | ||||
| from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient | ||||
| from authentik.sources.oauth.types.registry import SourceType, registry | ||||
| from authentik.sources.oauth.views.callback import OAuthCallback | ||||
| @ -20,21 +18,10 @@ class RedditOAuthRedirect(OAuthRedirect): | ||||
|         } | ||||
|  | ||||
|  | ||||
| class RedditOAuth2Client(UserprofileHeaderAuthClient): | ||||
|     """Reddit OAuth2 Client""" | ||||
|  | ||||
|     def get_access_token(self, **request_kwargs): | ||||
|         "Fetch access token from callback request." | ||||
|         request_kwargs["auth"] = HTTPBasicAuth( | ||||
|             self.source.consumer_key, self.source.consumer_secret | ||||
|         ) | ||||
|         return super().get_access_token(**request_kwargs) | ||||
|  | ||||
|  | ||||
| class RedditOAuth2Callback(OAuthCallback): | ||||
|     """Reddit OAuth2 Callback""" | ||||
|  | ||||
|     client_class = RedditOAuth2Client | ||||
|     client_class = UserprofileHeaderAuthClient | ||||
|  | ||||
|  | ||||
| @registry.register() | ||||
|  | ||||
| @ -10,7 +10,7 @@ from django.urls.base import reverse | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.flows.challenge import Challenge, RedirectChallenge | ||||
| from authentik.sources.oauth.models import OAuthSource | ||||
| from authentik.sources.oauth.models import AuthorizationCodeAuthMethod, OAuthSource | ||||
| from authentik.sources.oauth.views.callback import OAuthCallback | ||||
| from authentik.sources.oauth.views.redirect import OAuthRedirect | ||||
|  | ||||
| @ -41,6 +41,10 @@ class SourceType: | ||||
|     oidc_well_known_url: str | None = None | ||||
|     oidc_jwks_url: str | None = None | ||||
|  | ||||
|     authorization_code_auth_method: AuthorizationCodeAuthMethod = ( | ||||
|         AuthorizationCodeAuthMethod.BASIC_AUTH | ||||
|     ) | ||||
|  | ||||
|     def icon_url(self) -> str: | ||||
|         """Get Icon URL for login""" | ||||
|         return static(f"authentik/sources/{self.name}.svg") | ||||
|  | ||||
| @ -4,6 +4,7 @@ from json import dumps | ||||
| from typing import Any | ||||
|  | ||||
| from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient | ||||
| from authentik.sources.oauth.models import AuthorizationCodeAuthMethod | ||||
| from authentik.sources.oauth.types.oidc import OpenIDConnectOAuth2Callback | ||||
| from authentik.sources.oauth.types.registry import SourceType, registry | ||||
| from authentik.sources.oauth.views.redirect import OAuthRedirect | ||||
| @ -47,6 +48,8 @@ class TwitchType(SourceType): | ||||
|     access_token_url = "https://id.twitch.tv/oauth2/token"  # nosec | ||||
|     profile_url = "https://id.twitch.tv/oauth2/userinfo" | ||||
|  | ||||
|     authorization_code_auth_method = AuthorizationCodeAuthMethod.POST_BODY | ||||
|  | ||||
|     def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]: | ||||
|         return { | ||||
|             "username": info.get("preferred_username"), | ||||
|  | ||||
| @ -12,23 +12,6 @@ from authentik.sources.oauth.views.callback import OAuthCallback | ||||
| from authentik.sources.oauth.views.redirect import OAuthRedirect | ||||
|  | ||||
|  | ||||
| class TwitterClient(UserprofileHeaderAuthClient): | ||||
|     """Twitter has similar quirks to Azure AD, and additionally requires Basic auth on | ||||
|     the access token endpoint for some reason.""" | ||||
|  | ||||
|     # Twitter has the same quirk as azure and throws an error if the access token | ||||
|     # is set via query parameter, so we reuse the azure client | ||||
|     # see https://github.com/goauthentik/authentik/issues/1910 | ||||
|  | ||||
|     def get_access_token(self, **request_kwargs) -> dict[str, Any] | None: | ||||
|         return super().get_access_token( | ||||
|             auth=( | ||||
|                 self.source.consumer_key, | ||||
|                 self.source.consumer_secret, | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class TwitterOAuthRedirect(OAuthRedirect): | ||||
|     """Twitter OAuth2 Redirect""" | ||||
|  | ||||
| @ -44,7 +27,7 @@ class TwitterOAuthRedirect(OAuthRedirect): | ||||
| class TwitterOAuthCallback(OAuthCallback): | ||||
|     """Twitter OAuth2 Callback""" | ||||
|  | ||||
|     client_class = TwitterClient | ||||
|     client_class = UserprofileHeaderAuthClient | ||||
|  | ||||
|     def get_user_id(self, info: dict[str, str]) -> str: | ||||
|         return info.get("data", {}).get("id", "") | ||||
|  | ||||
| @ -8436,6 +8436,15 @@ | ||||
|                     "type": "object", | ||||
|                     "additionalProperties": true, | ||||
|                     "title": "Oidc jwks" | ||||
|                 }, | ||||
|                 "authorization_code_auth_method": { | ||||
|                     "type": "string", | ||||
|                     "enum": [ | ||||
|                         "basic_auth", | ||||
|                         "post_body" | ||||
|                     ], | ||||
|                     "title": "Authorization code auth method", | ||||
|                     "description": "How to perform authentication during an authorization_code token request flow" | ||||
|                 } | ||||
|             }, | ||||
|             "required": [] | ||||
|  | ||||
							
								
								
									
										20
									
								
								schema.yml
									
									
									
									
									
								
							
							
						
						
									
										20
									
								
								schema.yml
									
									
									
									
									
								
							| @ -42110,6 +42110,11 @@ components: | ||||
|             format: uuid | ||||
|       required: | ||||
|       - name | ||||
|     AuthorizationCodeAuthMethodEnum: | ||||
|       enum: | ||||
|       - basic_auth | ||||
|       - post_body | ||||
|       type: string | ||||
|     AutoSubmitChallengeResponseRequest: | ||||
|       type: object | ||||
|       description: Pseudo class for autosubmit response | ||||
| @ -48742,6 +48747,11 @@ components: | ||||
|         oidc_jwks_url: | ||||
|           type: string | ||||
|         oidc_jwks: {} | ||||
|         authorization_code_auth_method: | ||||
|           allOf: | ||||
|           - $ref: '#/components/schemas/AuthorizationCodeAuthMethodEnum' | ||||
|           description: How to perform authentication during an authorization_code | ||||
|             token request flow | ||||
|       required: | ||||
|       - callback_url | ||||
|       - component | ||||
| @ -48911,6 +48921,11 @@ components: | ||||
|         oidc_jwks_url: | ||||
|           type: string | ||||
|         oidc_jwks: {} | ||||
|         authorization_code_auth_method: | ||||
|           allOf: | ||||
|           - $ref: '#/components/schemas/AuthorizationCodeAuthMethodEnum' | ||||
|           description: How to perform authentication during an authorization_code | ||||
|             token request flow | ||||
|       required: | ||||
|       - consumer_key | ||||
|       - consumer_secret | ||||
| @ -53009,6 +53024,11 @@ components: | ||||
|         oidc_jwks_url: | ||||
|           type: string | ||||
|         oidc_jwks: {} | ||||
|         authorization_code_auth_method: | ||||
|           allOf: | ||||
|           - $ref: '#/components/schemas/AuthorizationCodeAuthMethodEnum' | ||||
|           description: How to perform authentication during an authorization_code | ||||
|             token request flow | ||||
|     PatchedOutpostRequest: | ||||
|       type: object | ||||
|       description: Outpost Serializer | ||||
|  | ||||
| @ -7,6 +7,7 @@ import { | ||||
| } from "@goauthentik/admin/sources/oauth/utils"; | ||||
| import { DEFAULT_CONFIG, config } from "@goauthentik/common/api/config"; | ||||
| import { first } from "@goauthentik/common/utils"; | ||||
| import "@goauthentik/components/ak-radio-input"; | ||||
| import "@goauthentik/elements/CodeMirror"; | ||||
| import { CodeMirrorMode } from "@goauthentik/elements/CodeMirror"; | ||||
| import { | ||||
| @ -16,6 +17,7 @@ import { | ||||
| import "@goauthentik/elements/ak-dual-select/ak-dual-select-dynamic-selected-provider.js"; | ||||
| import "@goauthentik/elements/forms/FormGroup"; | ||||
| import "@goauthentik/elements/forms/HorizontalFormElement"; | ||||
| import "@goauthentik/elements/forms/Radio"; | ||||
| import "@goauthentik/elements/forms/SearchSelect"; | ||||
|  | ||||
| import { msg } from "@lit/localize"; | ||||
| @ -24,6 +26,7 @@ import { customElement, property, state } from "lit/decorators.js"; | ||||
| import { ifDefined } from "lit/directives/if-defined.js"; | ||||
|  | ||||
| import { | ||||
|     AuthorizationCodeAuthMethodEnum, | ||||
|     FlowsInstancesListDesignationEnum, | ||||
|     GroupMatchingModeEnum, | ||||
|     OAuthSource, | ||||
| @ -36,6 +39,18 @@ import { | ||||
|  | ||||
| import { propertyMappingsProvider, propertyMappingsSelector } from "./OAuthSourceFormHelpers.js"; | ||||
|  | ||||
| const authorizationCodeAuthMethodOptions = [ | ||||
|     { | ||||
|         label: msg("HTTP Basic Auth"), | ||||
|         value: AuthorizationCodeAuthMethodEnum.BasicAuth, | ||||
|         default: true, | ||||
|     }, | ||||
|     { | ||||
|         label: msg("Include the client ID and secret as request parameters"), | ||||
|         value: AuthorizationCodeAuthMethodEnum.PostBody, | ||||
|     }, | ||||
| ]; | ||||
|  | ||||
| @customElement("ak-source-oauth-form") | ||||
| export class OAuthSourceForm extends WithCapabilitiesConfig(BaseSourceForm<OAuthSource>) { | ||||
|     async loadInstance(pk: string): Promise<OAuthSource> { | ||||
| @ -240,6 +255,19 @@ export class OAuthSourceForm extends WithCapabilitiesConfig(BaseSourceForm<OAuth | ||||
|                               <p class="pf-c-form__helper-text">${msg("Raw JWKS data.")}</p> | ||||
|                           </ak-form-element-horizontal>` | ||||
|                     : html``} | ||||
|                 ${this.providerType.name === ProviderTypeEnum.Openidconnect | ||||
|                     ? html`<ak-radio-input | ||||
|                           label=${msg("Authorization code authentication method")} | ||||
|                           name="authorizationCodeAuthMethod" | ||||
|                           required | ||||
|                           .options=${authorizationCodeAuthMethodOptions} | ||||
|                           .value=${this.instance?.authorizationCodeAuthMethod} | ||||
|                           help=${msg( | ||||
|                               "How to perform authentication during an authorization_code token request flow", | ||||
|                           )} | ||||
|                       > | ||||
|                       </ak-radio-input>` | ||||
|                     : html``} | ||||
|             </div> | ||||
|         </ak-form-group>`; | ||||
|     } | ||||
|  | ||||
		Reference in New Issue
	
	Block a user