Files
authentik/authentik/lib/sync/outgoing/tasks.py
Jens L. b5a8957720 lib/sync/outgoing: add dry run (#13244)
* lib/sync/outgoing: add dry run

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add option to temporarily override dry run

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* web a

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* web b

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* format

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add some test

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add more tests

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add dry run label

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add support for entra too

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add web

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add entra test and improve error handling

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2025-03-01 19:44:17 +00:00

295 lines
13 KiB
Python

from collections.abc import Callable
from dataclasses import asdict
from celery.exceptions import Retry
from celery.result import allow_join_result
from django.core.paginator import Paginator
from django.db.models import Model, QuerySet
from django.db.models.query import Q
from django.utils.text import slugify
from django.utils.translation import gettext_lazy as _
from structlog.stdlib import BoundLogger, get_logger
from authentik.core.expression.exceptions import SkipObjectException
from authentik.core.models import Group, User
from authentik.events.logs import LogEvent
from authentik.events.models import TaskStatus
from authentik.events.system_tasks import SystemTask
from authentik.events.utils import sanitize_item
from authentik.lib.sync.outgoing import PAGE_SIZE, PAGE_TIMEOUT
from authentik.lib.sync.outgoing.base import Direction
from authentik.lib.sync.outgoing.exceptions import (
BadRequestSyncException,
DryRunRejected,
StopSync,
TransientSyncException,
)
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
from authentik.lib.utils.reflection import class_to_path, path_to_class
class SyncTasks:
"""Container for all sync 'tasks' (this class doesn't actually contain celery
tasks due to celery's magic, however exposes a number of functions to be called from tasks)"""
logger: BoundLogger
def __init__(self, provider_model: type[OutgoingSyncProvider]) -> None:
super().__init__()
self._provider_model = provider_model
def sync_all(self, single_sync: Callable[[int], None]):
for provider in self._provider_model.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
):
self.trigger_single_task(provider, single_sync)
def trigger_single_task(self, provider: OutgoingSyncProvider, sync_task: Callable[[int], None]):
"""Wrapper single sync task that correctly sets time limits based
on the amount of objects that will be synced"""
users_paginator = Paginator(provider.get_object_qs(User), PAGE_SIZE)
groups_paginator = Paginator(provider.get_object_qs(Group), PAGE_SIZE)
soft_time_limit = (users_paginator.num_pages + groups_paginator.num_pages) * PAGE_TIMEOUT
time_limit = soft_time_limit * 1.5
return sync_task.apply_async(
(provider.pk,), time_limit=int(time_limit), soft_time_limit=int(soft_time_limit)
)
def sync_single(
self,
task: SystemTask,
provider_pk: int,
sync_objects: Callable[[int, int], list[str]],
):
self.logger = get_logger().bind(
provider_type=class_to_path(self._provider_model),
provider_pk=provider_pk,
)
provider = self._provider_model.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False),
pk=provider_pk,
).first()
if not provider:
return
task.set_uid(slugify(provider.name))
messages = []
messages.append(_("Starting full provider sync"))
self.logger.debug("Starting provider sync")
users_paginator = Paginator(provider.get_object_qs(User), PAGE_SIZE)
groups_paginator = Paginator(provider.get_object_qs(Group), PAGE_SIZE)
with allow_join_result(), provider.sync_lock as lock_acquired:
if not lock_acquired:
self.logger.debug("Failed to acquire sync lock, skipping", provider=provider.name)
return
try:
for page in users_paginator.page_range:
messages.append(_("Syncing page {page} of users".format(page=page)))
for msg in sync_objects.apply_async(
args=(class_to_path(User), page, provider_pk),
time_limit=PAGE_TIMEOUT,
soft_time_limit=PAGE_TIMEOUT,
).get():
messages.append(LogEvent(**msg))
for page in groups_paginator.page_range:
messages.append(_("Syncing page {page} of groups".format(page=page)))
for msg in sync_objects.apply_async(
args=(class_to_path(Group), page, provider_pk),
time_limit=PAGE_TIMEOUT,
soft_time_limit=PAGE_TIMEOUT,
).get():
messages.append(LogEvent(**msg))
except TransientSyncException as exc:
self.logger.warning("transient sync exception", exc=exc)
raise task.retry(exc=exc) from exc
except StopSync as exc:
task.set_error(exc)
return
task.set_status(TaskStatus.SUCCESSFUL, *messages)
def sync_objects(
self, object_type: str, page: int, provider_pk: int, override_dry_run=False, **filter
):
_object_type = path_to_class(object_type)
self.logger = get_logger().bind(
provider_type=class_to_path(self._provider_model),
provider_pk=provider_pk,
object_type=object_type,
)
messages = []
provider = self._provider_model.objects.filter(pk=provider_pk).first()
if not provider:
return messages
# Override dry run mode if requested, however don't save the provider
# so that scheduled sync tasks still run in dry_run mode
if override_dry_run:
provider.dry_run = False
try:
client = provider.client_for_model(_object_type)
except TransientSyncException:
return messages
paginator = Paginator(provider.get_object_qs(_object_type).filter(**filter), PAGE_SIZE)
if client.can_discover:
self.logger.debug("starting discover")
client.discover()
self.logger.debug("starting sync for page", page=page)
for obj in paginator.page(page).object_list:
obj: Model
try:
client.write(obj)
except SkipObjectException:
self.logger.debug("skipping object due to SkipObject", obj=obj)
continue
except DryRunRejected as exc:
messages.append(
asdict(
LogEvent(
_("Dropping mutating request due to dry run"),
log_level="info",
logger=f"{provider._meta.verbose_name}@{object_type}",
attributes={
"obj": sanitize_item(obj),
"method": exc.method,
"url": exc.url,
"body": exc.body,
},
)
)
)
except BadRequestSyncException as exc:
self.logger.warning("failed to sync object", exc=exc, obj=obj)
messages.append(
asdict(
LogEvent(
_(
(
"Failed to sync {object_type} {object_name} "
"due to error: {error}"
).format_map(
{
"object_type": obj._meta.verbose_name,
"object_name": str(obj),
"error": str(exc),
}
)
),
log_level="warning",
logger=f"{provider._meta.verbose_name}@{object_type}",
attributes={"arguments": exc.args[1:], "obj": sanitize_item(obj)},
)
)
)
except TransientSyncException as exc:
self.logger.warning("failed to sync object", exc=exc, user=obj)
messages.append(
asdict(
LogEvent(
_(
(
"Failed to sync {object_type} {object_name} "
"due to transient error: {error}"
).format_map(
{
"object_type": obj._meta.verbose_name,
"object_name": str(obj),
"error": str(exc),
}
)
),
log_level="warning",
logger=f"{provider._meta.verbose_name}@{object_type}",
attributes={"obj": sanitize_item(obj)},
)
)
)
except StopSync as exc:
self.logger.warning("Stopping sync", exc=exc)
messages.append(
asdict(
LogEvent(
_(
"Stopping sync due to error: {error}".format_map(
{
"error": exc.detail(),
}
)
),
log_level="warning",
logger=f"{provider._meta.verbose_name}@{object_type}",
attributes={"obj": sanitize_item(obj)},
)
)
)
break
return messages
def sync_signal_direct(self, model: str, pk: str | int, raw_op: str):
self.logger = get_logger().bind(
provider_type=class_to_path(self._provider_model),
)
model_class: type[Model] = path_to_class(model)
instance = model_class.objects.filter(pk=pk).first()
if not instance:
return
operation = Direction(raw_op)
for provider in self._provider_model.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
):
client = provider.client_for_model(instance.__class__)
# Check if the object is allowed within the provider's restrictions
queryset = provider.get_object_qs(instance.__class__)
if not queryset:
continue
# The queryset we get from the provider must include the instance we've got given
# otherwise ignore this provider
if not queryset.filter(pk=instance.pk).exists():
continue
try:
if operation == Direction.add:
client.write(instance)
if operation == Direction.remove:
client.delete(instance)
except TransientSyncException as exc:
raise Retry() from exc
except SkipObjectException:
continue
except DryRunRejected as exc:
self.logger.info("Rejected dry-run event", exc=exc)
except StopSync as exc:
self.logger.warning("Stopping sync", exc=exc, provider_pk=provider.pk)
def sync_signal_m2m(self, group_pk: str, action: str, pk_set: list[int]):
self.logger = get_logger().bind(
provider_type=class_to_path(self._provider_model),
)
group = Group.objects.filter(pk=group_pk).first()
if not group:
return
for provider in self._provider_model.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
):
# Check if the object is allowed within the provider's restrictions
queryset: QuerySet = provider.get_object_qs(Group)
# The queryset we get from the provider must include the instance we've got given
# otherwise ignore this provider
if not queryset.filter(pk=group_pk).exists():
continue
client = provider.client_for_model(Group)
try:
operation = None
if action == "post_add":
operation = Direction.add
if action == "post_remove":
operation = Direction.remove
client.update_group(group, operation, pk_set)
except TransientSyncException as exc:
raise Retry() from exc
except SkipObjectException:
continue
except DryRunRejected as exc:
self.logger.info("Rejected dry-run event", exc=exc)
except StopSync as exc:
self.logger.warning("Stopping sync", exc=exc, provider_pk=provider.pk)