providers/scim: add comparison with existing group on update and delta update users (cherry-pick #11414) (#11796)

providers/scim: add comparison with existing group on update and delta update users (#11414)

* fix incorrect default group mapping



* providers/scim: add comparison with existing group on update and delta update users



* fix



* fix



* fix another exception when creating groups



* fix users to add check



---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
Co-authored-by: Jens L. <jens@goauthentik.io>
This commit is contained in:
gcp-cherry-pick-bot[bot]
2024-10-24 18:28:06 +02:00
committed by GitHub
parent eab3d9b411
commit de9fc5de6b
5 changed files with 249 additions and 40 deletions

View File

@ -21,7 +21,14 @@ class DebugSession(Session):
def send(self, req: PreparedRequest, *args, **kwargs):
request_id = str(uuid4())
LOGGER.debug("HTTP request sent", uid=request_id, path=req.path_url, headers=req.headers)
LOGGER.debug(
"HTTP request sent",
uid=request_id,
url=req.url,
method=req.method,
headers=req.headers,
body=req.body,
)
resp = super().send(req, *args, **kwargs)
LOGGER.debug(
"HTTP response received",

View File

@ -2,9 +2,10 @@
from itertools import batched
from django.db import transaction
from pydantic import ValidationError
from pydanticscim.group import GroupMember
from pydanticscim.responses import PatchOp, PatchOperation
from pydanticscim.responses import PatchOp
from authentik.core.models import Group
from authentik.lib.sync.mapper import PropertyMappingManager
@ -19,7 +20,7 @@ from authentik.providers.scim.clients.base import SCIMClient
from authentik.providers.scim.clients.exceptions import (
SCIMRequestException,
)
from authentik.providers.scim.clients.schema import SCIM_GROUP_SCHEMA, PatchRequest
from authentik.providers.scim.clients.schema import SCIM_GROUP_SCHEMA, PatchOperation, PatchRequest
from authentik.providers.scim.clients.schema import Group as SCIMGroupSchema
from authentik.providers.scim.models import (
SCIMMapping,
@ -104,13 +105,47 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
provider=self.provider, group=group, scim_id=scim_id
)
users = list(group.users.order_by("id").values_list("id", flat=True))
self._patch_add_users(group, users)
self._patch_add_users(connection, users)
return connection
def update(self, group: Group, connection: SCIMProviderGroup):
"""Update existing group"""
scim_group = self.to_schema(group, connection)
scim_group.id = connection.scim_id
try:
if self._config.patch.supported:
return self._update_patch(group, scim_group, connection)
return self._update_put(group, scim_group, connection)
except NotFoundSyncException:
# Resource missing is handled by self.write, which will re-create the group
raise
def _update_patch(
self, group: Group, scim_group: SCIMGroupSchema, connection: SCIMProviderGroup
):
"""Update a group via PATCH request"""
# Patch group's attributes instead of replacing it and re-adding users if we can
self._request(
"PATCH",
f"/Groups/{connection.scim_id}",
json=PatchRequest(
Operations=[
PatchOperation(
op=PatchOp.replace,
path=None,
value=scim_group.model_dump(mode="json", exclude_unset=True),
)
]
).model_dump(
mode="json",
exclude_unset=True,
exclude_none=True,
),
)
return self.patch_compare_users(group)
def _update_put(self, group: Group, scim_group: SCIMGroupSchema, connection: SCIMProviderGroup):
"""Update a group via PUT request"""
try:
self._request(
"PUT",
@ -120,33 +155,25 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
exclude_unset=True,
),
)
users = list(group.users.order_by("id").values_list("id", flat=True))
return self._patch_add_users(group, users)
except NotFoundSyncException:
# Resource missing is handled by self.write, which will re-create the group
raise
return self.patch_compare_users(group)
except (SCIMRequestException, ObjectExistsSyncException):
# Some providers don't support PUT on groups, so this is mainly a fix for the initial
# sync, send patch add requests for all the users the group currently has
users = list(group.users.order_by("id").values_list("id", flat=True))
self._patch_add_users(group, users)
# Also update the group name
return self._patch(
scim_group.id,
PatchOperation(
op=PatchOp.replace,
path="displayName",
value=scim_group.displayName,
),
)
return self._update_patch(group, scim_group, connection)
def update_group(self, group: Group, action: Direction, users_set: set[int]):
"""Update a group, either using PUT to replace it or PATCH if supported"""
scim_group = SCIMProviderGroup.objects.filter(provider=self.provider, group=group).first()
if not scim_group:
self.logger.warning(
"could not sync group membership, group does not exist", group=group
)
return
if self._config.patch.supported:
if action == Direction.add:
return self._patch_add_users(group, users_set)
return self._patch_add_users(scim_group, users_set)
if action == Direction.remove:
return self._patch_remove_users(group, users_set)
return self._patch_remove_users(scim_group, users_set)
try:
return self.write(group)
except SCIMRequestException as exc:
@ -154,16 +181,19 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
# Assume that provider does not support PUT and also doesn't support
# ServiceProviderConfig, so try PATCH as a fallback
if action == Direction.add:
return self._patch_add_users(group, users_set)
return self._patch_add_users(scim_group, users_set)
if action == Direction.remove:
return self._patch_remove_users(group, users_set)
return self._patch_remove_users(scim_group, users_set)
raise exc
def _patch(
def _patch_chunked(
self,
group_id: str,
*ops: PatchOperation,
):
"""Helper function that chunks patch requests based on the maxOperations attribute.
This is not strictly according to specs but there's nothing in the schema that allows the
us to know what the maximum patch operations per request should be."""
chunk_size = self._config.bulk.maxOperations
if chunk_size < 1:
chunk_size = len(ops)
@ -177,16 +207,67 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
),
)
def _patch_add_users(self, group: Group, users_set: set[int]):
"""Add users in users_set to group"""
if len(users_set) < 1:
return
@transaction.atomic
def patch_compare_users(self, group: Group):
"""Compare users with a SCIM group and add/remove any differences"""
# Get scim group first
scim_group = SCIMProviderGroup.objects.filter(provider=self.provider, group=group).first()
if not scim_group:
self.logger.warning(
"could not sync group membership, group does not exist", group=group
)
return
# Get a list of all users in the authentik group
raw_users_should = list(group.users.order_by("id").values_list("id", flat=True))
# Lookup the SCIM IDs of the users
users_should: list[str] = list(
SCIMProviderUser.objects.filter(
user__pk__in=raw_users_should, provider=self.provider
).values_list("scim_id", flat=True)
)
if len(raw_users_should) != len(users_should):
self.logger.warning(
"User count mismatch, not all users in the group are synced to SCIM yet.",
group=group,
)
# Get current group status
current_group = SCIMGroupSchema.model_validate(
self._request("GET", f"/Groups/{scim_group.scim_id}")
)
users_to_add = []
users_to_remove = []
# Check users currently in group and if they shouldn't be in the group and remove them
for user in current_group.members:
if user.value not in users_should:
users_to_remove.append(user.value)
# Check users that should be in the group and add them
for user in users_should:
if len([x for x in current_group.members if x.value == user]) < 1:
users_to_add.append(user)
return self._patch_chunked(
scim_group.scim_id,
*[
PatchOperation(
op=PatchOp.add,
path="members",
value=[{"value": x}],
)
for x in users_to_add
],
*[
PatchOperation(
op=PatchOp.remove,
path="members",
value=[{"value": x}],
)
for x in users_to_remove
],
)
def _patch_add_users(self, scim_group: SCIMProviderGroup, users_set: set[int]):
"""Add users in users_set to group"""
if len(users_set) < 1:
return
user_ids = list(
SCIMProviderUser.objects.filter(
user__pk__in=users_set, provider=self.provider
@ -194,7 +275,7 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
)
if len(user_ids) < 1:
return
self._patch(
self._patch_chunked(
scim_group.scim_id,
*[
PatchOperation(
@ -206,16 +287,10 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
],
)
def _patch_remove_users(self, group: Group, users_set: set[int]):
def _patch_remove_users(self, scim_group: SCIMProviderGroup, users_set: set[int]):
"""Remove users in users_set from group"""
if len(users_set) < 1:
return
scim_group = SCIMProviderGroup.objects.filter(provider=self.provider, group=group).first()
if not scim_group:
self.logger.warning(
"could not sync group membership, group does not exist", group=group
)
return
user_ids = list(
SCIMProviderUser.objects.filter(
user__pk__in=users_set, provider=self.provider
@ -223,7 +298,7 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
)
if len(user_ids) < 1:
return
self._patch(
self._patch_chunked(
scim_group.scim_id,
*[
PatchOperation(

View File

@ -2,6 +2,7 @@
from pydantic import Field
from pydanticscim.group import Group as BaseGroup
from pydanticscim.responses import PatchOperation as BasePatchOperation
from pydanticscim.responses import PatchRequest as BasePatchRequest
from pydanticscim.responses import SCIMError as BaseSCIMError
from pydanticscim.service_provider import Bulk as BaseBulk
@ -68,6 +69,12 @@ class PatchRequest(BasePatchRequest):
schemas: tuple[str] = ("urn:ietf:params:scim:api:messages:2.0:PatchOp",)
class PatchOperation(BasePatchOperation):
"""PatchOperation with optional path"""
path: str | None
class SCIMError(BaseSCIMError):
"""SCIM error with optional status code"""

View File

@ -252,3 +252,118 @@ class SCIMMembershipTests(TestCase):
],
},
)
def test_member_add_save(self):
"""Test member add + save"""
config = ServiceProviderConfiguration.default()
config.patch.supported = True
user_scim_id = generate_id()
group_scim_id = generate_id()
uid = generate_id()
group = Group.objects.create(
name=uid,
)
user = User.objects.create(username=generate_id())
# Test initial sync of group creation
with Mocker() as mocker:
mocker.get(
"https://localhost/ServiceProviderConfig",
json=config.model_dump(),
)
mocker.post(
"https://localhost/Users",
json={
"id": user_scim_id,
},
)
mocker.post(
"https://localhost/Groups",
json={
"id": group_scim_id,
},
)
self.configure()
sync_tasks.trigger_single_task(self.provider, scim_sync).get()
self.assertEqual(mocker.call_count, 6)
self.assertEqual(mocker.request_history[0].method, "GET")
self.assertEqual(mocker.request_history[1].method, "GET")
self.assertEqual(mocker.request_history[2].method, "GET")
self.assertEqual(mocker.request_history[3].method, "POST")
self.assertEqual(mocker.request_history[4].method, "GET")
self.assertEqual(mocker.request_history[5].method, "POST")
self.assertJSONEqual(
mocker.request_history[3].body,
{
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
"emails": [],
"active": True,
"externalId": user.uid,
"name": {"familyName": " ", "formatted": " ", "givenName": ""},
"displayName": "",
"userName": user.username,
},
)
self.assertJSONEqual(
mocker.request_history[5].body,
{
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"],
"externalId": str(group.pk),
"displayName": group.name,
},
)
with Mocker() as mocker:
mocker.get(
"https://localhost/ServiceProviderConfig",
json=config.model_dump(),
)
mocker.get(
f"https://localhost/Groups/{group_scim_id}",
json={},
)
mocker.patch(
f"https://localhost/Groups/{group_scim_id}",
json={},
)
group.users.add(user)
group.save()
self.assertEqual(mocker.call_count, 5)
self.assertEqual(mocker.request_history[0].method, "GET")
self.assertEqual(mocker.request_history[1].method, "PATCH")
self.assertEqual(mocker.request_history[2].method, "GET")
self.assertEqual(mocker.request_history[3].method, "PATCH")
self.assertEqual(mocker.request_history[4].method, "GET")
self.assertJSONEqual(
mocker.request_history[1].body,
{
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
"Operations": [
{
"op": "add",
"path": "members",
"value": [{"value": user_scim_id}],
}
],
},
)
self.assertJSONEqual(
mocker.request_history[3].body,
{
"Operations": [
{
"op": "replace",
"value": {
"id": group_scim_id,
"displayName": group.name,
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"],
"externalId": str(group.pk),
},
}
]
},
)

View File

@ -38,12 +38,15 @@ export async function scimPropertyMappingsProvider(page = 1, search = "") {
};
}
export function makeSCIMPropertyMappingsSelector(instanceMappings: string[] | undefined) {
export function makeSCIMPropertyMappingsSelector(
instanceMappings: string[] | undefined,
defaultSelected: string,
) {
const localMappings = instanceMappings ? new Set(instanceMappings) : undefined;
return localMappings
? ([pk, _]: DualSelectPair) => localMappings.has(pk)
: ([_0, _1, _2, mapping]: DualSelectPair<SCIMMapping>) =>
mapping?.managed === "goauthentik.io/providers/scim/user";
mapping?.managed === defaultSelected;
}
@customElement("ak-provider-scim-form")
@ -172,6 +175,7 @@ export class SCIMProviderFormPage extends BaseProviderForm<SCIMProvider> {
.provider=${scimPropertyMappingsProvider}
.selector=${makeSCIMPropertyMappingsSelector(
this.instance?.propertyMappings,
"goauthentik.io/providers/scim/user",
)}
available-label=${msg("Available User Property Mappings")}
selected-label=${msg("Selected User Property Mappings")}
@ -188,6 +192,7 @@ export class SCIMProviderFormPage extends BaseProviderForm<SCIMProvider> {
.provider=${scimPropertyMappingsProvider}
.selector=${makeSCIMPropertyMappingsSelector(
this.instance?.propertyMappingsGroup,
"goauthentik.io/providers/scim/group",
)}
available-label=${msg("Available Group Property Mappings")}
selected-label=${msg("Selected Group Property Mappings")}