sources/oauth: add group sync for azure_ad (cherry-pick #12894) (#13123)

sources/oauth: add group sync for azure_ad (#12894)

* sources/oauth: add group sync for azure_ad



* make group sync optional



---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
Co-authored-by: Jens L. <jens@goauthentik.io>
This commit is contained in:
gcp-cherry-pick-bot[bot]
2025-02-19 18:20:16 +01:00
committed by GitHub
parent 13c99c8546
commit 50a68c22c5

View File

@ -2,6 +2,7 @@
from typing import Any
from requests import RequestException
from structlog.stdlib import get_logger
from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient
@ -21,10 +22,35 @@ class AzureADOAuthRedirect(OAuthRedirect):
}
class AzureADClient(UserprofileHeaderAuthClient):
"""Fetch AzureAD group information"""
def get_profile_info(self, token):
profile_data = super().get_profile_info(token)
if "https://graph.microsoft.com/GroupMember.Read.All" not in self.source.additional_scopes:
return profile_data
group_response = self.session.request(
"get",
"https://graph.microsoft.com/v1.0/me/memberOf",
headers={"Authorization": f"{token['token_type']} {token['access_token']}"},
)
try:
group_response.raise_for_status()
except RequestException as exc:
LOGGER.warning(
"Unable to fetch user profile",
exc=exc,
response=exc.response.text if exc.response else str(exc),
)
return None
profile_data["raw_groups"] = group_response.json()
return profile_data
class AzureADOAuthCallback(OpenIDConnectOAuth2Callback):
"""AzureAD OAuth2 Callback"""
client_class = UserprofileHeaderAuthClient
client_class = AzureADClient
def get_user_id(self, info: dict[str, str]) -> str:
# Default try to get `id` for the Graph API endpoint
@ -53,8 +79,24 @@ class AzureADType(SourceType):
def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]:
mail = info.get("mail", None) or info.get("otherMails", [None])[0]
# Format group info
groups = []
group_id_dict = {}
for group in info.get("raw_groups", {}).get("value", []):
if group["@odata.type"] != "#microsoft.graph.group":
continue
groups.append(group["id"])
group_id_dict[group["id"]] = group
info["raw_groups"] = group_id_dict
return {
"username": info.get("userPrincipalName"),
"email": mail,
"name": info.get("displayName"),
"groups": groups,
}
def get_base_group_properties(self, source, group_id, **kwargs):
raw_group = kwargs["info"]["raw_groups"][group_id]
return {
"name": raw_group["displayName"],
}