providers/scim: improve compatibility (#5425)
* providers/scim: improve compatibility Signed-off-by: Jens Langhammer <jens@goauthentik.io> * fix lint and tests Signed-off-by: Jens Langhammer <jens@goauthentik.io> --------- Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
		@ -2,19 +2,12 @@
 | 
			
		||||
from typing import Generic, TypeVar
 | 
			
		||||
 | 
			
		||||
from pydantic import ValidationError
 | 
			
		||||
from pydanticscim.service_provider import (
 | 
			
		||||
    Bulk,
 | 
			
		||||
    ChangePassword,
 | 
			
		||||
    Filter,
 | 
			
		||||
    Patch,
 | 
			
		||||
    ServiceProviderConfiguration,
 | 
			
		||||
    Sort,
 | 
			
		||||
)
 | 
			
		||||
from requests import RequestException, Session
 | 
			
		||||
from structlog.stdlib import get_logger
 | 
			
		||||
 | 
			
		||||
from authentik.lib.utils.http import get_http_session
 | 
			
		||||
from authentik.providers.scim.clients.exceptions import ResourceMissing, SCIMRequestException
 | 
			
		||||
from authentik.providers.scim.clients.schema import ServiceProviderConfiguration
 | 
			
		||||
from authentik.providers.scim.models import SCIMProvider
 | 
			
		||||
 | 
			
		||||
T = TypeVar("T")
 | 
			
		||||
@ -22,18 +15,6 @@ T = TypeVar("T")
 | 
			
		||||
SchemaType = TypeVar("SchemaType")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def default_service_provider_config() -> ServiceProviderConfiguration:
 | 
			
		||||
    """Fallback service provider configuration"""
 | 
			
		||||
    return ServiceProviderConfiguration(
 | 
			
		||||
        patch=Patch(supported=False),
 | 
			
		||||
        bulk=Bulk(supported=False),
 | 
			
		||||
        filter=Filter(supported=False),
 | 
			
		||||
        changePassword=ChangePassword(supported=False),
 | 
			
		||||
        sort=Sort(supported=False),
 | 
			
		||||
        authenticationSchemes=[],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SCIMClient(Generic[T, SchemaType]):
 | 
			
		||||
    """SCIM Client"""
 | 
			
		||||
 | 
			
		||||
@ -85,7 +66,7 @@ class SCIMClient(Generic[T, SchemaType]):
 | 
			
		||||
 | 
			
		||||
    def get_service_provider_config(self):
 | 
			
		||||
        """Get Service provider config"""
 | 
			
		||||
        default_config = default_service_provider_config()
 | 
			
		||||
        default_config = ServiceProviderConfiguration.default()
 | 
			
		||||
        try:
 | 
			
		||||
            return ServiceProviderConfiguration.parse_obj(
 | 
			
		||||
                self._request("GET", "/ServiceProviderConfig")
 | 
			
		||||
 | 
			
		||||
@ -2,7 +2,7 @@
 | 
			
		||||
from deepmerge import always_merger
 | 
			
		||||
from pydantic import ValidationError
 | 
			
		||||
from pydanticscim.group import GroupMember
 | 
			
		||||
from pydanticscim.responses import PatchOp, PatchOperation, PatchRequest
 | 
			
		||||
from pydanticscim.responses import PatchOp, PatchOperation
 | 
			
		||||
 | 
			
		||||
from authentik.core.exceptions import PropertyMappingExpressionException
 | 
			
		||||
from authentik.core.models import Group
 | 
			
		||||
@ -10,8 +10,13 @@ from authentik.events.models import Event, EventAction
 | 
			
		||||
from authentik.lib.utils.errors import exception_to_string
 | 
			
		||||
from authentik.policies.utils import delete_none_keys
 | 
			
		||||
from authentik.providers.scim.clients.base import SCIMClient
 | 
			
		||||
from authentik.providers.scim.clients.exceptions import ResourceMissing, StopSync
 | 
			
		||||
from authentik.providers.scim.clients.exceptions import (
 | 
			
		||||
    ResourceMissing,
 | 
			
		||||
    SCIMRequestException,
 | 
			
		||||
    StopSync,
 | 
			
		||||
)
 | 
			
		||||
from authentik.providers.scim.clients.schema import Group as SCIMGroupSchema
 | 
			
		||||
from authentik.providers.scim.clients.schema import PatchRequest
 | 
			
		||||
from authentik.providers.scim.models import SCIMGroup, SCIMMapping, SCIMUser
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -104,13 +109,20 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]):
 | 
			
		||||
        """Update existing group"""
 | 
			
		||||
        scim_group = self.to_scim(group)
 | 
			
		||||
        scim_group.id = connection.id
 | 
			
		||||
        return self._request(
 | 
			
		||||
            "PUT",
 | 
			
		||||
            f"/Groups/{scim_group.id}",
 | 
			
		||||
            data=scim_group.json(
 | 
			
		||||
                exclude_unset=True,
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
        try:
 | 
			
		||||
            return self._request(
 | 
			
		||||
                "PUT",
 | 
			
		||||
                f"/Groups/{scim_group.id}",
 | 
			
		||||
                data=scim_group.json(
 | 
			
		||||
                    exclude_unset=True,
 | 
			
		||||
                ),
 | 
			
		||||
            )
 | 
			
		||||
        except SCIMRequestException:
 | 
			
		||||
            # Some providers don't support PUT on groups, so this is mainly a fix for the initial
 | 
			
		||||
            # sync, send patch add requests for all the users the group currently has
 | 
			
		||||
            # TODO: send patch request for group name
 | 
			
		||||
            users = list(group.users.order_by("id").values_list("id", flat=True))
 | 
			
		||||
            return self._patch_add_users(group, users)
 | 
			
		||||
 | 
			
		||||
    def _patch(
 | 
			
		||||
        self,
 | 
			
		||||
@ -118,7 +130,7 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]):
 | 
			
		||||
        *ops: PatchOperation,
 | 
			
		||||
    ):
 | 
			
		||||
        req = PatchRequest(Operations=ops)
 | 
			
		||||
        self._request("PATCH", f"/Groups/{group_id}", data=req.json(exclude_unset=True))
 | 
			
		||||
        self._request("PATCH", f"/Groups/{group_id}", data=req.json())
 | 
			
		||||
 | 
			
		||||
    def update_group(self, group: Group, action: PatchOp, users_set: set[int]):
 | 
			
		||||
        """Update a group, either using PUT to replace it or PATCH if supported"""
 | 
			
		||||
@ -127,7 +139,17 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]):
 | 
			
		||||
                return self._patch_add_users(group, users_set)
 | 
			
		||||
            if action == PatchOp.remove:
 | 
			
		||||
                return self._patch_remove_users(group, users_set)
 | 
			
		||||
        return self.write(group)
 | 
			
		||||
        try:
 | 
			
		||||
            return self.write(group)
 | 
			
		||||
        except SCIMRequestException as exc:
 | 
			
		||||
            if self._config.is_fallback:
 | 
			
		||||
                # Assume that provider does not support PUT and also doesn't support
 | 
			
		||||
                # ServiceProviderConfig, so try PATCH as a fallback
 | 
			
		||||
                if action == PatchOp.add:
 | 
			
		||||
                    return self._patch_add_users(group, users_set)
 | 
			
		||||
                if action == PatchOp.remove:
 | 
			
		||||
                    return self._patch_remove_users(group, users_set)
 | 
			
		||||
            raise exc
 | 
			
		||||
 | 
			
		||||
    def _patch_add_users(self, group: Group, users_set: set[int]):
 | 
			
		||||
        """Add users in users_set to group"""
 | 
			
		||||
@ -144,6 +166,8 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]):
 | 
			
		||||
                "id", flat=True
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        if len(user_ids) < 1:
 | 
			
		||||
            return
 | 
			
		||||
        self._patch(
 | 
			
		||||
            scim_group.id,
 | 
			
		||||
            PatchOperation(
 | 
			
		||||
@ -168,6 +192,8 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]):
 | 
			
		||||
                "id", flat=True
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        if len(user_ids) < 1:
 | 
			
		||||
            return
 | 
			
		||||
        self._patch(
 | 
			
		||||
            scim_group.id,
 | 
			
		||||
            PatchOperation(
 | 
			
		||||
 | 
			
		||||
@ -1,17 +1,54 @@
 | 
			
		||||
"""Custom SCIM schemas"""
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from pydanticscim.group import Group as SCIMGroupSchema
 | 
			
		||||
from pydanticscim.user import User as SCIMUserSchema
 | 
			
		||||
from pydanticscim.group import Group as BaseGroup
 | 
			
		||||
from pydanticscim.responses import PatchRequest as BasePatchRequest
 | 
			
		||||
from pydanticscim.service_provider import Bulk, ChangePassword, Filter, Patch
 | 
			
		||||
from pydanticscim.service_provider import (
 | 
			
		||||
    ServiceProviderConfiguration as BaseServiceProviderConfiguration,
 | 
			
		||||
)
 | 
			
		||||
from pydanticscim.service_provider import Sort
 | 
			
		||||
from pydanticscim.user import User as BaseUser
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class User(SCIMUserSchema):
 | 
			
		||||
class User(BaseUser):
 | 
			
		||||
    """Modified User schema with added externalId field"""
 | 
			
		||||
 | 
			
		||||
    externalId: Optional[str] = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Group(SCIMGroupSchema):
 | 
			
		||||
class Group(BaseGroup):
 | 
			
		||||
    """Modified Group schema with added externalId field"""
 | 
			
		||||
 | 
			
		||||
    externalId: Optional[str] = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ServiceProviderConfiguration(BaseServiceProviderConfiguration):
 | 
			
		||||
    """ServiceProviderConfig with fallback"""
 | 
			
		||||
 | 
			
		||||
    _is_fallback: Optional[bool] = False
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def is_fallback(self) -> bool:
 | 
			
		||||
        """Check if this service provider config was retrieved from the API endpoint
 | 
			
		||||
        or a fallback was used"""
 | 
			
		||||
        return self._is_fallback
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def default() -> "ServiceProviderConfiguration":
 | 
			
		||||
        """Get default configuration, which doesn't support any optional features as fallback"""
 | 
			
		||||
        return ServiceProviderConfiguration(
 | 
			
		||||
            patch=Patch(supported=False),
 | 
			
		||||
            bulk=Bulk(supported=False),
 | 
			
		||||
            filter=Filter(supported=False),
 | 
			
		||||
            changePassword=ChangePassword(supported=False),
 | 
			
		||||
            sort=Sort(supported=False),
 | 
			
		||||
            authenticationSchemes=[],
 | 
			
		||||
            _is_fallback=True,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PatchRequest(BasePatchRequest):
 | 
			
		||||
    """PatchRequest which correctly sets schemas"""
 | 
			
		||||
 | 
			
		||||
    schemas: tuple[str] = ["urn:ietf:params:scim:api:messages:2.0:PatchOp"]
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										0
									
								
								authentik/providers/scim/management/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								authentik/providers/scim/management/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										23
									
								
								authentik/providers/scim/management/commands/scim_sync.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								authentik/providers/scim/management/commands/scim_sync.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,23 @@
 | 
			
		||||
"""SCIM Sync"""
 | 
			
		||||
from django.core.management.base import BaseCommand
 | 
			
		||||
from structlog.stdlib import get_logger
 | 
			
		||||
 | 
			
		||||
from authentik.providers.scim.models import SCIMProvider
 | 
			
		||||
from authentik.providers.scim.tasks import scim_sync
 | 
			
		||||
 | 
			
		||||
LOGGER = get_logger()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Command(BaseCommand):
 | 
			
		||||
    """Run sync for an SCIM Provider"""
 | 
			
		||||
 | 
			
		||||
    def add_arguments(self, parser):
 | 
			
		||||
        parser.add_argument("providers", nargs="+", type=str)
 | 
			
		||||
 | 
			
		||||
    def handle(self, **options):
 | 
			
		||||
        for provider_name in options["providers"]:
 | 
			
		||||
            provider = SCIMProvider.objects.filter(name=provider_name).first()
 | 
			
		||||
            if not provider:
 | 
			
		||||
                LOGGER.warning("Provider does not exist", name=provider_name)
 | 
			
		||||
                continue
 | 
			
		||||
            scim_sync.delay(provider.pk).get()
 | 
			
		||||
@ -94,7 +94,8 @@ def scim_sync_users(page: int, provider_pk: int):
 | 
			
		||||
                    }
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
        except StopSync:
 | 
			
		||||
        except StopSync as exc:
 | 
			
		||||
            LOGGER.warning("Stopping sync", exc=exc)
 | 
			
		||||
            break
 | 
			
		||||
    return messages
 | 
			
		||||
 | 
			
		||||
@ -126,7 +127,8 @@ def scim_sync_group(page: int, provider_pk: int):
 | 
			
		||||
                    }
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
        except StopSync:
 | 
			
		||||
        except StopSync as exc:
 | 
			
		||||
            LOGGER.warning("Stopping sync", exc=exc)
 | 
			
		||||
            break
 | 
			
		||||
    return messages
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -6,7 +6,7 @@ from requests_mock import Mocker
 | 
			
		||||
from authentik.blueprints.tests import apply_blueprint
 | 
			
		||||
from authentik.core.models import Group, User
 | 
			
		||||
from authentik.lib.generators import generate_id
 | 
			
		||||
from authentik.providers.scim.clients.base import default_service_provider_config
 | 
			
		||||
from authentik.providers.scim.clients.schema import ServiceProviderConfiguration
 | 
			
		||||
from authentik.providers.scim.models import SCIMMapping, SCIMProvider
 | 
			
		||||
from authentik.providers.scim.tasks import scim_sync
 | 
			
		||||
 | 
			
		||||
@ -39,7 +39,7 @@ class SCIMMembershipTests(TestCase):
 | 
			
		||||
 | 
			
		||||
    def test_member_add(self):
 | 
			
		||||
        """Test member add"""
 | 
			
		||||
        config = default_service_provider_config()
 | 
			
		||||
        config = ServiceProviderConfiguration.default()
 | 
			
		||||
        config.patch.supported = True
 | 
			
		||||
        user_scim_id = generate_id()
 | 
			
		||||
        group_scim_id = generate_id()
 | 
			
		||||
@ -117,13 +117,14 @@ class SCIMMembershipTests(TestCase):
 | 
			
		||||
                            "path": "members",
 | 
			
		||||
                            "value": [{"value": user_scim_id}],
 | 
			
		||||
                        }
 | 
			
		||||
                    ]
 | 
			
		||||
                    ],
 | 
			
		||||
                    "schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
 | 
			
		||||
                },
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def test_member_remove(self):
 | 
			
		||||
        """Test member remove"""
 | 
			
		||||
        config = default_service_provider_config()
 | 
			
		||||
        config = ServiceProviderConfiguration.default()
 | 
			
		||||
        config.patch.supported = True
 | 
			
		||||
        user_scim_id = generate_id()
 | 
			
		||||
        group_scim_id = generate_id()
 | 
			
		||||
@ -201,7 +202,8 @@ class SCIMMembershipTests(TestCase):
 | 
			
		||||
                            "path": "members",
 | 
			
		||||
                            "value": [{"value": user_scim_id}],
 | 
			
		||||
                        }
 | 
			
		||||
                    ]
 | 
			
		||||
                    ],
 | 
			
		||||
                    "schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
 | 
			
		||||
                },
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
@ -227,6 +229,7 @@ class SCIMMembershipTests(TestCase):
 | 
			
		||||
                            "path": "members",
 | 
			
		||||
                            "value": [{"value": user_scim_id}],
 | 
			
		||||
                        }
 | 
			
		||||
                    ]
 | 
			
		||||
                    ],
 | 
			
		||||
                    "schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
 | 
			
		||||
                },
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user