sources/ldap: improve scalability (#6056)
* sources/ldap: improve scalability Signed-off-by: Jens Langhammer <jens@goauthentik.io> * fix tests Signed-off-by: Jens Langhammer <jens@goauthentik.io> * fix lint Signed-off-by: Jens Langhammer <jens@goauthentik.io> * use cache instead of call signature for page data Signed-off-by: Jens Langhammer <jens@goauthentik.io> --------- Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
		| @ -70,8 +70,10 @@ class TaskInfo: | |||||||
|         return cache.get_many(cache.keys(CACHE_KEY_PREFIX + "*")) |         return cache.get_many(cache.keys(CACHE_KEY_PREFIX + "*")) | ||||||
|  |  | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def by_name(name: str) -> Optional["TaskInfo"]: |     def by_name(name: str) -> Optional["TaskInfo"] | Optional[list["TaskInfo"]]: | ||||||
|         """Get TaskInfo Object by name""" |         """Get TaskInfo Object by name""" | ||||||
|  |         if "*" in name: | ||||||
|  |             return cache.get_many(cache.keys(CACHE_KEY_PREFIX + name)).values() | ||||||
|         return cache.get(CACHE_KEY_PREFIX + name, None) |         return cache.get(CACHE_KEY_PREFIX + name, None) | ||||||
|  |  | ||||||
|     def delete(self): |     def delete(self): | ||||||
|  | |||||||
| @ -118,10 +118,9 @@ class LDAPSourceViewSet(UsedByMixin, ModelViewSet): | |||||||
|         """Get source's sync status""" |         """Get source's sync status""" | ||||||
|         source = self.get_object() |         source = self.get_object() | ||||||
|         results = [] |         results = [] | ||||||
|         for sync_class in SYNC_CLASSES: |         tasks = TaskInfo.by_name(f"ldap_sync:{source.slug}:*") | ||||||
|             sync_name = sync_class.__name__.replace("LDAPSynchronizer", "").lower() |         if tasks: | ||||||
|             task = TaskInfo.by_name(f"ldap_sync:{source.slug}:{sync_name}") |             for task in tasks: | ||||||
|             if task: |  | ||||||
|                 results.append(task) |                 results.append(task) | ||||||
|         return Response(TaskSerializer(results, many=True).data) |         return Response(TaskSerializer(results, many=True).data) | ||||||
|  |  | ||||||
| @ -143,7 +142,7 @@ class LDAPSourceViewSet(UsedByMixin, ModelViewSet): | |||||||
|         source = self.get_object() |         source = self.get_object() | ||||||
|         all_objects = {} |         all_objects = {} | ||||||
|         for sync_class in SYNC_CLASSES: |         for sync_class in SYNC_CLASSES: | ||||||
|             class_name = sync_class.__name__.replace("LDAPSynchronizer", "").lower() |             class_name = sync_class.name() | ||||||
|             all_objects.setdefault(class_name, []) |             all_objects.setdefault(class_name, []) | ||||||
|             for obj in sync_class(source).get_objects(size_limit=10): |             for obj in sync_class(source).get_objects(size_limit=10): | ||||||
|                 obj: dict |                 obj: dict | ||||||
|  | |||||||
| @ -2,9 +2,8 @@ | |||||||
| from django.core.management.base import BaseCommand | from django.core.management.base import BaseCommand | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik.lib.utils.reflection import class_to_path |  | ||||||
| from authentik.sources.ldap.models import LDAPSource | from authentik.sources.ldap.models import LDAPSource | ||||||
| from authentik.sources.ldap.tasks import SYNC_CLASSES, ldap_sync | from authentik.sources.ldap.tasks import ldap_sync_single | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  |  | ||||||
| @ -21,7 +20,4 @@ class Command(BaseCommand): | |||||||
|             if not source: |             if not source: | ||||||
|                 LOGGER.warning("Source does not exist", slug=source_slug) |                 LOGGER.warning("Source does not exist", slug=source_slug) | ||||||
|                 continue |                 continue | ||||||
|             for sync_class in SYNC_CLASSES: |             ldap_sync_single(source) | ||||||
|                 LOGGER.info("Starting sync", cls=sync_class) |  | ||||||
|                 # pylint: disable=no-value-for-parameter |  | ||||||
|                 ldap_sync(source.pk, class_to_path(sync_class)) |  | ||||||
|  | |||||||
| @ -12,13 +12,9 @@ from authentik.core.models import User | |||||||
| from authentik.core.signals import password_changed | from authentik.core.signals import password_changed | ||||||
| from authentik.events.models import Event, EventAction | from authentik.events.models import Event, EventAction | ||||||
| from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER | from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER | ||||||
| from authentik.lib.utils.reflection import class_to_path |  | ||||||
| from authentik.sources.ldap.models import LDAPSource | from authentik.sources.ldap.models import LDAPSource | ||||||
| from authentik.sources.ldap.password import LDAPPasswordChanger | from authentik.sources.ldap.password import LDAPPasswordChanger | ||||||
| from authentik.sources.ldap.sync.groups import GroupLDAPSynchronizer | from authentik.sources.ldap.tasks import ldap_sync_single | ||||||
| from authentik.sources.ldap.sync.membership import MembershipLDAPSynchronizer |  | ||||||
| from authentik.sources.ldap.sync.users import UserLDAPSynchronizer |  | ||||||
| from authentik.sources.ldap.tasks import ldap_sync |  | ||||||
| from authentik.stages.prompt.signals import password_validate | from authentik.stages.prompt.signals import password_validate | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
| @ -35,12 +31,7 @@ def sync_ldap_source_on_save(sender, instance: LDAPSource, **_): | |||||||
|     #   and the mappings are created with an m2m event |     #   and the mappings are created with an m2m event | ||||||
|     if not instance.property_mappings.exists() or not instance.property_mappings_group.exists(): |     if not instance.property_mappings.exists() or not instance.property_mappings_group.exists(): | ||||||
|         return |         return | ||||||
|     for sync_class in [ |     ldap_sync_single.delay(instance.pk) | ||||||
|         UserLDAPSynchronizer, |  | ||||||
|         GroupLDAPSynchronizer, |  | ||||||
|         MembershipLDAPSynchronizer, |  | ||||||
|     ]: |  | ||||||
|         ldap_sync.delay(instance.pk, class_to_path(sync_class)) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @receiver(password_validate) | @receiver(password_validate) | ||||||
|  | |||||||
| @ -1,9 +1,10 @@ | |||||||
| """Sync LDAP Users and groups into authentik""" | """Sync LDAP Users and groups into authentik""" | ||||||
| from typing import Any, Generator | from typing import Any, Generator | ||||||
|  |  | ||||||
|  | from django.conf import settings | ||||||
| from django.db.models.base import Model | from django.db.models.base import Model | ||||||
| from django.db.models.query import QuerySet | from django.db.models.query import QuerySet | ||||||
| from ldap3 import Connection | from ldap3 import DEREF_ALWAYS, SUBTREE, Connection | ||||||
| from structlog.stdlib import BoundLogger, get_logger | from structlog.stdlib import BoundLogger, get_logger | ||||||
|  |  | ||||||
| from authentik.core.exceptions import PropertyMappingExpressionException | from authentik.core.exceptions import PropertyMappingExpressionException | ||||||
| @ -29,6 +30,24 @@ class BaseLDAPSynchronizer: | |||||||
|         self._messages = [] |         self._messages = [] | ||||||
|         self._logger = get_logger().bind(source=source, syncer=self.__class__.__name__) |         self._logger = get_logger().bind(source=source, syncer=self.__class__.__name__) | ||||||
|  |  | ||||||
|  |     @staticmethod | ||||||
|  |     def name() -> str: | ||||||
|  |         """UI name for the type of object this class synchronizes""" | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     def sync_full(self): | ||||||
|  |         """Run full sync, this function should only be used in tests""" | ||||||
|  |         if not settings.TEST:  # noqa | ||||||
|  |             raise RuntimeError( | ||||||
|  |                 f"{self.__class__.__name__}.sync_full() should only be used in tests" | ||||||
|  |             ) | ||||||
|  |         for page in self.get_objects(): | ||||||
|  |             self.sync(page) | ||||||
|  |  | ||||||
|  |     def sync(self, page_data: list) -> int: | ||||||
|  |         """Sync function, implemented in subclass""" | ||||||
|  |         raise NotImplementedError() | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def messages(self) -> list[str]: |     def messages(self) -> list[str]: | ||||||
|         """Get all UI messages""" |         """Get all UI messages""" | ||||||
| @ -60,9 +79,47 @@ class BaseLDAPSynchronizer: | |||||||
|         """Get objects from LDAP, implemented in subclass""" |         """Get objects from LDAP, implemented in subclass""" | ||||||
|         raise NotImplementedError() |         raise NotImplementedError() | ||||||
|  |  | ||||||
|     def sync(self) -> int: |     # pylint: disable=too-many-arguments | ||||||
|         """Sync function, implemented in subclass""" |     def search_paginator( | ||||||
|         raise NotImplementedError() |         self, | ||||||
|  |         search_base, | ||||||
|  |         search_filter, | ||||||
|  |         search_scope=SUBTREE, | ||||||
|  |         dereference_aliases=DEREF_ALWAYS, | ||||||
|  |         attributes=None, | ||||||
|  |         size_limit=0, | ||||||
|  |         time_limit=0, | ||||||
|  |         types_only=False, | ||||||
|  |         get_operational_attributes=False, | ||||||
|  |         controls=None, | ||||||
|  |         paged_size=5, | ||||||
|  |         paged_criticality=False, | ||||||
|  |     ): | ||||||
|  |         """Search in pages, returns each page""" | ||||||
|  |         cookie = True | ||||||
|  |         while cookie: | ||||||
|  |             self._connection.search( | ||||||
|  |                 search_base, | ||||||
|  |                 search_filter, | ||||||
|  |                 search_scope, | ||||||
|  |                 dereference_aliases, | ||||||
|  |                 attributes, | ||||||
|  |                 size_limit, | ||||||
|  |                 time_limit, | ||||||
|  |                 types_only, | ||||||
|  |                 get_operational_attributes, | ||||||
|  |                 controls, | ||||||
|  |                 paged_size, | ||||||
|  |                 paged_criticality, | ||||||
|  |                 None if cookie is True else cookie, | ||||||
|  |             ) | ||||||
|  |             try: | ||||||
|  |                 cookie = self._connection.result["controls"]["1.2.840.113556.1.4.319"]["value"][ | ||||||
|  |                     "cookie" | ||||||
|  |                 ] | ||||||
|  |             except KeyError: | ||||||
|  |                 cookie = None | ||||||
|  |             yield self._connection.response | ||||||
|  |  | ||||||
|     def _flatten(self, value: Any) -> Any: |     def _flatten(self, value: Any) -> Any: | ||||||
|         """Flatten `value` if its a list""" |         """Flatten `value` if its a list""" | ||||||
|  | |||||||
| @ -13,8 +13,12 @@ from authentik.sources.ldap.sync.base import LDAP_UNIQUENESS, BaseLDAPSynchroniz | |||||||
| class GroupLDAPSynchronizer(BaseLDAPSynchronizer): | class GroupLDAPSynchronizer(BaseLDAPSynchronizer): | ||||||
|     """Sync LDAP Users and groups into authentik""" |     """Sync LDAP Users and groups into authentik""" | ||||||
|  |  | ||||||
|  |     @staticmethod | ||||||
|  |     def name() -> str: | ||||||
|  |         return "groups" | ||||||
|  |  | ||||||
|     def get_objects(self, **kwargs) -> Generator: |     def get_objects(self, **kwargs) -> Generator: | ||||||
|         return self._connection.extend.standard.paged_search( |         return self.search_paginator( | ||||||
|             search_base=self.base_dn_groups, |             search_base=self.base_dn_groups, | ||||||
|             search_filter=self._source.group_object_filter, |             search_filter=self._source.group_object_filter, | ||||||
|             search_scope=SUBTREE, |             search_scope=SUBTREE, | ||||||
| @ -22,13 +26,13 @@ class GroupLDAPSynchronizer(BaseLDAPSynchronizer): | |||||||
|             **kwargs, |             **kwargs, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def sync(self) -> int: |     def sync(self, page_data: list) -> int: | ||||||
|         """Iterate over all LDAP Groups and create authentik_core.Group instances""" |         """Iterate over all LDAP Groups and create authentik_core.Group instances""" | ||||||
|         if not self._source.sync_groups: |         if not self._source.sync_groups: | ||||||
|             self.message("Group syncing is disabled for this Source") |             self.message("Group syncing is disabled for this Source") | ||||||
|             return -1 |             return -1 | ||||||
|         group_count = 0 |         group_count = 0 | ||||||
|         for group in self.get_objects(): |         for group in page_data: | ||||||
|             if "attributes" not in group: |             if "attributes" not in group: | ||||||
|                 continue |                 continue | ||||||
|             attributes = group.get("attributes", {}) |             attributes = group.get("attributes", {}) | ||||||
|  | |||||||
| @ -19,8 +19,12 @@ class MembershipLDAPSynchronizer(BaseLDAPSynchronizer): | |||||||
|         super().__init__(source) |         super().__init__(source) | ||||||
|         self.group_cache: dict[str, Group] = {} |         self.group_cache: dict[str, Group] = {} | ||||||
|  |  | ||||||
|  |     @staticmethod | ||||||
|  |     def name() -> str: | ||||||
|  |         return "membership" | ||||||
|  |  | ||||||
|     def get_objects(self, **kwargs) -> Generator: |     def get_objects(self, **kwargs) -> Generator: | ||||||
|         return self._connection.extend.standard.paged_search( |         return self.search_paginator( | ||||||
|             search_base=self.base_dn_groups, |             search_base=self.base_dn_groups, | ||||||
|             search_filter=self._source.group_object_filter, |             search_filter=self._source.group_object_filter, | ||||||
|             search_scope=SUBTREE, |             search_scope=SUBTREE, | ||||||
| @ -32,13 +36,13 @@ class MembershipLDAPSynchronizer(BaseLDAPSynchronizer): | |||||||
|             **kwargs, |             **kwargs, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def sync(self) -> int: |     def sync(self, page_data: list) -> int: | ||||||
|         """Iterate over all Users and assign Groups using memberOf Field""" |         """Iterate over all Users and assign Groups using memberOf Field""" | ||||||
|         if not self._source.sync_groups: |         if not self._source.sync_groups: | ||||||
|             self.message("Group syncing is disabled for this Source") |             self.message("Group syncing is disabled for this Source") | ||||||
|             return -1 |             return -1 | ||||||
|         membership_count = 0 |         membership_count = 0 | ||||||
|         for group in self.get_objects(): |         for group in page_data: | ||||||
|             if "attributes" not in group: |             if "attributes" not in group: | ||||||
|                 continue |                 continue | ||||||
|             members = group.get("attributes", {}).get(self._source.group_membership_field, []) |             members = group.get("attributes", {}).get(self._source.group_membership_field, []) | ||||||
|  | |||||||
| @ -15,8 +15,12 @@ from authentik.sources.ldap.sync.vendor.ms_ad import MicrosoftActiveDirectory | |||||||
| class UserLDAPSynchronizer(BaseLDAPSynchronizer): | class UserLDAPSynchronizer(BaseLDAPSynchronizer): | ||||||
|     """Sync LDAP Users into authentik""" |     """Sync LDAP Users into authentik""" | ||||||
|  |  | ||||||
|  |     @staticmethod | ||||||
|  |     def name() -> str: | ||||||
|  |         return "users" | ||||||
|  |  | ||||||
|     def get_objects(self, **kwargs) -> Generator: |     def get_objects(self, **kwargs) -> Generator: | ||||||
|         return self._connection.extend.standard.paged_search( |         return self.search_paginator( | ||||||
|             search_base=self.base_dn_users, |             search_base=self.base_dn_users, | ||||||
|             search_filter=self._source.user_object_filter, |             search_filter=self._source.user_object_filter, | ||||||
|             search_scope=SUBTREE, |             search_scope=SUBTREE, | ||||||
| @ -24,13 +28,13 @@ class UserLDAPSynchronizer(BaseLDAPSynchronizer): | |||||||
|             **kwargs, |             **kwargs, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def sync(self) -> int: |     def sync(self, page_data: list) -> int: | ||||||
|         """Iterate over all LDAP Users and create authentik_core.User instances""" |         """Iterate over all LDAP Users and create authentik_core.User instances""" | ||||||
|         if not self._source.sync_users: |         if not self._source.sync_users: | ||||||
|             self.message("User syncing is disabled for this Source") |             self.message("User syncing is disabled for this Source") | ||||||
|             return -1 |             return -1 | ||||||
|         user_count = 0 |         user_count = 0 | ||||||
|         for user in self.get_objects(): |         for user in page_data: | ||||||
|             if "attributes" not in user: |             if "attributes" not in user: | ||||||
|                 continue |                 continue | ||||||
|             attributes = user.get("attributes", {}) |             attributes = user.get("attributes", {}) | ||||||
|  | |||||||
| @ -11,6 +11,10 @@ from authentik.sources.ldap.sync.base import BaseLDAPSynchronizer | |||||||
| class FreeIPA(BaseLDAPSynchronizer): | class FreeIPA(BaseLDAPSynchronizer): | ||||||
|     """FreeIPA-specific LDAP""" |     """FreeIPA-specific LDAP""" | ||||||
|  |  | ||||||
|  |     @staticmethod | ||||||
|  |     def name() -> str: | ||||||
|  |         return "freeipa" | ||||||
|  |  | ||||||
|     def get_objects(self, **kwargs) -> Generator: |     def get_objects(self, **kwargs) -> Generator: | ||||||
|         yield None |         yield None | ||||||
|  |  | ||||||
|  | |||||||
							
								
								
									
										4
									
								
								authentik/sources/ldap/sync/vendor/ms_ad.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								authentik/sources/ldap/sync/vendor/ms_ad.py
									
									
									
									
										vendored
									
									
								
							| @ -42,6 +42,10 @@ class UserAccountControl(IntFlag): | |||||||
| class MicrosoftActiveDirectory(BaseLDAPSynchronizer): | class MicrosoftActiveDirectory(BaseLDAPSynchronizer): | ||||||
|     """Microsoft-specific LDAP""" |     """Microsoft-specific LDAP""" | ||||||
|  |  | ||||||
|  |     @staticmethod | ||||||
|  |     def name() -> str: | ||||||
|  |         return "microsoft_ad" | ||||||
|  |  | ||||||
|     def get_objects(self, **kwargs) -> Generator: |     def get_objects(self, **kwargs) -> Generator: | ||||||
|         yield None |         yield None | ||||||
|  |  | ||||||
|  | |||||||
| @ -1,4 +1,8 @@ | |||||||
| """LDAP Sync tasks""" | """LDAP Sync tasks""" | ||||||
|  | from uuid import uuid4 | ||||||
|  |  | ||||||
|  | from celery import chain, group | ||||||
|  | from django.core.cache import cache | ||||||
| from ldap3.core.exceptions import LDAPException | from ldap3.core.exceptions import LDAPException | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| @ -8,6 +12,7 @@ from authentik.lib.utils.errors import exception_to_string | |||||||
| from authentik.lib.utils.reflection import class_to_path, path_to_class | from authentik.lib.utils.reflection import class_to_path, path_to_class | ||||||
| from authentik.root.celery import CELERY_APP | from authentik.root.celery import CELERY_APP | ||||||
| from authentik.sources.ldap.models import LDAPSource | from authentik.sources.ldap.models import LDAPSource | ||||||
|  | from authentik.sources.ldap.sync.base import BaseLDAPSynchronizer | ||||||
| from authentik.sources.ldap.sync.groups import GroupLDAPSynchronizer | from authentik.sources.ldap.sync.groups import GroupLDAPSynchronizer | ||||||
| from authentik.sources.ldap.sync.membership import MembershipLDAPSynchronizer | from authentik.sources.ldap.sync.membership import MembershipLDAPSynchronizer | ||||||
| from authentik.sources.ldap.sync.users import UserLDAPSynchronizer | from authentik.sources.ldap.sync.users import UserLDAPSynchronizer | ||||||
| @ -18,14 +23,43 @@ SYNC_CLASSES = [ | |||||||
|     GroupLDAPSynchronizer, |     GroupLDAPSynchronizer, | ||||||
|     MembershipLDAPSynchronizer, |     MembershipLDAPSynchronizer, | ||||||
| ] | ] | ||||||
|  | CACHE_KEY_PREFIX = "goauthentik.io/sources/ldap/page/" | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task() | @CELERY_APP.task() | ||||||
| def ldap_sync_all(): | def ldap_sync_all(): | ||||||
|     """Sync all sources""" |     """Sync all sources""" | ||||||
|     for source in LDAPSource.objects.filter(enabled=True): |     for source in LDAPSource.objects.filter(enabled=True): | ||||||
|         for sync_class in SYNC_CLASSES: |         ldap_sync_single(source) | ||||||
|             ldap_sync.delay(source.pk, class_to_path(sync_class)) |  | ||||||
|  |  | ||||||
|  | @CELERY_APP.task() | ||||||
|  | def ldap_sync_single(source: LDAPSource): | ||||||
|  |     """Sync a single source""" | ||||||
|  |     task = chain( | ||||||
|  |         # User and group sync can happen at once, they have no dependencies on each other | ||||||
|  |         group( | ||||||
|  |             ldap_sync_paginator(source, UserLDAPSynchronizer) | ||||||
|  |             + ldap_sync_paginator(source, GroupLDAPSynchronizer), | ||||||
|  |         ), | ||||||
|  |         # Membership sync needs to run afterwards | ||||||
|  |         group( | ||||||
|  |             ldap_sync_paginator(source, MembershipLDAPSynchronizer), | ||||||
|  |         ), | ||||||
|  |     ) | ||||||
|  |     task() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) -> list: | ||||||
|  |     """Return a list of task signatures with LDAP pagination data""" | ||||||
|  |     sync_inst: BaseLDAPSynchronizer = sync(source) | ||||||
|  |     signatures = [] | ||||||
|  |     for page in sync_inst.get_objects(): | ||||||
|  |         page_cache_key = CACHE_KEY_PREFIX + str(uuid4()) | ||||||
|  |         cache.set(page_cache_key, page) | ||||||
|  |         page_sync = ldap_sync.si(source.pk, class_to_path(sync), page_cache_key) | ||||||
|  |         signatures.append(page_sync) | ||||||
|  |     return signatures | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task( | @CELERY_APP.task( | ||||||
| @ -34,7 +68,7 @@ def ldap_sync_all(): | |||||||
|     soft_time_limit=60 * 60 * int(CONFIG.y("ldap.task_timeout_hours")), |     soft_time_limit=60 * 60 * int(CONFIG.y("ldap.task_timeout_hours")), | ||||||
|     task_time_limit=60 * 60 * int(CONFIG.y("ldap.task_timeout_hours")), |     task_time_limit=60 * 60 * int(CONFIG.y("ldap.task_timeout_hours")), | ||||||
| ) | ) | ||||||
| def ldap_sync(self: MonitoredTask, source_pk: str, sync_class: str): | def ldap_sync(self: MonitoredTask, source_pk: str, sync_class: str, page_cache_key: str): | ||||||
|     """Synchronization of an LDAP Source""" |     """Synchronization of an LDAP Source""" | ||||||
|     self.result_timeout_hours = int(CONFIG.y("ldap.task_timeout_hours")) |     self.result_timeout_hours = int(CONFIG.y("ldap.task_timeout_hours")) | ||||||
|     try: |     try: | ||||||
| @ -43,11 +77,16 @@ def ldap_sync(self: MonitoredTask, source_pk: str, sync_class: str): | |||||||
|         # Because the source couldn't be found, we don't have a UID |         # Because the source couldn't be found, we don't have a UID | ||||||
|         # to set the state with |         # to set the state with | ||||||
|         return |         return | ||||||
|     sync = path_to_class(sync_class) |     sync: type[BaseLDAPSynchronizer] = path_to_class(sync_class) | ||||||
|     self.set_uid(f"{source.slug}:{sync.__name__.replace('LDAPSynchronizer', '').lower()}") |     uid = page_cache_key.replace(CACHE_KEY_PREFIX, "") | ||||||
|  |     self.set_uid(f"{source.slug}:{sync.name()}:{uid}") | ||||||
|     try: |     try: | ||||||
|         sync_inst = sync(source) |         sync_inst: BaseLDAPSynchronizer = sync(source) | ||||||
|         count = sync_inst.sync() |         page = cache.get(page_cache_key) | ||||||
|  |         if not page: | ||||||
|  |             return | ||||||
|  |         cache.touch(page_cache_key) | ||||||
|  |         count = sync_inst.sync(page) | ||||||
|         messages = sync_inst.messages |         messages = sync_inst.messages | ||||||
|         messages.append(f"Synced {count} objects.") |         messages.append(f"Synced {count} objects.") | ||||||
|         self.set_status( |         self.set_status( | ||||||
| @ -56,6 +95,7 @@ def ldap_sync(self: MonitoredTask, source_pk: str, sync_class: str): | |||||||
|                 messages, |                 messages, | ||||||
|             ) |             ) | ||||||
|         ) |         ) | ||||||
|  |         cache.delete(page_cache_key) | ||||||
|     except LDAPException as exc: |     except LDAPException as exc: | ||||||
|         # No explicit event is created here as .set_status with an error will do that |         # No explicit event is created here as .set_status with an error will do that | ||||||
|         LOGGER.warning(exception_to_string(exc)) |         LOGGER.warning(exception_to_string(exc)) | ||||||
|  | |||||||
| @ -43,7 +43,7 @@ class LDAPSyncTests(TestCase): | |||||||
|         connection = MagicMock(return_value=raw_conn) |         connection = MagicMock(return_value=raw_conn) | ||||||
|         with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): |         with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): | ||||||
|             user_sync = UserLDAPSynchronizer(self.source) |             user_sync = UserLDAPSynchronizer(self.source) | ||||||
|             user_sync.sync() |             user_sync.sync_full() | ||||||
|  |  | ||||||
|             user = User.objects.get(username="user0_sn") |             user = User.objects.get(username="user0_sn") | ||||||
|             # auth_user_by_bind = Mock(return_value=user) |             # auth_user_by_bind = Mock(return_value=user) | ||||||
| @ -71,7 +71,7 @@ class LDAPSyncTests(TestCase): | |||||||
|         connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD)) |         connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD)) | ||||||
|         with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): |         with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): | ||||||
|             user_sync = UserLDAPSynchronizer(self.source) |             user_sync = UserLDAPSynchronizer(self.source) | ||||||
|             user_sync.sync() |             user_sync.sync_full() | ||||||
|  |  | ||||||
|             user = User.objects.get(username="user0_sn") |             user = User.objects.get(username="user0_sn") | ||||||
|             auth_user_by_bind = Mock(return_value=user) |             auth_user_by_bind = Mock(return_value=user) | ||||||
| @ -98,7 +98,7 @@ class LDAPSyncTests(TestCase): | |||||||
|         connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD)) |         connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD)) | ||||||
|         with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): |         with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): | ||||||
|             user_sync = UserLDAPSynchronizer(self.source) |             user_sync = UserLDAPSynchronizer(self.source) | ||||||
|             user_sync.sync() |             user_sync.sync_full() | ||||||
|  |  | ||||||
|             user = User.objects.get(username="user0_sn") |             user = User.objects.get(username="user0_sn") | ||||||
|             auth_user_by_bind = Mock(return_value=user) |             auth_user_by_bind = Mock(return_value=user) | ||||||
|  | |||||||
| @ -51,7 +51,7 @@ class LDAPSyncTests(TestCase): | |||||||
|         connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD)) |         connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD)) | ||||||
|         with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): |         with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): | ||||||
|             user_sync = UserLDAPSynchronizer(self.source) |             user_sync = UserLDAPSynchronizer(self.source) | ||||||
|             user_sync.sync() |             user_sync.sync_full() | ||||||
|             self.assertFalse(User.objects.filter(username="user0_sn").exists()) |             self.assertFalse(User.objects.filter(username="user0_sn").exists()) | ||||||
|             self.assertFalse(User.objects.filter(username="user1_sn").exists()) |             self.assertFalse(User.objects.filter(username="user1_sn").exists()) | ||||||
|         events = Event.objects.filter( |         events = Event.objects.filter( | ||||||
| @ -87,7 +87,7 @@ class LDAPSyncTests(TestCase): | |||||||
|  |  | ||||||
|         with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): |         with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): | ||||||
|             user_sync = UserLDAPSynchronizer(self.source) |             user_sync = UserLDAPSynchronizer(self.source) | ||||||
|             user_sync.sync() |             user_sync.sync_full() | ||||||
|             user = User.objects.filter(username="user0_sn").first() |             user = User.objects.filter(username="user0_sn").first() | ||||||
|             self.assertEqual(user.attributes["foo"], "bar") |             self.assertEqual(user.attributes["foo"], "bar") | ||||||
|             self.assertFalse(user.is_active) |             self.assertFalse(user.is_active) | ||||||
| @ -106,7 +106,7 @@ class LDAPSyncTests(TestCase): | |||||||
|         connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD)) |         connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD)) | ||||||
|         with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): |         with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): | ||||||
|             user_sync = UserLDAPSynchronizer(self.source) |             user_sync = UserLDAPSynchronizer(self.source) | ||||||
|             user_sync.sync() |             user_sync.sync_full() | ||||||
|             self.assertTrue(User.objects.filter(username="user0_sn").exists()) |             self.assertTrue(User.objects.filter(username="user0_sn").exists()) | ||||||
|             self.assertFalse(User.objects.filter(username="user1_sn").exists()) |             self.assertFalse(User.objects.filter(username="user1_sn").exists()) | ||||||
|  |  | ||||||
| @ -128,9 +128,9 @@ class LDAPSyncTests(TestCase): | |||||||
|             self.source.sync_parent_group = parent_group |             self.source.sync_parent_group = parent_group | ||||||
|             self.source.save() |             self.source.save() | ||||||
|             group_sync = GroupLDAPSynchronizer(self.source) |             group_sync = GroupLDAPSynchronizer(self.source) | ||||||
|             group_sync.sync() |             group_sync.sync_full() | ||||||
|             membership_sync = MembershipLDAPSynchronizer(self.source) |             membership_sync = MembershipLDAPSynchronizer(self.source) | ||||||
|             membership_sync.sync() |             membership_sync.sync_full() | ||||||
|             group: Group = Group.objects.filter(name="test-group").first() |             group: Group = Group.objects.filter(name="test-group").first() | ||||||
|             self.assertIsNotNone(group) |             self.assertIsNotNone(group) | ||||||
|             self.assertEqual(group.parent, parent_group) |             self.assertEqual(group.parent, parent_group) | ||||||
| @ -152,9 +152,9 @@ class LDAPSyncTests(TestCase): | |||||||
|         with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): |         with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): | ||||||
|             self.source.save() |             self.source.save() | ||||||
|             group_sync = GroupLDAPSynchronizer(self.source) |             group_sync = GroupLDAPSynchronizer(self.source) | ||||||
|             group_sync.sync() |             group_sync.sync_full() | ||||||
|             membership_sync = MembershipLDAPSynchronizer(self.source) |             membership_sync = MembershipLDAPSynchronizer(self.source) | ||||||
|             membership_sync.sync() |             membership_sync.sync_full() | ||||||
|             group = Group.objects.filter(name="group1") |             group = Group.objects.filter(name="group1") | ||||||
|             self.assertTrue(group.exists()) |             self.assertTrue(group.exists()) | ||||||
|  |  | ||||||
| @ -177,11 +177,11 @@ class LDAPSyncTests(TestCase): | |||||||
|         with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): |         with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): | ||||||
|             self.source.save() |             self.source.save() | ||||||
|             user_sync = UserLDAPSynchronizer(self.source) |             user_sync = UserLDAPSynchronizer(self.source) | ||||||
|             user_sync.sync() |             user_sync.sync_full() | ||||||
|             group_sync = GroupLDAPSynchronizer(self.source) |             group_sync = GroupLDAPSynchronizer(self.source) | ||||||
|             group_sync.sync() |             group_sync.sync_full() | ||||||
|             membership_sync = MembershipLDAPSynchronizer(self.source) |             membership_sync = MembershipLDAPSynchronizer(self.source) | ||||||
|             membership_sync.sync() |             membership_sync.sync_full() | ||||||
|             # Test if membership mapping based on memberUid works. |             # Test if membership mapping based on memberUid works. | ||||||
|             posix_group = Group.objects.filter(name="group-posix").first() |             posix_group = Group.objects.filter(name="group-posix").first() | ||||||
|             self.assertTrue(posix_group.users.filter(name="user-posix").exists()) |             self.assertTrue(posix_group.users.filter(name="user-posix").exists()) | ||||||
|  | |||||||
| @ -63,7 +63,7 @@ class TestSourceLDAPSamba(SeleniumTestCase): | |||||||
|         source.property_mappings_group.set( |         source.property_mappings_group.set( | ||||||
|             LDAPPropertyMapping.objects.filter(name="goauthentik.io/sources/ldap/default-name") |             LDAPPropertyMapping.objects.filter(name="goauthentik.io/sources/ldap/default-name") | ||||||
|         ) |         ) | ||||||
|         UserLDAPSynchronizer(source).sync() |         UserLDAPSynchronizer(source).sync_full() | ||||||
|         self.assertTrue(User.objects.filter(username="bob").exists()) |         self.assertTrue(User.objects.filter(username="bob").exists()) | ||||||
|         self.assertTrue(User.objects.filter(username="james").exists()) |         self.assertTrue(User.objects.filter(username="james").exists()) | ||||||
|         self.assertTrue(User.objects.filter(username="john").exists()) |         self.assertTrue(User.objects.filter(username="john").exists()) | ||||||
| @ -94,9 +94,9 @@ class TestSourceLDAPSamba(SeleniumTestCase): | |||||||
|         source.property_mappings_group.set( |         source.property_mappings_group.set( | ||||||
|             LDAPPropertyMapping.objects.filter(managed="goauthentik.io/sources/ldap/default-name") |             LDAPPropertyMapping.objects.filter(managed="goauthentik.io/sources/ldap/default-name") | ||||||
|         ) |         ) | ||||||
|         GroupLDAPSynchronizer(source).sync() |         GroupLDAPSynchronizer(source).sync_full() | ||||||
|         UserLDAPSynchronizer(source).sync() |         UserLDAPSynchronizer(source).sync_full() | ||||||
|         MembershipLDAPSynchronizer(source).sync() |         MembershipLDAPSynchronizer(source).sync_full() | ||||||
|         self.assertIsNotNone(User.objects.get(username="bob")) |         self.assertIsNotNone(User.objects.get(username="bob")) | ||||||
|         self.assertIsNotNone(User.objects.get(username="james")) |         self.assertIsNotNone(User.objects.get(username="james")) | ||||||
|         self.assertIsNotNone(User.objects.get(username="john")) |         self.assertIsNotNone(User.objects.get(username="john")) | ||||||
| @ -137,7 +137,7 @@ class TestSourceLDAPSamba(SeleniumTestCase): | |||||||
|         source.property_mappings_group.set( |         source.property_mappings_group.set( | ||||||
|             LDAPPropertyMapping.objects.filter(name="goauthentik.io/sources/ldap/default-name") |             LDAPPropertyMapping.objects.filter(name="goauthentik.io/sources/ldap/default-name") | ||||||
|         ) |         ) | ||||||
|         UserLDAPSynchronizer(source).sync() |         UserLDAPSynchronizer(source).sync_full() | ||||||
|         username = "bob" |         username = "bob" | ||||||
|         password = generate_id() |         password = generate_id() | ||||||
|         result = self.container.exec_run( |         result = self.container.exec_run( | ||||||
| @ -160,7 +160,7 @@ class TestSourceLDAPSamba(SeleniumTestCase): | |||||||
|         ) |         ) | ||||||
|         self.assertEqual(result.exit_code, 0) |         self.assertEqual(result.exit_code, 0) | ||||||
|         # Sync again |         # Sync again | ||||||
|         UserLDAPSynchronizer(source).sync() |         UserLDAPSynchronizer(source).sync_full() | ||||||
|         user.refresh_from_db() |         user.refresh_from_db() | ||||||
|         # Since password in samba was checked, it should be invalidated here too |         # Since password in samba was checked, it should be invalidated here too | ||||||
|         self.assertFalse(user.has_usable_password()) |         self.assertFalse(user.has_usable_password()) | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user
	 Jens L
					Jens L