providers/scim: fix SCIM ID incorrectly used as primary key (#9557)

* providers/scim: fix SCIM ID incorrectly used as primary key

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix unique together

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add test

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add check for empty scim ID

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens L
2024-05-05 17:32:19 +02:00
committed by GitHub
parent 26daaeb57d
commit 3c54e94c6e
5 changed files with 169 additions and 17 deletions

View File

@ -41,7 +41,7 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]):
if not scim_group: if not scim_group:
self.logger.debug("Group does not exist in SCIM, skipping") self.logger.debug("Group does not exist in SCIM, skipping")
return None return None
response = self._request("DELETE", f"/Groups/{scim_group.id}") response = self._request("DELETE", f"/Groups/{scim_group.scim_id}")
scim_group.delete() scim_group.delete()
return response return response
@ -89,7 +89,7 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]):
for user in connections: for user in connections:
members.append( members.append(
GroupMember( GroupMember(
value=user.id, value=user.scim_id,
) )
) )
if members: if members:
@ -107,16 +107,19 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]):
exclude_unset=True, exclude_unset=True,
), ),
) )
SCIMGroup.objects.create(provider=self.provider, group=group, id=response["id"]) scim_id = response.get("id")
if not scim_id or scim_id == "":
raise StopSync("SCIM Response with missing or invalid `id`")
SCIMGroup.objects.create(provider=self.provider, group=group, scim_id=scim_id)
def _update(self, group: Group, connection: SCIMGroup): def _update(self, group: Group, connection: SCIMGroup):
"""Update existing group""" """Update existing group"""
scim_group = self.to_scim(group) scim_group = self.to_scim(group)
scim_group.id = connection.id scim_group.id = connection.scim_id
try: try:
return self._request( return self._request(
"PUT", "PUT",
f"/Groups/{scim_group.id}", f"/Groups/{connection.scim_id}",
json=scim_group.model_dump( json=scim_group.model_dump(
mode="json", mode="json",
exclude_unset=True, exclude_unset=True,
@ -185,13 +188,13 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]):
return return
user_ids = list( user_ids = list(
SCIMUser.objects.filter(user__pk__in=users_set, provider=self.provider).values_list( SCIMUser.objects.filter(user__pk__in=users_set, provider=self.provider).values_list(
"id", flat=True "scim_id", flat=True
) )
) )
if len(user_ids) < 1: if len(user_ids) < 1:
return return
self._patch( self._patch(
scim_group.id, scim_group.scim_id,
PatchOperation( PatchOperation(
op=PatchOp.add, op=PatchOp.add,
path="members", path="members",
@ -211,13 +214,13 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]):
return return
user_ids = list( user_ids = list(
SCIMUser.objects.filter(user__pk__in=users_set, provider=self.provider).values_list( SCIMUser.objects.filter(user__pk__in=users_set, provider=self.provider).values_list(
"id", flat=True "scim_id", flat=True
) )
) )
if len(user_ids) < 1: if len(user_ids) < 1:
return return
self._patch( self._patch(
scim_group.id, scim_group.scim_id,
PatchOperation( PatchOperation(
op=PatchOp.remove, op=PatchOp.remove,
path="members", path="members",

View File

@ -34,7 +34,7 @@ class SCIMUserClient(SCIMClient[User, SCIMUserSchema]):
if not scim_user: if not scim_user:
self.logger.debug("User does not exist in SCIM, skipping") self.logger.debug("User does not exist in SCIM, skipping")
return None return None
response = self._request("DELETE", f"/Users/{scim_user.id}") response = self._request("DELETE", f"/Users/{scim_user.scim_id}")
scim_user.delete() scim_user.delete()
return response return response
@ -85,15 +85,18 @@ class SCIMUserClient(SCIMClient[User, SCIMUserSchema]):
exclude_unset=True, exclude_unset=True,
), ),
) )
SCIMUser.objects.create(provider=self.provider, user=user, id=response["id"]) scim_id = response.get("id")
if not scim_id or scim_id == "":
raise StopSync("SCIM Response with missing or invalid `id`")
SCIMUser.objects.create(provider=self.provider, user=user, scim_id=scim_id)
def _update(self, user: User, connection: SCIMUser): def _update(self, user: User, connection: SCIMUser):
"""Update existing user""" """Update existing user"""
scim_user = self.to_scim(user) scim_user = self.to_scim(user)
scim_user.id = connection.id scim_user.id = connection.scim_id
self._request( self._request(
"PUT", "PUT",
f"/Users/{connection.id}", f"/Users/{connection.scim_id}",
json=scim_user.model_dump( json=scim_user.model_dump(
mode="json", mode="json",
exclude_unset=True, exclude_unset=True,

View File

@ -0,0 +1,76 @@
# Generated by Django 5.0.4 on 2024-05-03 12:38
import uuid
from django.db import migrations, models
from django.apps.registry import Apps
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
from authentik.lib.migrations import progress_bar
def fix_scim_user_group_pk(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
SCIMUser = apps.get_model("authentik_providers_scim", "SCIMUser")
SCIMGroup = apps.get_model("authentik_providers_scim", "SCIMGroup")
db_alias = schema_editor.connection.alias
print("\nFixing primary key for SCIM users, this might take a couple of minutes...")
for user in progress_bar(SCIMUser.objects.using(db_alias).all()):
SCIMUser.objects.using(db_alias).filter(
pk=user.pk, user=user.user_id, provider=user.provider_id
).update(scim_id=user.pk, id=uuid.uuid4())
print("\nFixing primary key for SCIM groups, this might take a couple of minutes...")
for group in progress_bar(SCIMGroup.objects.using(db_alias).all()):
SCIMGroup.objects.using(db_alias).filter(
pk=group.pk, group=group.group_id, provider=group.provider_id
).update(scim_id=group.pk, id=uuid.uuid4())
class Migration(migrations.Migration):
dependencies = [
(
"authentik_providers_scim",
"0001_squashed_0006_rename_parent_group_scimprovider_filter_group",
),
]
operations = [
migrations.AddField(
model_name="scimgroup",
name="scim_id",
field=models.TextField(default="temp"),
preserve_default=False,
),
migrations.AddField(
model_name="scimuser",
name="scim_id",
field=models.TextField(default="temp"),
preserve_default=False,
),
migrations.RunPython(fix_scim_user_group_pk),
migrations.AlterField(
model_name="scimgroup",
name="id",
field=models.UUIDField(
default=uuid.uuid4, editable=False, primary_key=True, serialize=False
),
),
migrations.AlterField(
model_name="scimuser",
name="id",
field=models.UUIDField(
default=uuid.uuid4, editable=False, primary_key=True, serialize=False
),
),
migrations.AlterField(model_name="scimuser", name="scim_id", field=models.TextField()),
migrations.AlterField(model_name="scimgroup", name="scim_id", field=models.TextField()),
migrations.AlterUniqueTogether(
name="scimgroup",
unique_together={("scim_id", "group", "provider")},
),
migrations.AlterUniqueTogether(
name="scimuser",
unique_together={("scim_id", "user", "provider")},
),
]

View File

@ -1,5 +1,7 @@
"""SCIM Provider models""" """SCIM Provider models"""
from uuid import uuid4
from django.core.cache import cache from django.core.cache import cache
from django.db import models from django.db import models
from django.db.models import QuerySet from django.db.models import QuerySet
@ -97,12 +99,13 @@ class SCIMMapping(PropertyMapping):
class SCIMUser(models.Model): class SCIMUser(models.Model):
"""Mapping of a user and provider to a SCIM user ID""" """Mapping of a user and provider to a SCIM user ID"""
id = models.TextField(primary_key=True) id = models.UUIDField(primary_key=True, editable=False, default=uuid4)
scim_id = models.TextField()
user = models.ForeignKey(User, on_delete=models.CASCADE) user = models.ForeignKey(User, on_delete=models.CASCADE)
provider = models.ForeignKey(SCIMProvider, on_delete=models.CASCADE) provider = models.ForeignKey(SCIMProvider, on_delete=models.CASCADE)
class Meta: class Meta:
unique_together = (("id", "user", "provider"),) unique_together = (("scim_id", "user", "provider"),)
def __str__(self) -> str: def __str__(self) -> str:
return f"SCIM User {self.user_id} to {self.provider_id}" return f"SCIM User {self.user_id} to {self.provider_id}"
@ -111,12 +114,13 @@ class SCIMUser(models.Model):
class SCIMGroup(models.Model): class SCIMGroup(models.Model):
"""Mapping of a group and provider to a SCIM user ID""" """Mapping of a group and provider to a SCIM user ID"""
id = models.TextField(primary_key=True) id = models.UUIDField(primary_key=True, editable=False, default=uuid4)
scim_id = models.TextField()
group = models.ForeignKey(Group, on_delete=models.CASCADE) group = models.ForeignKey(Group, on_delete=models.CASCADE)
provider = models.ForeignKey(SCIMProvider, on_delete=models.CASCADE) provider = models.ForeignKey(SCIMProvider, on_delete=models.CASCADE)
class Meta: class Meta:
unique_together = (("id", "group", "provider"),) unique_together = (("scim_id", "group", "provider"),)
def __str__(self) -> str: def __str__(self) -> str:
return f"SCIM Group {self.group_id} to {self.provider_id}" return f"SCIM Group {self.group_id} to {self.provider_id}"

View File

@ -88,6 +88,72 @@ class SCIMUserTests(TestCase):
}, },
) )
@Mocker()
def test_user_create_different_provider_same_id(self, mock: Mocker):
"""Test user creation with multiple providers that happen
to return the same object ID"""
# Create duplicate provider
provider: SCIMProvider = SCIMProvider.objects.create(
name=generate_id(),
url="https://localhost",
token=generate_id(),
exclude_users_service_account=True,
)
app: Application = Application.objects.create(
name=generate_id(),
slug=generate_id(),
)
app.backchannel_providers.add(provider)
provider.property_mappings.add(
SCIMMapping.objects.get(managed="goauthentik.io/providers/scim/user")
)
provider.property_mappings_group.add(
SCIMMapping.objects.get(managed="goauthentik.io/providers/scim/group")
)
scim_id = generate_id()
mock.get(
"https://localhost/ServiceProviderConfig",
json={},
)
mock.post(
"https://localhost/Users",
json={
"id": scim_id,
},
)
uid = generate_id()
user = User.objects.create(
username=uid,
name=f"{uid} {uid}",
email=f"{uid}@goauthentik.io",
)
self.assertEqual(mock.call_count, 4)
self.assertEqual(mock.request_history[0].method, "GET")
self.assertEqual(mock.request_history[1].method, "POST")
self.assertJSONEqual(
mock.request_history[1].body,
{
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
"active": True,
"emails": [
{
"primary": True,
"type": "other",
"value": f"{uid}@goauthentik.io",
}
],
"externalId": user.uid,
"name": {
"familyName": uid,
"formatted": f"{uid} {uid}",
"givenName": uid,
},
"displayName": f"{uid} {uid}",
"userName": uid,
},
)
@Mocker() @Mocker()
def test_user_create_update(self, mock: Mocker): def test_user_create_update(self, mock: Mocker):
"""Test user creation and update""" """Test user creation and update"""