providers/scim: fix SCIM ID incorrectly used as primary key (#9557)
* providers/scim: fix SCIM ID incorrectly used as primary key Signed-off-by: Jens Langhammer <jens@goauthentik.io> * fix unique together Signed-off-by: Jens Langhammer <jens@goauthentik.io> * add test Signed-off-by: Jens Langhammer <jens@goauthentik.io> * add check for empty scim ID Signed-off-by: Jens Langhammer <jens@goauthentik.io> --------- Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
		| @ -41,7 +41,7 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]): | ||||
|         if not scim_group: | ||||
|             self.logger.debug("Group does not exist in SCIM, skipping") | ||||
|             return None | ||||
|         response = self._request("DELETE", f"/Groups/{scim_group.id}") | ||||
|         response = self._request("DELETE", f"/Groups/{scim_group.scim_id}") | ||||
|         scim_group.delete() | ||||
|         return response | ||||
|  | ||||
| @ -89,7 +89,7 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]): | ||||
|         for user in connections: | ||||
|             members.append( | ||||
|                 GroupMember( | ||||
|                     value=user.id, | ||||
|                     value=user.scim_id, | ||||
|                 ) | ||||
|             ) | ||||
|         if members: | ||||
| @ -107,16 +107,19 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]): | ||||
|                 exclude_unset=True, | ||||
|             ), | ||||
|         ) | ||||
|         SCIMGroup.objects.create(provider=self.provider, group=group, id=response["id"]) | ||||
|         scim_id = response.get("id") | ||||
|         if not scim_id or scim_id == "": | ||||
|             raise StopSync("SCIM Response with missing or invalid `id`") | ||||
|         SCIMGroup.objects.create(provider=self.provider, group=group, scim_id=scim_id) | ||||
|  | ||||
|     def _update(self, group: Group, connection: SCIMGroup): | ||||
|         """Update existing group""" | ||||
|         scim_group = self.to_scim(group) | ||||
|         scim_group.id = connection.id | ||||
|         scim_group.id = connection.scim_id | ||||
|         try: | ||||
|             return self._request( | ||||
|                 "PUT", | ||||
|                 f"/Groups/{scim_group.id}", | ||||
|                 f"/Groups/{connection.scim_id}", | ||||
|                 json=scim_group.model_dump( | ||||
|                     mode="json", | ||||
|                     exclude_unset=True, | ||||
| @ -185,13 +188,13 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]): | ||||
|             return | ||||
|         user_ids = list( | ||||
|             SCIMUser.objects.filter(user__pk__in=users_set, provider=self.provider).values_list( | ||||
|                 "id", flat=True | ||||
|                 "scim_id", flat=True | ||||
|             ) | ||||
|         ) | ||||
|         if len(user_ids) < 1: | ||||
|             return | ||||
|         self._patch( | ||||
|             scim_group.id, | ||||
|             scim_group.scim_id, | ||||
|             PatchOperation( | ||||
|                 op=PatchOp.add, | ||||
|                 path="members", | ||||
| @ -211,13 +214,13 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]): | ||||
|             return | ||||
|         user_ids = list( | ||||
|             SCIMUser.objects.filter(user__pk__in=users_set, provider=self.provider).values_list( | ||||
|                 "id", flat=True | ||||
|                 "scim_id", flat=True | ||||
|             ) | ||||
|         ) | ||||
|         if len(user_ids) < 1: | ||||
|             return | ||||
|         self._patch( | ||||
|             scim_group.id, | ||||
|             scim_group.scim_id, | ||||
|             PatchOperation( | ||||
|                 op=PatchOp.remove, | ||||
|                 path="members", | ||||
|  | ||||
| @ -34,7 +34,7 @@ class SCIMUserClient(SCIMClient[User, SCIMUserSchema]): | ||||
|         if not scim_user: | ||||
|             self.logger.debug("User does not exist in SCIM, skipping") | ||||
|             return None | ||||
|         response = self._request("DELETE", f"/Users/{scim_user.id}") | ||||
|         response = self._request("DELETE", f"/Users/{scim_user.scim_id}") | ||||
|         scim_user.delete() | ||||
|         return response | ||||
|  | ||||
| @ -85,15 +85,18 @@ class SCIMUserClient(SCIMClient[User, SCIMUserSchema]): | ||||
|                 exclude_unset=True, | ||||
|             ), | ||||
|         ) | ||||
|         SCIMUser.objects.create(provider=self.provider, user=user, id=response["id"]) | ||||
|         scim_id = response.get("id") | ||||
|         if not scim_id or scim_id == "": | ||||
|             raise StopSync("SCIM Response with missing or invalid `id`") | ||||
|         SCIMUser.objects.create(provider=self.provider, user=user, scim_id=scim_id) | ||||
|  | ||||
|     def _update(self, user: User, connection: SCIMUser): | ||||
|         """Update existing user""" | ||||
|         scim_user = self.to_scim(user) | ||||
|         scim_user.id = connection.id | ||||
|         scim_user.id = connection.scim_id | ||||
|         self._request( | ||||
|             "PUT", | ||||
|             f"/Users/{connection.id}", | ||||
|             f"/Users/{connection.scim_id}", | ||||
|             json=scim_user.model_dump( | ||||
|                 mode="json", | ||||
|                 exclude_unset=True, | ||||
|  | ||||
| @ -0,0 +1,76 @@ | ||||
| # Generated by Django 5.0.4 on 2024-05-03 12:38 | ||||
|  | ||||
| import uuid | ||||
| from django.db import migrations, models | ||||
| from django.apps.registry import Apps | ||||
|  | ||||
| from django.db.backends.base.schema import BaseDatabaseSchemaEditor | ||||
|  | ||||
| from authentik.lib.migrations import progress_bar | ||||
|  | ||||
|  | ||||
| def fix_scim_user_group_pk(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): | ||||
|     SCIMUser = apps.get_model("authentik_providers_scim", "SCIMUser") | ||||
|     SCIMGroup = apps.get_model("authentik_providers_scim", "SCIMGroup") | ||||
|     db_alias = schema_editor.connection.alias | ||||
|     print("\nFixing primary key for SCIM users, this might take a couple of minutes...") | ||||
|     for user in progress_bar(SCIMUser.objects.using(db_alias).all()): | ||||
|         SCIMUser.objects.using(db_alias).filter( | ||||
|             pk=user.pk, user=user.user_id, provider=user.provider_id | ||||
|         ).update(scim_id=user.pk, id=uuid.uuid4()) | ||||
|  | ||||
|     print("\nFixing primary key for SCIM groups, this might take a couple of minutes...") | ||||
|     for group in progress_bar(SCIMGroup.objects.using(db_alias).all()): | ||||
|         SCIMGroup.objects.using(db_alias).filter( | ||||
|             pk=group.pk, group=group.group_id, provider=group.provider_id | ||||
|         ).update(scim_id=group.pk, id=uuid.uuid4()) | ||||
|  | ||||
|  | ||||
| class Migration(migrations.Migration): | ||||
|  | ||||
|     dependencies = [ | ||||
|         ( | ||||
|             "authentik_providers_scim", | ||||
|             "0001_squashed_0006_rename_parent_group_scimprovider_filter_group", | ||||
|         ), | ||||
|     ] | ||||
|  | ||||
|     operations = [ | ||||
|         migrations.AddField( | ||||
|             model_name="scimgroup", | ||||
|             name="scim_id", | ||||
|             field=models.TextField(default="temp"), | ||||
|             preserve_default=False, | ||||
|         ), | ||||
|         migrations.AddField( | ||||
|             model_name="scimuser", | ||||
|             name="scim_id", | ||||
|             field=models.TextField(default="temp"), | ||||
|             preserve_default=False, | ||||
|         ), | ||||
|         migrations.RunPython(fix_scim_user_group_pk), | ||||
|         migrations.AlterField( | ||||
|             model_name="scimgroup", | ||||
|             name="id", | ||||
|             field=models.UUIDField( | ||||
|                 default=uuid.uuid4, editable=False, primary_key=True, serialize=False | ||||
|             ), | ||||
|         ), | ||||
|         migrations.AlterField( | ||||
|             model_name="scimuser", | ||||
|             name="id", | ||||
|             field=models.UUIDField( | ||||
|                 default=uuid.uuid4, editable=False, primary_key=True, serialize=False | ||||
|             ), | ||||
|         ), | ||||
|         migrations.AlterField(model_name="scimuser", name="scim_id", field=models.TextField()), | ||||
|         migrations.AlterField(model_name="scimgroup", name="scim_id", field=models.TextField()), | ||||
|         migrations.AlterUniqueTogether( | ||||
|             name="scimgroup", | ||||
|             unique_together={("scim_id", "group", "provider")}, | ||||
|         ), | ||||
|         migrations.AlterUniqueTogether( | ||||
|             name="scimuser", | ||||
|             unique_together={("scim_id", "user", "provider")}, | ||||
|         ), | ||||
|     ] | ||||
| @ -1,5 +1,7 @@ | ||||
| """SCIM Provider models""" | ||||
|  | ||||
| from uuid import uuid4 | ||||
|  | ||||
| from django.core.cache import cache | ||||
| from django.db import models | ||||
| from django.db.models import QuerySet | ||||
| @ -97,12 +99,13 @@ class SCIMMapping(PropertyMapping): | ||||
| class SCIMUser(models.Model): | ||||
|     """Mapping of a user and provider to a SCIM user ID""" | ||||
|  | ||||
|     id = models.TextField(primary_key=True) | ||||
|     id = models.UUIDField(primary_key=True, editable=False, default=uuid4) | ||||
|     scim_id = models.TextField() | ||||
|     user = models.ForeignKey(User, on_delete=models.CASCADE) | ||||
|     provider = models.ForeignKey(SCIMProvider, on_delete=models.CASCADE) | ||||
|  | ||||
|     class Meta: | ||||
|         unique_together = (("id", "user", "provider"),) | ||||
|         unique_together = (("scim_id", "user", "provider"),) | ||||
|  | ||||
|     def __str__(self) -> str: | ||||
|         return f"SCIM User {self.user_id} to {self.provider_id}" | ||||
| @ -111,12 +114,13 @@ class SCIMUser(models.Model): | ||||
| class SCIMGroup(models.Model): | ||||
|     """Mapping of a group and provider to a SCIM user ID""" | ||||
|  | ||||
|     id = models.TextField(primary_key=True) | ||||
|     id = models.UUIDField(primary_key=True, editable=False, default=uuid4) | ||||
|     scim_id = models.TextField() | ||||
|     group = models.ForeignKey(Group, on_delete=models.CASCADE) | ||||
|     provider = models.ForeignKey(SCIMProvider, on_delete=models.CASCADE) | ||||
|  | ||||
|     class Meta: | ||||
|         unique_together = (("id", "group", "provider"),) | ||||
|         unique_together = (("scim_id", "group", "provider"),) | ||||
|  | ||||
|     def __str__(self) -> str: | ||||
|         return f"SCIM Group {self.group_id} to {self.provider_id}" | ||||
|  | ||||
| @ -88,6 +88,72 @@ class SCIMUserTests(TestCase): | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|     @Mocker() | ||||
|     def test_user_create_different_provider_same_id(self, mock: Mocker): | ||||
|         """Test user creation with multiple providers that happen | ||||
|         to return the same object ID""" | ||||
|         # Create duplicate provider | ||||
|         provider: SCIMProvider = SCIMProvider.objects.create( | ||||
|             name=generate_id(), | ||||
|             url="https://localhost", | ||||
|             token=generate_id(), | ||||
|             exclude_users_service_account=True, | ||||
|         ) | ||||
|         app: Application = Application.objects.create( | ||||
|             name=generate_id(), | ||||
|             slug=generate_id(), | ||||
|         ) | ||||
|         app.backchannel_providers.add(provider) | ||||
|         provider.property_mappings.add( | ||||
|             SCIMMapping.objects.get(managed="goauthentik.io/providers/scim/user") | ||||
|         ) | ||||
|         provider.property_mappings_group.add( | ||||
|             SCIMMapping.objects.get(managed="goauthentik.io/providers/scim/group") | ||||
|         ) | ||||
|  | ||||
|         scim_id = generate_id() | ||||
|         mock.get( | ||||
|             "https://localhost/ServiceProviderConfig", | ||||
|             json={}, | ||||
|         ) | ||||
|         mock.post( | ||||
|             "https://localhost/Users", | ||||
|             json={ | ||||
|                 "id": scim_id, | ||||
|             }, | ||||
|         ) | ||||
|         uid = generate_id() | ||||
|         user = User.objects.create( | ||||
|             username=uid, | ||||
|             name=f"{uid} {uid}", | ||||
|             email=f"{uid}@goauthentik.io", | ||||
|         ) | ||||
|         self.assertEqual(mock.call_count, 4) | ||||
|         self.assertEqual(mock.request_history[0].method, "GET") | ||||
|         self.assertEqual(mock.request_history[1].method, "POST") | ||||
|         self.assertJSONEqual( | ||||
|             mock.request_history[1].body, | ||||
|             { | ||||
|                 "schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"], | ||||
|                 "active": True, | ||||
|                 "emails": [ | ||||
|                     { | ||||
|                         "primary": True, | ||||
|                         "type": "other", | ||||
|                         "value": f"{uid}@goauthentik.io", | ||||
|                     } | ||||
|                 ], | ||||
|                 "externalId": user.uid, | ||||
|                 "name": { | ||||
|                     "familyName": uid, | ||||
|                     "formatted": f"{uid} {uid}", | ||||
|                     "givenName": uid, | ||||
|                 }, | ||||
|                 "displayName": f"{uid} {uid}", | ||||
|                 "userName": uid, | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|     @Mocker() | ||||
|     def test_user_create_update(self, mock: Mocker): | ||||
|         """Test user creation and update""" | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	 Jens L
					Jens L