From 50a68c22c5c215128031251123ccecb0f910e70b Mon Sep 17 00:00:00 2001 From: "gcp-cherry-pick-bot[bot]" <98988430+gcp-cherry-pick-bot[bot]@users.noreply.github.com> Date: Wed, 19 Feb 2025 18:20:16 +0100 Subject: [PATCH] sources/oauth: add group sync for azure_ad (cherry-pick #12894) (#13123) sources/oauth: add group sync for azure_ad (#12894) * sources/oauth: add group sync for azure_ad * make group sync optional --------- Signed-off-by: Jens Langhammer Co-authored-by: Jens L. --- authentik/sources/oauth/types/azure_ad.py | 44 ++++++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/authentik/sources/oauth/types/azure_ad.py b/authentik/sources/oauth/types/azure_ad.py index 7d7f4e1592..74de7e194f 100644 --- a/authentik/sources/oauth/types/azure_ad.py +++ b/authentik/sources/oauth/types/azure_ad.py @@ -2,6 +2,7 @@ from typing import Any +from requests import RequestException from structlog.stdlib import get_logger from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient @@ -21,10 +22,35 @@ class AzureADOAuthRedirect(OAuthRedirect): } +class AzureADClient(UserprofileHeaderAuthClient): + """Fetch AzureAD group information""" + + def get_profile_info(self, token): + profile_data = super().get_profile_info(token) + if "https://graph.microsoft.com/GroupMember.Read.All" not in self.source.additional_scopes: + return profile_data + group_response = self.session.request( + "get", + "https://graph.microsoft.com/v1.0/me/memberOf", + headers={"Authorization": f"{token['token_type']} {token['access_token']}"}, + ) + try: + group_response.raise_for_status() + except RequestException as exc: + LOGGER.warning( + "Unable to fetch user profile", + exc=exc, + response=exc.response.text if exc.response else str(exc), + ) + return None + profile_data["raw_groups"] = group_response.json() + return profile_data + + class AzureADOAuthCallback(OpenIDConnectOAuth2Callback): """AzureAD OAuth2 Callback""" - client_class = UserprofileHeaderAuthClient + client_class = AzureADClient def get_user_id(self, info: dict[str, str]) -> str: # Default try to get `id` for the Graph API endpoint @@ -53,8 +79,24 @@ class AzureADType(SourceType): def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]: mail = info.get("mail", None) or info.get("otherMails", [None])[0] + # Format group info + groups = [] + group_id_dict = {} + for group in info.get("raw_groups", {}).get("value", []): + if group["@odata.type"] != "#microsoft.graph.group": + continue + groups.append(group["id"]) + group_id_dict[group["id"]] = group + info["raw_groups"] = group_id_dict return { "username": info.get("userPrincipalName"), "email": mail, "name": info.get("displayName"), + "groups": groups, + } + + def get_base_group_properties(self, source, group_id, **kwargs): + raw_group = kwargs["info"]["raw_groups"][group_id] + return { + "name": raw_group["displayName"], }