Compare commits
	
		
			66 Commits
		
	
	
		
			celery-2-d
			...
			eap-but-ac
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 06848be14b | |||
| 4bae3bbe60 | |||
| e33f839d7f | |||
| f5eb827d14 | |||
| 9045f5ba73 | |||
| 7b97e92094 | |||
| 3027cdcc4b | |||
| 67f627a925 | |||
| f1101e0c01 | |||
| fb01a117ad | |||
| fad18db70b | |||
| e0c837257c | |||
| 2a567ccc85 | |||
| e36373ceab | |||
| d8a625be03 | |||
| 4d944f7444 | |||
| c49274042b | |||
| 10fc15ffe0 | |||
| 7c996d9d9d | |||
| 5d25f68b71 | |||
| 8da54d5811 | |||
| 4571f5e644 | |||
| ee234ea3aa | |||
| 82c177b7eb | |||
| 1155ccb3e8 | |||
| 1575b96262 | |||
| 19bb77638a | |||
| d6cf129eaa | |||
| b6686cff14 | |||
| 8cf8f1e199 | |||
| 50c50c4109 | |||
| 51f4a8d83d | |||
| 3ada3a7e0e | |||
| fa06c9fe4e | |||
| 2a024238fe | |||
| 91c87b7c3c | |||
| 318443f270 | |||
| ac88784089 | |||
| 855afa7b9f | |||
| 240abfef41 | |||
| 03075f1890 | |||
| 5bc0ed6e11 | |||
| 8f4cfc28c7 | |||
| 6d77eaaab7 | |||
| 9cee59537c | |||
| fc5c0e2789 | |||
| 573446689f | |||
| fd4bfe604d | |||
| 06e76a5b37 | |||
| 3c228bf5c3 | |||
| 8a80f07db2 | |||
| ae59a3e576 | |||
| df21e678d6 | |||
| a71532b3e3 | |||
| d7cb0b3ea1 | |||
| ba8f137885 | |||
| 958ff66070 | |||
| ad57c66a32 | |||
| 2bba0ddd74 | |||
| 767c0a8e45 | |||
| b10c795a26 | |||
| 8088e08fd9 | |||
| eab6e288d7 | |||
| 91c2863358 | |||
| 1638e95bc7 | |||
| 8f75131541 | 
							
								
								
									
										5
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										5
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@ -100,6 +100,9 @@ ipython_config.py
 | 
			
		||||
# pyenv
 | 
			
		||||
.python-version
 | 
			
		||||
 | 
			
		||||
# celery beat schedule file
 | 
			
		||||
celerybeat-schedule
 | 
			
		||||
 | 
			
		||||
# SageMath parsed files
 | 
			
		||||
*.sage.py
 | 
			
		||||
 | 
			
		||||
@ -163,6 +166,8 @@ dmypy.json
 | 
			
		||||
 | 
			
		||||
# pyenv
 | 
			
		||||
 | 
			
		||||
# celery beat schedule file
 | 
			
		||||
 | 
			
		||||
# SageMath parsed files
 | 
			
		||||
 | 
			
		||||
# Environments
 | 
			
		||||
 | 
			
		||||
@ -122,7 +122,6 @@ ENV UV_NO_BINARY_PACKAGE="cryptography lxml python-kadmin-rs xmlsec"
 | 
			
		||||
 | 
			
		||||
RUN --mount=type=bind,target=pyproject.toml,src=pyproject.toml \
 | 
			
		||||
    --mount=type=bind,target=uv.lock,src=uv.lock \
 | 
			
		||||
    --mount=type=bind,target=packages,src=packages \
 | 
			
		||||
    --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
    uv sync --frozen --no-install-project --no-dev
 | 
			
		||||
 | 
			
		||||
@ -168,7 +167,6 @@ COPY ./blueprints /blueprints
 | 
			
		||||
COPY ./lifecycle/ /lifecycle
 | 
			
		||||
COPY ./authentik/sources/kerberos/krb5.conf /etc/krb5.conf
 | 
			
		||||
COPY --from=go-builder /go/authentik /bin/authentik
 | 
			
		||||
COPY ./packages/ /ak-root/packages
 | 
			
		||||
COPY --from=python-deps /ak-root/.venv /ak-root/.venv
 | 
			
		||||
COPY --from=node-builder /work/web/dist/ /web/dist/
 | 
			
		||||
COPY --from=node-builder /work/web/authentik/ /web/authentik/
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								Makefile
									
									
									
									
									
								
							@ -6,7 +6,7 @@ PWD = $(shell pwd)
 | 
			
		||||
UID = $(shell id -u)
 | 
			
		||||
GID = $(shell id -g)
 | 
			
		||||
NPM_VERSION = $(shell python -m scripts.generate_semver)
 | 
			
		||||
PY_SOURCES = authentik packages tests scripts lifecycle .github
 | 
			
		||||
PY_SOURCES = authentik tests scripts lifecycle .github
 | 
			
		||||
DOCKER_IMAGE ?= "authentik:test"
 | 
			
		||||
 | 
			
		||||
GEN_API_TS = gen-ts-api
 | 
			
		||||
 | 
			
		||||
@ -41,7 +41,7 @@ class VersionSerializer(PassiveSerializer):
 | 
			
		||||
            return __version__
 | 
			
		||||
        version_in_cache = cache.get(VERSION_CACHE_KEY)
 | 
			
		||||
        if not version_in_cache:  # pragma: no cover
 | 
			
		||||
            update_latest_version.send()
 | 
			
		||||
            update_latest_version.delay()
 | 
			
		||||
            return __version__
 | 
			
		||||
        return version_in_cache
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										57
									
								
								authentik/admin/api/workers.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										57
									
								
								authentik/admin/api/workers.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,57 @@
 | 
			
		||||
"""authentik administration overview"""
 | 
			
		||||
 | 
			
		||||
from socket import gethostname
 | 
			
		||||
 | 
			
		||||
from django.conf import settings
 | 
			
		||||
from drf_spectacular.utils import extend_schema, inline_serializer
 | 
			
		||||
from packaging.version import parse
 | 
			
		||||
from rest_framework.fields import BooleanField, CharField
 | 
			
		||||
from rest_framework.request import Request
 | 
			
		||||
from rest_framework.response import Response
 | 
			
		||||
from rest_framework.views import APIView
 | 
			
		||||
 | 
			
		||||
from authentik import get_full_version
 | 
			
		||||
from authentik.rbac.permissions import HasPermission
 | 
			
		||||
from authentik.root.celery import CELERY_APP
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class WorkerView(APIView):
 | 
			
		||||
    """Get currently connected worker count."""
 | 
			
		||||
 | 
			
		||||
    permission_classes = [HasPermission("authentik_rbac.view_system_info")]
 | 
			
		||||
 | 
			
		||||
    @extend_schema(
 | 
			
		||||
        responses=inline_serializer(
 | 
			
		||||
            "Worker",
 | 
			
		||||
            fields={
 | 
			
		||||
                "worker_id": CharField(),
 | 
			
		||||
                "version": CharField(),
 | 
			
		||||
                "version_matching": BooleanField(),
 | 
			
		||||
            },
 | 
			
		||||
            many=True,
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
    def get(self, request: Request) -> Response:
 | 
			
		||||
        """Get currently connected worker count."""
 | 
			
		||||
        raw: list[dict[str, dict]] = CELERY_APP.control.ping(timeout=0.5)
 | 
			
		||||
        our_version = parse(get_full_version())
 | 
			
		||||
        response = []
 | 
			
		||||
        for worker in raw:
 | 
			
		||||
            key = list(worker.keys())[0]
 | 
			
		||||
            version = worker[key].get("version")
 | 
			
		||||
            version_matching = False
 | 
			
		||||
            if version:
 | 
			
		||||
                version_matching = parse(version) == our_version
 | 
			
		||||
            response.append(
 | 
			
		||||
                {"worker_id": key, "version": version, "version_matching": version_matching}
 | 
			
		||||
            )
 | 
			
		||||
        # In debug we run with `task_always_eager`, so tasks are ran on the main process
 | 
			
		||||
        if settings.DEBUG:  # pragma: no cover
 | 
			
		||||
            response.append(
 | 
			
		||||
                {
 | 
			
		||||
                    "worker_id": f"authentik-debug@{gethostname()}",
 | 
			
		||||
                    "version": get_full_version(),
 | 
			
		||||
                    "version_matching": True,
 | 
			
		||||
                }
 | 
			
		||||
            )
 | 
			
		||||
        return Response(response)
 | 
			
		||||
@ -3,9 +3,6 @@
 | 
			
		||||
from prometheus_client import Info
 | 
			
		||||
 | 
			
		||||
from authentik.blueprints.apps import ManagedAppConfig
 | 
			
		||||
from authentik.lib.config import CONFIG
 | 
			
		||||
from authentik.lib.utils.time import fqdn_rand
 | 
			
		||||
from authentik.tasks.schedules.lib import ScheduleSpec
 | 
			
		||||
 | 
			
		||||
PROM_INFO = Info("authentik_version", "Currently running authentik version")
 | 
			
		||||
 | 
			
		||||
@ -33,15 +30,3 @@ class AuthentikAdminConfig(ManagedAppConfig):
 | 
			
		||||
            notification_version = notification.event.context["new_version"]
 | 
			
		||||
            if LOCAL_VERSION >= parse(notification_version):
 | 
			
		||||
                notification.delete()
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def global_schedule_specs(self) -> list[ScheduleSpec]:
 | 
			
		||||
        from authentik.admin.tasks import update_latest_version
 | 
			
		||||
 | 
			
		||||
        return [
 | 
			
		||||
            ScheduleSpec(
 | 
			
		||||
                actor=update_latest_version,
 | 
			
		||||
                crontab=f"{fqdn_rand('admin_latest_version')} * * * *",
 | 
			
		||||
                paused=CONFIG.get_bool("disable_update_check"),
 | 
			
		||||
            ),
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										15
									
								
								authentik/admin/settings.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								authentik/admin/settings.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,15 @@
 | 
			
		||||
"""authentik admin settings"""
 | 
			
		||||
 | 
			
		||||
from celery.schedules import crontab
 | 
			
		||||
from django_tenants.utils import get_public_schema_name
 | 
			
		||||
 | 
			
		||||
from authentik.lib.utils.time import fqdn_rand
 | 
			
		||||
 | 
			
		||||
CELERY_BEAT_SCHEDULE = {
 | 
			
		||||
    "admin_latest_version": {
 | 
			
		||||
        "task": "authentik.admin.tasks.update_latest_version",
 | 
			
		||||
        "schedule": crontab(minute=fqdn_rand("admin_latest_version"), hour="*"),
 | 
			
		||||
        "tenant_schemas": [get_public_schema_name()],
 | 
			
		||||
        "options": {"queue": "authentik_scheduled"},
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										35
									
								
								authentik/admin/signals.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								authentik/admin/signals.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,35 @@
 | 
			
		||||
"""admin signals"""
 | 
			
		||||
 | 
			
		||||
from django.dispatch import receiver
 | 
			
		||||
from packaging.version import parse
 | 
			
		||||
from prometheus_client import Gauge
 | 
			
		||||
 | 
			
		||||
from authentik import get_full_version
 | 
			
		||||
from authentik.root.celery import CELERY_APP
 | 
			
		||||
from authentik.root.monitoring import monitoring_set
 | 
			
		||||
 | 
			
		||||
GAUGE_WORKERS = Gauge(
 | 
			
		||||
    "authentik_admin_workers",
 | 
			
		||||
    "Currently connected workers, their versions and if they are the same version as authentik",
 | 
			
		||||
    ["version", "version_matched"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_version = parse(get_full_version())
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@receiver(monitoring_set)
 | 
			
		||||
def monitoring_set_workers(sender, **kwargs):
 | 
			
		||||
    """Set worker gauge"""
 | 
			
		||||
    raw: list[dict[str, dict]] = CELERY_APP.control.ping(timeout=0.5)
 | 
			
		||||
    worker_version_count = {}
 | 
			
		||||
    for worker in raw:
 | 
			
		||||
        key = list(worker.keys())[0]
 | 
			
		||||
        version = worker[key].get("version")
 | 
			
		||||
        version_matching = False
 | 
			
		||||
        if version:
 | 
			
		||||
            version_matching = parse(version) == _version
 | 
			
		||||
        worker_version_count.setdefault(version, {"count": 0, "matching": version_matching})
 | 
			
		||||
        worker_version_count[version]["count"] += 1
 | 
			
		||||
    for version, stats in worker_version_count.items():
 | 
			
		||||
        GAUGE_WORKERS.labels(version, stats["matching"]).set(stats["count"])
 | 
			
		||||
@ -2,8 +2,6 @@
 | 
			
		||||
 | 
			
		||||
from django.core.cache import cache
 | 
			
		||||
from django.utils.translation import gettext_lazy as _
 | 
			
		||||
from django_dramatiq_postgres.middleware import CurrentTask
 | 
			
		||||
from dramatiq import actor
 | 
			
		||||
from packaging.version import parse
 | 
			
		||||
from requests import RequestException
 | 
			
		||||
from structlog.stdlib import get_logger
 | 
			
		||||
@ -11,9 +9,10 @@ from structlog.stdlib import get_logger
 | 
			
		||||
from authentik import __version__, get_build_hash
 | 
			
		||||
from authentik.admin.apps import PROM_INFO
 | 
			
		||||
from authentik.events.models import Event, EventAction
 | 
			
		||||
from authentik.events.system_tasks import SystemTask, TaskStatus, prefill_task
 | 
			
		||||
from authentik.lib.config import CONFIG
 | 
			
		||||
from authentik.lib.utils.http import get_http_session
 | 
			
		||||
from authentik.tasks.models import Task
 | 
			
		||||
from authentik.root.celery import CELERY_APP
 | 
			
		||||
 | 
			
		||||
LOGGER = get_logger()
 | 
			
		||||
VERSION_NULL = "0.0.0"
 | 
			
		||||
@ -33,12 +32,13 @@ def _set_prom_info():
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@actor(description=_("Update latest version info."))
 | 
			
		||||
def update_latest_version():
 | 
			
		||||
    self: Task = CurrentTask.get_task()
 | 
			
		||||
@CELERY_APP.task(bind=True, base=SystemTask)
 | 
			
		||||
@prefill_task
 | 
			
		||||
def update_latest_version(self: SystemTask):
 | 
			
		||||
    """Update latest version info"""
 | 
			
		||||
    if CONFIG.get_bool("disable_update_check"):
 | 
			
		||||
        cache.set(VERSION_CACHE_KEY, VERSION_NULL, VERSION_CACHE_TIMEOUT)
 | 
			
		||||
        self.info("Version check disabled.")
 | 
			
		||||
        self.set_status(TaskStatus.WARNING, "Version check disabled.")
 | 
			
		||||
        return
 | 
			
		||||
    try:
 | 
			
		||||
        response = get_http_session().get(
 | 
			
		||||
@ -48,7 +48,7 @@ def update_latest_version():
 | 
			
		||||
        data = response.json()
 | 
			
		||||
        upstream_version = data.get("stable", {}).get("version")
 | 
			
		||||
        cache.set(VERSION_CACHE_KEY, upstream_version, VERSION_CACHE_TIMEOUT)
 | 
			
		||||
        self.info("Successfully updated latest Version")
 | 
			
		||||
        self.set_status(TaskStatus.SUCCESSFUL, "Successfully updated latest Version")
 | 
			
		||||
        _set_prom_info()
 | 
			
		||||
        # Check if upstream version is newer than what we're running,
 | 
			
		||||
        # and if no event exists yet, create one.
 | 
			
		||||
@ -71,7 +71,7 @@ def update_latest_version():
 | 
			
		||||
            ).save()
 | 
			
		||||
    except (RequestException, IndexError) as exc:
 | 
			
		||||
        cache.set(VERSION_CACHE_KEY, VERSION_NULL, VERSION_CACHE_TIMEOUT)
 | 
			
		||||
        raise exc
 | 
			
		||||
        self.set_error(exc)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_set_prom_info()
 | 
			
		||||
 | 
			
		||||
@ -29,6 +29,13 @@ class TestAdminAPI(TestCase):
 | 
			
		||||
        body = loads(response.content)
 | 
			
		||||
        self.assertEqual(body["version_current"], __version__)
 | 
			
		||||
 | 
			
		||||
    def test_workers(self):
 | 
			
		||||
        """Test Workers API"""
 | 
			
		||||
        response = self.client.get(reverse("authentik_api:admin_workers"))
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
        body = loads(response.content)
 | 
			
		||||
        self.assertEqual(len(body), 0)
 | 
			
		||||
 | 
			
		||||
    def test_apps(self):
 | 
			
		||||
        """Test apps API"""
 | 
			
		||||
        response = self.client.get(reverse("authentik_api:apps-list"))
 | 
			
		||||
 | 
			
		||||
@ -30,7 +30,7 @@ class TestAdminTasks(TestCase):
 | 
			
		||||
        """Test Update checker with valid response"""
 | 
			
		||||
        with Mocker() as mocker, CONFIG.patch("disable_update_check", False):
 | 
			
		||||
            mocker.get("https://version.goauthentik.io/version.json", json=RESPONSE_VALID)
 | 
			
		||||
            update_latest_version.send()
 | 
			
		||||
            update_latest_version.delay().get()
 | 
			
		||||
            self.assertEqual(cache.get(VERSION_CACHE_KEY), "99999999.9999999")
 | 
			
		||||
            self.assertTrue(
 | 
			
		||||
                Event.objects.filter(
 | 
			
		||||
@ -40,7 +40,7 @@ class TestAdminTasks(TestCase):
 | 
			
		||||
                ).exists()
 | 
			
		||||
            )
 | 
			
		||||
            # test that a consecutive check doesn't create a duplicate event
 | 
			
		||||
            update_latest_version.send()
 | 
			
		||||
            update_latest_version.delay().get()
 | 
			
		||||
            self.assertEqual(
 | 
			
		||||
                len(
 | 
			
		||||
                    Event.objects.filter(
 | 
			
		||||
@ -56,7 +56,7 @@ class TestAdminTasks(TestCase):
 | 
			
		||||
        """Test Update checker with invalid response"""
 | 
			
		||||
        with Mocker() as mocker:
 | 
			
		||||
            mocker.get("https://version.goauthentik.io/version.json", status_code=400)
 | 
			
		||||
            update_latest_version.send()
 | 
			
		||||
            update_latest_version.delay().get()
 | 
			
		||||
            self.assertEqual(cache.get(VERSION_CACHE_KEY), "0.0.0")
 | 
			
		||||
            self.assertFalse(
 | 
			
		||||
                Event.objects.filter(
 | 
			
		||||
@ -67,15 +67,14 @@ class TestAdminTasks(TestCase):
 | 
			
		||||
    def test_version_disabled(self):
 | 
			
		||||
        """Test Update checker while its disabled"""
 | 
			
		||||
        with CONFIG.patch("disable_update_check", True):
 | 
			
		||||
            update_latest_version.send()
 | 
			
		||||
            update_latest_version.delay().get()
 | 
			
		||||
            self.assertEqual(cache.get(VERSION_CACHE_KEY), "0.0.0")
 | 
			
		||||
 | 
			
		||||
    def test_clear_update_notifications(self):
 | 
			
		||||
        """Test clear of previous notification"""
 | 
			
		||||
        admin_config = apps.get_app_config("authentik_admin")
 | 
			
		||||
        Event.objects.create(
 | 
			
		||||
            action=EventAction.UPDATE_AVAILABLE,
 | 
			
		||||
            context={"new_version": "99999999.9999999.9999999"},
 | 
			
		||||
            action=EventAction.UPDATE_AVAILABLE, context={"new_version": "99999999.9999999.9999999"}
 | 
			
		||||
        )
 | 
			
		||||
        Event.objects.create(action=EventAction.UPDATE_AVAILABLE, context={"new_version": "1.1.1"})
 | 
			
		||||
        Event.objects.create(action=EventAction.UPDATE_AVAILABLE, context={})
 | 
			
		||||
 | 
			
		||||
@ -6,11 +6,13 @@ from authentik.admin.api.meta import AppsViewSet, ModelViewSet
 | 
			
		||||
from authentik.admin.api.system import SystemView
 | 
			
		||||
from authentik.admin.api.version import VersionView
 | 
			
		||||
from authentik.admin.api.version_history import VersionHistoryViewSet
 | 
			
		||||
from authentik.admin.api.workers import WorkerView
 | 
			
		||||
 | 
			
		||||
api_urlpatterns = [
 | 
			
		||||
    ("admin/apps", AppsViewSet, "apps"),
 | 
			
		||||
    ("admin/models", ModelViewSet, "models"),
 | 
			
		||||
    path("admin/version/", VersionView.as_view(), name="admin_version"),
 | 
			
		||||
    ("admin/version/history", VersionHistoryViewSet, "version_history"),
 | 
			
		||||
    path("admin/workers/", WorkerView.as_view(), name="admin_workers"),
 | 
			
		||||
    path("admin/system/", SystemView.as_view(), name="admin_system"),
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
@ -39,7 +39,7 @@ class BlueprintInstanceSerializer(ModelSerializer):
 | 
			
		||||
        """Ensure the path (if set) specified is retrievable"""
 | 
			
		||||
        if path == "" or path.startswith(OCI_PREFIX):
 | 
			
		||||
            return path
 | 
			
		||||
        files: list[dict] = blueprints_find_dict.send().get_result(block=True)
 | 
			
		||||
        files: list[dict] = blueprints_find_dict.delay().get()
 | 
			
		||||
        if path not in [file["path"] for file in files]:
 | 
			
		||||
            raise ValidationError(_("Blueprint file does not exist"))
 | 
			
		||||
        return path
 | 
			
		||||
@ -115,7 +115,7 @@ class BlueprintInstanceViewSet(UsedByMixin, ModelViewSet):
 | 
			
		||||
    @action(detail=False, pagination_class=None, filter_backends=[])
 | 
			
		||||
    def available(self, request: Request) -> Response:
 | 
			
		||||
        """Get blueprints"""
 | 
			
		||||
        files: list[dict] = blueprints_find_dict.send().get_result(block=True)
 | 
			
		||||
        files: list[dict] = blueprints_find_dict.delay().get()
 | 
			
		||||
        return Response(files)
 | 
			
		||||
 | 
			
		||||
    @permission_required("authentik_blueprints.view_blueprintinstance")
 | 
			
		||||
@ -129,5 +129,5 @@ class BlueprintInstanceViewSet(UsedByMixin, ModelViewSet):
 | 
			
		||||
    def apply(self, request: Request, *args, **kwargs) -> Response:
 | 
			
		||||
        """Apply a blueprint"""
 | 
			
		||||
        blueprint = self.get_object()
 | 
			
		||||
        apply_blueprint.send_with_options(args=(blueprint.pk,), rel_obj=blueprint)
 | 
			
		||||
        apply_blueprint.delay(str(blueprint.pk)).get()
 | 
			
		||||
        return self.retrieve(request, *args, **kwargs)
 | 
			
		||||
 | 
			
		||||
@ -6,12 +6,9 @@ from inspect import ismethod
 | 
			
		||||
 | 
			
		||||
from django.apps import AppConfig
 | 
			
		||||
from django.db import DatabaseError, InternalError, ProgrammingError
 | 
			
		||||
from dramatiq.broker import get_broker
 | 
			
		||||
from structlog.stdlib import BoundLogger, get_logger
 | 
			
		||||
 | 
			
		||||
from authentik.lib.utils.time import fqdn_rand
 | 
			
		||||
from authentik.root.signals import startup
 | 
			
		||||
from authentik.tasks.schedules.lib import ScheduleSpec
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ManagedAppConfig(AppConfig):
 | 
			
		||||
@ -37,7 +34,7 @@ class ManagedAppConfig(AppConfig):
 | 
			
		||||
 | 
			
		||||
    def import_related(self):
 | 
			
		||||
        """Automatically import related modules which rely on just being imported
 | 
			
		||||
        to register themselves (mainly django signals and tasks)"""
 | 
			
		||||
        to register themselves (mainly django signals and celery tasks)"""
 | 
			
		||||
 | 
			
		||||
        def import_relative(rel_module: str):
 | 
			
		||||
            try:
 | 
			
		||||
@ -83,16 +80,6 @@ class ManagedAppConfig(AppConfig):
 | 
			
		||||
        func._authentik_managed_reconcile = ManagedAppConfig.RECONCILE_GLOBAL_CATEGORY
 | 
			
		||||
        return func
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def tenant_schedule_specs(self) -> list[ScheduleSpec]:
 | 
			
		||||
        """Get a list of schedule specs that must exist in each tenant"""
 | 
			
		||||
        return []
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def global_schedule_specs(self) -> list[ScheduleSpec]:
 | 
			
		||||
        """Get a list of schedule specs that must exist in the default tenant"""
 | 
			
		||||
        return []
 | 
			
		||||
 | 
			
		||||
    def _reconcile_tenant(self) -> None:
 | 
			
		||||
        """reconcile ourselves for tenanted methods"""
 | 
			
		||||
        from authentik.tenants.models import Tenant
 | 
			
		||||
@ -113,12 +100,8 @@ class ManagedAppConfig(AppConfig):
 | 
			
		||||
        """
 | 
			
		||||
        from django_tenants.utils import get_public_schema_name, schema_context
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            with schema_context(get_public_schema_name()):
 | 
			
		||||
                self._reconcile(self.RECONCILE_GLOBAL_CATEGORY)
 | 
			
		||||
        except (DatabaseError, ProgrammingError, InternalError) as exc:
 | 
			
		||||
            self.logger.debug("Failed to access database to run reconcile", exc=exc)
 | 
			
		||||
            return
 | 
			
		||||
        with schema_context(get_public_schema_name()):
 | 
			
		||||
            self._reconcile(self.RECONCILE_GLOBAL_CATEGORY)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AuthentikBlueprintsConfig(ManagedAppConfig):
 | 
			
		||||
@ -129,29 +112,19 @@ class AuthentikBlueprintsConfig(ManagedAppConfig):
 | 
			
		||||
    verbose_name = "authentik Blueprints"
 | 
			
		||||
    default = True
 | 
			
		||||
 | 
			
		||||
    @ManagedAppConfig.reconcile_global
 | 
			
		||||
    def load_blueprints_v1_tasks(self):
 | 
			
		||||
        """Load v1 tasks"""
 | 
			
		||||
        self.import_module("authentik.blueprints.v1.tasks")
 | 
			
		||||
 | 
			
		||||
    @ManagedAppConfig.reconcile_tenant
 | 
			
		||||
    def blueprints_discovery(self):
 | 
			
		||||
        """Run blueprint discovery"""
 | 
			
		||||
        from authentik.blueprints.v1.tasks import blueprints_discovery, clear_failed_blueprints
 | 
			
		||||
 | 
			
		||||
        blueprints_discovery.delay()
 | 
			
		||||
        clear_failed_blueprints.delay()
 | 
			
		||||
 | 
			
		||||
    def import_models(self):
 | 
			
		||||
        super().import_models()
 | 
			
		||||
        self.import_module("authentik.blueprints.v1.meta.apply_blueprint")
 | 
			
		||||
 | 
			
		||||
    @ManagedAppConfig.reconcile_global
 | 
			
		||||
    def tasks_middlewares(self):
 | 
			
		||||
        from authentik.blueprints.v1.tasks import BlueprintWatcherMiddleware
 | 
			
		||||
 | 
			
		||||
        get_broker().add_middleware(BlueprintWatcherMiddleware())
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def tenant_schedule_specs(self) -> list[ScheduleSpec]:
 | 
			
		||||
        from authentik.blueprints.v1.tasks import blueprints_discovery, clear_failed_blueprints
 | 
			
		||||
 | 
			
		||||
        return [
 | 
			
		||||
            ScheduleSpec(
 | 
			
		||||
                actor=blueprints_discovery,
 | 
			
		||||
                crontab=f"{fqdn_rand('blueprints_v1_discover')} * * * *",
 | 
			
		||||
                send_on_startup=True,
 | 
			
		||||
            ),
 | 
			
		||||
            ScheduleSpec(
 | 
			
		||||
                actor=clear_failed_blueprints,
 | 
			
		||||
                crontab=f"{fqdn_rand('blueprints_v1_cleanup')} * * * *",
 | 
			
		||||
                send_on_startup=True,
 | 
			
		||||
            ),
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
@ -3,7 +3,6 @@
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from uuid import uuid4
 | 
			
		||||
 | 
			
		||||
from django.contrib.contenttypes.fields import GenericRelation
 | 
			
		||||
from django.contrib.postgres.fields import ArrayField
 | 
			
		||||
from django.db import models
 | 
			
		||||
from django.utils.translation import gettext_lazy as _
 | 
			
		||||
@ -72,13 +71,6 @@ class BlueprintInstance(SerializerModel, ManagedModel, CreatedUpdatedModel):
 | 
			
		||||
    enabled = models.BooleanField(default=True)
 | 
			
		||||
    managed_models = ArrayField(models.TextField(), default=list)
 | 
			
		||||
 | 
			
		||||
    # Manual link to tasks instead of using TasksModel because of loop imports
 | 
			
		||||
    tasks = GenericRelation(
 | 
			
		||||
        "authentik_tasks.Task",
 | 
			
		||||
        content_type_field="rel_obj_content_type",
 | 
			
		||||
        object_id_field="rel_obj_id",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    class Meta:
 | 
			
		||||
        verbose_name = _("Blueprint Instance")
 | 
			
		||||
        verbose_name_plural = _("Blueprint Instances")
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										18
									
								
								authentik/blueprints/settings.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								authentik/blueprints/settings.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,18 @@
 | 
			
		||||
"""blueprint Settings"""
 | 
			
		||||
 | 
			
		||||
from celery.schedules import crontab
 | 
			
		||||
 | 
			
		||||
from authentik.lib.utils.time import fqdn_rand
 | 
			
		||||
 | 
			
		||||
CELERY_BEAT_SCHEDULE = {
 | 
			
		||||
    "blueprints_v1_discover": {
 | 
			
		||||
        "task": "authentik.blueprints.v1.tasks.blueprints_discovery",
 | 
			
		||||
        "schedule": crontab(minute=fqdn_rand("blueprints_v1_discover"), hour="*"),
 | 
			
		||||
        "options": {"queue": "authentik_scheduled"},
 | 
			
		||||
    },
 | 
			
		||||
    "blueprints_v1_cleanup": {
 | 
			
		||||
        "task": "authentik.blueprints.v1.tasks.clear_failed_blueprints",
 | 
			
		||||
        "schedule": crontab(minute=fqdn_rand("blueprints_v1_cleanup"), hour="*"),
 | 
			
		||||
        "options": {"queue": "authentik_scheduled"},
 | 
			
		||||
    },
 | 
			
		||||
}
 | 
			
		||||
@ -1,2 +0,0 @@
 | 
			
		||||
# Import all v1 tasks for auto task discovery
 | 
			
		||||
from authentik.blueprints.v1.tasks import *  # noqa: F403
 | 
			
		||||
@ -54,7 +54,7 @@ class TestBlueprintsV1Tasks(TransactionTestCase):
 | 
			
		||||
            file.seek(0)
 | 
			
		||||
            file_hash = sha512(file.read().encode()).hexdigest()
 | 
			
		||||
            file.flush()
 | 
			
		||||
            blueprints_discovery.send()
 | 
			
		||||
            blueprints_discovery()
 | 
			
		||||
            instance = BlueprintInstance.objects.filter(name=blueprint_id).first()
 | 
			
		||||
            self.assertEqual(instance.last_applied_hash, file_hash)
 | 
			
		||||
            self.assertEqual(
 | 
			
		||||
@ -82,7 +82,7 @@ class TestBlueprintsV1Tasks(TransactionTestCase):
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
            file.flush()
 | 
			
		||||
            blueprints_discovery.send()
 | 
			
		||||
            blueprints_discovery()
 | 
			
		||||
            blueprint = BlueprintInstance.objects.filter(name="foo").first()
 | 
			
		||||
            self.assertEqual(
 | 
			
		||||
                blueprint.last_applied_hash,
 | 
			
		||||
@ -107,7 +107,7 @@ class TestBlueprintsV1Tasks(TransactionTestCase):
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
            file.flush()
 | 
			
		||||
            blueprints_discovery.send()
 | 
			
		||||
            blueprints_discovery()
 | 
			
		||||
            blueprint.refresh_from_db()
 | 
			
		||||
            self.assertEqual(
 | 
			
		||||
                blueprint.last_applied_hash,
 | 
			
		||||
 | 
			
		||||
@ -57,6 +57,7 @@ from authentik.enterprise.stages.authenticator_endpoint_gdtc.models import (
 | 
			
		||||
    EndpointDeviceConnection,
 | 
			
		||||
)
 | 
			
		||||
from authentik.events.logs import LogEvent, capture_logs
 | 
			
		||||
from authentik.events.models import SystemTask
 | 
			
		||||
from authentik.events.utils import cleanse_dict
 | 
			
		||||
from authentik.flows.models import FlowToken, Stage
 | 
			
		||||
from authentik.lib.models import SerializerModel
 | 
			
		||||
@ -76,7 +77,6 @@ from authentik.providers.scim.models import SCIMProviderGroup, SCIMProviderUser
 | 
			
		||||
from authentik.rbac.models import Role
 | 
			
		||||
from authentik.sources.scim.models import SCIMSourceGroup, SCIMSourceUser
 | 
			
		||||
from authentik.stages.authenticator_webauthn.models import WebAuthnDeviceType
 | 
			
		||||
from authentik.tasks.models import Task
 | 
			
		||||
from authentik.tenants.models import Tenant
 | 
			
		||||
 | 
			
		||||
# Context set when the serializer is created in a blueprint context
 | 
			
		||||
@ -118,7 +118,7 @@ def excluded_models() -> list[type[Model]]:
 | 
			
		||||
        SCIMProviderGroup,
 | 
			
		||||
        SCIMProviderUser,
 | 
			
		||||
        Tenant,
 | 
			
		||||
        Task,
 | 
			
		||||
        SystemTask,
 | 
			
		||||
        ConnectionToken,
 | 
			
		||||
        AuthorizationCode,
 | 
			
		||||
        AccessToken,
 | 
			
		||||
 | 
			
		||||
@ -44,7 +44,7 @@ class ApplyBlueprintMetaSerializer(PassiveSerializer):
 | 
			
		||||
            return MetaResult()
 | 
			
		||||
        LOGGER.debug("Applying blueprint from meta model", blueprint=self.blueprint_instance)
 | 
			
		||||
 | 
			
		||||
        apply_blueprint(self.blueprint_instance.pk)
 | 
			
		||||
        apply_blueprint(str(self.blueprint_instance.pk))
 | 
			
		||||
        return MetaResult()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -4,17 +4,12 @@ from dataclasses import asdict, dataclass, field
 | 
			
		||||
from hashlib import sha512
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from sys import platform
 | 
			
		||||
from uuid import UUID
 | 
			
		||||
 | 
			
		||||
from dacite.core import from_dict
 | 
			
		||||
from django.conf import settings
 | 
			
		||||
from django.db import DatabaseError, InternalError, ProgrammingError
 | 
			
		||||
from django.utils.text import slugify
 | 
			
		||||
from django.utils.timezone import now
 | 
			
		||||
from django.utils.translation import gettext_lazy as _
 | 
			
		||||
from django_dramatiq_postgres.middleware import CurrentTask, CurrentTaskNotFound
 | 
			
		||||
from dramatiq.actor import actor
 | 
			
		||||
from dramatiq.middleware import Middleware
 | 
			
		||||
from structlog.stdlib import get_logger
 | 
			
		||||
from watchdog.events import (
 | 
			
		||||
    FileCreatedEvent,
 | 
			
		||||
@ -36,13 +31,15 @@ from authentik.blueprints.v1.importer import Importer
 | 
			
		||||
from authentik.blueprints.v1.labels import LABEL_AUTHENTIK_INSTANTIATE
 | 
			
		||||
from authentik.blueprints.v1.oci import OCI_PREFIX
 | 
			
		||||
from authentik.events.logs import capture_logs
 | 
			
		||||
from authentik.events.models import TaskStatus
 | 
			
		||||
from authentik.events.system_tasks import SystemTask, prefill_task
 | 
			
		||||
from authentik.events.utils import sanitize_dict
 | 
			
		||||
from authentik.lib.config import CONFIG
 | 
			
		||||
from authentik.tasks.models import Task
 | 
			
		||||
from authentik.tasks.schedules.models import Schedule
 | 
			
		||||
from authentik.root.celery import CELERY_APP
 | 
			
		||||
from authentik.tenants.models import Tenant
 | 
			
		||||
 | 
			
		||||
LOGGER = get_logger()
 | 
			
		||||
_file_watcher_started = False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
@ -56,21 +53,22 @@ class BlueprintFile:
 | 
			
		||||
    meta: BlueprintMetadata | None = field(default=None)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BlueprintWatcherMiddleware(Middleware):
 | 
			
		||||
    def start_blueprint_watcher(self):
 | 
			
		||||
        """Start blueprint watcher"""
 | 
			
		||||
        observer = Observer()
 | 
			
		||||
        kwargs = {}
 | 
			
		||||
        if platform.startswith("linux"):
 | 
			
		||||
            kwargs["event_filter"] = (FileCreatedEvent, FileModifiedEvent)
 | 
			
		||||
        observer.schedule(
 | 
			
		||||
            BlueprintEventHandler(), CONFIG.get("blueprints_dir"), recursive=True, **kwargs
 | 
			
		||||
        )
 | 
			
		||||
        observer.start()
 | 
			
		||||
def start_blueprint_watcher():
 | 
			
		||||
    """Start blueprint watcher, if it's not running already."""
 | 
			
		||||
    # This function might be called twice since it's called on celery startup
 | 
			
		||||
 | 
			
		||||
    def after_worker_boot(self, broker, worker):
 | 
			
		||||
        if not settings.TEST:
 | 
			
		||||
            self.start_blueprint_watcher()
 | 
			
		||||
    global _file_watcher_started  # noqa: PLW0603
 | 
			
		||||
    if _file_watcher_started:
 | 
			
		||||
        return
 | 
			
		||||
    observer = Observer()
 | 
			
		||||
    kwargs = {}
 | 
			
		||||
    if platform.startswith("linux"):
 | 
			
		||||
        kwargs["event_filter"] = (FileCreatedEvent, FileModifiedEvent)
 | 
			
		||||
    observer.schedule(
 | 
			
		||||
        BlueprintEventHandler(), CONFIG.get("blueprints_dir"), recursive=True, **kwargs
 | 
			
		||||
    )
 | 
			
		||||
    observer.start()
 | 
			
		||||
    _file_watcher_started = True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BlueprintEventHandler(FileSystemEventHandler):
 | 
			
		||||
@ -94,7 +92,7 @@ class BlueprintEventHandler(FileSystemEventHandler):
 | 
			
		||||
        LOGGER.debug("new blueprint file created, starting discovery")
 | 
			
		||||
        for tenant in Tenant.objects.filter(ready=True):
 | 
			
		||||
            with tenant:
 | 
			
		||||
                Schedule.dispatch_by_actor(blueprints_discovery)
 | 
			
		||||
                blueprints_discovery.delay()
 | 
			
		||||
 | 
			
		||||
    def on_modified(self, event: FileSystemEvent):
 | 
			
		||||
        """Process file modification"""
 | 
			
		||||
@ -105,14 +103,14 @@ class BlueprintEventHandler(FileSystemEventHandler):
 | 
			
		||||
            with tenant:
 | 
			
		||||
                for instance in BlueprintInstance.objects.filter(path=rel_path, enabled=True):
 | 
			
		||||
                    LOGGER.debug("modified blueprint file, starting apply", instance=instance)
 | 
			
		||||
                    apply_blueprint.send_with_options(args=(instance.pk,), rel_obj=instance)
 | 
			
		||||
                    apply_blueprint.delay(instance.pk.hex)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@actor(
 | 
			
		||||
    description=_("Find blueprints as `blueprints_find` does, but return a safe dict."),
 | 
			
		||||
@CELERY_APP.task(
 | 
			
		||||
    throws=(DatabaseError, ProgrammingError, InternalError),
 | 
			
		||||
)
 | 
			
		||||
def blueprints_find_dict():
 | 
			
		||||
    """Find blueprints as `blueprints_find` does, but return a safe dict"""
 | 
			
		||||
    blueprints = []
 | 
			
		||||
    for blueprint in blueprints_find():
 | 
			
		||||
        blueprints.append(sanitize_dict(asdict(blueprint)))
 | 
			
		||||
@ -148,19 +146,21 @@ def blueprints_find() -> list[BlueprintFile]:
 | 
			
		||||
    return blueprints
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@actor(
 | 
			
		||||
    description=_("Find blueprints and check if they need to be created in the database."),
 | 
			
		||||
    throws=(DatabaseError, ProgrammingError, InternalError),
 | 
			
		||||
@CELERY_APP.task(
 | 
			
		||||
    throws=(DatabaseError, ProgrammingError, InternalError), base=SystemTask, bind=True
 | 
			
		||||
)
 | 
			
		||||
def blueprints_discovery(path: str | None = None):
 | 
			
		||||
    self: Task = CurrentTask.get_task()
 | 
			
		||||
@prefill_task
 | 
			
		||||
def blueprints_discovery(self: SystemTask, path: str | None = None):
 | 
			
		||||
    """Find blueprints and check if they need to be created in the database"""
 | 
			
		||||
    count = 0
 | 
			
		||||
    for blueprint in blueprints_find():
 | 
			
		||||
        if path and blueprint.path != path:
 | 
			
		||||
            continue
 | 
			
		||||
        check_blueprint_v1_file(blueprint)
 | 
			
		||||
        count += 1
 | 
			
		||||
    self.info(f"Successfully imported {count} files.")
 | 
			
		||||
    self.set_status(
 | 
			
		||||
        TaskStatus.SUCCESSFUL, _("Successfully imported {count} files.".format(count=count))
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def check_blueprint_v1_file(blueprint: BlueprintFile):
 | 
			
		||||
@ -187,26 +187,22 @@ def check_blueprint_v1_file(blueprint: BlueprintFile):
 | 
			
		||||
        )
 | 
			
		||||
    if instance.last_applied_hash != blueprint.hash:
 | 
			
		||||
        LOGGER.info("Applying blueprint due to changed file", instance=instance, path=instance.path)
 | 
			
		||||
        apply_blueprint.send_with_options(args=(instance.pk,), rel_obj=instance)
 | 
			
		||||
        apply_blueprint.delay(str(instance.pk))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@actor(description=_("Apply single blueprint."))
 | 
			
		||||
def apply_blueprint(instance_pk: UUID):
 | 
			
		||||
    try:
 | 
			
		||||
        self: Task = CurrentTask.get_task()
 | 
			
		||||
    except CurrentTaskNotFound:
 | 
			
		||||
        self = Task()
 | 
			
		||||
    self.set_uid(str(instance_pk))
 | 
			
		||||
@CELERY_APP.task(
 | 
			
		||||
    bind=True,
 | 
			
		||||
    base=SystemTask,
 | 
			
		||||
)
 | 
			
		||||
def apply_blueprint(self: SystemTask, instance_pk: str):
 | 
			
		||||
    """Apply single blueprint"""
 | 
			
		||||
    self.save_on_success = False
 | 
			
		||||
    instance: BlueprintInstance | None = None
 | 
			
		||||
    try:
 | 
			
		||||
        instance: BlueprintInstance = BlueprintInstance.objects.filter(pk=instance_pk).first()
 | 
			
		||||
        if not instance:
 | 
			
		||||
            self.warning(f"Could not find blueprint {instance_pk}, skipping")
 | 
			
		||||
        if not instance or not instance.enabled:
 | 
			
		||||
            return
 | 
			
		||||
        self.set_uid(slugify(instance.name))
 | 
			
		||||
        if not instance.enabled:
 | 
			
		||||
            self.info(f"Blueprint {instance.name} is disabled, skipping")
 | 
			
		||||
            return
 | 
			
		||||
        blueprint_content = instance.retrieve()
 | 
			
		||||
        file_hash = sha512(blueprint_content.encode()).hexdigest()
 | 
			
		||||
        importer = Importer.from_string(blueprint_content, instance.context)
 | 
			
		||||
@ -216,18 +212,19 @@ def apply_blueprint(instance_pk: UUID):
 | 
			
		||||
        if not valid:
 | 
			
		||||
            instance.status = BlueprintInstanceStatus.ERROR
 | 
			
		||||
            instance.save()
 | 
			
		||||
            self.logs(logs)
 | 
			
		||||
            self.set_status(TaskStatus.ERROR, *logs)
 | 
			
		||||
            return
 | 
			
		||||
        with capture_logs() as logs:
 | 
			
		||||
            applied = importer.apply()
 | 
			
		||||
            if not applied:
 | 
			
		||||
                instance.status = BlueprintInstanceStatus.ERROR
 | 
			
		||||
                instance.save()
 | 
			
		||||
                self.logs(logs)
 | 
			
		||||
                self.set_status(TaskStatus.ERROR, *logs)
 | 
			
		||||
                return
 | 
			
		||||
        instance.status = BlueprintInstanceStatus.SUCCESSFUL
 | 
			
		||||
        instance.last_applied_hash = file_hash
 | 
			
		||||
        instance.last_applied = now()
 | 
			
		||||
        self.set_status(TaskStatus.SUCCESSFUL)
 | 
			
		||||
    except (
 | 
			
		||||
        OSError,
 | 
			
		||||
        DatabaseError,
 | 
			
		||||
@ -238,14 +235,15 @@ def apply_blueprint(instance_pk: UUID):
 | 
			
		||||
    ) as exc:
 | 
			
		||||
        if instance:
 | 
			
		||||
            instance.status = BlueprintInstanceStatus.ERROR
 | 
			
		||||
        self.error(exc)
 | 
			
		||||
        self.set_error(exc)
 | 
			
		||||
    finally:
 | 
			
		||||
        if instance:
 | 
			
		||||
            instance.save()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@actor(description=_("Remove blueprints which couldn't be fetched."))
 | 
			
		||||
@CELERY_APP.task()
 | 
			
		||||
def clear_failed_blueprints():
 | 
			
		||||
    """Remove blueprints which couldn't be fetched"""
 | 
			
		||||
    # Exclude OCI blueprints as those might be temporarily unavailable
 | 
			
		||||
    for blueprint in BlueprintInstance.objects.exclude(path__startswith=OCI_PREFIX):
 | 
			
		||||
        try:
 | 
			
		||||
 | 
			
		||||
@ -9,7 +9,6 @@ class AuthentikBrandsConfig(ManagedAppConfig):
 | 
			
		||||
    name = "authentik.brands"
 | 
			
		||||
    label = "authentik_brands"
 | 
			
		||||
    verbose_name = "authentik Brands"
 | 
			
		||||
    default = True
 | 
			
		||||
    mountpoints = {
 | 
			
		||||
        "authentik.brands.urls_root": "",
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -1,7 +1,8 @@
 | 
			
		||||
"""authentik core app config"""
 | 
			
		||||
 | 
			
		||||
from django.conf import settings
 | 
			
		||||
 | 
			
		||||
from authentik.blueprints.apps import ManagedAppConfig
 | 
			
		||||
from authentik.tasks.schedules.lib import ScheduleSpec
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AuthentikCoreConfig(ManagedAppConfig):
 | 
			
		||||
@ -13,6 +14,14 @@ class AuthentikCoreConfig(ManagedAppConfig):
 | 
			
		||||
    mountpoint = ""
 | 
			
		||||
    default = True
 | 
			
		||||
 | 
			
		||||
    @ManagedAppConfig.reconcile_global
 | 
			
		||||
    def debug_worker_hook(self):
 | 
			
		||||
        """Dispatch startup tasks inline when debugging"""
 | 
			
		||||
        if settings.DEBUG:
 | 
			
		||||
            from authentik.root.celery import worker_ready_hook
 | 
			
		||||
 | 
			
		||||
            worker_ready_hook()
 | 
			
		||||
 | 
			
		||||
    @ManagedAppConfig.reconcile_tenant
 | 
			
		||||
    def source_inbuilt(self):
 | 
			
		||||
        """Reconcile inbuilt source"""
 | 
			
		||||
@ -25,18 +34,3 @@ class AuthentikCoreConfig(ManagedAppConfig):
 | 
			
		||||
            },
 | 
			
		||||
            managed=Source.MANAGED_INBUILT,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def tenant_schedule_specs(self) -> list[ScheduleSpec]:
 | 
			
		||||
        from authentik.core.tasks import clean_expired_models, clean_temporary_users
 | 
			
		||||
 | 
			
		||||
        return [
 | 
			
		||||
            ScheduleSpec(
 | 
			
		||||
                actor=clean_expired_models,
 | 
			
		||||
                crontab="2-59/5 * * * *",
 | 
			
		||||
            ),
 | 
			
		||||
            ScheduleSpec(
 | 
			
		||||
                actor=clean_temporary_users,
 | 
			
		||||
                crontab="9-59/5 * * * *",
 | 
			
		||||
            ),
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										21
									
								
								authentik/core/management/commands/bootstrap_tasks.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								authentik/core/management/commands/bootstrap_tasks.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,21 @@
 | 
			
		||||
"""Run bootstrap tasks"""
 | 
			
		||||
 | 
			
		||||
from django.core.management.base import BaseCommand
 | 
			
		||||
from django_tenants.utils import get_public_schema_name
 | 
			
		||||
 | 
			
		||||
from authentik.root.celery import _get_startup_tasks_all_tenants, _get_startup_tasks_default_tenant
 | 
			
		||||
from authentik.tenants.models import Tenant
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Command(BaseCommand):
 | 
			
		||||
    """Run bootstrap tasks to ensure certain objects are created"""
 | 
			
		||||
 | 
			
		||||
    def handle(self, **options):
 | 
			
		||||
        for task in _get_startup_tasks_default_tenant():
 | 
			
		||||
            with Tenant.objects.get(schema_name=get_public_schema_name()):
 | 
			
		||||
                task()
 | 
			
		||||
 | 
			
		||||
        for task in _get_startup_tasks_all_tenants():
 | 
			
		||||
            for tenant in Tenant.objects.filter(ready=True):
 | 
			
		||||
                with tenant:
 | 
			
		||||
                    task()
 | 
			
		||||
							
								
								
									
										47
									
								
								authentik/core/management/commands/worker.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										47
									
								
								authentik/core/management/commands/worker.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,47 @@
 | 
			
		||||
"""Run worker"""
 | 
			
		||||
 | 
			
		||||
from sys import exit as sysexit
 | 
			
		||||
from tempfile import tempdir
 | 
			
		||||
 | 
			
		||||
from celery.apps.worker import Worker
 | 
			
		||||
from django.core.management.base import BaseCommand
 | 
			
		||||
from django.db import close_old_connections
 | 
			
		||||
from structlog.stdlib import get_logger
 | 
			
		||||
 | 
			
		||||
from authentik.lib.config import CONFIG
 | 
			
		||||
from authentik.lib.debug import start_debug_server
 | 
			
		||||
from authentik.root.celery import CELERY_APP
 | 
			
		||||
 | 
			
		||||
LOGGER = get_logger()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Command(BaseCommand):
 | 
			
		||||
    """Run worker"""
 | 
			
		||||
 | 
			
		||||
    def add_arguments(self, parser):
 | 
			
		||||
        parser.add_argument(
 | 
			
		||||
            "-b",
 | 
			
		||||
            "--beat",
 | 
			
		||||
            action="store_false",
 | 
			
		||||
            help="When set, this worker will _not_ run Beat (scheduled) tasks",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def handle(self, **options):
 | 
			
		||||
        LOGGER.debug("Celery options", **options)
 | 
			
		||||
        close_old_connections()
 | 
			
		||||
        start_debug_server()
 | 
			
		||||
        worker: Worker = CELERY_APP.Worker(
 | 
			
		||||
            no_color=False,
 | 
			
		||||
            quiet=True,
 | 
			
		||||
            optimization="fair",
 | 
			
		||||
            autoscale=(CONFIG.get_int("worker.concurrency"), 1),
 | 
			
		||||
            task_events=True,
 | 
			
		||||
            beat=options.get("beat", True),
 | 
			
		||||
            schedule_filename=f"{tempdir}/celerybeat-schedule",
 | 
			
		||||
            queues=["authentik", "authentik_scheduled", "authentik_events"],
 | 
			
		||||
        )
 | 
			
		||||
        for task in CELERY_APP.tasks:
 | 
			
		||||
            LOGGER.debug("Registered task", task=task)
 | 
			
		||||
 | 
			
		||||
        worker.start()
 | 
			
		||||
        sysexit(worker.exitcode)
 | 
			
		||||
@ -3,9 +3,6 @@
 | 
			
		||||
from datetime import datetime, timedelta
 | 
			
		||||
 | 
			
		||||
from django.utils.timezone import now
 | 
			
		||||
from django.utils.translation import gettext_lazy as _
 | 
			
		||||
from django_dramatiq_postgres.middleware import CurrentTask
 | 
			
		||||
from dramatiq.actor import actor
 | 
			
		||||
from structlog.stdlib import get_logger
 | 
			
		||||
 | 
			
		||||
from authentik.core.models import (
 | 
			
		||||
@ -14,14 +11,17 @@ from authentik.core.models import (
 | 
			
		||||
    ExpiringModel,
 | 
			
		||||
    User,
 | 
			
		||||
)
 | 
			
		||||
from authentik.tasks.models import Task
 | 
			
		||||
from authentik.events.system_tasks import SystemTask, TaskStatus, prefill_task
 | 
			
		||||
from authentik.root.celery import CELERY_APP
 | 
			
		||||
 | 
			
		||||
LOGGER = get_logger()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@actor(description=_("Remove expired objects."))
 | 
			
		||||
def clean_expired_models():
 | 
			
		||||
    self: Task = CurrentTask.get_task()
 | 
			
		||||
@CELERY_APP.task(bind=True, base=SystemTask)
 | 
			
		||||
@prefill_task
 | 
			
		||||
def clean_expired_models(self: SystemTask):
 | 
			
		||||
    """Remove expired objects"""
 | 
			
		||||
    messages = []
 | 
			
		||||
    for cls in ExpiringModel.__subclasses__():
 | 
			
		||||
        cls: ExpiringModel
 | 
			
		||||
        objects = (
 | 
			
		||||
@ -31,13 +31,16 @@ def clean_expired_models():
 | 
			
		||||
        for obj in objects:
 | 
			
		||||
            obj.expire_action()
 | 
			
		||||
        LOGGER.debug("Expired models", model=cls, amount=amount)
 | 
			
		||||
        self.info(f"Expired {amount} {cls._meta.verbose_name_plural}")
 | 
			
		||||
        messages.append(f"Expired {amount} {cls._meta.verbose_name_plural}")
 | 
			
		||||
    self.set_status(TaskStatus.SUCCESSFUL, *messages)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@actor(description=_("Remove temporary users created by SAML Sources."))
 | 
			
		||||
def clean_temporary_users():
 | 
			
		||||
    self: Task = CurrentTask.get_task()
 | 
			
		||||
@CELERY_APP.task(bind=True, base=SystemTask)
 | 
			
		||||
@prefill_task
 | 
			
		||||
def clean_temporary_users(self: SystemTask):
 | 
			
		||||
    """Remove temporary users created by SAML Sources"""
 | 
			
		||||
    _now = datetime.now()
 | 
			
		||||
    messages = []
 | 
			
		||||
    deleted_users = 0
 | 
			
		||||
    for user in User.objects.filter(**{f"attributes__{USER_ATTRIBUTE_GENERATED}": True}):
 | 
			
		||||
        if not user.attributes.get(USER_ATTRIBUTE_EXPIRES):
 | 
			
		||||
@ -49,4 +52,5 @@ def clean_temporary_users():
 | 
			
		||||
            LOGGER.debug("User is expired and will be deleted.", user=user, delta=delta)
 | 
			
		||||
            user.delete()
 | 
			
		||||
            deleted_users += 1
 | 
			
		||||
    self.info(f"Successfully deleted {deleted_users} users.")
 | 
			
		||||
    messages.append(f"Successfully deleted {deleted_users} users.")
 | 
			
		||||
    self.set_status(TaskStatus.SUCCESSFUL, *messages)
 | 
			
		||||
 | 
			
		||||
@ -36,7 +36,7 @@ class TestTasks(APITestCase):
 | 
			
		||||
            expires=now(), user=get_anonymous_user(), intent=TokenIntents.INTENT_API
 | 
			
		||||
        )
 | 
			
		||||
        key = token.key
 | 
			
		||||
        clean_expired_models.send()
 | 
			
		||||
        clean_expired_models.delay().get()
 | 
			
		||||
        token.refresh_from_db()
 | 
			
		||||
        self.assertNotEqual(key, token.key)
 | 
			
		||||
 | 
			
		||||
@ -50,5 +50,5 @@ class TestTasks(APITestCase):
 | 
			
		||||
                USER_ATTRIBUTE_EXPIRES: mktime(now().timetuple()),
 | 
			
		||||
            },
 | 
			
		||||
        )
 | 
			
		||||
        clean_temporary_users.send()
 | 
			
		||||
        clean_temporary_users.delay().get()
 | 
			
		||||
        self.assertFalse(User.objects.filter(username=username))
 | 
			
		||||
 | 
			
		||||
@ -4,8 +4,6 @@ from datetime import UTC, datetime
 | 
			
		||||
 | 
			
		||||
from authentik.blueprints.apps import ManagedAppConfig
 | 
			
		||||
from authentik.lib.generators import generate_id
 | 
			
		||||
from authentik.lib.utils.time import fqdn_rand
 | 
			
		||||
from authentik.tasks.schedules.lib import ScheduleSpec
 | 
			
		||||
 | 
			
		||||
MANAGED_KEY = "goauthentik.io/crypto/jwt-managed"
 | 
			
		||||
 | 
			
		||||
@ -69,14 +67,3 @@ class AuthentikCryptoConfig(ManagedAppConfig):
 | 
			
		||||
                "key_data": builder.private_key,
 | 
			
		||||
            },
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def tenant_schedule_specs(self) -> list[ScheduleSpec]:
 | 
			
		||||
        from authentik.crypto.tasks import certificate_discovery
 | 
			
		||||
 | 
			
		||||
        return [
 | 
			
		||||
            ScheduleSpec(
 | 
			
		||||
                actor=certificate_discovery,
 | 
			
		||||
                crontab=f"{fqdn_rand('crypto_certificate_discovery')} * * * *",
 | 
			
		||||
            ),
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										13
									
								
								authentik/crypto/settings.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								authentik/crypto/settings.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,13 @@
 | 
			
		||||
"""Crypto task Settings"""
 | 
			
		||||
 | 
			
		||||
from celery.schedules import crontab
 | 
			
		||||
 | 
			
		||||
from authentik.lib.utils.time import fqdn_rand
 | 
			
		||||
 | 
			
		||||
CELERY_BEAT_SCHEDULE = {
 | 
			
		||||
    "crypto_certificate_discovery": {
 | 
			
		||||
        "task": "authentik.crypto.tasks.certificate_discovery",
 | 
			
		||||
        "schedule": crontab(minute=fqdn_rand("crypto_certificate_discovery"), hour="*"),
 | 
			
		||||
        "options": {"queue": "authentik_scheduled"},
 | 
			
		||||
    },
 | 
			
		||||
}
 | 
			
		||||
@ -7,13 +7,13 @@ from cryptography.hazmat.backends import default_backend
 | 
			
		||||
from cryptography.hazmat.primitives.serialization import load_pem_private_key
 | 
			
		||||
from cryptography.x509.base import load_pem_x509_certificate
 | 
			
		||||
from django.utils.translation import gettext_lazy as _
 | 
			
		||||
from django_dramatiq_postgres.middleware import CurrentTask
 | 
			
		||||
from dramatiq.actor import actor
 | 
			
		||||
from structlog.stdlib import get_logger
 | 
			
		||||
 | 
			
		||||
from authentik.crypto.models import CertificateKeyPair
 | 
			
		||||
from authentik.events.models import TaskStatus
 | 
			
		||||
from authentik.events.system_tasks import SystemTask, prefill_task
 | 
			
		||||
from authentik.lib.config import CONFIG
 | 
			
		||||
from authentik.tasks.models import Task
 | 
			
		||||
from authentik.root.celery import CELERY_APP
 | 
			
		||||
 | 
			
		||||
LOGGER = get_logger()
 | 
			
		||||
 | 
			
		||||
@ -36,9 +36,10 @@ def ensure_certificate_valid(body: str):
 | 
			
		||||
    return body
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@actor(description=_("Discover, import and update certificates from the filesystem."))
 | 
			
		||||
def certificate_discovery():
 | 
			
		||||
    self: Task = CurrentTask.get_task()
 | 
			
		||||
@CELERY_APP.task(bind=True, base=SystemTask)
 | 
			
		||||
@prefill_task
 | 
			
		||||
def certificate_discovery(self: SystemTask):
 | 
			
		||||
    """Discover, import and update certificates from the filesystem"""
 | 
			
		||||
    certs = {}
 | 
			
		||||
    private_keys = {}
 | 
			
		||||
    discovered = 0
 | 
			
		||||
@ -83,4 +84,6 @@ def certificate_discovery():
 | 
			
		||||
                dirty = True
 | 
			
		||||
        if dirty:
 | 
			
		||||
            cert.save()
 | 
			
		||||
    self.info(f"Successfully imported {discovered} files.")
 | 
			
		||||
    self.set_status(
 | 
			
		||||
        TaskStatus.SUCCESSFUL, _("Successfully imported {count} files.".format(count=discovered))
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -338,7 +338,7 @@ class TestCrypto(APITestCase):
 | 
			
		||||
            with open(f"{temp_dir}/foo.bar/privkey.pem", "w+", encoding="utf-8") as _key:
 | 
			
		||||
                _key.write(builder.private_key)
 | 
			
		||||
            with CONFIG.patch("cert_discovery_dir", temp_dir):
 | 
			
		||||
                certificate_discovery.send()
 | 
			
		||||
                certificate_discovery()
 | 
			
		||||
        keypair: CertificateKeyPair = CertificateKeyPair.objects.filter(
 | 
			
		||||
            managed=MANAGED_DISCOVERED % "foo"
 | 
			
		||||
        ).first()
 | 
			
		||||
 | 
			
		||||
@ -3,8 +3,6 @@
 | 
			
		||||
from django.conf import settings
 | 
			
		||||
 | 
			
		||||
from authentik.blueprints.apps import ManagedAppConfig
 | 
			
		||||
from authentik.lib.utils.time import fqdn_rand
 | 
			
		||||
from authentik.tasks.schedules.lib import ScheduleSpec
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class EnterpriseConfig(ManagedAppConfig):
 | 
			
		||||
@ -28,14 +26,3 @@ class AuthentikEnterpriseConfig(EnterpriseConfig):
 | 
			
		||||
        from authentik.enterprise.license import LicenseKey
 | 
			
		||||
 | 
			
		||||
        return LicenseKey.cached_summary().status.is_valid
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def tenant_schedule_specs(self) -> list[ScheduleSpec]:
 | 
			
		||||
        from authentik.enterprise.tasks import enterprise_update_usage
 | 
			
		||||
 | 
			
		||||
        return [
 | 
			
		||||
            ScheduleSpec(
 | 
			
		||||
                actor=enterprise_update_usage,
 | 
			
		||||
                crontab=f"{fqdn_rand('enterprise_update_usage')} */2 * * *",
 | 
			
		||||
            ),
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
@ -1,8 +1,6 @@
 | 
			
		||||
"""authentik Unique Password policy app config"""
 | 
			
		||||
 | 
			
		||||
from authentik.enterprise.apps import EnterpriseConfig
 | 
			
		||||
from authentik.lib.utils.time import fqdn_rand
 | 
			
		||||
from authentik.tasks.schedules.lib import ScheduleSpec
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AuthentikEnterprisePoliciesUniquePasswordConfig(EnterpriseConfig):
 | 
			
		||||
@ -10,21 +8,3 @@ class AuthentikEnterprisePoliciesUniquePasswordConfig(EnterpriseConfig):
 | 
			
		||||
    label = "authentik_policies_unique_password"
 | 
			
		||||
    verbose_name = "authentik Enterprise.Policies.Unique Password"
 | 
			
		||||
    default = True
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def tenant_schedule_specs(self) -> list[ScheduleSpec]:
 | 
			
		||||
        from authentik.enterprise.policies.unique_password.tasks import (
 | 
			
		||||
            check_and_purge_password_history,
 | 
			
		||||
            trim_password_histories,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        return [
 | 
			
		||||
            ScheduleSpec(
 | 
			
		||||
                actor=trim_password_histories,
 | 
			
		||||
                crontab=f"{fqdn_rand('policies_unique_password_trim')} */12 * * *",
 | 
			
		||||
            ),
 | 
			
		||||
            ScheduleSpec(
 | 
			
		||||
                actor=check_and_purge_password_history,
 | 
			
		||||
                crontab=f"{fqdn_rand('policies_unique_password_purge')} */24 * * *",
 | 
			
		||||
            ),
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										20
									
								
								authentik/enterprise/policies/unique_password/settings.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								authentik/enterprise/policies/unique_password/settings.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,20 @@
 | 
			
		||||
"""Unique Password Policy settings"""
 | 
			
		||||
 | 
			
		||||
from celery.schedules import crontab
 | 
			
		||||
 | 
			
		||||
from authentik.lib.utils.time import fqdn_rand
 | 
			
		||||
 | 
			
		||||
CELERY_BEAT_SCHEDULE = {
 | 
			
		||||
    "policies_unique_password_trim_history": {
 | 
			
		||||
        "task": "authentik.enterprise.policies.unique_password.tasks.trim_password_histories",
 | 
			
		||||
        "schedule": crontab(minute=fqdn_rand("policies_unique_password_trim"), hour="*/12"),
 | 
			
		||||
        "options": {"queue": "authentik_scheduled"},
 | 
			
		||||
    },
 | 
			
		||||
    "policies_unique_password_check_purge": {
 | 
			
		||||
        "task": (
 | 
			
		||||
            "authentik.enterprise.policies.unique_password.tasks.check_and_purge_password_history"
 | 
			
		||||
        ),
 | 
			
		||||
        "schedule": crontab(minute=fqdn_rand("policies_unique_password_purge"), hour="*/24"),
 | 
			
		||||
        "options": {"queue": "authentik_scheduled"},
 | 
			
		||||
    },
 | 
			
		||||
}
 | 
			
		||||
@ -1,37 +1,35 @@
 | 
			
		||||
from django.db.models.aggregates import Count
 | 
			
		||||
from django.utils.translation import gettext_lazy as _
 | 
			
		||||
from django_dramatiq_postgres.middleware import CurrentTask
 | 
			
		||||
from dramatiq.actor import actor
 | 
			
		||||
from structlog import get_logger
 | 
			
		||||
 | 
			
		||||
from authentik.enterprise.policies.unique_password.models import (
 | 
			
		||||
    UniquePasswordPolicy,
 | 
			
		||||
    UserPasswordHistory,
 | 
			
		||||
)
 | 
			
		||||
from authentik.tasks.models import Task
 | 
			
		||||
from authentik.events.system_tasks import SystemTask, TaskStatus, prefill_task
 | 
			
		||||
from authentik.root.celery import CELERY_APP
 | 
			
		||||
 | 
			
		||||
LOGGER = get_logger()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@actor(
 | 
			
		||||
    description=_(
 | 
			
		||||
        "Check if any UniquePasswordPolicy exists, and if not, purge the password history table."
 | 
			
		||||
    )
 | 
			
		||||
)
 | 
			
		||||
def check_and_purge_password_history():
 | 
			
		||||
    self: Task = CurrentTask.get_task()
 | 
			
		||||
 | 
			
		||||
@CELERY_APP.task(bind=True, base=SystemTask)
 | 
			
		||||
@prefill_task
 | 
			
		||||
def check_and_purge_password_history(self: SystemTask):
 | 
			
		||||
    """Check if any UniquePasswordPolicy exists, and if not, purge the password history table.
 | 
			
		||||
    This is run on a schedule instead of being triggered by policy binding deletion.
 | 
			
		||||
    """
 | 
			
		||||
    if not UniquePasswordPolicy.objects.exists():
 | 
			
		||||
        UserPasswordHistory.objects.all().delete()
 | 
			
		||||
        LOGGER.debug("Purged UserPasswordHistory table as no policies are in use")
 | 
			
		||||
        self.info("Successfully purged UserPasswordHistory")
 | 
			
		||||
        self.set_status(TaskStatus.SUCCESSFUL, "Successfully purged UserPasswordHistory")
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    self.info("Not purging password histories, a unique password policy exists")
 | 
			
		||||
    self.set_status(
 | 
			
		||||
        TaskStatus.SUCCESSFUL, "Not purging password histories, a unique password policy exists"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@actor(description=_("Remove user password history that are too old."))
 | 
			
		||||
def trim_password_histories():
 | 
			
		||||
@CELERY_APP.task(bind=True, base=SystemTask)
 | 
			
		||||
def trim_password_histories(self: SystemTask):
 | 
			
		||||
    """Removes rows from UserPasswordHistory older than
 | 
			
		||||
    the `n` most recent entries.
 | 
			
		||||
 | 
			
		||||
@ -39,8 +37,6 @@ def trim_password_histories():
 | 
			
		||||
    UniquePasswordPolicy policies.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    self: Task = CurrentTask.get_task()
 | 
			
		||||
 | 
			
		||||
    # No policy, we'll let the cleanup above do its thing
 | 
			
		||||
    if not UniquePasswordPolicy.objects.exists():
 | 
			
		||||
        return
 | 
			
		||||
@ -67,4 +63,4 @@ def trim_password_histories():
 | 
			
		||||
 | 
			
		||||
    num_deleted, _ = UserPasswordHistory.objects.exclude(pk__in=all_pks_to_keep).delete()
 | 
			
		||||
    LOGGER.debug("Deleted stale password history records", count=num_deleted)
 | 
			
		||||
    self.info(f"Delete {num_deleted} stale password history records")
 | 
			
		||||
    self.set_status(TaskStatus.SUCCESSFUL, f"Delete {num_deleted} stale password history records")
 | 
			
		||||
 | 
			
		||||
@ -76,7 +76,7 @@ class TestCheckAndPurgePasswordHistory(TestCase):
 | 
			
		||||
        self.assertTrue(UserPasswordHistory.objects.exists())
 | 
			
		||||
 | 
			
		||||
        # Run the task - should purge since no policy is in use
 | 
			
		||||
        check_and_purge_password_history.send()
 | 
			
		||||
        check_and_purge_password_history()
 | 
			
		||||
 | 
			
		||||
        # Verify the table is empty
 | 
			
		||||
        self.assertFalse(UserPasswordHistory.objects.exists())
 | 
			
		||||
@ -99,7 +99,7 @@ class TestCheckAndPurgePasswordHistory(TestCase):
 | 
			
		||||
        self.assertTrue(UserPasswordHistory.objects.exists())
 | 
			
		||||
 | 
			
		||||
        # Run the task - should NOT purge since a policy is in use
 | 
			
		||||
        check_and_purge_password_history.send()
 | 
			
		||||
        check_and_purge_password_history()
 | 
			
		||||
 | 
			
		||||
        # Verify the entries still exist
 | 
			
		||||
        self.assertTrue(UserPasswordHistory.objects.exists())
 | 
			
		||||
@ -142,7 +142,7 @@ class TestTrimPasswordHistory(TestCase):
 | 
			
		||||
            enabled=True,
 | 
			
		||||
            order=0,
 | 
			
		||||
        )
 | 
			
		||||
        trim_password_histories.send()
 | 
			
		||||
        trim_password_histories.delay()
 | 
			
		||||
        user_pwd_history_qs = UserPasswordHistory.objects.filter(user=self.user)
 | 
			
		||||
        self.assertEqual(len(user_pwd_history_qs), 1)
 | 
			
		||||
 | 
			
		||||
@ -159,7 +159,7 @@ class TestTrimPasswordHistory(TestCase):
 | 
			
		||||
            enabled=False,
 | 
			
		||||
            order=0,
 | 
			
		||||
        )
 | 
			
		||||
        trim_password_histories.send()
 | 
			
		||||
        trim_password_histories.delay()
 | 
			
		||||
        self.assertTrue(UserPasswordHistory.objects.filter(user=self.user).exists())
 | 
			
		||||
 | 
			
		||||
    def test_trim_password_history_fewer_records_than_maximum_is_no_op(self):
 | 
			
		||||
@ -174,5 +174,5 @@ class TestTrimPasswordHistory(TestCase):
 | 
			
		||||
            enabled=True,
 | 
			
		||||
            order=0,
 | 
			
		||||
        )
 | 
			
		||||
        trim_password_histories.send()
 | 
			
		||||
        trim_password_histories.delay()
 | 
			
		||||
        self.assertTrue(UserPasswordHistory.objects.filter(user=self.user).exists())
 | 
			
		||||
 | 
			
		||||
@ -55,5 +55,5 @@ class GoogleWorkspaceProviderViewSet(OutgoingSyncProviderStatusMixin, UsedByMixi
 | 
			
		||||
    ]
 | 
			
		||||
    search_fields = ["name"]
 | 
			
		||||
    ordering = ["name"]
 | 
			
		||||
    sync_task = google_workspace_sync
 | 
			
		||||
    sync_single_task = google_workspace_sync
 | 
			
		||||
    sync_objects_task = google_workspace_sync_objects
 | 
			
		||||
 | 
			
		||||
@ -7,7 +7,6 @@ from django.db import models
 | 
			
		||||
from django.db.models import QuerySet
 | 
			
		||||
from django.templatetags.static import static
 | 
			
		||||
from django.utils.translation import gettext_lazy as _
 | 
			
		||||
from dramatiq.actor import Actor
 | 
			
		||||
from google.oauth2.service_account import Credentials
 | 
			
		||||
from rest_framework.serializers import Serializer
 | 
			
		||||
 | 
			
		||||
@ -111,12 +110,6 @@ class GoogleWorkspaceProvider(OutgoingSyncProvider, BackchannelProvider):
 | 
			
		||||
        help_text=_("Property mappings used for group creation/updating."),
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def sync_actor(self) -> Actor:
 | 
			
		||||
        from authentik.enterprise.providers.google_workspace.tasks import google_workspace_sync
 | 
			
		||||
 | 
			
		||||
        return google_workspace_sync
 | 
			
		||||
 | 
			
		||||
    def client_for_model(
 | 
			
		||||
        self,
 | 
			
		||||
        model: type[User | Group | GoogleWorkspaceProviderUser | GoogleWorkspaceProviderGroup],
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										13
									
								
								authentik/enterprise/providers/google_workspace/settings.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								authentik/enterprise/providers/google_workspace/settings.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,13 @@
 | 
			
		||||
"""Google workspace provider task Settings"""
 | 
			
		||||
 | 
			
		||||
from celery.schedules import crontab
 | 
			
		||||
 | 
			
		||||
from authentik.lib.utils.time import fqdn_rand
 | 
			
		||||
 | 
			
		||||
CELERY_BEAT_SCHEDULE = {
 | 
			
		||||
    "providers_google_workspace_sync": {
 | 
			
		||||
        "task": "authentik.enterprise.providers.google_workspace.tasks.google_workspace_sync_all",
 | 
			
		||||
        "schedule": crontab(minute=fqdn_rand("google_workspace_sync_all"), hour="*/4"),
 | 
			
		||||
        "options": {"queue": "authentik_scheduled"},
 | 
			
		||||
    },
 | 
			
		||||
}
 | 
			
		||||
@ -2,13 +2,15 @@
 | 
			
		||||
 | 
			
		||||
from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProvider
 | 
			
		||||
from authentik.enterprise.providers.google_workspace.tasks import (
 | 
			
		||||
    google_workspace_sync_direct_dispatch,
 | 
			
		||||
    google_workspace_sync_m2m_dispatch,
 | 
			
		||||
    google_workspace_sync,
 | 
			
		||||
    google_workspace_sync_direct,
 | 
			
		||||
    google_workspace_sync_m2m,
 | 
			
		||||
)
 | 
			
		||||
from authentik.lib.sync.outgoing.signals import register_signals
 | 
			
		||||
 | 
			
		||||
register_signals(
 | 
			
		||||
    GoogleWorkspaceProvider,
 | 
			
		||||
    task_sync_direct_dispatch=google_workspace_sync_direct_dispatch,
 | 
			
		||||
    task_sync_m2m_dispatch=google_workspace_sync_m2m_dispatch,
 | 
			
		||||
    task_sync_single=google_workspace_sync,
 | 
			
		||||
    task_sync_direct=google_workspace_sync_direct,
 | 
			
		||||
    task_sync_m2m=google_workspace_sync_m2m,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -1,48 +1,37 @@
 | 
			
		||||
"""Google Provider tasks"""
 | 
			
		||||
 | 
			
		||||
from django.utils.translation import gettext_lazy as _
 | 
			
		||||
from dramatiq.actor import actor
 | 
			
		||||
 | 
			
		||||
from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProvider
 | 
			
		||||
from authentik.events.system_tasks import SystemTask
 | 
			
		||||
from authentik.lib.sync.outgoing.exceptions import TransientSyncException
 | 
			
		||||
from authentik.lib.sync.outgoing.tasks import SyncTasks
 | 
			
		||||
from authentik.root.celery import CELERY_APP
 | 
			
		||||
 | 
			
		||||
sync_tasks = SyncTasks(GoogleWorkspaceProvider)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@actor(description=_("Sync Google Workspace provider objects."))
 | 
			
		||||
@CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True)
 | 
			
		||||
def google_workspace_sync_objects(*args, **kwargs):
 | 
			
		||||
    return sync_tasks.sync_objects(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@actor(description=_("Full sync for Google Workspace provider."))
 | 
			
		||||
def google_workspace_sync(provider_pk: int, *args, **kwargs):
 | 
			
		||||
@CELERY_APP.task(
 | 
			
		||||
    base=SystemTask, bind=True, autoretry_for=(TransientSyncException,), retry_backoff=True
 | 
			
		||||
)
 | 
			
		||||
def google_workspace_sync(self, provider_pk: int, *args, **kwargs):
 | 
			
		||||
    """Run full sync for Google Workspace provider"""
 | 
			
		||||
    return sync_tasks.sync(provider_pk, google_workspace_sync_objects)
 | 
			
		||||
    return sync_tasks.sync_single(self, provider_pk, google_workspace_sync_objects)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@actor(description=_("Sync a direct object (user, group) for Google Workspace provider."))
 | 
			
		||||
@CELERY_APP.task()
 | 
			
		||||
def google_workspace_sync_all():
 | 
			
		||||
    return sync_tasks.sync_all(google_workspace_sync)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True)
 | 
			
		||||
def google_workspace_sync_direct(*args, **kwargs):
 | 
			
		||||
    return sync_tasks.sync_signal_direct(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@actor(
 | 
			
		||||
    description=_(
 | 
			
		||||
        "Dispatch syncs for a direct object (user, group) for Google Workspace providers."
 | 
			
		||||
    )
 | 
			
		||||
)
 | 
			
		||||
def google_workspace_sync_direct_dispatch(*args, **kwargs):
 | 
			
		||||
    return sync_tasks.sync_signal_direct_dispatch(google_workspace_sync_direct, *args, **kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@actor(description=_("Sync a related object (memberships) for Google Workspace provider."))
 | 
			
		||||
@CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True)
 | 
			
		||||
def google_workspace_sync_m2m(*args, **kwargs):
 | 
			
		||||
    return sync_tasks.sync_signal_m2m(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@actor(
 | 
			
		||||
    description=_(
 | 
			
		||||
        "Dispatch syncs for a related object (memberships) for Google Workspace providers."
 | 
			
		||||
    )
 | 
			
		||||
)
 | 
			
		||||
def google_workspace_sync_m2m_dispatch(*args, **kwargs):
 | 
			
		||||
    return sync_tasks.sync_signal_m2m_dispatch(google_workspace_sync_m2m, *args, **kwargs)
 | 
			
		||||
 | 
			
		||||
@ -324,7 +324,7 @@ class GoogleWorkspaceGroupTests(TestCase):
 | 
			
		||||
            "authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials",
 | 
			
		||||
            MagicMock(return_value={"developerKey": self.api_key, "http": http}),
 | 
			
		||||
        ):
 | 
			
		||||
            google_workspace_sync.send(self.provider.pk).get_result()
 | 
			
		||||
            google_workspace_sync.delay(self.provider.pk).get()
 | 
			
		||||
            self.assertTrue(
 | 
			
		||||
                GoogleWorkspaceProviderGroup.objects.filter(
 | 
			
		||||
                    group=different_group, provider=self.provider
 | 
			
		||||
 | 
			
		||||
@ -302,7 +302,7 @@ class GoogleWorkspaceUserTests(TestCase):
 | 
			
		||||
            "authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials",
 | 
			
		||||
            MagicMock(return_value={"developerKey": self.api_key, "http": http}),
 | 
			
		||||
        ):
 | 
			
		||||
            google_workspace_sync.send(self.provider.pk).get_result()
 | 
			
		||||
            google_workspace_sync.delay(self.provider.pk).get()
 | 
			
		||||
            self.assertTrue(
 | 
			
		||||
                GoogleWorkspaceProviderUser.objects.filter(
 | 
			
		||||
                    user=different_user, provider=self.provider
 | 
			
		||||
 | 
			
		||||
@ -53,5 +53,5 @@ class MicrosoftEntraProviderViewSet(OutgoingSyncProviderStatusMixin, UsedByMixin
 | 
			
		||||
    ]
 | 
			
		||||
    search_fields = ["name"]
 | 
			
		||||
    ordering = ["name"]
 | 
			
		||||
    sync_task = microsoft_entra_sync
 | 
			
		||||
    sync_single_task = microsoft_entra_sync
 | 
			
		||||
    sync_objects_task = microsoft_entra_sync_objects
 | 
			
		||||
 | 
			
		||||
@ -8,7 +8,6 @@ from django.db import models
 | 
			
		||||
from django.db.models import QuerySet
 | 
			
		||||
from django.templatetags.static import static
 | 
			
		||||
from django.utils.translation import gettext_lazy as _
 | 
			
		||||
from dramatiq.actor import Actor
 | 
			
		||||
from rest_framework.serializers import Serializer
 | 
			
		||||
 | 
			
		||||
from authentik.core.models import (
 | 
			
		||||
@ -100,12 +99,6 @@ class MicrosoftEntraProvider(OutgoingSyncProvider, BackchannelProvider):
 | 
			
		||||
        help_text=_("Property mappings used for group creation/updating."),
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def sync_actor(self) -> Actor:
 | 
			
		||||
        from authentik.enterprise.providers.microsoft_entra.tasks import microsoft_entra_sync
 | 
			
		||||
 | 
			
		||||
        return microsoft_entra_sync
 | 
			
		||||
 | 
			
		||||
    def client_for_model(
 | 
			
		||||
        self,
 | 
			
		||||
        model: type[User | Group | MicrosoftEntraProviderUser | MicrosoftEntraProviderGroup],
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										13
									
								
								authentik/enterprise/providers/microsoft_entra/settings.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								authentik/enterprise/providers/microsoft_entra/settings.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,13 @@
 | 
			
		||||
"""Microsoft Entra provider task Settings"""
 | 
			
		||||
 | 
			
		||||
from celery.schedules import crontab
 | 
			
		||||
 | 
			
		||||
from authentik.lib.utils.time import fqdn_rand
 | 
			
		||||
 | 
			
		||||
CELERY_BEAT_SCHEDULE = {
 | 
			
		||||
    "providers_microsoft_entra_sync": {
 | 
			
		||||
        "task": "authentik.enterprise.providers.microsoft_entra.tasks.microsoft_entra_sync_all",
 | 
			
		||||
        "schedule": crontab(minute=fqdn_rand("microsoft_entra_sync_all"), hour="*/4"),
 | 
			
		||||
        "options": {"queue": "authentik_scheduled"},
 | 
			
		||||
    },
 | 
			
		||||
}
 | 
			
		||||
@ -2,13 +2,15 @@
 | 
			
		||||
 | 
			
		||||
from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProvider
 | 
			
		||||
from authentik.enterprise.providers.microsoft_entra.tasks import (
 | 
			
		||||
    microsoft_entra_sync_direct_dispatch,
 | 
			
		||||
    microsoft_entra_sync_m2m_dispatch,
 | 
			
		||||
    microsoft_entra_sync,
 | 
			
		||||
    microsoft_entra_sync_direct,
 | 
			
		||||
    microsoft_entra_sync_m2m,
 | 
			
		||||
)
 | 
			
		||||
from authentik.lib.sync.outgoing.signals import register_signals
 | 
			
		||||
 | 
			
		||||
register_signals(
 | 
			
		||||
    MicrosoftEntraProvider,
 | 
			
		||||
    task_sync_direct_dispatch=microsoft_entra_sync_direct_dispatch,
 | 
			
		||||
    task_sync_m2m_dispatch=microsoft_entra_sync_m2m_dispatch,
 | 
			
		||||
    task_sync_single=microsoft_entra_sync,
 | 
			
		||||
    task_sync_direct=microsoft_entra_sync_direct,
 | 
			
		||||
    task_sync_m2m=microsoft_entra_sync_m2m,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -1,46 +1,37 @@
 | 
			
		||||
"""Microsoft Entra Provider tasks"""
 | 
			
		||||
 | 
			
		||||
from django.utils.translation import gettext_lazy as _
 | 
			
		||||
from dramatiq.actor import actor
 | 
			
		||||
 | 
			
		||||
from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProvider
 | 
			
		||||
from authentik.events.system_tasks import SystemTask
 | 
			
		||||
from authentik.lib.sync.outgoing.exceptions import TransientSyncException
 | 
			
		||||
from authentik.lib.sync.outgoing.tasks import SyncTasks
 | 
			
		||||
from authentik.root.celery import CELERY_APP
 | 
			
		||||
 | 
			
		||||
sync_tasks = SyncTasks(MicrosoftEntraProvider)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@actor(description=_("Sync Microsoft Entra provider objects."))
 | 
			
		||||
@CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True)
 | 
			
		||||
def microsoft_entra_sync_objects(*args, **kwargs):
 | 
			
		||||
    return sync_tasks.sync_objects(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@actor(description=_("Full sync for Microsoft Entra provider."))
 | 
			
		||||
def microsoft_entra_sync(provider_pk: int, *args, **kwargs):
 | 
			
		||||
@CELERY_APP.task(
 | 
			
		||||
    base=SystemTask, bind=True, autoretry_for=(TransientSyncException,), retry_backoff=True
 | 
			
		||||
)
 | 
			
		||||
def microsoft_entra_sync(self, provider_pk: int, *args, **kwargs):
 | 
			
		||||
    """Run full sync for Microsoft Entra provider"""
 | 
			
		||||
    return sync_tasks.sync(provider_pk, microsoft_entra_sync_objects)
 | 
			
		||||
    return sync_tasks.sync_single(self, provider_pk, microsoft_entra_sync_objects)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@actor(description=_("Sync a direct object (user, group) for Microsoft Entra provider."))
 | 
			
		||||
@CELERY_APP.task()
 | 
			
		||||
def microsoft_entra_sync_all():
 | 
			
		||||
    return sync_tasks.sync_all(microsoft_entra_sync)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True)
 | 
			
		||||
def microsoft_entra_sync_direct(*args, **kwargs):
 | 
			
		||||
    return sync_tasks.sync_signal_direct(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@actor(
 | 
			
		||||
    description=_("Dispatch syncs for a direct object (user, group) for Microsoft Entra providers.")
 | 
			
		||||
)
 | 
			
		||||
def microsoft_entra_sync_direct_dispatch(*args, **kwargs):
 | 
			
		||||
    return sync_tasks.sync_signal_direct_dispatch(microsoft_entra_sync_direct, *args, **kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@actor(description=_("Sync a related object (memberships) for Microsoft Entra provider."))
 | 
			
		||||
@CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True)
 | 
			
		||||
def microsoft_entra_sync_m2m(*args, **kwargs):
 | 
			
		||||
    return sync_tasks.sync_signal_m2m(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@actor(
 | 
			
		||||
    description=_(
 | 
			
		||||
        "Dispatch syncs for a related object (memberships) for Microsoft Entra providers."
 | 
			
		||||
    )
 | 
			
		||||
)
 | 
			
		||||
def microsoft_entra_sync_m2m_dispatch(*args, **kwargs):
 | 
			
		||||
    return sync_tasks.sync_signal_m2m_dispatch(microsoft_entra_sync_m2m, *args, **kwargs)
 | 
			
		||||
 | 
			
		||||
@ -252,13 +252,9 @@ class MicrosoftEntraGroupTests(TestCase):
 | 
			
		||||
            member_add.assert_called_once()
 | 
			
		||||
            self.assertEqual(
 | 
			
		||||
                member_add.call_args[0][0].odata_id,
 | 
			
		||||
                f"https://graph.microsoft.com/v1.0/directoryObjects/{
 | 
			
		||||
                    MicrosoftEntraProviderUser.objects.filter(
 | 
			
		||||
                f"https://graph.microsoft.com/v1.0/directoryObjects/{MicrosoftEntraProviderUser.objects.filter(
 | 
			
		||||
                        provider=self.provider,
 | 
			
		||||
                    )
 | 
			
		||||
                    .first()
 | 
			
		||||
                    .microsoft_id
 | 
			
		||||
                }",
 | 
			
		||||
                    ).first().microsoft_id}",
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def test_group_create_member_remove(self):
 | 
			
		||||
@ -315,13 +311,9 @@ class MicrosoftEntraGroupTests(TestCase):
 | 
			
		||||
            member_add.assert_called_once()
 | 
			
		||||
            self.assertEqual(
 | 
			
		||||
                member_add.call_args[0][0].odata_id,
 | 
			
		||||
                f"https://graph.microsoft.com/v1.0/directoryObjects/{
 | 
			
		||||
                    MicrosoftEntraProviderUser.objects.filter(
 | 
			
		||||
                f"https://graph.microsoft.com/v1.0/directoryObjects/{MicrosoftEntraProviderUser.objects.filter(
 | 
			
		||||
                        provider=self.provider,
 | 
			
		||||
                    )
 | 
			
		||||
                    .first()
 | 
			
		||||
                    .microsoft_id
 | 
			
		||||
                }",
 | 
			
		||||
                    ).first().microsoft_id}",
 | 
			
		||||
            )
 | 
			
		||||
            member_remove.assert_called_once()
 | 
			
		||||
 | 
			
		||||
@ -421,7 +413,7 @@ class MicrosoftEntraGroupTests(TestCase):
 | 
			
		||||
                ),
 | 
			
		||||
            ) as group_list,
 | 
			
		||||
        ):
 | 
			
		||||
            microsoft_entra_sync.send(self.provider.pk).get_result()
 | 
			
		||||
            microsoft_entra_sync.delay(self.provider.pk).get()
 | 
			
		||||
            self.assertTrue(
 | 
			
		||||
                MicrosoftEntraProviderGroup.objects.filter(
 | 
			
		||||
                    group=different_group, provider=self.provider
 | 
			
		||||
 | 
			
		||||
@ -397,7 +397,7 @@ class MicrosoftEntraUserTests(APITestCase):
 | 
			
		||||
                AsyncMock(return_value=GroupCollectionResponse(value=[])),
 | 
			
		||||
            ),
 | 
			
		||||
        ):
 | 
			
		||||
            microsoft_entra_sync.send(self.provider.pk).get_result()
 | 
			
		||||
            microsoft_entra_sync.delay(self.provider.pk).get()
 | 
			
		||||
            self.assertTrue(
 | 
			
		||||
                MicrosoftEntraProviderUser.objects.filter(
 | 
			
		||||
                    user=different_user, provider=self.provider
 | 
			
		||||
 | 
			
		||||
@ -17,7 +17,6 @@ from authentik.crypto.models import CertificateKeyPair
 | 
			
		||||
from authentik.lib.models import CreatedUpdatedModel
 | 
			
		||||
from authentik.lib.utils.time import timedelta_from_string, timedelta_string_validator
 | 
			
		||||
from authentik.providers.oauth2.models import JWTAlgorithms, OAuth2Provider
 | 
			
		||||
from authentik.tasks.models import TasksModel
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class EventTypes(models.TextChoices):
 | 
			
		||||
@ -43,7 +42,7 @@ class SSFEventStatus(models.TextChoices):
 | 
			
		||||
    SENT = "sent"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SSFProvider(TasksModel, BackchannelProvider):
 | 
			
		||||
class SSFProvider(BackchannelProvider):
 | 
			
		||||
    """Shared Signals Framework provider to allow applications to
 | 
			
		||||
    receive user events from authentik."""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -18,7 +18,7 @@ from authentik.enterprise.providers.ssf.models import (
 | 
			
		||||
    EventTypes,
 | 
			
		||||
    SSFProvider,
 | 
			
		||||
)
 | 
			
		||||
from authentik.enterprise.providers.ssf.tasks import send_ssf_events
 | 
			
		||||
from authentik.enterprise.providers.ssf.tasks import send_ssf_event
 | 
			
		||||
from authentik.events.middleware import audit_ignore
 | 
			
		||||
from authentik.stages.authenticator.models import Device
 | 
			
		||||
from authentik.stages.authenticator_duo.models import DuoDevice
 | 
			
		||||
@ -66,7 +66,7 @@ def ssf_user_session_delete_session_revoked(sender, instance: AuthenticatedSessi
 | 
			
		||||
 | 
			
		||||
    As this signal is also triggered with a regular logout, we can't be sure
 | 
			
		||||
    if the session has been deleted by an admin or by the user themselves."""
 | 
			
		||||
    send_ssf_events(
 | 
			
		||||
    send_ssf_event(
 | 
			
		||||
        EventTypes.CAEP_SESSION_REVOKED,
 | 
			
		||||
        {
 | 
			
		||||
            "initiating_entity": "user",
 | 
			
		||||
@ -88,7 +88,7 @@ def ssf_user_session_delete_session_revoked(sender, instance: AuthenticatedSessi
 | 
			
		||||
@receiver(password_changed)
 | 
			
		||||
def ssf_password_changed_cred_change(sender, user: User, password: str | None, **_):
 | 
			
		||||
    """Credential change trigger (password changed)"""
 | 
			
		||||
    send_ssf_events(
 | 
			
		||||
    send_ssf_event(
 | 
			
		||||
        EventTypes.CAEP_CREDENTIAL_CHANGE,
 | 
			
		||||
        {
 | 
			
		||||
            "credential_type": "password",
 | 
			
		||||
@ -126,7 +126,7 @@ def ssf_device_post_save(sender: type[Model], instance: Device, created: bool, *
 | 
			
		||||
    }
 | 
			
		||||
    if isinstance(instance, WebAuthnDevice) and instance.aaguid != UNKNOWN_DEVICE_TYPE_AAGUID:
 | 
			
		||||
        data["fido2_aaguid"] = instance.aaguid
 | 
			
		||||
    send_ssf_events(
 | 
			
		||||
    send_ssf_event(
 | 
			
		||||
        EventTypes.CAEP_CREDENTIAL_CHANGE,
 | 
			
		||||
        data,
 | 
			
		||||
        sub_id={
 | 
			
		||||
@ -153,7 +153,7 @@ def ssf_device_post_delete(sender: type[Model], instance: Device, **_):
 | 
			
		||||
    }
 | 
			
		||||
    if isinstance(instance, WebAuthnDevice) and instance.aaguid != UNKNOWN_DEVICE_TYPE_AAGUID:
 | 
			
		||||
        data["fido2_aaguid"] = instance.aaguid
 | 
			
		||||
    send_ssf_events(
 | 
			
		||||
    send_ssf_event(
 | 
			
		||||
        EventTypes.CAEP_CREDENTIAL_CHANGE,
 | 
			
		||||
        data,
 | 
			
		||||
        sub_id={
 | 
			
		||||
 | 
			
		||||
@ -1,11 +1,7 @@
 | 
			
		||||
from typing import Any
 | 
			
		||||
from uuid import UUID
 | 
			
		||||
 | 
			
		||||
from celery import group
 | 
			
		||||
from django.http import HttpRequest
 | 
			
		||||
from django.utils.timezone import now
 | 
			
		||||
from django.utils.translation import gettext_lazy as _
 | 
			
		||||
from django_dramatiq_postgres.middleware import CurrentTask
 | 
			
		||||
from dramatiq.actor import actor
 | 
			
		||||
from requests.exceptions import RequestException
 | 
			
		||||
from structlog.stdlib import get_logger
 | 
			
		||||
 | 
			
		||||
@ -17,16 +13,19 @@ from authentik.enterprise.providers.ssf.models import (
 | 
			
		||||
    Stream,
 | 
			
		||||
    StreamEvent,
 | 
			
		||||
)
 | 
			
		||||
from authentik.events.logs import LogEvent
 | 
			
		||||
from authentik.events.models import TaskStatus
 | 
			
		||||
from authentik.events.system_tasks import SystemTask
 | 
			
		||||
from authentik.lib.utils.http import get_http_session
 | 
			
		||||
from authentik.lib.utils.time import timedelta_from_string
 | 
			
		||||
from authentik.policies.engine import PolicyEngine
 | 
			
		||||
from authentik.tasks.models import Task
 | 
			
		||||
from authentik.root.celery import CELERY_APP
 | 
			
		||||
 | 
			
		||||
session = get_http_session()
 | 
			
		||||
LOGGER = get_logger()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def send_ssf_events(
 | 
			
		||||
def send_ssf_event(
 | 
			
		||||
    event_type: EventTypes,
 | 
			
		||||
    data: dict,
 | 
			
		||||
    stream_filter: dict | None = None,
 | 
			
		||||
@ -34,7 +33,7 @@ def send_ssf_events(
 | 
			
		||||
    **extra_data,
 | 
			
		||||
):
 | 
			
		||||
    """Wrapper to send an SSF event to multiple streams"""
 | 
			
		||||
    events_data = {}
 | 
			
		||||
    payload = []
 | 
			
		||||
    if not stream_filter:
 | 
			
		||||
        stream_filter = {}
 | 
			
		||||
    stream_filter["events_requested__contains"] = [event_type]
 | 
			
		||||
@ -42,22 +41,16 @@ def send_ssf_events(
 | 
			
		||||
        extra_data.setdefault("txn", request.request_id)
 | 
			
		||||
    for stream in Stream.objects.filter(**stream_filter):
 | 
			
		||||
        event_data = stream.prepare_event_payload(event_type, data, **extra_data)
 | 
			
		||||
        events_data[stream.uuid] = event_data
 | 
			
		||||
    ssf_events_dispatch.send(events_data)
 | 
			
		||||
        payload.append((str(stream.uuid), event_data))
 | 
			
		||||
    return _send_ssf_event.delay(payload)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@actor(description=_("Dispatch SSF events."))
 | 
			
		||||
def ssf_events_dispatch(events_data: dict[str, dict[str, Any]]):
 | 
			
		||||
    for stream_uuid, event_data in events_data.items():
 | 
			
		||||
        stream = Stream.objects.filter(pk=stream_uuid).first()
 | 
			
		||||
        if not stream:
 | 
			
		||||
            continue
 | 
			
		||||
        send_ssf_event.send_with_options(args=(stream_uuid, event_data), rel_obj=stream.provider)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _check_app_access(stream: Stream, event_data: dict) -> bool:
 | 
			
		||||
def _check_app_access(stream_uuid: str, event_data: dict) -> bool:
 | 
			
		||||
    """Check if event is related to user and if so, check
 | 
			
		||||
    if the user has access to the application"""
 | 
			
		||||
    stream = Stream.objects.filter(pk=stream_uuid).first()
 | 
			
		||||
    if not stream:
 | 
			
		||||
        return False
 | 
			
		||||
    # `event_data` is a dict version of a StreamEvent
 | 
			
		||||
    sub_id = event_data.get("payload", {}).get("sub_id", {})
 | 
			
		||||
    email = sub_id.get("user", {}).get("email", None)
 | 
			
		||||
@ -72,22 +65,42 @@ def _check_app_access(stream: Stream, event_data: dict) -> bool:
 | 
			
		||||
    return engine.passing
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@actor(description=_("Send an SSF event."))
 | 
			
		||||
def send_ssf_event(stream_uuid: UUID, event_data: dict[str, Any]):
 | 
			
		||||
    self: Task = CurrentTask.get_task()
 | 
			
		||||
@CELERY_APP.task()
 | 
			
		||||
def _send_ssf_event(event_data: list[tuple[str, dict]]):
 | 
			
		||||
    tasks = []
 | 
			
		||||
    for stream, data in event_data:
 | 
			
		||||
        if not _check_app_access(stream, data):
 | 
			
		||||
            continue
 | 
			
		||||
        event = StreamEvent.objects.create(**data)
 | 
			
		||||
        tasks.extend(send_single_ssf_event(stream, str(event.uuid)))
 | 
			
		||||
    main_task = group(*tasks)
 | 
			
		||||
    main_task()
 | 
			
		||||
 | 
			
		||||
    stream = Stream.objects.filter(pk=stream_uuid).first()
 | 
			
		||||
 | 
			
		||||
def send_single_ssf_event(stream_id: str, evt_id: str):
 | 
			
		||||
    stream = Stream.objects.filter(pk=stream_id).first()
 | 
			
		||||
    if not stream:
 | 
			
		||||
        return
 | 
			
		||||
    if not _check_app_access(stream, event_data):
 | 
			
		||||
    event = StreamEvent.objects.filter(pk=evt_id).first()
 | 
			
		||||
    if not event:
 | 
			
		||||
        return
 | 
			
		||||
    event = StreamEvent.objects.create(**event_data)
 | 
			
		||||
    self.set_uid(event.pk)
 | 
			
		||||
    if event.status == SSFEventStatus.SENT:
 | 
			
		||||
        return
 | 
			
		||||
    if stream.delivery_method != DeliveryMethods.RISC_PUSH:
 | 
			
		||||
        return
 | 
			
		||||
    if stream.delivery_method == DeliveryMethods.RISC_PUSH:
 | 
			
		||||
        return [ssf_push_event.si(str(event.pk))]
 | 
			
		||||
    return []
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@CELERY_APP.task(bind=True, base=SystemTask)
 | 
			
		||||
def ssf_push_event(self: SystemTask, event_id: str):
 | 
			
		||||
    self.save_on_success = False
 | 
			
		||||
    event = StreamEvent.objects.filter(pk=event_id).first()
 | 
			
		||||
    if not event:
 | 
			
		||||
        return
 | 
			
		||||
    self.set_uid(event_id)
 | 
			
		||||
    if event.status == SSFEventStatus.SENT:
 | 
			
		||||
        self.set_status(TaskStatus.SUCCESSFUL)
 | 
			
		||||
        return
 | 
			
		||||
    try:
 | 
			
		||||
        response = session.post(
 | 
			
		||||
            event.stream.endpoint_url,
 | 
			
		||||
@ -97,17 +110,26 @@ def send_ssf_event(stream_uuid: UUID, event_data: dict[str, Any]):
 | 
			
		||||
        response.raise_for_status()
 | 
			
		||||
        event.status = SSFEventStatus.SENT
 | 
			
		||||
        event.save()
 | 
			
		||||
        self.set_status(TaskStatus.SUCCESSFUL)
 | 
			
		||||
        return
 | 
			
		||||
    except RequestException as exc:
 | 
			
		||||
        LOGGER.warning("Failed to send SSF event", exc=exc)
 | 
			
		||||
        self.set_status(TaskStatus.ERROR)
 | 
			
		||||
        attrs = {}
 | 
			
		||||
        if exc.response:
 | 
			
		||||
            attrs["response"] = {
 | 
			
		||||
                "content": exc.response.text,
 | 
			
		||||
                "status": exc.response.status_code,
 | 
			
		||||
            }
 | 
			
		||||
        self.warning(exc)
 | 
			
		||||
        self.warning("Failed to send request", **attrs)
 | 
			
		||||
        self.set_error(
 | 
			
		||||
            exc,
 | 
			
		||||
            LogEvent(
 | 
			
		||||
                _("Failed to send request"),
 | 
			
		||||
                log_level="warning",
 | 
			
		||||
                logger=self.__name__,
 | 
			
		||||
                attributes=attrs,
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
        # Re-up the expiry of the stream event
 | 
			
		||||
        event.expires = now() + timedelta_from_string(event.stream.provider.event_retention)
 | 
			
		||||
        event.status = SSFEventStatus.PENDING_FAILED
 | 
			
		||||
 | 
			
		||||
@ -13,7 +13,7 @@ from authentik.enterprise.providers.ssf.models import (
 | 
			
		||||
    SSFProvider,
 | 
			
		||||
    Stream,
 | 
			
		||||
)
 | 
			
		||||
from authentik.enterprise.providers.ssf.tasks import send_ssf_events
 | 
			
		||||
from authentik.enterprise.providers.ssf.tasks import send_ssf_event
 | 
			
		||||
from authentik.enterprise.providers.ssf.views.base import SSFView
 | 
			
		||||
 | 
			
		||||
LOGGER = get_logger()
 | 
			
		||||
@ -109,7 +109,7 @@ class StreamView(SSFView):
 | 
			
		||||
                "User does not have permission to create stream for this provider."
 | 
			
		||||
            )
 | 
			
		||||
        instance: Stream = stream.save(provider=self.provider)
 | 
			
		||||
        send_ssf_events(
 | 
			
		||||
        send_ssf_event(
 | 
			
		||||
            EventTypes.SET_VERIFICATION,
 | 
			
		||||
            {
 | 
			
		||||
                "state": None,
 | 
			
		||||
 | 
			
		||||
@ -1,5 +1,17 @@
 | 
			
		||||
"""Enterprise additional settings"""
 | 
			
		||||
 | 
			
		||||
from celery.schedules import crontab
 | 
			
		||||
 | 
			
		||||
from authentik.lib.utils.time import fqdn_rand
 | 
			
		||||
 | 
			
		||||
CELERY_BEAT_SCHEDULE = {
 | 
			
		||||
    "enterprise_update_usage": {
 | 
			
		||||
        "task": "authentik.enterprise.tasks.enterprise_update_usage",
 | 
			
		||||
        "schedule": crontab(minute=fqdn_rand("enterprise_update_usage"), hour="*/2"),
 | 
			
		||||
        "options": {"queue": "authentik_scheduled"},
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TENANT_APPS = [
 | 
			
		||||
    "authentik.enterprise.audit",
 | 
			
		||||
    "authentik.enterprise.policies.unique_password",
 | 
			
		||||
 | 
			
		||||
@ -10,7 +10,6 @@ from django.utils.timezone import get_current_timezone
 | 
			
		||||
from authentik.enterprise.license import CACHE_KEY_ENTERPRISE_LICENSE
 | 
			
		||||
from authentik.enterprise.models import License
 | 
			
		||||
from authentik.enterprise.tasks import enterprise_update_usage
 | 
			
		||||
from authentik.tasks.schedules.models import Schedule
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@receiver(pre_save, sender=License)
 | 
			
		||||
@ -27,7 +26,7 @@ def pre_save_license(sender: type[License], instance: License, **_):
 | 
			
		||||
def post_save_license(sender: type[License], instance: License, **_):
 | 
			
		||||
    """Trigger license usage calculation when license is saved"""
 | 
			
		||||
    cache.delete(CACHE_KEY_ENTERPRISE_LICENSE)
 | 
			
		||||
    Schedule.dispatch_by_actor(enterprise_update_usage)
 | 
			
		||||
    enterprise_update_usage.delay()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@receiver(post_delete, sender=License)
 | 
			
		||||
 | 
			
		||||
@ -1,11 +1,14 @@
 | 
			
		||||
"""Enterprise tasks"""
 | 
			
		||||
 | 
			
		||||
from django.utils.translation import gettext_lazy as _
 | 
			
		||||
from dramatiq.actor import actor
 | 
			
		||||
 | 
			
		||||
from authentik.enterprise.license import LicenseKey
 | 
			
		||||
from authentik.events.models import TaskStatus
 | 
			
		||||
from authentik.events.system_tasks import SystemTask, prefill_task
 | 
			
		||||
from authentik.root.celery import CELERY_APP
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@actor(description=_("Update enterprise license status."))
 | 
			
		||||
def enterprise_update_usage():
 | 
			
		||||
@CELERY_APP.task(bind=True, base=SystemTask)
 | 
			
		||||
@prefill_task
 | 
			
		||||
def enterprise_update_usage(self: SystemTask):
 | 
			
		||||
    """Update enterprise license status"""
 | 
			
		||||
    LicenseKey.get_total().record_usage()
 | 
			
		||||
    self.set_status(TaskStatus.SUCCESSFUL)
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										104
									
								
								authentik/events/api/tasks.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										104
									
								
								authentik/events/api/tasks.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,104 @@
 | 
			
		||||
"""Tasks API"""
 | 
			
		||||
 | 
			
		||||
from importlib import import_module
 | 
			
		||||
 | 
			
		||||
from django.contrib import messages
 | 
			
		||||
from django.utils.translation import gettext_lazy as _
 | 
			
		||||
from drf_spectacular.types import OpenApiTypes
 | 
			
		||||
from drf_spectacular.utils import OpenApiResponse, extend_schema
 | 
			
		||||
from rest_framework.decorators import action
 | 
			
		||||
from rest_framework.fields import (
 | 
			
		||||
    CharField,
 | 
			
		||||
    ChoiceField,
 | 
			
		||||
    DateTimeField,
 | 
			
		||||
    FloatField,
 | 
			
		||||
    SerializerMethodField,
 | 
			
		||||
)
 | 
			
		||||
from rest_framework.request import Request
 | 
			
		||||
from rest_framework.response import Response
 | 
			
		||||
from rest_framework.viewsets import ReadOnlyModelViewSet
 | 
			
		||||
from structlog.stdlib import get_logger
 | 
			
		||||
 | 
			
		||||
from authentik.core.api.utils import ModelSerializer
 | 
			
		||||
from authentik.events.logs import LogEventSerializer
 | 
			
		||||
from authentik.events.models import SystemTask, TaskStatus
 | 
			
		||||
from authentik.rbac.decorators import permission_required
 | 
			
		||||
 | 
			
		||||
LOGGER = get_logger()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SystemTaskSerializer(ModelSerializer):
 | 
			
		||||
    """Serialize TaskInfo and TaskResult"""
 | 
			
		||||
 | 
			
		||||
    name = CharField()
 | 
			
		||||
    full_name = SerializerMethodField()
 | 
			
		||||
    uid = CharField(required=False)
 | 
			
		||||
    description = CharField()
 | 
			
		||||
    start_timestamp = DateTimeField(read_only=True)
 | 
			
		||||
    finish_timestamp = DateTimeField(read_only=True)
 | 
			
		||||
    duration = FloatField(read_only=True)
 | 
			
		||||
 | 
			
		||||
    status = ChoiceField(choices=[(x.value, x.name) for x in TaskStatus])
 | 
			
		||||
    messages = LogEventSerializer(many=True)
 | 
			
		||||
 | 
			
		||||
    def get_full_name(self, instance: SystemTask) -> str:
 | 
			
		||||
        """Get full name with UID"""
 | 
			
		||||
        if instance.uid:
 | 
			
		||||
            return f"{instance.name}:{instance.uid}"
 | 
			
		||||
        return instance.name
 | 
			
		||||
 | 
			
		||||
    class Meta:
 | 
			
		||||
        model = SystemTask
 | 
			
		||||
        fields = [
 | 
			
		||||
            "uuid",
 | 
			
		||||
            "name",
 | 
			
		||||
            "full_name",
 | 
			
		||||
            "uid",
 | 
			
		||||
            "description",
 | 
			
		||||
            "start_timestamp",
 | 
			
		||||
            "finish_timestamp",
 | 
			
		||||
            "duration",
 | 
			
		||||
            "status",
 | 
			
		||||
            "messages",
 | 
			
		||||
            "expires",
 | 
			
		||||
            "expiring",
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SystemTaskViewSet(ReadOnlyModelViewSet):
 | 
			
		||||
    """Read-only view set that returns all background tasks"""
 | 
			
		||||
 | 
			
		||||
    queryset = SystemTask.objects.all()
 | 
			
		||||
    serializer_class = SystemTaskSerializer
 | 
			
		||||
    filterset_fields = ["name", "uid", "status"]
 | 
			
		||||
    ordering = ["name", "uid", "status"]
 | 
			
		||||
    search_fields = ["name", "description", "uid", "status"]
 | 
			
		||||
 | 
			
		||||
    @permission_required(None, ["authentik_events.run_task"])
 | 
			
		||||
    @extend_schema(
 | 
			
		||||
        request=OpenApiTypes.NONE,
 | 
			
		||||
        responses={
 | 
			
		||||
            204: OpenApiResponse(description="Task retried successfully"),
 | 
			
		||||
            404: OpenApiResponse(description="Task not found"),
 | 
			
		||||
            500: OpenApiResponse(description="Failed to retry task"),
 | 
			
		||||
        },
 | 
			
		||||
    )
 | 
			
		||||
    @action(detail=True, methods=["POST"], permission_classes=[])
 | 
			
		||||
    def run(self, request: Request, pk=None) -> Response:
 | 
			
		||||
        """Run task"""
 | 
			
		||||
        task: SystemTask = self.get_object()
 | 
			
		||||
        try:
 | 
			
		||||
            task_module = import_module(task.task_call_module)
 | 
			
		||||
            task_func = getattr(task_module, task.task_call_func)
 | 
			
		||||
            LOGGER.info("Running task", task=task_func)
 | 
			
		||||
            task_func.delay(*task.task_call_args, **task.task_call_kwargs)
 | 
			
		||||
            messages.success(
 | 
			
		||||
                self.request,
 | 
			
		||||
                _("Successfully started task {name}.".format_map({"name": task.name})),
 | 
			
		||||
            )
 | 
			
		||||
            return Response(status=204)
 | 
			
		||||
        except (ImportError, AttributeError) as exc:  # pragma: no cover
 | 
			
		||||
            LOGGER.warning("Failed to run task, remove state", task=task.name, exc=exc)
 | 
			
		||||
            # if we get an import error, the module path has probably changed
 | 
			
		||||
            task.delete()
 | 
			
		||||
            return Response(status=500)
 | 
			
		||||
@ -1,11 +1,12 @@
 | 
			
		||||
"""authentik events app"""
 | 
			
		||||
 | 
			
		||||
from celery.schedules import crontab
 | 
			
		||||
from prometheus_client import Gauge, Histogram
 | 
			
		||||
 | 
			
		||||
from authentik.blueprints.apps import ManagedAppConfig
 | 
			
		||||
from authentik.lib.config import CONFIG, ENV_PREFIX
 | 
			
		||||
from authentik.lib.utils.time import fqdn_rand
 | 
			
		||||
from authentik.tasks.schedules.lib import ScheduleSpec
 | 
			
		||||
from authentik.lib.utils.reflection import path_to_class
 | 
			
		||||
from authentik.root.celery import CELERY_APP
 | 
			
		||||
 | 
			
		||||
# TODO: Deprecated metric - remove in 2024.2 or later
 | 
			
		||||
GAUGE_TASKS = Gauge(
 | 
			
		||||
@ -34,17 +35,6 @@ class AuthentikEventsConfig(ManagedAppConfig):
 | 
			
		||||
    verbose_name = "authentik Events"
 | 
			
		||||
    default = True
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def tenant_schedule_specs(self) -> list[ScheduleSpec]:
 | 
			
		||||
        from authentik.events.tasks import notification_cleanup
 | 
			
		||||
 | 
			
		||||
        return [
 | 
			
		||||
            ScheduleSpec(
 | 
			
		||||
                actor=notification_cleanup,
 | 
			
		||||
                crontab=f"{fqdn_rand('notification_cleanup')} */8 * * *",
 | 
			
		||||
            ),
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
    @ManagedAppConfig.reconcile_global
 | 
			
		||||
    def check_deprecations(self):
 | 
			
		||||
        """Check for config deprecations"""
 | 
			
		||||
@ -66,3 +56,41 @@ class AuthentikEventsConfig(ManagedAppConfig):
 | 
			
		||||
                replacement_env=replace_env,
 | 
			
		||||
                message=msg,
 | 
			
		||||
            ).save()
 | 
			
		||||
 | 
			
		||||
    @ManagedAppConfig.reconcile_tenant
 | 
			
		||||
    def prefill_tasks(self):
 | 
			
		||||
        """Prefill tasks"""
 | 
			
		||||
        from authentik.events.models import SystemTask
 | 
			
		||||
        from authentik.events.system_tasks import _prefill_tasks
 | 
			
		||||
 | 
			
		||||
        for task in _prefill_tasks:
 | 
			
		||||
            if SystemTask.objects.filter(name=task.name).exists():
 | 
			
		||||
                continue
 | 
			
		||||
            task.save()
 | 
			
		||||
            self.logger.debug("prefilled task", task_name=task.name)
 | 
			
		||||
 | 
			
		||||
    @ManagedAppConfig.reconcile_tenant
 | 
			
		||||
    def run_scheduled_tasks(self):
 | 
			
		||||
        """Run schedule tasks which are behind schedule (only applies
 | 
			
		||||
        to tasks of which we keep metrics)"""
 | 
			
		||||
        from authentik.events.models import TaskStatus
 | 
			
		||||
        from authentik.events.system_tasks import SystemTask as CelerySystemTask
 | 
			
		||||
 | 
			
		||||
        for task in CELERY_APP.conf["beat_schedule"].values():
 | 
			
		||||
            schedule = task["schedule"]
 | 
			
		||||
            if not isinstance(schedule, crontab):
 | 
			
		||||
                continue
 | 
			
		||||
            task_class: CelerySystemTask = path_to_class(task["task"])
 | 
			
		||||
            if not isinstance(task_class, CelerySystemTask):
 | 
			
		||||
                continue
 | 
			
		||||
            db_task = task_class.db()
 | 
			
		||||
            if not db_task:
 | 
			
		||||
                continue
 | 
			
		||||
            due, _ = schedule.is_due(db_task.finish_timestamp)
 | 
			
		||||
            if due or db_task.status == TaskStatus.UNKNOWN:
 | 
			
		||||
                self.logger.debug("Running past-due scheduled task", task=task["task"])
 | 
			
		||||
                task_class.apply_async(
 | 
			
		||||
                    args=task.get("args", None),
 | 
			
		||||
                    kwargs=task.get("kwargs", None),
 | 
			
		||||
                    **task.get("options", {}),
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
@ -1,22 +0,0 @@
 | 
			
		||||
# Generated by Django 5.1.11 on 2025-06-24 15:36
 | 
			
		||||
 | 
			
		||||
from django.db import migrations
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Migration(migrations.Migration):
 | 
			
		||||
 | 
			
		||||
    dependencies = [
 | 
			
		||||
        ("authentik_events", "0010_rename_group_notificationrule_destination_group_and_more"),
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    operations = [
 | 
			
		||||
        migrations.AlterModelOptions(
 | 
			
		||||
            name="systemtask",
 | 
			
		||||
            options={
 | 
			
		||||
                "default_permissions": (),
 | 
			
		||||
                "permissions": (),
 | 
			
		||||
                "verbose_name": "System Task",
 | 
			
		||||
                "verbose_name_plural": "System Tasks",
 | 
			
		||||
            },
 | 
			
		||||
        ),
 | 
			
		||||
    ]
 | 
			
		||||
@ -5,11 +5,12 @@ from datetime import timedelta
 | 
			
		||||
from difflib import get_close_matches
 | 
			
		||||
from functools import lru_cache
 | 
			
		||||
from inspect import currentframe
 | 
			
		||||
from smtplib import SMTPException
 | 
			
		||||
from typing import Any
 | 
			
		||||
from uuid import uuid4
 | 
			
		||||
 | 
			
		||||
from django.apps import apps
 | 
			
		||||
from django.db import models
 | 
			
		||||
from django.db import connection, models
 | 
			
		||||
from django.http import HttpRequest
 | 
			
		||||
from django.http.request import QueryDict
 | 
			
		||||
from django.utils.timezone import now
 | 
			
		||||
@ -26,6 +27,7 @@ from authentik.core.middleware import (
 | 
			
		||||
    SESSION_KEY_IMPERSONATE_USER,
 | 
			
		||||
)
 | 
			
		||||
from authentik.core.models import ExpiringModel, Group, PropertyMapping, User
 | 
			
		||||
from authentik.events.apps import GAUGE_TASKS, SYSTEM_TASK_STATUS, SYSTEM_TASK_TIME
 | 
			
		||||
from authentik.events.context_processors.base import get_context_processors
 | 
			
		||||
from authentik.events.utils import (
 | 
			
		||||
    cleanse_dict,
 | 
			
		||||
@ -41,7 +43,6 @@ from authentik.lib.utils.time import timedelta_from_string
 | 
			
		||||
from authentik.policies.models import PolicyBindingModel
 | 
			
		||||
from authentik.root.middleware import ClientIPMiddleware
 | 
			
		||||
from authentik.stages.email.utils import TemplateEmailMessage
 | 
			
		||||
from authentik.tasks.models import TasksModel
 | 
			
		||||
from authentik.tenants.models import Tenant
 | 
			
		||||
from authentik.tenants.utils import get_current_tenant
 | 
			
		||||
 | 
			
		||||
@ -266,8 +267,7 @@ class Event(SerializerModel, ExpiringModel):
 | 
			
		||||
            models.Index(fields=["created"]),
 | 
			
		||||
            models.Index(fields=["client_ip"]),
 | 
			
		||||
            models.Index(
 | 
			
		||||
                models.F("context__authorized_application"),
 | 
			
		||||
                name="authentik_e_ctx_app__idx",
 | 
			
		||||
                models.F("context__authorized_application"), name="authentik_e_ctx_app__idx"
 | 
			
		||||
            ),
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
@ -281,7 +281,7 @@ class TransportMode(models.TextChoices):
 | 
			
		||||
    EMAIL = "email", _("Email")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class NotificationTransport(TasksModel, SerializerModel):
 | 
			
		||||
class NotificationTransport(SerializerModel):
 | 
			
		||||
    """Action which is executed when a Rule matches"""
 | 
			
		||||
 | 
			
		||||
    uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4)
 | 
			
		||||
@ -446,8 +446,6 @@ class NotificationTransport(TasksModel, SerializerModel):
 | 
			
		||||
 | 
			
		||||
    def send_email(self, notification: "Notification") -> list[str]:
 | 
			
		||||
        """Send notification via global email configuration"""
 | 
			
		||||
        from authentik.stages.email.tasks import send_mail
 | 
			
		||||
 | 
			
		||||
        if notification.user.email.strip() == "":
 | 
			
		||||
            LOGGER.info(
 | 
			
		||||
                "Discarding notification as user has no email address",
 | 
			
		||||
@ -489,14 +487,17 @@ class NotificationTransport(TasksModel, SerializerModel):
 | 
			
		||||
            template_name="email/event_notification.html",
 | 
			
		||||
            template_context=context,
 | 
			
		||||
        )
 | 
			
		||||
        send_mail.send_with_options(args=(mail.__dict__,), rel_obj=self)
 | 
			
		||||
        return []
 | 
			
		||||
        # Email is sent directly here, as the call to send() should have been from a task.
 | 
			
		||||
        try:
 | 
			
		||||
            from authentik.stages.email.tasks import send_mail
 | 
			
		||||
 | 
			
		||||
            return send_mail(mail.__dict__)
 | 
			
		||||
        except (SMTPException, ConnectionError, OSError) as exc:
 | 
			
		||||
            raise NotificationTransportError(exc) from exc
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def serializer(self) -> type[Serializer]:
 | 
			
		||||
        from authentik.events.api.notification_transports import (
 | 
			
		||||
            NotificationTransportSerializer,
 | 
			
		||||
        )
 | 
			
		||||
        from authentik.events.api.notification_transports import NotificationTransportSerializer
 | 
			
		||||
 | 
			
		||||
        return NotificationTransportSerializer
 | 
			
		||||
 | 
			
		||||
@ -546,7 +547,7 @@ class Notification(SerializerModel):
 | 
			
		||||
        verbose_name_plural = _("Notifications")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class NotificationRule(TasksModel, SerializerModel, PolicyBindingModel):
 | 
			
		||||
class NotificationRule(SerializerModel, PolicyBindingModel):
 | 
			
		||||
    """Decide when to create a Notification based on policies attached to this object."""
 | 
			
		||||
 | 
			
		||||
    name = models.TextField(unique=True)
 | 
			
		||||
@ -610,9 +611,7 @@ class NotificationWebhookMapping(PropertyMapping):
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def serializer(self) -> type[type[Serializer]]:
 | 
			
		||||
        from authentik.events.api.notification_mappings import (
 | 
			
		||||
            NotificationWebhookMappingSerializer,
 | 
			
		||||
        )
 | 
			
		||||
        from authentik.events.api.notification_mappings import NotificationWebhookMappingSerializer
 | 
			
		||||
 | 
			
		||||
        return NotificationWebhookMappingSerializer
 | 
			
		||||
 | 
			
		||||
@ -625,7 +624,7 @@ class NotificationWebhookMapping(PropertyMapping):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TaskStatus(models.TextChoices):
 | 
			
		||||
    """DEPRECATED do not use"""
 | 
			
		||||
    """Possible states of tasks"""
 | 
			
		||||
 | 
			
		||||
    UNKNOWN = "unknown"
 | 
			
		||||
    SUCCESSFUL = "successful"
 | 
			
		||||
@ -633,8 +632,8 @@ class TaskStatus(models.TextChoices):
 | 
			
		||||
    ERROR = "error"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SystemTask(ExpiringModel):
 | 
			
		||||
    """DEPRECATED do not use"""
 | 
			
		||||
class SystemTask(SerializerModel, ExpiringModel):
 | 
			
		||||
    """Info about a system task running in the background along with details to restart the task"""
 | 
			
		||||
 | 
			
		||||
    uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4)
 | 
			
		||||
    name = models.TextField()
 | 
			
		||||
@ -654,13 +653,41 @@ class SystemTask(ExpiringModel):
 | 
			
		||||
    task_call_args = models.JSONField(default=list)
 | 
			
		||||
    task_call_kwargs = models.JSONField(default=dict)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def serializer(self) -> type[Serializer]:
 | 
			
		||||
        from authentik.events.api.tasks import SystemTaskSerializer
 | 
			
		||||
 | 
			
		||||
        return SystemTaskSerializer
 | 
			
		||||
 | 
			
		||||
    def update_metrics(self):
 | 
			
		||||
        """Update prometheus metrics"""
 | 
			
		||||
        # TODO: Deprecated metric - remove in 2024.2 or later
 | 
			
		||||
        GAUGE_TASKS.labels(
 | 
			
		||||
            tenant=connection.schema_name,
 | 
			
		||||
            task_name=self.name,
 | 
			
		||||
            task_uid=self.uid or "",
 | 
			
		||||
            status=self.status.lower(),
 | 
			
		||||
        ).set(self.duration)
 | 
			
		||||
        SYSTEM_TASK_TIME.labels(
 | 
			
		||||
            tenant=connection.schema_name,
 | 
			
		||||
            task_name=self.name,
 | 
			
		||||
            task_uid=self.uid or "",
 | 
			
		||||
        ).observe(self.duration)
 | 
			
		||||
        SYSTEM_TASK_STATUS.labels(
 | 
			
		||||
            tenant=connection.schema_name,
 | 
			
		||||
            task_name=self.name,
 | 
			
		||||
            task_uid=self.uid or "",
 | 
			
		||||
            status=self.status.lower(),
 | 
			
		||||
        ).inc()
 | 
			
		||||
 | 
			
		||||
    def __str__(self) -> str:
 | 
			
		||||
        return f"System Task {self.name}"
 | 
			
		||||
 | 
			
		||||
    class Meta:
 | 
			
		||||
        unique_together = (("name", "uid"),)
 | 
			
		||||
        default_permissions = ()
 | 
			
		||||
        permissions = ()
 | 
			
		||||
        # Remove "add", "change" and "delete" permissions as those are not used
 | 
			
		||||
        default_permissions = ["view"]
 | 
			
		||||
        permissions = [("run_task", _("Run task"))]
 | 
			
		||||
        verbose_name = _("System Task")
 | 
			
		||||
        verbose_name_plural = _("System Tasks")
 | 
			
		||||
        indexes = ExpiringModel.Meta.indexes
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										13
									
								
								authentik/events/settings.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								authentik/events/settings.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,13 @@
 | 
			
		||||
"""Event Settings"""
 | 
			
		||||
 | 
			
		||||
from celery.schedules import crontab
 | 
			
		||||
 | 
			
		||||
from authentik.lib.utils.time import fqdn_rand
 | 
			
		||||
 | 
			
		||||
CELERY_BEAT_SCHEDULE = {
 | 
			
		||||
    "events_notification_cleanup": {
 | 
			
		||||
        "task": "authentik.events.tasks.notification_cleanup",
 | 
			
		||||
        "schedule": crontab(minute=fqdn_rand("notification_cleanup"), hour="*/8"),
 | 
			
		||||
        "options": {"queue": "authentik_scheduled"},
 | 
			
		||||
    },
 | 
			
		||||
}
 | 
			
		||||
@ -12,10 +12,13 @@ from rest_framework.request import Request
 | 
			
		||||
 | 
			
		||||
from authentik.core.models import AuthenticatedSession, User
 | 
			
		||||
from authentik.core.signals import login_failed, password_changed
 | 
			
		||||
from authentik.events.models import Event, EventAction
 | 
			
		||||
from authentik.events.apps import SYSTEM_TASK_STATUS
 | 
			
		||||
from authentik.events.models import Event, EventAction, SystemTask
 | 
			
		||||
from authentik.events.tasks import event_notification_handler, gdpr_cleanup
 | 
			
		||||
from authentik.flows.models import Stage
 | 
			
		||||
from authentik.flows.planner import PLAN_CONTEXT_OUTPOST, PLAN_CONTEXT_SOURCE, FlowPlan
 | 
			
		||||
from authentik.flows.views.executor import SESSION_KEY_PLAN
 | 
			
		||||
from authentik.root.monitoring import monitoring_set
 | 
			
		||||
from authentik.stages.invitation.models import Invitation
 | 
			
		||||
from authentik.stages.invitation.signals import invitation_used
 | 
			
		||||
from authentik.stages.password.stage import PLAN_CONTEXT_METHOD, PLAN_CONTEXT_METHOD_ARGS
 | 
			
		||||
@ -111,15 +114,19 @@ def on_password_changed(sender, user: User, password: str, request: HttpRequest
 | 
			
		||||
@receiver(post_save, sender=Event)
 | 
			
		||||
def event_post_save_notification(sender, instance: Event, **_):
 | 
			
		||||
    """Start task to check if any policies trigger an notification on this event"""
 | 
			
		||||
    from authentik.events.tasks import event_trigger_dispatch
 | 
			
		||||
 | 
			
		||||
    event_trigger_dispatch.send(instance.event_uuid)
 | 
			
		||||
    event_notification_handler.delay(instance.event_uuid.hex)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@receiver(pre_delete, sender=User)
 | 
			
		||||
def event_user_pre_delete_cleanup(sender, instance: User, **_):
 | 
			
		||||
    """If gdpr_compliance is enabled, remove all the user's events"""
 | 
			
		||||
    from authentik.events.tasks import gdpr_cleanup
 | 
			
		||||
 | 
			
		||||
    if get_current_tenant().gdpr_compliance:
 | 
			
		||||
        gdpr_cleanup.send(instance.pk)
 | 
			
		||||
        gdpr_cleanup.delay(instance.pk)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@receiver(monitoring_set)
 | 
			
		||||
def monitoring_system_task(sender, **_):
 | 
			
		||||
    """Update metrics when task is saved"""
 | 
			
		||||
    SYSTEM_TASK_STATUS.clear()
 | 
			
		||||
    for task in SystemTask.objects.all():
 | 
			
		||||
        task.update_metrics()
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										156
									
								
								authentik/events/system_tasks.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										156
									
								
								authentik/events/system_tasks.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,156 @@
 | 
			
		||||
"""Monitored tasks"""
 | 
			
		||||
 | 
			
		||||
from datetime import datetime, timedelta
 | 
			
		||||
from time import perf_counter
 | 
			
		||||
from typing import Any
 | 
			
		||||
 | 
			
		||||
from django.utils.timezone import now
 | 
			
		||||
from django.utils.translation import gettext_lazy as _
 | 
			
		||||
from structlog.stdlib import BoundLogger, get_logger
 | 
			
		||||
from tenant_schemas_celery.task import TenantTask
 | 
			
		||||
 | 
			
		||||
from authentik.events.logs import LogEvent
 | 
			
		||||
from authentik.events.models import Event, EventAction, TaskStatus
 | 
			
		||||
from authentik.events.models import SystemTask as DBSystemTask
 | 
			
		||||
from authentik.events.utils import sanitize_item
 | 
			
		||||
from authentik.lib.utils.errors import exception_to_string
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SystemTask(TenantTask):
 | 
			
		||||
    """Task which can save its state to the cache"""
 | 
			
		||||
 | 
			
		||||
    logger: BoundLogger
 | 
			
		||||
 | 
			
		||||
    # For tasks that should only be listed if they failed, set this to False
 | 
			
		||||
    save_on_success: bool
 | 
			
		||||
 | 
			
		||||
    _status: TaskStatus
 | 
			
		||||
    _messages: list[LogEvent]
 | 
			
		||||
 | 
			
		||||
    _uid: str | None
 | 
			
		||||
    # Precise start time from perf_counter
 | 
			
		||||
    _start_precise: float | None = None
 | 
			
		||||
    _start: datetime | None = None
 | 
			
		||||
 | 
			
		||||
    def __init__(self, *args, **kwargs) -> None:
 | 
			
		||||
        super().__init__(*args, **kwargs)
 | 
			
		||||
        self._status = TaskStatus.SUCCESSFUL
 | 
			
		||||
        self.save_on_success = True
 | 
			
		||||
        self._uid = None
 | 
			
		||||
        self._status = None
 | 
			
		||||
        self._messages = []
 | 
			
		||||
        self.result_timeout_hours = 6
 | 
			
		||||
 | 
			
		||||
    def set_uid(self, uid: str):
 | 
			
		||||
        """Set UID, so in the case of an unexpected error its saved correctly"""
 | 
			
		||||
        self._uid = uid
 | 
			
		||||
 | 
			
		||||
    def set_status(self, status: TaskStatus, *messages: LogEvent):
 | 
			
		||||
        """Set result for current run, will overwrite previous result."""
 | 
			
		||||
        self._status = status
 | 
			
		||||
        self._messages = list(messages)
 | 
			
		||||
        for idx, msg in enumerate(self._messages):
 | 
			
		||||
            if not isinstance(msg, LogEvent):
 | 
			
		||||
                self._messages[idx] = LogEvent(msg, logger=self.__name__, log_level="info")
 | 
			
		||||
 | 
			
		||||
    def set_error(self, exception: Exception, *messages: LogEvent):
 | 
			
		||||
        """Set result to error and save exception"""
 | 
			
		||||
        self._status = TaskStatus.ERROR
 | 
			
		||||
        self._messages = list(messages)
 | 
			
		||||
        self._messages.extend(
 | 
			
		||||
            [LogEvent(exception_to_string(exception), logger=self.__name__, log_level="error")]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def before_start(self, task_id, args, kwargs):
 | 
			
		||||
        self._start_precise = perf_counter()
 | 
			
		||||
        self._start = now()
 | 
			
		||||
        self.logger = get_logger().bind(task_id=task_id)
 | 
			
		||||
        return super().before_start(task_id, args, kwargs)
 | 
			
		||||
 | 
			
		||||
    def db(self) -> DBSystemTask | None:
 | 
			
		||||
        """Get DB object for latest task"""
 | 
			
		||||
        return DBSystemTask.objects.filter(
 | 
			
		||||
            name=self.__name__,
 | 
			
		||||
            uid=self._uid,
 | 
			
		||||
        ).first()
 | 
			
		||||
 | 
			
		||||
    def after_return(self, status, retval, task_id, args: list[Any], kwargs: dict[str, Any], einfo):
 | 
			
		||||
        super().after_return(status, retval, task_id, args, kwargs, einfo=einfo)
 | 
			
		||||
        if not self._status:
 | 
			
		||||
            return
 | 
			
		||||
        if self._status == TaskStatus.SUCCESSFUL and not self.save_on_success:
 | 
			
		||||
            DBSystemTask.objects.filter(
 | 
			
		||||
                name=self.__name__,
 | 
			
		||||
                uid=self._uid,
 | 
			
		||||
            ).delete()
 | 
			
		||||
            return
 | 
			
		||||
        DBSystemTask.objects.update_or_create(
 | 
			
		||||
            name=self.__name__,
 | 
			
		||||
            uid=self._uid,
 | 
			
		||||
            defaults={
 | 
			
		||||
                "description": self.__doc__,
 | 
			
		||||
                "start_timestamp": self._start or now(),
 | 
			
		||||
                "finish_timestamp": now(),
 | 
			
		||||
                "duration": max(perf_counter() - self._start_precise, 0),
 | 
			
		||||
                "task_call_module": self.__module__,
 | 
			
		||||
                "task_call_func": self.__name__,
 | 
			
		||||
                "task_call_args": sanitize_item(args),
 | 
			
		||||
                "task_call_kwargs": sanitize_item(kwargs),
 | 
			
		||||
                "status": self._status,
 | 
			
		||||
                "messages": sanitize_item(self._messages),
 | 
			
		||||
                "expires": now() + timedelta(hours=self.result_timeout_hours),
 | 
			
		||||
                "expiring": True,
 | 
			
		||||
            },
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def on_failure(self, exc, task_id, args, kwargs, einfo):
 | 
			
		||||
        super().on_failure(exc, task_id, args, kwargs, einfo=einfo)
 | 
			
		||||
        if not self._status:
 | 
			
		||||
            self.set_error(exc)
 | 
			
		||||
        DBSystemTask.objects.update_or_create(
 | 
			
		||||
            name=self.__name__,
 | 
			
		||||
            uid=self._uid,
 | 
			
		||||
            defaults={
 | 
			
		||||
                "description": self.__doc__,
 | 
			
		||||
                "start_timestamp": self._start or now(),
 | 
			
		||||
                "finish_timestamp": now(),
 | 
			
		||||
                "duration": max(perf_counter() - self._start_precise, 0),
 | 
			
		||||
                "task_call_module": self.__module__,
 | 
			
		||||
                "task_call_func": self.__name__,
 | 
			
		||||
                "task_call_args": sanitize_item(args),
 | 
			
		||||
                "task_call_kwargs": sanitize_item(kwargs),
 | 
			
		||||
                "status": self._status,
 | 
			
		||||
                "messages": sanitize_item(self._messages),
 | 
			
		||||
                "expires": now() + timedelta(hours=self.result_timeout_hours + 3),
 | 
			
		||||
                "expiring": True,
 | 
			
		||||
            },
 | 
			
		||||
        )
 | 
			
		||||
        Event.new(
 | 
			
		||||
            EventAction.SYSTEM_TASK_EXCEPTION,
 | 
			
		||||
            message=f"Task {self.__name__} encountered an error: {exception_to_string(exc)}",
 | 
			
		||||
        ).save()
 | 
			
		||||
 | 
			
		||||
    def run(self, *args, **kwargs):
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def prefill_task(func):
 | 
			
		||||
    """Ensure a task's details are always in cache, so it can always be triggered via API"""
 | 
			
		||||
    _prefill_tasks.append(
 | 
			
		||||
        DBSystemTask(
 | 
			
		||||
            name=func.__name__,
 | 
			
		||||
            description=func.__doc__,
 | 
			
		||||
            start_timestamp=now(),
 | 
			
		||||
            finish_timestamp=now(),
 | 
			
		||||
            status=TaskStatus.UNKNOWN,
 | 
			
		||||
            messages=sanitize_item([_("Task has not been run yet.")]),
 | 
			
		||||
            task_call_module=func.__module__,
 | 
			
		||||
            task_call_func=func.__name__,
 | 
			
		||||
            expiring=False,
 | 
			
		||||
            duration=0,
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
    return func
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_prefill_tasks = []
 | 
			
		||||
@ -1,49 +1,41 @@
 | 
			
		||||
"""Event notification tasks"""
 | 
			
		||||
 | 
			
		||||
from uuid import UUID
 | 
			
		||||
 | 
			
		||||
from django.db.models.query_utils import Q
 | 
			
		||||
from django.utils.translation import gettext_lazy as _
 | 
			
		||||
from django_dramatiq_postgres.middleware import CurrentTask
 | 
			
		||||
from dramatiq.actor import actor
 | 
			
		||||
from guardian.shortcuts import get_anonymous_user
 | 
			
		||||
from structlog.stdlib import get_logger
 | 
			
		||||
 | 
			
		||||
from authentik.core.expression.exceptions import PropertyMappingExpressionException
 | 
			
		||||
from authentik.core.models import User
 | 
			
		||||
from authentik.events.models import (
 | 
			
		||||
    Event,
 | 
			
		||||
    Notification,
 | 
			
		||||
    NotificationRule,
 | 
			
		||||
    NotificationTransport,
 | 
			
		||||
    NotificationTransportError,
 | 
			
		||||
    TaskStatus,
 | 
			
		||||
)
 | 
			
		||||
from authentik.events.system_tasks import SystemTask, prefill_task
 | 
			
		||||
from authentik.policies.engine import PolicyEngine
 | 
			
		||||
from authentik.policies.models import PolicyBinding, PolicyEngineMode
 | 
			
		||||
from authentik.tasks.models import Task
 | 
			
		||||
from authentik.root.celery import CELERY_APP
 | 
			
		||||
 | 
			
		||||
LOGGER = get_logger()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@actor(description=_("Dispatch new event notifications."))
 | 
			
		||||
def event_trigger_dispatch(event_uuid: UUID):
 | 
			
		||||
@CELERY_APP.task()
 | 
			
		||||
def event_notification_handler(event_uuid: str):
 | 
			
		||||
    """Start task for each trigger definition"""
 | 
			
		||||
    for trigger in NotificationRule.objects.all():
 | 
			
		||||
        event_trigger_handler.send_with_options(args=(event_uuid, trigger.name), rel_obj=trigger)
 | 
			
		||||
        event_trigger_handler.apply_async(args=[event_uuid, trigger.name], queue="authentik_events")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@actor(
 | 
			
		||||
    description=_(
 | 
			
		||||
        "Check if policies attached to NotificationRule match event "
 | 
			
		||||
        "and dispatch notification tasks."
 | 
			
		||||
    )
 | 
			
		||||
)
 | 
			
		||||
def event_trigger_handler(event_uuid: UUID, trigger_name: str):
 | 
			
		||||
@CELERY_APP.task()
 | 
			
		||||
def event_trigger_handler(event_uuid: str, trigger_name: str):
 | 
			
		||||
    """Check if policies attached to NotificationRule match event"""
 | 
			
		||||
    self: Task = CurrentTask.get_task()
 | 
			
		||||
 | 
			
		||||
    event: Event = Event.objects.filter(event_uuid=event_uuid).first()
 | 
			
		||||
    if not event:
 | 
			
		||||
        self.warning("event doesn't exist yet or anymore", event_uuid=event_uuid)
 | 
			
		||||
        LOGGER.warning("event doesn't exist yet or anymore", event_uuid=event_uuid)
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    trigger: NotificationRule | None = NotificationRule.objects.filter(name=trigger_name).first()
 | 
			
		||||
    if not trigger:
 | 
			
		||||
        return
 | 
			
		||||
@ -78,46 +70,57 @@ def event_trigger_handler(event_uuid: UUID, trigger_name: str):
 | 
			
		||||
 | 
			
		||||
    LOGGER.debug("e(trigger): event trigger matched", trigger=trigger)
 | 
			
		||||
    # Create the notification objects
 | 
			
		||||
    count = 0
 | 
			
		||||
    for transport in trigger.transports.all():
 | 
			
		||||
        for user in trigger.destination_users(event):
 | 
			
		||||
            notification_transport.send_with_options(
 | 
			
		||||
                args=(
 | 
			
		||||
            LOGGER.debug("created notification")
 | 
			
		||||
            notification_transport.apply_async(
 | 
			
		||||
                args=[
 | 
			
		||||
                    transport.pk,
 | 
			
		||||
                    event.pk,
 | 
			
		||||
                    str(event.pk),
 | 
			
		||||
                    user.pk,
 | 
			
		||||
                    trigger.pk,
 | 
			
		||||
                ),
 | 
			
		||||
                rel_obj=transport,
 | 
			
		||||
                    str(trigger.pk),
 | 
			
		||||
                ],
 | 
			
		||||
                queue="authentik_events",
 | 
			
		||||
            )
 | 
			
		||||
            count += 1
 | 
			
		||||
            if transport.send_once:
 | 
			
		||||
                break
 | 
			
		||||
    self.info(f"Created {count} notification tasks")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@actor(description=_("Send notification."))
 | 
			
		||||
def notification_transport(transport_pk: int, event_pk: str, user_pk: int, trigger_pk: str):
 | 
			
		||||
@CELERY_APP.task(
 | 
			
		||||
    bind=True,
 | 
			
		||||
    autoretry_for=(NotificationTransportError,),
 | 
			
		||||
    retry_backoff=True,
 | 
			
		||||
    base=SystemTask,
 | 
			
		||||
)
 | 
			
		||||
def notification_transport(
 | 
			
		||||
    self: SystemTask, transport_pk: int, event_pk: str, user_pk: int, trigger_pk: str
 | 
			
		||||
):
 | 
			
		||||
    """Send notification over specified transport"""
 | 
			
		||||
    event = Event.objects.filter(pk=event_pk).first()
 | 
			
		||||
    if not event:
 | 
			
		||||
        return
 | 
			
		||||
    user = User.objects.filter(pk=user_pk).first()
 | 
			
		||||
    if not user:
 | 
			
		||||
        return
 | 
			
		||||
    trigger = NotificationRule.objects.filter(pk=trigger_pk).first()
 | 
			
		||||
    if not trigger:
 | 
			
		||||
        return
 | 
			
		||||
    notification = Notification(
 | 
			
		||||
        severity=trigger.severity, body=event.summary, event=event, user=user
 | 
			
		||||
    )
 | 
			
		||||
    transport: NotificationTransport = NotificationTransport.objects.filter(pk=transport_pk).first()
 | 
			
		||||
    if not transport:
 | 
			
		||||
        return
 | 
			
		||||
    transport.send(notification)
 | 
			
		||||
    self.save_on_success = False
 | 
			
		||||
    try:
 | 
			
		||||
        event = Event.objects.filter(pk=event_pk).first()
 | 
			
		||||
        if not event:
 | 
			
		||||
            return
 | 
			
		||||
        user = User.objects.filter(pk=user_pk).first()
 | 
			
		||||
        if not user:
 | 
			
		||||
            return
 | 
			
		||||
        trigger = NotificationRule.objects.filter(pk=trigger_pk).first()
 | 
			
		||||
        if not trigger:
 | 
			
		||||
            return
 | 
			
		||||
        notification = Notification(
 | 
			
		||||
            severity=trigger.severity, body=event.summary, event=event, user=user
 | 
			
		||||
        )
 | 
			
		||||
        transport = NotificationTransport.objects.filter(pk=transport_pk).first()
 | 
			
		||||
        if not transport:
 | 
			
		||||
            return
 | 
			
		||||
        transport.send(notification)
 | 
			
		||||
        self.set_status(TaskStatus.SUCCESSFUL)
 | 
			
		||||
    except (NotificationTransportError, PropertyMappingExpressionException) as exc:
 | 
			
		||||
        self.set_error(exc)
 | 
			
		||||
        raise exc
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@actor(description=_("Cleanup events for GDPR compliance."))
 | 
			
		||||
@CELERY_APP.task()
 | 
			
		||||
def gdpr_cleanup(user_pk: int):
 | 
			
		||||
    """cleanup events from gdpr_compliance"""
 | 
			
		||||
    events = Event.objects.filter(user__pk=user_pk)
 | 
			
		||||
@ -125,12 +128,12 @@ def gdpr_cleanup(user_pk: int):
 | 
			
		||||
    events.delete()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@actor(description=_("Cleanup seen notifications and notifications whose event expired."))
 | 
			
		||||
def notification_cleanup():
 | 
			
		||||
@CELERY_APP.task(bind=True, base=SystemTask)
 | 
			
		||||
@prefill_task
 | 
			
		||||
def notification_cleanup(self: SystemTask):
 | 
			
		||||
    """Cleanup seen notifications and notifications whose event expired."""
 | 
			
		||||
    self: Task = CurrentTask.get_task()
 | 
			
		||||
    notifications = Notification.objects.filter(Q(event=None) | Q(seen=True))
 | 
			
		||||
    amount = notifications.count()
 | 
			
		||||
    notifications.delete()
 | 
			
		||||
    LOGGER.debug("Expired notifications", amount=amount)
 | 
			
		||||
    self.info(f"Expired {amount} Notifications")
 | 
			
		||||
    self.set_status(TaskStatus.SUCCESSFUL, f"Expired {amount} Notifications")
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										103
									
								
								authentik/events/tests/test_tasks.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										103
									
								
								authentik/events/tests/test_tasks.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,103 @@
 | 
			
		||||
"""Test Monitored tasks"""
 | 
			
		||||
 | 
			
		||||
from json import loads
 | 
			
		||||
 | 
			
		||||
from django.urls import reverse
 | 
			
		||||
from rest_framework.test import APITestCase
 | 
			
		||||
 | 
			
		||||
from authentik.core.tasks import clean_expired_models
 | 
			
		||||
from authentik.core.tests.utils import create_test_admin_user
 | 
			
		||||
from authentik.events.models import SystemTask as DBSystemTask
 | 
			
		||||
from authentik.events.models import TaskStatus
 | 
			
		||||
from authentik.events.system_tasks import SystemTask
 | 
			
		||||
from authentik.lib.generators import generate_id
 | 
			
		||||
from authentik.root.celery import CELERY_APP
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestSystemTasks(APITestCase):
 | 
			
		||||
    """Test Monitored tasks"""
 | 
			
		||||
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        super().setUp()
 | 
			
		||||
        self.user = create_test_admin_user()
 | 
			
		||||
        self.client.force_login(self.user)
 | 
			
		||||
 | 
			
		||||
    def test_failed_successful_remove_state(self):
 | 
			
		||||
        """Test that a task with `save_on_success` set to `False` that failed saves
 | 
			
		||||
        a state, and upon successful completion will delete the state"""
 | 
			
		||||
        should_fail = True
 | 
			
		||||
        uid = generate_id()
 | 
			
		||||
 | 
			
		||||
        @CELERY_APP.task(
 | 
			
		||||
            bind=True,
 | 
			
		||||
            base=SystemTask,
 | 
			
		||||
        )
 | 
			
		||||
        def test_task(self: SystemTask):
 | 
			
		||||
            self.save_on_success = False
 | 
			
		||||
            self.set_uid(uid)
 | 
			
		||||
            self.set_status(TaskStatus.ERROR if should_fail else TaskStatus.SUCCESSFUL)
 | 
			
		||||
 | 
			
		||||
        # First test successful run
 | 
			
		||||
        should_fail = False
 | 
			
		||||
        test_task.delay().get()
 | 
			
		||||
        self.assertIsNone(DBSystemTask.objects.filter(name="test_task", uid=uid).first())
 | 
			
		||||
 | 
			
		||||
        # Then test failed
 | 
			
		||||
        should_fail = True
 | 
			
		||||
        test_task.delay().get()
 | 
			
		||||
        task = DBSystemTask.objects.filter(name="test_task", uid=uid).first()
 | 
			
		||||
        self.assertEqual(task.status, TaskStatus.ERROR)
 | 
			
		||||
 | 
			
		||||
        # Then after that, the state should be removed
 | 
			
		||||
        should_fail = False
 | 
			
		||||
        test_task.delay().get()
 | 
			
		||||
        self.assertIsNone(DBSystemTask.objects.filter(name="test_task", uid=uid).first())
 | 
			
		||||
 | 
			
		||||
    def test_tasks(self):
 | 
			
		||||
        """Test Task API"""
 | 
			
		||||
        clean_expired_models.delay().get()
 | 
			
		||||
        response = self.client.get(reverse("authentik_api:systemtask-list"))
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
        body = loads(response.content)
 | 
			
		||||
        self.assertTrue(any(task["name"] == "clean_expired_models" for task in body["results"]))
 | 
			
		||||
 | 
			
		||||
    def test_tasks_single(self):
 | 
			
		||||
        """Test Task API (read single)"""
 | 
			
		||||
        clean_expired_models.delay().get()
 | 
			
		||||
        task = DBSystemTask.objects.filter(name="clean_expired_models").first()
 | 
			
		||||
        response = self.client.get(
 | 
			
		||||
            reverse(
 | 
			
		||||
                "authentik_api:systemtask-detail",
 | 
			
		||||
                kwargs={"pk": str(task.pk)},
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
        body = loads(response.content)
 | 
			
		||||
        self.assertEqual(body["status"], TaskStatus.SUCCESSFUL.value)
 | 
			
		||||
        self.assertEqual(body["name"], "clean_expired_models")
 | 
			
		||||
        response = self.client.get(
 | 
			
		||||
            reverse("authentik_api:systemtask-detail", kwargs={"pk": "qwerqwer"})
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(response.status_code, 404)
 | 
			
		||||
 | 
			
		||||
    def test_tasks_run(self):
 | 
			
		||||
        """Test Task API (run)"""
 | 
			
		||||
        clean_expired_models.delay().get()
 | 
			
		||||
        task = DBSystemTask.objects.filter(name="clean_expired_models").first()
 | 
			
		||||
        response = self.client.post(
 | 
			
		||||
            reverse(
 | 
			
		||||
                "authentik_api:systemtask-run",
 | 
			
		||||
                kwargs={"pk": str(task.pk)},
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(response.status_code, 204)
 | 
			
		||||
 | 
			
		||||
    def test_tasks_run_404(self):
 | 
			
		||||
        """Test Task API (run, 404)"""
 | 
			
		||||
        response = self.client.post(
 | 
			
		||||
            reverse(
 | 
			
		||||
                "authentik_api:systemtask-run",
 | 
			
		||||
                kwargs={"pk": "qwerqewrqrqewrqewr"},
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(response.status_code, 404)
 | 
			
		||||
@ -5,11 +5,13 @@ from authentik.events.api.notification_mappings import NotificationWebhookMappin
 | 
			
		||||
from authentik.events.api.notification_rules import NotificationRuleViewSet
 | 
			
		||||
from authentik.events.api.notification_transports import NotificationTransportViewSet
 | 
			
		||||
from authentik.events.api.notifications import NotificationViewSet
 | 
			
		||||
from authentik.events.api.tasks import SystemTaskViewSet
 | 
			
		||||
 | 
			
		||||
api_urlpatterns = [
 | 
			
		||||
    ("events/events", EventViewSet),
 | 
			
		||||
    ("events/notifications", NotificationViewSet),
 | 
			
		||||
    ("events/transports", NotificationTransportViewSet),
 | 
			
		||||
    ("events/rules", NotificationRuleViewSet),
 | 
			
		||||
    ("events/system_tasks", SystemTaskViewSet),
 | 
			
		||||
    ("propertymappings/notification", NotificationWebhookMappingViewSet),
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
@ -41,7 +41,6 @@ REDIS_ENV_KEYS = [
 | 
			
		||||
# Old key -> new key
 | 
			
		||||
DEPRECATIONS = {
 | 
			
		||||
    "geoip": "events.context_processors.geoip",
 | 
			
		||||
    "worker.concurrency": "worker.processes",
 | 
			
		||||
    "redis.broker_url": "broker.url",
 | 
			
		||||
    "redis.broker_transport_options": "broker.transport_options",
 | 
			
		||||
    "redis.cache_timeout": "cache.timeout",
 | 
			
		||||
 | 
			
		||||
@ -21,10 +21,6 @@ def start_debug_server(**kwargs) -> bool:
 | 
			
		||||
 | 
			
		||||
    listen: str = CONFIG.get("listen.listen_debug_py", "127.0.0.1:9901")
 | 
			
		||||
    host, _, port = listen.rpartition(":")
 | 
			
		||||
    try:
 | 
			
		||||
        debugpy.listen((host, int(port)), **kwargs)  # nosec
 | 
			
		||||
    except RuntimeError:
 | 
			
		||||
        LOGGER.warning("Could not start debug server. Continuing without")
 | 
			
		||||
        return False
 | 
			
		||||
    debugpy.listen((host, int(port)), **kwargs)  # nosec
 | 
			
		||||
    LOGGER.debug("Starting debug server", host=host, port=port)
 | 
			
		||||
    return True
 | 
			
		||||
 | 
			
		||||
@ -8,9 +8,9 @@
 | 
			
		||||
# make gen-dev-config
 | 
			
		||||
# ```
 | 
			
		||||
#
 | 
			
		||||
# You may edit the generated file to override the configuration below.
 | 
			
		||||
# You may edit the generated file to override the configuration below.  
 | 
			
		||||
#
 | 
			
		||||
# When making modifying the default configuration file,
 | 
			
		||||
# When making modifying the default configuration file, 
 | 
			
		||||
# ensure that the corresponding documentation is updated to match.
 | 
			
		||||
#
 | 
			
		||||
# @see {@link ../../website/docs/install-config/configuration/configuration.mdx Configuration documentation} for more information.
 | 
			
		||||
@ -157,14 +157,7 @@ web:
 | 
			
		||||
  path: /
 | 
			
		||||
 | 
			
		||||
worker:
 | 
			
		||||
  processes: 2
 | 
			
		||||
  threads: 1
 | 
			
		||||
  consumer_listen_timeout: "seconds=30"
 | 
			
		||||
  task_max_retries: 20
 | 
			
		||||
  task_default_time_limit: "minutes=10"
 | 
			
		||||
  task_purge_interval: "days=1"
 | 
			
		||||
  task_expiration: "days=30"
 | 
			
		||||
  scheduler_interval: "seconds=60"
 | 
			
		||||
  concurrency: 2
 | 
			
		||||
 | 
			
		||||
storage:
 | 
			
		||||
  media:
 | 
			
		||||
 | 
			
		||||
@ -88,6 +88,7 @@ def get_logger_config():
 | 
			
		||||
        "authentik": global_level,
 | 
			
		||||
        "django": "WARNING",
 | 
			
		||||
        "django.request": "ERROR",
 | 
			
		||||
        "celery": "WARNING",
 | 
			
		||||
        "selenium": "WARNING",
 | 
			
		||||
        "docker": "WARNING",
 | 
			
		||||
        "urllib3": "WARNING",
 | 
			
		||||
 | 
			
		||||
@ -3,6 +3,8 @@
 | 
			
		||||
from asyncio.exceptions import CancelledError
 | 
			
		||||
from typing import Any
 | 
			
		||||
 | 
			
		||||
from billiard.exceptions import SoftTimeLimitExceeded, WorkerLostError
 | 
			
		||||
from celery.exceptions import CeleryError
 | 
			
		||||
from channels_redis.core import ChannelFull
 | 
			
		||||
from django.conf import settings
 | 
			
		||||
from django.core.exceptions import ImproperlyConfigured, SuspiciousOperation, ValidationError
 | 
			
		||||
@ -20,6 +22,7 @@ from sentry_sdk import HttpTransport, get_current_scope
 | 
			
		||||
from sentry_sdk import init as sentry_sdk_init
 | 
			
		||||
from sentry_sdk.api import set_tag
 | 
			
		||||
from sentry_sdk.integrations.argv import ArgvIntegration
 | 
			
		||||
from sentry_sdk.integrations.celery import CeleryIntegration
 | 
			
		||||
from sentry_sdk.integrations.django import DjangoIntegration
 | 
			
		||||
from sentry_sdk.integrations.redis import RedisIntegration
 | 
			
		||||
from sentry_sdk.integrations.socket import SocketIntegration
 | 
			
		||||
@ -68,6 +71,10 @@ ignored_classes = (
 | 
			
		||||
    LocalProtocolError,
 | 
			
		||||
    # rest_framework error
 | 
			
		||||
    APIException,
 | 
			
		||||
    # celery errors
 | 
			
		||||
    WorkerLostError,
 | 
			
		||||
    CeleryError,
 | 
			
		||||
    SoftTimeLimitExceeded,
 | 
			
		||||
    # custom baseclass
 | 
			
		||||
    SentryIgnoredException,
 | 
			
		||||
    # ldap errors
 | 
			
		||||
@ -108,6 +115,7 @@ def sentry_init(**sentry_init_kwargs):
 | 
			
		||||
            ArgvIntegration(),
 | 
			
		||||
            StdlibIntegration(),
 | 
			
		||||
            DjangoIntegration(transaction_style="function_name", cache_spans=True),
 | 
			
		||||
            CeleryIntegration(),
 | 
			
		||||
            RedisIntegration(),
 | 
			
		||||
            ThreadingIntegration(propagate_hub=True),
 | 
			
		||||
            SocketIntegration(),
 | 
			
		||||
@ -152,11 +160,14 @@ def before_send(event: dict, hint: dict) -> dict | None:
 | 
			
		||||
            return None
 | 
			
		||||
    if "logger" in event:
 | 
			
		||||
        if event["logger"] in [
 | 
			
		||||
            "kombu",
 | 
			
		||||
            "asyncio",
 | 
			
		||||
            "multiprocessing",
 | 
			
		||||
            "django_redis",
 | 
			
		||||
            "django.security.DisallowedHost",
 | 
			
		||||
            "django_redis.cache",
 | 
			
		||||
            "celery.backends.redis",
 | 
			
		||||
            "celery.worker",
 | 
			
		||||
            "paramiko.transport",
 | 
			
		||||
        ]:
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
@ -1,12 +0,0 @@
 | 
			
		||||
from rest_framework.fields import BooleanField, ChoiceField, DateTimeField
 | 
			
		||||
 | 
			
		||||
from authentik.core.api.utils import PassiveSerializer
 | 
			
		||||
from authentik.tasks.models import TaskStatus
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SyncStatusSerializer(PassiveSerializer):
 | 
			
		||||
    """Provider/source sync status"""
 | 
			
		||||
 | 
			
		||||
    is_running = BooleanField()
 | 
			
		||||
    last_successful_sync = DateTimeField(required=False)
 | 
			
		||||
    last_sync_status = ChoiceField(required=False, choices=TaskStatus.choices)
 | 
			
		||||
@ -1,7 +1,7 @@
 | 
			
		||||
"""Sync constants"""
 | 
			
		||||
 | 
			
		||||
PAGE_SIZE = 100
 | 
			
		||||
PAGE_TIMEOUT_MS = 60 * 60 * 0.5 * 1000  # Half an hour
 | 
			
		||||
PAGE_TIMEOUT = 60 * 60 * 0.5  # Half an hour
 | 
			
		||||
HTTP_CONFLICT = 409
 | 
			
		||||
HTTP_NO_CONTENT = 204
 | 
			
		||||
HTTP_SERVICE_UNAVAILABLE = 503
 | 
			
		||||
 | 
			
		||||
@ -1,5 +1,7 @@
 | 
			
		||||
from dramatiq.actor import Actor
 | 
			
		||||
from drf_spectacular.utils import extend_schema
 | 
			
		||||
from celery import Task
 | 
			
		||||
from django.utils.text import slugify
 | 
			
		||||
from drf_spectacular.utils import OpenApiResponse, extend_schema
 | 
			
		||||
from guardian.shortcuts import get_objects_for_user
 | 
			
		||||
from rest_framework.decorators import action
 | 
			
		||||
from rest_framework.fields import BooleanField, CharField, ChoiceField
 | 
			
		||||
from rest_framework.request import Request
 | 
			
		||||
@ -7,12 +9,18 @@ from rest_framework.response import Response
 | 
			
		||||
 | 
			
		||||
from authentik.core.api.utils import ModelSerializer, PassiveSerializer
 | 
			
		||||
from authentik.core.models import Group, User
 | 
			
		||||
from authentik.events.logs import LogEventSerializer
 | 
			
		||||
from authentik.lib.sync.api import SyncStatusSerializer
 | 
			
		||||
from authentik.events.api.tasks import SystemTaskSerializer
 | 
			
		||||
from authentik.events.logs import LogEvent, LogEventSerializer
 | 
			
		||||
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
 | 
			
		||||
from authentik.lib.utils.reflection import class_to_path
 | 
			
		||||
from authentik.rbac.filters import ObjectFilter
 | 
			
		||||
from authentik.tasks.models import Task, TaskStatus
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SyncStatusSerializer(PassiveSerializer):
 | 
			
		||||
    """Provider sync status"""
 | 
			
		||||
 | 
			
		||||
    is_running = BooleanField(read_only=True)
 | 
			
		||||
    tasks = SystemTaskSerializer(many=True, read_only=True)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SyncObjectSerializer(PassiveSerializer):
 | 
			
		||||
@ -37,10 +45,15 @@ class SyncObjectResultSerializer(PassiveSerializer):
 | 
			
		||||
class OutgoingSyncProviderStatusMixin:
 | 
			
		||||
    """Common API Endpoints for Outgoing sync providers"""
 | 
			
		||||
 | 
			
		||||
    sync_task: Actor
 | 
			
		||||
    sync_objects_task: Actor
 | 
			
		||||
    sync_single_task: type[Task] = None
 | 
			
		||||
    sync_objects_task: type[Task] = None
 | 
			
		||||
 | 
			
		||||
    @extend_schema(responses={200: SyncStatusSerializer()})
 | 
			
		||||
    @extend_schema(
 | 
			
		||||
        responses={
 | 
			
		||||
            200: SyncStatusSerializer(),
 | 
			
		||||
            404: OpenApiResponse(description="Task not found"),
 | 
			
		||||
        }
 | 
			
		||||
    )
 | 
			
		||||
    @action(
 | 
			
		||||
        methods=["GET"],
 | 
			
		||||
        detail=True,
 | 
			
		||||
@ -51,39 +64,18 @@ class OutgoingSyncProviderStatusMixin:
 | 
			
		||||
    def sync_status(self, request: Request, pk: int) -> Response:
 | 
			
		||||
        """Get provider's sync status"""
 | 
			
		||||
        provider: OutgoingSyncProvider = self.get_object()
 | 
			
		||||
 | 
			
		||||
        status = {}
 | 
			
		||||
 | 
			
		||||
        with provider.sync_lock as lock_acquired:
 | 
			
		||||
            # If we could not acquire the lock, it means a task is using it, and thus is running
 | 
			
		||||
            status["is_running"] = not lock_acquired
 | 
			
		||||
 | 
			
		||||
        sync_schedule = None
 | 
			
		||||
        for schedule in provider.schedules.all():
 | 
			
		||||
            if schedule.actor_name == self.sync_task.actor_name:
 | 
			
		||||
                sync_schedule = schedule
 | 
			
		||||
 | 
			
		||||
        if not sync_schedule:
 | 
			
		||||
            return Response(SyncStatusSerializer(status).data)
 | 
			
		||||
 | 
			
		||||
        last_task: Task = (
 | 
			
		||||
            sync_schedule.tasks.exclude(
 | 
			
		||||
                aggregated_status__in=(TaskStatus.CONSUMED, TaskStatus.QUEUED)
 | 
			
		||||
        tasks = list(
 | 
			
		||||
            get_objects_for_user(request.user, "authentik_events.view_systemtask").filter(
 | 
			
		||||
                name=self.sync_single_task.__name__,
 | 
			
		||||
                uid=slugify(provider.name),
 | 
			
		||||
            )
 | 
			
		||||
            .order_by("-mtime")
 | 
			
		||||
            .first()
 | 
			
		||||
        )
 | 
			
		||||
        last_successful_task: Task = (
 | 
			
		||||
            sync_schedule.tasks.filter(aggregated_status__in=(TaskStatus.DONE, TaskStatus.INFO))
 | 
			
		||||
            .order_by("-mtime")
 | 
			
		||||
            .first()
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if last_task:
 | 
			
		||||
            status["last_sync_status"] = last_task.aggregated_status
 | 
			
		||||
        if last_successful_task:
 | 
			
		||||
            status["last_successful_sync"] = last_successful_task.mtime
 | 
			
		||||
 | 
			
		||||
        with provider.sync_lock as lock_acquired:
 | 
			
		||||
            status = {
 | 
			
		||||
                "tasks": tasks,
 | 
			
		||||
                # If we could not acquire the lock, it means a task is using it, and thus is running
 | 
			
		||||
                "is_running": not lock_acquired,
 | 
			
		||||
            }
 | 
			
		||||
        return Response(SyncStatusSerializer(status).data)
 | 
			
		||||
 | 
			
		||||
    @extend_schema(
 | 
			
		||||
@ -102,20 +94,14 @@ class OutgoingSyncProviderStatusMixin:
 | 
			
		||||
        provider: OutgoingSyncProvider = self.get_object()
 | 
			
		||||
        params = SyncObjectSerializer(data=request.data)
 | 
			
		||||
        params.is_valid(raise_exception=True)
 | 
			
		||||
        msg = self.sync_objects_task.send_with_options(
 | 
			
		||||
            kwargs={
 | 
			
		||||
                "object_type": params.validated_data["sync_object_model"],
 | 
			
		||||
                "page": 1,
 | 
			
		||||
                "provider_pk": provider.pk,
 | 
			
		||||
                "override_dry_run": params.validated_data["override_dry_run"],
 | 
			
		||||
                "pk": params.validated_data["sync_object_id"],
 | 
			
		||||
            },
 | 
			
		||||
            rel_obj=provider,
 | 
			
		||||
        )
 | 
			
		||||
        msg.get_result(block=True)
 | 
			
		||||
        task: Task = msg.options["task"]
 | 
			
		||||
        task.refresh_from_db()
 | 
			
		||||
        return Response(SyncObjectResultSerializer(instance={"messages": task._messages}).data)
 | 
			
		||||
        res: list[LogEvent] = self.sync_objects_task.delay(
 | 
			
		||||
            params.validated_data["sync_object_model"],
 | 
			
		||||
            page=1,
 | 
			
		||||
            provider_pk=provider.pk,
 | 
			
		||||
            pk=params.validated_data["sync_object_id"],
 | 
			
		||||
            override_dry_run=params.validated_data["override_dry_run"],
 | 
			
		||||
        ).get()
 | 
			
		||||
        return Response(SyncObjectResultSerializer(instance={"messages": res}).data)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class OutgoingSyncConnectionCreateMixin:
 | 
			
		||||
 | 
			
		||||
@ -1,18 +1,12 @@
 | 
			
		||||
from typing import Any, Self
 | 
			
		||||
 | 
			
		||||
import pglock
 | 
			
		||||
from django.core.paginator import Paginator
 | 
			
		||||
from django.db import connection, models
 | 
			
		||||
from django.db.models import Model, QuerySet, TextChoices
 | 
			
		||||
from django.utils.translation import gettext_lazy as _
 | 
			
		||||
from dramatiq.actor import Actor
 | 
			
		||||
 | 
			
		||||
from authentik.core.models import Group, User
 | 
			
		||||
from authentik.lib.sync.outgoing import PAGE_SIZE, PAGE_TIMEOUT_MS
 | 
			
		||||
from authentik.lib.sync.outgoing.base import BaseOutgoingSyncClient
 | 
			
		||||
from authentik.lib.utils.time import fqdn_rand
 | 
			
		||||
from authentik.tasks.schedules.lib import ScheduleSpec
 | 
			
		||||
from authentik.tasks.schedules.models import ScheduledModel
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class OutgoingSyncDeleteAction(TextChoices):
 | 
			
		||||
@ -24,7 +18,7 @@ class OutgoingSyncDeleteAction(TextChoices):
 | 
			
		||||
    SUSPEND = "suspend"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class OutgoingSyncProvider(ScheduledModel, Model):
 | 
			
		||||
class OutgoingSyncProvider(Model):
 | 
			
		||||
    """Base abstract models for providers implementing outgoing sync"""
 | 
			
		||||
 | 
			
		||||
    dry_run = models.BooleanField(
 | 
			
		||||
@ -45,19 +39,6 @@ class OutgoingSyncProvider(ScheduledModel, Model):
 | 
			
		||||
    def get_object_qs[T: User | Group](self, type: type[T]) -> QuerySet[T]:
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    def get_paginator[T: User | Group](self, type: type[T]) -> Paginator:
 | 
			
		||||
        return Paginator(self.get_object_qs(type), PAGE_SIZE)
 | 
			
		||||
 | 
			
		||||
    def get_object_sync_time_limit_ms[T: User | Group](self, type: type[T]) -> int:
 | 
			
		||||
        num_pages: int = self.get_paginator(type).num_pages
 | 
			
		||||
        return int(num_pages * PAGE_TIMEOUT_MS * 1.5)
 | 
			
		||||
 | 
			
		||||
    def get_sync_time_limit_ms(self) -> int:
 | 
			
		||||
        return int(
 | 
			
		||||
            (self.get_object_sync_time_limit_ms(User) + self.get_object_sync_time_limit_ms(Group))
 | 
			
		||||
            * 1.5
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def sync_lock(self) -> pglock.advisory:
 | 
			
		||||
        """Postgres lock for syncing to prevent multiple parallel syncs happening"""
 | 
			
		||||
@ -66,22 +47,3 @@ class OutgoingSyncProvider(ScheduledModel, Model):
 | 
			
		||||
            timeout=0,
 | 
			
		||||
            side_effect=pglock.Return,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def sync_actor(self) -> Actor:
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def schedule_specs(self) -> list[ScheduleSpec]:
 | 
			
		||||
        return [
 | 
			
		||||
            ScheduleSpec(
 | 
			
		||||
                actor=self.sync_actor,
 | 
			
		||||
                uid=self.pk,
 | 
			
		||||
                args=(self.pk,),
 | 
			
		||||
                options={
 | 
			
		||||
                    "time_limit": self.get_sync_time_limit_ms(),
 | 
			
		||||
                },
 | 
			
		||||
                send_on_save=True,
 | 
			
		||||
                crontab=f"{fqdn_rand(self.pk)} */4 * * *",
 | 
			
		||||
            ),
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
@ -1,8 +1,12 @@
 | 
			
		||||
from collections.abc import Callable
 | 
			
		||||
 | 
			
		||||
from django.core.paginator import Paginator
 | 
			
		||||
from django.db.models import Model
 | 
			
		||||
from django.db.models.query import Q
 | 
			
		||||
from django.db.models.signals import m2m_changed, post_save, pre_delete
 | 
			
		||||
from dramatiq.actor import Actor
 | 
			
		||||
 | 
			
		||||
from authentik.core.models import Group, User
 | 
			
		||||
from authentik.lib.sync.outgoing import PAGE_SIZE, PAGE_TIMEOUT
 | 
			
		||||
from authentik.lib.sync.outgoing.base import Direction
 | 
			
		||||
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
 | 
			
		||||
from authentik.lib.utils.reflection import class_to_path
 | 
			
		||||
@ -10,30 +14,45 @@ from authentik.lib.utils.reflection import class_to_path
 | 
			
		||||
 | 
			
		||||
def register_signals(
 | 
			
		||||
    provider_type: type[OutgoingSyncProvider],
 | 
			
		||||
    task_sync_direct_dispatch: Actor[[str, str | int, str], None],
 | 
			
		||||
    task_sync_m2m_dispatch: Actor[[str, str, list[str], bool], None],
 | 
			
		||||
    task_sync_single: Callable[[int], None],
 | 
			
		||||
    task_sync_direct: Callable[[int], None],
 | 
			
		||||
    task_sync_m2m: Callable[[int], None],
 | 
			
		||||
):
 | 
			
		||||
    """Register sync signals"""
 | 
			
		||||
    uid = class_to_path(provider_type)
 | 
			
		||||
 | 
			
		||||
    def post_save_provider(sender: type[Model], instance: OutgoingSyncProvider, created: bool, **_):
 | 
			
		||||
        """Trigger sync when Provider is saved"""
 | 
			
		||||
        users_paginator = Paginator(instance.get_object_qs(User), PAGE_SIZE)
 | 
			
		||||
        groups_paginator = Paginator(instance.get_object_qs(Group), PAGE_SIZE)
 | 
			
		||||
        soft_time_limit = (users_paginator.num_pages + groups_paginator.num_pages) * PAGE_TIMEOUT
 | 
			
		||||
        time_limit = soft_time_limit * 1.5
 | 
			
		||||
        task_sync_single.apply_async(
 | 
			
		||||
            (instance.pk,), time_limit=int(time_limit), soft_time_limit=int(soft_time_limit)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    post_save.connect(post_save_provider, provider_type, dispatch_uid=uid, weak=False)
 | 
			
		||||
 | 
			
		||||
    def model_post_save(sender: type[Model], instance: User | Group, created: bool, **_):
 | 
			
		||||
        """Post save handler"""
 | 
			
		||||
        task_sync_direct_dispatch.send(
 | 
			
		||||
            class_to_path(instance.__class__),
 | 
			
		||||
            instance.pk,
 | 
			
		||||
            Direction.add.value,
 | 
			
		||||
        )
 | 
			
		||||
        if not provider_type.objects.filter(
 | 
			
		||||
            Q(backchannel_application__isnull=False) | Q(application__isnull=False)
 | 
			
		||||
        ).exists():
 | 
			
		||||
            return
 | 
			
		||||
        task_sync_direct.delay(class_to_path(instance.__class__), instance.pk, Direction.add.value)
 | 
			
		||||
 | 
			
		||||
    post_save.connect(model_post_save, User, dispatch_uid=uid, weak=False)
 | 
			
		||||
    post_save.connect(model_post_save, Group, dispatch_uid=uid, weak=False)
 | 
			
		||||
 | 
			
		||||
    def model_pre_delete(sender: type[Model], instance: User | Group, **_):
 | 
			
		||||
        """Pre-delete handler"""
 | 
			
		||||
        task_sync_direct_dispatch.send(
 | 
			
		||||
            class_to_path(instance.__class__),
 | 
			
		||||
            instance.pk,
 | 
			
		||||
            Direction.remove.value,
 | 
			
		||||
        )
 | 
			
		||||
        if not provider_type.objects.filter(
 | 
			
		||||
            Q(backchannel_application__isnull=False) | Q(application__isnull=False)
 | 
			
		||||
        ).exists():
 | 
			
		||||
            return
 | 
			
		||||
        task_sync_direct.delay(
 | 
			
		||||
            class_to_path(instance.__class__), instance.pk, Direction.remove.value
 | 
			
		||||
        ).get(propagate=False)
 | 
			
		||||
 | 
			
		||||
    pre_delete.connect(model_pre_delete, User, dispatch_uid=uid, weak=False)
 | 
			
		||||
    pre_delete.connect(model_pre_delete, Group, dispatch_uid=uid, weak=False)
 | 
			
		||||
@ -44,6 +63,16 @@ def register_signals(
 | 
			
		||||
        """Sync group membership"""
 | 
			
		||||
        if action not in ["post_add", "post_remove"]:
 | 
			
		||||
            return
 | 
			
		||||
        task_sync_m2m_dispatch.send(instance.pk, action, list(pk_set), reverse)
 | 
			
		||||
        if not provider_type.objects.filter(
 | 
			
		||||
            Q(backchannel_application__isnull=False) | Q(application__isnull=False)
 | 
			
		||||
        ).exists():
 | 
			
		||||
            return
 | 
			
		||||
        # reverse: instance is a Group, pk_set is a list of user pks
 | 
			
		||||
        # non-reverse: instance is a User, pk_set is a list of groups
 | 
			
		||||
        if reverse:
 | 
			
		||||
            task_sync_m2m.delay(str(instance.pk), action, list(pk_set))
 | 
			
		||||
        else:
 | 
			
		||||
            for group_pk in pk_set:
 | 
			
		||||
                task_sync_m2m.delay(group_pk, action, [instance.pk])
 | 
			
		||||
 | 
			
		||||
    m2m_changed.connect(model_m2m_changed, User.ak_groups.through, dispatch_uid=uid, weak=False)
 | 
			
		||||
 | 
			
		||||
@ -1,17 +1,23 @@
 | 
			
		||||
from collections.abc import Callable
 | 
			
		||||
from dataclasses import asdict
 | 
			
		||||
 | 
			
		||||
from celery import group
 | 
			
		||||
from celery.exceptions import Retry
 | 
			
		||||
from celery.result import allow_join_result
 | 
			
		||||
from django.core.paginator import Paginator
 | 
			
		||||
from django.db.models import Model, QuerySet
 | 
			
		||||
from django.db.models.query import Q
 | 
			
		||||
from django.utils.text import slugify
 | 
			
		||||
from django_dramatiq_postgres.middleware import CurrentTask
 | 
			
		||||
from dramatiq.actor import Actor
 | 
			
		||||
from dramatiq.composition import group
 | 
			
		||||
from dramatiq.errors import Retry
 | 
			
		||||
from django.utils.translation import gettext_lazy as _
 | 
			
		||||
from structlog.stdlib import BoundLogger, get_logger
 | 
			
		||||
 | 
			
		||||
from authentik.core.expression.exceptions import SkipObjectException
 | 
			
		||||
from authentik.core.models import Group, User
 | 
			
		||||
from authentik.events.logs import LogEvent
 | 
			
		||||
from authentik.events.models import TaskStatus
 | 
			
		||||
from authentik.events.system_tasks import SystemTask
 | 
			
		||||
from authentik.events.utils import sanitize_item
 | 
			
		||||
from authentik.lib.sync.outgoing import PAGE_SIZE, PAGE_TIMEOUT_MS
 | 
			
		||||
from authentik.lib.sync.outgoing import PAGE_SIZE, PAGE_TIMEOUT
 | 
			
		||||
from authentik.lib.sync.outgoing.base import Direction
 | 
			
		||||
from authentik.lib.sync.outgoing.exceptions import (
 | 
			
		||||
    BadRequestSyncException,
 | 
			
		||||
@ -21,12 +27,11 @@ from authentik.lib.sync.outgoing.exceptions import (
 | 
			
		||||
)
 | 
			
		||||
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
 | 
			
		||||
from authentik.lib.utils.reflection import class_to_path, path_to_class
 | 
			
		||||
from authentik.tasks.models import Task
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SyncTasks:
 | 
			
		||||
    """Container for all sync 'tasks' (this class doesn't actually contain
 | 
			
		||||
    tasks due to dramatiq's magic, however exposes a number of functions to be called from tasks)"""
 | 
			
		||||
    """Container for all sync 'tasks' (this class doesn't actually contain celery
 | 
			
		||||
    tasks due to celery's magic, however exposes a number of functions to be called from tasks)"""
 | 
			
		||||
 | 
			
		||||
    logger: BoundLogger
 | 
			
		||||
 | 
			
		||||
@ -34,104 +39,107 @@ class SyncTasks:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self._provider_model = provider_model
 | 
			
		||||
 | 
			
		||||
    def sync_paginator(
 | 
			
		||||
        self,
 | 
			
		||||
        current_task: Task,
 | 
			
		||||
        provider: OutgoingSyncProvider,
 | 
			
		||||
        sync_objects: Actor[[str, int, int, bool], None],
 | 
			
		||||
        paginator: Paginator,
 | 
			
		||||
        object_type: type[User | Group],
 | 
			
		||||
        **options,
 | 
			
		||||
    ):
 | 
			
		||||
        tasks = []
 | 
			
		||||
        for page in paginator.page_range:
 | 
			
		||||
            page_sync = sync_objects.message_with_options(
 | 
			
		||||
                args=(class_to_path(object_type), page, provider.pk),
 | 
			
		||||
                time_limit=PAGE_TIMEOUT_MS,
 | 
			
		||||
                # Assign tasks to the same schedule as the current one
 | 
			
		||||
                rel_obj=current_task.rel_obj,
 | 
			
		||||
                **options,
 | 
			
		||||
            )
 | 
			
		||||
            tasks.append(page_sync)
 | 
			
		||||
        return tasks
 | 
			
		||||
    def sync_all(self, single_sync: Callable[[int], None]):
 | 
			
		||||
        for provider in self._provider_model.objects.filter(
 | 
			
		||||
            Q(backchannel_application__isnull=False) | Q(application__isnull=False)
 | 
			
		||||
        ):
 | 
			
		||||
            self.trigger_single_task(provider, single_sync)
 | 
			
		||||
 | 
			
		||||
    def sync(
 | 
			
		||||
    def trigger_single_task(self, provider: OutgoingSyncProvider, sync_task: Callable[[int], None]):
 | 
			
		||||
        """Wrapper single sync task that correctly sets time limits based
 | 
			
		||||
        on the amount of objects that will be synced"""
 | 
			
		||||
        users_paginator = Paginator(provider.get_object_qs(User), PAGE_SIZE)
 | 
			
		||||
        groups_paginator = Paginator(provider.get_object_qs(Group), PAGE_SIZE)
 | 
			
		||||
        soft_time_limit = (users_paginator.num_pages + groups_paginator.num_pages) * PAGE_TIMEOUT
 | 
			
		||||
        time_limit = soft_time_limit * 1.5
 | 
			
		||||
        return sync_task.apply_async(
 | 
			
		||||
            (provider.pk,), time_limit=int(time_limit), soft_time_limit=int(soft_time_limit)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def sync_single(
 | 
			
		||||
        self,
 | 
			
		||||
        task: SystemTask,
 | 
			
		||||
        provider_pk: int,
 | 
			
		||||
        sync_objects: Actor[[str, int, int, bool], None],
 | 
			
		||||
        sync_objects: Callable[[int, int], list[str]],
 | 
			
		||||
    ):
 | 
			
		||||
        task: Task = CurrentTask.get_task()
 | 
			
		||||
        self.logger = get_logger().bind(
 | 
			
		||||
            provider_type=class_to_path(self._provider_model),
 | 
			
		||||
            provider_pk=provider_pk,
 | 
			
		||||
        )
 | 
			
		||||
        provider: OutgoingSyncProvider = self._provider_model.objects.filter(
 | 
			
		||||
        provider = self._provider_model.objects.filter(
 | 
			
		||||
            Q(backchannel_application__isnull=False) | Q(application__isnull=False),
 | 
			
		||||
            pk=provider_pk,
 | 
			
		||||
        ).first()
 | 
			
		||||
        if not provider:
 | 
			
		||||
            task.warning("No provider found. Is it assigned to an application?")
 | 
			
		||||
            return
 | 
			
		||||
        task.set_uid(slugify(provider.name))
 | 
			
		||||
        task.info("Starting full provider sync")
 | 
			
		||||
        messages = []
 | 
			
		||||
        messages.append(_("Starting full provider sync"))
 | 
			
		||||
        self.logger.debug("Starting provider sync")
 | 
			
		||||
        with provider.sync_lock as lock_acquired:
 | 
			
		||||
        users_paginator = Paginator(provider.get_object_qs(User), PAGE_SIZE)
 | 
			
		||||
        groups_paginator = Paginator(provider.get_object_qs(Group), PAGE_SIZE)
 | 
			
		||||
        with allow_join_result(), provider.sync_lock as lock_acquired:
 | 
			
		||||
            if not lock_acquired:
 | 
			
		||||
                task.info("Synchronization is already running. Skipping.")
 | 
			
		||||
                self.logger.debug("Failed to acquire sync lock, skipping", provider=provider.name)
 | 
			
		||||
                return
 | 
			
		||||
            try:
 | 
			
		||||
                users_tasks = group(
 | 
			
		||||
                    self.sync_paginator(
 | 
			
		||||
                        current_task=task,
 | 
			
		||||
                        provider=provider,
 | 
			
		||||
                        sync_objects=sync_objects,
 | 
			
		||||
                        paginator=provider.get_paginator(User),
 | 
			
		||||
                        object_type=User,
 | 
			
		||||
                messages.append(_("Syncing users"))
 | 
			
		||||
                user_results = (
 | 
			
		||||
                    group(
 | 
			
		||||
                        [
 | 
			
		||||
                            sync_objects.signature(
 | 
			
		||||
                                args=(class_to_path(User), page, provider_pk),
 | 
			
		||||
                                time_limit=PAGE_TIMEOUT,
 | 
			
		||||
                                soft_time_limit=PAGE_TIMEOUT,
 | 
			
		||||
                            )
 | 
			
		||||
                            for page in users_paginator.page_range
 | 
			
		||||
                        ]
 | 
			
		||||
                    )
 | 
			
		||||
                    .apply_async()
 | 
			
		||||
                    .get()
 | 
			
		||||
                )
 | 
			
		||||
                group_tasks = group(
 | 
			
		||||
                    self.sync_paginator(
 | 
			
		||||
                        current_task=task,
 | 
			
		||||
                        provider=provider,
 | 
			
		||||
                        sync_objects=sync_objects,
 | 
			
		||||
                        paginator=provider.get_paginator(Group),
 | 
			
		||||
                        object_type=Group,
 | 
			
		||||
                for result in user_results:
 | 
			
		||||
                    for msg in result:
 | 
			
		||||
                        messages.append(LogEvent(**msg))
 | 
			
		||||
                messages.append(_("Syncing groups"))
 | 
			
		||||
                group_results = (
 | 
			
		||||
                    group(
 | 
			
		||||
                        [
 | 
			
		||||
                            sync_objects.signature(
 | 
			
		||||
                                args=(class_to_path(Group), page, provider_pk),
 | 
			
		||||
                                time_limit=PAGE_TIMEOUT,
 | 
			
		||||
                                soft_time_limit=PAGE_TIMEOUT,
 | 
			
		||||
                            )
 | 
			
		||||
                            for page in groups_paginator.page_range
 | 
			
		||||
                        ]
 | 
			
		||||
                    )
 | 
			
		||||
                    .apply_async()
 | 
			
		||||
                    .get()
 | 
			
		||||
                )
 | 
			
		||||
                users_tasks.run().wait(timeout=provider.get_object_sync_time_limit_ms(User))
 | 
			
		||||
                group_tasks.run().wait(timeout=provider.get_object_sync_time_limit_ms(Group))
 | 
			
		||||
                for result in group_results:
 | 
			
		||||
                    for msg in result:
 | 
			
		||||
                        messages.append(LogEvent(**msg))
 | 
			
		||||
            except TransientSyncException as exc:
 | 
			
		||||
                self.logger.warning("transient sync exception", exc=exc)
 | 
			
		||||
                task.warning("Sync encountered a transient exception. Retrying", exc=exc)
 | 
			
		||||
                raise Retry() from exc
 | 
			
		||||
                raise task.retry(exc=exc) from exc
 | 
			
		||||
            except StopSync as exc:
 | 
			
		||||
                task.error(exc)
 | 
			
		||||
                task.set_error(exc)
 | 
			
		||||
                return
 | 
			
		||||
        task.set_status(TaskStatus.SUCCESSFUL, *messages)
 | 
			
		||||
 | 
			
		||||
    def sync_objects(
 | 
			
		||||
        self,
 | 
			
		||||
        object_type: str,
 | 
			
		||||
        page: int,
 | 
			
		||||
        provider_pk: int,
 | 
			
		||||
        override_dry_run=False,
 | 
			
		||||
        **filter,
 | 
			
		||||
        self, object_type: str, page: int, provider_pk: int, override_dry_run=False, **filter
 | 
			
		||||
    ):
 | 
			
		||||
        task: Task = CurrentTask.get_task()
 | 
			
		||||
        _object_type: type[Model] = path_to_class(object_type)
 | 
			
		||||
        self.logger = get_logger().bind(
 | 
			
		||||
            provider_type=class_to_path(self._provider_model),
 | 
			
		||||
            provider_pk=provider_pk,
 | 
			
		||||
            object_type=object_type,
 | 
			
		||||
        )
 | 
			
		||||
        provider: OutgoingSyncProvider = self._provider_model.objects.filter(
 | 
			
		||||
            Q(backchannel_application__isnull=False) | Q(application__isnull=False),
 | 
			
		||||
            pk=provider_pk,
 | 
			
		||||
        ).first()
 | 
			
		||||
        messages = []
 | 
			
		||||
        provider = self._provider_model.objects.filter(pk=provider_pk).first()
 | 
			
		||||
        if not provider:
 | 
			
		||||
            task.warning("No provider found. Is it assigned to an application?")
 | 
			
		||||
            return
 | 
			
		||||
        task.set_uid(slugify(provider.name))
 | 
			
		||||
            return messages
 | 
			
		||||
        # Override dry run mode if requested, however don't save the provider
 | 
			
		||||
        # so that scheduled sync tasks still run in dry_run mode
 | 
			
		||||
        if override_dry_run:
 | 
			
		||||
@ -139,13 +147,25 @@ class SyncTasks:
 | 
			
		||||
        try:
 | 
			
		||||
            client = provider.client_for_model(_object_type)
 | 
			
		||||
        except TransientSyncException:
 | 
			
		||||
            return
 | 
			
		||||
            return messages
 | 
			
		||||
        paginator = Paginator(provider.get_object_qs(_object_type).filter(**filter), PAGE_SIZE)
 | 
			
		||||
        if client.can_discover:
 | 
			
		||||
            self.logger.debug("starting discover")
 | 
			
		||||
            client.discover()
 | 
			
		||||
        self.logger.debug("starting sync for page", page=page)
 | 
			
		||||
        task.info(f"Syncing page {page} or {_object_type._meta.verbose_name_plural}")
 | 
			
		||||
        messages.append(
 | 
			
		||||
            asdict(
 | 
			
		||||
                LogEvent(
 | 
			
		||||
                    _(
 | 
			
		||||
                        "Syncing page {page} of {object_type}".format(
 | 
			
		||||
                            page=page, object_type=_object_type._meta.verbose_name_plural
 | 
			
		||||
                        )
 | 
			
		||||
                    ),
 | 
			
		||||
                    log_level="info",
 | 
			
		||||
                    logger=f"{provider._meta.verbose_name}@{object_type}",
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        for obj in paginator.page(page).object_list:
 | 
			
		||||
            obj: Model
 | 
			
		||||
            try:
 | 
			
		||||
@ -154,58 +174,89 @@ class SyncTasks:
 | 
			
		||||
                self.logger.debug("skipping object due to SkipObject", obj=obj)
 | 
			
		||||
                continue
 | 
			
		||||
            except DryRunRejected as exc:
 | 
			
		||||
                task.info(
 | 
			
		||||
                    "Dropping mutating request due to dry run",
 | 
			
		||||
                    obj=sanitize_item(obj),
 | 
			
		||||
                    method=exc.method,
 | 
			
		||||
                    url=exc.url,
 | 
			
		||||
                    body=exc.body,
 | 
			
		||||
                messages.append(
 | 
			
		||||
                    asdict(
 | 
			
		||||
                        LogEvent(
 | 
			
		||||
                            _("Dropping mutating request due to dry run"),
 | 
			
		||||
                            log_level="info",
 | 
			
		||||
                            logger=f"{provider._meta.verbose_name}@{object_type}",
 | 
			
		||||
                            attributes={
 | 
			
		||||
                                "obj": sanitize_item(obj),
 | 
			
		||||
                                "method": exc.method,
 | 
			
		||||
                                "url": exc.url,
 | 
			
		||||
                                "body": exc.body,
 | 
			
		||||
                            },
 | 
			
		||||
                        )
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
            except BadRequestSyncException as exc:
 | 
			
		||||
                self.logger.warning("failed to sync object", exc=exc, obj=obj)
 | 
			
		||||
                task.warning(
 | 
			
		||||
                    f"Failed to sync {obj._meta.verbose_name} {str(obj)} due to error: {str(exc)}",
 | 
			
		||||
                    arguments=exc.args[1:],
 | 
			
		||||
                    obj=sanitize_item(obj),
 | 
			
		||||
                messages.append(
 | 
			
		||||
                    asdict(
 | 
			
		||||
                        LogEvent(
 | 
			
		||||
                            _(
 | 
			
		||||
                                (
 | 
			
		||||
                                    "Failed to sync {object_type} {object_name} "
 | 
			
		||||
                                    "due to error: {error}"
 | 
			
		||||
                                ).format_map(
 | 
			
		||||
                                    {
 | 
			
		||||
                                        "object_type": obj._meta.verbose_name,
 | 
			
		||||
                                        "object_name": str(obj),
 | 
			
		||||
                                        "error": str(exc),
 | 
			
		||||
                                    }
 | 
			
		||||
                                )
 | 
			
		||||
                            ),
 | 
			
		||||
                            log_level="warning",
 | 
			
		||||
                            logger=f"{provider._meta.verbose_name}@{object_type}",
 | 
			
		||||
                            attributes={"arguments": exc.args[1:], "obj": sanitize_item(obj)},
 | 
			
		||||
                        )
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
            except TransientSyncException as exc:
 | 
			
		||||
                self.logger.warning("failed to sync object", exc=exc, user=obj)
 | 
			
		||||
                task.warning(
 | 
			
		||||
                    f"Failed to sync {obj._meta.verbose_name} {str(obj)} due to "
 | 
			
		||||
                    "transient error: {str(exc)}",
 | 
			
		||||
                    obj=sanitize_item(obj),
 | 
			
		||||
                messages.append(
 | 
			
		||||
                    asdict(
 | 
			
		||||
                        LogEvent(
 | 
			
		||||
                            _(
 | 
			
		||||
                                (
 | 
			
		||||
                                    "Failed to sync {object_type} {object_name} "
 | 
			
		||||
                                    "due to transient error: {error}"
 | 
			
		||||
                                ).format_map(
 | 
			
		||||
                                    {
 | 
			
		||||
                                        "object_type": obj._meta.verbose_name,
 | 
			
		||||
                                        "object_name": str(obj),
 | 
			
		||||
                                        "error": str(exc),
 | 
			
		||||
                                    }
 | 
			
		||||
                                )
 | 
			
		||||
                            ),
 | 
			
		||||
                            log_level="warning",
 | 
			
		||||
                            logger=f"{provider._meta.verbose_name}@{object_type}",
 | 
			
		||||
                            attributes={"obj": sanitize_item(obj)},
 | 
			
		||||
                        )
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
            except StopSync as exc:
 | 
			
		||||
                self.logger.warning("Stopping sync", exc=exc)
 | 
			
		||||
                task.warning(
 | 
			
		||||
                    f"Stopping sync due to error: {exc.detail()}",
 | 
			
		||||
                    obj=sanitize_item(obj),
 | 
			
		||||
                messages.append(
 | 
			
		||||
                    asdict(
 | 
			
		||||
                        LogEvent(
 | 
			
		||||
                            _(
 | 
			
		||||
                                "Stopping sync due to error: {error}".format_map(
 | 
			
		||||
                                    {
 | 
			
		||||
                                        "error": exc.detail(),
 | 
			
		||||
                                    }
 | 
			
		||||
                                )
 | 
			
		||||
                            ),
 | 
			
		||||
                            log_level="warning",
 | 
			
		||||
                            logger=f"{provider._meta.verbose_name}@{object_type}",
 | 
			
		||||
                            attributes={"obj": sanitize_item(obj)},
 | 
			
		||||
                        )
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
                break
 | 
			
		||||
        return messages
 | 
			
		||||
 | 
			
		||||
    def sync_signal_direct_dispatch(
 | 
			
		||||
        self,
 | 
			
		||||
        task_sync_signal_direct: Actor[[str, str | int, int, str], None],
 | 
			
		||||
        model: str,
 | 
			
		||||
        pk: str | int,
 | 
			
		||||
        raw_op: str,
 | 
			
		||||
    ):
 | 
			
		||||
        for provider in self._provider_model.objects.filter(
 | 
			
		||||
            Q(backchannel_application__isnull=False) | Q(application__isnull=False)
 | 
			
		||||
        ):
 | 
			
		||||
            task_sync_signal_direct.send_with_options(
 | 
			
		||||
                args=(model, pk, provider.pk, raw_op),
 | 
			
		||||
                rel_obj=provider,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def sync_signal_direct(
 | 
			
		||||
        self,
 | 
			
		||||
        model: str,
 | 
			
		||||
        pk: str | int,
 | 
			
		||||
        provider_pk: int,
 | 
			
		||||
        raw_op: str,
 | 
			
		||||
    ):
 | 
			
		||||
        task: Task = CurrentTask.get_task()
 | 
			
		||||
    def sync_signal_direct(self, model: str, pk: str | int, raw_op: str):
 | 
			
		||||
        self.logger = get_logger().bind(
 | 
			
		||||
            provider_type=class_to_path(self._provider_model),
 | 
			
		||||
        )
 | 
			
		||||
@ -213,108 +264,65 @@ class SyncTasks:
 | 
			
		||||
        instance = model_class.objects.filter(pk=pk).first()
 | 
			
		||||
        if not instance:
 | 
			
		||||
            return
 | 
			
		||||
        provider: OutgoingSyncProvider = self._provider_model.objects.filter(
 | 
			
		||||
            Q(backchannel_application__isnull=False) | Q(application__isnull=False),
 | 
			
		||||
            pk=provider_pk,
 | 
			
		||||
        ).first()
 | 
			
		||||
        if not provider:
 | 
			
		||||
            task.warning("No provider found. Is it assigned to an application?")
 | 
			
		||||
            return
 | 
			
		||||
        task.set_uid(slugify(provider.name))
 | 
			
		||||
        operation = Direction(raw_op)
 | 
			
		||||
        client = provider.client_for_model(instance.__class__)
 | 
			
		||||
        # Check if the object is allowed within the provider's restrictions
 | 
			
		||||
        queryset = provider.get_object_qs(instance.__class__)
 | 
			
		||||
        if not queryset:
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        # The queryset we get from the provider must include the instance we've got given
 | 
			
		||||
        # otherwise ignore this provider
 | 
			
		||||
        if not queryset.filter(pk=instance.pk).exists():
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            if operation == Direction.add:
 | 
			
		||||
                client.write(instance)
 | 
			
		||||
            if operation == Direction.remove:
 | 
			
		||||
                client.delete(instance)
 | 
			
		||||
        except TransientSyncException as exc:
 | 
			
		||||
            raise Retry() from exc
 | 
			
		||||
        except SkipObjectException:
 | 
			
		||||
            return
 | 
			
		||||
        except DryRunRejected as exc:
 | 
			
		||||
            self.logger.info("Rejected dry-run event", exc=exc)
 | 
			
		||||
        except StopSync as exc:
 | 
			
		||||
            self.logger.warning("Stopping sync", exc=exc, provider_pk=provider.pk)
 | 
			
		||||
 | 
			
		||||
    def sync_signal_m2m_dispatch(
 | 
			
		||||
        self,
 | 
			
		||||
        task_sync_signal_m2m: Actor[[str, int, str, list[int]], None],
 | 
			
		||||
        instance_pk: str,
 | 
			
		||||
        action: str,
 | 
			
		||||
        pk_set: list[int],
 | 
			
		||||
        reverse: bool,
 | 
			
		||||
    ):
 | 
			
		||||
        for provider in self._provider_model.objects.filter(
 | 
			
		||||
            Q(backchannel_application__isnull=False) | Q(application__isnull=False)
 | 
			
		||||
        ):
 | 
			
		||||
            # reverse: instance is a Group, pk_set is a list of user pks
 | 
			
		||||
            # non-reverse: instance is a User, pk_set is a list of groups
 | 
			
		||||
            if reverse:
 | 
			
		||||
                task_sync_signal_m2m.send_with_options(
 | 
			
		||||
                    args=(instance_pk, provider.pk, action, list(pk_set)),
 | 
			
		||||
                    rel_obj=provider,
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                for pk in pk_set:
 | 
			
		||||
                    task_sync_signal_m2m.send_with_options(
 | 
			
		||||
                        args=(pk, provider.pk, action, [instance_pk]),
 | 
			
		||||
                        rel_obj=provider,
 | 
			
		||||
                    )
 | 
			
		||||
            client = provider.client_for_model(instance.__class__)
 | 
			
		||||
            # Check if the object is allowed within the provider's restrictions
 | 
			
		||||
            queryset = provider.get_object_qs(instance.__class__)
 | 
			
		||||
            if not queryset:
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
    def sync_signal_m2m(
 | 
			
		||||
        self,
 | 
			
		||||
        group_pk: str,
 | 
			
		||||
        provider_pk: int,
 | 
			
		||||
        action: str,
 | 
			
		||||
        pk_set: list[int],
 | 
			
		||||
    ):
 | 
			
		||||
        task: Task = CurrentTask.get_task()
 | 
			
		||||
            # The queryset we get from the provider must include the instance we've got given
 | 
			
		||||
            # otherwise ignore this provider
 | 
			
		||||
            if not queryset.filter(pk=instance.pk).exists():
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            try:
 | 
			
		||||
                if operation == Direction.add:
 | 
			
		||||
                    client.write(instance)
 | 
			
		||||
                if operation == Direction.remove:
 | 
			
		||||
                    client.delete(instance)
 | 
			
		||||
            except TransientSyncException as exc:
 | 
			
		||||
                raise Retry() from exc
 | 
			
		||||
            except SkipObjectException:
 | 
			
		||||
                continue
 | 
			
		||||
            except DryRunRejected as exc:
 | 
			
		||||
                self.logger.info("Rejected dry-run event", exc=exc)
 | 
			
		||||
            except StopSync as exc:
 | 
			
		||||
                self.logger.warning("Stopping sync", exc=exc, provider_pk=provider.pk)
 | 
			
		||||
 | 
			
		||||
    def sync_signal_m2m(self, group_pk: str, action: str, pk_set: list[int]):
 | 
			
		||||
        self.logger = get_logger().bind(
 | 
			
		||||
            provider_type=class_to_path(self._provider_model),
 | 
			
		||||
        )
 | 
			
		||||
        group = Group.objects.filter(pk=group_pk).first()
 | 
			
		||||
        if not group:
 | 
			
		||||
            return
 | 
			
		||||
        provider: OutgoingSyncProvider = self._provider_model.objects.filter(
 | 
			
		||||
            Q(backchannel_application__isnull=False) | Q(application__isnull=False),
 | 
			
		||||
            pk=provider_pk,
 | 
			
		||||
        ).first()
 | 
			
		||||
        if not provider:
 | 
			
		||||
            task.warning("No provider found. Is it assigned to an application?")
 | 
			
		||||
            return
 | 
			
		||||
        task.set_uid(slugify(provider.name))
 | 
			
		||||
        for provider in self._provider_model.objects.filter(
 | 
			
		||||
            Q(backchannel_application__isnull=False) | Q(application__isnull=False)
 | 
			
		||||
        ):
 | 
			
		||||
            # Check if the object is allowed within the provider's restrictions
 | 
			
		||||
            queryset: QuerySet = provider.get_object_qs(Group)
 | 
			
		||||
            # The queryset we get from the provider must include the instance we've got given
 | 
			
		||||
            # otherwise ignore this provider
 | 
			
		||||
            if not queryset.filter(pk=group_pk).exists():
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
        # Check if the object is allowed within the provider's restrictions
 | 
			
		||||
        queryset: QuerySet = provider.get_object_qs(Group)
 | 
			
		||||
        # The queryset we get from the provider must include the instance we've got given
 | 
			
		||||
        # otherwise ignore this provider
 | 
			
		||||
        if not queryset.filter(pk=group_pk).exists():
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        client = provider.client_for_model(Group)
 | 
			
		||||
        try:
 | 
			
		||||
            operation = None
 | 
			
		||||
            if action == "post_add":
 | 
			
		||||
                operation = Direction.add
 | 
			
		||||
            if action == "post_remove":
 | 
			
		||||
                operation = Direction.remove
 | 
			
		||||
            client.update_group(group, operation, pk_set)
 | 
			
		||||
        except TransientSyncException as exc:
 | 
			
		||||
            raise Retry() from exc
 | 
			
		||||
        except SkipObjectException:
 | 
			
		||||
            return
 | 
			
		||||
        except DryRunRejected as exc:
 | 
			
		||||
            self.logger.info("Rejected dry-run event", exc=exc)
 | 
			
		||||
        except StopSync as exc:
 | 
			
		||||
            self.logger.warning("Stopping sync", exc=exc, provider_pk=provider.pk)
 | 
			
		||||
            client = provider.client_for_model(Group)
 | 
			
		||||
            try:
 | 
			
		||||
                operation = None
 | 
			
		||||
                if action == "post_add":
 | 
			
		||||
                    operation = Direction.add
 | 
			
		||||
                if action == "post_remove":
 | 
			
		||||
                    operation = Direction.remove
 | 
			
		||||
                client.update_group(group, operation, pk_set)
 | 
			
		||||
            except TransientSyncException as exc:
 | 
			
		||||
                raise Retry() from exc
 | 
			
		||||
            except SkipObjectException:
 | 
			
		||||
                continue
 | 
			
		||||
            except DryRunRejected as exc:
 | 
			
		||||
                self.logger.info("Rejected dry-run event", exc=exc)
 | 
			
		||||
            except StopSync as exc:
 | 
			
		||||
                self.logger.warning("Stopping sync", exc=exc, provider_pk=provider.pk)
 | 
			
		||||
 | 
			
		||||
@ -5,8 +5,6 @@ from structlog.stdlib import get_logger
 | 
			
		||||
 | 
			
		||||
from authentik.blueprints.apps import ManagedAppConfig
 | 
			
		||||
from authentik.lib.config import CONFIG
 | 
			
		||||
from authentik.lib.utils.time import fqdn_rand
 | 
			
		||||
from authentik.tasks.schedules.lib import ScheduleSpec
 | 
			
		||||
 | 
			
		||||
LOGGER = get_logger()
 | 
			
		||||
 | 
			
		||||
@ -62,27 +60,3 @@ class AuthentikOutpostConfig(ManagedAppConfig):
 | 
			
		||||
                outpost.save()
 | 
			
		||||
        else:
 | 
			
		||||
            Outpost.objects.filter(managed=MANAGED_OUTPOST).delete()
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def tenant_schedule_specs(self) -> list[ScheduleSpec]:
 | 
			
		||||
        from authentik.outposts.tasks import outpost_token_ensurer
 | 
			
		||||
 | 
			
		||||
        return [
 | 
			
		||||
            ScheduleSpec(
 | 
			
		||||
                actor=outpost_token_ensurer,
 | 
			
		||||
                crontab=f"{fqdn_rand('outpost_token_ensurer')} */8 * * *",
 | 
			
		||||
            ),
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def global_schedule_specs(self) -> list[ScheduleSpec]:
 | 
			
		||||
        from authentik.outposts.tasks import outpost_connection_discovery
 | 
			
		||||
 | 
			
		||||
        return [
 | 
			
		||||
            ScheduleSpec(
 | 
			
		||||
                actor=outpost_connection_discovery,
 | 
			
		||||
                crontab=f"{fqdn_rand('outpost_connection_discovery')} */8 * * *",
 | 
			
		||||
                send_on_startup=True,
 | 
			
		||||
                paused=not CONFIG.get_bool("outposts.discover"),
 | 
			
		||||
            ),
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
@ -101,13 +101,7 @@ class KubernetesController(BaseController):
 | 
			
		||||
            all_logs = []
 | 
			
		||||
            for reconcile_key in self.reconcile_order:
 | 
			
		||||
                if reconcile_key in self.outpost.config.kubernetes_disabled_components:
 | 
			
		||||
                    all_logs.append(
 | 
			
		||||
                        LogEvent(
 | 
			
		||||
                            log_level="info",
 | 
			
		||||
                            event=f"{reconcile_key.title()}: Disabled",
 | 
			
		||||
                            logger=str(type(self)),
 | 
			
		||||
                        )
 | 
			
		||||
                    )
 | 
			
		||||
                    all_logs += [f"{reconcile_key.title()}: Disabled"]
 | 
			
		||||
                    continue
 | 
			
		||||
                with capture_logs() as logs:
 | 
			
		||||
                    reconciler_cls = self.reconcilers.get(reconcile_key)
 | 
			
		||||
@ -140,13 +134,7 @@ class KubernetesController(BaseController):
 | 
			
		||||
            all_logs = []
 | 
			
		||||
            for reconcile_key in self.reconcile_order:
 | 
			
		||||
                if reconcile_key in self.outpost.config.kubernetes_disabled_components:
 | 
			
		||||
                    all_logs.append(
 | 
			
		||||
                        LogEvent(
 | 
			
		||||
                            log_level="info",
 | 
			
		||||
                            event=f"{reconcile_key.title()}: Disabled",
 | 
			
		||||
                            logger=str(type(self)),
 | 
			
		||||
                        )
 | 
			
		||||
                    )
 | 
			
		||||
                    all_logs += [f"{reconcile_key.title()}: Disabled"]
 | 
			
		||||
                    continue
 | 
			
		||||
                with capture_logs() as logs:
 | 
			
		||||
                    reconciler_cls = self.reconcilers.get(reconcile_key)
 | 
			
		||||
 | 
			
		||||
@ -36,10 +36,7 @@ from authentik.lib.config import CONFIG
 | 
			
		||||
from authentik.lib.models import InheritanceForeignKey, SerializerModel
 | 
			
		||||
from authentik.lib.sentry import SentryIgnoredException
 | 
			
		||||
from authentik.lib.utils.errors import exception_to_string
 | 
			
		||||
from authentik.lib.utils.time import fqdn_rand
 | 
			
		||||
from authentik.outposts.controllers.k8s.utils import get_namespace
 | 
			
		||||
from authentik.tasks.schedules.lib import ScheduleSpec
 | 
			
		||||
from authentik.tasks.schedules.models import ScheduledModel
 | 
			
		||||
 | 
			
		||||
OUR_VERSION = parse(__version__)
 | 
			
		||||
OUTPOST_HELLO_INTERVAL = 10
 | 
			
		||||
@ -118,7 +115,7 @@ class OutpostServiceConnectionState:
 | 
			
		||||
    healthy: bool
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class OutpostServiceConnection(ScheduledModel, models.Model):
 | 
			
		||||
class OutpostServiceConnection(models.Model):
 | 
			
		||||
    """Connection details for an Outpost Controller, like Docker or Kubernetes"""
 | 
			
		||||
 | 
			
		||||
    uuid = models.UUIDField(default=uuid4, editable=False, primary_key=True)
 | 
			
		||||
@ -148,11 +145,11 @@ class OutpostServiceConnection(ScheduledModel, models.Model):
 | 
			
		||||
    @property
 | 
			
		||||
    def state(self) -> OutpostServiceConnectionState:
 | 
			
		||||
        """Get state of service connection"""
 | 
			
		||||
        from authentik.outposts.tasks import outpost_service_connection_monitor
 | 
			
		||||
        from authentik.outposts.tasks import outpost_service_connection_state
 | 
			
		||||
 | 
			
		||||
        state = cache.get(self.state_key, None)
 | 
			
		||||
        if not state:
 | 
			
		||||
            outpost_service_connection_monitor.send_with_options(args=(self.pk), rel_obj=self)
 | 
			
		||||
            outpost_service_connection_state.delay(self.pk)
 | 
			
		||||
            return OutpostServiceConnectionState("", False)
 | 
			
		||||
        return state
 | 
			
		||||
 | 
			
		||||
@ -163,20 +160,6 @@ class OutpostServiceConnection(ScheduledModel, models.Model):
 | 
			
		||||
        # since the response doesn't use the correct inheritance
 | 
			
		||||
        return ""
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def schedule_specs(self) -> list[ScheduleSpec]:
 | 
			
		||||
        from authentik.outposts.tasks import outpost_service_connection_monitor
 | 
			
		||||
 | 
			
		||||
        return [
 | 
			
		||||
            ScheduleSpec(
 | 
			
		||||
                actor=outpost_service_connection_monitor,
 | 
			
		||||
                uid=self.pk,
 | 
			
		||||
                args=(self.pk,),
 | 
			
		||||
                crontab="3-59/15 * * * *",
 | 
			
		||||
                send_on_save=True,
 | 
			
		||||
            ),
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DockerServiceConnection(SerializerModel, OutpostServiceConnection):
 | 
			
		||||
    """Service Connection to a Docker endpoint"""
 | 
			
		||||
@ -261,7 +244,7 @@ class KubernetesServiceConnection(SerializerModel, OutpostServiceConnection):
 | 
			
		||||
        return "ak-service-connection-kubernetes-form"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Outpost(ScheduledModel, SerializerModel, ManagedModel):
 | 
			
		||||
class Outpost(SerializerModel, ManagedModel):
 | 
			
		||||
    """Outpost instance which manages a service user and token"""
 | 
			
		||||
 | 
			
		||||
    uuid = models.UUIDField(default=uuid4, editable=False, primary_key=True)
 | 
			
		||||
@ -315,21 +298,6 @@ class Outpost(ScheduledModel, SerializerModel, ManagedModel):
 | 
			
		||||
        """Username for service user"""
 | 
			
		||||
        return f"ak-outpost-{self.uuid.hex}"
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def schedule_specs(self) -> list[ScheduleSpec]:
 | 
			
		||||
        from authentik.outposts.tasks import outpost_controller
 | 
			
		||||
 | 
			
		||||
        return [
 | 
			
		||||
            ScheduleSpec(
 | 
			
		||||
                actor=outpost_controller,
 | 
			
		||||
                uid=self.pk,
 | 
			
		||||
                args=(self.pk,),
 | 
			
		||||
                kwargs={"action": "up", "from_cache": False},
 | 
			
		||||
                crontab=f"{fqdn_rand('outpost_controller')} */4 * * *",
 | 
			
		||||
                send_on_save=True,
 | 
			
		||||
            ),
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
    def build_user_permissions(self, user: User):
 | 
			
		||||
        """Create per-object and global permissions for outpost service-account"""
 | 
			
		||||
        # To ensure the user only has the correct permissions, we delete all of them and re-add
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										28
									
								
								authentik/outposts/settings.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								authentik/outposts/settings.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,28 @@
 | 
			
		||||
"""Outposts Settings"""
 | 
			
		||||
 | 
			
		||||
from celery.schedules import crontab
 | 
			
		||||
 | 
			
		||||
from authentik.lib.utils.time import fqdn_rand
 | 
			
		||||
 | 
			
		||||
CELERY_BEAT_SCHEDULE = {
 | 
			
		||||
    "outposts_controller": {
 | 
			
		||||
        "task": "authentik.outposts.tasks.outpost_controller_all",
 | 
			
		||||
        "schedule": crontab(minute=fqdn_rand("outposts_controller"), hour="*/4"),
 | 
			
		||||
        "options": {"queue": "authentik_scheduled"},
 | 
			
		||||
    },
 | 
			
		||||
    "outposts_service_connection_check": {
 | 
			
		||||
        "task": "authentik.outposts.tasks.outpost_service_connection_monitor",
 | 
			
		||||
        "schedule": crontab(minute="3-59/15"),
 | 
			
		||||
        "options": {"queue": "authentik_scheduled"},
 | 
			
		||||
    },
 | 
			
		||||
    "outpost_token_ensurer": {
 | 
			
		||||
        "task": "authentik.outposts.tasks.outpost_token_ensurer",
 | 
			
		||||
        "schedule": crontab(minute=fqdn_rand("outpost_token_ensurer"), hour="*/8"),
 | 
			
		||||
        "options": {"queue": "authentik_scheduled"},
 | 
			
		||||
    },
 | 
			
		||||
    "outpost_connection_discovery": {
 | 
			
		||||
        "task": "authentik.outposts.tasks.outpost_connection_discovery",
 | 
			
		||||
        "schedule": crontab(minute=fqdn_rand("outpost_connection_discovery"), hour="*/8"),
 | 
			
		||||
        "options": {"queue": "authentik_scheduled"},
 | 
			
		||||
    },
 | 
			
		||||
}
 | 
			
		||||
@ -1,6 +1,7 @@
 | 
			
		||||
"""authentik outpost signals"""
 | 
			
		||||
 | 
			
		||||
from django.core.cache import cache
 | 
			
		||||
from django.db.models import Model
 | 
			
		||||
from django.db.models.signals import m2m_changed, post_save, pre_delete, pre_save
 | 
			
		||||
from django.dispatch import receiver
 | 
			
		||||
from structlog.stdlib import get_logger
 | 
			
		||||
@ -8,19 +9,27 @@ from structlog.stdlib import get_logger
 | 
			
		||||
from authentik.brands.models import Brand
 | 
			
		||||
from authentik.core.models import AuthenticatedSession, Provider
 | 
			
		||||
from authentik.crypto.models import CertificateKeyPair
 | 
			
		||||
from authentik.outposts.models import Outpost, OutpostModel, OutpostServiceConnection
 | 
			
		||||
from authentik.lib.utils.reflection import class_to_path
 | 
			
		||||
from authentik.outposts.models import Outpost, OutpostServiceConnection
 | 
			
		||||
from authentik.outposts.tasks import (
 | 
			
		||||
    CACHE_KEY_OUTPOST_DOWN,
 | 
			
		||||
    outpost_controller,
 | 
			
		||||
    outpost_send_update,
 | 
			
		||||
    outpost_post_save,
 | 
			
		||||
    outpost_session_end,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
LOGGER = get_logger()
 | 
			
		||||
UPDATE_TRIGGERING_MODELS = (
 | 
			
		||||
    Outpost,
 | 
			
		||||
    OutpostServiceConnection,
 | 
			
		||||
    Provider,
 | 
			
		||||
    CertificateKeyPair,
 | 
			
		||||
    Brand,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@receiver(pre_save, sender=Outpost)
 | 
			
		||||
def outpost_pre_save(sender, instance: Outpost, **_):
 | 
			
		||||
def pre_save_outpost(sender, instance: Outpost, **_):
 | 
			
		||||
    """Pre-save checks for an outpost, if the name or config.kubernetes_namespace changes,
 | 
			
		||||
    we call down and then wait for the up after save"""
 | 
			
		||||
    old_instances = Outpost.objects.filter(pk=instance.pk)
 | 
			
		||||
@ -35,89 +44,43 @@ def outpost_pre_save(sender, instance: Outpost, **_):
 | 
			
		||||
    if bool(dirty):
 | 
			
		||||
        LOGGER.info("Outpost needs re-deployment due to changes", instance=instance)
 | 
			
		||||
        cache.set(CACHE_KEY_OUTPOST_DOWN % instance.pk.hex, old_instance)
 | 
			
		||||
        outpost_controller.send_with_options(
 | 
			
		||||
            args=(instance.pk.hex,),
 | 
			
		||||
            kwargs={"action": "down", "from_cache": True},
 | 
			
		||||
            rel_obj=instance,
 | 
			
		||||
        )
 | 
			
		||||
        outpost_controller.delay(instance.pk.hex, action="down", from_cache=True)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@receiver(m2m_changed, sender=Outpost.providers.through)
 | 
			
		||||
def outpost_m2m_changed(sender, instance: Outpost | Provider, action: str, **_):
 | 
			
		||||
def m2m_changed_update(sender, instance: Model, action: str, **_):
 | 
			
		||||
    """Update outpost on m2m change, when providers are added or removed"""
 | 
			
		||||
    if action not in ["post_add", "post_remove", "post_clear"]:
 | 
			
		||||
    if action in ["post_add", "post_remove", "post_clear"]:
 | 
			
		||||
        outpost_post_save.delay(class_to_path(instance.__class__), instance.pk)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@receiver(post_save)
 | 
			
		||||
def post_save_update(sender, instance: Model, created: bool, **_):
 | 
			
		||||
    """If an Outpost is saved, Ensure that token is created/updated
 | 
			
		||||
 | 
			
		||||
    If an OutpostModel, or a model that is somehow connected to an OutpostModel is saved,
 | 
			
		||||
    we send a message down the relevant OutpostModels WS connection to trigger an update"""
 | 
			
		||||
    if instance.__module__ == "django.db.migrations.recorder":
 | 
			
		||||
        return
 | 
			
		||||
    if isinstance(instance, Outpost):
 | 
			
		||||
        outpost_controller.send_with_options(
 | 
			
		||||
            args=(instance.pk,),
 | 
			
		||||
            rel_obj=instance.service_connection,
 | 
			
		||||
        )
 | 
			
		||||
        outpost_send_update.send_with_options(args=(instance.pk,), rel_obj=instance)
 | 
			
		||||
    elif isinstance(instance, OutpostModel):
 | 
			
		||||
        for outpost in instance.outpost_set.all():
 | 
			
		||||
            outpost_controller.send_with_options(
 | 
			
		||||
                args=(instance.pk,),
 | 
			
		||||
                rel_obj=instance.service_connection,
 | 
			
		||||
            )
 | 
			
		||||
            outpost_send_update.send_with_options(args=(outpost.pk,), rel_obj=outpost)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@receiver(post_save, sender=Outpost)
 | 
			
		||||
def outpost_post_save(sender, instance: Outpost, created: bool, **_):
 | 
			
		||||
    if created:
 | 
			
		||||
    if instance.__module__ == "__fake__":
 | 
			
		||||
        return
 | 
			
		||||
    if not isinstance(instance, UPDATE_TRIGGERING_MODELS):
 | 
			
		||||
        return
 | 
			
		||||
    if isinstance(instance, Outpost) and created:
 | 
			
		||||
        LOGGER.info("New outpost saved, ensuring initial token and user are created")
 | 
			
		||||
        _ = instance.token
 | 
			
		||||
    outpost_controller.send_with_options(args=(instance.pk,), rel_obj=instance.service_connection)
 | 
			
		||||
    outpost_send_update.send_with_options(args=(instance.pk,), rel_obj=instance)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def outpost_related_post_save(sender, instance: OutpostServiceConnection | OutpostModel, **_):
 | 
			
		||||
    for outpost in instance.outpost_set.all():
 | 
			
		||||
        outpost_send_update.send_with_options(args=(outpost.pk,), rel_obj=outpost)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
post_save.connect(outpost_related_post_save, sender=OutpostServiceConnection, weak=False)
 | 
			
		||||
for subclass in OutpostModel.__subclasses__():
 | 
			
		||||
    post_save.connect(outpost_related_post_save, sender=subclass, weak=False)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def outpost_reverse_related_post_save(sender, instance: CertificateKeyPair | Brand, **_):
 | 
			
		||||
    for field in instance._meta.get_fields():
 | 
			
		||||
        # Each field is checked if it has a `related_model` attribute (when ForeginKeys or M2Ms)
 | 
			
		||||
        # are used, and if it has a value
 | 
			
		||||
        if not hasattr(field, "related_model"):
 | 
			
		||||
            continue
 | 
			
		||||
        if not field.related_model:
 | 
			
		||||
            continue
 | 
			
		||||
        if not issubclass(field.related_model, OutpostModel):
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        field_name = f"{field.name}_set"
 | 
			
		||||
        if not hasattr(instance, field_name):
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        LOGGER.debug("triggering outpost update from field", field=field.name)
 | 
			
		||||
        # Because the Outpost Model has an M2M to Provider,
 | 
			
		||||
        # we have to iterate over the entire QS
 | 
			
		||||
        for reverse in getattr(instance, field_name).all():
 | 
			
		||||
            if isinstance(reverse, OutpostModel):
 | 
			
		||||
                for outpost in reverse.outpost_set.all():
 | 
			
		||||
                    outpost_send_update.send_with_options(args=(outpost.pk,), rel_obj=outpost)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
post_save.connect(outpost_reverse_related_post_save, sender=Brand, weak=False)
 | 
			
		||||
post_save.connect(outpost_reverse_related_post_save, sender=CertificateKeyPair, weak=False)
 | 
			
		||||
    outpost_post_save.delay(class_to_path(instance.__class__), instance.pk)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@receiver(pre_delete, sender=Outpost)
 | 
			
		||||
def outpost_pre_delete_cleanup(sender, instance: Outpost, **_):
 | 
			
		||||
def pre_delete_cleanup(sender, instance: Outpost, **_):
 | 
			
		||||
    """Ensure that Outpost's user is deleted (which will delete the token through cascade)"""
 | 
			
		||||
    instance.user.delete()
 | 
			
		||||
    cache.set(CACHE_KEY_OUTPOST_DOWN % instance.pk.hex, instance)
 | 
			
		||||
    outpost_controller.send(instance.pk.hex, action="down", from_cache=True)
 | 
			
		||||
    outpost_controller.delay(instance.pk.hex, action="down", from_cache=True)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@receiver(pre_delete, sender=AuthenticatedSession)
 | 
			
		||||
def outpost_logout_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_):
 | 
			
		||||
def logout_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_):
 | 
			
		||||
    """Catch logout by expiring sessions being deleted"""
 | 
			
		||||
    outpost_session_end.send(instance.session.session_key)
 | 
			
		||||
    outpost_session_end.delay(instance.session.session_key)
 | 
			
		||||
 | 
			
		||||
@ -10,17 +10,19 @@ from urllib.parse import urlparse
 | 
			
		||||
from asgiref.sync import async_to_sync
 | 
			
		||||
from channels.layers import get_channel_layer
 | 
			
		||||
from django.core.cache import cache
 | 
			
		||||
from django.db import DatabaseError, InternalError, ProgrammingError
 | 
			
		||||
from django.db.models.base import Model
 | 
			
		||||
from django.utils.text import slugify
 | 
			
		||||
from django.utils.translation import gettext_lazy as _
 | 
			
		||||
from django_dramatiq_postgres.middleware import CurrentTask
 | 
			
		||||
from docker.constants import DEFAULT_UNIX_SOCKET
 | 
			
		||||
from dramatiq.actor import actor
 | 
			
		||||
from kubernetes.config.incluster_config import SERVICE_TOKEN_FILENAME
 | 
			
		||||
from kubernetes.config.kube_config import KUBE_CONFIG_DEFAULT_LOCATION
 | 
			
		||||
from structlog.stdlib import get_logger
 | 
			
		||||
from yaml import safe_load
 | 
			
		||||
 | 
			
		||||
from authentik.events.models import TaskStatus
 | 
			
		||||
from authentik.events.system_tasks import SystemTask, prefill_task
 | 
			
		||||
from authentik.lib.config import CONFIG
 | 
			
		||||
from authentik.lib.utils.reflection import path_to_class
 | 
			
		||||
from authentik.outposts.consumer import OUTPOST_GROUP
 | 
			
		||||
from authentik.outposts.controllers.base import BaseController, ControllerException
 | 
			
		||||
from authentik.outposts.controllers.docker import DockerClient
 | 
			
		||||
@ -29,6 +31,7 @@ from authentik.outposts.models import (
 | 
			
		||||
    DockerServiceConnection,
 | 
			
		||||
    KubernetesServiceConnection,
 | 
			
		||||
    Outpost,
 | 
			
		||||
    OutpostModel,
 | 
			
		||||
    OutpostServiceConnection,
 | 
			
		||||
    OutpostType,
 | 
			
		||||
    ServiceConnectionInvalid,
 | 
			
		||||
@ -41,7 +44,7 @@ from authentik.providers.rac.controllers.docker import RACDockerController
 | 
			
		||||
from authentik.providers.rac.controllers.kubernetes import RACKubernetesController
 | 
			
		||||
from authentik.providers.radius.controllers.docker import RadiusDockerController
 | 
			
		||||
from authentik.providers.radius.controllers.kubernetes import RadiusKubernetesController
 | 
			
		||||
from authentik.tasks.models import Task
 | 
			
		||||
from authentik.root.celery import CELERY_APP
 | 
			
		||||
 | 
			
		||||
LOGGER = get_logger()
 | 
			
		||||
CACHE_KEY_OUTPOST_DOWN = "goauthentik.io/outposts/teardown/%s"
 | 
			
		||||
@ -80,8 +83,8 @@ def controller_for_outpost(outpost: Outpost) -> type[BaseController] | None:
 | 
			
		||||
    return None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@actor(description=_("Update cached state of service connection."))
 | 
			
		||||
def outpost_service_connection_monitor(connection_pk: Any):
 | 
			
		||||
@CELERY_APP.task()
 | 
			
		||||
def outpost_service_connection_state(connection_pk: Any):
 | 
			
		||||
    """Update cached state of a service connection"""
 | 
			
		||||
    connection: OutpostServiceConnection = (
 | 
			
		||||
        OutpostServiceConnection.objects.filter(pk=connection_pk).select_subclasses().first()
 | 
			
		||||
@ -105,11 +108,37 @@ def outpost_service_connection_monitor(connection_pk: Any):
 | 
			
		||||
    cache.set(connection.state_key, state, timeout=None)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@actor(description=_("Create/update/monitor/delete the deployment of an Outpost."))
 | 
			
		||||
def outpost_controller(outpost_pk: str, action: str = "up", from_cache: bool = False):
 | 
			
		||||
@CELERY_APP.task(
 | 
			
		||||
    bind=True,
 | 
			
		||||
    base=SystemTask,
 | 
			
		||||
    throws=(DatabaseError, ProgrammingError, InternalError),
 | 
			
		||||
)
 | 
			
		||||
@prefill_task
 | 
			
		||||
def outpost_service_connection_monitor(self: SystemTask):
 | 
			
		||||
    """Regularly check the state of Outpost Service Connections"""
 | 
			
		||||
    connections = OutpostServiceConnection.objects.all()
 | 
			
		||||
    for connection in connections.iterator():
 | 
			
		||||
        outpost_service_connection_state.delay(connection.pk)
 | 
			
		||||
    self.set_status(
 | 
			
		||||
        TaskStatus.SUCCESSFUL,
 | 
			
		||||
        f"Successfully updated {len(connections)} connections.",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@CELERY_APP.task(
 | 
			
		||||
    throws=(DatabaseError, ProgrammingError, InternalError),
 | 
			
		||||
)
 | 
			
		||||
def outpost_controller_all():
 | 
			
		||||
    """Launch Controller for all Outposts which support it"""
 | 
			
		||||
    for outpost in Outpost.objects.exclude(service_connection=None):
 | 
			
		||||
        outpost_controller.delay(outpost.pk.hex, "up", from_cache=False)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@CELERY_APP.task(bind=True, base=SystemTask)
 | 
			
		||||
def outpost_controller(
 | 
			
		||||
    self: SystemTask, outpost_pk: str, action: str = "up", from_cache: bool = False
 | 
			
		||||
):
 | 
			
		||||
    """Create/update/monitor/delete the deployment of an Outpost"""
 | 
			
		||||
    self: Task = CurrentTask.get_task()
 | 
			
		||||
    self.set_uid(outpost_pk)
 | 
			
		||||
    logs = []
 | 
			
		||||
    if from_cache:
 | 
			
		||||
        outpost: Outpost = cache.get(CACHE_KEY_OUTPOST_DOWN % outpost_pk)
 | 
			
		||||
@ -130,65 +159,125 @@ def outpost_controller(outpost_pk: str, action: str = "up", from_cache: bool = F
 | 
			
		||||
            logs = getattr(controller, f"{action}_with_logs")()
 | 
			
		||||
            LOGGER.debug("-----------------Outpost Controller logs end-------------------")
 | 
			
		||||
    except (ControllerException, ServiceConnectionInvalid) as exc:
 | 
			
		||||
        self.error(exc)
 | 
			
		||||
        self.set_error(exc)
 | 
			
		||||
    else:
 | 
			
		||||
        if from_cache:
 | 
			
		||||
            cache.delete(CACHE_KEY_OUTPOST_DOWN % outpost_pk)
 | 
			
		||||
        self.logs(logs)
 | 
			
		||||
        self.set_status(TaskStatus.SUCCESSFUL, *logs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@actor(description=_("Ensure that all Outposts have valid Service Accounts and Tokens."))
 | 
			
		||||
def outpost_token_ensurer():
 | 
			
		||||
    """
 | 
			
		||||
    Periodically ensure that all Outposts have valid Service Accounts and Tokens
 | 
			
		||||
    """
 | 
			
		||||
    self: Task = CurrentTask.get_task()
 | 
			
		||||
@CELERY_APP.task(bind=True, base=SystemTask)
 | 
			
		||||
@prefill_task
 | 
			
		||||
def outpost_token_ensurer(self: SystemTask):
 | 
			
		||||
    """Periodically ensure that all Outposts have valid Service Accounts
 | 
			
		||||
    and Tokens"""
 | 
			
		||||
    all_outposts = Outpost.objects.all()
 | 
			
		||||
    for outpost in all_outposts:
 | 
			
		||||
        _ = outpost.token
 | 
			
		||||
        outpost.build_user_permissions(outpost.user)
 | 
			
		||||
    self.info(f"Successfully checked {len(all_outposts)} Outposts.")
 | 
			
		||||
    self.set_status(
 | 
			
		||||
        TaskStatus.SUCCESSFUL,
 | 
			
		||||
        f"Successfully checked {len(all_outposts)} Outposts.",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@actor(description=_("Send update to outpost"))
 | 
			
		||||
def outpost_send_update(pk: Any):
 | 
			
		||||
    """Update outpost instance"""
 | 
			
		||||
    outpost = Outpost.objects.filter(pk=pk).first()
 | 
			
		||||
    if not outpost:
 | 
			
		||||
@CELERY_APP.task()
 | 
			
		||||
def outpost_post_save(model_class: str, model_pk: Any):
 | 
			
		||||
    """If an Outpost is saved, Ensure that token is created/updated
 | 
			
		||||
 | 
			
		||||
    If an OutpostModel, or a model that is somehow connected to an OutpostModel is saved,
 | 
			
		||||
    we send a message down the relevant OutpostModels WS connection to trigger an update"""
 | 
			
		||||
    model: Model = path_to_class(model_class)
 | 
			
		||||
    try:
 | 
			
		||||
        instance = model.objects.get(pk=model_pk)
 | 
			
		||||
    except model.DoesNotExist:
 | 
			
		||||
        LOGGER.warning("Model does not exist", model=model, pk=model_pk)
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    if isinstance(instance, Outpost):
 | 
			
		||||
        LOGGER.debug("Trigger reconcile for outpost", instance=instance)
 | 
			
		||||
        outpost_controller.delay(str(instance.pk))
 | 
			
		||||
 | 
			
		||||
    if isinstance(instance, OutpostModel | Outpost):
 | 
			
		||||
        LOGGER.debug("triggering outpost update from outpostmodel/outpost", instance=instance)
 | 
			
		||||
        outpost_send_update(instance)
 | 
			
		||||
 | 
			
		||||
    if isinstance(instance, OutpostServiceConnection):
 | 
			
		||||
        LOGGER.debug("triggering ServiceConnection state update", instance=instance)
 | 
			
		||||
        outpost_service_connection_state.delay(str(instance.pk))
 | 
			
		||||
 | 
			
		||||
    for field in instance._meta.get_fields():
 | 
			
		||||
        # Each field is checked if it has a `related_model` attribute (when ForeginKeys or M2Ms)
 | 
			
		||||
        # are used, and if it has a value
 | 
			
		||||
        if not hasattr(field, "related_model"):
 | 
			
		||||
            continue
 | 
			
		||||
        if not field.related_model:
 | 
			
		||||
            continue
 | 
			
		||||
        if not issubclass(field.related_model, OutpostModel):
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        field_name = f"{field.name}_set"
 | 
			
		||||
        if not hasattr(instance, field_name):
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        LOGGER.debug("triggering outpost update from field", field=field.name)
 | 
			
		||||
        # Because the Outpost Model has an M2M to Provider,
 | 
			
		||||
        # we have to iterate over the entire QS
 | 
			
		||||
        for reverse in getattr(instance, field_name).all():
 | 
			
		||||
            outpost_send_update(reverse)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def outpost_send_update(model_instance: Model):
 | 
			
		||||
    """Send outpost update to all registered outposts, regardless to which authentik
 | 
			
		||||
    instance they are connected"""
 | 
			
		||||
    channel_layer = get_channel_layer()
 | 
			
		||||
    if isinstance(model_instance, OutpostModel):
 | 
			
		||||
        for outpost in model_instance.outpost_set.all():
 | 
			
		||||
            _outpost_single_update(outpost, channel_layer)
 | 
			
		||||
    elif isinstance(model_instance, Outpost):
 | 
			
		||||
        _outpost_single_update(model_instance, channel_layer)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _outpost_single_update(outpost: Outpost, layer=None):
 | 
			
		||||
    """Update outpost instances connected to a single outpost"""
 | 
			
		||||
    # Ensure token again, because this function is called when anything related to an
 | 
			
		||||
    # OutpostModel is saved, so we can be sure permissions are right
 | 
			
		||||
    _ = outpost.token
 | 
			
		||||
    outpost.build_user_permissions(outpost.user)
 | 
			
		||||
    layer = get_channel_layer()
 | 
			
		||||
    if not layer:  # pragma: no cover
 | 
			
		||||
        layer = get_channel_layer()
 | 
			
		||||
    group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)}
 | 
			
		||||
    LOGGER.debug("sending update", channel=group, outpost=outpost)
 | 
			
		||||
    async_to_sync(layer.group_send)(group, {"type": "event.update"})
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@actor(description=_("Checks the local environment and create Service connections."))
 | 
			
		||||
def outpost_connection_discovery():
 | 
			
		||||
@CELERY_APP.task(
 | 
			
		||||
    base=SystemTask,
 | 
			
		||||
    bind=True,
 | 
			
		||||
)
 | 
			
		||||
def outpost_connection_discovery(self: SystemTask):
 | 
			
		||||
    """Checks the local environment and create Service connections."""
 | 
			
		||||
    self: Task = CurrentTask.get_task()
 | 
			
		||||
    messages = []
 | 
			
		||||
    if not CONFIG.get_bool("outposts.discover"):
 | 
			
		||||
        self.info("Outpost integration discovery is disabled")
 | 
			
		||||
        messages.append("Outpost integration discovery is disabled")
 | 
			
		||||
        self.set_status(TaskStatus.SUCCESSFUL, *messages)
 | 
			
		||||
        return
 | 
			
		||||
    # Explicitly check against token filename, as that's
 | 
			
		||||
    # only present when the integration is enabled
 | 
			
		||||
    if Path(SERVICE_TOKEN_FILENAME).exists():
 | 
			
		||||
        self.info("Detected in-cluster Kubernetes Config")
 | 
			
		||||
        messages.append("Detected in-cluster Kubernetes Config")
 | 
			
		||||
        if not KubernetesServiceConnection.objects.filter(local=True).exists():
 | 
			
		||||
            self.info("Created Service Connection for in-cluster")
 | 
			
		||||
            messages.append("Created Service Connection for in-cluster")
 | 
			
		||||
            KubernetesServiceConnection.objects.create(
 | 
			
		||||
                name="Local Kubernetes Cluster", local=True, kubeconfig={}
 | 
			
		||||
            )
 | 
			
		||||
    # For development, check for the existence of a kubeconfig file
 | 
			
		||||
    kubeconfig_path = Path(KUBE_CONFIG_DEFAULT_LOCATION).expanduser()
 | 
			
		||||
    if kubeconfig_path.exists():
 | 
			
		||||
        self.info("Detected kubeconfig")
 | 
			
		||||
        messages.append("Detected kubeconfig")
 | 
			
		||||
        kubeconfig_local_name = f"k8s-{gethostname()}"
 | 
			
		||||
        if not KubernetesServiceConnection.objects.filter(name=kubeconfig_local_name).exists():
 | 
			
		||||
            self.info("Creating kubeconfig Service Connection")
 | 
			
		||||
            messages.append("Creating kubeconfig Service Connection")
 | 
			
		||||
            with kubeconfig_path.open("r", encoding="utf8") as _kubeconfig:
 | 
			
		||||
                KubernetesServiceConnection.objects.create(
 | 
			
		||||
                    name=kubeconfig_local_name,
 | 
			
		||||
@ -197,18 +286,20 @@ def outpost_connection_discovery():
 | 
			
		||||
    unix_socket_path = urlparse(DEFAULT_UNIX_SOCKET).path
 | 
			
		||||
    socket = Path(unix_socket_path)
 | 
			
		||||
    if socket.exists() and access(socket, R_OK):
 | 
			
		||||
        self.info("Detected local docker socket")
 | 
			
		||||
        messages.append("Detected local docker socket")
 | 
			
		||||
        if len(DockerServiceConnection.objects.filter(local=True)) == 0:
 | 
			
		||||
            self.info("Created Service Connection for docker")
 | 
			
		||||
            messages.append("Created Service Connection for docker")
 | 
			
		||||
            DockerServiceConnection.objects.create(
 | 
			
		||||
                name="Local Docker connection",
 | 
			
		||||
                local=True,
 | 
			
		||||
                url=unix_socket_path,
 | 
			
		||||
            )
 | 
			
		||||
    self.set_status(TaskStatus.SUCCESSFUL, *messages)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@actor(description=_("Terminate session on all outposts."))
 | 
			
		||||
@CELERY_APP.task()
 | 
			
		||||
def outpost_session_end(session_id: str):
 | 
			
		||||
    """Update outpost instances connected to a single outpost"""
 | 
			
		||||
    layer = get_channel_layer()
 | 
			
		||||
    hashed_session_id = hash_session_key(session_id)
 | 
			
		||||
    for outpost in Outpost.objects.all():
 | 
			
		||||
 | 
			
		||||
@ -37,7 +37,6 @@ class OutpostTests(TestCase):
 | 
			
		||||
 | 
			
		||||
        # We add a provider, user should only have access to outpost and provider
 | 
			
		||||
        outpost.providers.add(provider)
 | 
			
		||||
        provider.refresh_from_db()
 | 
			
		||||
        permissions = UserObjectPermission.objects.filter(user=outpost.user).order_by(
 | 
			
		||||
            "content_type__model"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
@ -15,7 +15,6 @@ class AuthentikProviderProxyConfig(ManagedAppConfig):
 | 
			
		||||
    def proxy_set_defaults(self):
 | 
			
		||||
        from authentik.providers.proxy.models import ProxyProvider
 | 
			
		||||
 | 
			
		||||
        # TODO: figure out if this can be in pre_save + post_save signals
 | 
			
		||||
        for provider in ProxyProvider.objects.all():
 | 
			
		||||
            provider.set_oauth_defaults()
 | 
			
		||||
            provider.save()
 | 
			
		||||
 | 
			
		||||
@ -1,13 +0,0 @@
 | 
			
		||||
"""Proxy provider signals"""
 | 
			
		||||
 | 
			
		||||
from django.db.models.signals import pre_delete
 | 
			
		||||
from django.dispatch import receiver
 | 
			
		||||
 | 
			
		||||
from authentik.core.models import AuthenticatedSession
 | 
			
		||||
from authentik.providers.proxy.tasks import proxy_on_logout
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@receiver(pre_delete, sender=AuthenticatedSession)
 | 
			
		||||
def logout_proxy_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_):
 | 
			
		||||
    """Catch logout by expiring sessions being deleted"""
 | 
			
		||||
    proxy_on_logout.send(instance.session.session_key)
 | 
			
		||||
@ -1,26 +0,0 @@
 | 
			
		||||
"""proxy provider tasks"""
 | 
			
		||||
 | 
			
		||||
from asgiref.sync import async_to_sync
 | 
			
		||||
from channels.layers import get_channel_layer
 | 
			
		||||
from django.utils.translation import gettext_lazy as _
 | 
			
		||||
from dramatiq.actor import actor
 | 
			
		||||
 | 
			
		||||
from authentik.outposts.consumer import OUTPOST_GROUP
 | 
			
		||||
from authentik.outposts.models import Outpost, OutpostType
 | 
			
		||||
from authentik.providers.oauth2.id_token import hash_session_key
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@actor(description=_("Terminate session on Proxy outpost."))
 | 
			
		||||
def proxy_on_logout(session_id: str):
 | 
			
		||||
    layer = get_channel_layer()
 | 
			
		||||
    hashed_session_id = hash_session_key(session_id)
 | 
			
		||||
    for outpost in Outpost.objects.filter(type=OutpostType.PROXY):
 | 
			
		||||
        group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)}
 | 
			
		||||
        async_to_sync(layer.group_send)(
 | 
			
		||||
            group,
 | 
			
		||||
            {
 | 
			
		||||
                "type": "event.provider.specific",
 | 
			
		||||
                "sub_type": "logout",
 | 
			
		||||
                "session_id": hashed_session_id,
 | 
			
		||||
            },
 | 
			
		||||
        )
 | 
			
		||||
@ -17,7 +17,6 @@ from authentik.core.models import ExpiringModel, PropertyMapping, Provider, User
 | 
			
		||||
from authentik.events.models import Event, EventAction
 | 
			
		||||
from authentik.lib.models import SerializerModel
 | 
			
		||||
from authentik.lib.utils.time import timedelta_string_validator
 | 
			
		||||
from authentik.outposts.models import OutpostModel
 | 
			
		||||
from authentik.policies.models import PolicyBindingModel
 | 
			
		||||
 | 
			
		||||
LOGGER = get_logger()
 | 
			
		||||
@ -38,7 +37,7 @@ class AuthenticationMode(models.TextChoices):
 | 
			
		||||
    PROMPT = "prompt"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RACProvider(OutpostModel, Provider):
 | 
			
		||||
class RACProvider(Provider):
 | 
			
		||||
    """Remotely access computers/servers via RDP/SSH/VNC."""
 | 
			
		||||
 | 
			
		||||
    settings = models.JSONField(default=dict)
 | 
			
		||||
 | 
			
		||||
@ -44,6 +44,7 @@ class RadiusProviderSerializer(ProviderSerializer):
 | 
			
		||||
            "shared_secret",
 | 
			
		||||
            "outpost_set",
 | 
			
		||||
            "mfa_support",
 | 
			
		||||
            "certificate",
 | 
			
		||||
        ]
 | 
			
		||||
        extra_kwargs = ProviderSerializer.Meta.extra_kwargs
 | 
			
		||||
 | 
			
		||||
@ -79,6 +80,7 @@ class RadiusOutpostConfigSerializer(ModelSerializer):
 | 
			
		||||
            "client_networks",
 | 
			
		||||
            "shared_secret",
 | 
			
		||||
            "mfa_support",
 | 
			
		||||
            "certificate",
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -0,0 +1,25 @@
 | 
			
		||||
# Generated by Django 5.1.9 on 2025-05-16 13:53
 | 
			
		||||
 | 
			
		||||
import django.db.models.deletion
 | 
			
		||||
from django.db import migrations, models
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Migration(migrations.Migration):
 | 
			
		||||
 | 
			
		||||
    dependencies = [
 | 
			
		||||
        ("authentik_crypto", "0004_alter_certificatekeypair_name"),
 | 
			
		||||
        ("authentik_providers_radius", "0004_alter_radiusproviderpropertymapping_options"),
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    operations = [
 | 
			
		||||
        migrations.AddField(
 | 
			
		||||
            model_name="radiusprovider",
 | 
			
		||||
            name="certificate",
 | 
			
		||||
            field=models.ForeignKey(
 | 
			
		||||
                default=None,
 | 
			
		||||
                null=True,
 | 
			
		||||
                on_delete=django.db.models.deletion.CASCADE,
 | 
			
		||||
                to="authentik_crypto.certificatekeypair",
 | 
			
		||||
            ),
 | 
			
		||||
        ),
 | 
			
		||||
    ]
 | 
			
		||||
@ -1,11 +1,14 @@
 | 
			
		||||
"""Radius Provider"""
 | 
			
		||||
 | 
			
		||||
from collections.abc import Iterable
 | 
			
		||||
 | 
			
		||||
from django.db import models
 | 
			
		||||
from django.templatetags.static import static
 | 
			
		||||
from django.utils.translation import gettext_lazy as _
 | 
			
		||||
from rest_framework.serializers import Serializer
 | 
			
		||||
 | 
			
		||||
from authentik.core.models import PropertyMapping, Provider
 | 
			
		||||
from authentik.crypto.models import CertificateKeyPair
 | 
			
		||||
from authentik.lib.generators import generate_id
 | 
			
		||||
from authentik.outposts.models import OutpostModel
 | 
			
		||||
 | 
			
		||||
@ -38,6 +41,10 @@ class RadiusProvider(OutpostModel, Provider):
 | 
			
		||||
        ),
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    certificate = models.ForeignKey(
 | 
			
		||||
        CertificateKeyPair, on_delete=models.CASCADE, default=None, null=True
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def launch_url(self) -> str | None:
 | 
			
		||||
        """Radius never has a launch URL"""
 | 
			
		||||
@ -57,6 +64,12 @@ class RadiusProvider(OutpostModel, Provider):
 | 
			
		||||
 | 
			
		||||
        return RadiusProviderSerializer
 | 
			
		||||
 | 
			
		||||
    def get_required_objects(self) -> Iterable[models.Model | str]:
 | 
			
		||||
        required_models = [self, "authentik_stages_mtls.pass_outpost_certificate"]
 | 
			
		||||
        if self.certificate is not None:
 | 
			
		||||
            required_models.append(self.certificate)
 | 
			
		||||
        return required_models
 | 
			
		||||
 | 
			
		||||
    def __str__(self):
 | 
			
		||||
        return f"Radius Provider {self.name}"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -44,5 +44,5 @@ class SCIMProviderViewSet(OutgoingSyncProviderStatusMixin, UsedByMixin, ModelVie
 | 
			
		||||
    filterset_fields = ["name", "exclude_users_service_account", "url", "filter_group"]
 | 
			
		||||
    search_fields = ["name", "url"]
 | 
			
		||||
    ordering = ["name", "url"]
 | 
			
		||||
    sync_task = scim_sync
 | 
			
		||||
    sync_single_task = scim_sync
 | 
			
		||||
    sync_objects_task = scim_sync_objects
 | 
			
		||||
 | 
			
		||||
@ -3,6 +3,7 @@
 | 
			
		||||
from structlog.stdlib import get_logger
 | 
			
		||||
 | 
			
		||||
from authentik.providers.scim.models import SCIMProvider
 | 
			
		||||
from authentik.providers.scim.tasks import scim_sync, sync_tasks
 | 
			
		||||
from authentik.tenants.management import TenantCommand
 | 
			
		||||
 | 
			
		||||
LOGGER = get_logger()
 | 
			
		||||
@ -20,5 +21,4 @@ class Command(TenantCommand):
 | 
			
		||||
            if not provider:
 | 
			
		||||
                LOGGER.warning("Provider does not exist", name=provider_name)
 | 
			
		||||
                continue
 | 
			
		||||
            for schedule in provider.schedules.all():
 | 
			
		||||
                schedule.send().get_result()
 | 
			
		||||
            sync_tasks.trigger_single_task(provider, scim_sync).get()
 | 
			
		||||
 | 
			
		||||
@ -7,7 +7,6 @@ from django.db import models
 | 
			
		||||
from django.db.models import QuerySet
 | 
			
		||||
from django.templatetags.static import static
 | 
			
		||||
from django.utils.translation import gettext_lazy as _
 | 
			
		||||
from dramatiq.actor import Actor
 | 
			
		||||
from rest_framework.serializers import Serializer
 | 
			
		||||
 | 
			
		||||
from authentik.core.models import BackchannelProvider, Group, PropertyMapping, User, UserTypes
 | 
			
		||||
@ -100,12 +99,6 @@ class SCIMProvider(OutgoingSyncProvider, BackchannelProvider):
 | 
			
		||||
    def icon_url(self) -> str | None:
 | 
			
		||||
        return static("authentik/sources/scim.png")
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def sync_actor(self) -> Actor:
 | 
			
		||||
        from authentik.providers.scim.tasks import scim_sync
 | 
			
		||||
 | 
			
		||||
        return scim_sync
 | 
			
		||||
 | 
			
		||||
    def client_for_model(
 | 
			
		||||
        self, model: type[User | Group | SCIMProviderUser | SCIMProviderGroup]
 | 
			
		||||
    ) -> BaseOutgoingSyncClient[User | Group, Any, Any, Self]:
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										13
									
								
								authentik/providers/scim/settings.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								authentik/providers/scim/settings.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,13 @@
 | 
			
		||||
"""SCIM task Settings"""
 | 
			
		||||
 | 
			
		||||
from celery.schedules import crontab
 | 
			
		||||
 | 
			
		||||
from authentik.lib.utils.time import fqdn_rand
 | 
			
		||||
 | 
			
		||||
CELERY_BEAT_SCHEDULE = {
 | 
			
		||||
    "providers_scim_sync": {
 | 
			
		||||
        "task": "authentik.providers.scim.tasks.scim_sync_all",
 | 
			
		||||
        "schedule": crontab(minute=fqdn_rand("scim_sync_all"), hour="*/4"),
 | 
			
		||||
        "options": {"queue": "authentik_scheduled"},
 | 
			
		||||
    },
 | 
			
		||||
}
 | 
			
		||||
@ -2,10 +2,11 @@
 | 
			
		||||
 | 
			
		||||
from authentik.lib.sync.outgoing.signals import register_signals
 | 
			
		||||
from authentik.providers.scim.models import SCIMProvider
 | 
			
		||||
from authentik.providers.scim.tasks import scim_sync_direct_dispatch, scim_sync_m2m_dispatch
 | 
			
		||||
from authentik.providers.scim.tasks import scim_sync, scim_sync_direct, scim_sync_m2m
 | 
			
		||||
 | 
			
		||||
register_signals(
 | 
			
		||||
    SCIMProvider,
 | 
			
		||||
    task_sync_direct_dispatch=scim_sync_direct_dispatch,
 | 
			
		||||
    task_sync_m2m_dispatch=scim_sync_m2m_dispatch,
 | 
			
		||||
    task_sync_single=scim_sync,
 | 
			
		||||
    task_sync_direct=scim_sync_direct,
 | 
			
		||||
    task_sync_m2m=scim_sync_m2m,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -1,40 +1,37 @@
 | 
			
		||||
"""SCIM Provider tasks"""
 | 
			
		||||
 | 
			
		||||
from django.utils.translation import gettext_lazy as _
 | 
			
		||||
from dramatiq.actor import actor
 | 
			
		||||
 | 
			
		||||
from authentik.events.system_tasks import SystemTask
 | 
			
		||||
from authentik.lib.sync.outgoing.exceptions import TransientSyncException
 | 
			
		||||
from authentik.lib.sync.outgoing.tasks import SyncTasks
 | 
			
		||||
from authentik.providers.scim.models import SCIMProvider
 | 
			
		||||
from authentik.root.celery import CELERY_APP
 | 
			
		||||
 | 
			
		||||
sync_tasks = SyncTasks(SCIMProvider)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@actor(description=_("Sync SCIM provider objects."))
 | 
			
		||||
@CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True)
 | 
			
		||||
def scim_sync_objects(*args, **kwargs):
 | 
			
		||||
    return sync_tasks.sync_objects(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@actor(description=_("Full sync for SCIM provider."))
 | 
			
		||||
def scim_sync(provider_pk: int, *args, **kwargs):
 | 
			
		||||
@CELERY_APP.task(
 | 
			
		||||
    base=SystemTask, bind=True, autoretry_for=(TransientSyncException,), retry_backoff=True
 | 
			
		||||
)
 | 
			
		||||
def scim_sync(self, provider_pk: int, *args, **kwargs):
 | 
			
		||||
    """Run full sync for SCIM provider"""
 | 
			
		||||
    return sync_tasks.sync(provider_pk, scim_sync_objects)
 | 
			
		||||
    return sync_tasks.sync_single(self, provider_pk, scim_sync_objects)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@actor(description=_("Sync a direct object (user, group) for SCIM provider."))
 | 
			
		||||
@CELERY_APP.task()
 | 
			
		||||
def scim_sync_all():
 | 
			
		||||
    return sync_tasks.sync_all(scim_sync)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True)
 | 
			
		||||
def scim_sync_direct(*args, **kwargs):
 | 
			
		||||
    return sync_tasks.sync_signal_direct(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@actor(description=_("Dispatch syncs for a direct object (user, group) for SCIM providers."))
 | 
			
		||||
def scim_sync_direct_dispatch(*args, **kwargs):
 | 
			
		||||
    return sync_tasks.sync_signal_direct_dispatch(scim_sync_direct, *args, **kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@actor(description=_("Sync a related object (memberships) for SCIM provider."))
 | 
			
		||||
@CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True)
 | 
			
		||||
def scim_sync_m2m(*args, **kwargs):
 | 
			
		||||
    return sync_tasks.sync_signal_m2m(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@actor(description=_("Dispatch syncs for a related object (memberships) for SCIM providers."))
 | 
			
		||||
def scim_sync_m2m_dispatch(*args, **kwargs):
 | 
			
		||||
    return sync_tasks.sync_signal_m2m_dispatch(scim_sync_m2m, *args, **kwargs)
 | 
			
		||||
 | 
			
		||||
@ -8,7 +8,7 @@ from authentik.core.models import Application
 | 
			
		||||
from authentik.lib.generators import generate_id
 | 
			
		||||
from authentik.providers.scim.clients.base import SCIMClient
 | 
			
		||||
from authentik.providers.scim.models import SCIMMapping, SCIMProvider
 | 
			
		||||
from authentik.providers.scim.tasks import scim_sync
 | 
			
		||||
from authentik.providers.scim.tasks import scim_sync_all
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SCIMClientTests(TestCase):
 | 
			
		||||
@ -85,6 +85,6 @@ class SCIMClientTests(TestCase):
 | 
			
		||||
            self.assertEqual(mock.call_count, 1)
 | 
			
		||||
            self.assertEqual(mock.request_history[0].method, "GET")
 | 
			
		||||
 | 
			
		||||
    def test_scim_sync(self):
 | 
			
		||||
        """test scim_sync task"""
 | 
			
		||||
        scim_sync.send(self.provider.pk).get_result()
 | 
			
		||||
    def test_scim_sync_all(self):
 | 
			
		||||
        """test scim_sync_all task"""
 | 
			
		||||
        scim_sync_all()
 | 
			
		||||
 | 
			
		||||
@ -8,7 +8,7 @@ from authentik.core.models import Application, Group, User
 | 
			
		||||
from authentik.lib.generators import generate_id
 | 
			
		||||
from authentik.providers.scim.clients.schema import ServiceProviderConfiguration
 | 
			
		||||
from authentik.providers.scim.models import SCIMMapping, SCIMProvider
 | 
			
		||||
from authentik.providers.scim.tasks import scim_sync
 | 
			
		||||
from authentik.providers.scim.tasks import scim_sync, sync_tasks
 | 
			
		||||
from authentik.tenants.models import Tenant
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -79,15 +79,17 @@ class SCIMMembershipTests(TestCase):
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            self.configure()
 | 
			
		||||
            scim_sync.send(self.provider.pk)
 | 
			
		||||
            sync_tasks.trigger_single_task(self.provider, scim_sync).get()
 | 
			
		||||
 | 
			
		||||
            self.assertEqual(mocker.call_count, 4)
 | 
			
		||||
            self.assertEqual(mocker.call_count, 6)
 | 
			
		||||
            self.assertEqual(mocker.request_history[0].method, "GET")
 | 
			
		||||
            self.assertEqual(mocker.request_history[1].method, "POST")
 | 
			
		||||
            self.assertEqual(mocker.request_history[1].method, "GET")
 | 
			
		||||
            self.assertEqual(mocker.request_history[2].method, "GET")
 | 
			
		||||
            self.assertEqual(mocker.request_history[3].method, "POST")
 | 
			
		||||
            self.assertEqual(mocker.request_history[4].method, "GET")
 | 
			
		||||
            self.assertEqual(mocker.request_history[5].method, "POST")
 | 
			
		||||
            self.assertJSONEqual(
 | 
			
		||||
                mocker.request_history[1].body,
 | 
			
		||||
                mocker.request_history[3].body,
 | 
			
		||||
                {
 | 
			
		||||
                    "schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
 | 
			
		||||
                    "emails": [],
 | 
			
		||||
@ -99,7 +101,7 @@ class SCIMMembershipTests(TestCase):
 | 
			
		||||
                },
 | 
			
		||||
            )
 | 
			
		||||
            self.assertJSONEqual(
 | 
			
		||||
                mocker.request_history[3].body,
 | 
			
		||||
                mocker.request_history[5].body,
 | 
			
		||||
                {
 | 
			
		||||
                    "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"],
 | 
			
		||||
                    "externalId": str(group.pk),
 | 
			
		||||
@ -167,15 +169,17 @@ class SCIMMembershipTests(TestCase):
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            self.configure()
 | 
			
		||||
            scim_sync.send(self.provider.pk)
 | 
			
		||||
            sync_tasks.trigger_single_task(self.provider, scim_sync).get()
 | 
			
		||||
 | 
			
		||||
            self.assertEqual(mocker.call_count, 4)
 | 
			
		||||
            self.assertEqual(mocker.call_count, 6)
 | 
			
		||||
            self.assertEqual(mocker.request_history[0].method, "GET")
 | 
			
		||||
            self.assertEqual(mocker.request_history[1].method, "POST")
 | 
			
		||||
            self.assertEqual(mocker.request_history[1].method, "GET")
 | 
			
		||||
            self.assertEqual(mocker.request_history[2].method, "GET")
 | 
			
		||||
            self.assertEqual(mocker.request_history[3].method, "POST")
 | 
			
		||||
            self.assertEqual(mocker.request_history[4].method, "GET")
 | 
			
		||||
            self.assertEqual(mocker.request_history[5].method, "POST")
 | 
			
		||||
            self.assertJSONEqual(
 | 
			
		||||
                mocker.request_history[1].body,
 | 
			
		||||
                mocker.request_history[3].body,
 | 
			
		||||
                {
 | 
			
		||||
                    "schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
 | 
			
		||||
                    "active": True,
 | 
			
		||||
@ -187,7 +191,7 @@ class SCIMMembershipTests(TestCase):
 | 
			
		||||
                },
 | 
			
		||||
            )
 | 
			
		||||
            self.assertJSONEqual(
 | 
			
		||||
                mocker.request_history[3].body,
 | 
			
		||||
                mocker.request_history[5].body,
 | 
			
		||||
                {
 | 
			
		||||
                    "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"],
 | 
			
		||||
                    "externalId": str(group.pk),
 | 
			
		||||
@ -283,15 +287,17 @@ class SCIMMembershipTests(TestCase):
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            self.configure()
 | 
			
		||||
            scim_sync.send(self.provider.pk)
 | 
			
		||||
            sync_tasks.trigger_single_task(self.provider, scim_sync).get()
 | 
			
		||||
 | 
			
		||||
            self.assertEqual(mocker.call_count, 4)
 | 
			
		||||
            self.assertEqual(mocker.call_count, 6)
 | 
			
		||||
            self.assertEqual(mocker.request_history[0].method, "GET")
 | 
			
		||||
            self.assertEqual(mocker.request_history[1].method, "POST")
 | 
			
		||||
            self.assertEqual(mocker.request_history[1].method, "GET")
 | 
			
		||||
            self.assertEqual(mocker.request_history[2].method, "GET")
 | 
			
		||||
            self.assertEqual(mocker.request_history[3].method, "POST")
 | 
			
		||||
            self.assertEqual(mocker.request_history[4].method, "GET")
 | 
			
		||||
            self.assertEqual(mocker.request_history[5].method, "POST")
 | 
			
		||||
            self.assertJSONEqual(
 | 
			
		||||
                mocker.request_history[1].body,
 | 
			
		||||
                mocker.request_history[3].body,
 | 
			
		||||
                {
 | 
			
		||||
                    "schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
 | 
			
		||||
                    "emails": [],
 | 
			
		||||
@ -303,7 +309,7 @@ class SCIMMembershipTests(TestCase):
 | 
			
		||||
                },
 | 
			
		||||
            )
 | 
			
		||||
            self.assertJSONEqual(
 | 
			
		||||
                mocker.request_history[3].body,
 | 
			
		||||
                mocker.request_history[5].body,
 | 
			
		||||
                {
 | 
			
		||||
                    "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"],
 | 
			
		||||
                    "externalId": str(group.pk),
 | 
			
		||||
 | 
			
		||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user