enterprise/providers/google: initial account sync to google workspace (#9384)

* providers/google: initial account sync to google workspace

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* start separating scim sync client

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* generalize more...ish

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* set dispatch_uid

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* start generalizing task

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fully separate tasks

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix more

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix signals...?

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* start google dedupe

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* drawing the rest of the owl

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* more

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* juse use a whole lot less magic

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* member sync, better implement conflict/retry-able exceptions

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* max wizards taller

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* gen api, basic UI

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix some bugs

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix a bunch more bugs

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* generalize sync status API

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* rework sync chart

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add slugify to evaluator

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add test property mappings

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* rename to google workspace

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* handle existing objects

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix credential render

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* verify email has correct domain before syncing user

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix missing docstring

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix lock not being used

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* abstract more common stuff away

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* backport time limit fix

https://github.com/goauthentik/authentik/pull/9546
Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* start discovery

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* implement discover for google

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* prevent same issue as with https://github.com/goauthentik/authentik/pull/9557

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix sync status

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* make group name unique in API

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix reference to old wrapper

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* start adding tests

man this api client is awful

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add SkipObject

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* dont use weak ref

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add group tests

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add user and group delete options

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* set user agent

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* if the api's testing tools are awful, let's just make our own

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add more tests and already fix some more bugs

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add discover

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add preview banner

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add group import test

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* only import users/groups in the correct parent group

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix conflicting args

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix missing schedule

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix web ui

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add default_group_email_domain

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens L
2024-05-07 19:52:20 +02:00
committed by GitHub
parent 18b4b2d7b2
commit aeb1b450eb
84 changed files with 4307 additions and 619 deletions

View File

View File

@ -0,0 +1,5 @@
"""Sync constants"""
PAGE_SIZE = 100
PAGE_TIMEOUT = 60 * 60 * 0.5 # Half an hour
HTTP_CONFLICT = 409

View File

@ -0,0 +1,54 @@
from collections.abc import Callable
from django.utils.text import slugify
from drf_spectacular.utils import OpenApiResponse, extend_schema
from guardian.shortcuts import get_objects_for_user
from rest_framework.decorators import action
from rest_framework.fields import BooleanField
from rest_framework.request import Request
from rest_framework.response import Response
from authentik.core.api.utils import PassiveSerializer
from authentik.events.api.tasks import SystemTaskSerializer
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
class SyncStatusSerializer(PassiveSerializer):
"""Provider sync status"""
is_running = BooleanField(read_only=True)
tasks = SystemTaskSerializer(many=True, read_only=True)
class OutgoingSyncProviderStatusMixin:
"""Common API Endpoints for Outgoing sync providers"""
sync_single_task: Callable = None
@extend_schema(
responses={
200: SyncStatusSerializer(),
404: OpenApiResponse(description="Task not found"),
}
)
@action(
methods=["GET"],
detail=True,
pagination_class=None,
url_path="sync/status",
filter_backends=[],
)
def sync_status(self, request: Request, pk: int) -> Response:
"""Get provider's sync status"""
provider: OutgoingSyncProvider = self.get_object()
tasks = list(
get_objects_for_user(request.user, "authentik_events.view_systemtask").filter(
name=self.sync_single_task.__name__,
uid=slugify(provider.name),
)
)
status = {
"tasks": tasks,
"is_running": provider.sync_lock.locked(),
}
return Response(SyncStatusSerializer(status).data)

View File

@ -0,0 +1,83 @@
"""Basic outgoing sync Client"""
from enum import StrEnum
from typing import TYPE_CHECKING
from django.db import DatabaseError
from structlog.stdlib import get_logger
from authentik.lib.sync.outgoing.exceptions import NotFoundSyncException
if TYPE_CHECKING:
from django.db.models import Model
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
class Direction(StrEnum):
add = "add"
remove = "remove"
class BaseOutgoingSyncClient[
TModel: "Model", TConnection: "Model", TSchema: dict, TProvider: "OutgoingSyncProvider"
]:
"""Basic Outgoing sync client Client"""
provider: TProvider
connection_type: type[TConnection]
connection_type_query: str
can_discover = False
def __init__(self, provider: TProvider):
self.logger = get_logger().bind(provider=provider.name)
self.provider = provider
def create(self, obj: TModel) -> TConnection:
"""Create object in remote destination"""
raise NotImplementedError()
def update(self, obj: TModel, connection: object):
"""Update object in remote destination"""
raise NotImplementedError()
def write(self, obj: TModel) -> tuple[TConnection, bool]:
"""Write object to destination. Uses self.create and self.update, but
can be overwritten for further logic"""
remote_obj = self.connection_type.objects.filter(
provider=self.provider, **{self.connection_type_query: obj}
).first()
connection: TConnection | None = None
try:
if not remote_obj:
connection = self.create(obj)
return connection, True
try:
self.update(obj, remote_obj)
return remote_obj, False
except NotFoundSyncException:
remote_obj.delete()
connection = self.create(obj)
return connection, True
except DatabaseError as exc:
self.logger.warning("Failed to write object", obj=obj, exc=exc)
if connection:
connection.delete()
return None, False
def delete(self, obj: TModel):
"""Delete object from destination"""
raise NotImplementedError()
def to_schema(self, obj: TModel) -> TSchema:
"""Convert object to destination schema"""
raise NotImplementedError()
def discover(self):
"""Optional method. Can be used to implement a "discovery" where
upon creation of this provider, this function will be called and can
pre-link any users/groups in the remote system with the respective
object in authentik based on a common identifier"""
raise NotImplementedError()

View File

@ -0,0 +1,37 @@
from authentik.lib.sentry import SentryIgnoredException
class BaseSyncException(SentryIgnoredException):
"""Base class for all sync exceptions"""
class TransientSyncException(BaseSyncException):
"""Transient sync exception which may be caused by network blips, etc"""
class NotFoundSyncException(BaseSyncException):
"""Exception when an object was not found in the remote system"""
class ObjectExistsSyncException(BaseSyncException):
"""Exception when an object already exists in the remote system"""
class StopSync(BaseSyncException):
"""Exception raised when a configuration error should stop the sync process"""
def __init__(
self, exc: Exception, obj: object | None = None, mapping: object | None = None
) -> None:
self.exc = exc
self.obj = obj
self.mapping = mapping
def detail(self) -> str:
"""Get human readable details of this error"""
msg = f"Error {str(self.exc)}"
if self.obj:
msg += f", caused by {self.obj}"
if self.mapping:
msg += f" (mapping {self.mapping})"
return msg

View File

@ -0,0 +1,32 @@
from typing import Any, Self
from django.core.cache import cache
from django.db.models import Model, QuerySet
from redis.lock import Lock
from authentik.core.models import Group, User
from authentik.lib.sync.outgoing import PAGE_TIMEOUT
from authentik.lib.sync.outgoing.base import BaseOutgoingSyncClient
class OutgoingSyncProvider(Model):
class Meta:
abstract = True
def client_for_model[
T: User | Group
](self, model: type[T]) -> BaseOutgoingSyncClient[T, Any, Any, Self]:
raise NotImplementedError
def get_object_qs[T: User | Group](self, type: type[T]) -> QuerySet[T]:
raise NotImplementedError
@property
def sync_lock(self) -> Lock:
"""Redis lock to prevent multiple parallel syncs happening"""
return Lock(
cache.client.get_client(),
name=f"goauthentik.io/providers/outgoing-sync/{str(self.pk)}",
timeout=(60 * 60 * PAGE_TIMEOUT) * 3,
)

View File

@ -0,0 +1,71 @@
from collections.abc import Callable
from django.core.paginator import Paginator
from django.db.models import Model
from django.db.models.signals import m2m_changed, post_save, pre_delete
from authentik.core.models import Group, User
from authentik.lib.sync.outgoing import PAGE_SIZE, PAGE_TIMEOUT
from authentik.lib.sync.outgoing.base import Direction
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
from authentik.lib.utils.reflection import class_to_path
def register_signals(
provider_type: type[OutgoingSyncProvider],
task_sync_single: Callable[[int], None],
task_sync_direct: Callable[[int], None],
task_sync_m2m: Callable[[int], None],
):
"""Register sync signals"""
uid = class_to_path(provider_type)
def post_save_provider(sender: type[Model], instance: OutgoingSyncProvider, created: bool, **_):
"""Trigger sync when Provider is saved"""
users_paginator = Paginator(instance.get_object_qs(User), PAGE_SIZE)
groups_paginator = Paginator(instance.get_object_qs(Group), PAGE_SIZE)
soft_time_limit = (users_paginator.num_pages + groups_paginator.num_pages) * PAGE_TIMEOUT
time_limit = soft_time_limit * 1.5
task_sync_single.apply_async(
(instance.pk,), time_limit=int(time_limit), soft_time_limit=int(soft_time_limit)
)
post_save.connect(post_save_provider, provider_type, dispatch_uid=uid, weak=False)
def model_post_save(sender: type[Model], instance: User | Group, created: bool, **_):
"""Post save handler"""
if not provider_type.objects.filter(backchannel_application__isnull=False).exists():
return
task_sync_direct.delay(class_to_path(instance.__class__), instance.pk, Direction.add.value)
post_save.connect(model_post_save, User, dispatch_uid=uid, weak=False)
post_save.connect(model_post_save, Group, dispatch_uid=uid, weak=False)
def model_pre_delete(sender: type[Model], instance: User | Group, **_):
"""Pre-delete handler"""
if not provider_type.objects.filter(backchannel_application__isnull=False).exists():
return
task_sync_direct.delay(
class_to_path(instance.__class__), instance.pk, Direction.remove.value
)
pre_delete.connect(model_pre_delete, User, dispatch_uid=uid, weak=False)
pre_delete.connect(model_pre_delete, Group, dispatch_uid=uid, weak=False)
def model_m2m_changed(
sender: type[Model], instance, action: str, pk_set: set, reverse: bool, **kwargs
):
"""Sync group membership"""
if action not in ["post_add", "post_remove"]:
return
if not provider_type.objects.filter(backchannel_application__isnull=False).exists():
return
# reverse: instance is a Group, pk_set is a list of user pks
# non-reverse: instance is a User, pk_set is a list of groups
if reverse:
task_sync_m2m.delay(str(instance.pk), action, list(pk_set))
else:
for group_pk in pk_set:
task_sync_m2m.delay(group_pk, action, [instance.pk])
m2m_changed.connect(model_m2m_changed, User.ak_groups.through, dispatch_uid=uid, weak=False)

View File

@ -0,0 +1,215 @@
from collections.abc import Callable
from celery.result import allow_join_result
from django.core.paginator import Paginator
from django.db.models import Model, QuerySet
from django.utils.text import slugify
from django.utils.translation import gettext_lazy as _
from structlog.stdlib import BoundLogger, get_logger
from authentik.core.expression.exceptions import SkipObjectException
from authentik.core.models import Group, User
from authentik.events.logs import LogEvent
from authentik.events.models import TaskStatus
from authentik.events.system_tasks import SystemTask
from authentik.lib.sync.outgoing import PAGE_SIZE, PAGE_TIMEOUT
from authentik.lib.sync.outgoing.base import Direction
from authentik.lib.sync.outgoing.exceptions import StopSync, TransientSyncException
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
from authentik.lib.utils.reflection import class_to_path, path_to_class
class SyncTasks:
"""Container for all sync 'tasks' (this class doesn't actually contain celery
tasks due to celery's magic, however exposes a number of functions to be called from tasks)"""
logger: BoundLogger
def __init__(self, provider_model: type[OutgoingSyncProvider]) -> None:
super().__init__()
self._provider_model = provider_model
def sync_all(self, single_sync: Callable[[int], None]):
for provider in self._provider_model.objects.filter(backchannel_application__isnull=False):
self.trigger_single_task(provider, single_sync)
def trigger_single_task(self, provider: OutgoingSyncProvider, sync_task: Callable[[int], None]):
"""Wrapper single sync task that correctly sets time limits based
on the amount of objects that will be synced"""
users_paginator = Paginator(provider.get_object_qs(User), PAGE_SIZE)
groups_paginator = Paginator(provider.get_object_qs(Group), PAGE_SIZE)
soft_time_limit = (users_paginator.num_pages + groups_paginator.num_pages) * PAGE_TIMEOUT
time_limit = soft_time_limit * 1.5
return sync_task.apply_async(
(provider.pk,), time_limit=int(time_limit), soft_time_limit=int(soft_time_limit)
)
def sync_single(
self,
task: SystemTask,
provider_pk: int,
sync_objects: Callable[[int, int], list[str]],
):
self.logger = get_logger().bind(
provider_type=class_to_path(self._provider_model),
provider_pk=provider_pk,
)
provider = self._provider_model.objects.filter(
pk=provider_pk, backchannel_application__isnull=False
).first()
if not provider:
return
lock = provider.sync_lock
if lock.locked():
self.logger.debug("Sync locked, skipping task", source=provider.name)
return
task.set_uid(slugify(provider.name))
messages = []
messages.append(_("Starting full provider sync"))
self.logger.debug("Starting provider sync")
users_paginator = Paginator(provider.get_object_qs(User), PAGE_SIZE)
groups_paginator = Paginator(provider.get_object_qs(Group), PAGE_SIZE)
with allow_join_result(), lock:
try:
for page in users_paginator.page_range:
messages.append(_("Syncing page %(page)d of users" % {"page": page}))
for msg in sync_objects.apply_async(
args=(class_to_path(User), page, provider_pk),
time_limit=PAGE_TIMEOUT,
soft_time_limit=PAGE_TIMEOUT,
).get():
messages.append(msg)
for page in groups_paginator.page_range:
messages.append(_("Syncing page %(page)d of groups" % {"page": page}))
for msg in sync_objects.apply_async(
args=(class_to_path(Group), page, provider_pk),
time_limit=PAGE_TIMEOUT,
soft_time_limit=PAGE_TIMEOUT,
).get():
messages.append(msg)
except TransientSyncException as exc:
self.logger.warning("transient sync exception", exc=exc)
raise task.retry(exc=exc) from exc
except StopSync as exc:
task.set_error(exc)
return
task.set_status(TaskStatus.SUCCESSFUL, *messages)
def sync_objects(self, object_type: str, page: int, provider_pk: int):
_object_type = path_to_class(object_type)
self.logger = get_logger().bind(
provider_type=class_to_path(self._provider_model),
provider_pk=provider_pk,
object_type=object_type,
)
messages = []
provider = self._provider_model.objects.filter(pk=provider_pk).first()
if not provider:
return messages
try:
client = provider.client_for_model(_object_type)
except TransientSyncException:
return messages
paginator = Paginator(provider.get_object_qs(_object_type), PAGE_SIZE)
if client.can_discover:
self.logger.debug("starting discover")
client.discover()
self.logger.debug("starting sync for page", page=page)
for obj in paginator.page(page).object_list:
obj: Model
try:
client.write(obj)
except SkipObjectException:
continue
except TransientSyncException as exc:
self.logger.warning("failed to sync object", exc=exc, user=obj)
messages.append(
LogEvent(
_(
(
"Failed to sync {object_type} {object_name} "
"due to transient error: {error}"
).format_map(
{
"object_type": obj._meta.verbose_name,
"object_name": str(obj),
"error": str(exc),
}
)
),
log_level="warning",
logger="",
)
)
except StopSync as exc:
self.logger.warning("Stopping sync", exc=exc)
messages.append(
LogEvent(
_(
"Stopping sync due to error: {error}".format_map(
{
"error": exc.detail(),
}
)
),
log_level="warning",
logger="",
)
)
break
return messages
def sync_signal_direct(self, model: str, pk: str | int, raw_op: str):
self.logger = get_logger().bind(
provider_type=class_to_path(self._provider_model),
)
model_class: type[Model] = path_to_class(model)
instance = model_class.objects.filter(pk=pk).first()
if not instance:
return
operation = Direction(raw_op)
for provider in self._provider_model.objects.filter(backchannel_application__isnull=False):
client = provider.client_for_model(instance.__class__)
# Check if the object is allowed within the provider's restrictions
queryset = provider.get_object_qs(instance.__class__)
if not queryset:
continue
# The queryset we get from the provider must include the instance we've got given
# otherwise ignore this provider
if not queryset.filter(pk=instance.pk).exists():
continue
try:
if operation == Direction.add:
client.write(instance)
if operation == Direction.remove:
client.delete(instance)
except (StopSync, TransientSyncException) as exc:
self.logger.warning(exc, provider_pk=provider.pk)
def sync_signal_m2m(self, group_pk: str, action: str, pk_set: list[int]):
self.logger = get_logger().bind(
provider_type=class_to_path(self._provider_model),
)
group = Group.objects.filter(pk=group_pk).first()
if not group:
return
for provider in self._provider_model.objects.filter(backchannel_application__isnull=False):
# Check if the object is allowed within the provider's restrictions
queryset: QuerySet = provider.get_object_qs(Group)
# The queryset we get from the provider must include the instance we've got given
# otherwise ignore this provider
if not queryset.filter(pk=group_pk).exists():
continue
client = provider.client_for_model(Group)
try:
operation = None
if action == "post_add":
operation = Direction.add
if action == "post_remove":
operation = Direction.remove
client.update_group(group, operation, pk_set)
except (StopSync, TransientSyncException) as exc:
self.logger.warning(exc, provider_pk=provider.pk)