diff --git a/authentik/providers/scim/clients/group.py b/authentik/providers/scim/clients/group.py index 467bd93037..93457779e6 100644 --- a/authentik/providers/scim/clients/group.py +++ b/authentik/providers/scim/clients/group.py @@ -41,7 +41,7 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]): if not scim_group: self.logger.debug("Group does not exist in SCIM, skipping") return None - response = self._request("DELETE", f"/Groups/{scim_group.id}") + response = self._request("DELETE", f"/Groups/{scim_group.scim_id}") scim_group.delete() return response @@ -89,7 +89,7 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]): for user in connections: members.append( GroupMember( - value=user.id, + value=user.scim_id, ) ) if members: @@ -107,16 +107,19 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]): exclude_unset=True, ), ) - SCIMGroup.objects.create(provider=self.provider, group=group, id=response["id"]) + scim_id = response.get("id") + if not scim_id or scim_id == "": + raise StopSync("SCIM Response with missing or invalid `id`") + SCIMGroup.objects.create(provider=self.provider, group=group, scim_id=scim_id) def _update(self, group: Group, connection: SCIMGroup): """Update existing group""" scim_group = self.to_scim(group) - scim_group.id = connection.id + scim_group.id = connection.scim_id try: return self._request( "PUT", - f"/Groups/{scim_group.id}", + f"/Groups/{connection.scim_id}", json=scim_group.model_dump( mode="json", exclude_unset=True, @@ -185,13 +188,13 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]): return user_ids = list( SCIMUser.objects.filter(user__pk__in=users_set, provider=self.provider).values_list( - "id", flat=True + "scim_id", flat=True ) ) if len(user_ids) < 1: return self._patch( - scim_group.id, + scim_group.scim_id, PatchOperation( op=PatchOp.add, path="members", @@ -211,13 +214,13 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]): return user_ids = list( SCIMUser.objects.filter(user__pk__in=users_set, provider=self.provider).values_list( - "id", flat=True + "scim_id", flat=True ) ) if len(user_ids) < 1: return self._patch( - scim_group.id, + scim_group.scim_id, PatchOperation( op=PatchOp.remove, path="members", diff --git a/authentik/providers/scim/clients/user.py b/authentik/providers/scim/clients/user.py index 84b2df6394..da0b8df69b 100644 --- a/authentik/providers/scim/clients/user.py +++ b/authentik/providers/scim/clients/user.py @@ -34,7 +34,7 @@ class SCIMUserClient(SCIMClient[User, SCIMUserSchema]): if not scim_user: self.logger.debug("User does not exist in SCIM, skipping") return None - response = self._request("DELETE", f"/Users/{scim_user.id}") + response = self._request("DELETE", f"/Users/{scim_user.scim_id}") scim_user.delete() return response @@ -85,15 +85,18 @@ class SCIMUserClient(SCIMClient[User, SCIMUserSchema]): exclude_unset=True, ), ) - SCIMUser.objects.create(provider=self.provider, user=user, id=response["id"]) + scim_id = response.get("id") + if not scim_id or scim_id == "": + raise StopSync("SCIM Response with missing or invalid `id`") + SCIMUser.objects.create(provider=self.provider, user=user, scim_id=scim_id) def _update(self, user: User, connection: SCIMUser): """Update existing user""" scim_user = self.to_scim(user) - scim_user.id = connection.id + scim_user.id = connection.scim_id self._request( "PUT", - f"/Users/{connection.id}", + f"/Users/{connection.scim_id}", json=scim_user.model_dump( mode="json", exclude_unset=True, diff --git a/authentik/providers/scim/migrations/0007_scimgroup_scim_id_scimuser_scim_id_and_more.py b/authentik/providers/scim/migrations/0007_scimgroup_scim_id_scimuser_scim_id_and_more.py new file mode 100644 index 0000000000..61ee46c3dd --- /dev/null +++ b/authentik/providers/scim/migrations/0007_scimgroup_scim_id_scimuser_scim_id_and_more.py @@ -0,0 +1,76 @@ +# Generated by Django 5.0.4 on 2024-05-03 12:38 + +import uuid +from django.db import migrations, models +from django.apps.registry import Apps + +from django.db.backends.base.schema import BaseDatabaseSchemaEditor + +from authentik.lib.migrations import progress_bar + + +def fix_scim_user_group_pk(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): + SCIMUser = apps.get_model("authentik_providers_scim", "SCIMUser") + SCIMGroup = apps.get_model("authentik_providers_scim", "SCIMGroup") + db_alias = schema_editor.connection.alias + print("\nFixing primary key for SCIM users, this might take a couple of minutes...") + for user in progress_bar(SCIMUser.objects.using(db_alias).all()): + SCIMUser.objects.using(db_alias).filter( + pk=user.pk, user=user.user_id, provider=user.provider_id + ).update(scim_id=user.pk, id=uuid.uuid4()) + + print("\nFixing primary key for SCIM groups, this might take a couple of minutes...") + for group in progress_bar(SCIMGroup.objects.using(db_alias).all()): + SCIMGroup.objects.using(db_alias).filter( + pk=group.pk, group=group.group_id, provider=group.provider_id + ).update(scim_id=group.pk, id=uuid.uuid4()) + + +class Migration(migrations.Migration): + + dependencies = [ + ( + "authentik_providers_scim", + "0001_squashed_0006_rename_parent_group_scimprovider_filter_group", + ), + ] + + operations = [ + migrations.AddField( + model_name="scimgroup", + name="scim_id", + field=models.TextField(default="temp"), + preserve_default=False, + ), + migrations.AddField( + model_name="scimuser", + name="scim_id", + field=models.TextField(default="temp"), + preserve_default=False, + ), + migrations.RunPython(fix_scim_user_group_pk), + migrations.AlterField( + model_name="scimgroup", + name="id", + field=models.UUIDField( + default=uuid.uuid4, editable=False, primary_key=True, serialize=False + ), + ), + migrations.AlterField( + model_name="scimuser", + name="id", + field=models.UUIDField( + default=uuid.uuid4, editable=False, primary_key=True, serialize=False + ), + ), + migrations.AlterField(model_name="scimuser", name="scim_id", field=models.TextField()), + migrations.AlterField(model_name="scimgroup", name="scim_id", field=models.TextField()), + migrations.AlterUniqueTogether( + name="scimgroup", + unique_together={("scim_id", "group", "provider")}, + ), + migrations.AlterUniqueTogether( + name="scimuser", + unique_together={("scim_id", "user", "provider")}, + ), + ] diff --git a/authentik/providers/scim/models.py b/authentik/providers/scim/models.py index e85d796a79..016586e681 100644 --- a/authentik/providers/scim/models.py +++ b/authentik/providers/scim/models.py @@ -1,5 +1,7 @@ """SCIM Provider models""" +from uuid import uuid4 + from django.core.cache import cache from django.db import models from django.db.models import QuerySet @@ -97,12 +99,13 @@ class SCIMMapping(PropertyMapping): class SCIMUser(models.Model): """Mapping of a user and provider to a SCIM user ID""" - id = models.TextField(primary_key=True) + id = models.UUIDField(primary_key=True, editable=False, default=uuid4) + scim_id = models.TextField() user = models.ForeignKey(User, on_delete=models.CASCADE) provider = models.ForeignKey(SCIMProvider, on_delete=models.CASCADE) class Meta: - unique_together = (("id", "user", "provider"),) + unique_together = (("scim_id", "user", "provider"),) def __str__(self) -> str: return f"SCIM User {self.user_id} to {self.provider_id}" @@ -111,12 +114,13 @@ class SCIMUser(models.Model): class SCIMGroup(models.Model): """Mapping of a group and provider to a SCIM user ID""" - id = models.TextField(primary_key=True) + id = models.UUIDField(primary_key=True, editable=False, default=uuid4) + scim_id = models.TextField() group = models.ForeignKey(Group, on_delete=models.CASCADE) provider = models.ForeignKey(SCIMProvider, on_delete=models.CASCADE) class Meta: - unique_together = (("id", "group", "provider"),) + unique_together = (("scim_id", "group", "provider"),) def __str__(self) -> str: return f"SCIM Group {self.group_id} to {self.provider_id}" diff --git a/authentik/providers/scim/tests/test_user.py b/authentik/providers/scim/tests/test_user.py index bc1b3817f0..6d0f1e1f48 100644 --- a/authentik/providers/scim/tests/test_user.py +++ b/authentik/providers/scim/tests/test_user.py @@ -88,6 +88,72 @@ class SCIMUserTests(TestCase): }, ) + @Mocker() + def test_user_create_different_provider_same_id(self, mock: Mocker): + """Test user creation with multiple providers that happen + to return the same object ID""" + # Create duplicate provider + provider: SCIMProvider = SCIMProvider.objects.create( + name=generate_id(), + url="https://localhost", + token=generate_id(), + exclude_users_service_account=True, + ) + app: Application = Application.objects.create( + name=generate_id(), + slug=generate_id(), + ) + app.backchannel_providers.add(provider) + provider.property_mappings.add( + SCIMMapping.objects.get(managed="goauthentik.io/providers/scim/user") + ) + provider.property_mappings_group.add( + SCIMMapping.objects.get(managed="goauthentik.io/providers/scim/group") + ) + + scim_id = generate_id() + mock.get( + "https://localhost/ServiceProviderConfig", + json={}, + ) + mock.post( + "https://localhost/Users", + json={ + "id": scim_id, + }, + ) + uid = generate_id() + user = User.objects.create( + username=uid, + name=f"{uid} {uid}", + email=f"{uid}@goauthentik.io", + ) + self.assertEqual(mock.call_count, 4) + self.assertEqual(mock.request_history[0].method, "GET") + self.assertEqual(mock.request_history[1].method, "POST") + self.assertJSONEqual( + mock.request_history[1].body, + { + "schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"], + "active": True, + "emails": [ + { + "primary": True, + "type": "other", + "value": f"{uid}@goauthentik.io", + } + ], + "externalId": user.uid, + "name": { + "familyName": uid, + "formatted": f"{uid} {uid}", + "givenName": uid, + }, + "displayName": f"{uid} {uid}", + "userName": uid, + }, + ) + @Mocker() def test_user_create_update(self, mock: Mocker): """Test user creation and update"""