Files
authentik/authentik/providers/scim/tasks.py
Jens L 28ddeb124f providers: SCIM (#4835)
* basic user sync

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add group sync and some refactor

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* start API

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* allow null authorization flow

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add UI

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* make task monitored

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add missing dependency

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* make authorization_flow required for most providers via API

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* more UI

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* make task result better readable, exclude anonymous user

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add task UI

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add scheduled task for all sync

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* make scim errors more readable

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add mappings, migrate to mappings

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add mapping UI and more

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add scim docs to web

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* start implementing membership

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* migrate signals to tasks

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* migrate fully to tasks

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* strip none keys, fix lint errors

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix things

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* start adding tests

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix saml

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add scim schemas and validate against it

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* improve error handling

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add group put support, add group tests

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* send correct application/scim+json headers

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* stop sync if no mappings are confiugred

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add test for task sync

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add membership tests

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* use decorator for tests

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* make tests better

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2023-03-06 19:39:08 +01:00

177 lines
6.4 KiB
Python

"""SCIM Provider tasks"""
from typing import Any
from celery.result import allow_join_result
from django.core.paginator import Paginator
from django.db.models import Model
from django.utils.text import slugify
from django.utils.translation import gettext_lazy as _
from guardian.shortcuts import get_anonymous_user
from pydanticscim.responses import PatchOp
from structlog.stdlib import get_logger
from authentik.core.models import Group, User
from authentik.events.monitored_tasks import MonitoredTask, TaskResult, TaskResultStatus
from authentik.lib.utils.reflection import path_to_class
from authentik.providers.scim.clients import PAGE_SIZE
from authentik.providers.scim.clients.base import SCIMClient
from authentik.providers.scim.clients.exceptions import SCIMRequestException, StopSync
from authentik.providers.scim.clients.group import SCIMGroupClient
from authentik.providers.scim.clients.user import SCIMUserClient
from authentik.providers.scim.models import SCIMProvider
from authentik.root.celery import CELERY_APP
LOGGER = get_logger(__name__)
def client_for_model(provider: SCIMProvider, model: Model) -> SCIMClient:
"""Get SCIM client for model"""
if isinstance(model, User):
return SCIMUserClient(provider)
if isinstance(model, Group):
return SCIMGroupClient(provider)
raise ValueError(f"Invalid model {model}")
@CELERY_APP.task()
def scim_sync_all():
"""Run sync for all providers"""
for provider in SCIMProvider.objects.all():
scim_sync.delay(provider.pk)
@CELERY_APP.task(bind=True, base=MonitoredTask)
def scim_sync(self: MonitoredTask, provider_pk: int) -> None:
"""Run SCIM full sync for provider"""
provider: SCIMProvider = SCIMProvider.objects.filter(pk=provider_pk).first()
if not provider:
return
self.set_uid(slugify(provider.name))
result = TaskResult(TaskResultStatus.SUCCESSFUL, [])
result.messages.append(_("Starting full SCIM sync"))
# TODO: Filtering
LOGGER.debug("Starting SCIM sync")
users_paginator = Paginator(
User.objects.all().exclude(pk=get_anonymous_user().pk).order_by("pk"), PAGE_SIZE
)
groups_paginator = Paginator(Group.objects.all().order_by("pk"), PAGE_SIZE)
with allow_join_result():
try:
for page in users_paginator.page_range:
result.messages.append(_("Syncing page %(page)d of users" % {"page": page}))
for msg in scim_sync_users.delay(page, provider_pk).get():
result.messages.append(msg)
for page in groups_paginator.page_range:
result.messages.append(_("Syncing page %(page)d of groups" % {"page": page}))
for msg in scim_sync_group.delay(page, provider_pk).get():
result.messages.append(msg)
except StopSync as exc:
self.set_status(TaskResult(TaskResultStatus.ERROR).with_error(exc))
return
self.set_status(result)
@CELERY_APP.task()
def scim_sync_users(page: int, provider_pk: int, **kwargs):
"""Sync single or multiple users to SCIM"""
messages = []
provider: SCIMProvider = SCIMProvider.objects.filter(pk=provider_pk).first()
if not provider:
return messages
try:
client = SCIMUserClient(provider)
except SCIMRequestException:
return messages
paginator = Paginator(
User.objects.all().filter(**kwargs).exclude(pk=get_anonymous_user().pk).order_by("pk"),
PAGE_SIZE,
)
LOGGER.debug("starting user sync for page", page=page)
for user in paginator.page(page).object_list:
try:
client.write(user)
except SCIMRequestException as exc:
LOGGER.warning("failed to sync user", exc=exc, user=user)
messages.append(
_(
"Failed to sync user due to remote error %(name)s: %(error)s"
% {
"name": user.username,
"error": str(exc),
}
)
)
except StopSync:
break
return messages
@CELERY_APP.task()
def scim_sync_group(page: int, provider_pk: int, **kwargs):
"""Sync single or multiple groups to SCIM"""
messages = []
provider: SCIMProvider = SCIMProvider.objects.filter(pk=provider_pk).first()
if not provider:
return messages
try:
client = SCIMGroupClient(provider)
except SCIMRequestException:
return messages
paginator = Paginator(Group.objects.all().filter(**kwargs).order_by("pk"), PAGE_SIZE)
LOGGER.debug("starting group sync for page", page=page)
for group in paginator.page(page).object_list:
try:
client.write(group)
except SCIMRequestException as exc:
LOGGER.warning("failed to sync group", exc=exc, group=group)
messages.append(
_(
"Failed to sync group due to remote error %(name)s: %(error)s"
% {
"name": group.name,
"error": str(exc),
}
)
)
except StopSync:
break
return messages
@CELERY_APP.task()
def scim_signal_direct(model: str, pk: Any, raw_op: str):
"""Handler for post_save and pre_delete signal"""
model_class: type[Model] = path_to_class(model)
instance = model_class.objects.filter(pk=pk).first()
if not instance:
return
operation = PatchOp(raw_op)
for provider in SCIMProvider.objects.all():
client = client_for_model(provider, instance)
try:
if operation == PatchOp.add:
client.write(instance)
if operation == PatchOp.remove:
client.delete(instance)
except (StopSync, SCIMRequestException) as exc:
LOGGER.warning(exc)
@CELERY_APP.task()
def scim_signal_m2m(group_pk: str, action: str, pk_set: set[int]):
"""Update m2m (group membership)"""
group = Group.objects.filter(pk=group_pk).first()
if not group:
return
for provider in SCIMProvider.objects.all():
client = SCIMGroupClient(provider)
try:
operation = None
if action == "post_add":
operation = PatchOp.add
if action == "post_remove":
operation = PatchOp.remove
client.update_group(group, operation, pk_set)
except (StopSync, SCIMRequestException) as exc:
LOGGER.warning(exc)