providers/scim: use lock for sync (#7948)

* providers/scim: use lock for sync

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens L
2023-12-21 14:43:40 +01:00
committed by GitHub
parent ec8f2d4bf9
commit 2521073dba
16 changed files with 123 additions and 91 deletions

View File

@ -2,6 +2,7 @@
from django.utils.text import slugify
from drf_spectacular.utils import OpenApiResponse, extend_schema
from rest_framework.decorators import action
from rest_framework.fields import BooleanField
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.viewsets import ModelViewSet
@ -9,6 +10,7 @@ from rest_framework.viewsets import ModelViewSet
from authentik.admin.api.tasks import TaskSerializer
from authentik.core.api.providers import ProviderSerializer
from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import PassiveSerializer
from authentik.events.monitored_tasks import TaskInfo
from authentik.providers.scim.models import SCIMProvider
@ -37,6 +39,13 @@ class SCIMProviderSerializer(ProviderSerializer):
extra_kwargs = {}
class SCIMSyncStatusSerializer(PassiveSerializer):
"""SCIM Provider sync status"""
is_running = BooleanField(read_only=True)
tasks = TaskSerializer(many=True, read_only=True)
class SCIMProviderViewSet(UsedByMixin, ModelViewSet):
"""SCIMProvider Viewset"""
@ -48,15 +57,18 @@ class SCIMProviderViewSet(UsedByMixin, ModelViewSet):
@extend_schema(
responses={
200: TaskSerializer(),
200: SCIMSyncStatusSerializer(),
404: OpenApiResponse(description="Task not found"),
}
)
@action(methods=["GET"], detail=True, pagination_class=None, filter_backends=[])
def sync_status(self, request: Request, pk: int) -> Response:
"""Get provider's sync status"""
provider = self.get_object()
provider: SCIMProvider = self.get_object()
task = TaskInfo.by_name(f"scim_sync:{slugify(provider.name)}")
if not task:
return Response(status=404)
return Response(TaskSerializer(task).data)
tasks = [task] if task else []
status = {
"tasks": tasks,
"is_running": provider.sync_lock.locked(),
}
return Response(SCIMSyncStatusSerializer(status).data)

View File

@ -1,11 +1,14 @@
"""SCIM Provider models"""
from django.core.cache import cache
from django.db import models
from django.db.models import QuerySet
from django.utils.translation import gettext_lazy as _
from guardian.shortcuts import get_anonymous_user
from redis.lock import Lock
from rest_framework.serializers import Serializer
from authentik.core.models import BackchannelProvider, Group, PropertyMapping, User, UserTypes
from authentik.providers.scim.clients import PAGE_TIMEOUT
class SCIMProvider(BackchannelProvider):
@ -27,6 +30,15 @@ class SCIMProvider(BackchannelProvider):
help_text=_("Property mappings used for group creation/updating."),
)
@property
def sync_lock(self) -> Lock:
"""Redis lock for syncing SCIM to prevent multiple parallel syncs happening"""
return Lock(
cache.client.get_client(),
name=f"goauthentik.io/providers/scim/sync-{str(self.pk)}",
timeout=(60 * 60 * PAGE_TIMEOUT) * 3,
)
def get_user_qs(self) -> QuerySet[User]:
"""Get queryset of all users with consistent ordering
according to the provider's settings"""

View File

@ -47,6 +47,10 @@ def scim_sync(self: MonitoredTask, provider_pk: int) -> None:
).first()
if not provider:
return
lock = provider.sync_lock
if lock.locked():
LOGGER.debug("SCIM sync locked, skipping task", source=provider.name)
return
self.set_uid(slugify(provider.name))
result = TaskResult(TaskResultStatus.SUCCESSFUL, [])
result.messages.append(_("Starting full SCIM sync"))