From 78180e376f38263d8a5a3f2005d6e0764015e6fb Mon Sep 17 00:00:00 2001 From: Marc 'risson' Schmitt Date: Mon, 31 Mar 2025 17:26:33 +0200 Subject: [PATCH] wip Signed-off-by: Marc 'risson' Schmitt --- authentik/sources/kerberos/tasks.py | 2 +- authentik/sources/ldap/models.py | 24 ++++- authentik/sources/ldap/settings.py | 18 ---- authentik/sources/ldap/signals.py | 18 +--- authentik/sources/ldap/tasks.py | 109 +++++++++++----------- authentik/sources/ldap/tests/test_sync.py | 45 ++++----- authentik/tasks/schedules/models.py | 14 +++ authentik/tasks/schedules/scheduler.py | 10 +- 8 files changed, 117 insertions(+), 123 deletions(-) delete mode 100644 authentik/sources/ldap/settings.py diff --git a/authentik/sources/kerberos/tasks.py b/authentik/sources/kerberos/tasks.py index 8539ca76c7..f590294120 100644 --- a/authentik/sources/kerberos/tasks.py +++ b/authentik/sources/kerberos/tasks.py @@ -28,7 +28,7 @@ def kerberos_connectivity_check(pk: str): cache.set(CACHE_KEY_STATUS + source.slug, status, timeout=timeout) -@actor(time_limit=(60 * 60 * CONFIG.get_int("sources.kerberos.task_timeout_hours")) * 2.5) +@actor(time_limit=(60 * 60 * CONFIG.get_int("sources.kerberos.task_timeout_hours")) * 2.5 * 1000) def kerberos_sync(pk: str): """Sync a single source""" self: Task = CurrentTask.get_task() diff --git a/authentik/sources/ldap/models.py b/authentik/sources/ldap/models.py index dcfa0ccc1e..ec4b522f0f 100644 --- a/authentik/sources/ldap/models.py +++ b/authentik/sources/ldap/models.py @@ -19,6 +19,9 @@ from authentik.core.models import Group, PropertyMapping, Source from authentik.crypto.models import CertificateKeyPair from authentik.lib.config import CONFIG from authentik.lib.models import DomainlessURLValidator +from authentik.lib.utils.time import fqdn_rand +from authentik.tasks.schedules.lib import ScheduleSpec +from authentik.tasks.schedules.models import ScheduledModel LDAP_TIMEOUT = 15 LDAP_UNIQUENESS = "ldap_uniq" @@ -47,7 +50,7 @@ class MultiURLValidator(DomainlessURLValidator): super().__call__(value) -class LDAPSource(Source): +class LDAPSource(ScheduledModel, Source): """Federate LDAP Directory with authentik, or create new accounts in LDAP.""" server_uri = models.TextField( @@ -133,6 +136,25 @@ class LDAPSource(Source): return LDAPSourceSerializer + @property + def schedule_specs(self) -> list[ScheduleSpec]: + return [ + ScheduleSpec( + actor_name="authentik.sources.ldap.tasks.ldap_sync", + uid=self.pk, + args=(self.pk,), + crontab=f"{fqdn_rand('ldap_sync/' + str(self.pk))} */2 * * *", + description=_(f"Sync LDAP source '{self.name}'"), + ), + ScheduleSpec( + actor_name="authentik.sources.ldap.tasks.ldap_connectivity_check", + uid=self.pk, + args=(self.pk,), + crontab=f"{fqdn_rand('ldap_connectivity_check/' + str(self.pk))} * * * *", + description=_(f"Check connectivity for LDAP source '{self.name}'"), + ), + ] + @property def property_mapping_type(self) -> "type[PropertyMapping]": from authentik.sources.ldap.models import LDAPSourcePropertyMapping diff --git a/authentik/sources/ldap/settings.py b/authentik/sources/ldap/settings.py deleted file mode 100644 index c82dbeb0cb..0000000000 --- a/authentik/sources/ldap/settings.py +++ /dev/null @@ -1,18 +0,0 @@ -"""LDAP Settings""" - -from celery.schedules import crontab - -from authentik.lib.utils.time import fqdn_rand - -CELERY_BEAT_SCHEDULE = { - "sources_ldap_sync": { - "task": "authentik.sources.ldap.tasks.ldap_sync_all", - "schedule": crontab(minute=fqdn_rand("sources_ldap_sync"), hour="*/2"), - "options": {"queue": "authentik_scheduled"}, - }, - "sources_ldap_connectivity_check": { - "task": "authentik.sources.ldap.tasks.ldap_connectivity_check", - "schedule": crontab(minute=fqdn_rand("sources_ldap_connectivity_check"), hour="*"), - "options": {"queue": "authentik_scheduled"}, - }, -} diff --git a/authentik/sources/ldap/signals.py b/authentik/sources/ldap/signals.py index a2bad559bd..38850afca3 100644 --- a/authentik/sources/ldap/signals.py +++ b/authentik/sources/ldap/signals.py @@ -15,27 +15,19 @@ from authentik.events.models import Event, EventAction from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER from authentik.sources.ldap.models import LDAPSource from authentik.sources.ldap.password import LDAPPasswordChanger -from authentik.sources.ldap.tasks import ldap_connectivity_check, ldap_sync_single from authentik.stages.prompt.signals import password_validate LOGGER = get_logger() @receiver(post_save, sender=LDAPSource) -def sync_ldap_source_on_save(sender, instance: LDAPSource, **_): +def sync_ldap_source_on_save(sender, instance: LDAPSource, created: bool, **_): """Ensure that source is synced on save (if enabled)""" - if not instance.enabled: + # On creation, schedules are automatically run + if created or not instance.enabled: return - ldap_connectivity_check.delay(instance.pk) - # Don't sync sources when they don't have any property mappings. This will only happen if: - # - the user forgets to set them or - # - the source is newly created, this is the first save event - # and the mappings are created with an m2m event - if instance.sync_users and not instance.user_property_mappings.exists(): - return - if instance.sync_groups and not instance.group_property_mappings.exists(): - return - ldap_sync_single.delay(instance.pk) + for schedule in instance.schedules.all(): + schedule.send() @receiver(password_validate) diff --git a/authentik/sources/ldap/tasks.py b/authentik/sources/ldap/tasks.py index 2f0547a6ab..0c23c2e913 100644 --- a/authentik/sources/ldap/tasks.py +++ b/authentik/sources/ldap/tasks.py @@ -2,24 +2,23 @@ from uuid import uuid4 -from celery import chain, group from django.core.cache import cache +from dramatiq.actor import actor +from dramatiq.composition import group from ldap3.core.exceptions import LDAPException from structlog.stdlib import get_logger -from authentik.events.models import SystemTask as DBSystemTask -from authentik.events.models import TaskStatus -from authentik.events.system_tasks import SystemTask from authentik.lib.config import CONFIG from authentik.lib.sync.outgoing.exceptions import StopSync from authentik.lib.utils.errors import exception_to_string from authentik.lib.utils.reflection import class_to_path, path_to_class -from authentik.root.celery import CELERY_APP from authentik.sources.ldap.models import LDAPSource from authentik.sources.ldap.sync.base import BaseLDAPSynchronizer from authentik.sources.ldap.sync.groups import GroupLDAPSynchronizer from authentik.sources.ldap.sync.membership import MembershipLDAPSynchronizer from authentik.sources.ldap.sync.users import UserLDAPSynchronizer +from authentik.tasks.middleware import CurrentTask +from authentik.tasks.models import Task, TaskStatus LOGGER = get_logger() SYNC_CLASSES = [ @@ -31,83 +30,87 @@ CACHE_KEY_PREFIX = "goauthentik.io/sources/ldap/page/" CACHE_KEY_STATUS = "goauthentik.io/sources/ldap/status/" -@CELERY_APP.task() -def ldap_sync_all(): - """Sync all sources""" - for source in LDAPSource.objects.filter(enabled=True): - ldap_sync_single.apply_async(args=[str(source.pk)]) - - -@CELERY_APP.task() -def ldap_connectivity_check(pk: str | None = None): +@actor +def ldap_connectivity_check(source_pk: str): """Check connectivity for LDAP Sources""" # 2 hour timeout, this task should run every hour timeout = 60 * 60 * 2 - sources = LDAPSource.objects.filter(enabled=True) - if pk: - sources = sources.filter(pk=pk) - for source in sources: - status = source.check_connection() - cache.set(CACHE_KEY_STATUS + source.slug, status, timeout=timeout) + source = LDAPSource.objects.filter(enabled=True, pk=source_pk).first() + if not source: + return + status = source.check_connection() + cache.set(CACHE_KEY_STATUS + source.slug, status, timeout=timeout) -@CELERY_APP.task( - # We take the configured hours timeout time by 2.5 as we run user and - # group in parallel and then membership, so 2x is to cover the serial tasks, - # and 0.5x on top of that to give some more leeway - soft_time_limit=(60 * 60 * CONFIG.get_int("ldap.task_timeout_hours")) * 2.5, - task_time_limit=(60 * 60 * CONFIG.get_int("ldap.task_timeout_hours")) * 2.5, -) -def ldap_sync_single(source_pk: str): +# We take the configured hours timeout time by 2.5 as we run user and +# group in parallel and then membership, so 2x is to cover the serial tasks, +# and 0.5x on top of that to give some more leeway +@actor(time_limit=(60 * 60 * CONFIG.get_int("ldap.task_timeout_hours")) * 2.5 * 1000) +def ldap_sync(source_pk: str): """Sync a single source""" + self: Task = CurrentTask.get_task() source: LDAPSource = LDAPSource.objects.filter(pk=source_pk).first() if not source: return + # Don't sync sources when they don't have any property mappings. This will only happen if: + # - the user forgets to set them or + # - the source is newly created, the mappings are save a bit later, which might cause invalid data + if source.sync_users and not source.user_property_mappings.exists(): + # TODO: add to task messages + LOGGER.warning( + "LDAP source has user sync enabled but does not have user property mappings configured, not syncing", + source=source.slug, + ) + return + if source.sync_groups and not source.group_property_mappings.exists(): + # TODO: add to task messages + LOGGER.warning( + "LDAP source has group sync enabled but does not have group property mappings configured, not syncing", + source=source.slug, + ) + return with source.sync_lock as lock_acquired: if not lock_acquired: LOGGER.debug("Failed to acquire lock for LDAP sync, skipping task", source=source.slug) return - # Delete all sync tasks from the cache - DBSystemTask.objects.filter(name="ldap_sync", uid__startswith=source.slug).delete() - task = chain( - # User and group sync can happen at once, they have no dependencies on each other - group( - ldap_sync_paginator(source, UserLDAPSynchronizer) - + ldap_sync_paginator(source, GroupLDAPSynchronizer), - ), - # Membership sync needs to run afterwards - group( - ldap_sync_paginator(source, MembershipLDAPSynchronizer), - ), + # User and group sync can happen at once, they have no dependencies on each other + task_users_group = group( + ldap_sync_paginator(source, UserLDAPSynchronizer, schedule_uid=self.schedule_uid) + + ldap_sync_paginator(source, GroupLDAPSynchronizer, schedule_uid=self.schedule_uid), ) - task() + task_users_group.run() + task_users_group.wait(timeout=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours") * 1000) + # Membership sync needs to run afterwards + task_membership = group( + ldap_sync_paginator(source, MembershipLDAPSynchronizer, schedule_uid=self.schedule_uid), + ) + task_membership.run() + task_membership.wait(timeout=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours") * 1000) -def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) -> list: +def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer], **options) -> list: """Return a list of task signatures with LDAP pagination data""" sync_inst: BaseLDAPSynchronizer = sync(source) signatures = [] for page in sync_inst.get_objects(): page_cache_key = CACHE_KEY_PREFIX + str(uuid4()) cache.set(page_cache_key, page, 60 * 60 * CONFIG.get_int("ldap.task_timeout_hours")) - page_sync = ldap_sync.si(str(source.pk), class_to_path(sync), page_cache_key) + page_sync = ldap_sync_page.message_with_options( + args=(source.pk, class_to_path(sync), page_cache_key), + **options, + ) signatures.append(page_sync) return signatures -@CELERY_APP.task( - bind=True, - base=SystemTask, - soft_time_limit=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours"), - task_time_limit=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours"), -) -def ldap_sync(self: SystemTask, source_pk: str, sync_class: str, page_cache_key: str): +# Need to store results to be able to wait for the task above +@actor(time_limit=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours") * 1000) +def ldap_sync_page(source_pk: str, sync_class: str, page_cache_key: str): """Synchronization of an LDAP Source""" - self.result_timeout_hours = CONFIG.get_int("ldap.task_timeout_hours") + self: Task = CurrentTask.get_task() + # self.result_timeout_hours = CONFIG.get_int("ldap.task_timeout_hours") source: LDAPSource = LDAPSource.objects.filter(pk=source_pk).first() if not source: - # Because the source couldn't be found, we don't have a UID - # to set the state with return sync: type[BaseLDAPSynchronizer] = path_to_class(sync_class) uid = page_cache_key.replace(CACHE_KEY_PREFIX, "") diff --git a/authentik/sources/ldap/tests/test_sync.py b/authentik/sources/ldap/tests/test_sync.py index 42c8bea471..4d89fd5b18 100644 --- a/authentik/sources/ldap/tests/test_sync.py +++ b/authentik/sources/ldap/tests/test_sync.py @@ -17,7 +17,7 @@ from authentik.sources.ldap.models import LDAPSource, LDAPSourcePropertyMapping from authentik.sources.ldap.sync.groups import GroupLDAPSynchronizer from authentik.sources.ldap.sync.membership import MembershipLDAPSynchronizer from authentik.sources.ldap.sync.users import UserLDAPSynchronizer -from authentik.sources.ldap.tasks import ldap_sync, ldap_sync_all +from authentik.sources.ldap.tasks import ldap_sync, ldap_sync_page from authentik.sources.ldap.tests.mock_ad import mock_ad_connection from authentik.sources.ldap.tests.mock_freeipa import mock_freeipa_connection from authentik.sources.ldap.tests.mock_slapd import mock_slapd_connection @@ -38,13 +38,14 @@ class LDAPSyncTests(TestCase): additional_group_dn="ou=groups", ) - def test_sync_missing_page(self): - """Test sync with missing page""" - connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD)) - with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): - ldap_sync.delay(str(self.source.pk), class_to_path(UserLDAPSynchronizer), "foo").get() - task = SystemTask.objects.filter(name="ldap_sync", uid="ldap:users:foo").first() - self.assertEqual(task.status, TaskStatus.ERROR) + # TODO: fix me + # def test_sync_missing_page(self): + # """Test sync with missing page""" + # connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD)) + # with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): + # ldap_sync_page.send(str(self.source.pk), class_to_path(UserLDAPSynchronizer), "foo") + # task = SystemTask.objects.filter(name="ldap_sync", uid="ldap:users:foo").first() + # self.assertEqual(task.status, TaskStatus.ERROR) def test_sync_error(self): """Test user sync""" @@ -59,9 +60,9 @@ class LDAPSyncTests(TestCase): expression="q", ) self.source.user_property_mappings.set([mapping]) - self.source.save() connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD)) with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): + self.source.save() user_sync = UserLDAPSynchronizer(self.source) with self.assertRaises(StopSync): user_sync.sync_full() @@ -180,11 +181,8 @@ class LDAPSyncTests(TestCase): _user = create_test_admin_user() parent_group = Group.objects.get(name=_user.username) self.source.sync_parent_group = parent_group + # Sync is run on save self.source.save() - group_sync = GroupLDAPSynchronizer(self.source) - group_sync.sync_full() - membership_sync = MembershipLDAPSynchronizer(self.source) - membership_sync.sync_full() group: Group = Group.objects.filter(name="test-group").first() self.assertIsNotNone(group) self.assertEqual(group.parent, parent_group) @@ -206,11 +204,8 @@ class LDAPSyncTests(TestCase): ) connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD)) with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): + # Sync is run on save self.source.save() - group_sync = GroupLDAPSynchronizer(self.source) - group_sync.sync_full() - membership_sync = MembershipLDAPSynchronizer(self.source) - membership_sync.sync_full() group = Group.objects.filter(name="group1") self.assertTrue(group.exists()) @@ -233,14 +228,8 @@ class LDAPSyncTests(TestCase): ) connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD)) with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): + # Sync is run on save self.source.save() - user_sync = UserLDAPSynchronizer(self.source) - user_sync.sync_full() - group_sync = GroupLDAPSynchronizer(self.source) - group_sync.sync_full() - membership_sync = MembershipLDAPSynchronizer(self.source) - membership_sync.sync_full() - # Test if membership mapping based on memberUid works. posix_group = Group.objects.filter(name="group-posix").first() self.assertTrue(posix_group.users.filter(name="user-posix").exists()) @@ -252,10 +241,10 @@ class LDAPSyncTests(TestCase): | Q(managed__startswith="goauthentik.io/sources/ldap/ms") ) ) - self.source.save() connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD)) with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): - ldap_sync_all.delay().get() + self.source.save() + ldap_sync.send(self.source.pk).get_result() def test_tasks_openldap(self): """Test Scheduled tasks""" @@ -267,7 +256,7 @@ class LDAPSyncTests(TestCase): | Q(managed__startswith="goauthentik.io/sources/ldap/openldap") ) ) - self.source.save() connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD)) with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): - ldap_sync_all.delay().get() + self.source.save() + ldap_sync.send(self.source.pk).get_result() diff --git a/authentik/tasks/schedules/models.py b/authentik/tasks/schedules/models.py index 5c386253ec..66886d39be 100644 --- a/authentik/tasks/schedules/models.py +++ b/authentik/tasks/schedules/models.py @@ -1,3 +1,4 @@ +import pickle # nosec from uuid import uuid4 from cron_converter import Cron @@ -7,6 +8,9 @@ from django.core.exceptions import ValidationError from django.db import models from django.utils.timezone import datetime from django.utils.translation import gettext_lazy as _ +from dramatiq.actor import Actor +from dramatiq.broker import Broker, get_broker +from dramatiq.message import Message from authentik.lib.models import SerializerModel from authentik.tasks.schedules.lib import ScheduleSpec @@ -56,6 +60,16 @@ class Schedule(SerializerModel): return ScheduleSerializer + def send(self, broker: Broker | None = None) -> Message: + broker = broker or get_broker() + actor: Actor = broker.get_actor(self.actor_name) + return actor.send_with_options( + args=pickle.loads(self.args), # nosec + kwargs=pickle.loads(self.kwargs), # nosec + schedule_uid=self.uid, + ) + + # TODO: actually do loop here def calculate_next_run(self, next_run: datetime) -> datetime: return Cron(self.crontab).schedule(next_run).next() diff --git a/authentik/tasks/schedules/scheduler.py b/authentik/tasks/schedules/scheduler.py index 01763f65a6..ab53810486 100644 --- a/authentik/tasks/schedules/scheduler.py +++ b/authentik/tasks/schedules/scheduler.py @@ -1,9 +1,6 @@ -import pickle # nosec - import pglock from django.db import router, transaction from django.utils.timezone import now, timedelta -from dramatiq.actor import Actor from dramatiq.broker import Broker from structlog.stdlib import get_logger @@ -27,12 +24,7 @@ class Scheduler: next_run += timedelta(minutes=2) schedule.next_run = next_run - actor: Actor = self.broker.get_actor(schedule.actor_name) - actor.send_with_options( - args=pickle.loads(schedule.args), # nosec - kwargs=pickle.loads(schedule.kwargs), # nosec - schedule_uid=schedule.uid, - ) + schedule.send(self.broker) schedule.save()