diff --git a/authentik/outposts/apps.py b/authentik/outposts/apps.py index a7680a9aa5..063da78c39 100644 --- a/authentik/outposts/apps.py +++ b/authentik/outposts/apps.py @@ -5,6 +5,8 @@ from structlog.stdlib import get_logger from authentik.blueprints.apps import ManagedAppConfig from authentik.lib.config import CONFIG +from authentik.lib.utils.time import fqdn_rand +from authentik.tasks.schedules.lib import ScheduleSpec LOGGER = get_logger() @@ -60,3 +62,15 @@ class AuthentikOutpostConfig(ManagedAppConfig): outpost.save() else: Outpost.objects.filter(managed=MANAGED_OUTPOST).delete() + + def get_tenant_schedule_specs(self) -> list[ScheduleSpec]: + return [ + ScheduleSpec( + actor_name="authentik.outposts.tasks.outpost_token_ensurer", + crontab=f"{fqdn_rand('outpost_token_ensurer')} */8 * * *", + ), + ScheduleSpec( + actor_name="authentik.outposts.tasks.outpost_connection_discovery", + crontab=f"{fqdn_rand('outpost_connection_discovery')} */8 * * *", + ), + ] diff --git a/authentik/outposts/models.py b/authentik/outposts/models.py index a8a4efbf26..70bea1ded8 100644 --- a/authentik/outposts/models.py +++ b/authentik/outposts/models.py @@ -36,6 +36,7 @@ from authentik.lib.config import CONFIG from authentik.lib.models import InheritanceForeignKey, SerializerModel from authentik.lib.sentry import SentryIgnoredException from authentik.lib.utils.errors import exception_to_string +from authentik.lib.utils.time import fqdn_rand from authentik.outposts.controllers.k8s.utils import get_namespace from authentik.tasks.schedules.lib import ScheduleSpec from authentik.tasks.schedules.models import ScheduledModel @@ -164,8 +165,8 @@ class OutpostServiceConnection(ScheduledModel, models.Model): def schedule_specs(self) -> list[ScheduleSpec]: return [ ScheduleSpec( - uid=self.pk, actor_name="authentik.outposts.tasks.outpost_service_connection_monitor", + uid=self.pk, args=(self.pk,), crontab="3-59/15 * * * *", description=_(f"Update cached state of service connection {self.name}"), @@ -256,7 +257,7 @@ class KubernetesServiceConnection(SerializerModel, OutpostServiceConnection): return "ak-service-connection-kubernetes-form" -class Outpost(SerializerModel, ManagedModel): +class Outpost(ScheduledModel, SerializerModel, ManagedModel): """Outpost instance which manages a service user and token""" uuid = models.UUIDField(default=uuid4, editable=False, primary_key=True) @@ -310,6 +311,24 @@ class Outpost(SerializerModel, ManagedModel): """Username for service user""" return f"ak-outpost-{self.uuid.hex}" + @property + def schedule_specs(self) -> list[ScheduleSpec]: + specs = [] + if self.service_connection is not None: + specs.append( + ScheduleSpec( + actor_name="authentik.outposts.tasks.outpost_controller", + uid=self.pk, + args=(self.pk, "up"), + kwargs={"action": "up", "from_cache": False}, + crontab=f"{fqdn_rand('outpost_controller')} */4 * * *", + description=_( + f"Create/update/monitor/delete the deployment for the {self.name} outpost" + ), + ) + ) + return specs + def build_user_permissions(self, user: User): """Create per-object and global permissions for outpost service-account""" # To ensure the user only has the correct permissions, we delete all of them and re-add diff --git a/authentik/outposts/settings.py b/authentik/outposts/settings.py deleted file mode 100644 index 06f903f8ae..0000000000 --- a/authentik/outposts/settings.py +++ /dev/null @@ -1,23 +0,0 @@ -"""Outposts Settings""" - -from celery.schedules import crontab - -from authentik.lib.utils.time import fqdn_rand - -CELERY_BEAT_SCHEDULE = { - "outposts_controller": { - "task": "authentik.outposts.tasks.outpost_controller_all", - "schedule": crontab(minute=fqdn_rand("outposts_controller"), hour="*/4"), - "options": {"queue": "authentik_scheduled"}, - }, - "outpost_token_ensurer": { - "task": "authentik.outposts.tasks.outpost_token_ensurer", - "schedule": crontab(minute=fqdn_rand("outpost_token_ensurer"), hour="*/8"), - "options": {"queue": "authentik_scheduled"}, - }, - "outpost_connection_discovery": { - "task": "authentik.outposts.tasks.outpost_connection_discovery", - "schedule": crontab(minute=fqdn_rand("outpost_connection_discovery"), hour="*/8"), - "options": {"queue": "authentik_scheduled"}, - }, -} diff --git a/authentik/outposts/signals.py b/authentik/outposts/signals.py index 73d05a4b9a..37b9fb777b 100644 --- a/authentik/outposts/signals.py +++ b/authentik/outposts/signals.py @@ -39,14 +39,14 @@ def pre_save_outpost(sender, instance: Outpost, **_): if bool(dirty): LOGGER.info("Outpost needs re-deployment due to changes", instance=instance) cache.set(CACHE_KEY_OUTPOST_DOWN % instance.pk.hex, old_instance) - outpost_controller.delay(instance.pk.hex, action="down", from_cache=True) + outpost_controller.send(instance.pk.hex, action="down", from_cache=True) @receiver(m2m_changed, sender=Outpost.providers.through) def m2m_changed_update(sender, instance: Model, action: str, **_): """Update outpost on m2m change, when providers are added or removed""" if action in ["post_add", "post_remove", "post_clear"]: - outpost_post_save.delay(class_to_path(instance.__class__), instance.pk) + outpost_post_save.send(class_to_path(instance.__class__), instance.pk) @receiver(post_save) @@ -64,7 +64,7 @@ def post_save_update(sender, instance: Model, created: bool, **_): if isinstance(instance, Outpost) and created: LOGGER.info("New outpost saved, ensuring initial token and user are created") _ = instance.token - outpost_post_save.delay(class_to_path(instance.__class__), instance.pk) + outpost_post_save.send(class_to_path(instance.__class__), instance.pk) @receiver(pre_delete, sender=Outpost) @@ -72,4 +72,4 @@ def pre_delete_cleanup(sender, instance: Outpost, **_): """Ensure that Outpost's user is deleted (which will delete the token through cascade)""" instance.user.delete() cache.set(CACHE_KEY_OUTPOST_DOWN % instance.pk.hex, instance) - outpost_controller.delay(instance.pk.hex, action="down", from_cache=True) + outpost_controller.send(instance.pk.hex, action="down", from_cache=True) diff --git a/authentik/outposts/tasks.py b/authentik/outposts/tasks.py index a497bc7b2a..5885358918 100644 --- a/authentik/outposts/tasks.py +++ b/authentik/outposts/tasks.py @@ -19,13 +19,13 @@ from kubernetes.config.kube_config import KUBE_CONFIG_DEFAULT_LOCATION from structlog.stdlib import get_logger from yaml import safe_load -from authentik.events.models import TaskStatus -from authentik.events.system_tasks import SystemTask, prefill_task from authentik.lib.config import CONFIG from authentik.lib.utils.reflection import path_to_class from authentik.outposts.consumer import OUTPOST_GROUP from authentik.outposts.controllers.base import BaseController, ControllerException from authentik.outposts.controllers.docker import DockerClient +from authentik.tasks.middleware import CurrentTask +from authentik.tasks.models import TaskStatus, Task from authentik.outposts.controllers.kubernetes import KubernetesClient from authentik.outposts.models import ( DockerServiceConnection, @@ -103,20 +103,10 @@ def outpost_service_connection_monitor(connection_pk: Any): cache.set(connection.state_key, state, timeout=None) -@CELERY_APP.task( - throws=(DatabaseError, ProgrammingError, InternalError), -) -def outpost_controller_all(): - """Launch Controller for all Outposts which support it""" - for outpost in Outpost.objects.exclude(service_connection=None): - outpost_controller.delay(outpost.pk.hex, "up", from_cache=False) - - -@CELERY_APP.task(bind=True, base=SystemTask) -def outpost_controller( - self: SystemTask, outpost_pk: str, action: str = "up", from_cache: bool = False -): +@actor +def outpost_controller(outpost_pk: str, action: str = "up", from_cache: bool = False): """Create/update/monitor/delete the deployment of an Outpost""" + self: Task = CurrentTask.get_task() logs = [] if from_cache: outpost: Outpost = cache.get(CACHE_KEY_OUTPOST_DOWN % outpost_pk) @@ -144,11 +134,11 @@ def outpost_controller( self.set_status(TaskStatus.SUCCESSFUL, *logs) -@CELERY_APP.task(bind=True, base=SystemTask) -@prefill_task -def outpost_token_ensurer(self: SystemTask): +@actor +def outpost_token_ensurer(): """Periodically ensure that all Outposts have valid Service Accounts and Tokens""" + self: Task = CurrentTask.get_task() all_outposts = Outpost.objects.all() for outpost in all_outposts: _ = outpost.token @@ -159,7 +149,7 @@ def outpost_token_ensurer(self: SystemTask): ) -@CELERY_APP.task() +@actor def outpost_post_save(model_class: str, model_pk: Any): """If an Outpost is saved, Ensure that token is created/updated @@ -174,7 +164,7 @@ def outpost_post_save(model_class: str, model_pk: Any): if isinstance(instance, Outpost): LOGGER.debug("Trigger reconcile for outpost", instance=instance) - outpost_controller.delay(str(instance.pk)) + outpost_controller.send(instance.pk) if isinstance(instance, OutpostModel | Outpost): LOGGER.debug("triggering outpost update from outpostmodel/outpost", instance=instance) @@ -182,7 +172,7 @@ def outpost_post_save(model_class: str, model_pk: Any): if isinstance(instance, OutpostServiceConnection): LOGGER.debug("triggering ServiceConnection state update", instance=instance) - outpost_service_connection_monitor.send(str(instance.pk)) + outpost_service_connection_monitor.send(instance.pk) for field in instance._meta.get_fields(): # Each field is checked if it has a `related_model` attribute (when ForeginKeys or M2Ms) @@ -229,12 +219,10 @@ def _outpost_single_update(outpost: Outpost, layer=None): async_to_sync(layer.group_send)(group, {"type": "event.update"}) -@CELERY_APP.task( - base=SystemTask, - bind=True, -) -def outpost_connection_discovery(self: SystemTask): +@actor +def outpost_connection_discovery(): """Checks the local environment and create Service connections.""" + self: Task = CurrentTask.get_task() messages = [] if not CONFIG.get_bool("outposts.discover"): messages.append("Outpost integration discovery is disabled") diff --git a/authentik/outposts/tests/test_sa.py b/authentik/outposts/tests/test_sa.py index 59238a2cf8..4f36741a7a 100644 --- a/authentik/outposts/tests/test_sa.py +++ b/authentik/outposts/tests/test_sa.py @@ -8,9 +8,10 @@ from guardian.models import UserObjectPermission from authentik.core.tests.utils import create_test_cert, create_test_flow from authentik.outposts.models import Outpost, OutpostType from authentik.providers.proxy.models import ProxyProvider +from authentik.tasks.tests import TaskTestCase -class OutpostTests(TestCase): +class OutpostTests(TaskTestCase): """Outpost Tests""" def setUp(self) -> None: @@ -29,6 +30,7 @@ class OutpostTests(TestCase): name="test", type=OutpostType.PROXY, ) + self.tasks_join() # Before we add a provider, the user should only have access to the outpost permissions = UserObjectPermission.objects.filter(user=outpost.user) diff --git a/authentik/root/celery.py b/authentik/root/celery.py index 661bd1ed78..6f162e5c5c 100644 --- a/authentik/root/celery.py +++ b/authentik/root/celery.py @@ -89,10 +89,10 @@ def task_error_hook(task_id: str, exception: Exception, traceback, *args, **kwar def _get_startup_tasks_default_tenant() -> list[Callable]: """Get all tasks to be run on startup for the default tenant""" - from authentik.outposts.tasks import outpost_connection_discovery + # from authentik.outposts.tasks import outpost_connection_discovery return [ - outpost_connection_discovery, + # outpost_connection_discovery, ] diff --git a/authentik/tasks/broker.py b/authentik/tasks/broker.py index 5badf3c7d1..5eb4b89de0 100644 --- a/authentik/tasks/broker.py +++ b/authentik/tasks/broker.py @@ -21,7 +21,17 @@ from dramatiq.broker import Broker, Consumer, MessageProxy from dramatiq.common import compute_backoff, current_millis, dq_name, xq_name from dramatiq.errors import ConnectionError, QueueJoinTimeout from dramatiq.message import Message -from dramatiq.middleware import Middleware, Prometheus, default_middleware +from dramatiq.middleware import ( + AgeLimit, + Callbacks, + Middleware, + Pipelines, + Prometheus, + Retries, + ShutdownNotifications, + TimeLimit, + default_middleware, +) from dramatiq.results import Results from pglock.core import _cast_lock_id from psycopg import Notify, sql @@ -77,7 +87,7 @@ class TenantMiddleware(Middleware): class PostgresBroker(Broker): def __init__(self, *args, db_alias: str = DEFAULT_DB_ALIAS, results: bool = True, **kwargs): - super().__init__(*args, **kwargs) + super().__init__(*args, middleware=[], **kwargs) self.logger = get_logger().bind() self.queues = set() @@ -89,9 +99,14 @@ class PostgresBroker(Broker): self.middleware = [] self.add_middleware(DbConnectionMiddleware()) self.add_middleware(TenantMiddleware()) - for middleware in default_middleware: - if middleware == Prometheus: - pass + for middleware in ( + AgeLimit, + TimeLimit, + ShutdownNotifications, + Callbacks, + Pipelines, + Retries, + ): self.add_middleware(middleware()) if results: self.backend = PostgresBackend() diff --git a/authentik/tasks/tests.py b/authentik/tasks/tests.py index cfa85d66c9..e202c877e0 100644 --- a/authentik/tasks/tests.py +++ b/authentik/tasks/tests.py @@ -1,8 +1,8 @@ -from django.test import TransactionTestCase +from django.test import TestCase from dramatiq import Worker, get_broker -class TaskTestCase(TransactionTestCase): +class TaskTestCase(TestCase): def _pre_setup(self): super()._pre_setup() @@ -17,6 +17,10 @@ class TaskTestCase(TransactionTestCase): super()._post_teardown() - def tasks_join(self, queue_name: str): - self.broker.join(queue_name) + def tasks_join(self, queue_name: str | None = None): + if queue_name is None: + for queue in self.broker.get_declared_queues(): + self.broker.join(queue) + else: + self.broker.join(queue_name) self.worker.join()