uhhhhhhhhh

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
Marc 'risson' Schmitt
2025-06-05 18:15:20 +02:00
parent 59c8472628
commit 3766ca86e8
18 changed files with 164 additions and 257 deletions

View File

@ -45,11 +45,17 @@ class OutgoingSyncProvider(ScheduledModel, Model):
def get_object_qs[T: User | Group](self, type: type[T]) -> QuerySet[T]:
raise NotImplementedError
def get_paginator[T: User | Group](self, type: type[T]) -> Paginator:
return Paginator(self.get_object_qs(type), PAGE_SIZE)
def get_object_sync_time_limit[T: User | Group](self, type: type[T]) -> int:
num_pages: int = self.get_paginator(type).num_pages
return int(num_pages * PAGE_TIMEOUT * 1.5) * 1000
def get_sync_time_limit(self) -> int:
users_paginator = Paginator(self.get_object_qs(User), PAGE_SIZE)
groups_paginator = Paginator(self.get_object_qs(Group), PAGE_SIZE)
time_limit = (users_paginator.num_pages + groups_paginator.num_pages) * PAGE_TIMEOUT * 1.5
return int(time_limit)
return int(
self.get_object_sync_time_limit(User) + self.get_object_sync_time_limit(Group) * 1.5
)
@property
def sync_lock(self) -> pglock.advisory:
@ -72,7 +78,7 @@ class OutgoingSyncProvider(ScheduledModel, Model):
uid=self.pk,
args=(self.pk,),
options={
"time_limit": self.get_sync_time_limit() * 1000,
"time_limit": self.get_sync_time_limit(),
},
send_on_save=True,
crontab=f"{fqdn_rand(self.pk)} */4 * * *",

View File

@ -1,21 +1,19 @@
from collections.abc import Callable
from dataclasses import asdict
from celery import group
from celery.exceptions import Retry
from celery.result import allow_join_result
from dramatiq.composition import group
from django.core.paginator import Paginator
from django.db.models import Model, QuerySet
from django.db.models.query import Q
from django.utils.text import slugify
from django.utils.translation import gettext_lazy as _
from dramatiq.actor import Actor
from dramatiq.errors import Retry
from structlog.stdlib import BoundLogger, get_logger
from authentik.core.expression.exceptions import SkipObjectException
from authentik.core.models import Group, User
from authentik.events.logs import LogEvent
from authentik.events.models import TaskStatus
from authentik.events.system_tasks import SystemTask
from authentik.events.utils import sanitize_item
from authentik.lib.sync.outgoing import PAGE_SIZE, PAGE_TIMEOUT
from authentik.lib.sync.outgoing.base import Direction
@ -27,11 +25,13 @@ from authentik.lib.sync.outgoing.exceptions import (
)
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
from authentik.lib.utils.reflection import class_to_path, path_to_class
from authentik.tasks.middleware import CurrentTask
from authentik.tasks.models import Task
class SyncTasks:
"""Container for all sync 'tasks' (this class doesn't actually contain celery
tasks due to celery's magic, however exposes a number of functions to be called from tasks)"""
"""Container for all sync 'tasks' (this class doesn't actually contain
tasks due to dramatiq's magic, however exposes a number of functions to be called from tasks)"""
logger: BoundLogger
@ -39,107 +39,97 @@ class SyncTasks:
super().__init__()
self._provider_model = provider_model
def sync_all(self, single_sync: Callable[[int], None]):
for provider in self._provider_model.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
):
self.trigger_single_task(provider, single_sync)
def trigger_single_task(self, provider: OutgoingSyncProvider, sync_task: Callable[[int], None]):
"""Wrapper single sync task that correctly sets time limits based
on the amount of objects that will be synced"""
users_paginator = Paginator(provider.get_object_qs(User), PAGE_SIZE)
groups_paginator = Paginator(provider.get_object_qs(Group), PAGE_SIZE)
soft_time_limit = (users_paginator.num_pages + groups_paginator.num_pages) * PAGE_TIMEOUT
time_limit = soft_time_limit * 1.5
return sync_task.apply_async(
(provider.pk,), time_limit=int(time_limit), soft_time_limit=int(soft_time_limit)
)
def sync_single(
def sync_paginator(
self,
task: SystemTask,
provider_pk: int,
sync_objects: Callable[[int, int], list[str]],
current_task: Task,
provider: OutgoingSyncProvider,
sync_objects: Actor,
paginator: Paginator,
object_type: type[User | Group],
**options,
):
tasks = []
for page in paginator.page_range:
page_sync = sync_objects.message_with_options(
args=(class_to_path(object_type), page, provider.pk),
time_limit=PAGE_TIMEOUT * 1000,
# Assign tasks to the same schedule as the current one
rel_obj=current_task.rel_obj,
**options,
)
tasks.append(page_sync)
return tasks
def sync(
self,
provider_pk: int,
sync_objects: Actor,
):
task = CurrentTask.get_task()
self.logger = get_logger().bind(
provider_type=class_to_path(self._provider_model),
provider_pk=provider_pk,
)
provider = self._provider_model.objects.filter(
provider: OutgoingSyncProvider = self._provider_model.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False),
pk=provider_pk,
).first()
if not provider:
return
task.set_uid(slugify(provider.name))
messages = []
messages.append(_("Starting full provider sync"))
task.info("Starting full provider sync")
self.logger.debug("Starting provider sync")
users_paginator = Paginator(provider.get_object_qs(User), PAGE_SIZE)
groups_paginator = Paginator(provider.get_object_qs(Group), PAGE_SIZE)
with allow_join_result(), provider.sync_lock as lock_acquired:
with provider.sync_lock as lock_acquired:
if not lock_acquired:
self.logger.debug("Failed to acquire sync lock, skipping", provider=provider.name)
return
try:
messages.append(_("Syncing users"))
user_results = (
group(
[
sync_objects.signature(
args=(class_to_path(User), page, provider_pk),
time_limit=PAGE_TIMEOUT,
soft_time_limit=PAGE_TIMEOUT,
)
for page in users_paginator.page_range
]
users_tasks = group(
self.sync_paginator(
current_task=task,
provider=provider,
sync_objects=sync_objects,
paginator=provider.get_paginator(User),
object_type=User,
)
.apply_async()
.get()
)
for result in user_results:
for msg in result:
messages.append(LogEvent(**msg))
messages.append(_("Syncing groups"))
group_results = (
group(
[
sync_objects.signature(
args=(class_to_path(Group), page, provider_pk),
time_limit=PAGE_TIMEOUT,
soft_time_limit=PAGE_TIMEOUT,
)
for page in groups_paginator.page_range
]
group_tasks = group(
self.sync_paginator(
current_task=task,
provider=provider,
sync_objects=sync_objects,
paginator=provider.get_paginator(Group),
object_type=Group,
)
.apply_async()
.get()
)
for result in group_results:
for msg in result:
messages.append(LogEvent(**msg))
users_tasks.run().wait(timeout=provider.get_object_sync_time_limit(User))
group_tasks.run().wait(timeout=provider.get_object_sync_time_limit(Group))
except TransientSyncException as exc:
self.logger.warning("transient sync exception", exc=exc)
raise task.retry(exc=exc) from exc
raise Retry from exc
except StopSync as exc:
task.set_error(exc)
task.error(exc)
return
task.set_status(TaskStatus.SUCCESSFUL, *messages)
def sync_objects(
self, object_type: str, page: int, provider_pk: int, override_dry_run=False, **filter
self,
object_type: str,
page: int,
provider_pk: int,
override_dry_run=False,
**filter,
):
task = CurrentTask.get_task()
_object_type: type[Model] = path_to_class(object_type)
self.logger = get_logger().bind(
provider_type=class_to_path(self._provider_model),
provider_pk=provider_pk,
object_type=object_type,
)
messages = []
task.info(f"Syncing page {page} of {_object_type._meta.verbose_name_plural}")
provider = self._provider_model.objects.filter(pk=provider_pk).first()
if not provider:
return messages
return
# Override dry run mode if requested, however don't save the provider
# so that scheduled sync tasks still run in dry_run mode
if override_dry_run:
@ -147,25 +137,13 @@ class SyncTasks:
try:
client = provider.client_for_model(_object_type)
except TransientSyncException:
return messages
return
paginator = Paginator(provider.get_object_qs(_object_type).filter(**filter), PAGE_SIZE)
if client.can_discover:
self.logger.debug("starting discover")
client.discover()
self.logger.debug("starting sync for page", page=page)
messages.append(
asdict(
LogEvent(
_(
"Syncing page {page} of {object_type}".format(
page=page, object_type=_object_type._meta.verbose_name_plural
)
),
log_level="info",
logger=f"{provider._meta.verbose_name}@{object_type}",
)
)
)
task.info(f"Syncing page {page} or {_object_type._meta.verbose_name_plural}")
for obj in paginator.page(page).object_list:
obj: Model
try:
@ -174,87 +152,34 @@ class SyncTasks:
self.logger.debug("skipping object due to SkipObject", obj=obj)
continue
except DryRunRejected as exc:
messages.append(
asdict(
LogEvent(
_("Dropping mutating request due to dry run"),
log_level="info",
logger=f"{provider._meta.verbose_name}@{object_type}",
attributes={
"obj": sanitize_item(obj),
"method": exc.method,
"url": exc.url,
"body": exc.body,
},
)
)
task.info(
"Dropping mutating request due to dry run",
attributes={
"obj": sanitize_item(obj),
"method": exc.method,
"url": exc.url,
"body": exc.body,
},
)
except BadRequestSyncException as exc:
self.logger.warning("failed to sync object", exc=exc, obj=obj)
messages.append(
asdict(
LogEvent(
_(
(
"Failed to sync {object_type} {object_name} "
"due to error: {error}"
).format_map(
{
"object_type": obj._meta.verbose_name,
"object_name": str(obj),
"error": str(exc),
}
)
),
log_level="warning",
logger=f"{provider._meta.verbose_name}@{object_type}",
attributes={"arguments": exc.args[1:], "obj": sanitize_item(obj)},
)
)
task.warning(
f"Failed to sync {obj._meta.verbose_name} {str(obj)} due to error: {str(exc)}",
attributes={"arguments": exc.args[1:], "obj": sanitize_item(obj)},
)
except TransientSyncException as exc:
self.logger.warning("failed to sync object", exc=exc, user=obj)
messages.append(
asdict(
LogEvent(
_(
(
"Failed to sync {object_type} {object_name} "
"due to transient error: {error}"
).format_map(
{
"object_type": obj._meta.verbose_name,
"object_name": str(obj),
"error": str(exc),
}
)
),
log_level="warning",
logger=f"{provider._meta.verbose_name}@{object_type}",
attributes={"obj": sanitize_item(obj)},
)
)
task.warning(
f"Failed to sync {obj._meta.verbose_name} {str(obj)} due to transient error: {str(exc)}",
attributes={"obj": sanitize_item(obj)},
)
except StopSync as exc:
self.logger.warning("Stopping sync", exc=exc)
messages.append(
asdict(
LogEvent(
_(
"Stopping sync due to error: {error}".format_map(
{
"error": exc.detail(),
}
)
),
log_level="warning",
logger=f"{provider._meta.verbose_name}@{object_type}",
attributes={"obj": sanitize_item(obj)},
)
)
task.warning(
f"Stopping sync due to error: {exc.detail()}",
attributes={"obj": sanitize_item(obj)},
)
break
return messages
def sync_signal_direct(self, model: str, pk: str | int, raw_op: str):
self.logger = get_logger().bind(

View File

@ -130,7 +130,8 @@ def outpost_controller(outpost_pk: str, action: str = "up", from_cache: bool = F
else:
if from_cache:
cache.delete(CACHE_KEY_OUTPOST_DOWN % outpost_pk)
self.info(*logs)
for log in logs:
self.info(log)
@actor

View File

@ -85,6 +85,6 @@ class SCIMClientTests(TestCase):
self.assertEqual(mock.call_count, 1)
self.assertEqual(mock.request_history[0].method, "GET")
def test_scim_sync_all(self):
"""test scim_sync_all task"""
def test_scim_sync(self):
"""test scim_sync task"""
scim_sync.send(self.provider.pk).get_result()

View File

@ -23,12 +23,14 @@ from authentik.sources.kerberos.models import (
Krb5ConfContext,
UserKerberosSourceConnection,
)
from authentik.tasks.models import Task
class KerberosSync:
"""Sync Kerberos users into authentik"""
_source: KerberosSource
_task: Task
_logger: BoundLogger
_connection: KAdmin
mapper: SourceMapper
@ -36,11 +38,11 @@ class KerberosSync:
group_manager: PropertyMappingManager
matcher: SourceMatcher
def __init__(self, source: KerberosSource):
def __init__(self, source: KerberosSource, task: Task):
self._source = source
self._task = task
with Krb5ConfContext(self._source):
self._connection = self._source.connection()
self._messages = []
self._logger = get_logger().bind(source=self._source, syncer=self.__class__.__name__)
self.mapper = SourceMapper(self._source)
self.user_manager = self.mapper.get_manager(User, ["principal", "principal_obj"])
@ -56,17 +58,6 @@ class KerberosSync:
"""UI name for the type of object this class synchronizes"""
return "users"
@property
def messages(self) -> list[str]:
"""Get all UI messages"""
return self._messages
def message(self, *args, **kwargs):
"""Add message that is later added to the System Task and shown to the user"""
formatted_message = " ".join(args)
self._messages.append(formatted_message)
self._logger.warning(*args, **kwargs)
def _handle_principal(self, principal: str) -> bool:
try:
# TODO: handle permission error
@ -163,7 +154,7 @@ class KerberosSync:
def sync(self) -> int:
"""Iterate over all Kerberos users and create authentik_core.User instances"""
if not self._source.enabled or not self._source.sync_users:
self.message("Source is disabled or user syncing is disabled for this Source")
self._task.info("Source is disabled or user syncing is disabled for this Source")
return -1
user_count = 0

View File

@ -43,7 +43,6 @@ def kerberos_sync(pk: str):
return
syncer = KerberosSync(source)
syncer.sync()
self.info(*syncer.messages)
except StopSync as exc:
LOGGER.warning(exception_to_string(exc))
self.error(exc)

View File

@ -10,22 +10,23 @@ from authentik.core.sources.mapper import SourceMapper
from authentik.lib.config import CONFIG
from authentik.lib.sync.mapper import PropertyMappingManager
from authentik.sources.ldap.models import LDAPSource, flatten
from authentik.tasks.models import Task
class BaseLDAPSynchronizer:
"""Sync LDAP Users and groups into authentik"""
_source: LDAPSource
_task: Task
_logger: BoundLogger
_connection: Connection
_messages: list[str]
mapper: SourceMapper
manager: PropertyMappingManager
def __init__(self, source: LDAPSource):
def __init__(self, source: LDAPSource, task: Task):
self._source = source
self._task = task
self._connection = source.connection()
self._messages = []
self._logger = get_logger().bind(source=source, syncer=self.__class__.__name__)
@staticmethod
@ -46,11 +47,6 @@ class BaseLDAPSynchronizer:
"""Sync function, implemented in subclass"""
raise NotImplementedError()
@property
def messages(self) -> list[str]:
"""Get all UI messages"""
return self._messages
@property
def base_dn_users(self) -> str:
"""Shortcut to get full base_dn for user lookups"""
@ -65,14 +61,6 @@ class BaseLDAPSynchronizer:
return f"{self._source.additional_group_dn},{self._source.base_dn}"
return self._source.base_dn
def message(self, *args, **kwargs):
"""Add message that is later added to the System Task and shown to the user"""
formatted_message = " ".join(args)
if "dn" in kwargs:
formatted_message += f"; DN: {kwargs['dn']}"
self._messages.append(formatted_message)
self._logger.warning(*args, **kwargs)
def get_objects(self, **kwargs) -> Generator:
"""Get objects from LDAP, implemented in subclass"""
raise NotImplementedError()

View File

@ -19,7 +19,7 @@ class GroupLDAPForwardDeletion(BaseLDAPSynchronizer):
def get_objects(self, **kwargs) -> Generator:
if not self._source.sync_groups or not self._source.delete_not_found_objects:
self.message("Group syncing is disabled for this Source")
self._task.info("Group syncing is disabled for this Source")
return iter(())
uuid = uuid4()
@ -54,7 +54,7 @@ class GroupLDAPForwardDeletion(BaseLDAPSynchronizer):
def sync(self, group_pks: tuple) -> int:
"""Delete authentik groups"""
if not self._source.sync_groups or not self._source.delete_not_found_objects:
self.message("Group syncing is disabled for this Source")
self._task.info("Group syncing is disabled for this Source")
return -1
self._logger.debug("Deleting groups", group_pks=group_pks)
_, deleted_per_type = Group.objects.filter(pk__in=group_pks).delete()

View File

@ -21,7 +21,7 @@ class UserLDAPForwardDeletion(BaseLDAPSynchronizer):
def get_objects(self, **kwargs) -> Generator:
if not self._source.sync_users or not self._source.delete_not_found_objects:
self.message("User syncing is disabled for this Source")
self._task.info("User syncing is disabled for this Source")
return iter(())
uuid = uuid4()
@ -56,7 +56,7 @@ class UserLDAPForwardDeletion(BaseLDAPSynchronizer):
def sync(self, user_pks: tuple) -> int:
"""Delete authentik users"""
if not self._source.sync_users or not self._source.delete_not_found_objects:
self.message("User syncing is disabled for this Source")
self._task.info("User syncing is disabled for this Source")
return -1
self._logger.debug("Deleting users", user_pks=user_pks)
_, deleted_per_type = User.objects.filter(pk__in=user_pks).delete()

View File

@ -37,7 +37,7 @@ class GroupLDAPSynchronizer(BaseLDAPSynchronizer):
def get_objects(self, **kwargs) -> Generator:
if not self._source.sync_groups:
self.message("Group syncing is disabled for this Source")
self._task.info("Group syncing is disabled for this Source")
return iter(())
return self.search_paginator(
search_base=self.base_dn_groups,
@ -54,7 +54,7 @@ class GroupLDAPSynchronizer(BaseLDAPSynchronizer):
def sync(self, page_data: list) -> int:
"""Iterate over all LDAP Groups and create authentik_core.Group instances"""
if not self._source.sync_groups:
self.message("Group syncing is disabled for this Source")
self._task.info("Group syncing is disabled for this Source")
return -1
group_count = 0
for group in page_data:
@ -62,7 +62,7 @@ class GroupLDAPSynchronizer(BaseLDAPSynchronizer):
continue
group_dn = flatten(flatten(group.get("entryDN", group.get("dn"))))
if not (uniq := self.get_identifier(attributes)):
self.message(
self._task.info(
f"Uniqueness field not found/not set in attributes: '{group_dn}'",
attributes=attributes.keys(),
dn=group_dn,

View File

@ -26,7 +26,7 @@ class MembershipLDAPSynchronizer(BaseLDAPSynchronizer):
def get_objects(self, **kwargs) -> Generator:
if not self._source.sync_groups:
self.message("Group syncing is disabled for this Source")
self._task.info("Group syncing is disabled for this Source")
return iter(())
# If we are looking up groups from users, we don't need to fetch the group membership field
@ -45,7 +45,7 @@ class MembershipLDAPSynchronizer(BaseLDAPSynchronizer):
def sync(self, page_data: list) -> int:
"""Iterate over all Users and assign Groups using memberOf Field"""
if not self._source.sync_groups:
self.message("Group syncing is disabled for this Source")
self._task.info("Group syncing is disabled for this Source")
return -1
membership_count = 0
for group in page_data:
@ -94,7 +94,7 @@ class MembershipLDAPSynchronizer(BaseLDAPSynchronizer):
# group_uniq might be a single string or an array with (hopefully) a single string
if isinstance(group_uniq, list):
if len(group_uniq) < 1:
self.message(
self._task.info(
f"Group does not have a uniqueness attribute: '{group_dn}'",
group=group_dn,
)
@ -104,7 +104,7 @@ class MembershipLDAPSynchronizer(BaseLDAPSynchronizer):
groups = Group.objects.filter(**{f"attributes__{LDAP_UNIQUENESS}": group_uniq})
if not groups.exists():
if self._source.sync_groups:
self.message(
self._task.info(
f"Group does not exist in our DB yet, run sync_groups first: '{group_dn}'",
group=group_dn,
)

View File

@ -39,7 +39,7 @@ class UserLDAPSynchronizer(BaseLDAPSynchronizer):
def get_objects(self, **kwargs) -> Generator:
if not self._source.sync_users:
self.message("User syncing is disabled for this Source")
self._task.info("User syncing is disabled for this Source")
return iter(())
return self.search_paginator(
search_base=self.base_dn_users,
@ -56,7 +56,7 @@ class UserLDAPSynchronizer(BaseLDAPSynchronizer):
def sync(self, page_data: list) -> int:
"""Iterate over all LDAP Users and create authentik_core.User instances"""
if not self._source.sync_users:
self.message("User syncing is disabled for this Source")
self._task.info("User syncing is disabled for this Source")
return -1
user_count = 0
for user in page_data:
@ -64,7 +64,7 @@ class UserLDAPSynchronizer(BaseLDAPSynchronizer):
continue
user_dn = flatten(user.get("entryDN", user.get("dn")))
if not (uniq := self.get_identifier(attributes)):
self.message(
self._task.info(
f"Uniqueness field not found/not set in attributes: '{user_dn}'",
attributes=attributes.keys(),
dn=user_dn,

View File

@ -30,7 +30,7 @@ class FreeIPA(BaseLDAPSynchronizer):
pwd_last_set: datetime = attributes.get("krbLastPwdChange", datetime.now())
pwd_last_set = pwd_last_set.replace(tzinfo=UTC)
if created or pwd_last_set >= user.password_change_date:
self.message(f"'{user.username}': Reset user's password")
self._task.info(f"'{user.username}': Reset user's password")
self._logger.debug(
"Reset user's password",
user=user.username,

View File

@ -60,7 +60,7 @@ class MicrosoftActiveDirectory(BaseLDAPSynchronizer):
pwd_last_set: datetime = attributes.get("pwdLastSet", datetime.now())
pwd_last_set = pwd_last_set.replace(tzinfo=UTC)
if created or pwd_last_set >= user.password_change_date:
self.message(f"'{user.username}': Reset user's password")
self._task.info(f"'{user.username}': Reset user's password")
self._logger.debug(
"Reset user's password",
user=user.username,

View File

@ -72,14 +72,12 @@ def ldap_sync(source_pk: str):
)
# User and group sync can happen at once, they have no dependencies on each other
user_group_tasks.run().get_results(
block=True,
timeout=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours") * 1000,
user_group_tasks.run().wait(
timeout=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours") * 1000
)
# Membership sync needs to run afterwards
membership_tasks.run().get_results(
block=True,
timeout=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours") * 1000,
membership_tasks.run().wait(
timeout=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours") * 1000
)
# Finally, deletions. What we'd really like to do here is something like
# ```
@ -96,7 +94,9 @@ def ldap_sync(source_pk: str):
# large chunks, and only queue the deletion step afterwards.
# 3. Delete every unmarked item. This is slow, so we spread it over many tasks in
# small chunks.
deletion_tasks.run() # no need to block here, we don't have anything else to do afterwards
deletion_tasks.run().wait(
timeout=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours") * 1000,
)
def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) -> list[Message]:
@ -139,9 +139,7 @@ def ldap_sync_page(source_pk: str, sync_class: str, page_cache_key: str):
return
cache.touch(page_cache_key)
count = sync_inst.sync(page)
messages = sync_inst.messages
messages.append(f"Synced {count} objects.")
self.info(*messages)
self.info(f"Synced {count} objects.")
cache.delete(page_cache_key)
except (LDAPException, StopSync) as exc:
# No explicit event is created here as .set_status with an error will do that

View File

@ -13,6 +13,7 @@ from authentik.sources.ldap.models import LDAPSource, LDAPSourcePropertyMapping
from authentik.sources.ldap.sync.users import UserLDAPSynchronizer
from authentik.sources.ldap.tests.mock_ad import mock_ad_connection
from authentik.sources.ldap.tests.mock_slapd import mock_slapd_connection
from authentik.tasks.models import Task
LDAP_PASSWORD = generate_key()
@ -43,7 +44,7 @@ class LDAPSyncTests(TestCase):
raw_conn.bind = bind_mock
connection = MagicMock(return_value=raw_conn)
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
user_sync = UserLDAPSynchronizer(self.source)
user_sync = UserLDAPSynchronizer(self.source, Task())
user_sync.sync_full()
user = User.objects.get(username="user0_sn")
@ -71,7 +72,7 @@ class LDAPSyncTests(TestCase):
)
connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD))
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
user_sync = UserLDAPSynchronizer(self.source)
user_sync = UserLDAPSynchronizer(self.source, Task())
user_sync.sync_full()
user = User.objects.get(username="user0_sn")
@ -98,7 +99,7 @@ class LDAPSyncTests(TestCase):
self.source.save()
connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD))
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
user_sync = UserLDAPSynchronizer(self.source)
user_sync = UserLDAPSynchronizer(self.source, Task())
user_sync.sync_full()
user = User.objects.get(username="user0_sn")

View File

@ -23,7 +23,7 @@ from authentik.sources.ldap.sync.forward_delete_users import DELETE_CHUNK_SIZE
from authentik.sources.ldap.sync.groups import GroupLDAPSynchronizer
from authentik.sources.ldap.sync.membership import MembershipLDAPSynchronizer
from authentik.sources.ldap.sync.users import UserLDAPSynchronizer
from authentik.sources.ldap.tasks import ldap_sync, ldap_sync_all, ldap_sync_page
from authentik.sources.ldap.tasks import ldap_sync, ldap_sync_page
from authentik.sources.ldap.tests.mock_ad import mock_ad_connection
from authentik.sources.ldap.tests.mock_freeipa import mock_freeipa_connection
from authentik.sources.ldap.tests.mock_slapd import (
@ -33,6 +33,7 @@ from authentik.sources.ldap.tests.mock_slapd import (
user_in_slapd_cn,
user_in_slapd_uid,
)
from authentik.tasks.models import Task
LDAP_PASSWORD = generate_key()
@ -74,7 +75,7 @@ class LDAPSyncTests(TestCase):
self.source.save()
connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD))
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
user_sync = UserLDAPSynchronizer(self.source)
user_sync = UserLDAPSynchronizer(self.source, Task())
with self.assertRaises(StopSync):
user_sync.sync_full()
self.assertFalse(User.objects.filter(username="user0_sn").exists())
@ -105,7 +106,7 @@ class LDAPSyncTests(TestCase):
# we basically just test that the mappings don't throw errors
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
user_sync = UserLDAPSynchronizer(self.source)
user_sync = UserLDAPSynchronizer(self.source, Task())
user_sync.sync_full()
def test_sync_users_ad(self):
@ -133,7 +134,7 @@ class LDAPSyncTests(TestCase):
)
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
user_sync = UserLDAPSynchronizer(self.source)
user_sync = UserLDAPSynchronizer(self.source, Task())
user_sync.sync_full()
user = User.objects.filter(username="user0_sn").first()
self.assertEqual(user.attributes["foo"], "bar")
@ -152,7 +153,7 @@ class LDAPSyncTests(TestCase):
)
connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD))
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
user_sync = UserLDAPSynchronizer(self.source)
user_sync = UserLDAPSynchronizer(self.source, Task())
user_sync.sync_full()
self.assertTrue(User.objects.filter(username="user0_sn").exists())
self.assertFalse(User.objects.filter(username="user1_sn").exists())
@ -168,7 +169,7 @@ class LDAPSyncTests(TestCase):
)
connection = MagicMock(return_value=mock_freeipa_connection(LDAP_PASSWORD))
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
user_sync = UserLDAPSynchronizer(self.source)
user_sync = UserLDAPSynchronizer(self.source, Task())
user_sync.sync_full()
self.assertTrue(User.objects.filter(username="user0_sn").exists())
self.assertFalse(User.objects.filter(username="user1_sn").exists())
@ -193,11 +194,11 @@ class LDAPSyncTests(TestCase):
)
connection = MagicMock(return_value=mock_freeipa_connection(LDAP_PASSWORD))
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
user_sync = UserLDAPSynchronizer(self.source)
user_sync = UserLDAPSynchronizer(self.source, Task())
user_sync.sync_full()
group_sync = GroupLDAPSynchronizer(self.source)
group_sync = GroupLDAPSynchronizer(self.source, Task())
group_sync.sync_full()
membership_sync = MembershipLDAPSynchronizer(self.source)
membership_sync = MembershipLDAPSynchronizer(self.source, Task())
membership_sync.sync_full()
self.assertTrue(
@ -230,9 +231,9 @@ class LDAPSyncTests(TestCase):
parent_group = Group.objects.get(name=_user.username)
self.source.sync_parent_group = parent_group
self.source.save()
group_sync = GroupLDAPSynchronizer(self.source)
group_sync = GroupLDAPSynchronizer(self.source, Task())
group_sync.sync_full()
membership_sync = MembershipLDAPSynchronizer(self.source)
membership_sync = MembershipLDAPSynchronizer(self.source, Task())
membership_sync.sync_full()
group: Group = Group.objects.filter(name="test-group").first()
self.assertIsNotNone(group)
@ -256,9 +257,9 @@ class LDAPSyncTests(TestCase):
connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD))
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
self.source.save()
group_sync = GroupLDAPSynchronizer(self.source)
group_sync = GroupLDAPSynchronizer(self.source, Task())
group_sync.sync_full()
membership_sync = MembershipLDAPSynchronizer(self.source)
membership_sync = MembershipLDAPSynchronizer(self.source, Task())
membership_sync.sync_full()
group = Group.objects.filter(name="group1")
self.assertTrue(group.exists())
@ -290,11 +291,11 @@ class LDAPSyncTests(TestCase):
connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD))
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
self.source.save()
user_sync = UserLDAPSynchronizer(self.source)
user_sync = UserLDAPSynchronizer(self.source, Task())
user_sync.sync_full()
group_sync = GroupLDAPSynchronizer(self.source)
group_sync = GroupLDAPSynchronizer(self.source, Task())
group_sync.sync_full()
membership_sync = MembershipLDAPSynchronizer(self.source)
membership_sync = MembershipLDAPSynchronizer(self.source, Task())
membership_sync.sync_full()
# Test if membership mapping based on memberUid works.
posix_group = Group.objects.filter(name="group-posix").first()
@ -327,11 +328,11 @@ class LDAPSyncTests(TestCase):
connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD))
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
self.source.save()
user_sync = UserLDAPSynchronizer(self.source)
user_sync = UserLDAPSynchronizer(self.source, Task())
user_sync.sync_full()
group_sync = GroupLDAPSynchronizer(self.source)
group_sync = GroupLDAPSynchronizer(self.source, Task())
group_sync.sync_full()
membership_sync = MembershipLDAPSynchronizer(self.source)
membership_sync = MembershipLDAPSynchronizer(self.source, Task())
membership_sync.sync_full()
# Test if membership mapping based on memberUid works.
posix_group = Group.objects.filter(name="group-posix").first()

View File

@ -112,23 +112,20 @@ class Task(SerializerModel):
if save:
self.save()
def log(self, status: TaskStatus, *messages: str | LogEvent | Exception, save: bool = False):
def log(self, status: TaskStatus, message: str | Exception, save: bool = False, **attributes):
self.messages: list
for msg in messages:
message = msg
if isinstance(message, Exception):
message = exception_to_string(message)
if not isinstance(message, LogEvent):
message = LogEvent(message, logger=self.uid, log_level=status.value)
self.messages.append(sanitize_item(message))
if isinstance(message, Exception):
message = exception_to_string(message)
message = LogEvent(message, logger=self.uid, log_level=status.value, attributes=attributes)
self.messages.append(sanitize_item(message))
if save:
self.save()
def info(self, *messages: str | LogEvent | Exception, save: bool = False):
self.log(TaskStatus.INFO, *messages, save=save)
def info(self, message: str | Exception, save: bool = False, **attributes):
self.log(TaskStatus.INFO, message, save=save, **attributes)
def warning(self, *messages: str | LogEvent | Exception, save: bool = False):
self.log(TaskStatus.WARNING, *messages, save=save)
def warning(self, message: str | Exception, save: bool = False, **attributes):
self.log(TaskStatus.WARNING, message, save=save, **attributes)
def error(self, *messages: str | LogEvent | Exception, save: bool = False):
self.log(TaskStatus.ERROR, *messages, save=save)
def error(self, message: str | Exception, save: bool = False, **attributes):
self.log(TaskStatus.ERROR, message, save=save, **attributes)