uhhhhhhhhh
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
@ -45,11 +45,17 @@ class OutgoingSyncProvider(ScheduledModel, Model):
|
|||||||
def get_object_qs[T: User | Group](self, type: type[T]) -> QuerySet[T]:
|
def get_object_qs[T: User | Group](self, type: type[T]) -> QuerySet[T]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_paginator[T: User | Group](self, type: type[T]) -> Paginator:
|
||||||
|
return Paginator(self.get_object_qs(type), PAGE_SIZE)
|
||||||
|
|
||||||
|
def get_object_sync_time_limit[T: User | Group](self, type: type[T]) -> int:
|
||||||
|
num_pages: int = self.get_paginator(type).num_pages
|
||||||
|
return int(num_pages * PAGE_TIMEOUT * 1.5) * 1000
|
||||||
|
|
||||||
def get_sync_time_limit(self) -> int:
|
def get_sync_time_limit(self) -> int:
|
||||||
users_paginator = Paginator(self.get_object_qs(User), PAGE_SIZE)
|
return int(
|
||||||
groups_paginator = Paginator(self.get_object_qs(Group), PAGE_SIZE)
|
self.get_object_sync_time_limit(User) + self.get_object_sync_time_limit(Group) * 1.5
|
||||||
time_limit = (users_paginator.num_pages + groups_paginator.num_pages) * PAGE_TIMEOUT * 1.5
|
)
|
||||||
return int(time_limit)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sync_lock(self) -> pglock.advisory:
|
def sync_lock(self) -> pglock.advisory:
|
||||||
@ -72,7 +78,7 @@ class OutgoingSyncProvider(ScheduledModel, Model):
|
|||||||
uid=self.pk,
|
uid=self.pk,
|
||||||
args=(self.pk,),
|
args=(self.pk,),
|
||||||
options={
|
options={
|
||||||
"time_limit": self.get_sync_time_limit() * 1000,
|
"time_limit": self.get_sync_time_limit(),
|
||||||
},
|
},
|
||||||
send_on_save=True,
|
send_on_save=True,
|
||||||
crontab=f"{fqdn_rand(self.pk)} */4 * * *",
|
crontab=f"{fqdn_rand(self.pk)} */4 * * *",
|
||||||
|
@ -1,21 +1,19 @@
|
|||||||
from collections.abc import Callable
|
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
|
|
||||||
from celery import group
|
from dramatiq.composition import group
|
||||||
from celery.exceptions import Retry
|
|
||||||
from celery.result import allow_join_result
|
|
||||||
from django.core.paginator import Paginator
|
from django.core.paginator import Paginator
|
||||||
from django.db.models import Model, QuerySet
|
from django.db.models import Model, QuerySet
|
||||||
from django.db.models.query import Q
|
from django.db.models.query import Q
|
||||||
from django.utils.text import slugify
|
from django.utils.text import slugify
|
||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
|
from dramatiq.actor import Actor
|
||||||
|
from dramatiq.errors import Retry
|
||||||
from structlog.stdlib import BoundLogger, get_logger
|
from structlog.stdlib import BoundLogger, get_logger
|
||||||
|
|
||||||
from authentik.core.expression.exceptions import SkipObjectException
|
from authentik.core.expression.exceptions import SkipObjectException
|
||||||
from authentik.core.models import Group, User
|
from authentik.core.models import Group, User
|
||||||
from authentik.events.logs import LogEvent
|
from authentik.events.logs import LogEvent
|
||||||
from authentik.events.models import TaskStatus
|
from authentik.events.models import TaskStatus
|
||||||
from authentik.events.system_tasks import SystemTask
|
|
||||||
from authentik.events.utils import sanitize_item
|
from authentik.events.utils import sanitize_item
|
||||||
from authentik.lib.sync.outgoing import PAGE_SIZE, PAGE_TIMEOUT
|
from authentik.lib.sync.outgoing import PAGE_SIZE, PAGE_TIMEOUT
|
||||||
from authentik.lib.sync.outgoing.base import Direction
|
from authentik.lib.sync.outgoing.base import Direction
|
||||||
@ -27,11 +25,13 @@ from authentik.lib.sync.outgoing.exceptions import (
|
|||||||
)
|
)
|
||||||
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
|
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
|
||||||
from authentik.lib.utils.reflection import class_to_path, path_to_class
|
from authentik.lib.utils.reflection import class_to_path, path_to_class
|
||||||
|
from authentik.tasks.middleware import CurrentTask
|
||||||
|
from authentik.tasks.models import Task
|
||||||
|
|
||||||
|
|
||||||
class SyncTasks:
|
class SyncTasks:
|
||||||
"""Container for all sync 'tasks' (this class doesn't actually contain celery
|
"""Container for all sync 'tasks' (this class doesn't actually contain
|
||||||
tasks due to celery's magic, however exposes a number of functions to be called from tasks)"""
|
tasks due to dramatiq's magic, however exposes a number of functions to be called from tasks)"""
|
||||||
|
|
||||||
logger: BoundLogger
|
logger: BoundLogger
|
||||||
|
|
||||||
@ -39,107 +39,97 @@ class SyncTasks:
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self._provider_model = provider_model
|
self._provider_model = provider_model
|
||||||
|
|
||||||
def sync_all(self, single_sync: Callable[[int], None]):
|
def sync_paginator(
|
||||||
for provider in self._provider_model.objects.filter(
|
|
||||||
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
|
|
||||||
):
|
|
||||||
self.trigger_single_task(provider, single_sync)
|
|
||||||
|
|
||||||
def trigger_single_task(self, provider: OutgoingSyncProvider, sync_task: Callable[[int], None]):
|
|
||||||
"""Wrapper single sync task that correctly sets time limits based
|
|
||||||
on the amount of objects that will be synced"""
|
|
||||||
users_paginator = Paginator(provider.get_object_qs(User), PAGE_SIZE)
|
|
||||||
groups_paginator = Paginator(provider.get_object_qs(Group), PAGE_SIZE)
|
|
||||||
soft_time_limit = (users_paginator.num_pages + groups_paginator.num_pages) * PAGE_TIMEOUT
|
|
||||||
time_limit = soft_time_limit * 1.5
|
|
||||||
return sync_task.apply_async(
|
|
||||||
(provider.pk,), time_limit=int(time_limit), soft_time_limit=int(soft_time_limit)
|
|
||||||
)
|
|
||||||
|
|
||||||
def sync_single(
|
|
||||||
self,
|
self,
|
||||||
task: SystemTask,
|
current_task: Task,
|
||||||
provider_pk: int,
|
provider: OutgoingSyncProvider,
|
||||||
sync_objects: Callable[[int, int], list[str]],
|
sync_objects: Actor,
|
||||||
|
paginator: Paginator,
|
||||||
|
object_type: type[User | Group],
|
||||||
|
**options,
|
||||||
):
|
):
|
||||||
|
tasks = []
|
||||||
|
for page in paginator.page_range:
|
||||||
|
page_sync = sync_objects.message_with_options(
|
||||||
|
args=(class_to_path(object_type), page, provider.pk),
|
||||||
|
time_limit=PAGE_TIMEOUT * 1000,
|
||||||
|
# Assign tasks to the same schedule as the current one
|
||||||
|
rel_obj=current_task.rel_obj,
|
||||||
|
**options,
|
||||||
|
)
|
||||||
|
tasks.append(page_sync)
|
||||||
|
return tasks
|
||||||
|
|
||||||
|
def sync(
|
||||||
|
self,
|
||||||
|
provider_pk: int,
|
||||||
|
sync_objects: Actor,
|
||||||
|
):
|
||||||
|
task = CurrentTask.get_task()
|
||||||
self.logger = get_logger().bind(
|
self.logger = get_logger().bind(
|
||||||
provider_type=class_to_path(self._provider_model),
|
provider_type=class_to_path(self._provider_model),
|
||||||
provider_pk=provider_pk,
|
provider_pk=provider_pk,
|
||||||
)
|
)
|
||||||
provider = self._provider_model.objects.filter(
|
provider: OutgoingSyncProvider = self._provider_model.objects.filter(
|
||||||
Q(backchannel_application__isnull=False) | Q(application__isnull=False),
|
Q(backchannel_application__isnull=False) | Q(application__isnull=False),
|
||||||
pk=provider_pk,
|
pk=provider_pk,
|
||||||
).first()
|
).first()
|
||||||
if not provider:
|
if not provider:
|
||||||
return
|
return
|
||||||
task.set_uid(slugify(provider.name))
|
task.set_uid(slugify(provider.name))
|
||||||
messages = []
|
task.info("Starting full provider sync")
|
||||||
messages.append(_("Starting full provider sync"))
|
|
||||||
self.logger.debug("Starting provider sync")
|
self.logger.debug("Starting provider sync")
|
||||||
users_paginator = Paginator(provider.get_object_qs(User), PAGE_SIZE)
|
with provider.sync_lock as lock_acquired:
|
||||||
groups_paginator = Paginator(provider.get_object_qs(Group), PAGE_SIZE)
|
|
||||||
with allow_join_result(), provider.sync_lock as lock_acquired:
|
|
||||||
if not lock_acquired:
|
if not lock_acquired:
|
||||||
self.logger.debug("Failed to acquire sync lock, skipping", provider=provider.name)
|
self.logger.debug("Failed to acquire sync lock, skipping", provider=provider.name)
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
messages.append(_("Syncing users"))
|
users_tasks = group(
|
||||||
user_results = (
|
self.sync_paginator(
|
||||||
group(
|
current_task=task,
|
||||||
[
|
provider=provider,
|
||||||
sync_objects.signature(
|
sync_objects=sync_objects,
|
||||||
args=(class_to_path(User), page, provider_pk),
|
paginator=provider.get_paginator(User),
|
||||||
time_limit=PAGE_TIMEOUT,
|
object_type=User,
|
||||||
soft_time_limit=PAGE_TIMEOUT,
|
|
||||||
)
|
|
||||||
for page in users_paginator.page_range
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
.apply_async()
|
|
||||||
.get()
|
|
||||||
)
|
)
|
||||||
for result in user_results:
|
group_tasks = group(
|
||||||
for msg in result:
|
self.sync_paginator(
|
||||||
messages.append(LogEvent(**msg))
|
current_task=task,
|
||||||
messages.append(_("Syncing groups"))
|
provider=provider,
|
||||||
group_results = (
|
sync_objects=sync_objects,
|
||||||
group(
|
paginator=provider.get_paginator(Group),
|
||||||
[
|
object_type=Group,
|
||||||
sync_objects.signature(
|
|
||||||
args=(class_to_path(Group), page, provider_pk),
|
|
||||||
time_limit=PAGE_TIMEOUT,
|
|
||||||
soft_time_limit=PAGE_TIMEOUT,
|
|
||||||
)
|
|
||||||
for page in groups_paginator.page_range
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
.apply_async()
|
|
||||||
.get()
|
|
||||||
)
|
)
|
||||||
for result in group_results:
|
users_tasks.run().wait(timeout=provider.get_object_sync_time_limit(User))
|
||||||
for msg in result:
|
group_tasks.run().wait(timeout=provider.get_object_sync_time_limit(Group))
|
||||||
messages.append(LogEvent(**msg))
|
|
||||||
except TransientSyncException as exc:
|
except TransientSyncException as exc:
|
||||||
self.logger.warning("transient sync exception", exc=exc)
|
self.logger.warning("transient sync exception", exc=exc)
|
||||||
raise task.retry(exc=exc) from exc
|
raise Retry from exc
|
||||||
except StopSync as exc:
|
except StopSync as exc:
|
||||||
task.set_error(exc)
|
task.error(exc)
|
||||||
return
|
return
|
||||||
task.set_status(TaskStatus.SUCCESSFUL, *messages)
|
|
||||||
|
|
||||||
def sync_objects(
|
def sync_objects(
|
||||||
self, object_type: str, page: int, provider_pk: int, override_dry_run=False, **filter
|
self,
|
||||||
|
object_type: str,
|
||||||
|
page: int,
|
||||||
|
provider_pk: int,
|
||||||
|
override_dry_run=False,
|
||||||
|
**filter,
|
||||||
):
|
):
|
||||||
|
task = CurrentTask.get_task()
|
||||||
_object_type: type[Model] = path_to_class(object_type)
|
_object_type: type[Model] = path_to_class(object_type)
|
||||||
self.logger = get_logger().bind(
|
self.logger = get_logger().bind(
|
||||||
provider_type=class_to_path(self._provider_model),
|
provider_type=class_to_path(self._provider_model),
|
||||||
provider_pk=provider_pk,
|
provider_pk=provider_pk,
|
||||||
object_type=object_type,
|
object_type=object_type,
|
||||||
)
|
)
|
||||||
messages = []
|
task.info(f"Syncing page {page} of {_object_type._meta.verbose_name_plural}")
|
||||||
provider = self._provider_model.objects.filter(pk=provider_pk).first()
|
provider = self._provider_model.objects.filter(pk=provider_pk).first()
|
||||||
if not provider:
|
if not provider:
|
||||||
return messages
|
return
|
||||||
# Override dry run mode if requested, however don't save the provider
|
# Override dry run mode if requested, however don't save the provider
|
||||||
# so that scheduled sync tasks still run in dry_run mode
|
# so that scheduled sync tasks still run in dry_run mode
|
||||||
if override_dry_run:
|
if override_dry_run:
|
||||||
@ -147,25 +137,13 @@ class SyncTasks:
|
|||||||
try:
|
try:
|
||||||
client = provider.client_for_model(_object_type)
|
client = provider.client_for_model(_object_type)
|
||||||
except TransientSyncException:
|
except TransientSyncException:
|
||||||
return messages
|
return
|
||||||
paginator = Paginator(provider.get_object_qs(_object_type).filter(**filter), PAGE_SIZE)
|
paginator = Paginator(provider.get_object_qs(_object_type).filter(**filter), PAGE_SIZE)
|
||||||
if client.can_discover:
|
if client.can_discover:
|
||||||
self.logger.debug("starting discover")
|
self.logger.debug("starting discover")
|
||||||
client.discover()
|
client.discover()
|
||||||
self.logger.debug("starting sync for page", page=page)
|
self.logger.debug("starting sync for page", page=page)
|
||||||
messages.append(
|
task.info(f"Syncing page {page} or {_object_type._meta.verbose_name_plural}")
|
||||||
asdict(
|
|
||||||
LogEvent(
|
|
||||||
_(
|
|
||||||
"Syncing page {page} of {object_type}".format(
|
|
||||||
page=page, object_type=_object_type._meta.verbose_name_plural
|
|
||||||
)
|
|
||||||
),
|
|
||||||
log_level="info",
|
|
||||||
logger=f"{provider._meta.verbose_name}@{object_type}",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
for obj in paginator.page(page).object_list:
|
for obj in paginator.page(page).object_list:
|
||||||
obj: Model
|
obj: Model
|
||||||
try:
|
try:
|
||||||
@ -174,87 +152,34 @@ class SyncTasks:
|
|||||||
self.logger.debug("skipping object due to SkipObject", obj=obj)
|
self.logger.debug("skipping object due to SkipObject", obj=obj)
|
||||||
continue
|
continue
|
||||||
except DryRunRejected as exc:
|
except DryRunRejected as exc:
|
||||||
messages.append(
|
task.info(
|
||||||
asdict(
|
"Dropping mutating request due to dry run",
|
||||||
LogEvent(
|
attributes={
|
||||||
_("Dropping mutating request due to dry run"),
|
"obj": sanitize_item(obj),
|
||||||
log_level="info",
|
"method": exc.method,
|
||||||
logger=f"{provider._meta.verbose_name}@{object_type}",
|
"url": exc.url,
|
||||||
attributes={
|
"body": exc.body,
|
||||||
"obj": sanitize_item(obj),
|
},
|
||||||
"method": exc.method,
|
|
||||||
"url": exc.url,
|
|
||||||
"body": exc.body,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
except BadRequestSyncException as exc:
|
except BadRequestSyncException as exc:
|
||||||
self.logger.warning("failed to sync object", exc=exc, obj=obj)
|
self.logger.warning("failed to sync object", exc=exc, obj=obj)
|
||||||
messages.append(
|
task.warning(
|
||||||
asdict(
|
f"Failed to sync {obj._meta.verbose_name} {str(obj)} due to error: {str(exc)}",
|
||||||
LogEvent(
|
attributes={"arguments": exc.args[1:], "obj": sanitize_item(obj)},
|
||||||
_(
|
|
||||||
(
|
|
||||||
"Failed to sync {object_type} {object_name} "
|
|
||||||
"due to error: {error}"
|
|
||||||
).format_map(
|
|
||||||
{
|
|
||||||
"object_type": obj._meta.verbose_name,
|
|
||||||
"object_name": str(obj),
|
|
||||||
"error": str(exc),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
),
|
|
||||||
log_level="warning",
|
|
||||||
logger=f"{provider._meta.verbose_name}@{object_type}",
|
|
||||||
attributes={"arguments": exc.args[1:], "obj": sanitize_item(obj)},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
except TransientSyncException as exc:
|
except TransientSyncException as exc:
|
||||||
self.logger.warning("failed to sync object", exc=exc, user=obj)
|
self.logger.warning("failed to sync object", exc=exc, user=obj)
|
||||||
messages.append(
|
task.warning(
|
||||||
asdict(
|
f"Failed to sync {obj._meta.verbose_name} {str(obj)} due to transient error: {str(exc)}",
|
||||||
LogEvent(
|
attributes={"obj": sanitize_item(obj)},
|
||||||
_(
|
|
||||||
(
|
|
||||||
"Failed to sync {object_type} {object_name} "
|
|
||||||
"due to transient error: {error}"
|
|
||||||
).format_map(
|
|
||||||
{
|
|
||||||
"object_type": obj._meta.verbose_name,
|
|
||||||
"object_name": str(obj),
|
|
||||||
"error": str(exc),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
),
|
|
||||||
log_level="warning",
|
|
||||||
logger=f"{provider._meta.verbose_name}@{object_type}",
|
|
||||||
attributes={"obj": sanitize_item(obj)},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
except StopSync as exc:
|
except StopSync as exc:
|
||||||
self.logger.warning("Stopping sync", exc=exc)
|
self.logger.warning("Stopping sync", exc=exc)
|
||||||
messages.append(
|
task.warning(
|
||||||
asdict(
|
f"Stopping sync due to error: {exc.detail()}",
|
||||||
LogEvent(
|
attributes={"obj": sanitize_item(obj)},
|
||||||
_(
|
|
||||||
"Stopping sync due to error: {error}".format_map(
|
|
||||||
{
|
|
||||||
"error": exc.detail(),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
),
|
|
||||||
log_level="warning",
|
|
||||||
logger=f"{provider._meta.verbose_name}@{object_type}",
|
|
||||||
attributes={"obj": sanitize_item(obj)},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
return messages
|
|
||||||
|
|
||||||
def sync_signal_direct(self, model: str, pk: str | int, raw_op: str):
|
def sync_signal_direct(self, model: str, pk: str | int, raw_op: str):
|
||||||
self.logger = get_logger().bind(
|
self.logger = get_logger().bind(
|
||||||
|
@ -130,7 +130,8 @@ def outpost_controller(outpost_pk: str, action: str = "up", from_cache: bool = F
|
|||||||
else:
|
else:
|
||||||
if from_cache:
|
if from_cache:
|
||||||
cache.delete(CACHE_KEY_OUTPOST_DOWN % outpost_pk)
|
cache.delete(CACHE_KEY_OUTPOST_DOWN % outpost_pk)
|
||||||
self.info(*logs)
|
for log in logs:
|
||||||
|
self.info(log)
|
||||||
|
|
||||||
|
|
||||||
@actor
|
@actor
|
||||||
|
@ -85,6 +85,6 @@ class SCIMClientTests(TestCase):
|
|||||||
self.assertEqual(mock.call_count, 1)
|
self.assertEqual(mock.call_count, 1)
|
||||||
self.assertEqual(mock.request_history[0].method, "GET")
|
self.assertEqual(mock.request_history[0].method, "GET")
|
||||||
|
|
||||||
def test_scim_sync_all(self):
|
def test_scim_sync(self):
|
||||||
"""test scim_sync_all task"""
|
"""test scim_sync task"""
|
||||||
scim_sync.send(self.provider.pk).get_result()
|
scim_sync.send(self.provider.pk).get_result()
|
||||||
|
@ -23,12 +23,14 @@ from authentik.sources.kerberos.models import (
|
|||||||
Krb5ConfContext,
|
Krb5ConfContext,
|
||||||
UserKerberosSourceConnection,
|
UserKerberosSourceConnection,
|
||||||
)
|
)
|
||||||
|
from authentik.tasks.models import Task
|
||||||
|
|
||||||
|
|
||||||
class KerberosSync:
|
class KerberosSync:
|
||||||
"""Sync Kerberos users into authentik"""
|
"""Sync Kerberos users into authentik"""
|
||||||
|
|
||||||
_source: KerberosSource
|
_source: KerberosSource
|
||||||
|
_task: Task
|
||||||
_logger: BoundLogger
|
_logger: BoundLogger
|
||||||
_connection: KAdmin
|
_connection: KAdmin
|
||||||
mapper: SourceMapper
|
mapper: SourceMapper
|
||||||
@ -36,11 +38,11 @@ class KerberosSync:
|
|||||||
group_manager: PropertyMappingManager
|
group_manager: PropertyMappingManager
|
||||||
matcher: SourceMatcher
|
matcher: SourceMatcher
|
||||||
|
|
||||||
def __init__(self, source: KerberosSource):
|
def __init__(self, source: KerberosSource, task: Task):
|
||||||
self._source = source
|
self._source = source
|
||||||
|
self._task = task
|
||||||
with Krb5ConfContext(self._source):
|
with Krb5ConfContext(self._source):
|
||||||
self._connection = self._source.connection()
|
self._connection = self._source.connection()
|
||||||
self._messages = []
|
|
||||||
self._logger = get_logger().bind(source=self._source, syncer=self.__class__.__name__)
|
self._logger = get_logger().bind(source=self._source, syncer=self.__class__.__name__)
|
||||||
self.mapper = SourceMapper(self._source)
|
self.mapper = SourceMapper(self._source)
|
||||||
self.user_manager = self.mapper.get_manager(User, ["principal", "principal_obj"])
|
self.user_manager = self.mapper.get_manager(User, ["principal", "principal_obj"])
|
||||||
@ -56,17 +58,6 @@ class KerberosSync:
|
|||||||
"""UI name for the type of object this class synchronizes"""
|
"""UI name for the type of object this class synchronizes"""
|
||||||
return "users"
|
return "users"
|
||||||
|
|
||||||
@property
|
|
||||||
def messages(self) -> list[str]:
|
|
||||||
"""Get all UI messages"""
|
|
||||||
return self._messages
|
|
||||||
|
|
||||||
def message(self, *args, **kwargs):
|
|
||||||
"""Add message that is later added to the System Task and shown to the user"""
|
|
||||||
formatted_message = " ".join(args)
|
|
||||||
self._messages.append(formatted_message)
|
|
||||||
self._logger.warning(*args, **kwargs)
|
|
||||||
|
|
||||||
def _handle_principal(self, principal: str) -> bool:
|
def _handle_principal(self, principal: str) -> bool:
|
||||||
try:
|
try:
|
||||||
# TODO: handle permission error
|
# TODO: handle permission error
|
||||||
@ -163,7 +154,7 @@ class KerberosSync:
|
|||||||
def sync(self) -> int:
|
def sync(self) -> int:
|
||||||
"""Iterate over all Kerberos users and create authentik_core.User instances"""
|
"""Iterate over all Kerberos users and create authentik_core.User instances"""
|
||||||
if not self._source.enabled or not self._source.sync_users:
|
if not self._source.enabled or not self._source.sync_users:
|
||||||
self.message("Source is disabled or user syncing is disabled for this Source")
|
self._task.info("Source is disabled or user syncing is disabled for this Source")
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
user_count = 0
|
user_count = 0
|
||||||
|
@ -43,7 +43,6 @@ def kerberos_sync(pk: str):
|
|||||||
return
|
return
|
||||||
syncer = KerberosSync(source)
|
syncer = KerberosSync(source)
|
||||||
syncer.sync()
|
syncer.sync()
|
||||||
self.info(*syncer.messages)
|
|
||||||
except StopSync as exc:
|
except StopSync as exc:
|
||||||
LOGGER.warning(exception_to_string(exc))
|
LOGGER.warning(exception_to_string(exc))
|
||||||
self.error(exc)
|
self.error(exc)
|
||||||
|
@ -10,22 +10,23 @@ from authentik.core.sources.mapper import SourceMapper
|
|||||||
from authentik.lib.config import CONFIG
|
from authentik.lib.config import CONFIG
|
||||||
from authentik.lib.sync.mapper import PropertyMappingManager
|
from authentik.lib.sync.mapper import PropertyMappingManager
|
||||||
from authentik.sources.ldap.models import LDAPSource, flatten
|
from authentik.sources.ldap.models import LDAPSource, flatten
|
||||||
|
from authentik.tasks.models import Task
|
||||||
|
|
||||||
|
|
||||||
class BaseLDAPSynchronizer:
|
class BaseLDAPSynchronizer:
|
||||||
"""Sync LDAP Users and groups into authentik"""
|
"""Sync LDAP Users and groups into authentik"""
|
||||||
|
|
||||||
_source: LDAPSource
|
_source: LDAPSource
|
||||||
|
_task: Task
|
||||||
_logger: BoundLogger
|
_logger: BoundLogger
|
||||||
_connection: Connection
|
_connection: Connection
|
||||||
_messages: list[str]
|
|
||||||
mapper: SourceMapper
|
mapper: SourceMapper
|
||||||
manager: PropertyMappingManager
|
manager: PropertyMappingManager
|
||||||
|
|
||||||
def __init__(self, source: LDAPSource):
|
def __init__(self, source: LDAPSource, task: Task):
|
||||||
self._source = source
|
self._source = source
|
||||||
|
self._task = task
|
||||||
self._connection = source.connection()
|
self._connection = source.connection()
|
||||||
self._messages = []
|
|
||||||
self._logger = get_logger().bind(source=source, syncer=self.__class__.__name__)
|
self._logger = get_logger().bind(source=source, syncer=self.__class__.__name__)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -46,11 +47,6 @@ class BaseLDAPSynchronizer:
|
|||||||
"""Sync function, implemented in subclass"""
|
"""Sync function, implemented in subclass"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@property
|
|
||||||
def messages(self) -> list[str]:
|
|
||||||
"""Get all UI messages"""
|
|
||||||
return self._messages
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def base_dn_users(self) -> str:
|
def base_dn_users(self) -> str:
|
||||||
"""Shortcut to get full base_dn for user lookups"""
|
"""Shortcut to get full base_dn for user lookups"""
|
||||||
@ -65,14 +61,6 @@ class BaseLDAPSynchronizer:
|
|||||||
return f"{self._source.additional_group_dn},{self._source.base_dn}"
|
return f"{self._source.additional_group_dn},{self._source.base_dn}"
|
||||||
return self._source.base_dn
|
return self._source.base_dn
|
||||||
|
|
||||||
def message(self, *args, **kwargs):
|
|
||||||
"""Add message that is later added to the System Task and shown to the user"""
|
|
||||||
formatted_message = " ".join(args)
|
|
||||||
if "dn" in kwargs:
|
|
||||||
formatted_message += f"; DN: {kwargs['dn']}"
|
|
||||||
self._messages.append(formatted_message)
|
|
||||||
self._logger.warning(*args, **kwargs)
|
|
||||||
|
|
||||||
def get_objects(self, **kwargs) -> Generator:
|
def get_objects(self, **kwargs) -> Generator:
|
||||||
"""Get objects from LDAP, implemented in subclass"""
|
"""Get objects from LDAP, implemented in subclass"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
@ -19,7 +19,7 @@ class GroupLDAPForwardDeletion(BaseLDAPSynchronizer):
|
|||||||
|
|
||||||
def get_objects(self, **kwargs) -> Generator:
|
def get_objects(self, **kwargs) -> Generator:
|
||||||
if not self._source.sync_groups or not self._source.delete_not_found_objects:
|
if not self._source.sync_groups or not self._source.delete_not_found_objects:
|
||||||
self.message("Group syncing is disabled for this Source")
|
self._task.info("Group syncing is disabled for this Source")
|
||||||
return iter(())
|
return iter(())
|
||||||
|
|
||||||
uuid = uuid4()
|
uuid = uuid4()
|
||||||
@ -54,7 +54,7 @@ class GroupLDAPForwardDeletion(BaseLDAPSynchronizer):
|
|||||||
def sync(self, group_pks: tuple) -> int:
|
def sync(self, group_pks: tuple) -> int:
|
||||||
"""Delete authentik groups"""
|
"""Delete authentik groups"""
|
||||||
if not self._source.sync_groups or not self._source.delete_not_found_objects:
|
if not self._source.sync_groups or not self._source.delete_not_found_objects:
|
||||||
self.message("Group syncing is disabled for this Source")
|
self._task.info("Group syncing is disabled for this Source")
|
||||||
return -1
|
return -1
|
||||||
self._logger.debug("Deleting groups", group_pks=group_pks)
|
self._logger.debug("Deleting groups", group_pks=group_pks)
|
||||||
_, deleted_per_type = Group.objects.filter(pk__in=group_pks).delete()
|
_, deleted_per_type = Group.objects.filter(pk__in=group_pks).delete()
|
||||||
|
@ -21,7 +21,7 @@ class UserLDAPForwardDeletion(BaseLDAPSynchronizer):
|
|||||||
|
|
||||||
def get_objects(self, **kwargs) -> Generator:
|
def get_objects(self, **kwargs) -> Generator:
|
||||||
if not self._source.sync_users or not self._source.delete_not_found_objects:
|
if not self._source.sync_users or not self._source.delete_not_found_objects:
|
||||||
self.message("User syncing is disabled for this Source")
|
self._task.info("User syncing is disabled for this Source")
|
||||||
return iter(())
|
return iter(())
|
||||||
|
|
||||||
uuid = uuid4()
|
uuid = uuid4()
|
||||||
@ -56,7 +56,7 @@ class UserLDAPForwardDeletion(BaseLDAPSynchronizer):
|
|||||||
def sync(self, user_pks: tuple) -> int:
|
def sync(self, user_pks: tuple) -> int:
|
||||||
"""Delete authentik users"""
|
"""Delete authentik users"""
|
||||||
if not self._source.sync_users or not self._source.delete_not_found_objects:
|
if not self._source.sync_users or not self._source.delete_not_found_objects:
|
||||||
self.message("User syncing is disabled for this Source")
|
self._task.info("User syncing is disabled for this Source")
|
||||||
return -1
|
return -1
|
||||||
self._logger.debug("Deleting users", user_pks=user_pks)
|
self._logger.debug("Deleting users", user_pks=user_pks)
|
||||||
_, deleted_per_type = User.objects.filter(pk__in=user_pks).delete()
|
_, deleted_per_type = User.objects.filter(pk__in=user_pks).delete()
|
||||||
|
@ -37,7 +37,7 @@ class GroupLDAPSynchronizer(BaseLDAPSynchronizer):
|
|||||||
|
|
||||||
def get_objects(self, **kwargs) -> Generator:
|
def get_objects(self, **kwargs) -> Generator:
|
||||||
if not self._source.sync_groups:
|
if not self._source.sync_groups:
|
||||||
self.message("Group syncing is disabled for this Source")
|
self._task.info("Group syncing is disabled for this Source")
|
||||||
return iter(())
|
return iter(())
|
||||||
return self.search_paginator(
|
return self.search_paginator(
|
||||||
search_base=self.base_dn_groups,
|
search_base=self.base_dn_groups,
|
||||||
@ -54,7 +54,7 @@ class GroupLDAPSynchronizer(BaseLDAPSynchronizer):
|
|||||||
def sync(self, page_data: list) -> int:
|
def sync(self, page_data: list) -> int:
|
||||||
"""Iterate over all LDAP Groups and create authentik_core.Group instances"""
|
"""Iterate over all LDAP Groups and create authentik_core.Group instances"""
|
||||||
if not self._source.sync_groups:
|
if not self._source.sync_groups:
|
||||||
self.message("Group syncing is disabled for this Source")
|
self._task.info("Group syncing is disabled for this Source")
|
||||||
return -1
|
return -1
|
||||||
group_count = 0
|
group_count = 0
|
||||||
for group in page_data:
|
for group in page_data:
|
||||||
@ -62,7 +62,7 @@ class GroupLDAPSynchronizer(BaseLDAPSynchronizer):
|
|||||||
continue
|
continue
|
||||||
group_dn = flatten(flatten(group.get("entryDN", group.get("dn"))))
|
group_dn = flatten(flatten(group.get("entryDN", group.get("dn"))))
|
||||||
if not (uniq := self.get_identifier(attributes)):
|
if not (uniq := self.get_identifier(attributes)):
|
||||||
self.message(
|
self._task.info(
|
||||||
f"Uniqueness field not found/not set in attributes: '{group_dn}'",
|
f"Uniqueness field not found/not set in attributes: '{group_dn}'",
|
||||||
attributes=attributes.keys(),
|
attributes=attributes.keys(),
|
||||||
dn=group_dn,
|
dn=group_dn,
|
||||||
|
@ -26,7 +26,7 @@ class MembershipLDAPSynchronizer(BaseLDAPSynchronizer):
|
|||||||
|
|
||||||
def get_objects(self, **kwargs) -> Generator:
|
def get_objects(self, **kwargs) -> Generator:
|
||||||
if not self._source.sync_groups:
|
if not self._source.sync_groups:
|
||||||
self.message("Group syncing is disabled for this Source")
|
self._task.info("Group syncing is disabled for this Source")
|
||||||
return iter(())
|
return iter(())
|
||||||
|
|
||||||
# If we are looking up groups from users, we don't need to fetch the group membership field
|
# If we are looking up groups from users, we don't need to fetch the group membership field
|
||||||
@ -45,7 +45,7 @@ class MembershipLDAPSynchronizer(BaseLDAPSynchronizer):
|
|||||||
def sync(self, page_data: list) -> int:
|
def sync(self, page_data: list) -> int:
|
||||||
"""Iterate over all Users and assign Groups using memberOf Field"""
|
"""Iterate over all Users and assign Groups using memberOf Field"""
|
||||||
if not self._source.sync_groups:
|
if not self._source.sync_groups:
|
||||||
self.message("Group syncing is disabled for this Source")
|
self._task.info("Group syncing is disabled for this Source")
|
||||||
return -1
|
return -1
|
||||||
membership_count = 0
|
membership_count = 0
|
||||||
for group in page_data:
|
for group in page_data:
|
||||||
@ -94,7 +94,7 @@ class MembershipLDAPSynchronizer(BaseLDAPSynchronizer):
|
|||||||
# group_uniq might be a single string or an array with (hopefully) a single string
|
# group_uniq might be a single string or an array with (hopefully) a single string
|
||||||
if isinstance(group_uniq, list):
|
if isinstance(group_uniq, list):
|
||||||
if len(group_uniq) < 1:
|
if len(group_uniq) < 1:
|
||||||
self.message(
|
self._task.info(
|
||||||
f"Group does not have a uniqueness attribute: '{group_dn}'",
|
f"Group does not have a uniqueness attribute: '{group_dn}'",
|
||||||
group=group_dn,
|
group=group_dn,
|
||||||
)
|
)
|
||||||
@ -104,7 +104,7 @@ class MembershipLDAPSynchronizer(BaseLDAPSynchronizer):
|
|||||||
groups = Group.objects.filter(**{f"attributes__{LDAP_UNIQUENESS}": group_uniq})
|
groups = Group.objects.filter(**{f"attributes__{LDAP_UNIQUENESS}": group_uniq})
|
||||||
if not groups.exists():
|
if not groups.exists():
|
||||||
if self._source.sync_groups:
|
if self._source.sync_groups:
|
||||||
self.message(
|
self._task.info(
|
||||||
f"Group does not exist in our DB yet, run sync_groups first: '{group_dn}'",
|
f"Group does not exist in our DB yet, run sync_groups first: '{group_dn}'",
|
||||||
group=group_dn,
|
group=group_dn,
|
||||||
)
|
)
|
||||||
|
@ -39,7 +39,7 @@ class UserLDAPSynchronizer(BaseLDAPSynchronizer):
|
|||||||
|
|
||||||
def get_objects(self, **kwargs) -> Generator:
|
def get_objects(self, **kwargs) -> Generator:
|
||||||
if not self._source.sync_users:
|
if not self._source.sync_users:
|
||||||
self.message("User syncing is disabled for this Source")
|
self._task.info("User syncing is disabled for this Source")
|
||||||
return iter(())
|
return iter(())
|
||||||
return self.search_paginator(
|
return self.search_paginator(
|
||||||
search_base=self.base_dn_users,
|
search_base=self.base_dn_users,
|
||||||
@ -56,7 +56,7 @@ class UserLDAPSynchronizer(BaseLDAPSynchronizer):
|
|||||||
def sync(self, page_data: list) -> int:
|
def sync(self, page_data: list) -> int:
|
||||||
"""Iterate over all LDAP Users and create authentik_core.User instances"""
|
"""Iterate over all LDAP Users and create authentik_core.User instances"""
|
||||||
if not self._source.sync_users:
|
if not self._source.sync_users:
|
||||||
self.message("User syncing is disabled for this Source")
|
self._task.info("User syncing is disabled for this Source")
|
||||||
return -1
|
return -1
|
||||||
user_count = 0
|
user_count = 0
|
||||||
for user in page_data:
|
for user in page_data:
|
||||||
@ -64,7 +64,7 @@ class UserLDAPSynchronizer(BaseLDAPSynchronizer):
|
|||||||
continue
|
continue
|
||||||
user_dn = flatten(user.get("entryDN", user.get("dn")))
|
user_dn = flatten(user.get("entryDN", user.get("dn")))
|
||||||
if not (uniq := self.get_identifier(attributes)):
|
if not (uniq := self.get_identifier(attributes)):
|
||||||
self.message(
|
self._task.info(
|
||||||
f"Uniqueness field not found/not set in attributes: '{user_dn}'",
|
f"Uniqueness field not found/not set in attributes: '{user_dn}'",
|
||||||
attributes=attributes.keys(),
|
attributes=attributes.keys(),
|
||||||
dn=user_dn,
|
dn=user_dn,
|
||||||
|
@ -30,7 +30,7 @@ class FreeIPA(BaseLDAPSynchronizer):
|
|||||||
pwd_last_set: datetime = attributes.get("krbLastPwdChange", datetime.now())
|
pwd_last_set: datetime = attributes.get("krbLastPwdChange", datetime.now())
|
||||||
pwd_last_set = pwd_last_set.replace(tzinfo=UTC)
|
pwd_last_set = pwd_last_set.replace(tzinfo=UTC)
|
||||||
if created or pwd_last_set >= user.password_change_date:
|
if created or pwd_last_set >= user.password_change_date:
|
||||||
self.message(f"'{user.username}': Reset user's password")
|
self._task.info(f"'{user.username}': Reset user's password")
|
||||||
self._logger.debug(
|
self._logger.debug(
|
||||||
"Reset user's password",
|
"Reset user's password",
|
||||||
user=user.username,
|
user=user.username,
|
||||||
|
2
authentik/sources/ldap/sync/vendor/ms_ad.py
vendored
2
authentik/sources/ldap/sync/vendor/ms_ad.py
vendored
@ -60,7 +60,7 @@ class MicrosoftActiveDirectory(BaseLDAPSynchronizer):
|
|||||||
pwd_last_set: datetime = attributes.get("pwdLastSet", datetime.now())
|
pwd_last_set: datetime = attributes.get("pwdLastSet", datetime.now())
|
||||||
pwd_last_set = pwd_last_set.replace(tzinfo=UTC)
|
pwd_last_set = pwd_last_set.replace(tzinfo=UTC)
|
||||||
if created or pwd_last_set >= user.password_change_date:
|
if created or pwd_last_set >= user.password_change_date:
|
||||||
self.message(f"'{user.username}': Reset user's password")
|
self._task.info(f"'{user.username}': Reset user's password")
|
||||||
self._logger.debug(
|
self._logger.debug(
|
||||||
"Reset user's password",
|
"Reset user's password",
|
||||||
user=user.username,
|
user=user.username,
|
||||||
|
@ -72,14 +72,12 @@ def ldap_sync(source_pk: str):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# User and group sync can happen at once, they have no dependencies on each other
|
# User and group sync can happen at once, they have no dependencies on each other
|
||||||
user_group_tasks.run().get_results(
|
user_group_tasks.run().wait(
|
||||||
block=True,
|
timeout=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours") * 1000
|
||||||
timeout=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours") * 1000,
|
|
||||||
)
|
)
|
||||||
# Membership sync needs to run afterwards
|
# Membership sync needs to run afterwards
|
||||||
membership_tasks.run().get_results(
|
membership_tasks.run().wait(
|
||||||
block=True,
|
timeout=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours") * 1000
|
||||||
timeout=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours") * 1000,
|
|
||||||
)
|
)
|
||||||
# Finally, deletions. What we'd really like to do here is something like
|
# Finally, deletions. What we'd really like to do here is something like
|
||||||
# ```
|
# ```
|
||||||
@ -96,7 +94,9 @@ def ldap_sync(source_pk: str):
|
|||||||
# large chunks, and only queue the deletion step afterwards.
|
# large chunks, and only queue the deletion step afterwards.
|
||||||
# 3. Delete every unmarked item. This is slow, so we spread it over many tasks in
|
# 3. Delete every unmarked item. This is slow, so we spread it over many tasks in
|
||||||
# small chunks.
|
# small chunks.
|
||||||
deletion_tasks.run() # no need to block here, we don't have anything else to do afterwards
|
deletion_tasks.run().wait(
|
||||||
|
timeout=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours") * 1000,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) -> list[Message]:
|
def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) -> list[Message]:
|
||||||
@ -139,9 +139,7 @@ def ldap_sync_page(source_pk: str, sync_class: str, page_cache_key: str):
|
|||||||
return
|
return
|
||||||
cache.touch(page_cache_key)
|
cache.touch(page_cache_key)
|
||||||
count = sync_inst.sync(page)
|
count = sync_inst.sync(page)
|
||||||
messages = sync_inst.messages
|
self.info(f"Synced {count} objects.")
|
||||||
messages.append(f"Synced {count} objects.")
|
|
||||||
self.info(*messages)
|
|
||||||
cache.delete(page_cache_key)
|
cache.delete(page_cache_key)
|
||||||
except (LDAPException, StopSync) as exc:
|
except (LDAPException, StopSync) as exc:
|
||||||
# No explicit event is created here as .set_status with an error will do that
|
# No explicit event is created here as .set_status with an error will do that
|
||||||
|
@ -13,6 +13,7 @@ from authentik.sources.ldap.models import LDAPSource, LDAPSourcePropertyMapping
|
|||||||
from authentik.sources.ldap.sync.users import UserLDAPSynchronizer
|
from authentik.sources.ldap.sync.users import UserLDAPSynchronizer
|
||||||
from authentik.sources.ldap.tests.mock_ad import mock_ad_connection
|
from authentik.sources.ldap.tests.mock_ad import mock_ad_connection
|
||||||
from authentik.sources.ldap.tests.mock_slapd import mock_slapd_connection
|
from authentik.sources.ldap.tests.mock_slapd import mock_slapd_connection
|
||||||
|
from authentik.tasks.models import Task
|
||||||
|
|
||||||
LDAP_PASSWORD = generate_key()
|
LDAP_PASSWORD = generate_key()
|
||||||
|
|
||||||
@ -43,7 +44,7 @@ class LDAPSyncTests(TestCase):
|
|||||||
raw_conn.bind = bind_mock
|
raw_conn.bind = bind_mock
|
||||||
connection = MagicMock(return_value=raw_conn)
|
connection = MagicMock(return_value=raw_conn)
|
||||||
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
|
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
|
||||||
user_sync = UserLDAPSynchronizer(self.source)
|
user_sync = UserLDAPSynchronizer(self.source, Task())
|
||||||
user_sync.sync_full()
|
user_sync.sync_full()
|
||||||
|
|
||||||
user = User.objects.get(username="user0_sn")
|
user = User.objects.get(username="user0_sn")
|
||||||
@ -71,7 +72,7 @@ class LDAPSyncTests(TestCase):
|
|||||||
)
|
)
|
||||||
connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD))
|
connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD))
|
||||||
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
|
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
|
||||||
user_sync = UserLDAPSynchronizer(self.source)
|
user_sync = UserLDAPSynchronizer(self.source, Task())
|
||||||
user_sync.sync_full()
|
user_sync.sync_full()
|
||||||
|
|
||||||
user = User.objects.get(username="user0_sn")
|
user = User.objects.get(username="user0_sn")
|
||||||
@ -98,7 +99,7 @@ class LDAPSyncTests(TestCase):
|
|||||||
self.source.save()
|
self.source.save()
|
||||||
connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD))
|
connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD))
|
||||||
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
|
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
|
||||||
user_sync = UserLDAPSynchronizer(self.source)
|
user_sync = UserLDAPSynchronizer(self.source, Task())
|
||||||
user_sync.sync_full()
|
user_sync.sync_full()
|
||||||
|
|
||||||
user = User.objects.get(username="user0_sn")
|
user = User.objects.get(username="user0_sn")
|
||||||
|
@ -23,7 +23,7 @@ from authentik.sources.ldap.sync.forward_delete_users import DELETE_CHUNK_SIZE
|
|||||||
from authentik.sources.ldap.sync.groups import GroupLDAPSynchronizer
|
from authentik.sources.ldap.sync.groups import GroupLDAPSynchronizer
|
||||||
from authentik.sources.ldap.sync.membership import MembershipLDAPSynchronizer
|
from authentik.sources.ldap.sync.membership import MembershipLDAPSynchronizer
|
||||||
from authentik.sources.ldap.sync.users import UserLDAPSynchronizer
|
from authentik.sources.ldap.sync.users import UserLDAPSynchronizer
|
||||||
from authentik.sources.ldap.tasks import ldap_sync, ldap_sync_all, ldap_sync_page
|
from authentik.sources.ldap.tasks import ldap_sync, ldap_sync_page
|
||||||
from authentik.sources.ldap.tests.mock_ad import mock_ad_connection
|
from authentik.sources.ldap.tests.mock_ad import mock_ad_connection
|
||||||
from authentik.sources.ldap.tests.mock_freeipa import mock_freeipa_connection
|
from authentik.sources.ldap.tests.mock_freeipa import mock_freeipa_connection
|
||||||
from authentik.sources.ldap.tests.mock_slapd import (
|
from authentik.sources.ldap.tests.mock_slapd import (
|
||||||
@ -33,6 +33,7 @@ from authentik.sources.ldap.tests.mock_slapd import (
|
|||||||
user_in_slapd_cn,
|
user_in_slapd_cn,
|
||||||
user_in_slapd_uid,
|
user_in_slapd_uid,
|
||||||
)
|
)
|
||||||
|
from authentik.tasks.models import Task
|
||||||
|
|
||||||
LDAP_PASSWORD = generate_key()
|
LDAP_PASSWORD = generate_key()
|
||||||
|
|
||||||
@ -74,7 +75,7 @@ class LDAPSyncTests(TestCase):
|
|||||||
self.source.save()
|
self.source.save()
|
||||||
connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD))
|
connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD))
|
||||||
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
|
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
|
||||||
user_sync = UserLDAPSynchronizer(self.source)
|
user_sync = UserLDAPSynchronizer(self.source, Task())
|
||||||
with self.assertRaises(StopSync):
|
with self.assertRaises(StopSync):
|
||||||
user_sync.sync_full()
|
user_sync.sync_full()
|
||||||
self.assertFalse(User.objects.filter(username="user0_sn").exists())
|
self.assertFalse(User.objects.filter(username="user0_sn").exists())
|
||||||
@ -105,7 +106,7 @@ class LDAPSyncTests(TestCase):
|
|||||||
|
|
||||||
# we basically just test that the mappings don't throw errors
|
# we basically just test that the mappings don't throw errors
|
||||||
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
|
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
|
||||||
user_sync = UserLDAPSynchronizer(self.source)
|
user_sync = UserLDAPSynchronizer(self.source, Task())
|
||||||
user_sync.sync_full()
|
user_sync.sync_full()
|
||||||
|
|
||||||
def test_sync_users_ad(self):
|
def test_sync_users_ad(self):
|
||||||
@ -133,7 +134,7 @@ class LDAPSyncTests(TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
|
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
|
||||||
user_sync = UserLDAPSynchronizer(self.source)
|
user_sync = UserLDAPSynchronizer(self.source, Task())
|
||||||
user_sync.sync_full()
|
user_sync.sync_full()
|
||||||
user = User.objects.filter(username="user0_sn").first()
|
user = User.objects.filter(username="user0_sn").first()
|
||||||
self.assertEqual(user.attributes["foo"], "bar")
|
self.assertEqual(user.attributes["foo"], "bar")
|
||||||
@ -152,7 +153,7 @@ class LDAPSyncTests(TestCase):
|
|||||||
)
|
)
|
||||||
connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD))
|
connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD))
|
||||||
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
|
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
|
||||||
user_sync = UserLDAPSynchronizer(self.source)
|
user_sync = UserLDAPSynchronizer(self.source, Task())
|
||||||
user_sync.sync_full()
|
user_sync.sync_full()
|
||||||
self.assertTrue(User.objects.filter(username="user0_sn").exists())
|
self.assertTrue(User.objects.filter(username="user0_sn").exists())
|
||||||
self.assertFalse(User.objects.filter(username="user1_sn").exists())
|
self.assertFalse(User.objects.filter(username="user1_sn").exists())
|
||||||
@ -168,7 +169,7 @@ class LDAPSyncTests(TestCase):
|
|||||||
)
|
)
|
||||||
connection = MagicMock(return_value=mock_freeipa_connection(LDAP_PASSWORD))
|
connection = MagicMock(return_value=mock_freeipa_connection(LDAP_PASSWORD))
|
||||||
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
|
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
|
||||||
user_sync = UserLDAPSynchronizer(self.source)
|
user_sync = UserLDAPSynchronizer(self.source, Task())
|
||||||
user_sync.sync_full()
|
user_sync.sync_full()
|
||||||
self.assertTrue(User.objects.filter(username="user0_sn").exists())
|
self.assertTrue(User.objects.filter(username="user0_sn").exists())
|
||||||
self.assertFalse(User.objects.filter(username="user1_sn").exists())
|
self.assertFalse(User.objects.filter(username="user1_sn").exists())
|
||||||
@ -193,11 +194,11 @@ class LDAPSyncTests(TestCase):
|
|||||||
)
|
)
|
||||||
connection = MagicMock(return_value=mock_freeipa_connection(LDAP_PASSWORD))
|
connection = MagicMock(return_value=mock_freeipa_connection(LDAP_PASSWORD))
|
||||||
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
|
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
|
||||||
user_sync = UserLDAPSynchronizer(self.source)
|
user_sync = UserLDAPSynchronizer(self.source, Task())
|
||||||
user_sync.sync_full()
|
user_sync.sync_full()
|
||||||
group_sync = GroupLDAPSynchronizer(self.source)
|
group_sync = GroupLDAPSynchronizer(self.source, Task())
|
||||||
group_sync.sync_full()
|
group_sync.sync_full()
|
||||||
membership_sync = MembershipLDAPSynchronizer(self.source)
|
membership_sync = MembershipLDAPSynchronizer(self.source, Task())
|
||||||
membership_sync.sync_full()
|
membership_sync.sync_full()
|
||||||
|
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
@ -230,9 +231,9 @@ class LDAPSyncTests(TestCase):
|
|||||||
parent_group = Group.objects.get(name=_user.username)
|
parent_group = Group.objects.get(name=_user.username)
|
||||||
self.source.sync_parent_group = parent_group
|
self.source.sync_parent_group = parent_group
|
||||||
self.source.save()
|
self.source.save()
|
||||||
group_sync = GroupLDAPSynchronizer(self.source)
|
group_sync = GroupLDAPSynchronizer(self.source, Task())
|
||||||
group_sync.sync_full()
|
group_sync.sync_full()
|
||||||
membership_sync = MembershipLDAPSynchronizer(self.source)
|
membership_sync = MembershipLDAPSynchronizer(self.source, Task())
|
||||||
membership_sync.sync_full()
|
membership_sync.sync_full()
|
||||||
group: Group = Group.objects.filter(name="test-group").first()
|
group: Group = Group.objects.filter(name="test-group").first()
|
||||||
self.assertIsNotNone(group)
|
self.assertIsNotNone(group)
|
||||||
@ -256,9 +257,9 @@ class LDAPSyncTests(TestCase):
|
|||||||
connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD))
|
connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD))
|
||||||
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
|
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
|
||||||
self.source.save()
|
self.source.save()
|
||||||
group_sync = GroupLDAPSynchronizer(self.source)
|
group_sync = GroupLDAPSynchronizer(self.source, Task())
|
||||||
group_sync.sync_full()
|
group_sync.sync_full()
|
||||||
membership_sync = MembershipLDAPSynchronizer(self.source)
|
membership_sync = MembershipLDAPSynchronizer(self.source, Task())
|
||||||
membership_sync.sync_full()
|
membership_sync.sync_full()
|
||||||
group = Group.objects.filter(name="group1")
|
group = Group.objects.filter(name="group1")
|
||||||
self.assertTrue(group.exists())
|
self.assertTrue(group.exists())
|
||||||
@ -290,11 +291,11 @@ class LDAPSyncTests(TestCase):
|
|||||||
connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD))
|
connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD))
|
||||||
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
|
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
|
||||||
self.source.save()
|
self.source.save()
|
||||||
user_sync = UserLDAPSynchronizer(self.source)
|
user_sync = UserLDAPSynchronizer(self.source, Task())
|
||||||
user_sync.sync_full()
|
user_sync.sync_full()
|
||||||
group_sync = GroupLDAPSynchronizer(self.source)
|
group_sync = GroupLDAPSynchronizer(self.source, Task())
|
||||||
group_sync.sync_full()
|
group_sync.sync_full()
|
||||||
membership_sync = MembershipLDAPSynchronizer(self.source)
|
membership_sync = MembershipLDAPSynchronizer(self.source, Task())
|
||||||
membership_sync.sync_full()
|
membership_sync.sync_full()
|
||||||
# Test if membership mapping based on memberUid works.
|
# Test if membership mapping based on memberUid works.
|
||||||
posix_group = Group.objects.filter(name="group-posix").first()
|
posix_group = Group.objects.filter(name="group-posix").first()
|
||||||
@ -327,11 +328,11 @@ class LDAPSyncTests(TestCase):
|
|||||||
connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD))
|
connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD))
|
||||||
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
|
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
|
||||||
self.source.save()
|
self.source.save()
|
||||||
user_sync = UserLDAPSynchronizer(self.source)
|
user_sync = UserLDAPSynchronizer(self.source, Task())
|
||||||
user_sync.sync_full()
|
user_sync.sync_full()
|
||||||
group_sync = GroupLDAPSynchronizer(self.source)
|
group_sync = GroupLDAPSynchronizer(self.source, Task())
|
||||||
group_sync.sync_full()
|
group_sync.sync_full()
|
||||||
membership_sync = MembershipLDAPSynchronizer(self.source)
|
membership_sync = MembershipLDAPSynchronizer(self.source, Task())
|
||||||
membership_sync.sync_full()
|
membership_sync.sync_full()
|
||||||
# Test if membership mapping based on memberUid works.
|
# Test if membership mapping based on memberUid works.
|
||||||
posix_group = Group.objects.filter(name="group-posix").first()
|
posix_group = Group.objects.filter(name="group-posix").first()
|
||||||
|
@ -112,23 +112,20 @@ class Task(SerializerModel):
|
|||||||
if save:
|
if save:
|
||||||
self.save()
|
self.save()
|
||||||
|
|
||||||
def log(self, status: TaskStatus, *messages: str | LogEvent | Exception, save: bool = False):
|
def log(self, status: TaskStatus, message: str | Exception, save: bool = False, **attributes):
|
||||||
self.messages: list
|
self.messages: list
|
||||||
for msg in messages:
|
if isinstance(message, Exception):
|
||||||
message = msg
|
message = exception_to_string(message)
|
||||||
if isinstance(message, Exception):
|
message = LogEvent(message, logger=self.uid, log_level=status.value, attributes=attributes)
|
||||||
message = exception_to_string(message)
|
self.messages.append(sanitize_item(message))
|
||||||
if not isinstance(message, LogEvent):
|
|
||||||
message = LogEvent(message, logger=self.uid, log_level=status.value)
|
|
||||||
self.messages.append(sanitize_item(message))
|
|
||||||
if save:
|
if save:
|
||||||
self.save()
|
self.save()
|
||||||
|
|
||||||
def info(self, *messages: str | LogEvent | Exception, save: bool = False):
|
def info(self, message: str | Exception, save: bool = False, **attributes):
|
||||||
self.log(TaskStatus.INFO, *messages, save=save)
|
self.log(TaskStatus.INFO, message, save=save, **attributes)
|
||||||
|
|
||||||
def warning(self, *messages: str | LogEvent | Exception, save: bool = False):
|
def warning(self, message: str | Exception, save: bool = False, **attributes):
|
||||||
self.log(TaskStatus.WARNING, *messages, save=save)
|
self.log(TaskStatus.WARNING, message, save=save, **attributes)
|
||||||
|
|
||||||
def error(self, *messages: str | LogEvent | Exception, save: bool = False):
|
def error(self, message: str | Exception, save: bool = False, **attributes):
|
||||||
self.log(TaskStatus.ERROR, *messages, save=save)
|
self.log(TaskStatus.ERROR, message, save=save, **attributes)
|
||||||
|
Reference in New Issue
Block a user