providers/scim: optimize sending all members within a group (#9968)
* providers/scim: optimize sending all members within a group Signed-off-by: Jens Langhammer <jens@goauthentik.io> * correctly batch requests Signed-off-by: Jens Langhammer <jens@goauthentik.io> --------- Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
@ -1,5 +1,7 @@
|
|||||||
"""Group client"""
|
"""Group client"""
|
||||||
|
|
||||||
|
from itertools import batched
|
||||||
|
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
from pydanticscim.group import GroupMember
|
from pydanticscim.group import GroupMember
|
||||||
from pydanticscim.responses import PatchOp, PatchOperation
|
from pydanticscim.responses import PatchOp, PatchOperation
|
||||||
@ -56,8 +58,11 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
|
|||||||
if not scim_group.externalId:
|
if not scim_group.externalId:
|
||||||
scim_group.externalId = str(obj.pk)
|
scim_group.externalId = str(obj.pk)
|
||||||
|
|
||||||
|
if not self._config.patch.supported:
|
||||||
users = list(obj.users.order_by("id").values_list("id", flat=True))
|
users = list(obj.users.order_by("id").values_list("id", flat=True))
|
||||||
connections = SCIMProviderUser.objects.filter(provider=self.provider, user__pk__in=users)
|
connections = SCIMProviderUser.objects.filter(
|
||||||
|
provider=self.provider, user__pk__in=users
|
||||||
|
)
|
||||||
members = []
|
members = []
|
||||||
for user in connections:
|
for user in connections:
|
||||||
members.append(
|
members.append(
|
||||||
@ -67,6 +72,8 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
|
|||||||
)
|
)
|
||||||
if members:
|
if members:
|
||||||
scim_group.members = members
|
scim_group.members = members
|
||||||
|
else:
|
||||||
|
del scim_group.members
|
||||||
return scim_group
|
return scim_group
|
||||||
|
|
||||||
def delete(self, obj: Group):
|
def delete(self, obj: Group):
|
||||||
@ -93,16 +100,19 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
|
|||||||
scim_id = response.get("id")
|
scim_id = response.get("id")
|
||||||
if not scim_id or scim_id == "":
|
if not scim_id or scim_id == "":
|
||||||
raise StopSync("SCIM Response with missing or invalid `id`")
|
raise StopSync("SCIM Response with missing or invalid `id`")
|
||||||
return SCIMProviderGroup.objects.create(
|
connection = SCIMProviderGroup.objects.create(
|
||||||
provider=self.provider, group=group, scim_id=scim_id
|
provider=self.provider, group=group, scim_id=scim_id
|
||||||
)
|
)
|
||||||
|
users = list(group.users.order_by("id").values_list("id", flat=True))
|
||||||
|
self._patch_add_users(group, users)
|
||||||
|
return connection
|
||||||
|
|
||||||
def update(self, group: Group, connection: SCIMProviderGroup):
|
def update(self, group: Group, connection: SCIMProviderGroup):
|
||||||
"""Update existing group"""
|
"""Update existing group"""
|
||||||
scim_group = self.to_schema(group, connection)
|
scim_group = self.to_schema(group, connection)
|
||||||
scim_group.id = connection.scim_id
|
scim_group.id = connection.scim_id
|
||||||
try:
|
try:
|
||||||
return self._request(
|
self._request(
|
||||||
"PUT",
|
"PUT",
|
||||||
f"/Groups/{connection.scim_id}",
|
f"/Groups/{connection.scim_id}",
|
||||||
json=scim_group.model_dump(
|
json=scim_group.model_dump(
|
||||||
@ -110,6 +120,8 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
|
|||||||
exclude_unset=True,
|
exclude_unset=True,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
users = list(group.users.order_by("id").values_list("id", flat=True))
|
||||||
|
return self._patch_add_users(group, users)
|
||||||
except NotFoundSyncException:
|
except NotFoundSyncException:
|
||||||
# Resource missing is handled by self.write, which will re-create the group
|
# Resource missing is handled by self.write, which will re-create the group
|
||||||
raise
|
raise
|
||||||
@ -152,7 +164,11 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
|
|||||||
group_id: str,
|
group_id: str,
|
||||||
*ops: PatchOperation,
|
*ops: PatchOperation,
|
||||||
):
|
):
|
||||||
req = PatchRequest(Operations=ops)
|
chunk_size = self._config.bulk.maxOperations
|
||||||
|
if chunk_size < 1:
|
||||||
|
chunk_size = len(ops)
|
||||||
|
for chunk in batched(ops, chunk_size):
|
||||||
|
req = PatchRequest(Operations=list(chunk))
|
||||||
self._request(
|
self._request(
|
||||||
"PATCH",
|
"PATCH",
|
||||||
f"/Groups/{group_id}",
|
f"/Groups/{group_id}",
|
||||||
@ -180,11 +196,14 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
|
|||||||
return
|
return
|
||||||
self._patch(
|
self._patch(
|
||||||
scim_group.scim_id,
|
scim_group.scim_id,
|
||||||
|
*[
|
||||||
PatchOperation(
|
PatchOperation(
|
||||||
op=PatchOp.add,
|
op=PatchOp.add,
|
||||||
path="members",
|
path="members",
|
||||||
value=[{"value": x} for x in user_ids],
|
value=[{"value": x}],
|
||||||
),
|
)
|
||||||
|
for x in user_ids
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
def _patch_remove_users(self, group: Group, users_set: set[int]):
|
def _patch_remove_users(self, group: Group, users_set: set[int]):
|
||||||
@ -206,9 +225,12 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
|
|||||||
return
|
return
|
||||||
self._patch(
|
self._patch(
|
||||||
scim_group.scim_id,
|
scim_group.scim_id,
|
||||||
|
*[
|
||||||
PatchOperation(
|
PatchOperation(
|
||||||
op=PatchOp.remove,
|
op=PatchOp.remove,
|
||||||
path="members",
|
path="members",
|
||||||
value=[{"value": x} for x in user_ids],
|
value=[{"value": x}],
|
||||||
),
|
)
|
||||||
|
for x in user_ids
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
"""Custom SCIM schemas"""
|
"""Custom SCIM schemas"""
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
from pydanticscim.group import Group as BaseGroup
|
from pydanticscim.group import Group as BaseGroup
|
||||||
from pydanticscim.responses import PatchRequest as BasePatchRequest
|
from pydanticscim.responses import PatchRequest as BasePatchRequest
|
||||||
from pydanticscim.responses import SCIMError as BaseSCIMError
|
from pydanticscim.responses import SCIMError as BaseSCIMError
|
||||||
from pydanticscim.service_provider import Bulk, ChangePassword, Filter, Patch, Sort
|
from pydanticscim.service_provider import Bulk as BaseBulk
|
||||||
|
from pydanticscim.service_provider import ChangePassword, Filter, Patch, Sort
|
||||||
from pydanticscim.service_provider import (
|
from pydanticscim.service_provider import (
|
||||||
ServiceProviderConfiguration as BaseServiceProviderConfiguration,
|
ServiceProviderConfiguration as BaseServiceProviderConfiguration,
|
||||||
)
|
)
|
||||||
@ -29,10 +31,16 @@ class Group(BaseGroup):
|
|||||||
meta: dict | None = None
|
meta: dict | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class Bulk(BaseBulk):
|
||||||
|
|
||||||
|
maxOperations: int = Field()
|
||||||
|
|
||||||
|
|
||||||
class ServiceProviderConfiguration(BaseServiceProviderConfiguration):
|
class ServiceProviderConfiguration(BaseServiceProviderConfiguration):
|
||||||
"""ServiceProviderConfig with fallback"""
|
"""ServiceProviderConfig with fallback"""
|
||||||
|
|
||||||
_is_fallback: bool | None = False
|
_is_fallback: bool | None = False
|
||||||
|
bulk: Bulk = Field(..., description="A complex type that specifies bulk configuration options.")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_fallback(self) -> bool:
|
def is_fallback(self) -> bool:
|
||||||
@ -45,7 +53,7 @@ class ServiceProviderConfiguration(BaseServiceProviderConfiguration):
|
|||||||
"""Get default configuration, which doesn't support any optional features as fallback"""
|
"""Get default configuration, which doesn't support any optional features as fallback"""
|
||||||
return ServiceProviderConfiguration(
|
return ServiceProviderConfiguration(
|
||||||
patch=Patch(supported=False),
|
patch=Patch(supported=False),
|
||||||
bulk=Bulk(supported=False),
|
bulk=Bulk(supported=False, maxOperations=0),
|
||||||
filter=Filter(supported=False),
|
filter=Filter(supported=False),
|
||||||
changePassword=ChangePassword(supported=False),
|
changePassword=ChangePassword(supported=False),
|
||||||
sort=Sort(supported=False),
|
sort=Sort(supported=False),
|
||||||
|
Reference in New Issue
Block a user