sources/oauth: introduce authorization code auth method (#14034)
Co-authored-by: Rsgm <rsgm123@gmail.com>
This commit is contained in:

committed by
GitHub

parent
c6f9d5df7b
commit
155a31fd70
@ -130,6 +130,7 @@ class OAuthSourceSerializer(SourceSerializer):
|
|||||||
"oidc_well_known_url",
|
"oidc_well_known_url",
|
||||||
"oidc_jwks_url",
|
"oidc_jwks_url",
|
||||||
"oidc_jwks",
|
"oidc_jwks",
|
||||||
|
"authorization_code_auth_method",
|
||||||
]
|
]
|
||||||
extra_kwargs = {
|
extra_kwargs = {
|
||||||
"consumer_secret": {"write_only": True},
|
"consumer_secret": {"write_only": True},
|
||||||
|
@ -6,11 +6,15 @@ from urllib.parse import parse_qsl
|
|||||||
|
|
||||||
from django.utils.crypto import constant_time_compare, get_random_string
|
from django.utils.crypto import constant_time_compare, get_random_string
|
||||||
from django.utils.translation import gettext as _
|
from django.utils.translation import gettext as _
|
||||||
|
from requests.auth import AuthBase, HTTPBasicAuth
|
||||||
from requests.exceptions import RequestException
|
from requests.exceptions import RequestException
|
||||||
from requests.models import Response
|
from requests.models import Response
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
|
|
||||||
from authentik.sources.oauth.clients.base import BaseOAuthClient
|
from authentik.sources.oauth.clients.base import BaseOAuthClient
|
||||||
|
from authentik.sources.oauth.models import (
|
||||||
|
AuthorizationCodeAuthMethod,
|
||||||
|
)
|
||||||
|
|
||||||
LOGGER = get_logger()
|
LOGGER = get_logger()
|
||||||
SESSION_KEY_OAUTH_PKCE = "authentik/sources/oauth/pkce"
|
SESSION_KEY_OAUTH_PKCE = "authentik/sources/oauth/pkce"
|
||||||
@ -55,6 +59,30 @@ class OAuth2Client(BaseOAuthClient):
|
|||||||
"""Get client secret"""
|
"""Get client secret"""
|
||||||
return self.source.consumer_secret
|
return self.source.consumer_secret
|
||||||
|
|
||||||
|
def get_access_token_args(self, callback: str, code: str) -> dict[str, Any]:
|
||||||
|
args = {
|
||||||
|
"redirect_uri": callback,
|
||||||
|
"code": code,
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
}
|
||||||
|
if SESSION_KEY_OAUTH_PKCE in self.request.session:
|
||||||
|
args["code_verifier"] = self.request.session[SESSION_KEY_OAUTH_PKCE]
|
||||||
|
if (
|
||||||
|
self.source.source_type.authorization_code_auth_method
|
||||||
|
== AuthorizationCodeAuthMethod.POST_BODY
|
||||||
|
):
|
||||||
|
args["client_id"] = self.get_client_id()
|
||||||
|
args["client_secret"] = self.get_client_secret()
|
||||||
|
return args
|
||||||
|
|
||||||
|
def get_access_token_auth(self) -> AuthBase | None:
|
||||||
|
if (
|
||||||
|
self.source.source_type.authorization_code_auth_method
|
||||||
|
== AuthorizationCodeAuthMethod.BASIC_AUTH
|
||||||
|
):
|
||||||
|
return HTTPBasicAuth(self.get_client_id(), self.get_client_secret())
|
||||||
|
return None
|
||||||
|
|
||||||
def get_access_token(self, **request_kwargs) -> dict[str, Any] | None:
|
def get_access_token(self, **request_kwargs) -> dict[str, Any] | None:
|
||||||
"""Fetch access token from callback request."""
|
"""Fetch access token from callback request."""
|
||||||
callback = self.request.build_absolute_uri(self.callback or self.request.path)
|
callback = self.request.build_absolute_uri(self.callback or self.request.path)
|
||||||
@ -67,13 +95,6 @@ class OAuth2Client(BaseOAuthClient):
|
|||||||
error = self.get_request_arg("error", None)
|
error = self.get_request_arg("error", None)
|
||||||
error_desc = self.get_request_arg("error_description", None)
|
error_desc = self.get_request_arg("error_description", None)
|
||||||
return {"error": error_desc or error or _("No token received.")}
|
return {"error": error_desc or error or _("No token received.")}
|
||||||
args = {
|
|
||||||
"redirect_uri": callback,
|
|
||||||
"code": code,
|
|
||||||
"grant_type": "authorization_code",
|
|
||||||
}
|
|
||||||
if SESSION_KEY_OAUTH_PKCE in self.request.session:
|
|
||||||
args["code_verifier"] = self.request.session[SESSION_KEY_OAUTH_PKCE]
|
|
||||||
try:
|
try:
|
||||||
access_token_url = self.source.source_type.access_token_url or ""
|
access_token_url = self.source.source_type.access_token_url or ""
|
||||||
if self.source.source_type.urls_customizable and self.source.access_token_url:
|
if self.source.source_type.urls_customizable and self.source.access_token_url:
|
||||||
@ -81,8 +102,8 @@ class OAuth2Client(BaseOAuthClient):
|
|||||||
response = self.do_request(
|
response = self.do_request(
|
||||||
"post",
|
"post",
|
||||||
access_token_url,
|
access_token_url,
|
||||||
auth=(self.get_client_id(), self.get_client_secret()),
|
auth=self.get_access_token_auth(),
|
||||||
data=args,
|
data=self.get_access_token_args(callback, code),
|
||||||
headers=self._default_headers,
|
headers=self._default_headers,
|
||||||
**request_kwargs,
|
**request_kwargs,
|
||||||
)
|
)
|
||||||
|
@ -0,0 +1,25 @@
|
|||||||
|
# Generated by Django 5.0.14 on 2025-04-11 18:09
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
("authentik_sources_oauth", "0009_migrate_useroauthsourceconnection_identifier"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="oauthsource",
|
||||||
|
name="authorization_code_auth_method",
|
||||||
|
field=models.TextField(
|
||||||
|
choices=[
|
||||||
|
("basic_auth", "HTTP Basic Authentication"),
|
||||||
|
("post_body", "Include the client ID and secret as request parameters"),
|
||||||
|
],
|
||||||
|
default="basic_auth",
|
||||||
|
help_text="How to perform authentication during an authorization_code token request flow",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
@ -21,6 +21,11 @@ if TYPE_CHECKING:
|
|||||||
from authentik.sources.oauth.types.registry import SourceType
|
from authentik.sources.oauth.types.registry import SourceType
|
||||||
|
|
||||||
|
|
||||||
|
class AuthorizationCodeAuthMethod(models.TextChoices):
|
||||||
|
BASIC_AUTH = "basic_auth", _("HTTP Basic Authentication")
|
||||||
|
POST_BODY = "post_body", _("Include the client ID and secret as request parameters")
|
||||||
|
|
||||||
|
|
||||||
class OAuthSource(NonCreatableType, Source):
|
class OAuthSource(NonCreatableType, Source):
|
||||||
"""Login using a Generic OAuth provider."""
|
"""Login using a Generic OAuth provider."""
|
||||||
|
|
||||||
@ -61,6 +66,14 @@ class OAuthSource(NonCreatableType, Source):
|
|||||||
oidc_jwks_url = models.TextField(default="", blank=True)
|
oidc_jwks_url = models.TextField(default="", blank=True)
|
||||||
oidc_jwks = models.JSONField(default=dict, blank=True)
|
oidc_jwks = models.JSONField(default=dict, blank=True)
|
||||||
|
|
||||||
|
authorization_code_auth_method = models.TextField(
|
||||||
|
choices=AuthorizationCodeAuthMethod.choices,
|
||||||
|
default=AuthorizationCodeAuthMethod.BASIC_AUTH,
|
||||||
|
help_text=_(
|
||||||
|
"How to perform authentication during an authorization_code token request flow"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def source_type(self) -> type["SourceType"]:
|
def source_type(self) -> type["SourceType"]:
|
||||||
"""Return the provider instance for this source"""
|
"""Return the provider instance for this source"""
|
||||||
|
69
authentik/sources/oauth/tests/test_client.py
Normal file
69
authentik/sources/oauth/tests/test_client.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
from django.test import RequestFactory, TestCase
|
||||||
|
from guardian.shortcuts import get_anonymous_user
|
||||||
|
|
||||||
|
from authentik.lib.generators import generate_id
|
||||||
|
from authentik.sources.oauth.clients.oauth2 import OAuth2Client
|
||||||
|
from authentik.sources.oauth.models import AuthorizationCodeAuthMethod, OAuthSource
|
||||||
|
from authentik.sources.oauth.types.oidc import OpenIDConnectClient
|
||||||
|
|
||||||
|
|
||||||
|
class TestOAuthClient(TestCase):
|
||||||
|
"""OAuth Source tests"""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.source = OAuthSource.objects.create(
|
||||||
|
name="test",
|
||||||
|
slug="test",
|
||||||
|
provider_type="openidconnect",
|
||||||
|
authorization_url="",
|
||||||
|
profile_url="",
|
||||||
|
consumer_key=generate_id(),
|
||||||
|
)
|
||||||
|
self.factory = RequestFactory()
|
||||||
|
|
||||||
|
def test_client_post_body_auth(self):
|
||||||
|
"""Test login_challenge"""
|
||||||
|
self.source.provider_type = "apple"
|
||||||
|
self.source.save()
|
||||||
|
request = self.factory.get("/")
|
||||||
|
request.session = {}
|
||||||
|
request.user = get_anonymous_user()
|
||||||
|
client = OAuth2Client(self.source, request)
|
||||||
|
self.assertIsNone(client.get_access_token_auth())
|
||||||
|
args = client.get_access_token_args("", "")
|
||||||
|
self.assertIn("client_id", args)
|
||||||
|
self.assertIn("client_secret", args)
|
||||||
|
|
||||||
|
def test_client_basic_auth(self):
|
||||||
|
"""Test login_challenge"""
|
||||||
|
self.source.provider_type = "reddit"
|
||||||
|
self.source.save()
|
||||||
|
request = self.factory.get("/")
|
||||||
|
request.session = {}
|
||||||
|
request.user = get_anonymous_user()
|
||||||
|
client = OAuth2Client(self.source, request)
|
||||||
|
self.assertIsNotNone(client.get_access_token_auth())
|
||||||
|
args = client.get_access_token_args("", "")
|
||||||
|
self.assertNotIn("client_id", args)
|
||||||
|
self.assertNotIn("client_secret", args)
|
||||||
|
|
||||||
|
def test_client_openid_auth(self):
|
||||||
|
"""Test login_challenge"""
|
||||||
|
request = self.factory.get("/")
|
||||||
|
request.session = {}
|
||||||
|
request.user = get_anonymous_user()
|
||||||
|
client = OpenIDConnectClient(self.source, request)
|
||||||
|
|
||||||
|
self.assertIsNotNone(client.get_access_token_auth())
|
||||||
|
args = client.get_access_token_args("", "")
|
||||||
|
self.assertNotIn("client_id", args)
|
||||||
|
self.assertNotIn("client_secret", args)
|
||||||
|
|
||||||
|
self.source.authorization_code_auth_method = AuthorizationCodeAuthMethod.POST_BODY
|
||||||
|
self.source.save()
|
||||||
|
client = OpenIDConnectClient(self.source, request)
|
||||||
|
|
||||||
|
self.assertIsNone(client.get_access_token_auth())
|
||||||
|
args = client.get_access_token_args("", "")
|
||||||
|
self.assertIn("client_id", args)
|
||||||
|
self.assertIn("client_secret", args)
|
@ -11,7 +11,7 @@ from structlog.stdlib import get_logger
|
|||||||
|
|
||||||
from authentik.flows.challenge import Challenge, ChallengeResponse
|
from authentik.flows.challenge import Challenge, ChallengeResponse
|
||||||
from authentik.sources.oauth.clients.oauth2 import OAuth2Client
|
from authentik.sources.oauth.clients.oauth2 import OAuth2Client
|
||||||
from authentik.sources.oauth.models import OAuthSource
|
from authentik.sources.oauth.models import AuthorizationCodeAuthMethod, OAuthSource
|
||||||
from authentik.sources.oauth.types.registry import SourceType, registry
|
from authentik.sources.oauth.types.registry import SourceType, registry
|
||||||
from authentik.sources.oauth.views.callback import OAuthCallback
|
from authentik.sources.oauth.views.callback import OAuthCallback
|
||||||
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
||||||
@ -105,6 +105,8 @@ class AppleType(SourceType):
|
|||||||
access_token_url = "https://appleid.apple.com/auth/token" # nosec
|
access_token_url = "https://appleid.apple.com/auth/token" # nosec
|
||||||
profile_url = ""
|
profile_url = ""
|
||||||
|
|
||||||
|
authorization_code_auth_method = AuthorizationCodeAuthMethod.POST_BODY
|
||||||
|
|
||||||
def login_challenge(self, source: OAuthSource, request: HttpRequest) -> Challenge:
|
def login_challenge(self, source: OAuthSource, request: HttpRequest) -> Challenge:
|
||||||
"""Pre-general all the things required for the JS SDK"""
|
"""Pre-general all the things required for the JS SDK"""
|
||||||
apple_client = AppleOAuthClient(
|
apple_client = AppleOAuthClient(
|
||||||
|
@ -6,6 +6,7 @@ from requests import RequestException
|
|||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
|
|
||||||
from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient
|
from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient
|
||||||
|
from authentik.sources.oauth.models import AuthorizationCodeAuthMethod
|
||||||
from authentik.sources.oauth.types.oidc import OpenIDConnectOAuth2Callback
|
from authentik.sources.oauth.types.oidc import OpenIDConnectOAuth2Callback
|
||||||
from authentik.sources.oauth.types.registry import SourceType, registry
|
from authentik.sources.oauth.types.registry import SourceType, registry
|
||||||
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
||||||
@ -77,6 +78,8 @@ class AzureADType(SourceType):
|
|||||||
)
|
)
|
||||||
oidc_jwks_url = "https://login.microsoftonline.com/common/discovery/keys"
|
oidc_jwks_url = "https://login.microsoftonline.com/common/discovery/keys"
|
||||||
|
|
||||||
|
authorization_code_auth_method = AuthorizationCodeAuthMethod.POST_BODY
|
||||||
|
|
||||||
def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]:
|
def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]:
|
||||||
mail = info.get("mail", None) or info.get("otherMails", [None])[0]
|
mail = info.get("mail", None) or info.get("otherMails", [None])[0]
|
||||||
# Format group info
|
# Format group info
|
||||||
|
@ -2,8 +2,8 @@
|
|||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from authentik.sources.oauth.models import AuthorizationCodeAuthMethod
|
||||||
from authentik.sources.oauth.types.registry import SourceType, registry
|
from authentik.sources.oauth.types.registry import SourceType, registry
|
||||||
from authentik.sources.oauth.views.callback import OAuthCallback
|
|
||||||
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
||||||
|
|
||||||
|
|
||||||
@ -16,15 +16,10 @@ class FacebookOAuthRedirect(OAuthRedirect):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class FacebookOAuth2Callback(OAuthCallback):
|
|
||||||
"""Facebook OAuth2 Callback"""
|
|
||||||
|
|
||||||
|
|
||||||
@registry.register()
|
@registry.register()
|
||||||
class FacebookType(SourceType):
|
class FacebookType(SourceType):
|
||||||
"""Facebook Type definition"""
|
"""Facebook Type definition"""
|
||||||
|
|
||||||
callback_view = FacebookOAuth2Callback
|
|
||||||
redirect_view = FacebookOAuthRedirect
|
redirect_view = FacebookOAuthRedirect
|
||||||
verbose_name = "Facebook"
|
verbose_name = "Facebook"
|
||||||
name = "facebook"
|
name = "facebook"
|
||||||
@ -33,6 +28,8 @@ class FacebookType(SourceType):
|
|||||||
access_token_url = "https://graph.facebook.com/v7.0/oauth/access_token" # nosec
|
access_token_url = "https://graph.facebook.com/v7.0/oauth/access_token" # nosec
|
||||||
profile_url = "https://graph.facebook.com/v7.0/me?fields=id,name,email"
|
profile_url = "https://graph.facebook.com/v7.0/me?fields=id,name,email"
|
||||||
|
|
||||||
|
authorization_code_auth_method = AuthorizationCodeAuthMethod.POST_BODY
|
||||||
|
|
||||||
def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]:
|
def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"username": info.get("name"),
|
"username": info.get("name"),
|
||||||
|
@ -5,7 +5,7 @@ from typing import Any
|
|||||||
from requests.exceptions import RequestException
|
from requests.exceptions import RequestException
|
||||||
|
|
||||||
from authentik.sources.oauth.clients.oauth2 import OAuth2Client
|
from authentik.sources.oauth.clients.oauth2 import OAuth2Client
|
||||||
from authentik.sources.oauth.models import OAuthSource
|
from authentik.sources.oauth.models import AuthorizationCodeAuthMethod, OAuthSource
|
||||||
from authentik.sources.oauth.types.registry import SourceType, registry
|
from authentik.sources.oauth.types.registry import SourceType, registry
|
||||||
from authentik.sources.oauth.views.callback import OAuthCallback
|
from authentik.sources.oauth.views.callback import OAuthCallback
|
||||||
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
||||||
@ -63,6 +63,8 @@ class GitHubType(SourceType):
|
|||||||
)
|
)
|
||||||
oidc_jwks_url = "https://token.actions.githubusercontent.com/.well-known/jwks"
|
oidc_jwks_url = "https://token.actions.githubusercontent.com/.well-known/jwks"
|
||||||
|
|
||||||
|
authorization_code_auth_method = AuthorizationCodeAuthMethod.POST_BODY
|
||||||
|
|
||||||
def get_base_user_properties(
|
def get_base_user_properties(
|
||||||
self,
|
self,
|
||||||
source: OAuthSource,
|
source: OAuthSource,
|
||||||
|
@ -7,9 +7,8 @@ and https://docs.gitlab.com/ee/integration/openid_connect_provider.html
|
|||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from authentik.sources.oauth.models import OAuthSource
|
from authentik.sources.oauth.models import AuthorizationCodeAuthMethod, OAuthSource
|
||||||
from authentik.sources.oauth.types.registry import SourceType, registry
|
from authentik.sources.oauth.types.registry import SourceType, registry
|
||||||
from authentik.sources.oauth.views.callback import OAuthCallback
|
|
||||||
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
||||||
|
|
||||||
|
|
||||||
@ -22,15 +21,10 @@ class GitLabOAuthRedirect(OAuthRedirect):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class GitLabOAuthCallback(OAuthCallback):
|
|
||||||
"""GitLab OAuth2 Callback"""
|
|
||||||
|
|
||||||
|
|
||||||
@registry.register()
|
@registry.register()
|
||||||
class GitLabType(SourceType):
|
class GitLabType(SourceType):
|
||||||
"""GitLab Type definition"""
|
"""GitLab Type definition"""
|
||||||
|
|
||||||
callback_view = GitLabOAuthCallback
|
|
||||||
redirect_view = GitLabOAuthRedirect
|
redirect_view = GitLabOAuthRedirect
|
||||||
verbose_name = "GitLab"
|
verbose_name = "GitLab"
|
||||||
name = "gitlab"
|
name = "gitlab"
|
||||||
@ -43,6 +37,8 @@ class GitLabType(SourceType):
|
|||||||
oidc_well_known_url = "https://gitlab.com/.well-known/openid-configuration"
|
oidc_well_known_url = "https://gitlab.com/.well-known/openid-configuration"
|
||||||
oidc_jwks_url = "https://gitlab.com/oauth/discovery/keys"
|
oidc_jwks_url = "https://gitlab.com/oauth/discovery/keys"
|
||||||
|
|
||||||
|
authorization_code_auth_method = AuthorizationCodeAuthMethod.POST_BODY
|
||||||
|
|
||||||
def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]:
|
def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"username": info.get("preferred_username"),
|
"username": info.get("preferred_username"),
|
||||||
|
@ -2,8 +2,8 @@
|
|||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from authentik.sources.oauth.models import AuthorizationCodeAuthMethod
|
||||||
from authentik.sources.oauth.types.registry import SourceType, registry
|
from authentik.sources.oauth.types.registry import SourceType, registry
|
||||||
from authentik.sources.oauth.views.callback import OAuthCallback
|
|
||||||
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
||||||
|
|
||||||
|
|
||||||
@ -16,15 +16,10 @@ class GoogleOAuthRedirect(OAuthRedirect):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class GoogleOAuth2Callback(OAuthCallback):
|
|
||||||
"""Google OAuth2 Callback"""
|
|
||||||
|
|
||||||
|
|
||||||
@registry.register()
|
@registry.register()
|
||||||
class GoogleType(SourceType):
|
class GoogleType(SourceType):
|
||||||
"""Google Type definition"""
|
"""Google Type definition"""
|
||||||
|
|
||||||
callback_view = GoogleOAuth2Callback
|
|
||||||
redirect_view = GoogleOAuthRedirect
|
redirect_view = GoogleOAuthRedirect
|
||||||
verbose_name = "Google"
|
verbose_name = "Google"
|
||||||
name = "google"
|
name = "google"
|
||||||
@ -35,6 +30,8 @@ class GoogleType(SourceType):
|
|||||||
oidc_well_known_url = "https://accounts.google.com/.well-known/openid-configuration"
|
oidc_well_known_url = "https://accounts.google.com/.well-known/openid-configuration"
|
||||||
oidc_jwks_url = "https://www.googleapis.com/oauth2/v3/certs"
|
oidc_jwks_url = "https://www.googleapis.com/oauth2/v3/certs"
|
||||||
|
|
||||||
|
authorization_code_auth_method = AuthorizationCodeAuthMethod.POST_BODY
|
||||||
|
|
||||||
def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]:
|
def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"email": info.get("email"),
|
"email": info.get("email"),
|
||||||
|
@ -6,6 +6,7 @@ from requests.exceptions import RequestException
|
|||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
|
|
||||||
from authentik.sources.oauth.clients.oauth2 import OAuth2Client
|
from authentik.sources.oauth.clients.oauth2 import OAuth2Client
|
||||||
|
from authentik.sources.oauth.models import AuthorizationCodeAuthMethod
|
||||||
from authentik.sources.oauth.types.registry import SourceType, registry
|
from authentik.sources.oauth.types.registry import SourceType, registry
|
||||||
from authentik.sources.oauth.views.callback import OAuthCallback
|
from authentik.sources.oauth.views.callback import OAuthCallback
|
||||||
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
||||||
@ -59,6 +60,8 @@ class MailcowType(SourceType):
|
|||||||
|
|
||||||
urls_customizable = True
|
urls_customizable = True
|
||||||
|
|
||||||
|
authorization_code_auth_method = AuthorizationCodeAuthMethod.POST_BODY
|
||||||
|
|
||||||
def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]:
|
def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"username": info.get("full_name"),
|
"username": info.get("full_name"),
|
||||||
|
@ -2,8 +2,10 @@
|
|||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from requests.auth import AuthBase, HTTPBasicAuth
|
||||||
|
|
||||||
from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient
|
from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient
|
||||||
from authentik.sources.oauth.models import OAuthSource
|
from authentik.sources.oauth.models import AuthorizationCodeAuthMethod, OAuthSource
|
||||||
from authentik.sources.oauth.types.registry import SourceType, registry
|
from authentik.sources.oauth.types.registry import SourceType, registry
|
||||||
from authentik.sources.oauth.views.callback import OAuthCallback
|
from authentik.sources.oauth.views.callback import OAuthCallback
|
||||||
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
||||||
@ -18,10 +20,27 @@ class OpenIDConnectOAuthRedirect(OAuthRedirect):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class OpenIDConnectClient(UserprofileHeaderAuthClient):
|
||||||
|
def get_access_token_args(self, callback: str, code: str) -> dict[str, Any]:
|
||||||
|
args = super().get_access_token_args(callback, code)
|
||||||
|
if self.source.authorization_code_auth_method == AuthorizationCodeAuthMethod.POST_BODY:
|
||||||
|
args["client_id"] = self.get_client_id()
|
||||||
|
args["client_secret"] = self.get_client_secret()
|
||||||
|
else:
|
||||||
|
args.pop("client_id", None)
|
||||||
|
args.pop("client_secret", None)
|
||||||
|
return args
|
||||||
|
|
||||||
|
def get_access_token_auth(self) -> AuthBase | None:
|
||||||
|
if self.source.authorization_code_auth_method == AuthorizationCodeAuthMethod.BASIC_AUTH:
|
||||||
|
return HTTPBasicAuth(self.get_client_id(), self.get_client_secret())
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class OpenIDConnectOAuth2Callback(OAuthCallback):
|
class OpenIDConnectOAuth2Callback(OAuthCallback):
|
||||||
"""OpenIDConnect OAuth2 Callback"""
|
"""OpenIDConnect OAuth2 Callback"""
|
||||||
|
|
||||||
client_class = UserprofileHeaderAuthClient
|
client_class = OpenIDConnectClient
|
||||||
|
|
||||||
def get_user_id(self, info: dict[str, str]) -> str:
|
def get_user_id(self, info: dict[str, str]) -> str:
|
||||||
return info.get("sub", None)
|
return info.get("sub", None)
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient
|
|
||||||
from authentik.sources.oauth.models import OAuthSource
|
from authentik.sources.oauth.models import OAuthSource
|
||||||
from authentik.sources.oauth.types.oidc import OpenIDConnectOAuth2Callback
|
from authentik.sources.oauth.types.oidc import OpenIDConnectOAuth2Callback
|
||||||
from authentik.sources.oauth.types.registry import SourceType, registry
|
from authentik.sources.oauth.types.registry import SourceType, registry
|
||||||
@ -18,20 +17,11 @@ class OktaOAuthRedirect(OAuthRedirect):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class OktaOAuth2Callback(OpenIDConnectOAuth2Callback):
|
|
||||||
"""Okta OAuth2 Callback"""
|
|
||||||
|
|
||||||
# Okta has the same quirk as azure and throws an error if the access token
|
|
||||||
# is set via query parameter, so we reuse the azure client
|
|
||||||
# see https://github.com/goauthentik/authentik/issues/1910
|
|
||||||
client_class = UserprofileHeaderAuthClient
|
|
||||||
|
|
||||||
|
|
||||||
@registry.register()
|
@registry.register()
|
||||||
class OktaType(SourceType):
|
class OktaType(SourceType):
|
||||||
"""Okta Type definition"""
|
"""Okta Type definition"""
|
||||||
|
|
||||||
callback_view = OktaOAuth2Callback
|
callback_view = OpenIDConnectOAuth2Callback
|
||||||
redirect_view = OktaOAuthRedirect
|
redirect_view = OktaOAuthRedirect
|
||||||
verbose_name = "Okta"
|
verbose_name = "Okta"
|
||||||
name = "okta"
|
name = "okta"
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient
|
from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient
|
||||||
from authentik.sources.oauth.models import OAuthSource
|
from authentik.sources.oauth.models import AuthorizationCodeAuthMethod, OAuthSource
|
||||||
from authentik.sources.oauth.types.registry import SourceType, registry
|
from authentik.sources.oauth.types.registry import SourceType, registry
|
||||||
from authentik.sources.oauth.views.callback import OAuthCallback
|
from authentik.sources.oauth.views.callback import OAuthCallback
|
||||||
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
||||||
@ -41,6 +41,8 @@ class PatreonType(SourceType):
|
|||||||
access_token_url = "https://www.patreon.com/api/oauth2/token" # nosec
|
access_token_url = "https://www.patreon.com/api/oauth2/token" # nosec
|
||||||
profile_url = "https://www.patreon.com/api/oauth2/api/current_user"
|
profile_url = "https://www.patreon.com/api/oauth2/api/current_user"
|
||||||
|
|
||||||
|
authorization_code_auth_method = AuthorizationCodeAuthMethod.POST_BODY
|
||||||
|
|
||||||
def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]:
|
def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"username": info.get("data", {}).get("attributes", {}).get("vanity"),
|
"username": info.get("data", {}).get("attributes", {}).get("vanity"),
|
||||||
|
@ -2,8 +2,6 @@
|
|||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from requests.auth import HTTPBasicAuth
|
|
||||||
|
|
||||||
from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient
|
from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient
|
||||||
from authentik.sources.oauth.types.registry import SourceType, registry
|
from authentik.sources.oauth.types.registry import SourceType, registry
|
||||||
from authentik.sources.oauth.views.callback import OAuthCallback
|
from authentik.sources.oauth.views.callback import OAuthCallback
|
||||||
@ -20,21 +18,10 @@ class RedditOAuthRedirect(OAuthRedirect):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class RedditOAuth2Client(UserprofileHeaderAuthClient):
|
|
||||||
"""Reddit OAuth2 Client"""
|
|
||||||
|
|
||||||
def get_access_token(self, **request_kwargs):
|
|
||||||
"Fetch access token from callback request."
|
|
||||||
request_kwargs["auth"] = HTTPBasicAuth(
|
|
||||||
self.source.consumer_key, self.source.consumer_secret
|
|
||||||
)
|
|
||||||
return super().get_access_token(**request_kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class RedditOAuth2Callback(OAuthCallback):
|
class RedditOAuth2Callback(OAuthCallback):
|
||||||
"""Reddit OAuth2 Callback"""
|
"""Reddit OAuth2 Callback"""
|
||||||
|
|
||||||
client_class = RedditOAuth2Client
|
client_class = UserprofileHeaderAuthClient
|
||||||
|
|
||||||
|
|
||||||
@registry.register()
|
@registry.register()
|
||||||
|
@ -10,7 +10,7 @@ from django.urls.base import reverse
|
|||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
|
|
||||||
from authentik.flows.challenge import Challenge, RedirectChallenge
|
from authentik.flows.challenge import Challenge, RedirectChallenge
|
||||||
from authentik.sources.oauth.models import OAuthSource
|
from authentik.sources.oauth.models import AuthorizationCodeAuthMethod, OAuthSource
|
||||||
from authentik.sources.oauth.views.callback import OAuthCallback
|
from authentik.sources.oauth.views.callback import OAuthCallback
|
||||||
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
||||||
|
|
||||||
@ -41,6 +41,10 @@ class SourceType:
|
|||||||
oidc_well_known_url: str | None = None
|
oidc_well_known_url: str | None = None
|
||||||
oidc_jwks_url: str | None = None
|
oidc_jwks_url: str | None = None
|
||||||
|
|
||||||
|
authorization_code_auth_method: AuthorizationCodeAuthMethod = (
|
||||||
|
AuthorizationCodeAuthMethod.BASIC_AUTH
|
||||||
|
)
|
||||||
|
|
||||||
def icon_url(self) -> str:
|
def icon_url(self) -> str:
|
||||||
"""Get Icon URL for login"""
|
"""Get Icon URL for login"""
|
||||||
return static(f"authentik/sources/{self.name}.svg")
|
return static(f"authentik/sources/{self.name}.svg")
|
||||||
|
@ -4,6 +4,7 @@ from json import dumps
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient
|
from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient
|
||||||
|
from authentik.sources.oauth.models import AuthorizationCodeAuthMethod
|
||||||
from authentik.sources.oauth.types.oidc import OpenIDConnectOAuth2Callback
|
from authentik.sources.oauth.types.oidc import OpenIDConnectOAuth2Callback
|
||||||
from authentik.sources.oauth.types.registry import SourceType, registry
|
from authentik.sources.oauth.types.registry import SourceType, registry
|
||||||
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
||||||
@ -47,6 +48,8 @@ class TwitchType(SourceType):
|
|||||||
access_token_url = "https://id.twitch.tv/oauth2/token" # nosec
|
access_token_url = "https://id.twitch.tv/oauth2/token" # nosec
|
||||||
profile_url = "https://id.twitch.tv/oauth2/userinfo"
|
profile_url = "https://id.twitch.tv/oauth2/userinfo"
|
||||||
|
|
||||||
|
authorization_code_auth_method = AuthorizationCodeAuthMethod.POST_BODY
|
||||||
|
|
||||||
def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]:
|
def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"username": info.get("preferred_username"),
|
"username": info.get("preferred_username"),
|
||||||
|
@ -12,23 +12,6 @@ from authentik.sources.oauth.views.callback import OAuthCallback
|
|||||||
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
||||||
|
|
||||||
|
|
||||||
class TwitterClient(UserprofileHeaderAuthClient):
|
|
||||||
"""Twitter has similar quirks to Azure AD, and additionally requires Basic auth on
|
|
||||||
the access token endpoint for some reason."""
|
|
||||||
|
|
||||||
# Twitter has the same quirk as azure and throws an error if the access token
|
|
||||||
# is set via query parameter, so we reuse the azure client
|
|
||||||
# see https://github.com/goauthentik/authentik/issues/1910
|
|
||||||
|
|
||||||
def get_access_token(self, **request_kwargs) -> dict[str, Any] | None:
|
|
||||||
return super().get_access_token(
|
|
||||||
auth=(
|
|
||||||
self.source.consumer_key,
|
|
||||||
self.source.consumer_secret,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TwitterOAuthRedirect(OAuthRedirect):
|
class TwitterOAuthRedirect(OAuthRedirect):
|
||||||
"""Twitter OAuth2 Redirect"""
|
"""Twitter OAuth2 Redirect"""
|
||||||
|
|
||||||
@ -44,7 +27,7 @@ class TwitterOAuthRedirect(OAuthRedirect):
|
|||||||
class TwitterOAuthCallback(OAuthCallback):
|
class TwitterOAuthCallback(OAuthCallback):
|
||||||
"""Twitter OAuth2 Callback"""
|
"""Twitter OAuth2 Callback"""
|
||||||
|
|
||||||
client_class = TwitterClient
|
client_class = UserprofileHeaderAuthClient
|
||||||
|
|
||||||
def get_user_id(self, info: dict[str, str]) -> str:
|
def get_user_id(self, info: dict[str, str]) -> str:
|
||||||
return info.get("data", {}).get("id", "")
|
return info.get("data", {}).get("id", "")
|
||||||
|
@ -8436,6 +8436,15 @@
|
|||||||
"type": "object",
|
"type": "object",
|
||||||
"additionalProperties": true,
|
"additionalProperties": true,
|
||||||
"title": "Oidc jwks"
|
"title": "Oidc jwks"
|
||||||
|
},
|
||||||
|
"authorization_code_auth_method": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": [
|
||||||
|
"basic_auth",
|
||||||
|
"post_body"
|
||||||
|
],
|
||||||
|
"title": "Authorization code auth method",
|
||||||
|
"description": "How to perform authentication during an authorization_code token request flow"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"required": []
|
"required": []
|
||||||
|
20
schema.yml
20
schema.yml
@ -42110,6 +42110,11 @@ components:
|
|||||||
format: uuid
|
format: uuid
|
||||||
required:
|
required:
|
||||||
- name
|
- name
|
||||||
|
AuthorizationCodeAuthMethodEnum:
|
||||||
|
enum:
|
||||||
|
- basic_auth
|
||||||
|
- post_body
|
||||||
|
type: string
|
||||||
AutoSubmitChallengeResponseRequest:
|
AutoSubmitChallengeResponseRequest:
|
||||||
type: object
|
type: object
|
||||||
description: Pseudo class for autosubmit response
|
description: Pseudo class for autosubmit response
|
||||||
@ -48742,6 +48747,11 @@ components:
|
|||||||
oidc_jwks_url:
|
oidc_jwks_url:
|
||||||
type: string
|
type: string
|
||||||
oidc_jwks: {}
|
oidc_jwks: {}
|
||||||
|
authorization_code_auth_method:
|
||||||
|
allOf:
|
||||||
|
- $ref: '#/components/schemas/AuthorizationCodeAuthMethodEnum'
|
||||||
|
description: How to perform authentication during an authorization_code
|
||||||
|
token request flow
|
||||||
required:
|
required:
|
||||||
- callback_url
|
- callback_url
|
||||||
- component
|
- component
|
||||||
@ -48911,6 +48921,11 @@ components:
|
|||||||
oidc_jwks_url:
|
oidc_jwks_url:
|
||||||
type: string
|
type: string
|
||||||
oidc_jwks: {}
|
oidc_jwks: {}
|
||||||
|
authorization_code_auth_method:
|
||||||
|
allOf:
|
||||||
|
- $ref: '#/components/schemas/AuthorizationCodeAuthMethodEnum'
|
||||||
|
description: How to perform authentication during an authorization_code
|
||||||
|
token request flow
|
||||||
required:
|
required:
|
||||||
- consumer_key
|
- consumer_key
|
||||||
- consumer_secret
|
- consumer_secret
|
||||||
@ -53009,6 +53024,11 @@ components:
|
|||||||
oidc_jwks_url:
|
oidc_jwks_url:
|
||||||
type: string
|
type: string
|
||||||
oidc_jwks: {}
|
oidc_jwks: {}
|
||||||
|
authorization_code_auth_method:
|
||||||
|
allOf:
|
||||||
|
- $ref: '#/components/schemas/AuthorizationCodeAuthMethodEnum'
|
||||||
|
description: How to perform authentication during an authorization_code
|
||||||
|
token request flow
|
||||||
PatchedOutpostRequest:
|
PatchedOutpostRequest:
|
||||||
type: object
|
type: object
|
||||||
description: Outpost Serializer
|
description: Outpost Serializer
|
||||||
|
@ -7,6 +7,7 @@ import {
|
|||||||
} from "@goauthentik/admin/sources/oauth/utils";
|
} from "@goauthentik/admin/sources/oauth/utils";
|
||||||
import { DEFAULT_CONFIG, config } from "@goauthentik/common/api/config";
|
import { DEFAULT_CONFIG, config } from "@goauthentik/common/api/config";
|
||||||
import { first } from "@goauthentik/common/utils";
|
import { first } from "@goauthentik/common/utils";
|
||||||
|
import "@goauthentik/components/ak-radio-input";
|
||||||
import "@goauthentik/elements/CodeMirror";
|
import "@goauthentik/elements/CodeMirror";
|
||||||
import { CodeMirrorMode } from "@goauthentik/elements/CodeMirror";
|
import { CodeMirrorMode } from "@goauthentik/elements/CodeMirror";
|
||||||
import {
|
import {
|
||||||
@ -16,6 +17,7 @@ import {
|
|||||||
import "@goauthentik/elements/ak-dual-select/ak-dual-select-dynamic-selected-provider.js";
|
import "@goauthentik/elements/ak-dual-select/ak-dual-select-dynamic-selected-provider.js";
|
||||||
import "@goauthentik/elements/forms/FormGroup";
|
import "@goauthentik/elements/forms/FormGroup";
|
||||||
import "@goauthentik/elements/forms/HorizontalFormElement";
|
import "@goauthentik/elements/forms/HorizontalFormElement";
|
||||||
|
import "@goauthentik/elements/forms/Radio";
|
||||||
import "@goauthentik/elements/forms/SearchSelect";
|
import "@goauthentik/elements/forms/SearchSelect";
|
||||||
|
|
||||||
import { msg } from "@lit/localize";
|
import { msg } from "@lit/localize";
|
||||||
@ -24,6 +26,7 @@ import { customElement, property, state } from "lit/decorators.js";
|
|||||||
import { ifDefined } from "lit/directives/if-defined.js";
|
import { ifDefined } from "lit/directives/if-defined.js";
|
||||||
|
|
||||||
import {
|
import {
|
||||||
|
AuthorizationCodeAuthMethodEnum,
|
||||||
FlowsInstancesListDesignationEnum,
|
FlowsInstancesListDesignationEnum,
|
||||||
GroupMatchingModeEnum,
|
GroupMatchingModeEnum,
|
||||||
OAuthSource,
|
OAuthSource,
|
||||||
@ -36,6 +39,18 @@ import {
|
|||||||
|
|
||||||
import { propertyMappingsProvider, propertyMappingsSelector } from "./OAuthSourceFormHelpers.js";
|
import { propertyMappingsProvider, propertyMappingsSelector } from "./OAuthSourceFormHelpers.js";
|
||||||
|
|
||||||
|
const authorizationCodeAuthMethodOptions = [
|
||||||
|
{
|
||||||
|
label: msg("HTTP Basic Auth"),
|
||||||
|
value: AuthorizationCodeAuthMethodEnum.BasicAuth,
|
||||||
|
default: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: msg("Include the client ID and secret as request parameters"),
|
||||||
|
value: AuthorizationCodeAuthMethodEnum.PostBody,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
@customElement("ak-source-oauth-form")
|
@customElement("ak-source-oauth-form")
|
||||||
export class OAuthSourceForm extends WithCapabilitiesConfig(BaseSourceForm<OAuthSource>) {
|
export class OAuthSourceForm extends WithCapabilitiesConfig(BaseSourceForm<OAuthSource>) {
|
||||||
async loadInstance(pk: string): Promise<OAuthSource> {
|
async loadInstance(pk: string): Promise<OAuthSource> {
|
||||||
@ -240,6 +255,19 @@ export class OAuthSourceForm extends WithCapabilitiesConfig(BaseSourceForm<OAuth
|
|||||||
<p class="pf-c-form__helper-text">${msg("Raw JWKS data.")}</p>
|
<p class="pf-c-form__helper-text">${msg("Raw JWKS data.")}</p>
|
||||||
</ak-form-element-horizontal>`
|
</ak-form-element-horizontal>`
|
||||||
: html``}
|
: html``}
|
||||||
|
${this.providerType.name === ProviderTypeEnum.Openidconnect
|
||||||
|
? html`<ak-radio-input
|
||||||
|
label=${msg("Authorization code authentication method")}
|
||||||
|
name="authorizationCodeAuthMethod"
|
||||||
|
required
|
||||||
|
.options=${authorizationCodeAuthMethodOptions}
|
||||||
|
.value=${this.instance?.authorizationCodeAuthMethod}
|
||||||
|
help=${msg(
|
||||||
|
"How to perform authentication during an authorization_code token request flow",
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
</ak-radio-input>`
|
||||||
|
: html``}
|
||||||
</div>
|
</div>
|
||||||
</ak-form-group>`;
|
</ak-form-group>`;
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user