From eafb7093c946276cbca2418cdfd18e932836b249 Mon Sep 17 00:00:00 2001 From: "Jens L." Date: Thu, 22 Aug 2024 16:39:18 +0200 Subject: [PATCH] providers/scim: optimize sending all members within a group (#9968) * providers/scim: optimize sending all members within a group Signed-off-by: Jens Langhammer * correctly batch requests Signed-off-by: Jens Langhammer --------- Signed-off-by: Jens Langhammer --- authentik/providers/scim/clients/groups.py | 82 ++++++++++++++-------- authentik/providers/scim/clients/schema.py | 12 +++- 2 files changed, 62 insertions(+), 32 deletions(-) diff --git a/authentik/providers/scim/clients/groups.py b/authentik/providers/scim/clients/groups.py index b1dc657dcb..1f39eea8f5 100644 --- a/authentik/providers/scim/clients/groups.py +++ b/authentik/providers/scim/clients/groups.py @@ -1,5 +1,7 @@ """Group client""" +from itertools import batched + from pydantic import ValidationError from pydanticscim.group import GroupMember from pydanticscim.responses import PatchOp, PatchOperation @@ -56,17 +58,22 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]): if not scim_group.externalId: scim_group.externalId = str(obj.pk) - users = list(obj.users.order_by("id").values_list("id", flat=True)) - connections = SCIMProviderUser.objects.filter(provider=self.provider, user__pk__in=users) - members = [] - for user in connections: - members.append( - GroupMember( - value=user.scim_id, - ) + if not self._config.patch.supported: + users = list(obj.users.order_by("id").values_list("id", flat=True)) + connections = SCIMProviderUser.objects.filter( + provider=self.provider, user__pk__in=users ) - if members: - scim_group.members = members + members = [] + for user in connections: + members.append( + GroupMember( + value=user.scim_id, + ) + ) + if members: + scim_group.members = members + else: + del scim_group.members return scim_group def delete(self, obj: Group): @@ -93,16 +100,19 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]): scim_id = response.get("id") if not scim_id or scim_id == "": raise StopSync("SCIM Response with missing or invalid `id`") - return SCIMProviderGroup.objects.create( + connection = SCIMProviderGroup.objects.create( provider=self.provider, group=group, scim_id=scim_id ) + users = list(group.users.order_by("id").values_list("id", flat=True)) + self._patch_add_users(group, users) + return connection def update(self, group: Group, connection: SCIMProviderGroup): """Update existing group""" scim_group = self.to_schema(group, connection) scim_group.id = connection.scim_id try: - return self._request( + self._request( "PUT", f"/Groups/{connection.scim_id}", json=scim_group.model_dump( @@ -110,6 +120,8 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]): exclude_unset=True, ), ) + users = list(group.users.order_by("id").values_list("id", flat=True)) + return self._patch_add_users(group, users) except NotFoundSyncException: # Resource missing is handled by self.write, which will re-create the group raise @@ -152,14 +164,18 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]): group_id: str, *ops: PatchOperation, ): - req = PatchRequest(Operations=ops) - self._request( - "PATCH", - f"/Groups/{group_id}", - json=req.model_dump( - mode="json", - ), - ) + chunk_size = self._config.bulk.maxOperations + if chunk_size < 1: + chunk_size = len(ops) + for chunk in batched(ops, chunk_size): + req = PatchRequest(Operations=list(chunk)) + self._request( + "PATCH", + f"/Groups/{group_id}", + json=req.model_dump( + mode="json", + ), + ) def _patch_add_users(self, group: Group, users_set: set[int]): """Add users in users_set to group""" @@ -180,11 +196,14 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]): return self._patch( scim_group.scim_id, - PatchOperation( - op=PatchOp.add, - path="members", - value=[{"value": x} for x in user_ids], - ), + *[ + PatchOperation( + op=PatchOp.add, + path="members", + value=[{"value": x}], + ) + for x in user_ids + ], ) def _patch_remove_users(self, group: Group, users_set: set[int]): @@ -206,9 +225,12 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]): return self._patch( scim_group.scim_id, - PatchOperation( - op=PatchOp.remove, - path="members", - value=[{"value": x} for x in user_ids], - ), + *[ + PatchOperation( + op=PatchOp.remove, + path="members", + value=[{"value": x}], + ) + for x in user_ids + ], ) diff --git a/authentik/providers/scim/clients/schema.py b/authentik/providers/scim/clients/schema.py index f56d6b0e46..b4444b3734 100644 --- a/authentik/providers/scim/clients/schema.py +++ b/authentik/providers/scim/clients/schema.py @@ -1,9 +1,11 @@ """Custom SCIM schemas""" +from pydantic import Field from pydanticscim.group import Group as BaseGroup from pydanticscim.responses import PatchRequest as BasePatchRequest from pydanticscim.responses import SCIMError as BaseSCIMError -from pydanticscim.service_provider import Bulk, ChangePassword, Filter, Patch, Sort +from pydanticscim.service_provider import Bulk as BaseBulk +from pydanticscim.service_provider import ChangePassword, Filter, Patch, Sort from pydanticscim.service_provider import ( ServiceProviderConfiguration as BaseServiceProviderConfiguration, ) @@ -29,10 +31,16 @@ class Group(BaseGroup): meta: dict | None = None +class Bulk(BaseBulk): + + maxOperations: int = Field() + + class ServiceProviderConfiguration(BaseServiceProviderConfiguration): """ServiceProviderConfig with fallback""" _is_fallback: bool | None = False + bulk: Bulk = Field(..., description="A complex type that specifies bulk configuration options.") @property def is_fallback(self) -> bool: @@ -45,7 +53,7 @@ class ServiceProviderConfiguration(BaseServiceProviderConfiguration): """Get default configuration, which doesn't support any optional features as fallback""" return ServiceProviderConfiguration( patch=Patch(supported=False), - bulk=Bulk(supported=False), + bulk=Bulk(supported=False, maxOperations=0), filter=Filter(supported=False), changePassword=ChangePassword(supported=False), sort=Sort(supported=False),