providers/scim: optimize sending all members within a group (#9968)

* providers/scim: optimize sending all members within a group

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* correctly batch requests

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens L.
2024-08-22 16:39:18 +02:00
committed by GitHub
parent 46acab3b2e
commit eafb7093c9
2 changed files with 62 additions and 32 deletions

View File

@ -1,5 +1,7 @@
"""Group client""" """Group client"""
from itertools import batched
from pydantic import ValidationError from pydantic import ValidationError
from pydanticscim.group import GroupMember from pydanticscim.group import GroupMember
from pydanticscim.responses import PatchOp, PatchOperation from pydanticscim.responses import PatchOp, PatchOperation
@ -56,8 +58,11 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
if not scim_group.externalId: if not scim_group.externalId:
scim_group.externalId = str(obj.pk) scim_group.externalId = str(obj.pk)
if not self._config.patch.supported:
users = list(obj.users.order_by("id").values_list("id", flat=True)) users = list(obj.users.order_by("id").values_list("id", flat=True))
connections = SCIMProviderUser.objects.filter(provider=self.provider, user__pk__in=users) connections = SCIMProviderUser.objects.filter(
provider=self.provider, user__pk__in=users
)
members = [] members = []
for user in connections: for user in connections:
members.append( members.append(
@ -67,6 +72,8 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
) )
if members: if members:
scim_group.members = members scim_group.members = members
else:
del scim_group.members
return scim_group return scim_group
def delete(self, obj: Group): def delete(self, obj: Group):
@ -93,16 +100,19 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
scim_id = response.get("id") scim_id = response.get("id")
if not scim_id or scim_id == "": if not scim_id or scim_id == "":
raise StopSync("SCIM Response with missing or invalid `id`") raise StopSync("SCIM Response with missing or invalid `id`")
return SCIMProviderGroup.objects.create( connection = SCIMProviderGroup.objects.create(
provider=self.provider, group=group, scim_id=scim_id provider=self.provider, group=group, scim_id=scim_id
) )
users = list(group.users.order_by("id").values_list("id", flat=True))
self._patch_add_users(group, users)
return connection
def update(self, group: Group, connection: SCIMProviderGroup): def update(self, group: Group, connection: SCIMProviderGroup):
"""Update existing group""" """Update existing group"""
scim_group = self.to_schema(group, connection) scim_group = self.to_schema(group, connection)
scim_group.id = connection.scim_id scim_group.id = connection.scim_id
try: try:
return self._request( self._request(
"PUT", "PUT",
f"/Groups/{connection.scim_id}", f"/Groups/{connection.scim_id}",
json=scim_group.model_dump( json=scim_group.model_dump(
@ -110,6 +120,8 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
exclude_unset=True, exclude_unset=True,
), ),
) )
users = list(group.users.order_by("id").values_list("id", flat=True))
return self._patch_add_users(group, users)
except NotFoundSyncException: except NotFoundSyncException:
# Resource missing is handled by self.write, which will re-create the group # Resource missing is handled by self.write, which will re-create the group
raise raise
@ -152,7 +164,11 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
group_id: str, group_id: str,
*ops: PatchOperation, *ops: PatchOperation,
): ):
req = PatchRequest(Operations=ops) chunk_size = self._config.bulk.maxOperations
if chunk_size < 1:
chunk_size = len(ops)
for chunk in batched(ops, chunk_size):
req = PatchRequest(Operations=list(chunk))
self._request( self._request(
"PATCH", "PATCH",
f"/Groups/{group_id}", f"/Groups/{group_id}",
@ -180,11 +196,14 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
return return
self._patch( self._patch(
scim_group.scim_id, scim_group.scim_id,
*[
PatchOperation( PatchOperation(
op=PatchOp.add, op=PatchOp.add,
path="members", path="members",
value=[{"value": x} for x in user_ids], value=[{"value": x}],
), )
for x in user_ids
],
) )
def _patch_remove_users(self, group: Group, users_set: set[int]): def _patch_remove_users(self, group: Group, users_set: set[int]):
@ -206,9 +225,12 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
return return
self._patch( self._patch(
scim_group.scim_id, scim_group.scim_id,
*[
PatchOperation( PatchOperation(
op=PatchOp.remove, op=PatchOp.remove,
path="members", path="members",
value=[{"value": x} for x in user_ids], value=[{"value": x}],
), )
for x in user_ids
],
) )

View File

@ -1,9 +1,11 @@
"""Custom SCIM schemas""" """Custom SCIM schemas"""
from pydantic import Field
from pydanticscim.group import Group as BaseGroup from pydanticscim.group import Group as BaseGroup
from pydanticscim.responses import PatchRequest as BasePatchRequest from pydanticscim.responses import PatchRequest as BasePatchRequest
from pydanticscim.responses import SCIMError as BaseSCIMError from pydanticscim.responses import SCIMError as BaseSCIMError
from pydanticscim.service_provider import Bulk, ChangePassword, Filter, Patch, Sort from pydanticscim.service_provider import Bulk as BaseBulk
from pydanticscim.service_provider import ChangePassword, Filter, Patch, Sort
from pydanticscim.service_provider import ( from pydanticscim.service_provider import (
ServiceProviderConfiguration as BaseServiceProviderConfiguration, ServiceProviderConfiguration as BaseServiceProviderConfiguration,
) )
@ -29,10 +31,16 @@ class Group(BaseGroup):
meta: dict | None = None meta: dict | None = None
class Bulk(BaseBulk):
maxOperations: int = Field()
class ServiceProviderConfiguration(BaseServiceProviderConfiguration): class ServiceProviderConfiguration(BaseServiceProviderConfiguration):
"""ServiceProviderConfig with fallback""" """ServiceProviderConfig with fallback"""
_is_fallback: bool | None = False _is_fallback: bool | None = False
bulk: Bulk = Field(..., description="A complex type that specifies bulk configuration options.")
@property @property
def is_fallback(self) -> bool: def is_fallback(self) -> bool:
@ -45,7 +53,7 @@ class ServiceProviderConfiguration(BaseServiceProviderConfiguration):
"""Get default configuration, which doesn't support any optional features as fallback""" """Get default configuration, which doesn't support any optional features as fallback"""
return ServiceProviderConfiguration( return ServiceProviderConfiguration(
patch=Patch(supported=False), patch=Patch(supported=False),
bulk=Bulk(supported=False), bulk=Bulk(supported=False, maxOperations=0),
filter=Filter(supported=False), filter=Filter(supported=False),
changePassword=ChangePassword(supported=False), changePassword=ChangePassword(supported=False),
sort=Sort(supported=False), sort=Sort(supported=False),