uhhhhhhhhh

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
Marc 'risson' Schmitt
2025-06-05 18:15:20 +02:00
parent 59c8472628
commit 3766ca86e8
18 changed files with 164 additions and 257 deletions

View File

@ -45,11 +45,17 @@ class OutgoingSyncProvider(ScheduledModel, Model):
def get_object_qs[T: User | Group](self, type: type[T]) -> QuerySet[T]: def get_object_qs[T: User | Group](self, type: type[T]) -> QuerySet[T]:
raise NotImplementedError raise NotImplementedError
def get_paginator[T: User | Group](self, type: type[T]) -> Paginator:
return Paginator(self.get_object_qs(type), PAGE_SIZE)
def get_object_sync_time_limit[T: User | Group](self, type: type[T]) -> int:
num_pages: int = self.get_paginator(type).num_pages
return int(num_pages * PAGE_TIMEOUT * 1.5) * 1000
def get_sync_time_limit(self) -> int: def get_sync_time_limit(self) -> int:
users_paginator = Paginator(self.get_object_qs(User), PAGE_SIZE) return int(
groups_paginator = Paginator(self.get_object_qs(Group), PAGE_SIZE) self.get_object_sync_time_limit(User) + self.get_object_sync_time_limit(Group) * 1.5
time_limit = (users_paginator.num_pages + groups_paginator.num_pages) * PAGE_TIMEOUT * 1.5 )
return int(time_limit)
@property @property
def sync_lock(self) -> pglock.advisory: def sync_lock(self) -> pglock.advisory:
@ -72,7 +78,7 @@ class OutgoingSyncProvider(ScheduledModel, Model):
uid=self.pk, uid=self.pk,
args=(self.pk,), args=(self.pk,),
options={ options={
"time_limit": self.get_sync_time_limit() * 1000, "time_limit": self.get_sync_time_limit(),
}, },
send_on_save=True, send_on_save=True,
crontab=f"{fqdn_rand(self.pk)} */4 * * *", crontab=f"{fqdn_rand(self.pk)} */4 * * *",

View File

@ -1,21 +1,19 @@
from collections.abc import Callable
from dataclasses import asdict from dataclasses import asdict
from celery import group from dramatiq.composition import group
from celery.exceptions import Retry
from celery.result import allow_join_result
from django.core.paginator import Paginator from django.core.paginator import Paginator
from django.db.models import Model, QuerySet from django.db.models import Model, QuerySet
from django.db.models.query import Q from django.db.models.query import Q
from django.utils.text import slugify from django.utils.text import slugify
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from dramatiq.actor import Actor
from dramatiq.errors import Retry
from structlog.stdlib import BoundLogger, get_logger from structlog.stdlib import BoundLogger, get_logger
from authentik.core.expression.exceptions import SkipObjectException from authentik.core.expression.exceptions import SkipObjectException
from authentik.core.models import Group, User from authentik.core.models import Group, User
from authentik.events.logs import LogEvent from authentik.events.logs import LogEvent
from authentik.events.models import TaskStatus from authentik.events.models import TaskStatus
from authentik.events.system_tasks import SystemTask
from authentik.events.utils import sanitize_item from authentik.events.utils import sanitize_item
from authentik.lib.sync.outgoing import PAGE_SIZE, PAGE_TIMEOUT from authentik.lib.sync.outgoing import PAGE_SIZE, PAGE_TIMEOUT
from authentik.lib.sync.outgoing.base import Direction from authentik.lib.sync.outgoing.base import Direction
@ -27,11 +25,13 @@ from authentik.lib.sync.outgoing.exceptions import (
) )
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
from authentik.lib.utils.reflection import class_to_path, path_to_class from authentik.lib.utils.reflection import class_to_path, path_to_class
from authentik.tasks.middleware import CurrentTask
from authentik.tasks.models import Task
class SyncTasks: class SyncTasks:
"""Container for all sync 'tasks' (this class doesn't actually contain celery """Container for all sync 'tasks' (this class doesn't actually contain
tasks due to celery's magic, however exposes a number of functions to be called from tasks)""" tasks due to dramatiq's magic, however exposes a number of functions to be called from tasks)"""
logger: BoundLogger logger: BoundLogger
@ -39,107 +39,97 @@ class SyncTasks:
super().__init__() super().__init__()
self._provider_model = provider_model self._provider_model = provider_model
def sync_all(self, single_sync: Callable[[int], None]): def sync_paginator(
for provider in self._provider_model.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
):
self.trigger_single_task(provider, single_sync)
def trigger_single_task(self, provider: OutgoingSyncProvider, sync_task: Callable[[int], None]):
"""Wrapper single sync task that correctly sets time limits based
on the amount of objects that will be synced"""
users_paginator = Paginator(provider.get_object_qs(User), PAGE_SIZE)
groups_paginator = Paginator(provider.get_object_qs(Group), PAGE_SIZE)
soft_time_limit = (users_paginator.num_pages + groups_paginator.num_pages) * PAGE_TIMEOUT
time_limit = soft_time_limit * 1.5
return sync_task.apply_async(
(provider.pk,), time_limit=int(time_limit), soft_time_limit=int(soft_time_limit)
)
def sync_single(
self, self,
task: SystemTask, current_task: Task,
provider_pk: int, provider: OutgoingSyncProvider,
sync_objects: Callable[[int, int], list[str]], sync_objects: Actor,
paginator: Paginator,
object_type: type[User | Group],
**options,
): ):
tasks = []
for page in paginator.page_range:
page_sync = sync_objects.message_with_options(
args=(class_to_path(object_type), page, provider.pk),
time_limit=PAGE_TIMEOUT * 1000,
# Assign tasks to the same schedule as the current one
rel_obj=current_task.rel_obj,
**options,
)
tasks.append(page_sync)
return tasks
def sync(
self,
provider_pk: int,
sync_objects: Actor,
):
task = CurrentTask.get_task()
self.logger = get_logger().bind( self.logger = get_logger().bind(
provider_type=class_to_path(self._provider_model), provider_type=class_to_path(self._provider_model),
provider_pk=provider_pk, provider_pk=provider_pk,
) )
provider = self._provider_model.objects.filter( provider: OutgoingSyncProvider = self._provider_model.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False), Q(backchannel_application__isnull=False) | Q(application__isnull=False),
pk=provider_pk, pk=provider_pk,
).first() ).first()
if not provider: if not provider:
return return
task.set_uid(slugify(provider.name)) task.set_uid(slugify(provider.name))
messages = [] task.info("Starting full provider sync")
messages.append(_("Starting full provider sync"))
self.logger.debug("Starting provider sync") self.logger.debug("Starting provider sync")
users_paginator = Paginator(provider.get_object_qs(User), PAGE_SIZE) with provider.sync_lock as lock_acquired:
groups_paginator = Paginator(provider.get_object_qs(Group), PAGE_SIZE)
with allow_join_result(), provider.sync_lock as lock_acquired:
if not lock_acquired: if not lock_acquired:
self.logger.debug("Failed to acquire sync lock, skipping", provider=provider.name) self.logger.debug("Failed to acquire sync lock, skipping", provider=provider.name)
return return
try: try:
messages.append(_("Syncing users")) users_tasks = group(
user_results = ( self.sync_paginator(
group( current_task=task,
[ provider=provider,
sync_objects.signature( sync_objects=sync_objects,
args=(class_to_path(User), page, provider_pk), paginator=provider.get_paginator(User),
time_limit=PAGE_TIMEOUT, object_type=User,
soft_time_limit=PAGE_TIMEOUT,
)
for page in users_paginator.page_range
]
) )
.apply_async()
.get()
) )
for result in user_results: group_tasks = group(
for msg in result: self.sync_paginator(
messages.append(LogEvent(**msg)) current_task=task,
messages.append(_("Syncing groups")) provider=provider,
group_results = ( sync_objects=sync_objects,
group( paginator=provider.get_paginator(Group),
[ object_type=Group,
sync_objects.signature(
args=(class_to_path(Group), page, provider_pk),
time_limit=PAGE_TIMEOUT,
soft_time_limit=PAGE_TIMEOUT,
)
for page in groups_paginator.page_range
]
) )
.apply_async()
.get()
) )
for result in group_results: users_tasks.run().wait(timeout=provider.get_object_sync_time_limit(User))
for msg in result: group_tasks.run().wait(timeout=provider.get_object_sync_time_limit(Group))
messages.append(LogEvent(**msg))
except TransientSyncException as exc: except TransientSyncException as exc:
self.logger.warning("transient sync exception", exc=exc) self.logger.warning("transient sync exception", exc=exc)
raise task.retry(exc=exc) from exc raise Retry from exc
except StopSync as exc: except StopSync as exc:
task.set_error(exc) task.error(exc)
return return
task.set_status(TaskStatus.SUCCESSFUL, *messages)
def sync_objects( def sync_objects(
self, object_type: str, page: int, provider_pk: int, override_dry_run=False, **filter self,
object_type: str,
page: int,
provider_pk: int,
override_dry_run=False,
**filter,
): ):
task = CurrentTask.get_task()
_object_type: type[Model] = path_to_class(object_type) _object_type: type[Model] = path_to_class(object_type)
self.logger = get_logger().bind( self.logger = get_logger().bind(
provider_type=class_to_path(self._provider_model), provider_type=class_to_path(self._provider_model),
provider_pk=provider_pk, provider_pk=provider_pk,
object_type=object_type, object_type=object_type,
) )
messages = [] task.info(f"Syncing page {page} of {_object_type._meta.verbose_name_plural}")
provider = self._provider_model.objects.filter(pk=provider_pk).first() provider = self._provider_model.objects.filter(pk=provider_pk).first()
if not provider: if not provider:
return messages return
# Override dry run mode if requested, however don't save the provider # Override dry run mode if requested, however don't save the provider
# so that scheduled sync tasks still run in dry_run mode # so that scheduled sync tasks still run in dry_run mode
if override_dry_run: if override_dry_run:
@ -147,25 +137,13 @@ class SyncTasks:
try: try:
client = provider.client_for_model(_object_type) client = provider.client_for_model(_object_type)
except TransientSyncException: except TransientSyncException:
return messages return
paginator = Paginator(provider.get_object_qs(_object_type).filter(**filter), PAGE_SIZE) paginator = Paginator(provider.get_object_qs(_object_type).filter(**filter), PAGE_SIZE)
if client.can_discover: if client.can_discover:
self.logger.debug("starting discover") self.logger.debug("starting discover")
client.discover() client.discover()
self.logger.debug("starting sync for page", page=page) self.logger.debug("starting sync for page", page=page)
messages.append( task.info(f"Syncing page {page} or {_object_type._meta.verbose_name_plural}")
asdict(
LogEvent(
_(
"Syncing page {page} of {object_type}".format(
page=page, object_type=_object_type._meta.verbose_name_plural
)
),
log_level="info",
logger=f"{provider._meta.verbose_name}@{object_type}",
)
)
)
for obj in paginator.page(page).object_list: for obj in paginator.page(page).object_list:
obj: Model obj: Model
try: try:
@ -174,87 +152,34 @@ class SyncTasks:
self.logger.debug("skipping object due to SkipObject", obj=obj) self.logger.debug("skipping object due to SkipObject", obj=obj)
continue continue
except DryRunRejected as exc: except DryRunRejected as exc:
messages.append( task.info(
asdict( "Dropping mutating request due to dry run",
LogEvent( attributes={
_("Dropping mutating request due to dry run"), "obj": sanitize_item(obj),
log_level="info", "method": exc.method,
logger=f"{provider._meta.verbose_name}@{object_type}", "url": exc.url,
attributes={ "body": exc.body,
"obj": sanitize_item(obj), },
"method": exc.method,
"url": exc.url,
"body": exc.body,
},
)
)
) )
except BadRequestSyncException as exc: except BadRequestSyncException as exc:
self.logger.warning("failed to sync object", exc=exc, obj=obj) self.logger.warning("failed to sync object", exc=exc, obj=obj)
messages.append( task.warning(
asdict( f"Failed to sync {obj._meta.verbose_name} {str(obj)} due to error: {str(exc)}",
LogEvent( attributes={"arguments": exc.args[1:], "obj": sanitize_item(obj)},
_(
(
"Failed to sync {object_type} {object_name} "
"due to error: {error}"
).format_map(
{
"object_type": obj._meta.verbose_name,
"object_name": str(obj),
"error": str(exc),
}
)
),
log_level="warning",
logger=f"{provider._meta.verbose_name}@{object_type}",
attributes={"arguments": exc.args[1:], "obj": sanitize_item(obj)},
)
)
) )
except TransientSyncException as exc: except TransientSyncException as exc:
self.logger.warning("failed to sync object", exc=exc, user=obj) self.logger.warning("failed to sync object", exc=exc, user=obj)
messages.append( task.warning(
asdict( f"Failed to sync {obj._meta.verbose_name} {str(obj)} due to transient error: {str(exc)}",
LogEvent( attributes={"obj": sanitize_item(obj)},
_(
(
"Failed to sync {object_type} {object_name} "
"due to transient error: {error}"
).format_map(
{
"object_type": obj._meta.verbose_name,
"object_name": str(obj),
"error": str(exc),
}
)
),
log_level="warning",
logger=f"{provider._meta.verbose_name}@{object_type}",
attributes={"obj": sanitize_item(obj)},
)
)
) )
except StopSync as exc: except StopSync as exc:
self.logger.warning("Stopping sync", exc=exc) self.logger.warning("Stopping sync", exc=exc)
messages.append( task.warning(
asdict( f"Stopping sync due to error: {exc.detail()}",
LogEvent( attributes={"obj": sanitize_item(obj)},
_(
"Stopping sync due to error: {error}".format_map(
{
"error": exc.detail(),
}
)
),
log_level="warning",
logger=f"{provider._meta.verbose_name}@{object_type}",
attributes={"obj": sanitize_item(obj)},
)
)
) )
break break
return messages
def sync_signal_direct(self, model: str, pk: str | int, raw_op: str): def sync_signal_direct(self, model: str, pk: str | int, raw_op: str):
self.logger = get_logger().bind( self.logger = get_logger().bind(

View File

@ -130,7 +130,8 @@ def outpost_controller(outpost_pk: str, action: str = "up", from_cache: bool = F
else: else:
if from_cache: if from_cache:
cache.delete(CACHE_KEY_OUTPOST_DOWN % outpost_pk) cache.delete(CACHE_KEY_OUTPOST_DOWN % outpost_pk)
self.info(*logs) for log in logs:
self.info(log)
@actor @actor

View File

@ -85,6 +85,6 @@ class SCIMClientTests(TestCase):
self.assertEqual(mock.call_count, 1) self.assertEqual(mock.call_count, 1)
self.assertEqual(mock.request_history[0].method, "GET") self.assertEqual(mock.request_history[0].method, "GET")
def test_scim_sync_all(self): def test_scim_sync(self):
"""test scim_sync_all task""" """test scim_sync task"""
scim_sync.send(self.provider.pk).get_result() scim_sync.send(self.provider.pk).get_result()

View File

@ -23,12 +23,14 @@ from authentik.sources.kerberos.models import (
Krb5ConfContext, Krb5ConfContext,
UserKerberosSourceConnection, UserKerberosSourceConnection,
) )
from authentik.tasks.models import Task
class KerberosSync: class KerberosSync:
"""Sync Kerberos users into authentik""" """Sync Kerberos users into authentik"""
_source: KerberosSource _source: KerberosSource
_task: Task
_logger: BoundLogger _logger: BoundLogger
_connection: KAdmin _connection: KAdmin
mapper: SourceMapper mapper: SourceMapper
@ -36,11 +38,11 @@ class KerberosSync:
group_manager: PropertyMappingManager group_manager: PropertyMappingManager
matcher: SourceMatcher matcher: SourceMatcher
def __init__(self, source: KerberosSource): def __init__(self, source: KerberosSource, task: Task):
self._source = source self._source = source
self._task = task
with Krb5ConfContext(self._source): with Krb5ConfContext(self._source):
self._connection = self._source.connection() self._connection = self._source.connection()
self._messages = []
self._logger = get_logger().bind(source=self._source, syncer=self.__class__.__name__) self._logger = get_logger().bind(source=self._source, syncer=self.__class__.__name__)
self.mapper = SourceMapper(self._source) self.mapper = SourceMapper(self._source)
self.user_manager = self.mapper.get_manager(User, ["principal", "principal_obj"]) self.user_manager = self.mapper.get_manager(User, ["principal", "principal_obj"])
@ -56,17 +58,6 @@ class KerberosSync:
"""UI name for the type of object this class synchronizes""" """UI name for the type of object this class synchronizes"""
return "users" return "users"
@property
def messages(self) -> list[str]:
"""Get all UI messages"""
return self._messages
def message(self, *args, **kwargs):
"""Add message that is later added to the System Task and shown to the user"""
formatted_message = " ".join(args)
self._messages.append(formatted_message)
self._logger.warning(*args, **kwargs)
def _handle_principal(self, principal: str) -> bool: def _handle_principal(self, principal: str) -> bool:
try: try:
# TODO: handle permission error # TODO: handle permission error
@ -163,7 +154,7 @@ class KerberosSync:
def sync(self) -> int: def sync(self) -> int:
"""Iterate over all Kerberos users and create authentik_core.User instances""" """Iterate over all Kerberos users and create authentik_core.User instances"""
if not self._source.enabled or not self._source.sync_users: if not self._source.enabled or not self._source.sync_users:
self.message("Source is disabled or user syncing is disabled for this Source") self._task.info("Source is disabled or user syncing is disabled for this Source")
return -1 return -1
user_count = 0 user_count = 0

View File

@ -43,7 +43,6 @@ def kerberos_sync(pk: str):
return return
syncer = KerberosSync(source) syncer = KerberosSync(source)
syncer.sync() syncer.sync()
self.info(*syncer.messages)
except StopSync as exc: except StopSync as exc:
LOGGER.warning(exception_to_string(exc)) LOGGER.warning(exception_to_string(exc))
self.error(exc) self.error(exc)

View File

@ -10,22 +10,23 @@ from authentik.core.sources.mapper import SourceMapper
from authentik.lib.config import CONFIG from authentik.lib.config import CONFIG
from authentik.lib.sync.mapper import PropertyMappingManager from authentik.lib.sync.mapper import PropertyMappingManager
from authentik.sources.ldap.models import LDAPSource, flatten from authentik.sources.ldap.models import LDAPSource, flatten
from authentik.tasks.models import Task
class BaseLDAPSynchronizer: class BaseLDAPSynchronizer:
"""Sync LDAP Users and groups into authentik""" """Sync LDAP Users and groups into authentik"""
_source: LDAPSource _source: LDAPSource
_task: Task
_logger: BoundLogger _logger: BoundLogger
_connection: Connection _connection: Connection
_messages: list[str]
mapper: SourceMapper mapper: SourceMapper
manager: PropertyMappingManager manager: PropertyMappingManager
def __init__(self, source: LDAPSource): def __init__(self, source: LDAPSource, task: Task):
self._source = source self._source = source
self._task = task
self._connection = source.connection() self._connection = source.connection()
self._messages = []
self._logger = get_logger().bind(source=source, syncer=self.__class__.__name__) self._logger = get_logger().bind(source=source, syncer=self.__class__.__name__)
@staticmethod @staticmethod
@ -46,11 +47,6 @@ class BaseLDAPSynchronizer:
"""Sync function, implemented in subclass""" """Sync function, implemented in subclass"""
raise NotImplementedError() raise NotImplementedError()
@property
def messages(self) -> list[str]:
"""Get all UI messages"""
return self._messages
@property @property
def base_dn_users(self) -> str: def base_dn_users(self) -> str:
"""Shortcut to get full base_dn for user lookups""" """Shortcut to get full base_dn for user lookups"""
@ -65,14 +61,6 @@ class BaseLDAPSynchronizer:
return f"{self._source.additional_group_dn},{self._source.base_dn}" return f"{self._source.additional_group_dn},{self._source.base_dn}"
return self._source.base_dn return self._source.base_dn
def message(self, *args, **kwargs):
"""Add message that is later added to the System Task and shown to the user"""
formatted_message = " ".join(args)
if "dn" in kwargs:
formatted_message += f"; DN: {kwargs['dn']}"
self._messages.append(formatted_message)
self._logger.warning(*args, **kwargs)
def get_objects(self, **kwargs) -> Generator: def get_objects(self, **kwargs) -> Generator:
"""Get objects from LDAP, implemented in subclass""" """Get objects from LDAP, implemented in subclass"""
raise NotImplementedError() raise NotImplementedError()

View File

@ -19,7 +19,7 @@ class GroupLDAPForwardDeletion(BaseLDAPSynchronizer):
def get_objects(self, **kwargs) -> Generator: def get_objects(self, **kwargs) -> Generator:
if not self._source.sync_groups or not self._source.delete_not_found_objects: if not self._source.sync_groups or not self._source.delete_not_found_objects:
self.message("Group syncing is disabled for this Source") self._task.info("Group syncing is disabled for this Source")
return iter(()) return iter(())
uuid = uuid4() uuid = uuid4()
@ -54,7 +54,7 @@ class GroupLDAPForwardDeletion(BaseLDAPSynchronizer):
def sync(self, group_pks: tuple) -> int: def sync(self, group_pks: tuple) -> int:
"""Delete authentik groups""" """Delete authentik groups"""
if not self._source.sync_groups or not self._source.delete_not_found_objects: if not self._source.sync_groups or not self._source.delete_not_found_objects:
self.message("Group syncing is disabled for this Source") self._task.info("Group syncing is disabled for this Source")
return -1 return -1
self._logger.debug("Deleting groups", group_pks=group_pks) self._logger.debug("Deleting groups", group_pks=group_pks)
_, deleted_per_type = Group.objects.filter(pk__in=group_pks).delete() _, deleted_per_type = Group.objects.filter(pk__in=group_pks).delete()

View File

@ -21,7 +21,7 @@ class UserLDAPForwardDeletion(BaseLDAPSynchronizer):
def get_objects(self, **kwargs) -> Generator: def get_objects(self, **kwargs) -> Generator:
if not self._source.sync_users or not self._source.delete_not_found_objects: if not self._source.sync_users or not self._source.delete_not_found_objects:
self.message("User syncing is disabled for this Source") self._task.info("User syncing is disabled for this Source")
return iter(()) return iter(())
uuid = uuid4() uuid = uuid4()
@ -56,7 +56,7 @@ class UserLDAPForwardDeletion(BaseLDAPSynchronizer):
def sync(self, user_pks: tuple) -> int: def sync(self, user_pks: tuple) -> int:
"""Delete authentik users""" """Delete authentik users"""
if not self._source.sync_users or not self._source.delete_not_found_objects: if not self._source.sync_users or not self._source.delete_not_found_objects:
self.message("User syncing is disabled for this Source") self._task.info("User syncing is disabled for this Source")
return -1 return -1
self._logger.debug("Deleting users", user_pks=user_pks) self._logger.debug("Deleting users", user_pks=user_pks)
_, deleted_per_type = User.objects.filter(pk__in=user_pks).delete() _, deleted_per_type = User.objects.filter(pk__in=user_pks).delete()

View File

@ -37,7 +37,7 @@ class GroupLDAPSynchronizer(BaseLDAPSynchronizer):
def get_objects(self, **kwargs) -> Generator: def get_objects(self, **kwargs) -> Generator:
if not self._source.sync_groups: if not self._source.sync_groups:
self.message("Group syncing is disabled for this Source") self._task.info("Group syncing is disabled for this Source")
return iter(()) return iter(())
return self.search_paginator( return self.search_paginator(
search_base=self.base_dn_groups, search_base=self.base_dn_groups,
@ -54,7 +54,7 @@ class GroupLDAPSynchronizer(BaseLDAPSynchronizer):
def sync(self, page_data: list) -> int: def sync(self, page_data: list) -> int:
"""Iterate over all LDAP Groups and create authentik_core.Group instances""" """Iterate over all LDAP Groups and create authentik_core.Group instances"""
if not self._source.sync_groups: if not self._source.sync_groups:
self.message("Group syncing is disabled for this Source") self._task.info("Group syncing is disabled for this Source")
return -1 return -1
group_count = 0 group_count = 0
for group in page_data: for group in page_data:
@ -62,7 +62,7 @@ class GroupLDAPSynchronizer(BaseLDAPSynchronizer):
continue continue
group_dn = flatten(flatten(group.get("entryDN", group.get("dn")))) group_dn = flatten(flatten(group.get("entryDN", group.get("dn"))))
if not (uniq := self.get_identifier(attributes)): if not (uniq := self.get_identifier(attributes)):
self.message( self._task.info(
f"Uniqueness field not found/not set in attributes: '{group_dn}'", f"Uniqueness field not found/not set in attributes: '{group_dn}'",
attributes=attributes.keys(), attributes=attributes.keys(),
dn=group_dn, dn=group_dn,

View File

@ -26,7 +26,7 @@ class MembershipLDAPSynchronizer(BaseLDAPSynchronizer):
def get_objects(self, **kwargs) -> Generator: def get_objects(self, **kwargs) -> Generator:
if not self._source.sync_groups: if not self._source.sync_groups:
self.message("Group syncing is disabled for this Source") self._task.info("Group syncing is disabled for this Source")
return iter(()) return iter(())
# If we are looking up groups from users, we don't need to fetch the group membership field # If we are looking up groups from users, we don't need to fetch the group membership field
@ -45,7 +45,7 @@ class MembershipLDAPSynchronizer(BaseLDAPSynchronizer):
def sync(self, page_data: list) -> int: def sync(self, page_data: list) -> int:
"""Iterate over all Users and assign Groups using memberOf Field""" """Iterate over all Users and assign Groups using memberOf Field"""
if not self._source.sync_groups: if not self._source.sync_groups:
self.message("Group syncing is disabled for this Source") self._task.info("Group syncing is disabled for this Source")
return -1 return -1
membership_count = 0 membership_count = 0
for group in page_data: for group in page_data:
@ -94,7 +94,7 @@ class MembershipLDAPSynchronizer(BaseLDAPSynchronizer):
# group_uniq might be a single string or an array with (hopefully) a single string # group_uniq might be a single string or an array with (hopefully) a single string
if isinstance(group_uniq, list): if isinstance(group_uniq, list):
if len(group_uniq) < 1: if len(group_uniq) < 1:
self.message( self._task.info(
f"Group does not have a uniqueness attribute: '{group_dn}'", f"Group does not have a uniqueness attribute: '{group_dn}'",
group=group_dn, group=group_dn,
) )
@ -104,7 +104,7 @@ class MembershipLDAPSynchronizer(BaseLDAPSynchronizer):
groups = Group.objects.filter(**{f"attributes__{LDAP_UNIQUENESS}": group_uniq}) groups = Group.objects.filter(**{f"attributes__{LDAP_UNIQUENESS}": group_uniq})
if not groups.exists(): if not groups.exists():
if self._source.sync_groups: if self._source.sync_groups:
self.message( self._task.info(
f"Group does not exist in our DB yet, run sync_groups first: '{group_dn}'", f"Group does not exist in our DB yet, run sync_groups first: '{group_dn}'",
group=group_dn, group=group_dn,
) )

View File

@ -39,7 +39,7 @@ class UserLDAPSynchronizer(BaseLDAPSynchronizer):
def get_objects(self, **kwargs) -> Generator: def get_objects(self, **kwargs) -> Generator:
if not self._source.sync_users: if not self._source.sync_users:
self.message("User syncing is disabled for this Source") self._task.info("User syncing is disabled for this Source")
return iter(()) return iter(())
return self.search_paginator( return self.search_paginator(
search_base=self.base_dn_users, search_base=self.base_dn_users,
@ -56,7 +56,7 @@ class UserLDAPSynchronizer(BaseLDAPSynchronizer):
def sync(self, page_data: list) -> int: def sync(self, page_data: list) -> int:
"""Iterate over all LDAP Users and create authentik_core.User instances""" """Iterate over all LDAP Users and create authentik_core.User instances"""
if not self._source.sync_users: if not self._source.sync_users:
self.message("User syncing is disabled for this Source") self._task.info("User syncing is disabled for this Source")
return -1 return -1
user_count = 0 user_count = 0
for user in page_data: for user in page_data:
@ -64,7 +64,7 @@ class UserLDAPSynchronizer(BaseLDAPSynchronizer):
continue continue
user_dn = flatten(user.get("entryDN", user.get("dn"))) user_dn = flatten(user.get("entryDN", user.get("dn")))
if not (uniq := self.get_identifier(attributes)): if not (uniq := self.get_identifier(attributes)):
self.message( self._task.info(
f"Uniqueness field not found/not set in attributes: '{user_dn}'", f"Uniqueness field not found/not set in attributes: '{user_dn}'",
attributes=attributes.keys(), attributes=attributes.keys(),
dn=user_dn, dn=user_dn,

View File

@ -30,7 +30,7 @@ class FreeIPA(BaseLDAPSynchronizer):
pwd_last_set: datetime = attributes.get("krbLastPwdChange", datetime.now()) pwd_last_set: datetime = attributes.get("krbLastPwdChange", datetime.now())
pwd_last_set = pwd_last_set.replace(tzinfo=UTC) pwd_last_set = pwd_last_set.replace(tzinfo=UTC)
if created or pwd_last_set >= user.password_change_date: if created or pwd_last_set >= user.password_change_date:
self.message(f"'{user.username}': Reset user's password") self._task.info(f"'{user.username}': Reset user's password")
self._logger.debug( self._logger.debug(
"Reset user's password", "Reset user's password",
user=user.username, user=user.username,

View File

@ -60,7 +60,7 @@ class MicrosoftActiveDirectory(BaseLDAPSynchronizer):
pwd_last_set: datetime = attributes.get("pwdLastSet", datetime.now()) pwd_last_set: datetime = attributes.get("pwdLastSet", datetime.now())
pwd_last_set = pwd_last_set.replace(tzinfo=UTC) pwd_last_set = pwd_last_set.replace(tzinfo=UTC)
if created or pwd_last_set >= user.password_change_date: if created or pwd_last_set >= user.password_change_date:
self.message(f"'{user.username}': Reset user's password") self._task.info(f"'{user.username}': Reset user's password")
self._logger.debug( self._logger.debug(
"Reset user's password", "Reset user's password",
user=user.username, user=user.username,

View File

@ -72,14 +72,12 @@ def ldap_sync(source_pk: str):
) )
# User and group sync can happen at once, they have no dependencies on each other # User and group sync can happen at once, they have no dependencies on each other
user_group_tasks.run().get_results( user_group_tasks.run().wait(
block=True, timeout=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours") * 1000
timeout=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours") * 1000,
) )
# Membership sync needs to run afterwards # Membership sync needs to run afterwards
membership_tasks.run().get_results( membership_tasks.run().wait(
block=True, timeout=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours") * 1000
timeout=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours") * 1000,
) )
# Finally, deletions. What we'd really like to do here is something like # Finally, deletions. What we'd really like to do here is something like
# ``` # ```
@ -96,7 +94,9 @@ def ldap_sync(source_pk: str):
# large chunks, and only queue the deletion step afterwards. # large chunks, and only queue the deletion step afterwards.
# 3. Delete every unmarked item. This is slow, so we spread it over many tasks in # 3. Delete every unmarked item. This is slow, so we spread it over many tasks in
# small chunks. # small chunks.
deletion_tasks.run() # no need to block here, we don't have anything else to do afterwards deletion_tasks.run().wait(
timeout=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours") * 1000,
)
def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) -> list[Message]: def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) -> list[Message]:
@ -139,9 +139,7 @@ def ldap_sync_page(source_pk: str, sync_class: str, page_cache_key: str):
return return
cache.touch(page_cache_key) cache.touch(page_cache_key)
count = sync_inst.sync(page) count = sync_inst.sync(page)
messages = sync_inst.messages self.info(f"Synced {count} objects.")
messages.append(f"Synced {count} objects.")
self.info(*messages)
cache.delete(page_cache_key) cache.delete(page_cache_key)
except (LDAPException, StopSync) as exc: except (LDAPException, StopSync) as exc:
# No explicit event is created here as .set_status with an error will do that # No explicit event is created here as .set_status with an error will do that

View File

@ -13,6 +13,7 @@ from authentik.sources.ldap.models import LDAPSource, LDAPSourcePropertyMapping
from authentik.sources.ldap.sync.users import UserLDAPSynchronizer from authentik.sources.ldap.sync.users import UserLDAPSynchronizer
from authentik.sources.ldap.tests.mock_ad import mock_ad_connection from authentik.sources.ldap.tests.mock_ad import mock_ad_connection
from authentik.sources.ldap.tests.mock_slapd import mock_slapd_connection from authentik.sources.ldap.tests.mock_slapd import mock_slapd_connection
from authentik.tasks.models import Task
LDAP_PASSWORD = generate_key() LDAP_PASSWORD = generate_key()
@ -43,7 +44,7 @@ class LDAPSyncTests(TestCase):
raw_conn.bind = bind_mock raw_conn.bind = bind_mock
connection = MagicMock(return_value=raw_conn) connection = MagicMock(return_value=raw_conn)
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
user_sync = UserLDAPSynchronizer(self.source) user_sync = UserLDAPSynchronizer(self.source, Task())
user_sync.sync_full() user_sync.sync_full()
user = User.objects.get(username="user0_sn") user = User.objects.get(username="user0_sn")
@ -71,7 +72,7 @@ class LDAPSyncTests(TestCase):
) )
connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD)) connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD))
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
user_sync = UserLDAPSynchronizer(self.source) user_sync = UserLDAPSynchronizer(self.source, Task())
user_sync.sync_full() user_sync.sync_full()
user = User.objects.get(username="user0_sn") user = User.objects.get(username="user0_sn")
@ -98,7 +99,7 @@ class LDAPSyncTests(TestCase):
self.source.save() self.source.save()
connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD)) connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD))
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
user_sync = UserLDAPSynchronizer(self.source) user_sync = UserLDAPSynchronizer(self.source, Task())
user_sync.sync_full() user_sync.sync_full()
user = User.objects.get(username="user0_sn") user = User.objects.get(username="user0_sn")

View File

@ -23,7 +23,7 @@ from authentik.sources.ldap.sync.forward_delete_users import DELETE_CHUNK_SIZE
from authentik.sources.ldap.sync.groups import GroupLDAPSynchronizer from authentik.sources.ldap.sync.groups import GroupLDAPSynchronizer
from authentik.sources.ldap.sync.membership import MembershipLDAPSynchronizer from authentik.sources.ldap.sync.membership import MembershipLDAPSynchronizer
from authentik.sources.ldap.sync.users import UserLDAPSynchronizer from authentik.sources.ldap.sync.users import UserLDAPSynchronizer
from authentik.sources.ldap.tasks import ldap_sync, ldap_sync_all, ldap_sync_page from authentik.sources.ldap.tasks import ldap_sync, ldap_sync_page
from authentik.sources.ldap.tests.mock_ad import mock_ad_connection from authentik.sources.ldap.tests.mock_ad import mock_ad_connection
from authentik.sources.ldap.tests.mock_freeipa import mock_freeipa_connection from authentik.sources.ldap.tests.mock_freeipa import mock_freeipa_connection
from authentik.sources.ldap.tests.mock_slapd import ( from authentik.sources.ldap.tests.mock_slapd import (
@ -33,6 +33,7 @@ from authentik.sources.ldap.tests.mock_slapd import (
user_in_slapd_cn, user_in_slapd_cn,
user_in_slapd_uid, user_in_slapd_uid,
) )
from authentik.tasks.models import Task
LDAP_PASSWORD = generate_key() LDAP_PASSWORD = generate_key()
@ -74,7 +75,7 @@ class LDAPSyncTests(TestCase):
self.source.save() self.source.save()
connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD)) connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD))
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
user_sync = UserLDAPSynchronizer(self.source) user_sync = UserLDAPSynchronizer(self.source, Task())
with self.assertRaises(StopSync): with self.assertRaises(StopSync):
user_sync.sync_full() user_sync.sync_full()
self.assertFalse(User.objects.filter(username="user0_sn").exists()) self.assertFalse(User.objects.filter(username="user0_sn").exists())
@ -105,7 +106,7 @@ class LDAPSyncTests(TestCase):
# we basically just test that the mappings don't throw errors # we basically just test that the mappings don't throw errors
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
user_sync = UserLDAPSynchronizer(self.source) user_sync = UserLDAPSynchronizer(self.source, Task())
user_sync.sync_full() user_sync.sync_full()
def test_sync_users_ad(self): def test_sync_users_ad(self):
@ -133,7 +134,7 @@ class LDAPSyncTests(TestCase):
) )
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
user_sync = UserLDAPSynchronizer(self.source) user_sync = UserLDAPSynchronizer(self.source, Task())
user_sync.sync_full() user_sync.sync_full()
user = User.objects.filter(username="user0_sn").first() user = User.objects.filter(username="user0_sn").first()
self.assertEqual(user.attributes["foo"], "bar") self.assertEqual(user.attributes["foo"], "bar")
@ -152,7 +153,7 @@ class LDAPSyncTests(TestCase):
) )
connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD)) connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD))
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
user_sync = UserLDAPSynchronizer(self.source) user_sync = UserLDAPSynchronizer(self.source, Task())
user_sync.sync_full() user_sync.sync_full()
self.assertTrue(User.objects.filter(username="user0_sn").exists()) self.assertTrue(User.objects.filter(username="user0_sn").exists())
self.assertFalse(User.objects.filter(username="user1_sn").exists()) self.assertFalse(User.objects.filter(username="user1_sn").exists())
@ -168,7 +169,7 @@ class LDAPSyncTests(TestCase):
) )
connection = MagicMock(return_value=mock_freeipa_connection(LDAP_PASSWORD)) connection = MagicMock(return_value=mock_freeipa_connection(LDAP_PASSWORD))
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
user_sync = UserLDAPSynchronizer(self.source) user_sync = UserLDAPSynchronizer(self.source, Task())
user_sync.sync_full() user_sync.sync_full()
self.assertTrue(User.objects.filter(username="user0_sn").exists()) self.assertTrue(User.objects.filter(username="user0_sn").exists())
self.assertFalse(User.objects.filter(username="user1_sn").exists()) self.assertFalse(User.objects.filter(username="user1_sn").exists())
@ -193,11 +194,11 @@ class LDAPSyncTests(TestCase):
) )
connection = MagicMock(return_value=mock_freeipa_connection(LDAP_PASSWORD)) connection = MagicMock(return_value=mock_freeipa_connection(LDAP_PASSWORD))
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
user_sync = UserLDAPSynchronizer(self.source) user_sync = UserLDAPSynchronizer(self.source, Task())
user_sync.sync_full() user_sync.sync_full()
group_sync = GroupLDAPSynchronizer(self.source) group_sync = GroupLDAPSynchronizer(self.source, Task())
group_sync.sync_full() group_sync.sync_full()
membership_sync = MembershipLDAPSynchronizer(self.source) membership_sync = MembershipLDAPSynchronizer(self.source, Task())
membership_sync.sync_full() membership_sync.sync_full()
self.assertTrue( self.assertTrue(
@ -230,9 +231,9 @@ class LDAPSyncTests(TestCase):
parent_group = Group.objects.get(name=_user.username) parent_group = Group.objects.get(name=_user.username)
self.source.sync_parent_group = parent_group self.source.sync_parent_group = parent_group
self.source.save() self.source.save()
group_sync = GroupLDAPSynchronizer(self.source) group_sync = GroupLDAPSynchronizer(self.source, Task())
group_sync.sync_full() group_sync.sync_full()
membership_sync = MembershipLDAPSynchronizer(self.source) membership_sync = MembershipLDAPSynchronizer(self.source, Task())
membership_sync.sync_full() membership_sync.sync_full()
group: Group = Group.objects.filter(name="test-group").first() group: Group = Group.objects.filter(name="test-group").first()
self.assertIsNotNone(group) self.assertIsNotNone(group)
@ -256,9 +257,9 @@ class LDAPSyncTests(TestCase):
connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD)) connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD))
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
self.source.save() self.source.save()
group_sync = GroupLDAPSynchronizer(self.source) group_sync = GroupLDAPSynchronizer(self.source, Task())
group_sync.sync_full() group_sync.sync_full()
membership_sync = MembershipLDAPSynchronizer(self.source) membership_sync = MembershipLDAPSynchronizer(self.source, Task())
membership_sync.sync_full() membership_sync.sync_full()
group = Group.objects.filter(name="group1") group = Group.objects.filter(name="group1")
self.assertTrue(group.exists()) self.assertTrue(group.exists())
@ -290,11 +291,11 @@ class LDAPSyncTests(TestCase):
connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD)) connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD))
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
self.source.save() self.source.save()
user_sync = UserLDAPSynchronizer(self.source) user_sync = UserLDAPSynchronizer(self.source, Task())
user_sync.sync_full() user_sync.sync_full()
group_sync = GroupLDAPSynchronizer(self.source) group_sync = GroupLDAPSynchronizer(self.source, Task())
group_sync.sync_full() group_sync.sync_full()
membership_sync = MembershipLDAPSynchronizer(self.source) membership_sync = MembershipLDAPSynchronizer(self.source, Task())
membership_sync.sync_full() membership_sync.sync_full()
# Test if membership mapping based on memberUid works. # Test if membership mapping based on memberUid works.
posix_group = Group.objects.filter(name="group-posix").first() posix_group = Group.objects.filter(name="group-posix").first()
@ -327,11 +328,11 @@ class LDAPSyncTests(TestCase):
connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD)) connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD))
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
self.source.save() self.source.save()
user_sync = UserLDAPSynchronizer(self.source) user_sync = UserLDAPSynchronizer(self.source, Task())
user_sync.sync_full() user_sync.sync_full()
group_sync = GroupLDAPSynchronizer(self.source) group_sync = GroupLDAPSynchronizer(self.source, Task())
group_sync.sync_full() group_sync.sync_full()
membership_sync = MembershipLDAPSynchronizer(self.source) membership_sync = MembershipLDAPSynchronizer(self.source, Task())
membership_sync.sync_full() membership_sync.sync_full()
# Test if membership mapping based on memberUid works. # Test if membership mapping based on memberUid works.
posix_group = Group.objects.filter(name="group-posix").first() posix_group = Group.objects.filter(name="group-posix").first()

View File

@ -112,23 +112,20 @@ class Task(SerializerModel):
if save: if save:
self.save() self.save()
def log(self, status: TaskStatus, *messages: str | LogEvent | Exception, save: bool = False): def log(self, status: TaskStatus, message: str | Exception, save: bool = False, **attributes):
self.messages: list self.messages: list
for msg in messages: if isinstance(message, Exception):
message = msg message = exception_to_string(message)
if isinstance(message, Exception): message = LogEvent(message, logger=self.uid, log_level=status.value, attributes=attributes)
message = exception_to_string(message) self.messages.append(sanitize_item(message))
if not isinstance(message, LogEvent):
message = LogEvent(message, logger=self.uid, log_level=status.value)
self.messages.append(sanitize_item(message))
if save: if save:
self.save() self.save()
def info(self, *messages: str | LogEvent | Exception, save: bool = False): def info(self, message: str | Exception, save: bool = False, **attributes):
self.log(TaskStatus.INFO, *messages, save=save) self.log(TaskStatus.INFO, message, save=save, **attributes)
def warning(self, *messages: str | LogEvent | Exception, save: bool = False): def warning(self, message: str | Exception, save: bool = False, **attributes):
self.log(TaskStatus.WARNING, *messages, save=save) self.log(TaskStatus.WARNING, message, save=save, **attributes)
def error(self, *messages: str | LogEvent | Exception, save: bool = False): def error(self, message: str | Exception, save: bool = False, **attributes):
self.log(TaskStatus.ERROR, *messages, save=save) self.log(TaskStatus.ERROR, message, save=save, **attributes)