providers/scim: use lock for sync (#7948)
* providers/scim: use lock for sync Signed-off-by: Jens Langhammer <jens@goauthentik.io> * fix Signed-off-by: Jens Langhammer <jens@goauthentik.io> --------- Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
@ -2,6 +2,7 @@
|
||||
from django.utils.text import slugify
|
||||
from drf_spectacular.utils import OpenApiResponse, extend_schema
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.fields import BooleanField
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
@ -9,6 +10,7 @@ from rest_framework.viewsets import ModelViewSet
|
||||
from authentik.admin.api.tasks import TaskSerializer
|
||||
from authentik.core.api.providers import ProviderSerializer
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.utils import PassiveSerializer
|
||||
from authentik.events.monitored_tasks import TaskInfo
|
||||
from authentik.providers.scim.models import SCIMProvider
|
||||
|
||||
@ -37,6 +39,13 @@ class SCIMProviderSerializer(ProviderSerializer):
|
||||
extra_kwargs = {}
|
||||
|
||||
|
||||
class SCIMSyncStatusSerializer(PassiveSerializer):
|
||||
"""SCIM Provider sync status"""
|
||||
|
||||
is_running = BooleanField(read_only=True)
|
||||
tasks = TaskSerializer(many=True, read_only=True)
|
||||
|
||||
|
||||
class SCIMProviderViewSet(UsedByMixin, ModelViewSet):
|
||||
"""SCIMProvider Viewset"""
|
||||
|
||||
@ -48,15 +57,18 @@ class SCIMProviderViewSet(UsedByMixin, ModelViewSet):
|
||||
|
||||
@extend_schema(
|
||||
responses={
|
||||
200: TaskSerializer(),
|
||||
200: SCIMSyncStatusSerializer(),
|
||||
404: OpenApiResponse(description="Task not found"),
|
||||
}
|
||||
)
|
||||
@action(methods=["GET"], detail=True, pagination_class=None, filter_backends=[])
|
||||
def sync_status(self, request: Request, pk: int) -> Response:
|
||||
"""Get provider's sync status"""
|
||||
provider = self.get_object()
|
||||
provider: SCIMProvider = self.get_object()
|
||||
task = TaskInfo.by_name(f"scim_sync:{slugify(provider.name)}")
|
||||
if not task:
|
||||
return Response(status=404)
|
||||
return Response(TaskSerializer(task).data)
|
||||
tasks = [task] if task else []
|
||||
status = {
|
||||
"tasks": tasks,
|
||||
"is_running": provider.sync_lock.locked(),
|
||||
}
|
||||
return Response(SCIMSyncStatusSerializer(status).data)
|
||||
|
||||
@ -1,11 +1,14 @@
|
||||
"""SCIM Provider models"""
|
||||
from django.core.cache import cache
|
||||
from django.db import models
|
||||
from django.db.models import QuerySet
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from guardian.shortcuts import get_anonymous_user
|
||||
from redis.lock import Lock
|
||||
from rest_framework.serializers import Serializer
|
||||
|
||||
from authentik.core.models import BackchannelProvider, Group, PropertyMapping, User, UserTypes
|
||||
from authentik.providers.scim.clients import PAGE_TIMEOUT
|
||||
|
||||
|
||||
class SCIMProvider(BackchannelProvider):
|
||||
@ -27,6 +30,15 @@ class SCIMProvider(BackchannelProvider):
|
||||
help_text=_("Property mappings used for group creation/updating."),
|
||||
)
|
||||
|
||||
@property
|
||||
def sync_lock(self) -> Lock:
|
||||
"""Redis lock for syncing SCIM to prevent multiple parallel syncs happening"""
|
||||
return Lock(
|
||||
cache.client.get_client(),
|
||||
name=f"goauthentik.io/providers/scim/sync-{str(self.pk)}",
|
||||
timeout=(60 * 60 * PAGE_TIMEOUT) * 3,
|
||||
)
|
||||
|
||||
def get_user_qs(self) -> QuerySet[User]:
|
||||
"""Get queryset of all users with consistent ordering
|
||||
according to the provider's settings"""
|
||||
|
||||
@ -47,6 +47,10 @@ def scim_sync(self: MonitoredTask, provider_pk: int) -> None:
|
||||
).first()
|
||||
if not provider:
|
||||
return
|
||||
lock = provider.sync_lock
|
||||
if lock.locked():
|
||||
LOGGER.debug("SCIM sync locked, skipping task", source=provider.name)
|
||||
return
|
||||
self.set_uid(slugify(provider.name))
|
||||
result = TaskResult(TaskResultStatus.SUCCESSFUL, [])
|
||||
result.messages.append(_("Starting full SCIM sync"))
|
||||
|
||||
Reference in New Issue
Block a user