providers/scim: optimize sending all members within a group (#9968)
* providers/scim: optimize sending all members within a group Signed-off-by: Jens Langhammer <jens@goauthentik.io> * correctly batch requests Signed-off-by: Jens Langhammer <jens@goauthentik.io> --------- Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
@ -1,5 +1,7 @@
|
||||
"""Group client"""
|
||||
|
||||
from itertools import batched
|
||||
|
||||
from pydantic import ValidationError
|
||||
from pydanticscim.group import GroupMember
|
||||
from pydanticscim.responses import PatchOp, PatchOperation
|
||||
@ -56,17 +58,22 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
|
||||
if not scim_group.externalId:
|
||||
scim_group.externalId = str(obj.pk)
|
||||
|
||||
users = list(obj.users.order_by("id").values_list("id", flat=True))
|
||||
connections = SCIMProviderUser.objects.filter(provider=self.provider, user__pk__in=users)
|
||||
members = []
|
||||
for user in connections:
|
||||
members.append(
|
||||
GroupMember(
|
||||
value=user.scim_id,
|
||||
)
|
||||
if not self._config.patch.supported:
|
||||
users = list(obj.users.order_by("id").values_list("id", flat=True))
|
||||
connections = SCIMProviderUser.objects.filter(
|
||||
provider=self.provider, user__pk__in=users
|
||||
)
|
||||
if members:
|
||||
scim_group.members = members
|
||||
members = []
|
||||
for user in connections:
|
||||
members.append(
|
||||
GroupMember(
|
||||
value=user.scim_id,
|
||||
)
|
||||
)
|
||||
if members:
|
||||
scim_group.members = members
|
||||
else:
|
||||
del scim_group.members
|
||||
return scim_group
|
||||
|
||||
def delete(self, obj: Group):
|
||||
@ -93,16 +100,19 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
|
||||
scim_id = response.get("id")
|
||||
if not scim_id or scim_id == "":
|
||||
raise StopSync("SCIM Response with missing or invalid `id`")
|
||||
return SCIMProviderGroup.objects.create(
|
||||
connection = SCIMProviderGroup.objects.create(
|
||||
provider=self.provider, group=group, scim_id=scim_id
|
||||
)
|
||||
users = list(group.users.order_by("id").values_list("id", flat=True))
|
||||
self._patch_add_users(group, users)
|
||||
return connection
|
||||
|
||||
def update(self, group: Group, connection: SCIMProviderGroup):
|
||||
"""Update existing group"""
|
||||
scim_group = self.to_schema(group, connection)
|
||||
scim_group.id = connection.scim_id
|
||||
try:
|
||||
return self._request(
|
||||
self._request(
|
||||
"PUT",
|
||||
f"/Groups/{connection.scim_id}",
|
||||
json=scim_group.model_dump(
|
||||
@ -110,6 +120,8 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
|
||||
exclude_unset=True,
|
||||
),
|
||||
)
|
||||
users = list(group.users.order_by("id").values_list("id", flat=True))
|
||||
return self._patch_add_users(group, users)
|
||||
except NotFoundSyncException:
|
||||
# Resource missing is handled by self.write, which will re-create the group
|
||||
raise
|
||||
@ -152,14 +164,18 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
|
||||
group_id: str,
|
||||
*ops: PatchOperation,
|
||||
):
|
||||
req = PatchRequest(Operations=ops)
|
||||
self._request(
|
||||
"PATCH",
|
||||
f"/Groups/{group_id}",
|
||||
json=req.model_dump(
|
||||
mode="json",
|
||||
),
|
||||
)
|
||||
chunk_size = self._config.bulk.maxOperations
|
||||
if chunk_size < 1:
|
||||
chunk_size = len(ops)
|
||||
for chunk in batched(ops, chunk_size):
|
||||
req = PatchRequest(Operations=list(chunk))
|
||||
self._request(
|
||||
"PATCH",
|
||||
f"/Groups/{group_id}",
|
||||
json=req.model_dump(
|
||||
mode="json",
|
||||
),
|
||||
)
|
||||
|
||||
def _patch_add_users(self, group: Group, users_set: set[int]):
|
||||
"""Add users in users_set to group"""
|
||||
@ -180,11 +196,14 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
|
||||
return
|
||||
self._patch(
|
||||
scim_group.scim_id,
|
||||
PatchOperation(
|
||||
op=PatchOp.add,
|
||||
path="members",
|
||||
value=[{"value": x} for x in user_ids],
|
||||
),
|
||||
*[
|
||||
PatchOperation(
|
||||
op=PatchOp.add,
|
||||
path="members",
|
||||
value=[{"value": x}],
|
||||
)
|
||||
for x in user_ids
|
||||
],
|
||||
)
|
||||
|
||||
def _patch_remove_users(self, group: Group, users_set: set[int]):
|
||||
@ -206,9 +225,12 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
|
||||
return
|
||||
self._patch(
|
||||
scim_group.scim_id,
|
||||
PatchOperation(
|
||||
op=PatchOp.remove,
|
||||
path="members",
|
||||
value=[{"value": x} for x in user_ids],
|
||||
),
|
||||
*[
|
||||
PatchOperation(
|
||||
op=PatchOp.remove,
|
||||
path="members",
|
||||
value=[{"value": x}],
|
||||
)
|
||||
for x in user_ids
|
||||
],
|
||||
)
|
||||
|
@ -1,9 +1,11 @@
|
||||
"""Custom SCIM schemas"""
|
||||
|
||||
from pydantic import Field
|
||||
from pydanticscim.group import Group as BaseGroup
|
||||
from pydanticscim.responses import PatchRequest as BasePatchRequest
|
||||
from pydanticscim.responses import SCIMError as BaseSCIMError
|
||||
from pydanticscim.service_provider import Bulk, ChangePassword, Filter, Patch, Sort
|
||||
from pydanticscim.service_provider import Bulk as BaseBulk
|
||||
from pydanticscim.service_provider import ChangePassword, Filter, Patch, Sort
|
||||
from pydanticscim.service_provider import (
|
||||
ServiceProviderConfiguration as BaseServiceProviderConfiguration,
|
||||
)
|
||||
@ -29,10 +31,16 @@ class Group(BaseGroup):
|
||||
meta: dict | None = None
|
||||
|
||||
|
||||
class Bulk(BaseBulk):
|
||||
|
||||
maxOperations: int = Field()
|
||||
|
||||
|
||||
class ServiceProviderConfiguration(BaseServiceProviderConfiguration):
|
||||
"""ServiceProviderConfig with fallback"""
|
||||
|
||||
_is_fallback: bool | None = False
|
||||
bulk: Bulk = Field(..., description="A complex type that specifies bulk configuration options.")
|
||||
|
||||
@property
|
||||
def is_fallback(self) -> bool:
|
||||
@ -45,7 +53,7 @@ class ServiceProviderConfiguration(BaseServiceProviderConfiguration):
|
||||
"""Get default configuration, which doesn't support any optional features as fallback"""
|
||||
return ServiceProviderConfiguration(
|
||||
patch=Patch(supported=False),
|
||||
bulk=Bulk(supported=False),
|
||||
bulk=Bulk(supported=False, maxOperations=0),
|
||||
filter=Filter(supported=False),
|
||||
changePassword=ChangePassword(supported=False),
|
||||
sort=Sort(supported=False),
|
||||
|
Reference in New Issue
Block a user