providers/scim: optimize sending all members within a group (#9968)

* providers/scim: optimize sending all members within a group

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* correctly batch requests

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens L.
2024-08-22 16:39:18 +02:00
committed by GitHub
parent 46acab3b2e
commit eafb7093c9
2 changed files with 62 additions and 32 deletions

View File

@ -1,5 +1,7 @@
"""Group client"""
from itertools import batched
from pydantic import ValidationError
from pydanticscim.group import GroupMember
from pydanticscim.responses import PatchOp, PatchOperation
@ -56,17 +58,22 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
if not scim_group.externalId:
scim_group.externalId = str(obj.pk)
users = list(obj.users.order_by("id").values_list("id", flat=True))
connections = SCIMProviderUser.objects.filter(provider=self.provider, user__pk__in=users)
members = []
for user in connections:
members.append(
GroupMember(
value=user.scim_id,
)
if not self._config.patch.supported:
users = list(obj.users.order_by("id").values_list("id", flat=True))
connections = SCIMProviderUser.objects.filter(
provider=self.provider, user__pk__in=users
)
if members:
scim_group.members = members
members = []
for user in connections:
members.append(
GroupMember(
value=user.scim_id,
)
)
if members:
scim_group.members = members
else:
del scim_group.members
return scim_group
def delete(self, obj: Group):
@ -93,16 +100,19 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
scim_id = response.get("id")
if not scim_id or scim_id == "":
raise StopSync("SCIM Response with missing or invalid `id`")
return SCIMProviderGroup.objects.create(
connection = SCIMProviderGroup.objects.create(
provider=self.provider, group=group, scim_id=scim_id
)
users = list(group.users.order_by("id").values_list("id", flat=True))
self._patch_add_users(group, users)
return connection
def update(self, group: Group, connection: SCIMProviderGroup):
"""Update existing group"""
scim_group = self.to_schema(group, connection)
scim_group.id = connection.scim_id
try:
return self._request(
self._request(
"PUT",
f"/Groups/{connection.scim_id}",
json=scim_group.model_dump(
@ -110,6 +120,8 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
exclude_unset=True,
),
)
users = list(group.users.order_by("id").values_list("id", flat=True))
return self._patch_add_users(group, users)
except NotFoundSyncException:
# Resource missing is handled by self.write, which will re-create the group
raise
@ -152,14 +164,18 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
group_id: str,
*ops: PatchOperation,
):
req = PatchRequest(Operations=ops)
self._request(
"PATCH",
f"/Groups/{group_id}",
json=req.model_dump(
mode="json",
),
)
chunk_size = self._config.bulk.maxOperations
if chunk_size < 1:
chunk_size = len(ops)
for chunk in batched(ops, chunk_size):
req = PatchRequest(Operations=list(chunk))
self._request(
"PATCH",
f"/Groups/{group_id}",
json=req.model_dump(
mode="json",
),
)
def _patch_add_users(self, group: Group, users_set: set[int]):
"""Add users in users_set to group"""
@ -180,11 +196,14 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
return
self._patch(
scim_group.scim_id,
PatchOperation(
op=PatchOp.add,
path="members",
value=[{"value": x} for x in user_ids],
),
*[
PatchOperation(
op=PatchOp.add,
path="members",
value=[{"value": x}],
)
for x in user_ids
],
)
def _patch_remove_users(self, group: Group, users_set: set[int]):
@ -206,9 +225,12 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
return
self._patch(
scim_group.scim_id,
PatchOperation(
op=PatchOp.remove,
path="members",
value=[{"value": x} for x in user_ids],
),
*[
PatchOperation(
op=PatchOp.remove,
path="members",
value=[{"value": x}],
)
for x in user_ids
],
)

View File

@ -1,9 +1,11 @@
"""Custom SCIM schemas"""
from pydantic import Field
from pydanticscim.group import Group as BaseGroup
from pydanticscim.responses import PatchRequest as BasePatchRequest
from pydanticscim.responses import SCIMError as BaseSCIMError
from pydanticscim.service_provider import Bulk, ChangePassword, Filter, Patch, Sort
from pydanticscim.service_provider import Bulk as BaseBulk
from pydanticscim.service_provider import ChangePassword, Filter, Patch, Sort
from pydanticscim.service_provider import (
ServiceProviderConfiguration as BaseServiceProviderConfiguration,
)
@ -29,10 +31,16 @@ class Group(BaseGroup):
meta: dict | None = None
class Bulk(BaseBulk):
maxOperations: int = Field()
class ServiceProviderConfiguration(BaseServiceProviderConfiguration):
"""ServiceProviderConfig with fallback"""
_is_fallback: bool | None = False
bulk: Bulk = Field(..., description="A complex type that specifies bulk configuration options.")
@property
def is_fallback(self) -> bool:
@ -45,7 +53,7 @@ class ServiceProviderConfiguration(BaseServiceProviderConfiguration):
"""Get default configuration, which doesn't support any optional features as fallback"""
return ServiceProviderConfiguration(
patch=Patch(supported=False),
bulk=Bulk(supported=False),
bulk=Bulk(supported=False, maxOperations=0),
filter=Filter(supported=False),
changePassword=ChangePassword(supported=False),
sort=Sort(supported=False),