uhhhhhhhhh
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
@ -45,11 +45,17 @@ class OutgoingSyncProvider(ScheduledModel, Model):
|
||||
def get_object_qs[T: User | Group](self, type: type[T]) -> QuerySet[T]:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_paginator[T: User | Group](self, type: type[T]) -> Paginator:
|
||||
return Paginator(self.get_object_qs(type), PAGE_SIZE)
|
||||
|
||||
def get_object_sync_time_limit[T: User | Group](self, type: type[T]) -> int:
|
||||
num_pages: int = self.get_paginator(type).num_pages
|
||||
return int(num_pages * PAGE_TIMEOUT * 1.5) * 1000
|
||||
|
||||
def get_sync_time_limit(self) -> int:
|
||||
users_paginator = Paginator(self.get_object_qs(User), PAGE_SIZE)
|
||||
groups_paginator = Paginator(self.get_object_qs(Group), PAGE_SIZE)
|
||||
time_limit = (users_paginator.num_pages + groups_paginator.num_pages) * PAGE_TIMEOUT * 1.5
|
||||
return int(time_limit)
|
||||
return int(
|
||||
self.get_object_sync_time_limit(User) + self.get_object_sync_time_limit(Group) * 1.5
|
||||
)
|
||||
|
||||
@property
|
||||
def sync_lock(self) -> pglock.advisory:
|
||||
@ -72,7 +78,7 @@ class OutgoingSyncProvider(ScheduledModel, Model):
|
||||
uid=self.pk,
|
||||
args=(self.pk,),
|
||||
options={
|
||||
"time_limit": self.get_sync_time_limit() * 1000,
|
||||
"time_limit": self.get_sync_time_limit(),
|
||||
},
|
||||
send_on_save=True,
|
||||
crontab=f"{fqdn_rand(self.pk)} */4 * * *",
|
||||
|
@ -1,21 +1,19 @@
|
||||
from collections.abc import Callable
|
||||
from dataclasses import asdict
|
||||
|
||||
from celery import group
|
||||
from celery.exceptions import Retry
|
||||
from celery.result import allow_join_result
|
||||
from dramatiq.composition import group
|
||||
from django.core.paginator import Paginator
|
||||
from django.db.models import Model, QuerySet
|
||||
from django.db.models.query import Q
|
||||
from django.utils.text import slugify
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from dramatiq.actor import Actor
|
||||
from dramatiq.errors import Retry
|
||||
from structlog.stdlib import BoundLogger, get_logger
|
||||
|
||||
from authentik.core.expression.exceptions import SkipObjectException
|
||||
from authentik.core.models import Group, User
|
||||
from authentik.events.logs import LogEvent
|
||||
from authentik.events.models import TaskStatus
|
||||
from authentik.events.system_tasks import SystemTask
|
||||
from authentik.events.utils import sanitize_item
|
||||
from authentik.lib.sync.outgoing import PAGE_SIZE, PAGE_TIMEOUT
|
||||
from authentik.lib.sync.outgoing.base import Direction
|
||||
@ -27,11 +25,13 @@ from authentik.lib.sync.outgoing.exceptions import (
|
||||
)
|
||||
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
|
||||
from authentik.lib.utils.reflection import class_to_path, path_to_class
|
||||
from authentik.tasks.middleware import CurrentTask
|
||||
from authentik.tasks.models import Task
|
||||
|
||||
|
||||
class SyncTasks:
|
||||
"""Container for all sync 'tasks' (this class doesn't actually contain celery
|
||||
tasks due to celery's magic, however exposes a number of functions to be called from tasks)"""
|
||||
"""Container for all sync 'tasks' (this class doesn't actually contain
|
||||
tasks due to dramatiq's magic, however exposes a number of functions to be called from tasks)"""
|
||||
|
||||
logger: BoundLogger
|
||||
|
||||
@ -39,107 +39,97 @@ class SyncTasks:
|
||||
super().__init__()
|
||||
self._provider_model = provider_model
|
||||
|
||||
def sync_all(self, single_sync: Callable[[int], None]):
|
||||
for provider in self._provider_model.objects.filter(
|
||||
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
|
||||
):
|
||||
self.trigger_single_task(provider, single_sync)
|
||||
|
||||
def trigger_single_task(self, provider: OutgoingSyncProvider, sync_task: Callable[[int], None]):
|
||||
"""Wrapper single sync task that correctly sets time limits based
|
||||
on the amount of objects that will be synced"""
|
||||
users_paginator = Paginator(provider.get_object_qs(User), PAGE_SIZE)
|
||||
groups_paginator = Paginator(provider.get_object_qs(Group), PAGE_SIZE)
|
||||
soft_time_limit = (users_paginator.num_pages + groups_paginator.num_pages) * PAGE_TIMEOUT
|
||||
time_limit = soft_time_limit * 1.5
|
||||
return sync_task.apply_async(
|
||||
(provider.pk,), time_limit=int(time_limit), soft_time_limit=int(soft_time_limit)
|
||||
)
|
||||
|
||||
def sync_single(
|
||||
def sync_paginator(
|
||||
self,
|
||||
task: SystemTask,
|
||||
provider_pk: int,
|
||||
sync_objects: Callable[[int, int], list[str]],
|
||||
current_task: Task,
|
||||
provider: OutgoingSyncProvider,
|
||||
sync_objects: Actor,
|
||||
paginator: Paginator,
|
||||
object_type: type[User | Group],
|
||||
**options,
|
||||
):
|
||||
tasks = []
|
||||
for page in paginator.page_range:
|
||||
page_sync = sync_objects.message_with_options(
|
||||
args=(class_to_path(object_type), page, provider.pk),
|
||||
time_limit=PAGE_TIMEOUT * 1000,
|
||||
# Assign tasks to the same schedule as the current one
|
||||
rel_obj=current_task.rel_obj,
|
||||
**options,
|
||||
)
|
||||
tasks.append(page_sync)
|
||||
return tasks
|
||||
|
||||
def sync(
|
||||
self,
|
||||
provider_pk: int,
|
||||
sync_objects: Actor,
|
||||
):
|
||||
task = CurrentTask.get_task()
|
||||
self.logger = get_logger().bind(
|
||||
provider_type=class_to_path(self._provider_model),
|
||||
provider_pk=provider_pk,
|
||||
)
|
||||
provider = self._provider_model.objects.filter(
|
||||
provider: OutgoingSyncProvider = self._provider_model.objects.filter(
|
||||
Q(backchannel_application__isnull=False) | Q(application__isnull=False),
|
||||
pk=provider_pk,
|
||||
).first()
|
||||
if not provider:
|
||||
return
|
||||
task.set_uid(slugify(provider.name))
|
||||
messages = []
|
||||
messages.append(_("Starting full provider sync"))
|
||||
task.info("Starting full provider sync")
|
||||
self.logger.debug("Starting provider sync")
|
||||
users_paginator = Paginator(provider.get_object_qs(User), PAGE_SIZE)
|
||||
groups_paginator = Paginator(provider.get_object_qs(Group), PAGE_SIZE)
|
||||
with allow_join_result(), provider.sync_lock as lock_acquired:
|
||||
with provider.sync_lock as lock_acquired:
|
||||
if not lock_acquired:
|
||||
self.logger.debug("Failed to acquire sync lock, skipping", provider=provider.name)
|
||||
return
|
||||
try:
|
||||
messages.append(_("Syncing users"))
|
||||
user_results = (
|
||||
group(
|
||||
[
|
||||
sync_objects.signature(
|
||||
args=(class_to_path(User), page, provider_pk),
|
||||
time_limit=PAGE_TIMEOUT,
|
||||
soft_time_limit=PAGE_TIMEOUT,
|
||||
)
|
||||
for page in users_paginator.page_range
|
||||
]
|
||||
users_tasks = group(
|
||||
self.sync_paginator(
|
||||
current_task=task,
|
||||
provider=provider,
|
||||
sync_objects=sync_objects,
|
||||
paginator=provider.get_paginator(User),
|
||||
object_type=User,
|
||||
)
|
||||
.apply_async()
|
||||
.get()
|
||||
)
|
||||
for result in user_results:
|
||||
for msg in result:
|
||||
messages.append(LogEvent(**msg))
|
||||
messages.append(_("Syncing groups"))
|
||||
group_results = (
|
||||
group(
|
||||
[
|
||||
sync_objects.signature(
|
||||
args=(class_to_path(Group), page, provider_pk),
|
||||
time_limit=PAGE_TIMEOUT,
|
||||
soft_time_limit=PAGE_TIMEOUT,
|
||||
)
|
||||
for page in groups_paginator.page_range
|
||||
]
|
||||
group_tasks = group(
|
||||
self.sync_paginator(
|
||||
current_task=task,
|
||||
provider=provider,
|
||||
sync_objects=sync_objects,
|
||||
paginator=provider.get_paginator(Group),
|
||||
object_type=Group,
|
||||
)
|
||||
.apply_async()
|
||||
.get()
|
||||
)
|
||||
for result in group_results:
|
||||
for msg in result:
|
||||
messages.append(LogEvent(**msg))
|
||||
users_tasks.run().wait(timeout=provider.get_object_sync_time_limit(User))
|
||||
group_tasks.run().wait(timeout=provider.get_object_sync_time_limit(Group))
|
||||
except TransientSyncException as exc:
|
||||
self.logger.warning("transient sync exception", exc=exc)
|
||||
raise task.retry(exc=exc) from exc
|
||||
raise Retry from exc
|
||||
except StopSync as exc:
|
||||
task.set_error(exc)
|
||||
task.error(exc)
|
||||
return
|
||||
task.set_status(TaskStatus.SUCCESSFUL, *messages)
|
||||
|
||||
def sync_objects(
|
||||
self, object_type: str, page: int, provider_pk: int, override_dry_run=False, **filter
|
||||
self,
|
||||
object_type: str,
|
||||
page: int,
|
||||
provider_pk: int,
|
||||
override_dry_run=False,
|
||||
**filter,
|
||||
):
|
||||
task = CurrentTask.get_task()
|
||||
_object_type: type[Model] = path_to_class(object_type)
|
||||
self.logger = get_logger().bind(
|
||||
provider_type=class_to_path(self._provider_model),
|
||||
provider_pk=provider_pk,
|
||||
object_type=object_type,
|
||||
)
|
||||
messages = []
|
||||
task.info(f"Syncing page {page} of {_object_type._meta.verbose_name_plural}")
|
||||
provider = self._provider_model.objects.filter(pk=provider_pk).first()
|
||||
if not provider:
|
||||
return messages
|
||||
return
|
||||
# Override dry run mode if requested, however don't save the provider
|
||||
# so that scheduled sync tasks still run in dry_run mode
|
||||
if override_dry_run:
|
||||
@ -147,25 +137,13 @@ class SyncTasks:
|
||||
try:
|
||||
client = provider.client_for_model(_object_type)
|
||||
except TransientSyncException:
|
||||
return messages
|
||||
return
|
||||
paginator = Paginator(provider.get_object_qs(_object_type).filter(**filter), PAGE_SIZE)
|
||||
if client.can_discover:
|
||||
self.logger.debug("starting discover")
|
||||
client.discover()
|
||||
self.logger.debug("starting sync for page", page=page)
|
||||
messages.append(
|
||||
asdict(
|
||||
LogEvent(
|
||||
_(
|
||||
"Syncing page {page} of {object_type}".format(
|
||||
page=page, object_type=_object_type._meta.verbose_name_plural
|
||||
)
|
||||
),
|
||||
log_level="info",
|
||||
logger=f"{provider._meta.verbose_name}@{object_type}",
|
||||
)
|
||||
)
|
||||
)
|
||||
task.info(f"Syncing page {page} or {_object_type._meta.verbose_name_plural}")
|
||||
for obj in paginator.page(page).object_list:
|
||||
obj: Model
|
||||
try:
|
||||
@ -174,87 +152,34 @@ class SyncTasks:
|
||||
self.logger.debug("skipping object due to SkipObject", obj=obj)
|
||||
continue
|
||||
except DryRunRejected as exc:
|
||||
messages.append(
|
||||
asdict(
|
||||
LogEvent(
|
||||
_("Dropping mutating request due to dry run"),
|
||||
log_level="info",
|
||||
logger=f"{provider._meta.verbose_name}@{object_type}",
|
||||
attributes={
|
||||
"obj": sanitize_item(obj),
|
||||
"method": exc.method,
|
||||
"url": exc.url,
|
||||
"body": exc.body,
|
||||
},
|
||||
)
|
||||
)
|
||||
task.info(
|
||||
"Dropping mutating request due to dry run",
|
||||
attributes={
|
||||
"obj": sanitize_item(obj),
|
||||
"method": exc.method,
|
||||
"url": exc.url,
|
||||
"body": exc.body,
|
||||
},
|
||||
)
|
||||
except BadRequestSyncException as exc:
|
||||
self.logger.warning("failed to sync object", exc=exc, obj=obj)
|
||||
messages.append(
|
||||
asdict(
|
||||
LogEvent(
|
||||
_(
|
||||
(
|
||||
"Failed to sync {object_type} {object_name} "
|
||||
"due to error: {error}"
|
||||
).format_map(
|
||||
{
|
||||
"object_type": obj._meta.verbose_name,
|
||||
"object_name": str(obj),
|
||||
"error": str(exc),
|
||||
}
|
||||
)
|
||||
),
|
||||
log_level="warning",
|
||||
logger=f"{provider._meta.verbose_name}@{object_type}",
|
||||
attributes={"arguments": exc.args[1:], "obj": sanitize_item(obj)},
|
||||
)
|
||||
)
|
||||
task.warning(
|
||||
f"Failed to sync {obj._meta.verbose_name} {str(obj)} due to error: {str(exc)}",
|
||||
attributes={"arguments": exc.args[1:], "obj": sanitize_item(obj)},
|
||||
)
|
||||
except TransientSyncException as exc:
|
||||
self.logger.warning("failed to sync object", exc=exc, user=obj)
|
||||
messages.append(
|
||||
asdict(
|
||||
LogEvent(
|
||||
_(
|
||||
(
|
||||
"Failed to sync {object_type} {object_name} "
|
||||
"due to transient error: {error}"
|
||||
).format_map(
|
||||
{
|
||||
"object_type": obj._meta.verbose_name,
|
||||
"object_name": str(obj),
|
||||
"error": str(exc),
|
||||
}
|
||||
)
|
||||
),
|
||||
log_level="warning",
|
||||
logger=f"{provider._meta.verbose_name}@{object_type}",
|
||||
attributes={"obj": sanitize_item(obj)},
|
||||
)
|
||||
)
|
||||
task.warning(
|
||||
f"Failed to sync {obj._meta.verbose_name} {str(obj)} due to transient error: {str(exc)}",
|
||||
attributes={"obj": sanitize_item(obj)},
|
||||
)
|
||||
except StopSync as exc:
|
||||
self.logger.warning("Stopping sync", exc=exc)
|
||||
messages.append(
|
||||
asdict(
|
||||
LogEvent(
|
||||
_(
|
||||
"Stopping sync due to error: {error}".format_map(
|
||||
{
|
||||
"error": exc.detail(),
|
||||
}
|
||||
)
|
||||
),
|
||||
log_level="warning",
|
||||
logger=f"{provider._meta.verbose_name}@{object_type}",
|
||||
attributes={"obj": sanitize_item(obj)},
|
||||
)
|
||||
)
|
||||
task.warning(
|
||||
f"Stopping sync due to error: {exc.detail()}",
|
||||
attributes={"obj": sanitize_item(obj)},
|
||||
)
|
||||
break
|
||||
return messages
|
||||
|
||||
def sync_signal_direct(self, model: str, pk: str | int, raw_op: str):
|
||||
self.logger = get_logger().bind(
|
||||
|
@ -130,7 +130,8 @@ def outpost_controller(outpost_pk: str, action: str = "up", from_cache: bool = F
|
||||
else:
|
||||
if from_cache:
|
||||
cache.delete(CACHE_KEY_OUTPOST_DOWN % outpost_pk)
|
||||
self.info(*logs)
|
||||
for log in logs:
|
||||
self.info(log)
|
||||
|
||||
|
||||
@actor
|
||||
|
@ -85,6 +85,6 @@ class SCIMClientTests(TestCase):
|
||||
self.assertEqual(mock.call_count, 1)
|
||||
self.assertEqual(mock.request_history[0].method, "GET")
|
||||
|
||||
def test_scim_sync_all(self):
|
||||
"""test scim_sync_all task"""
|
||||
def test_scim_sync(self):
|
||||
"""test scim_sync task"""
|
||||
scim_sync.send(self.provider.pk).get_result()
|
||||
|
@ -23,12 +23,14 @@ from authentik.sources.kerberos.models import (
|
||||
Krb5ConfContext,
|
||||
UserKerberosSourceConnection,
|
||||
)
|
||||
from authentik.tasks.models import Task
|
||||
|
||||
|
||||
class KerberosSync:
|
||||
"""Sync Kerberos users into authentik"""
|
||||
|
||||
_source: KerberosSource
|
||||
_task: Task
|
||||
_logger: BoundLogger
|
||||
_connection: KAdmin
|
||||
mapper: SourceMapper
|
||||
@ -36,11 +38,11 @@ class KerberosSync:
|
||||
group_manager: PropertyMappingManager
|
||||
matcher: SourceMatcher
|
||||
|
||||
def __init__(self, source: KerberosSource):
|
||||
def __init__(self, source: KerberosSource, task: Task):
|
||||
self._source = source
|
||||
self._task = task
|
||||
with Krb5ConfContext(self._source):
|
||||
self._connection = self._source.connection()
|
||||
self._messages = []
|
||||
self._logger = get_logger().bind(source=self._source, syncer=self.__class__.__name__)
|
||||
self.mapper = SourceMapper(self._source)
|
||||
self.user_manager = self.mapper.get_manager(User, ["principal", "principal_obj"])
|
||||
@ -56,17 +58,6 @@ class KerberosSync:
|
||||
"""UI name for the type of object this class synchronizes"""
|
||||
return "users"
|
||||
|
||||
@property
|
||||
def messages(self) -> list[str]:
|
||||
"""Get all UI messages"""
|
||||
return self._messages
|
||||
|
||||
def message(self, *args, **kwargs):
|
||||
"""Add message that is later added to the System Task and shown to the user"""
|
||||
formatted_message = " ".join(args)
|
||||
self._messages.append(formatted_message)
|
||||
self._logger.warning(*args, **kwargs)
|
||||
|
||||
def _handle_principal(self, principal: str) -> bool:
|
||||
try:
|
||||
# TODO: handle permission error
|
||||
@ -163,7 +154,7 @@ class KerberosSync:
|
||||
def sync(self) -> int:
|
||||
"""Iterate over all Kerberos users and create authentik_core.User instances"""
|
||||
if not self._source.enabled or not self._source.sync_users:
|
||||
self.message("Source is disabled or user syncing is disabled for this Source")
|
||||
self._task.info("Source is disabled or user syncing is disabled for this Source")
|
||||
return -1
|
||||
|
||||
user_count = 0
|
||||
|
@ -43,7 +43,6 @@ def kerberos_sync(pk: str):
|
||||
return
|
||||
syncer = KerberosSync(source)
|
||||
syncer.sync()
|
||||
self.info(*syncer.messages)
|
||||
except StopSync as exc:
|
||||
LOGGER.warning(exception_to_string(exc))
|
||||
self.error(exc)
|
||||
|
@ -10,22 +10,23 @@ from authentik.core.sources.mapper import SourceMapper
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.lib.sync.mapper import PropertyMappingManager
|
||||
from authentik.sources.ldap.models import LDAPSource, flatten
|
||||
from authentik.tasks.models import Task
|
||||
|
||||
|
||||
class BaseLDAPSynchronizer:
|
||||
"""Sync LDAP Users and groups into authentik"""
|
||||
|
||||
_source: LDAPSource
|
||||
_task: Task
|
||||
_logger: BoundLogger
|
||||
_connection: Connection
|
||||
_messages: list[str]
|
||||
mapper: SourceMapper
|
||||
manager: PropertyMappingManager
|
||||
|
||||
def __init__(self, source: LDAPSource):
|
||||
def __init__(self, source: LDAPSource, task: Task):
|
||||
self._source = source
|
||||
self._task = task
|
||||
self._connection = source.connection()
|
||||
self._messages = []
|
||||
self._logger = get_logger().bind(source=source, syncer=self.__class__.__name__)
|
||||
|
||||
@staticmethod
|
||||
@ -46,11 +47,6 @@ class BaseLDAPSynchronizer:
|
||||
"""Sync function, implemented in subclass"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def messages(self) -> list[str]:
|
||||
"""Get all UI messages"""
|
||||
return self._messages
|
||||
|
||||
@property
|
||||
def base_dn_users(self) -> str:
|
||||
"""Shortcut to get full base_dn for user lookups"""
|
||||
@ -65,14 +61,6 @@ class BaseLDAPSynchronizer:
|
||||
return f"{self._source.additional_group_dn},{self._source.base_dn}"
|
||||
return self._source.base_dn
|
||||
|
||||
def message(self, *args, **kwargs):
|
||||
"""Add message that is later added to the System Task and shown to the user"""
|
||||
formatted_message = " ".join(args)
|
||||
if "dn" in kwargs:
|
||||
formatted_message += f"; DN: {kwargs['dn']}"
|
||||
self._messages.append(formatted_message)
|
||||
self._logger.warning(*args, **kwargs)
|
||||
|
||||
def get_objects(self, **kwargs) -> Generator:
|
||||
"""Get objects from LDAP, implemented in subclass"""
|
||||
raise NotImplementedError()
|
||||
|
@ -19,7 +19,7 @@ class GroupLDAPForwardDeletion(BaseLDAPSynchronizer):
|
||||
|
||||
def get_objects(self, **kwargs) -> Generator:
|
||||
if not self._source.sync_groups or not self._source.delete_not_found_objects:
|
||||
self.message("Group syncing is disabled for this Source")
|
||||
self._task.info("Group syncing is disabled for this Source")
|
||||
return iter(())
|
||||
|
||||
uuid = uuid4()
|
||||
@ -54,7 +54,7 @@ class GroupLDAPForwardDeletion(BaseLDAPSynchronizer):
|
||||
def sync(self, group_pks: tuple) -> int:
|
||||
"""Delete authentik groups"""
|
||||
if not self._source.sync_groups or not self._source.delete_not_found_objects:
|
||||
self.message("Group syncing is disabled for this Source")
|
||||
self._task.info("Group syncing is disabled for this Source")
|
||||
return -1
|
||||
self._logger.debug("Deleting groups", group_pks=group_pks)
|
||||
_, deleted_per_type = Group.objects.filter(pk__in=group_pks).delete()
|
||||
|
@ -21,7 +21,7 @@ class UserLDAPForwardDeletion(BaseLDAPSynchronizer):
|
||||
|
||||
def get_objects(self, **kwargs) -> Generator:
|
||||
if not self._source.sync_users or not self._source.delete_not_found_objects:
|
||||
self.message("User syncing is disabled for this Source")
|
||||
self._task.info("User syncing is disabled for this Source")
|
||||
return iter(())
|
||||
|
||||
uuid = uuid4()
|
||||
@ -56,7 +56,7 @@ class UserLDAPForwardDeletion(BaseLDAPSynchronizer):
|
||||
def sync(self, user_pks: tuple) -> int:
|
||||
"""Delete authentik users"""
|
||||
if not self._source.sync_users or not self._source.delete_not_found_objects:
|
||||
self.message("User syncing is disabled for this Source")
|
||||
self._task.info("User syncing is disabled for this Source")
|
||||
return -1
|
||||
self._logger.debug("Deleting users", user_pks=user_pks)
|
||||
_, deleted_per_type = User.objects.filter(pk__in=user_pks).delete()
|
||||
|
@ -37,7 +37,7 @@ class GroupLDAPSynchronizer(BaseLDAPSynchronizer):
|
||||
|
||||
def get_objects(self, **kwargs) -> Generator:
|
||||
if not self._source.sync_groups:
|
||||
self.message("Group syncing is disabled for this Source")
|
||||
self._task.info("Group syncing is disabled for this Source")
|
||||
return iter(())
|
||||
return self.search_paginator(
|
||||
search_base=self.base_dn_groups,
|
||||
@ -54,7 +54,7 @@ class GroupLDAPSynchronizer(BaseLDAPSynchronizer):
|
||||
def sync(self, page_data: list) -> int:
|
||||
"""Iterate over all LDAP Groups and create authentik_core.Group instances"""
|
||||
if not self._source.sync_groups:
|
||||
self.message("Group syncing is disabled for this Source")
|
||||
self._task.info("Group syncing is disabled for this Source")
|
||||
return -1
|
||||
group_count = 0
|
||||
for group in page_data:
|
||||
@ -62,7 +62,7 @@ class GroupLDAPSynchronizer(BaseLDAPSynchronizer):
|
||||
continue
|
||||
group_dn = flatten(flatten(group.get("entryDN", group.get("dn"))))
|
||||
if not (uniq := self.get_identifier(attributes)):
|
||||
self.message(
|
||||
self._task.info(
|
||||
f"Uniqueness field not found/not set in attributes: '{group_dn}'",
|
||||
attributes=attributes.keys(),
|
||||
dn=group_dn,
|
||||
|
@ -26,7 +26,7 @@ class MembershipLDAPSynchronizer(BaseLDAPSynchronizer):
|
||||
|
||||
def get_objects(self, **kwargs) -> Generator:
|
||||
if not self._source.sync_groups:
|
||||
self.message("Group syncing is disabled for this Source")
|
||||
self._task.info("Group syncing is disabled for this Source")
|
||||
return iter(())
|
||||
|
||||
# If we are looking up groups from users, we don't need to fetch the group membership field
|
||||
@ -45,7 +45,7 @@ class MembershipLDAPSynchronizer(BaseLDAPSynchronizer):
|
||||
def sync(self, page_data: list) -> int:
|
||||
"""Iterate over all Users and assign Groups using memberOf Field"""
|
||||
if not self._source.sync_groups:
|
||||
self.message("Group syncing is disabled for this Source")
|
||||
self._task.info("Group syncing is disabled for this Source")
|
||||
return -1
|
||||
membership_count = 0
|
||||
for group in page_data:
|
||||
@ -94,7 +94,7 @@ class MembershipLDAPSynchronizer(BaseLDAPSynchronizer):
|
||||
# group_uniq might be a single string or an array with (hopefully) a single string
|
||||
if isinstance(group_uniq, list):
|
||||
if len(group_uniq) < 1:
|
||||
self.message(
|
||||
self._task.info(
|
||||
f"Group does not have a uniqueness attribute: '{group_dn}'",
|
||||
group=group_dn,
|
||||
)
|
||||
@ -104,7 +104,7 @@ class MembershipLDAPSynchronizer(BaseLDAPSynchronizer):
|
||||
groups = Group.objects.filter(**{f"attributes__{LDAP_UNIQUENESS}": group_uniq})
|
||||
if not groups.exists():
|
||||
if self._source.sync_groups:
|
||||
self.message(
|
||||
self._task.info(
|
||||
f"Group does not exist in our DB yet, run sync_groups first: '{group_dn}'",
|
||||
group=group_dn,
|
||||
)
|
||||
|
@ -39,7 +39,7 @@ class UserLDAPSynchronizer(BaseLDAPSynchronizer):
|
||||
|
||||
def get_objects(self, **kwargs) -> Generator:
|
||||
if not self._source.sync_users:
|
||||
self.message("User syncing is disabled for this Source")
|
||||
self._task.info("User syncing is disabled for this Source")
|
||||
return iter(())
|
||||
return self.search_paginator(
|
||||
search_base=self.base_dn_users,
|
||||
@ -56,7 +56,7 @@ class UserLDAPSynchronizer(BaseLDAPSynchronizer):
|
||||
def sync(self, page_data: list) -> int:
|
||||
"""Iterate over all LDAP Users and create authentik_core.User instances"""
|
||||
if not self._source.sync_users:
|
||||
self.message("User syncing is disabled for this Source")
|
||||
self._task.info("User syncing is disabled for this Source")
|
||||
return -1
|
||||
user_count = 0
|
||||
for user in page_data:
|
||||
@ -64,7 +64,7 @@ class UserLDAPSynchronizer(BaseLDAPSynchronizer):
|
||||
continue
|
||||
user_dn = flatten(user.get("entryDN", user.get("dn")))
|
||||
if not (uniq := self.get_identifier(attributes)):
|
||||
self.message(
|
||||
self._task.info(
|
||||
f"Uniqueness field not found/not set in attributes: '{user_dn}'",
|
||||
attributes=attributes.keys(),
|
||||
dn=user_dn,
|
||||
|
@ -30,7 +30,7 @@ class FreeIPA(BaseLDAPSynchronizer):
|
||||
pwd_last_set: datetime = attributes.get("krbLastPwdChange", datetime.now())
|
||||
pwd_last_set = pwd_last_set.replace(tzinfo=UTC)
|
||||
if created or pwd_last_set >= user.password_change_date:
|
||||
self.message(f"'{user.username}': Reset user's password")
|
||||
self._task.info(f"'{user.username}': Reset user's password")
|
||||
self._logger.debug(
|
||||
"Reset user's password",
|
||||
user=user.username,
|
||||
|
2
authentik/sources/ldap/sync/vendor/ms_ad.py
vendored
2
authentik/sources/ldap/sync/vendor/ms_ad.py
vendored
@ -60,7 +60,7 @@ class MicrosoftActiveDirectory(BaseLDAPSynchronizer):
|
||||
pwd_last_set: datetime = attributes.get("pwdLastSet", datetime.now())
|
||||
pwd_last_set = pwd_last_set.replace(tzinfo=UTC)
|
||||
if created or pwd_last_set >= user.password_change_date:
|
||||
self.message(f"'{user.username}': Reset user's password")
|
||||
self._task.info(f"'{user.username}': Reset user's password")
|
||||
self._logger.debug(
|
||||
"Reset user's password",
|
||||
user=user.username,
|
||||
|
@ -72,14 +72,12 @@ def ldap_sync(source_pk: str):
|
||||
)
|
||||
|
||||
# User and group sync can happen at once, they have no dependencies on each other
|
||||
user_group_tasks.run().get_results(
|
||||
block=True,
|
||||
timeout=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours") * 1000,
|
||||
user_group_tasks.run().wait(
|
||||
timeout=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours") * 1000
|
||||
)
|
||||
# Membership sync needs to run afterwards
|
||||
membership_tasks.run().get_results(
|
||||
block=True,
|
||||
timeout=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours") * 1000,
|
||||
membership_tasks.run().wait(
|
||||
timeout=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours") * 1000
|
||||
)
|
||||
# Finally, deletions. What we'd really like to do here is something like
|
||||
# ```
|
||||
@ -96,7 +94,9 @@ def ldap_sync(source_pk: str):
|
||||
# large chunks, and only queue the deletion step afterwards.
|
||||
# 3. Delete every unmarked item. This is slow, so we spread it over many tasks in
|
||||
# small chunks.
|
||||
deletion_tasks.run() # no need to block here, we don't have anything else to do afterwards
|
||||
deletion_tasks.run().wait(
|
||||
timeout=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours") * 1000,
|
||||
)
|
||||
|
||||
|
||||
def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) -> list[Message]:
|
||||
@ -139,9 +139,7 @@ def ldap_sync_page(source_pk: str, sync_class: str, page_cache_key: str):
|
||||
return
|
||||
cache.touch(page_cache_key)
|
||||
count = sync_inst.sync(page)
|
||||
messages = sync_inst.messages
|
||||
messages.append(f"Synced {count} objects.")
|
||||
self.info(*messages)
|
||||
self.info(f"Synced {count} objects.")
|
||||
cache.delete(page_cache_key)
|
||||
except (LDAPException, StopSync) as exc:
|
||||
# No explicit event is created here as .set_status with an error will do that
|
||||
|
@ -13,6 +13,7 @@ from authentik.sources.ldap.models import LDAPSource, LDAPSourcePropertyMapping
|
||||
from authentik.sources.ldap.sync.users import UserLDAPSynchronizer
|
||||
from authentik.sources.ldap.tests.mock_ad import mock_ad_connection
|
||||
from authentik.sources.ldap.tests.mock_slapd import mock_slapd_connection
|
||||
from authentik.tasks.models import Task
|
||||
|
||||
LDAP_PASSWORD = generate_key()
|
||||
|
||||
@ -43,7 +44,7 @@ class LDAPSyncTests(TestCase):
|
||||
raw_conn.bind = bind_mock
|
||||
connection = MagicMock(return_value=raw_conn)
|
||||
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
|
||||
user_sync = UserLDAPSynchronizer(self.source)
|
||||
user_sync = UserLDAPSynchronizer(self.source, Task())
|
||||
user_sync.sync_full()
|
||||
|
||||
user = User.objects.get(username="user0_sn")
|
||||
@ -71,7 +72,7 @@ class LDAPSyncTests(TestCase):
|
||||
)
|
||||
connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD))
|
||||
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
|
||||
user_sync = UserLDAPSynchronizer(self.source)
|
||||
user_sync = UserLDAPSynchronizer(self.source, Task())
|
||||
user_sync.sync_full()
|
||||
|
||||
user = User.objects.get(username="user0_sn")
|
||||
@ -98,7 +99,7 @@ class LDAPSyncTests(TestCase):
|
||||
self.source.save()
|
||||
connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD))
|
||||
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
|
||||
user_sync = UserLDAPSynchronizer(self.source)
|
||||
user_sync = UserLDAPSynchronizer(self.source, Task())
|
||||
user_sync.sync_full()
|
||||
|
||||
user = User.objects.get(username="user0_sn")
|
||||
|
@ -23,7 +23,7 @@ from authentik.sources.ldap.sync.forward_delete_users import DELETE_CHUNK_SIZE
|
||||
from authentik.sources.ldap.sync.groups import GroupLDAPSynchronizer
|
||||
from authentik.sources.ldap.sync.membership import MembershipLDAPSynchronizer
|
||||
from authentik.sources.ldap.sync.users import UserLDAPSynchronizer
|
||||
from authentik.sources.ldap.tasks import ldap_sync, ldap_sync_all, ldap_sync_page
|
||||
from authentik.sources.ldap.tasks import ldap_sync, ldap_sync_page
|
||||
from authentik.sources.ldap.tests.mock_ad import mock_ad_connection
|
||||
from authentik.sources.ldap.tests.mock_freeipa import mock_freeipa_connection
|
||||
from authentik.sources.ldap.tests.mock_slapd import (
|
||||
@ -33,6 +33,7 @@ from authentik.sources.ldap.tests.mock_slapd import (
|
||||
user_in_slapd_cn,
|
||||
user_in_slapd_uid,
|
||||
)
|
||||
from authentik.tasks.models import Task
|
||||
|
||||
LDAP_PASSWORD = generate_key()
|
||||
|
||||
@ -74,7 +75,7 @@ class LDAPSyncTests(TestCase):
|
||||
self.source.save()
|
||||
connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD))
|
||||
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
|
||||
user_sync = UserLDAPSynchronizer(self.source)
|
||||
user_sync = UserLDAPSynchronizer(self.source, Task())
|
||||
with self.assertRaises(StopSync):
|
||||
user_sync.sync_full()
|
||||
self.assertFalse(User.objects.filter(username="user0_sn").exists())
|
||||
@ -105,7 +106,7 @@ class LDAPSyncTests(TestCase):
|
||||
|
||||
# we basically just test that the mappings don't throw errors
|
||||
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
|
||||
user_sync = UserLDAPSynchronizer(self.source)
|
||||
user_sync = UserLDAPSynchronizer(self.source, Task())
|
||||
user_sync.sync_full()
|
||||
|
||||
def test_sync_users_ad(self):
|
||||
@ -133,7 +134,7 @@ class LDAPSyncTests(TestCase):
|
||||
)
|
||||
|
||||
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
|
||||
user_sync = UserLDAPSynchronizer(self.source)
|
||||
user_sync = UserLDAPSynchronizer(self.source, Task())
|
||||
user_sync.sync_full()
|
||||
user = User.objects.filter(username="user0_sn").first()
|
||||
self.assertEqual(user.attributes["foo"], "bar")
|
||||
@ -152,7 +153,7 @@ class LDAPSyncTests(TestCase):
|
||||
)
|
||||
connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD))
|
||||
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
|
||||
user_sync = UserLDAPSynchronizer(self.source)
|
||||
user_sync = UserLDAPSynchronizer(self.source, Task())
|
||||
user_sync.sync_full()
|
||||
self.assertTrue(User.objects.filter(username="user0_sn").exists())
|
||||
self.assertFalse(User.objects.filter(username="user1_sn").exists())
|
||||
@ -168,7 +169,7 @@ class LDAPSyncTests(TestCase):
|
||||
)
|
||||
connection = MagicMock(return_value=mock_freeipa_connection(LDAP_PASSWORD))
|
||||
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
|
||||
user_sync = UserLDAPSynchronizer(self.source)
|
||||
user_sync = UserLDAPSynchronizer(self.source, Task())
|
||||
user_sync.sync_full()
|
||||
self.assertTrue(User.objects.filter(username="user0_sn").exists())
|
||||
self.assertFalse(User.objects.filter(username="user1_sn").exists())
|
||||
@ -193,11 +194,11 @@ class LDAPSyncTests(TestCase):
|
||||
)
|
||||
connection = MagicMock(return_value=mock_freeipa_connection(LDAP_PASSWORD))
|
||||
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
|
||||
user_sync = UserLDAPSynchronizer(self.source)
|
||||
user_sync = UserLDAPSynchronizer(self.source, Task())
|
||||
user_sync.sync_full()
|
||||
group_sync = GroupLDAPSynchronizer(self.source)
|
||||
group_sync = GroupLDAPSynchronizer(self.source, Task())
|
||||
group_sync.sync_full()
|
||||
membership_sync = MembershipLDAPSynchronizer(self.source)
|
||||
membership_sync = MembershipLDAPSynchronizer(self.source, Task())
|
||||
membership_sync.sync_full()
|
||||
|
||||
self.assertTrue(
|
||||
@ -230,9 +231,9 @@ class LDAPSyncTests(TestCase):
|
||||
parent_group = Group.objects.get(name=_user.username)
|
||||
self.source.sync_parent_group = parent_group
|
||||
self.source.save()
|
||||
group_sync = GroupLDAPSynchronizer(self.source)
|
||||
group_sync = GroupLDAPSynchronizer(self.source, Task())
|
||||
group_sync.sync_full()
|
||||
membership_sync = MembershipLDAPSynchronizer(self.source)
|
||||
membership_sync = MembershipLDAPSynchronizer(self.source, Task())
|
||||
membership_sync.sync_full()
|
||||
group: Group = Group.objects.filter(name="test-group").first()
|
||||
self.assertIsNotNone(group)
|
||||
@ -256,9 +257,9 @@ class LDAPSyncTests(TestCase):
|
||||
connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD))
|
||||
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
|
||||
self.source.save()
|
||||
group_sync = GroupLDAPSynchronizer(self.source)
|
||||
group_sync = GroupLDAPSynchronizer(self.source, Task())
|
||||
group_sync.sync_full()
|
||||
membership_sync = MembershipLDAPSynchronizer(self.source)
|
||||
membership_sync = MembershipLDAPSynchronizer(self.source, Task())
|
||||
membership_sync.sync_full()
|
||||
group = Group.objects.filter(name="group1")
|
||||
self.assertTrue(group.exists())
|
||||
@ -290,11 +291,11 @@ class LDAPSyncTests(TestCase):
|
||||
connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD))
|
||||
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
|
||||
self.source.save()
|
||||
user_sync = UserLDAPSynchronizer(self.source)
|
||||
user_sync = UserLDAPSynchronizer(self.source, Task())
|
||||
user_sync.sync_full()
|
||||
group_sync = GroupLDAPSynchronizer(self.source)
|
||||
group_sync = GroupLDAPSynchronizer(self.source, Task())
|
||||
group_sync.sync_full()
|
||||
membership_sync = MembershipLDAPSynchronizer(self.source)
|
||||
membership_sync = MembershipLDAPSynchronizer(self.source, Task())
|
||||
membership_sync.sync_full()
|
||||
# Test if membership mapping based on memberUid works.
|
||||
posix_group = Group.objects.filter(name="group-posix").first()
|
||||
@ -327,11 +328,11 @@ class LDAPSyncTests(TestCase):
|
||||
connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD))
|
||||
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
|
||||
self.source.save()
|
||||
user_sync = UserLDAPSynchronizer(self.source)
|
||||
user_sync = UserLDAPSynchronizer(self.source, Task())
|
||||
user_sync.sync_full()
|
||||
group_sync = GroupLDAPSynchronizer(self.source)
|
||||
group_sync = GroupLDAPSynchronizer(self.source, Task())
|
||||
group_sync.sync_full()
|
||||
membership_sync = MembershipLDAPSynchronizer(self.source)
|
||||
membership_sync = MembershipLDAPSynchronizer(self.source, Task())
|
||||
membership_sync.sync_full()
|
||||
# Test if membership mapping based on memberUid works.
|
||||
posix_group = Group.objects.filter(name="group-posix").first()
|
||||
|
@ -112,23 +112,20 @@ class Task(SerializerModel):
|
||||
if save:
|
||||
self.save()
|
||||
|
||||
def log(self, status: TaskStatus, *messages: str | LogEvent | Exception, save: bool = False):
|
||||
def log(self, status: TaskStatus, message: str | Exception, save: bool = False, **attributes):
|
||||
self.messages: list
|
||||
for msg in messages:
|
||||
message = msg
|
||||
if isinstance(message, Exception):
|
||||
message = exception_to_string(message)
|
||||
if not isinstance(message, LogEvent):
|
||||
message = LogEvent(message, logger=self.uid, log_level=status.value)
|
||||
self.messages.append(sanitize_item(message))
|
||||
if isinstance(message, Exception):
|
||||
message = exception_to_string(message)
|
||||
message = LogEvent(message, logger=self.uid, log_level=status.value, attributes=attributes)
|
||||
self.messages.append(sanitize_item(message))
|
||||
if save:
|
||||
self.save()
|
||||
|
||||
def info(self, *messages: str | LogEvent | Exception, save: bool = False):
|
||||
self.log(TaskStatus.INFO, *messages, save=save)
|
||||
def info(self, message: str | Exception, save: bool = False, **attributes):
|
||||
self.log(TaskStatus.INFO, message, save=save, **attributes)
|
||||
|
||||
def warning(self, *messages: str | LogEvent | Exception, save: bool = False):
|
||||
self.log(TaskStatus.WARNING, *messages, save=save)
|
||||
def warning(self, message: str | Exception, save: bool = False, **attributes):
|
||||
self.log(TaskStatus.WARNING, message, save=save, **attributes)
|
||||
|
||||
def error(self, *messages: str | LogEvent | Exception, save: bool = False):
|
||||
self.log(TaskStatus.ERROR, *messages, save=save)
|
||||
def error(self, message: str | Exception, save: bool = False, **attributes):
|
||||
self.log(TaskStatus.ERROR, message, save=save, **attributes)
|
||||
|
Reference in New Issue
Block a user