Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
Marc 'risson' Schmitt
2025-03-27 18:51:47 +01:00
parent 8ae0f145f5
commit 75c13a8801
9 changed files with 86 additions and 67 deletions

View File

@ -5,6 +5,8 @@ from structlog.stdlib import get_logger
from authentik.blueprints.apps import ManagedAppConfig from authentik.blueprints.apps import ManagedAppConfig
from authentik.lib.config import CONFIG from authentik.lib.config import CONFIG
from authentik.lib.utils.time import fqdn_rand
from authentik.tasks.schedules.lib import ScheduleSpec
LOGGER = get_logger() LOGGER = get_logger()
@ -60,3 +62,15 @@ class AuthentikOutpostConfig(ManagedAppConfig):
outpost.save() outpost.save()
else: else:
Outpost.objects.filter(managed=MANAGED_OUTPOST).delete() Outpost.objects.filter(managed=MANAGED_OUTPOST).delete()
def get_tenant_schedule_specs(self) -> list[ScheduleSpec]:
return [
ScheduleSpec(
actor_name="authentik.outposts.tasks.outpost_token_ensurer",
crontab=f"{fqdn_rand('outpost_token_ensurer')} */8 * * *",
),
ScheduleSpec(
actor_name="authentik.outposts.tasks.outpost_connection_discovery",
crontab=f"{fqdn_rand('outpost_connection_discovery')} */8 * * *",
),
]

View File

@ -36,6 +36,7 @@ from authentik.lib.config import CONFIG
from authentik.lib.models import InheritanceForeignKey, SerializerModel from authentik.lib.models import InheritanceForeignKey, SerializerModel
from authentik.lib.sentry import SentryIgnoredException from authentik.lib.sentry import SentryIgnoredException
from authentik.lib.utils.errors import exception_to_string from authentik.lib.utils.errors import exception_to_string
from authentik.lib.utils.time import fqdn_rand
from authentik.outposts.controllers.k8s.utils import get_namespace from authentik.outposts.controllers.k8s.utils import get_namespace
from authentik.tasks.schedules.lib import ScheduleSpec from authentik.tasks.schedules.lib import ScheduleSpec
from authentik.tasks.schedules.models import ScheduledModel from authentik.tasks.schedules.models import ScheduledModel
@ -164,8 +165,8 @@ class OutpostServiceConnection(ScheduledModel, models.Model):
def schedule_specs(self) -> list[ScheduleSpec]: def schedule_specs(self) -> list[ScheduleSpec]:
return [ return [
ScheduleSpec( ScheduleSpec(
uid=self.pk,
actor_name="authentik.outposts.tasks.outpost_service_connection_monitor", actor_name="authentik.outposts.tasks.outpost_service_connection_monitor",
uid=self.pk,
args=(self.pk,), args=(self.pk,),
crontab="3-59/15 * * * *", crontab="3-59/15 * * * *",
description=_(f"Update cached state of service connection {self.name}"), description=_(f"Update cached state of service connection {self.name}"),
@ -256,7 +257,7 @@ class KubernetesServiceConnection(SerializerModel, OutpostServiceConnection):
return "ak-service-connection-kubernetes-form" return "ak-service-connection-kubernetes-form"
class Outpost(SerializerModel, ManagedModel): class Outpost(ScheduledModel, SerializerModel, ManagedModel):
"""Outpost instance which manages a service user and token""" """Outpost instance which manages a service user and token"""
uuid = models.UUIDField(default=uuid4, editable=False, primary_key=True) uuid = models.UUIDField(default=uuid4, editable=False, primary_key=True)
@ -310,6 +311,24 @@ class Outpost(SerializerModel, ManagedModel):
"""Username for service user""" """Username for service user"""
return f"ak-outpost-{self.uuid.hex}" return f"ak-outpost-{self.uuid.hex}"
@property
def schedule_specs(self) -> list[ScheduleSpec]:
specs = []
if self.service_connection is not None:
specs.append(
ScheduleSpec(
actor_name="authentik.outposts.tasks.outpost_controller",
uid=self.pk,
args=(self.pk, "up"),
kwargs={"action": "up", "from_cache": False},
crontab=f"{fqdn_rand('outpost_controller')} */4 * * *",
description=_(
f"Create/update/monitor/delete the deployment for the {self.name} outpost"
),
)
)
return specs
def build_user_permissions(self, user: User): def build_user_permissions(self, user: User):
"""Create per-object and global permissions for outpost service-account""" """Create per-object and global permissions for outpost service-account"""
# To ensure the user only has the correct permissions, we delete all of them and re-add # To ensure the user only has the correct permissions, we delete all of them and re-add

View File

@ -1,23 +0,0 @@
"""Outposts Settings"""
from celery.schedules import crontab
from authentik.lib.utils.time import fqdn_rand
CELERY_BEAT_SCHEDULE = {
"outposts_controller": {
"task": "authentik.outposts.tasks.outpost_controller_all",
"schedule": crontab(minute=fqdn_rand("outposts_controller"), hour="*/4"),
"options": {"queue": "authentik_scheduled"},
},
"outpost_token_ensurer": {
"task": "authentik.outposts.tasks.outpost_token_ensurer",
"schedule": crontab(minute=fqdn_rand("outpost_token_ensurer"), hour="*/8"),
"options": {"queue": "authentik_scheduled"},
},
"outpost_connection_discovery": {
"task": "authentik.outposts.tasks.outpost_connection_discovery",
"schedule": crontab(minute=fqdn_rand("outpost_connection_discovery"), hour="*/8"),
"options": {"queue": "authentik_scheduled"},
},
}

View File

@ -39,14 +39,14 @@ def pre_save_outpost(sender, instance: Outpost, **_):
if bool(dirty): if bool(dirty):
LOGGER.info("Outpost needs re-deployment due to changes", instance=instance) LOGGER.info("Outpost needs re-deployment due to changes", instance=instance)
cache.set(CACHE_KEY_OUTPOST_DOWN % instance.pk.hex, old_instance) cache.set(CACHE_KEY_OUTPOST_DOWN % instance.pk.hex, old_instance)
outpost_controller.delay(instance.pk.hex, action="down", from_cache=True) outpost_controller.send(instance.pk.hex, action="down", from_cache=True)
@receiver(m2m_changed, sender=Outpost.providers.through) @receiver(m2m_changed, sender=Outpost.providers.through)
def m2m_changed_update(sender, instance: Model, action: str, **_): def m2m_changed_update(sender, instance: Model, action: str, **_):
"""Update outpost on m2m change, when providers are added or removed""" """Update outpost on m2m change, when providers are added or removed"""
if action in ["post_add", "post_remove", "post_clear"]: if action in ["post_add", "post_remove", "post_clear"]:
outpost_post_save.delay(class_to_path(instance.__class__), instance.pk) outpost_post_save.send(class_to_path(instance.__class__), instance.pk)
@receiver(post_save) @receiver(post_save)
@ -64,7 +64,7 @@ def post_save_update(sender, instance: Model, created: bool, **_):
if isinstance(instance, Outpost) and created: if isinstance(instance, Outpost) and created:
LOGGER.info("New outpost saved, ensuring initial token and user are created") LOGGER.info("New outpost saved, ensuring initial token and user are created")
_ = instance.token _ = instance.token
outpost_post_save.delay(class_to_path(instance.__class__), instance.pk) outpost_post_save.send(class_to_path(instance.__class__), instance.pk)
@receiver(pre_delete, sender=Outpost) @receiver(pre_delete, sender=Outpost)
@ -72,4 +72,4 @@ def pre_delete_cleanup(sender, instance: Outpost, **_):
"""Ensure that Outpost's user is deleted (which will delete the token through cascade)""" """Ensure that Outpost's user is deleted (which will delete the token through cascade)"""
instance.user.delete() instance.user.delete()
cache.set(CACHE_KEY_OUTPOST_DOWN % instance.pk.hex, instance) cache.set(CACHE_KEY_OUTPOST_DOWN % instance.pk.hex, instance)
outpost_controller.delay(instance.pk.hex, action="down", from_cache=True) outpost_controller.send(instance.pk.hex, action="down", from_cache=True)

View File

@ -19,13 +19,13 @@ from kubernetes.config.kube_config import KUBE_CONFIG_DEFAULT_LOCATION
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from yaml import safe_load from yaml import safe_load
from authentik.events.models import TaskStatus
from authentik.events.system_tasks import SystemTask, prefill_task
from authentik.lib.config import CONFIG from authentik.lib.config import CONFIG
from authentik.lib.utils.reflection import path_to_class from authentik.lib.utils.reflection import path_to_class
from authentik.outposts.consumer import OUTPOST_GROUP from authentik.outposts.consumer import OUTPOST_GROUP
from authentik.outposts.controllers.base import BaseController, ControllerException from authentik.outposts.controllers.base import BaseController, ControllerException
from authentik.outposts.controllers.docker import DockerClient from authentik.outposts.controllers.docker import DockerClient
from authentik.tasks.middleware import CurrentTask
from authentik.tasks.models import TaskStatus, Task
from authentik.outposts.controllers.kubernetes import KubernetesClient from authentik.outposts.controllers.kubernetes import KubernetesClient
from authentik.outposts.models import ( from authentik.outposts.models import (
DockerServiceConnection, DockerServiceConnection,
@ -103,20 +103,10 @@ def outpost_service_connection_monitor(connection_pk: Any):
cache.set(connection.state_key, state, timeout=None) cache.set(connection.state_key, state, timeout=None)
@CELERY_APP.task( @actor
throws=(DatabaseError, ProgrammingError, InternalError), def outpost_controller(outpost_pk: str, action: str = "up", from_cache: bool = False):
)
def outpost_controller_all():
"""Launch Controller for all Outposts which support it"""
for outpost in Outpost.objects.exclude(service_connection=None):
outpost_controller.delay(outpost.pk.hex, "up", from_cache=False)
@CELERY_APP.task(bind=True, base=SystemTask)
def outpost_controller(
self: SystemTask, outpost_pk: str, action: str = "up", from_cache: bool = False
):
"""Create/update/monitor/delete the deployment of an Outpost""" """Create/update/monitor/delete the deployment of an Outpost"""
self: Task = CurrentTask.get_task()
logs = [] logs = []
if from_cache: if from_cache:
outpost: Outpost = cache.get(CACHE_KEY_OUTPOST_DOWN % outpost_pk) outpost: Outpost = cache.get(CACHE_KEY_OUTPOST_DOWN % outpost_pk)
@ -144,11 +134,11 @@ def outpost_controller(
self.set_status(TaskStatus.SUCCESSFUL, *logs) self.set_status(TaskStatus.SUCCESSFUL, *logs)
@CELERY_APP.task(bind=True, base=SystemTask) @actor
@prefill_task def outpost_token_ensurer():
def outpost_token_ensurer(self: SystemTask):
"""Periodically ensure that all Outposts have valid Service Accounts """Periodically ensure that all Outposts have valid Service Accounts
and Tokens""" and Tokens"""
self: Task = CurrentTask.get_task()
all_outposts = Outpost.objects.all() all_outposts = Outpost.objects.all()
for outpost in all_outposts: for outpost in all_outposts:
_ = outpost.token _ = outpost.token
@ -159,7 +149,7 @@ def outpost_token_ensurer(self: SystemTask):
) )
@CELERY_APP.task() @actor
def outpost_post_save(model_class: str, model_pk: Any): def outpost_post_save(model_class: str, model_pk: Any):
"""If an Outpost is saved, Ensure that token is created/updated """If an Outpost is saved, Ensure that token is created/updated
@ -174,7 +164,7 @@ def outpost_post_save(model_class: str, model_pk: Any):
if isinstance(instance, Outpost): if isinstance(instance, Outpost):
LOGGER.debug("Trigger reconcile for outpost", instance=instance) LOGGER.debug("Trigger reconcile for outpost", instance=instance)
outpost_controller.delay(str(instance.pk)) outpost_controller.send(instance.pk)
if isinstance(instance, OutpostModel | Outpost): if isinstance(instance, OutpostModel | Outpost):
LOGGER.debug("triggering outpost update from outpostmodel/outpost", instance=instance) LOGGER.debug("triggering outpost update from outpostmodel/outpost", instance=instance)
@ -182,7 +172,7 @@ def outpost_post_save(model_class: str, model_pk: Any):
if isinstance(instance, OutpostServiceConnection): if isinstance(instance, OutpostServiceConnection):
LOGGER.debug("triggering ServiceConnection state update", instance=instance) LOGGER.debug("triggering ServiceConnection state update", instance=instance)
outpost_service_connection_monitor.send(str(instance.pk)) outpost_service_connection_monitor.send(instance.pk)
for field in instance._meta.get_fields(): for field in instance._meta.get_fields():
# Each field is checked if it has a `related_model` attribute (when ForeginKeys or M2Ms) # Each field is checked if it has a `related_model` attribute (when ForeginKeys or M2Ms)
@ -229,12 +219,10 @@ def _outpost_single_update(outpost: Outpost, layer=None):
async_to_sync(layer.group_send)(group, {"type": "event.update"}) async_to_sync(layer.group_send)(group, {"type": "event.update"})
@CELERY_APP.task( @actor
base=SystemTask, def outpost_connection_discovery():
bind=True,
)
def outpost_connection_discovery(self: SystemTask):
"""Checks the local environment and create Service connections.""" """Checks the local environment and create Service connections."""
self: Task = CurrentTask.get_task()
messages = [] messages = []
if not CONFIG.get_bool("outposts.discover"): if not CONFIG.get_bool("outposts.discover"):
messages.append("Outpost integration discovery is disabled") messages.append("Outpost integration discovery is disabled")

View File

@ -8,9 +8,10 @@ from guardian.models import UserObjectPermission
from authentik.core.tests.utils import create_test_cert, create_test_flow from authentik.core.tests.utils import create_test_cert, create_test_flow
from authentik.outposts.models import Outpost, OutpostType from authentik.outposts.models import Outpost, OutpostType
from authentik.providers.proxy.models import ProxyProvider from authentik.providers.proxy.models import ProxyProvider
from authentik.tasks.tests import TaskTestCase
class OutpostTests(TestCase): class OutpostTests(TaskTestCase):
"""Outpost Tests""" """Outpost Tests"""
def setUp(self) -> None: def setUp(self) -> None:
@ -29,6 +30,7 @@ class OutpostTests(TestCase):
name="test", name="test",
type=OutpostType.PROXY, type=OutpostType.PROXY,
) )
self.tasks_join()
# Before we add a provider, the user should only have access to the outpost # Before we add a provider, the user should only have access to the outpost
permissions = UserObjectPermission.objects.filter(user=outpost.user) permissions = UserObjectPermission.objects.filter(user=outpost.user)

View File

@ -89,10 +89,10 @@ def task_error_hook(task_id: str, exception: Exception, traceback, *args, **kwar
def _get_startup_tasks_default_tenant() -> list[Callable]: def _get_startup_tasks_default_tenant() -> list[Callable]:
"""Get all tasks to be run on startup for the default tenant""" """Get all tasks to be run on startup for the default tenant"""
from authentik.outposts.tasks import outpost_connection_discovery # from authentik.outposts.tasks import outpost_connection_discovery
return [ return [
outpost_connection_discovery, # outpost_connection_discovery,
] ]

View File

@ -21,7 +21,17 @@ from dramatiq.broker import Broker, Consumer, MessageProxy
from dramatiq.common import compute_backoff, current_millis, dq_name, xq_name from dramatiq.common import compute_backoff, current_millis, dq_name, xq_name
from dramatiq.errors import ConnectionError, QueueJoinTimeout from dramatiq.errors import ConnectionError, QueueJoinTimeout
from dramatiq.message import Message from dramatiq.message import Message
from dramatiq.middleware import Middleware, Prometheus, default_middleware from dramatiq.middleware import (
AgeLimit,
Callbacks,
Middleware,
Pipelines,
Prometheus,
Retries,
ShutdownNotifications,
TimeLimit,
default_middleware,
)
from dramatiq.results import Results from dramatiq.results import Results
from pglock.core import _cast_lock_id from pglock.core import _cast_lock_id
from psycopg import Notify, sql from psycopg import Notify, sql
@ -77,7 +87,7 @@ class TenantMiddleware(Middleware):
class PostgresBroker(Broker): class PostgresBroker(Broker):
def __init__(self, *args, db_alias: str = DEFAULT_DB_ALIAS, results: bool = True, **kwargs): def __init__(self, *args, db_alias: str = DEFAULT_DB_ALIAS, results: bool = True, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, middleware=[], **kwargs)
self.logger = get_logger().bind() self.logger = get_logger().bind()
self.queues = set() self.queues = set()
@ -89,9 +99,14 @@ class PostgresBroker(Broker):
self.middleware = [] self.middleware = []
self.add_middleware(DbConnectionMiddleware()) self.add_middleware(DbConnectionMiddleware())
self.add_middleware(TenantMiddleware()) self.add_middleware(TenantMiddleware())
for middleware in default_middleware: for middleware in (
if middleware == Prometheus: AgeLimit,
pass TimeLimit,
ShutdownNotifications,
Callbacks,
Pipelines,
Retries,
):
self.add_middleware(middleware()) self.add_middleware(middleware())
if results: if results:
self.backend = PostgresBackend() self.backend = PostgresBackend()

View File

@ -1,8 +1,8 @@
from django.test import TransactionTestCase from django.test import TestCase
from dramatiq import Worker, get_broker from dramatiq import Worker, get_broker
class TaskTestCase(TransactionTestCase): class TaskTestCase(TestCase):
def _pre_setup(self): def _pre_setup(self):
super()._pre_setup() super()._pre_setup()
@ -17,6 +17,10 @@ class TaskTestCase(TransactionTestCase):
super()._post_teardown() super()._post_teardown()
def tasks_join(self, queue_name: str): def tasks_join(self, queue_name: str | None = None):
self.broker.join(queue_name) if queue_name is None:
for queue in self.broker.get_declared_queues():
self.broker.join(queue)
else:
self.broker.join(queue_name)
self.worker.join() self.worker.join()