events: migrate SystemTasks to DB (#8159)

* events: migrate system tasks to save in DB

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* prefill in app startup

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* cleanup api

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* update web

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* use string for status

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix enum

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* save start and end directly in timestamp from default_timer()

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* improve metrics

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* lint

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix tests

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* rename globally to system task

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* recreate migrations, better denote anonymous user

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* events: lookup actual django app instead of using module path, fallback to module path

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix logger call

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens L
2024-01-24 17:23:03 +01:00
committed by GitHub
parent c0562bf860
commit 96b2a1a9ba
65 changed files with 11564 additions and 12080 deletions

View File

@ -1,134 +0,0 @@
"""Tasks API"""
from importlib import import_module
from django.contrib import messages
from django.http.response import Http404
from django.utils.translation import gettext_lazy as _
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema
from rest_framework.decorators import action
from rest_framework.fields import (
CharField,
ChoiceField,
DateTimeField,
ListField,
SerializerMethodField,
)
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.viewsets import ViewSet
from structlog.stdlib import get_logger
from authentik.api.decorators import permission_required
from authentik.core.api.utils import PassiveSerializer
from authentik.events.monitored_tasks import TaskInfo, TaskResultStatus
from authentik.rbac.permissions import HasPermission
LOGGER = get_logger()
class TaskSerializer(PassiveSerializer):
"""Serialize TaskInfo and TaskResult"""
task_name = CharField()
task_description = CharField()
task_finish_timestamp = DateTimeField(source="finish_time")
task_duration = SerializerMethodField()
status = ChoiceField(
source="result.status.name",
choices=[(x.name, x.name) for x in TaskResultStatus],
)
messages = ListField(source="result.messages")
def get_task_duration(self, instance: TaskInfo) -> int:
"""Get the duration a task took to run"""
return max(instance.finish_timestamp - instance.start_timestamp, 0)
def to_representation(self, instance: TaskInfo):
"""When a new version of authentik adds fields to TaskInfo,
the API will fail with an AttributeError, as the classes
are pickled in cache. In that case, just delete the info"""
try:
return super().to_representation(instance)
# pylint: disable=broad-except
except Exception: # pragma: no cover
if isinstance(self.instance, list):
for inst in self.instance:
inst.delete()
else:
self.instance.delete()
return {}
class TaskViewSet(ViewSet):
"""Read-only view set that returns all background tasks"""
permission_classes = [HasPermission("authentik_rbac.view_system_tasks")]
serializer_class = TaskSerializer
@extend_schema(
responses={
200: TaskSerializer(many=False),
404: OpenApiResponse(description="Task not found"),
},
parameters=[
OpenApiParameter(
"id",
type=OpenApiTypes.STR,
location=OpenApiParameter.PATH,
required=True,
),
],
)
def retrieve(self, request: Request, pk=None) -> Response:
"""Get a single system task"""
task = TaskInfo.by_name(pk)
if not task:
raise Http404
return Response(TaskSerializer(task, many=False).data)
@extend_schema(responses={200: TaskSerializer(many=True)})
def list(self, request: Request) -> Response:
"""List system tasks"""
tasks = sorted(TaskInfo.all().values(), key=lambda task: task.task_name)
return Response(TaskSerializer(tasks, many=True).data)
@permission_required(None, ["authentik_rbac.run_system_tasks"])
@extend_schema(
request=OpenApiTypes.NONE,
responses={
204: OpenApiResponse(description="Task retried successfully"),
404: OpenApiResponse(description="Task not found"),
500: OpenApiResponse(description="Failed to retry task"),
},
parameters=[
OpenApiParameter(
"id",
type=OpenApiTypes.STR,
location=OpenApiParameter.PATH,
required=True,
),
],
)
@action(detail=True, methods=["post"])
def retry(self, request: Request, pk=None) -> Response:
"""Retry task"""
task = TaskInfo.by_name(pk)
if not task:
raise Http404
try:
task_module = import_module(task.task_call_module)
task_func = getattr(task_module, task.task_call_func)
LOGGER.debug("Running task", task=task_func)
task_func.delay(*task.task_call_args, **task.task_call_kwargs)
messages.success(
self.request,
_("Successfully re-scheduled Task %(name)s!" % {"name": task.task_name}),
)
return Response(status=204)
except (ImportError, AttributeError): # pragma: no cover
LOGGER.warning("Failed to run task, remove state", task=task)
# if we get an import error, the module path has probably changed
task.delete()
return Response(status=500)

View File

@ -1,7 +1,6 @@
"""admin signals"""
from django.dispatch import receiver
from authentik.admin.api.tasks import TaskInfo
from authentik.admin.apps import GAUGE_WORKERS
from authentik.root.celery import CELERY_APP
from authentik.root.monitoring import monitoring_set
@ -12,10 +11,3 @@ def monitoring_set_workers(sender, **kwargs):
"""Set worker gauge"""
count = len(CELERY_APP.control.ping(timeout=0.5))
GAUGE_WORKERS.set(count)
@receiver(monitoring_set)
def monitoring_set_tasks(sender, **kwargs):
"""Set task gauges"""
for task in TaskInfo.all().values():
task.update_metrics()

View File

@ -11,12 +11,7 @@ from structlog.stdlib import get_logger
from authentik import __version__, get_build_hash
from authentik.admin.apps import PROM_INFO
from authentik.events.models import Event, EventAction, Notification
from authentik.events.monitored_tasks import (
MonitoredTask,
TaskResult,
TaskResultStatus,
prefill_task,
)
from authentik.events.system_tasks import SystemTask, TaskStatus, prefill_task
from authentik.lib.config import CONFIG
from authentik.lib.utils.http import get_http_session
from authentik.root.celery import CELERY_APP
@ -54,13 +49,13 @@ def clear_update_notifications():
notification.delete()
@CELERY_APP.task(bind=True, base=MonitoredTask)
@CELERY_APP.task(bind=True, base=SystemTask)
@prefill_task
def update_latest_version(self: MonitoredTask):
def update_latest_version(self: SystemTask):
"""Update latest version info"""
if CONFIG.get_bool("disable_update_check"):
cache.set(VERSION_CACHE_KEY, "0.0.0", VERSION_CACHE_TIMEOUT)
self.set_status(TaskResult(TaskResultStatus.WARNING, messages=["Version check disabled."]))
self.set_status(TaskStatus.WARNING, "Version check disabled.")
return
try:
response = get_http_session().get(
@ -70,9 +65,7 @@ def update_latest_version(self: MonitoredTask):
data = response.json()
upstream_version = data.get("stable", {}).get("version")
cache.set(VERSION_CACHE_KEY, upstream_version, VERSION_CACHE_TIMEOUT)
self.set_status(
TaskResult(TaskResultStatus.SUCCESSFUL, ["Successfully updated latest Version"])
)
self.set_status(TaskStatus.SUCCESSFUL, "Successfully updated latest Version")
_set_prom_info()
# Check if upstream version is newer than what we're running,
# and if no event exists yet, create one.
@ -89,7 +82,7 @@ def update_latest_version(self: MonitoredTask):
Event.new(EventAction.UPDATE_AVAILABLE, **event_dict).save()
except (RequestException, IndexError) as exc:
cache.set(VERSION_CACHE_KEY, "0.0.0", VERSION_CACHE_TIMEOUT)
self.set_status(TaskResult(TaskResultStatus.ERROR).with_error(exc))
self.set_error(exc)
_set_prom_info()

View File

@ -7,8 +7,6 @@ from django.urls import reverse
from authentik import __version__
from authentik.blueprints.tests import reconcile_app
from authentik.core.models import Group, User
from authentik.core.tasks import clean_expired_models
from authentik.events.monitored_tasks import TaskResultStatus
from authentik.lib.generators import generate_id
@ -23,53 +21,6 @@ class TestAdminAPI(TestCase):
self.group.save()
self.client.force_login(self.user)
def test_tasks(self):
"""Test Task API"""
clean_expired_models.delay()
response = self.client.get(reverse("authentik_api:admin_system_tasks-list"))
self.assertEqual(response.status_code, 200)
body = loads(response.content)
self.assertTrue(any(task["task_name"] == "clean_expired_models" for task in body))
def test_tasks_single(self):
"""Test Task API (read single)"""
clean_expired_models.delay()
response = self.client.get(
reverse(
"authentik_api:admin_system_tasks-detail",
kwargs={"pk": "clean_expired_models"},
)
)
self.assertEqual(response.status_code, 200)
body = loads(response.content)
self.assertEqual(body["status"], TaskResultStatus.SUCCESSFUL.name)
self.assertEqual(body["task_name"], "clean_expired_models")
response = self.client.get(
reverse("authentik_api:admin_system_tasks-detail", kwargs={"pk": "qwerqwer"})
)
self.assertEqual(response.status_code, 404)
def test_tasks_retry(self):
"""Test Task API (retry)"""
clean_expired_models.delay()
response = self.client.post(
reverse(
"authentik_api:admin_system_tasks-retry",
kwargs={"pk": "clean_expired_models"},
)
)
self.assertEqual(response.status_code, 204)
def test_tasks_retry_404(self):
"""Test Task API (retry, 404)"""
response = self.client.post(
reverse(
"authentik_api:admin_system_tasks-retry",
kwargs={"pk": "qwerqewrqrqewrqewr"},
)
)
self.assertEqual(response.status_code, 404)
def test_version(self):
"""Test Version API"""
response = self.client.get(reverse("authentik_api:admin_version"))

View File

@ -4,12 +4,10 @@ from django.urls import path
from authentik.admin.api.meta import AppsViewSet, ModelViewSet
from authentik.admin.api.metrics import AdministrationMetricsViewSet
from authentik.admin.api.system import SystemView
from authentik.admin.api.tasks import TaskViewSet
from authentik.admin.api.version import VersionView
from authentik.admin.api.workers import WorkerView
api_urlpatterns = [
("admin/system_tasks", TaskViewSet, "admin_system_tasks"),
("admin/apps", AppsViewSet, "apps"),
("admin/models", ModelViewSet, "models"),
path(

View File

@ -11,14 +11,14 @@ from structlog.stdlib import BoundLogger, get_logger
class ManagedAppConfig(AppConfig):
"""Basic reconciliation logic for apps"""
_logger: BoundLogger
logger: BoundLogger
RECONCILE_GLOBAL_PREFIX: str = "reconcile_global_"
RECONCILE_TENANT_PREFIX: str = "reconcile_tenant_"
def __init__(self, app_name: str, *args, **kwargs) -> None:
super().__init__(app_name, *args, **kwargs)
self._logger = get_logger().bind(app_name=app_name)
self.logger = get_logger().bind(app_name=app_name)
def ready(self) -> None:
self.reconcile_global()
@ -38,11 +38,11 @@ class ManagedAppConfig(AppConfig):
continue
name = meth_name.replace(prefix, "")
try:
self._logger.debug("Starting reconciler", name=name)
self.logger.debug("Starting reconciler", name=name)
meth()
self._logger.debug("Successfully reconciled", name=name)
self.logger.debug("Successfully reconciled", name=name)
except (DatabaseError, ProgrammingError, InternalError) as exc:
self._logger.warning("Failed to run reconcile", name=name, exc=exc)
self.logger.warning("Failed to run reconcile", name=name, exc=exc)
def reconcile_tenant(self) -> None:
"""reconcile ourselves for tenanted methods"""
@ -51,7 +51,7 @@ class ManagedAppConfig(AppConfig):
try:
tenants = list(Tenant.objects.filter(ready=True))
except (DatabaseError, ProgrammingError, InternalError) as exc:
self._logger.debug("Failed to get tenants to run reconcile", exc=exc)
self.logger.debug("Failed to get tenants to run reconcile", exc=exc)
return
for tenant in tenants:
with tenant:

View File

@ -14,6 +14,7 @@ from django.db.models import Model
from django.db.models.query_utils import Q
from django.db.transaction import atomic
from django.db.utils import IntegrityError
from guardian.models import UserObjectPermission
from rest_framework.exceptions import ValidationError
from rest_framework.serializers import BaseSerializer, Serializer
from structlog.stdlib import BoundLogger, get_logger
@ -38,12 +39,16 @@ from authentik.core.models import (
UserSourceConnection,
)
from authentik.enterprise.models import LicenseKey, LicenseUsage
from authentik.enterprise.providers.rac.models import ConnectionToken
from authentik.events.models import SystemTask
from authentik.events.utils import cleanse_dict
from authentik.flows.models import FlowToken, Stage
from authentik.lib.models import SerializerModel
from authentik.lib.sentry import SentryIgnoredException
from authentik.outposts.models import OutpostServiceConnection
from authentik.policies.models import Policy, PolicyBindingModel
from authentik.policies.reputation.models import Reputation
from authentik.providers.oauth2.models import AccessToken, AuthorizationCode, RefreshToken
from authentik.providers.scim.models import SCIMGroup, SCIMUser
from authentik.tenants.models import Tenant
@ -65,6 +70,7 @@ def excluded_models() -> list[type[Model]]:
DjangoGroup,
ContentType,
Permission,
UserObjectPermission,
# Base classes
Provider,
Source,
@ -82,6 +88,12 @@ def excluded_models() -> list[type[Model]]:
SCIMGroup,
SCIMUser,
Tenant,
SystemTask,
ConnectionToken,
AuthorizationCode,
AccessToken,
RefreshToken,
Reputation,
)

View File

@ -29,12 +29,8 @@ from authentik.blueprints.v1.common import BlueprintLoader, BlueprintMetadata, E
from authentik.blueprints.v1.importer import Importer
from authentik.blueprints.v1.labels import LABEL_AUTHENTIK_INSTANTIATE
from authentik.blueprints.v1.oci import OCI_PREFIX
from authentik.events.monitored_tasks import (
MonitoredTask,
TaskResult,
TaskResultStatus,
prefill_task,
)
from authentik.events.models import TaskStatus
from authentik.events.system_tasks import SystemTask, prefill_task
from authentik.events.utils import sanitize_dict
from authentik.lib.config import CONFIG
from authentik.root.celery import CELERY_APP
@ -134,10 +130,10 @@ def blueprints_find() -> list[BlueprintFile]:
@CELERY_APP.task(
throws=(DatabaseError, ProgrammingError, InternalError), base=MonitoredTask, bind=True
throws=(DatabaseError, ProgrammingError, InternalError), base=SystemTask, bind=True
)
@prefill_task
def blueprints_discovery(self: MonitoredTask, path: Optional[str] = None):
def blueprints_discovery(self: SystemTask, path: Optional[str] = None):
"""Find blueprints and check if they need to be created in the database"""
count = 0
for blueprint in blueprints_find():
@ -146,10 +142,7 @@ def blueprints_discovery(self: MonitoredTask, path: Optional[str] = None):
check_blueprint_v1_file(blueprint)
count += 1
self.set_status(
TaskResult(
TaskResultStatus.SUCCESSFUL,
messages=[_("Successfully imported %(count)d files." % {"count": count})],
)
TaskStatus.SUCCESSFUL, _("Successfully imported %(count)d files." % {"count": count})
)
@ -182,9 +175,9 @@ def check_blueprint_v1_file(blueprint: BlueprintFile):
@CELERY_APP.task(
bind=True,
base=MonitoredTask,
base=SystemTask,
)
def apply_blueprint(self: MonitoredTask, instance_pk: str):
def apply_blueprint(self: SystemTask, instance_pk: str):
"""Apply single blueprint"""
self.save_on_success = False
instance: Optional[BlueprintInstance] = None
@ -202,18 +195,18 @@ def apply_blueprint(self: MonitoredTask, instance_pk: str):
if not valid:
instance.status = BlueprintInstanceStatus.ERROR
instance.save()
self.set_status(TaskResult(TaskResultStatus.ERROR, [x["event"] for x in logs]))
self.set_status(TaskStatus.ERROR, *[x["event"] for x in logs])
return
applied = importer.apply()
if not applied:
instance.status = BlueprintInstanceStatus.ERROR
instance.save()
self.set_status(TaskResult(TaskResultStatus.ERROR, "Failed to apply"))
self.set_status(TaskStatus.ERROR, "Failed to apply")
return
instance.status = BlueprintInstanceStatus.SUCCESSFUL
instance.last_applied_hash = file_hash
instance.last_applied = now()
self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL))
self.set_status(TaskStatus.SUCCESSFUL)
except (
DatabaseError,
ProgrammingError,
@ -224,7 +217,7 @@ def apply_blueprint(self: MonitoredTask, instance_pk: str):
) as exc:
if instance:
instance.status = BlueprintInstanceStatus.ERROR
self.set_status(TaskResult(TaskResultStatus.ERROR).with_error(exc))
self.set_error(exc)
finally:
if instance:
instance.save()

View File

@ -31,7 +31,7 @@ class UsedBySerializer(PassiveSerializer):
model_name = CharField()
pk = CharField()
name = CharField()
action = ChoiceField(choices=[(x.name, x.name) for x in DeleteAction])
action = ChoiceField(choices=[(x.value, x.name) for x in DeleteAction])
def get_delete_action(manager: Manager) -> str:

View File

@ -13,20 +13,15 @@ from authentik.core.models import (
ExpiringModel,
User,
)
from authentik.events.monitored_tasks import (
MonitoredTask,
TaskResult,
TaskResultStatus,
prefill_task,
)
from authentik.events.system_tasks import SystemTask, TaskStatus, prefill_task
from authentik.root.celery import CELERY_APP
LOGGER = get_logger()
@CELERY_APP.task(bind=True, base=MonitoredTask)
@CELERY_APP.task(bind=True, base=SystemTask)
@prefill_task
def clean_expired_models(self: MonitoredTask):
def clean_expired_models(self: SystemTask):
"""Remove expired objects"""
messages = []
for cls in ExpiringModel.__subclasses__():
@ -54,12 +49,12 @@ def clean_expired_models(self: MonitoredTask):
amount += 1
LOGGER.debug("Expired sessions", model=AuthenticatedSession, amount=amount)
messages.append(f"Expired {amount} {AuthenticatedSession._meta.verbose_name_plural}")
self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL, messages))
self.set_status(TaskStatus.SUCCESSFUL, *messages)
@CELERY_APP.task(bind=True, base=MonitoredTask)
@CELERY_APP.task(bind=True, base=SystemTask)
@prefill_task
def clean_temporary_users(self: MonitoredTask):
def clean_temporary_users(self: SystemTask):
"""Remove temporary users created by SAML Sources"""
_now = datetime.now()
messages = []
@ -75,4 +70,4 @@ def clean_temporary_users(self: MonitoredTask):
user.delete()
deleted_users += 1
messages.append(f"Successfully deleted {deleted_users} users.")
self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL, messages))
self.set_status(TaskStatus.SUCCESSFUL, *messages)

View File

@ -9,12 +9,8 @@ from django.utils.translation import gettext_lazy as _
from structlog.stdlib import get_logger
from authentik.crypto.models import CertificateKeyPair
from authentik.events.monitored_tasks import (
MonitoredTask,
TaskResult,
TaskResultStatus,
prefill_task,
)
from authentik.events.models import TaskStatus
from authentik.events.system_tasks import SystemTask, prefill_task
from authentik.lib.config import CONFIG
from authentik.root.celery import CELERY_APP
@ -39,9 +35,9 @@ def ensure_certificate_valid(body: str):
return body
@CELERY_APP.task(bind=True, base=MonitoredTask)
@CELERY_APP.task(bind=True, base=SystemTask)
@prefill_task
def certificate_discovery(self: MonitoredTask):
def certificate_discovery(self: SystemTask):
"""Discover, import and update certificates from the filesystem"""
certs = {}
private_keys = {}
@ -88,8 +84,5 @@ def certificate_discovery(self: MonitoredTask):
if dirty:
cert.save()
self.set_status(
TaskResult(
TaskResultStatus.SUCCESSFUL,
messages=[_("Successfully imported %(count)d files." % {"count": discovered})],
)
TaskStatus.SUCCESSFUL, _("Successfully imported %(count)d files." % {"count": discovered})
)

View File

@ -0,0 +1,107 @@
"""Tasks API"""
from datetime import datetime, timezone
from importlib import import_module
from django.contrib import messages
from django.utils.translation import gettext_lazy as _
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiResponse, extend_schema
from rest_framework.decorators import action
from rest_framework.fields import CharField, ChoiceField, ListField, SerializerMethodField
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.serializers import ModelSerializer
from rest_framework.viewsets import ReadOnlyModelViewSet
from structlog.stdlib import get_logger
from authentik.api.decorators import permission_required
from authentik.events.models import SystemTask, TaskStatus
LOGGER = get_logger()
class SystemTaskSerializer(ModelSerializer):
"""Serialize TaskInfo and TaskResult"""
name = CharField()
full_name = SerializerMethodField()
uid = CharField(required=False)
description = CharField()
start_timestamp = SerializerMethodField()
finish_timestamp = SerializerMethodField()
duration = SerializerMethodField()
status = ChoiceField(choices=[(x.value, x.name) for x in TaskStatus])
messages = ListField(child=CharField())
def get_full_name(self, instance: SystemTask) -> str:
"""Get full name with UID"""
if instance.uid:
return f"{instance.name}:{instance.uid}"
return instance.name
def get_start_timestamp(self, instance: SystemTask) -> datetime:
"""Timestamp when the task started"""
return datetime.fromtimestamp(instance.start_timestamp, tz=timezone.utc)
def get_finish_timestamp(self, instance: SystemTask) -> datetime:
"""Timestamp when the task finished"""
return datetime.fromtimestamp(instance.finish_timestamp, tz=timezone.utc)
def get_duration(self, instance: SystemTask) -> float:
"""Get the duration a task took to run"""
return max(instance.finish_timestamp - instance.start_timestamp, 0)
class Meta:
model = SystemTask
fields = [
"uuid",
"name",
"full_name",
"uid",
"description",
"start_timestamp",
"finish_timestamp",
"duration",
"status",
"messages",
]
class SystemTaskViewSet(ReadOnlyModelViewSet):
"""Read-only view set that returns all background tasks"""
queryset = SystemTask.objects.all()
serializer_class = SystemTaskSerializer
filterset_fields = ["name", "uid", "status"]
ordering = ["name", "uid", "status"]
search_fields = ["name", "description", "uid", "status"]
@permission_required(None, ["authentik_events.run_task"])
@extend_schema(
request=OpenApiTypes.NONE,
responses={
204: OpenApiResponse(description="Task retried successfully"),
404: OpenApiResponse(description="Task not found"),
500: OpenApiResponse(description="Failed to retry task"),
},
)
@action(detail=True, methods=["post"])
def run(self, request: Request, pk=None) -> Response:
"""Run task"""
task: SystemTask = self.get_object()
try:
task_module = import_module(task.task_call_module)
task_func = getattr(task_module, task.task_call_func)
LOGGER.info("Running task", task=task_func)
task_func.delay(*task.task_call_args, **task.task_call_kwargs)
messages.success(
self.request,
_("Successfully started task %(name)s." % {"name": task.name}),
)
return Response(status=204)
except (ImportError, AttributeError) as exc: # pragma: no cover
LOGGER.warning("Failed to run task, remove state", task=task.name, exc=exc)
# if we get an import error, the module path has probably changed
task.delete()
return Response(status=500)

View File

@ -1,15 +1,26 @@
"""authentik events app"""
from prometheus_client import Gauge
from prometheus_client import Gauge, Histogram
from authentik.blueprints.apps import ManagedAppConfig
from authentik.lib.config import CONFIG, ENV_PREFIX
# TODO: Deprecated metric - remove in 2024.2 or later
GAUGE_TASKS = Gauge(
"authentik_system_tasks",
"System tasks and their status",
["tenant", "task_name", "task_uid", "status"],
)
SYSTEM_TASK_TIME = Histogram(
"authentik_system_tasks_time_seconds",
"Runtime of system tasks",
)
SYSTEM_TASK_STATUS = Gauge(
"authentik_system_tasks_status",
"System task status",
["task_name", "task_uid", "status"],
)
class AuthentikEventsConfig(ManagedAppConfig):
"""authentik events app"""
@ -43,3 +54,14 @@ class AuthentikEventsConfig(ManagedAppConfig):
replacement_env=replace_env,
message=msg,
).save()
def reconcile_prefill_tasks(self):
"""Prefill tasks"""
from authentik.events.models import SystemTask
from authentik.events.system_tasks import _prefill_tasks
for task in _prefill_tasks:
if SystemTask.objects.filter(name=task.name).exists():
continue
task.save()
self.logger.debug("prefilled task", task_name=task.name)

View File

@ -9,19 +9,14 @@ from django.core.exceptions import SuspiciousOperation
from django.db.models import Model
from django.db.models.signals import m2m_changed, post_save, pre_delete
from django.http import HttpRequest, HttpResponse
from guardian.models import UserObjectPermission
from structlog.stdlib import BoundLogger, get_logger
from authentik.blueprints.v1.importer import excluded_models
from authentik.core.models import Group, User
from authentik.enterprise.providers.rac.models import ConnectionToken
from authentik.events.models import Event, EventAction, Notification
from authentik.events.utils import model_to_dict
from authentik.lib.sentry import before_send
from authentik.lib.utils.errors import exception_to_string
from authentik.policies.reputation.models import Reputation
from authentik.providers.oauth2.models import AccessToken, AuthorizationCode, RefreshToken
from authentik.providers.scim.models import SCIMGroup, SCIMUser
from authentik.stages.authenticator_static.models import StaticToken
IGNORED_MODELS = tuple(
@ -29,16 +24,8 @@ IGNORED_MODELS = tuple(
+ (
Event,
Notification,
UserObjectPermission,
StaticToken,
Session,
AuthorizationCode,
AccessToken,
RefreshToken,
SCIMUser,
SCIMGroup,
Reputation,
ConnectionToken,
)
)

View File

@ -0,0 +1,60 @@
# Generated by Django 5.0.1 on 2024-01-24 12:48
import uuid
from django.db import migrations, models
import authentik.core.models
class Migration(migrations.Migration):
dependencies = [
("authentik_events", "0003_rename_tenant_event_brand"),
]
operations = [
migrations.CreateModel(
name="SystemTask",
fields=[
(
"expires",
models.DateTimeField(default=authentik.core.models.default_token_duration),
),
("expiring", models.BooleanField(default=True)),
(
"uuid",
models.UUIDField(
default=uuid.uuid4, editable=False, primary_key=True, serialize=False
),
),
("name", models.TextField()),
("uid", models.TextField(null=True)),
("start_timestamp", models.FloatField()),
("finish_timestamp", models.FloatField()),
(
"status",
models.TextField(
choices=[
("unknown", "Unknown"),
("successful", "Successful"),
("warning", "Warning"),
("error", "Error"),
]
),
),
("description", models.TextField(null=True)),
("messages", models.JSONField()),
("task_call_module", models.TextField()),
("task_call_func", models.TextField()),
("task_call_args", models.JSONField(default=list)),
("task_call_kwargs", models.JSONField(default=dict)),
],
options={
"verbose_name": "System Task",
"verbose_name_plural": "System Tasks",
"permissions": [("run_task", "Run task")],
"default_permissions": ["view"],
"unique_together": {("name", "uid")},
},
),
]

View File

@ -2,11 +2,14 @@
import time
from collections import Counter
from datetime import timedelta
from difflib import get_close_matches
from functools import lru_cache
from inspect import currentframe
from smtplib import SMTPException
from typing import TYPE_CHECKING, Optional
from typing import Optional
from uuid import uuid4
from django.apps import apps
from django.db import models
from django.db.models import Count, ExpressionWrapper, F
from django.db.models.fields import DurationField
@ -18,6 +21,7 @@ from django.http.request import QueryDict
from django.utils.timezone import now
from django.utils.translation import gettext as _
from requests import RequestException
from rest_framework.serializers import Serializer
from structlog.stdlib import get_logger
from authentik import get_full_version
@ -28,6 +32,7 @@ from authentik.core.middleware import (
SESSION_KEY_IMPERSONATE_USER,
)
from authentik.core.models import ExpiringModel, Group, PropertyMapping, User
from authentik.events.apps import GAUGE_TASKS, SYSTEM_TASK_STATUS, SYSTEM_TASK_TIME
from authentik.events.context_processors.base import get_context_processors
from authentik.events.utils import (
cleanse_dict,
@ -46,8 +51,6 @@ from authentik.stages.email.utils import TemplateEmailMessage
from authentik.tenants.models import Tenant
LOGGER = get_logger()
if TYPE_CHECKING:
from rest_framework.serializers import Serializer
def default_event_duration():
@ -61,6 +64,12 @@ def default_brand():
return sanitize_dict(model_to_dict(DEFAULT_BRAND))
@lru_cache()
def django_app_names() -> list[str]:
"""Get a cached list of all django apps' names (not labels)"""
return [x.name for x in apps.app_configs.values()]
class NotificationTransportError(SentryIgnoredException):
"""Error raised when a notification fails to be delivered"""
@ -198,6 +207,11 @@ class Event(SerializerModel, ExpiringModel):
current = currentframe()
parent = current.f_back
app = parent.f_globals["__name__"]
# Attempt to match the calling module to the django app it belongs to
# if we can't find a match, keep the module name
django_apps = get_close_matches(app, django_app_names(), n=1)
if len(django_apps) > 0:
app = django_apps[0]
cleaned_kwargs = cleanse_dict(sanitize_dict(kwargs))
event = Event(action=action, app=app, context=cleaned_kwargs)
return event
@ -270,7 +284,7 @@ class Event(SerializerModel, ExpiringModel):
super().save(*args, **kwargs)
@property
def serializer(self) -> "Serializer":
def serializer(self) -> type[Serializer]:
from authentik.events.api.events import EventSerializer
return EventSerializer
@ -478,7 +492,7 @@ class NotificationTransport(SerializerModel):
raise NotificationTransportError(exc) from exc
@property
def serializer(self) -> "Serializer":
def serializer(self) -> type[Serializer]:
from authentik.events.api.notification_transports import NotificationTransportSerializer
return NotificationTransportSerializer
@ -511,7 +525,7 @@ class Notification(SerializerModel):
user = models.ForeignKey(User, on_delete=models.CASCADE)
@property
def serializer(self) -> "Serializer":
def serializer(self) -> type[Serializer]:
from authentik.events.api.notifications import NotificationSerializer
return NotificationSerializer
@ -554,7 +568,7 @@ class NotificationRule(SerializerModel, PolicyBindingModel):
)
@property
def serializer(self) -> "Serializer":
def serializer(self) -> type[Serializer]:
from authentik.events.api.notification_rules import NotificationRuleSerializer
return NotificationRuleSerializer
@ -575,7 +589,7 @@ class NotificationWebhookMapping(PropertyMapping):
return "ak-property-mapping-notification-form"
@property
def serializer(self) -> type["Serializer"]:
def serializer(self) -> type[type[Serializer]]:
from authentik.events.api.notification_mappings import NotificationWebhookMappingSerializer
return NotificationWebhookMappingSerializer
@ -586,3 +600,66 @@ class NotificationWebhookMapping(PropertyMapping):
class Meta:
verbose_name = _("Webhook Mapping")
verbose_name_plural = _("Webhook Mappings")
class TaskStatus(models.TextChoices):
"""Possible states of tasks"""
UNKNOWN = "unknown"
SUCCESSFUL = "successful"
WARNING = "warning"
ERROR = "error"
class SystemTask(SerializerModel, ExpiringModel):
"""Info about a system task running in the background along with details to restart the task"""
uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4)
name = models.TextField()
uid = models.TextField(null=True)
start_timestamp = models.FloatField()
finish_timestamp = models.FloatField()
status = models.TextField(choices=TaskStatus.choices)
description = models.TextField(null=True)
messages = models.JSONField()
task_call_module = models.TextField()
task_call_func = models.TextField()
task_call_args = models.JSONField(default=list)
task_call_kwargs = models.JSONField(default=dict)
@property
def serializer(self) -> type[Serializer]:
from authentik.events.api.tasks import SystemTaskSerializer
return SystemTaskSerializer
def update_metrics(self):
"""Update prometheus metrics"""
duration = max(self.finish_timestamp - self.start_timestamp, 0)
# TODO: Deprecated metric - remove in 2024.2 or later
GAUGE_TASKS.labels(
task_name=self.name,
task_uid=self.uid or "",
status=self.status.lower(),
).set(duration)
SYSTEM_TASK_TIME.observe(duration)
SYSTEM_TASK_STATUS.labels(
task_name=self.name,
task_uid=self.uid or "",
status=self.status.lower(),
).inc()
def __str__(self) -> str:
return f"System Task {self.name}"
class Meta:
unique_together = (("name", "uid"),)
# Remove "add", "change" and "delete" permissions as those are not used
default_permissions = ["view"]
permissions = [("run_task", _("Run task"))]
verbose_name = _("System Task")
verbose_name_plural = _("System Tasks")

View File

@ -1,216 +0,0 @@
"""Monitored tasks"""
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from timeit import default_timer
from typing import Any, Optional
from django.core.cache import cache
from django.db import connection
from django.utils.translation import gettext_lazy as _
from structlog.stdlib import get_logger
from tenant_schemas_celery.task import TenantTask
from authentik.events.apps import GAUGE_TASKS
from authentik.events.models import Event, EventAction
from authentik.lib.utils.errors import exception_to_string
LOGGER = get_logger()
CACHE_KEY_PREFIX = "goauthentik.io/events/tasks/"
class TaskResultStatus(Enum):
"""Possible states of tasks"""
SUCCESSFUL = 1
WARNING = 2
ERROR = 4
UNKNOWN = 8
@dataclass
class TaskResult:
"""Result of a task run, this class is created by the task itself
and used by self.set_status"""
status: TaskResultStatus
messages: list[str] = field(default_factory=list)
# Optional UID used in cache for tasks that run in different instances
uid: Optional[str] = field(default=None)
def with_error(self, exc: Exception) -> "TaskResult":
"""Since errors might not always be pickle-able, set the traceback"""
# TODO: Mark exception somehow so that is rendered as <pre> in frontend
self.messages.append(exception_to_string(exc))
return self
@dataclass
class TaskInfo:
"""Info about a task run"""
task_name: str
start_timestamp: float
finish_timestamp: float
finish_time: datetime
result: TaskResult
task_call_module: str
task_call_func: str
task_call_args: list[Any] = field(default_factory=list)
task_call_kwargs: dict[str, Any] = field(default_factory=dict)
task_description: Optional[str] = field(default=None)
@staticmethod
def all() -> dict[str, "TaskInfo"]:
"""Get all TaskInfo objects"""
return cache.get_many(cache.keys(CACHE_KEY_PREFIX + "*"))
@staticmethod
def by_name(name: str) -> Optional["TaskInfo"] | Optional[list["TaskInfo"]]:
"""Get TaskInfo Object by name"""
if "*" in name:
return cache.get_many(cache.keys(CACHE_KEY_PREFIX + name)).values()
return cache.get(CACHE_KEY_PREFIX + name, None)
@property
def full_name(self) -> str:
"""Get the full cache key with task name and UID"""
key = CACHE_KEY_PREFIX + self.task_name
if self.result.uid:
uid_suffix = f":{self.result.uid}"
key += uid_suffix
if not self.task_name.endswith(uid_suffix):
self.task_name += uid_suffix
return key
def delete(self):
"""Delete task info from cache"""
return cache.delete(self.full_name)
def update_metrics(self):
"""Update prometheus metrics"""
start = default_timer()
if hasattr(self, "start_timestamp"):
start = self.start_timestamp
try:
duration = max(self.finish_timestamp - start, 0)
except TypeError:
duration = 0
GAUGE_TASKS.labels(
tenant=connection.schema_name,
task_name=self.task_name.split(":")[0],
task_uid=self.result.uid or "",
status=self.result.status.name.lower(),
).set(duration)
def save(self, timeout_hours=6):
"""Save task into cache"""
self.update_metrics()
cache.set(self.full_name, self, timeout=timeout_hours * 60 * 60)
class MonitoredTask(TenantTask):
"""Task which can save its state to the cache"""
# For tasks that should only be listed if they failed, set this to False
save_on_success: bool
_result: Optional[TaskResult]
_uid: Optional[str]
start: Optional[float] = None
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.save_on_success = True
self._uid = None
self._result = None
self.result_timeout_hours = 6
def set_uid(self, uid: str):
"""Set UID, so in the case of an unexpected error its saved correctly"""
self._uid = uid
def set_status(self, result: TaskResult):
"""Set result for current run, will overwrite previous result."""
self._result = result
def before_start(self, task_id, args, kwargs):
self.start = default_timer()
return super().before_start(task_id, args, kwargs)
# pylint: disable=too-many-arguments
def after_return(self, status, retval, task_id, args: list[Any], kwargs: dict[str, Any], einfo):
super().after_return(status, retval, task_id, args, kwargs, einfo=einfo)
if not self._result:
return
if not self._result.uid:
self._result.uid = self._uid
info = TaskInfo(
task_name=self.__name__,
task_description=self.__doc__,
start_timestamp=self.start or default_timer(),
finish_timestamp=default_timer(),
finish_time=datetime.now(),
result=self._result,
task_call_module=self.__module__,
task_call_func=self.__name__,
task_call_args=args,
task_call_kwargs=kwargs,
)
if self._result.status == TaskResultStatus.SUCCESSFUL and not self.save_on_success:
info.delete()
return
info.save(self.result_timeout_hours)
# pylint: disable=too-many-arguments
def on_failure(self, exc, task_id, args, kwargs, einfo):
super().on_failure(exc, task_id, args, kwargs, einfo=einfo)
if not self._result:
self._result = TaskResult(status=TaskResultStatus.ERROR, messages=[str(exc)])
if not self._result.uid:
self._result.uid = self._uid
TaskInfo(
task_name=self.__name__,
task_description=self.__doc__,
start_timestamp=self.start or default_timer(),
finish_timestamp=default_timer(),
finish_time=datetime.now(),
result=self._result,
task_call_module=self.__module__,
task_call_func=self.__name__,
task_call_args=args,
task_call_kwargs=kwargs,
).save(self.result_timeout_hours)
Event.new(
EventAction.SYSTEM_TASK_EXCEPTION,
message=f"Task {self.__name__} encountered an error: {exception_to_string(exc)}",
).save()
def run(self, *args, **kwargs):
raise NotImplementedError
def prefill_task(func):
"""Ensure a task's details are always in cache, so it can always be triggered via API"""
status = TaskInfo.by_name(func.__name__)
if status:
return func
TaskInfo(
task_name=func.__name__,
task_description=func.__doc__,
result=TaskResult(TaskResultStatus.UNKNOWN, messages=[_("Task has not been run yet.")]),
task_call_module=func.__module__,
task_call_func=func.__name__,
# We don't have real values for these attributes but they cannot be null
start_timestamp=0,
finish_timestamp=0,
finish_time=datetime.now(),
).save(86400)
LOGGER.debug("prefilled task", task_name=func.__name__)
return func

View File

@ -8,11 +8,13 @@ from django.http import HttpRequest
from authentik.core.models import User
from authentik.core.signals import login_failed, password_changed
from authentik.events.models import Event, EventAction
from authentik.events.apps import SYSTEM_TASK_STATUS
from authentik.events.models import Event, EventAction, SystemTask
from authentik.events.tasks import event_notification_handler, gdpr_cleanup
from authentik.flows.models import Stage
from authentik.flows.planner import PLAN_CONTEXT_SOURCE, FlowPlan
from authentik.flows.views.executor import SESSION_KEY_PLAN
from authentik.root.monitoring import monitoring_set
from authentik.stages.invitation.models import Invitation
from authentik.stages.invitation.signals import invitation_used
from authentik.stages.password.stage import PLAN_CONTEXT_METHOD, PLAN_CONTEXT_METHOD_ARGS
@ -100,3 +102,11 @@ def event_user_pre_delete_cleanup(sender, instance: User, **_):
"""If gdpr_compliance is enabled, remove all the user's events"""
if get_current_tenant().gdpr_compliance:
gdpr_cleanup.delay(instance.pk)
@receiver(monitoring_set)
def monitoring_system_task(sender, **_):
"""Update metrics when task is saved"""
SYSTEM_TASK_STATUS.clear()
for task in SystemTask.objects.all():
task.update_metrics()

View File

@ -0,0 +1,135 @@
"""Monitored tasks"""
from datetime import timedelta
from timeit import default_timer
from typing import Any, Optional
from django.utils.timezone import now
from django.utils.translation import gettext_lazy as _
from structlog.stdlib import get_logger
from tenant_schemas_celery.task import TenantTask
from authentik.events.models import Event, EventAction
from authentik.events.models import SystemTask as DBSystemTask
from authentik.events.models import TaskStatus
from authentik.events.utils import sanitize_item
from authentik.lib.utils.errors import exception_to_string
LOGGER = get_logger()
class SystemTask(TenantTask):
"""Task which can save its state to the cache"""
# For tasks that should only be listed if they failed, set this to False
save_on_success: bool
_status: Optional[TaskStatus]
_messages: list[str]
_uid: Optional[str]
_start: Optional[float] = None
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.save_on_success = True
self._uid = None
self._status = None
self._messages = []
self.result_timeout_hours = 6
def set_uid(self, uid: str):
"""Set UID, so in the case of an unexpected error its saved correctly"""
self._uid = uid
def set_status(self, status: TaskStatus, *messages: str):
"""Set result for current run, will overwrite previous result."""
self._status = status
self._messages = messages
def set_error(self, exception: Exception):
"""Set result to error and save exception"""
self._status = TaskStatus.ERROR
self._messages = [exception_to_string(exception)]
def before_start(self, task_id, args, kwargs):
self._start = default_timer()
return super().before_start(task_id, args, kwargs)
# pylint: disable=too-many-arguments
def after_return(self, status, retval, task_id, args: list[Any], kwargs: dict[str, Any], einfo):
super().after_return(status, retval, task_id, args, kwargs, einfo=einfo)
if not self._status:
return
if self._status == TaskStatus.SUCCESSFUL and not self.save_on_success:
DBSystemTask.objects.filter(
name=self.__name__,
uid=self._uid,
).delete()
return
DBSystemTask.objects.update_or_create(
name=self.__name__,
uid=self._uid,
defaults={
"description": self.__doc__,
"start_timestamp": self._start or default_timer(),
"finish_timestamp": default_timer(),
"task_call_module": self.__module__,
"task_call_func": self.__name__,
"task_call_args": args,
"task_call_kwargs": kwargs,
"status": self._status,
"messages": sanitize_item(self._messages),
"expires": now() + timedelta(hours=self.result_timeout_hours),
"expiring": True,
},
)
# pylint: disable=too-many-arguments
def on_failure(self, exc, task_id, args, kwargs, einfo):
super().on_failure(exc, task_id, args, kwargs, einfo=einfo)
if not self._status:
self._status = TaskStatus.ERROR
self._messages = exception_to_string(exc)
DBSystemTask.objects.update_or_create(
name=self.__name__,
uid=self._uid,
defaults={
"description": self.__doc__,
"start_timestamp": self._start or default_timer(),
"finish_timestamp": default_timer(),
"task_call_module": self.__module__,
"task_call_func": self.__name__,
"task_call_args": args,
"task_call_kwargs": kwargs,
"status": self._status,
"messages": sanitize_item(self._messages),
"expires": now() + timedelta(hours=self.result_timeout_hours),
"expiring": True,
},
)
Event.new(
EventAction.SYSTEM_TASK_EXCEPTION,
message=f"Task {self.__name__} encountered an error: {exception_to_string(exc)}",
).save()
def run(self, *args, **kwargs):
raise NotImplementedError
def prefill_task(func):
"""Ensure a task's details are always in cache, so it can always be triggered via API"""
_prefill_tasks.append(
DBSystemTask(
name=func.__name__,
description=func.__doc__,
status=TaskStatus.UNKNOWN,
messages=sanitize_item([_("Task has not been run yet.")]),
task_call_module=func.__module__,
task_call_func=func.__name__,
expiring=False,
)
)
return func
_prefill_tasks = []

View File

@ -13,13 +13,9 @@ from authentik.events.models import (
NotificationRule,
NotificationTransport,
NotificationTransportError,
TaskStatus,
)
from authentik.events.monitored_tasks import (
MonitoredTask,
TaskResult,
TaskResultStatus,
prefill_task,
)
from authentik.events.system_tasks import SystemTask, prefill_task
from authentik.policies.engine import PolicyEngine
from authentik.policies.models import PolicyBinding, PolicyEngineMode
from authentik.root.celery import CELERY_APP
@ -99,10 +95,10 @@ def event_trigger_handler(event_uuid: str, trigger_name: str):
bind=True,
autoretry_for=(NotificationTransportError,),
retry_backoff=True,
base=MonitoredTask,
base=SystemTask,
)
def notification_transport(
self: MonitoredTask, transport_pk: int, event_pk: str, user_pk: int, trigger_pk: str
self: SystemTask, transport_pk: int, event_pk: str, user_pk: int, trigger_pk: str
):
"""Send notification over specified transport"""
self.save_on_success = False
@ -123,9 +119,9 @@ def notification_transport(
if not transport:
return
transport.send(notification)
self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL))
self.set_status(TaskStatus.SUCCESSFUL)
except (NotificationTransportError, PropertyMappingExpressionException) as exc:
self.set_status(TaskResult(TaskResultStatus.ERROR).with_error(exc))
self.set_error(exc)
raise exc
@ -137,13 +133,13 @@ def gdpr_cleanup(user_pk: int):
events.delete()
@CELERY_APP.task(bind=True, base=MonitoredTask)
@CELERY_APP.task(bind=True, base=SystemTask)
@prefill_task
def notification_cleanup(self: MonitoredTask):
def notification_cleanup(self: SystemTask):
"""Cleanup seen notifications and notifications whose event expired."""
notifications = Notification.objects.filter(Q(event=None) | Q(seen=True))
amount = notifications.count()
for notification in notifications:
notification.delete()
LOGGER.debug("Expired notifications", amount=amount)
self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL, [f"Expired {amount} Notifications"]))
self.set_status(TaskStatus.SUCCESSFUL, f"Expired {amount} Notifications")

View File

@ -1,14 +1,26 @@
"""Test Monitored tasks"""
from django.test import TestCase
from json import loads
from authentik.events.monitored_tasks import MonitoredTask, TaskInfo, TaskResult, TaskResultStatus
from django.urls import reverse
from rest_framework.test import APITestCase
from authentik.core.tasks import clean_expired_models
from authentik.core.tests.utils import create_test_admin_user
from authentik.events.models import SystemTask as DBSystemTask
from authentik.events.models import TaskStatus
from authentik.events.system_tasks import SystemTask
from authentik.lib.generators import generate_id
from authentik.root.celery import CELERY_APP
class TestMonitoredTasks(TestCase):
class TestSystemTasks(APITestCase):
"""Test Monitored tasks"""
def setUp(self):
super().setUp()
self.user = create_test_admin_user()
self.client.force_login(self.user)
def test_failed_successful_remove_state(self):
"""Test that a task with `save_on_success` set to `False` that failed saves
a state, and upon successful completion will delete the state"""
@ -17,27 +29,74 @@ class TestMonitoredTasks(TestCase):
@CELERY_APP.task(
bind=True,
base=MonitoredTask,
base=SystemTask,
)
def test_task(self: MonitoredTask):
def test_task(self: SystemTask):
self.save_on_success = False
self.set_uid(uid)
self.set_status(
TaskResult(TaskResultStatus.ERROR if should_fail else TaskResultStatus.SUCCESSFUL)
)
self.set_status(TaskStatus.ERROR if should_fail else TaskStatus.SUCCESSFUL)
# First test successful run
should_fail = False
test_task.delay().get()
self.assertIsNone(TaskInfo.by_name(f"test_task:{uid}"))
self.assertIsNone(DBSystemTask.objects.filter(name="test_task", uid=uid).first())
# Then test failed
should_fail = True
test_task.delay().get()
info = TaskInfo.by_name(f"test_task:{uid}")
self.assertEqual(info.result.status, TaskResultStatus.ERROR)
task = DBSystemTask.objects.filter(name="test_task", uid=uid).first()
self.assertEqual(task.status, TaskStatus.ERROR)
# Then after that, the state should be removed
should_fail = False
test_task.delay().get()
self.assertIsNone(TaskInfo.by_name(f"test_task:{uid}"))
self.assertIsNone(DBSystemTask.objects.filter(name="test_task", uid=uid).first())
def test_tasks(self):
"""Test Task API"""
clean_expired_models.delay().get()
response = self.client.get(reverse("authentik_api:systemtask-list"))
self.assertEqual(response.status_code, 200)
body = loads(response.content)
self.assertTrue(any(task["name"] == "clean_expired_models" for task in body["results"]))
def test_tasks_single(self):
"""Test Task API (read single)"""
clean_expired_models.delay().get()
task = DBSystemTask.objects.filter(name="clean_expired_models").first()
response = self.client.get(
reverse(
"authentik_api:systemtask-detail",
kwargs={"pk": str(task.pk)},
)
)
self.assertEqual(response.status_code, 200)
body = loads(response.content)
self.assertEqual(body["status"], TaskStatus.SUCCESSFUL.value)
self.assertEqual(body["name"], "clean_expired_models")
response = self.client.get(
reverse("authentik_api:systemtask-detail", kwargs={"pk": "qwerqwer"})
)
self.assertEqual(response.status_code, 404)
def test_tasks_run(self):
"""Test Task API (run)"""
clean_expired_models.delay().get()
task = DBSystemTask.objects.filter(name="clean_expired_models").first()
response = self.client.post(
reverse(
"authentik_api:systemtask-run",
kwargs={"pk": str(task.pk)},
)
)
self.assertEqual(response.status_code, 204)
def test_tasks_run_404(self):
"""Test Task API (run, 404)"""
response = self.client.post(
reverse(
"authentik_api:systemtask-run",
kwargs={"pk": "qwerqewrqrqewrqewr"},
)
)
self.assertEqual(response.status_code, 404)

View File

@ -4,11 +4,13 @@ from authentik.events.api.notification_mappings import NotificationWebhookMappin
from authentik.events.api.notification_rules import NotificationRuleViewSet
from authentik.events.api.notification_transports import NotificationTransportViewSet
from authentik.events.api.notifications import NotificationViewSet
from authentik.events.api.tasks import SystemTaskViewSet
api_urlpatterns = [
("events/events", EventViewSet),
("events/notifications", NotificationViewSet),
("events/transports", NotificationTransportViewSet),
("events/rules", NotificationRuleViewSet),
("events/system_tasks", SystemTaskViewSet),
("propertymappings/notification", NotificationWebhookMappingViewSet),
]

View File

@ -18,6 +18,7 @@ from django.http.request import HttpRequest
from django.utils import timezone
from django.views.debug import SafeExceptionReporterFilter
from geoip2.models import ASN, City
from guardian.conf import settings
from guardian.utils import get_anonymous_user
from authentik.blueprints.v1.common import YAMLTag
@ -84,6 +85,8 @@ def get_user(user: User | AnonymousUser, original_user: Optional[User] = None) -
"pk": user.pk,
"email": user.email,
}
if user.username == settings.ANONYMOUS_USER_NAME:
user_data["is_anonymous"] = True
if original_user:
original_data = get_user(original_user)
original_data["on_behalf_of"] = user_data

View File

@ -19,12 +19,8 @@ from yaml import safe_load
from authentik.enterprise.providers.rac.controllers.docker import RACDockerController
from authentik.enterprise.providers.rac.controllers.kubernetes import RACKubernetesController
from authentik.events.monitored_tasks import (
MonitoredTask,
TaskResult,
TaskResultStatus,
prefill_task,
)
from authentik.events.models import TaskStatus
from authentik.events.system_tasks import SystemTask, prefill_task
from authentik.lib.config import CONFIG
from authentik.lib.utils.reflection import path_to_class
from authentik.outposts.consumer import OUTPOST_GROUP
@ -108,20 +104,18 @@ def outpost_service_connection_state(connection_pk: Any):
@CELERY_APP.task(
bind=True,
base=MonitoredTask,
base=SystemTask,
throws=(DatabaseError, ProgrammingError, InternalError),
)
@prefill_task
def outpost_service_connection_monitor(self: MonitoredTask):
def outpost_service_connection_monitor(self: SystemTask):
"""Regularly check the state of Outpost Service Connections"""
connections = OutpostServiceConnection.objects.all()
for connection in connections.iterator():
outpost_service_connection_state.delay(connection.pk)
self.set_status(
TaskResult(
TaskResultStatus.SUCCESSFUL,
[f"Successfully updated {len(connections)} connections."],
)
TaskStatus.SUCCESSFUL,
f"Successfully updated {len(connections)} connections.",
)
@ -134,9 +128,9 @@ def outpost_controller_all():
outpost_controller.delay(outpost.pk.hex, "up", from_cache=False)
@CELERY_APP.task(bind=True, base=MonitoredTask)
@CELERY_APP.task(bind=True, base=SystemTask)
def outpost_controller(
self: MonitoredTask, outpost_pk: str, action: str = "up", from_cache: bool = False
self: SystemTask, outpost_pk: str, action: str = "up", from_cache: bool = False
):
"""Create/update/monitor/delete the deployment of an Outpost"""
logs = []
@ -161,16 +155,16 @@ def outpost_controller(
LOGGER.debug(log)
LOGGER.debug("-----------------Outpost Controller logs end-------------------")
except (ControllerException, ServiceConnectionInvalid) as exc:
self.set_status(TaskResult(TaskResultStatus.ERROR).with_error(exc))
self.set_error(exc)
else:
if from_cache:
cache.delete(CACHE_KEY_OUTPOST_DOWN % outpost_pk)
self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL, logs))
self.set_status(TaskStatus.SUCCESSFUL, *logs)
@CELERY_APP.task(bind=True, base=MonitoredTask)
@CELERY_APP.task(bind=True, base=SystemTask)
@prefill_task
def outpost_token_ensurer(self: MonitoredTask):
def outpost_token_ensurer(self: SystemTask):
"""Periodically ensure that all Outposts have valid Service Accounts
and Tokens"""
all_outposts = Outpost.objects.all()
@ -178,10 +172,8 @@ def outpost_token_ensurer(self: MonitoredTask):
_ = outpost.token
outpost.build_user_permissions(outpost.user)
self.set_status(
TaskResult(
TaskResultStatus.SUCCESSFUL,
[f"Successfully checked {len(all_outposts)} Outposts."],
)
TaskStatus.SUCCESSFUL,
f"Successfully checked {len(all_outposts)} Outposts.",
)
@ -256,32 +248,32 @@ def _outpost_single_update(outpost: Outpost, layer=None):
@CELERY_APP.task(
base=MonitoredTask,
base=SystemTask,
bind=True,
)
def outpost_connection_discovery(self: MonitoredTask):
def outpost_connection_discovery(self: SystemTask):
"""Checks the local environment and create Service connections."""
status = TaskResult(TaskResultStatus.SUCCESSFUL)
messages = []
if not CONFIG.get_bool("outposts.discover"):
status.messages.append("Outpost integration discovery is disabled")
self.set_status(status)
messages.append("Outpost integration discovery is disabled")
self.set_status(TaskStatus.SUCCESSFUL, *messages)
return
# Explicitly check against token filename, as that's
# only present when the integration is enabled
if Path(SERVICE_TOKEN_FILENAME).exists():
status.messages.append("Detected in-cluster Kubernetes Config")
messages.append("Detected in-cluster Kubernetes Config")
if not KubernetesServiceConnection.objects.filter(local=True).exists():
status.messages.append("Created Service Connection for in-cluster")
messages.append("Created Service Connection for in-cluster")
KubernetesServiceConnection.objects.create(
name="Local Kubernetes Cluster", local=True, kubeconfig={}
)
# For development, check for the existence of a kubeconfig file
kubeconfig_path = Path(KUBE_CONFIG_DEFAULT_LOCATION).expanduser()
if kubeconfig_path.exists():
status.messages.append("Detected kubeconfig")
messages.append("Detected kubeconfig")
kubeconfig_local_name = f"k8s-{gethostname()}"
if not KubernetesServiceConnection.objects.filter(name=kubeconfig_local_name).exists():
status.messages.append("Creating kubeconfig Service Connection")
messages.append("Creating kubeconfig Service Connection")
with kubeconfig_path.open("r", encoding="utf8") as _kubeconfig:
KubernetesServiceConnection.objects.create(
name=kubeconfig_local_name,
@ -290,12 +282,12 @@ def outpost_connection_discovery(self: MonitoredTask):
unix_socket_path = urlparse(DEFAULT_UNIX_SOCKET).path
socket = Path(unix_socket_path)
if socket.exists() and access(socket, R_OK):
status.messages.append("Detected local docker socket")
messages.append("Detected local docker socket")
if len(DockerServiceConnection.objects.filter(local=True)) == 0:
status.messages.append("Created Service Connection for docker")
messages.append("Created Service Connection for docker")
DockerServiceConnection.objects.create(
name="Local Docker connection",
local=True,
url=unix_socket_path,
)
self.set_status(status)
self.set_status(TaskStatus.SUCCESSFUL, *messages)

View File

@ -4,12 +4,8 @@ from structlog.stdlib import get_logger
from authentik.events.context_processors.asn import ASN_CONTEXT_PROCESSOR
from authentik.events.context_processors.geoip import GEOIP_CONTEXT_PROCESSOR
from authentik.events.monitored_tasks import (
MonitoredTask,
TaskResult,
TaskResultStatus,
prefill_task,
)
from authentik.events.models import TaskStatus
from authentik.events.system_tasks import SystemTask, prefill_task
from authentik.policies.reputation.models import Reputation
from authentik.policies.reputation.signals import CACHE_KEY_PREFIX
from authentik.root.celery import CELERY_APP
@ -17,9 +13,9 @@ from authentik.root.celery import CELERY_APP
LOGGER = get_logger()
@CELERY_APP.task(bind=True, base=MonitoredTask)
@CELERY_APP.task(bind=True, base=SystemTask)
@prefill_task
def save_reputation(self: MonitoredTask):
def save_reputation(self: SystemTask):
"""Save currently cached reputation to database"""
objects_to_update = []
for _, score in cache.get_many(cache.keys(CACHE_KEY_PREFIX + "*")).items():
@ -32,4 +28,4 @@ def save_reputation(self: MonitoredTask):
rep.score = score["score"]
objects_to_update.append(rep)
Reputation.objects.bulk_update(objects_to_update, ["score", "ip_geo_data"])
self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL, ["Successfully updated Reputation"]))
self.set_status(TaskStatus.SUCCESSFUL, "Successfully updated Reputation")

View File

@ -1,17 +1,17 @@
"""SCIM Provider API Views"""
from django.utils.text import slugify
from drf_spectacular.utils import OpenApiResponse, extend_schema
from guardian.shortcuts import get_objects_for_user
from rest_framework.decorators import action
from rest_framework.fields import BooleanField
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.viewsets import ModelViewSet
from authentik.admin.api.tasks import TaskSerializer
from authentik.core.api.providers import ProviderSerializer
from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import PassiveSerializer
from authentik.events.monitored_tasks import TaskInfo
from authentik.events.api.tasks import SystemTaskSerializer
from authentik.providers.scim.models import SCIMProvider
@ -43,7 +43,7 @@ class SCIMSyncStatusSerializer(PassiveSerializer):
"""SCIM Provider sync status"""
is_running = BooleanField(read_only=True)
tasks = TaskSerializer(many=True, read_only=True)
tasks = SystemTaskSerializer(many=True, read_only=True)
class SCIMProviderViewSet(UsedByMixin, ModelViewSet):
@ -65,8 +65,12 @@ class SCIMProviderViewSet(UsedByMixin, ModelViewSet):
def sync_status(self, request: Request, pk: int) -> Response:
"""Get provider's sync status"""
provider: SCIMProvider = self.get_object()
task = TaskInfo.by_name(f"scim_sync:{slugify(provider.name)}")
tasks = [task] if task else []
tasks = list(
get_objects_for_user(request.user, "authentik_events.view_systemtask").filter(
name="scim_sync",
uid=slugify(provider.name),
)
)
status = {
"tasks": tasks,
"is_running": provider.sync_lock.locked(),

View File

@ -10,7 +10,8 @@ from pydanticscim.responses import PatchOp
from structlog.stdlib import get_logger
from authentik.core.models import Group, User
from authentik.events.monitored_tasks import MonitoredTask, TaskResult, TaskResultStatus
from authentik.events.models import TaskStatus
from authentik.events.system_tasks import SystemTask
from authentik.lib.utils.reflection import path_to_class
from authentik.providers.scim.clients import PAGE_SIZE, PAGE_TIMEOUT
from authentik.providers.scim.clients.base import SCIMClient
@ -39,8 +40,8 @@ def scim_sync_all():
scim_sync.delay(provider.pk)
@CELERY_APP.task(bind=True, base=MonitoredTask)
def scim_sync(self: MonitoredTask, provider_pk: int) -> None:
@CELERY_APP.task(bind=True, base=SystemTask)
def scim_sync(self: SystemTask, provider_pk: int) -> None:
"""Run SCIM full sync for provider"""
provider: SCIMProvider = SCIMProvider.objects.filter(
pk=provider_pk, backchannel_application__isnull=False
@ -52,8 +53,8 @@ def scim_sync(self: MonitoredTask, provider_pk: int) -> None:
LOGGER.debug("SCIM sync locked, skipping task", source=provider.name)
return
self.set_uid(slugify(provider.name))
result = TaskResult(TaskResultStatus.SUCCESSFUL, [])
result.messages.append(_("Starting full SCIM sync"))
messages = []
messages.append(_("Starting full SCIM sync"))
LOGGER.debug("Starting SCIM sync")
users_paginator = Paginator(provider.get_user_qs(), PAGE_SIZE)
groups_paginator = Paginator(provider.get_group_qs(), PAGE_SIZE)
@ -63,17 +64,17 @@ def scim_sync(self: MonitoredTask, provider_pk: int) -> None:
with allow_join_result():
try:
for page in users_paginator.page_range:
result.messages.append(_("Syncing page %(page)d of users" % {"page": page}))
messages.append(_("Syncing page %(page)d of users" % {"page": page}))
for msg in scim_sync_users.delay(page, provider_pk).get():
result.messages.append(msg)
messages.append(msg)
for page in groups_paginator.page_range:
result.messages.append(_("Syncing page %(page)d of groups" % {"page": page}))
messages.append(_("Syncing page %(page)d of groups" % {"page": page}))
for msg in scim_sync_group.delay(page, provider_pk).get():
result.messages.append(msg)
messages.append(msg)
except StopSync as exc:
self.set_status(TaskResult(TaskResultStatus.ERROR).with_error(exc))
self.set_error(exc)
return
self.set_status(result)
self.set_status(TaskStatus.SUCCESSFUL, *messages)
@CELERY_APP.task(

View File

@ -6,6 +6,7 @@ from django_filters.filters import AllValuesMultipleFilter
from django_filters.filterset import FilterSet
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import extend_schema, extend_schema_field, inline_serializer
from guardian.shortcuts import get_objects_for_user
from rest_framework.decorators import action
from rest_framework.exceptions import ValidationError
from rest_framework.fields import BooleanField, DictField, ListField, SerializerMethodField
@ -14,13 +15,12 @@ from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.viewsets import ModelViewSet
from authentik.admin.api.tasks import TaskSerializer
from authentik.core.api.propertymappings import PropertyMappingSerializer
from authentik.core.api.sources import SourceSerializer
from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import PassiveSerializer
from authentik.crypto.models import CertificateKeyPair
from authentik.events.monitored_tasks import TaskInfo
from authentik.events.api.tasks import SystemTaskSerializer
from authentik.sources.ldap.models import LDAPPropertyMapping, LDAPSource
from authentik.sources.ldap.tasks import CACHE_KEY_STATUS, SYNC_CLASSES
@ -91,7 +91,7 @@ class LDAPSyncStatusSerializer(PassiveSerializer):
"""LDAP Source sync status"""
is_running = BooleanField(read_only=True)
tasks = TaskSerializer(many=True, read_only=True)
tasks = SystemTaskSerializer(many=True, read_only=True)
class LDAPSourceViewSet(UsedByMixin, ModelViewSet):
@ -136,7 +136,12 @@ class LDAPSourceViewSet(UsedByMixin, ModelViewSet):
def sync_status(self, request: Request, slug: str) -> Response:
"""Get source's sync status"""
source: LDAPSource = self.get_object()
tasks = TaskInfo.by_name(f"ldap_sync:{source.slug}:*") or []
tasks = list(
get_objects_for_user(request.user, "authentik_events.view_systemtask").filter(
name="ldap_sync",
uid__startswith=source.slug,
)
)
status = {
"tasks": tasks,
"is_running": source.sync_lock.locked(),

View File

@ -1,9 +1,7 @@
"""FreeIPA specific"""
from datetime import datetime
from datetime import datetime, timezone
from typing import Any, Generator
from pytz import UTC
from authentik.core.models import User
from authentik.sources.ldap.sync.base import BaseLDAPSynchronizer, flatten
@ -27,7 +25,7 @@ class FreeIPA(BaseLDAPSynchronizer):
if "krbLastPwdChange" not in attributes:
return
pwd_last_set: datetime = attributes.get("krbLastPwdChange", datetime.now())
pwd_last_set = pwd_last_set.replace(tzinfo=UTC)
pwd_last_set = pwd_last_set.replace(tzinfo=timezone.utc)
if created or pwd_last_set >= user.password_change_date:
self.message(f"'{user.username}': Reset user's password")
self._logger.debug(

View File

@ -1,10 +1,8 @@
"""Active Directory specific"""
from datetime import datetime
from datetime import datetime, timezone
from enum import IntFlag
from typing import Any, Generator
from pytz import UTC
from authentik.core.models import User
from authentik.sources.ldap.sync.base import BaseLDAPSynchronizer
@ -58,7 +56,7 @@ class MicrosoftActiveDirectory(BaseLDAPSynchronizer):
if "pwdLastSet" not in attributes:
return
pwd_last_set: datetime = attributes.get("pwdLastSet", datetime.now())
pwd_last_set = pwd_last_set.replace(tzinfo=UTC)
pwd_last_set = pwd_last_set.replace(tzinfo=timezone.utc)
if created or pwd_last_set >= user.password_change_date:
self.message(f"'{user.username}': Reset user's password")
self._logger.debug(

View File

@ -8,8 +8,9 @@ from ldap3.core.exceptions import LDAPException
from redis.exceptions import LockError
from structlog.stdlib import get_logger
from authentik.events.monitored_tasks import CACHE_KEY_PREFIX as CACHE_KEY_PREFIX_TASKS
from authentik.events.monitored_tasks import MonitoredTask, TaskResult, TaskResultStatus
from authentik.events.models import SystemTask as DBSystemTask
from authentik.events.models import TaskStatus
from authentik.events.system_tasks import SystemTask
from authentik.lib.config import CONFIG
from authentik.lib.utils.errors import exception_to_string
from authentik.lib.utils.reflection import class_to_path, path_to_class
@ -34,7 +35,7 @@ CACHE_KEY_STATUS = "goauthentik.io/sources/ldap/status/"
def ldap_sync_all():
"""Sync all sources"""
for source in LDAPSource.objects.filter(enabled=True):
ldap_sync_single.apply_async(args=[source.pk])
ldap_sync_single.apply_async(args=[str(source.pk)])
@CELERY_APP.task()
@ -69,8 +70,7 @@ def ldap_sync_single(source_pk: str):
try:
with lock:
# Delete all sync tasks from the cache
keys = cache.keys(f"{CACHE_KEY_PREFIX_TASKS}ldap_sync:{source.slug}*")
cache.delete_many(keys)
DBSystemTask.objects.filter(name="ldap_sync", uid__startswith=source.slug).delete()
task = chain(
# User and group sync can happen at once, they have no dependencies on each other
group(
@ -96,18 +96,18 @@ def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) ->
for page in sync_inst.get_objects():
page_cache_key = CACHE_KEY_PREFIX + str(uuid4())
cache.set(page_cache_key, page, 60 * 60 * CONFIG.get_int("ldap.task_timeout_hours"))
page_sync = ldap_sync.si(source.pk, class_to_path(sync), page_cache_key)
page_sync = ldap_sync.si(str(source.pk), class_to_path(sync), page_cache_key)
signatures.append(page_sync)
return signatures
@CELERY_APP.task(
bind=True,
base=MonitoredTask,
base=SystemTask,
soft_time_limit=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours"),
task_time_limit=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours"),
)
def ldap_sync(self: MonitoredTask, source_pk: str, sync_class: str, page_cache_key: str):
def ldap_sync(self: SystemTask, source_pk: str, sync_class: str, page_cache_key: str):
"""Synchronization of an LDAP Source"""
self.result_timeout_hours = CONFIG.get_int("ldap.task_timeout_hours")
source: LDAPSource = LDAPSource.objects.filter(pk=source_pk).first()
@ -127,20 +127,18 @@ def ldap_sync(self: MonitoredTask, source_pk: str, sync_class: str, page_cache_k
+ "Try increasing ldap.task_timeout_hours"
)
LOGGER.warning(error_message)
self.set_status(TaskResult(TaskResultStatus.ERROR, [error_message]))
self.set_status(TaskStatus.ERROR, error_message)
return
cache.touch(page_cache_key)
count = sync_inst.sync(page)
messages = sync_inst.messages
messages.append(f"Synced {count} objects.")
self.set_status(
TaskResult(
TaskResultStatus.SUCCESSFUL,
messages,
)
TaskStatus.SUCCESSFUL,
*messages,
)
cache.delete(page_cache_key)
except LDAPException as exc:
# No explicit event is created here as .set_status with an error will do that
LOGGER.warning(exception_to_string(exc))
self.set_status(TaskResult(TaskResultStatus.ERROR).with_error(exc))
self.set_error(exc)

View File

@ -7,8 +7,8 @@ from django.test import TestCase
from authentik.blueprints.tests import apply_blueprint
from authentik.core.models import Group, User
from authentik.core.tests.utils import create_test_admin_user
from authentik.events.models import Event, EventAction
from authentik.events.monitored_tasks import TaskInfo, TaskResultStatus
from authentik.events.models import Event, EventAction, SystemTask
from authentik.events.system_tasks import TaskStatus
from authentik.lib.generators import generate_id, generate_key
from authentik.lib.utils.reflection import class_to_path
from authentik.sources.ldap.models import LDAPPropertyMapping, LDAPSource
@ -40,9 +40,9 @@ class LDAPSyncTests(TestCase):
"""Test sync with missing page"""
connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD))
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
ldap_sync.delay(self.source.pk, class_to_path(UserLDAPSynchronizer), "foo").get()
status = TaskInfo.by_name("ldap_sync:ldap:users:foo")
self.assertEqual(status.result.status, TaskResultStatus.ERROR)
ldap_sync.delay(str(self.source.pk), class_to_path(UserLDAPSynchronizer), "foo").get()
task = SystemTask.objects.filter(name="ldap_sync", uid="ldap:users:foo").first()
self.assertEqual(task.status, TaskStatus.ERROR)
def test_sync_error(self):
"""Test user sync"""

View File

@ -4,7 +4,8 @@ from json import dumps
from requests import RequestException
from structlog.stdlib import get_logger
from authentik.events.monitored_tasks import MonitoredTask, TaskResult, TaskResultStatus
from authentik.events.models import TaskStatus
from authentik.events.system_tasks import SystemTask
from authentik.lib.utils.http import get_http_session
from authentik.root.celery import CELERY_APP
from authentik.sources.oauth.models import OAuthSource
@ -12,11 +13,11 @@ from authentik.sources.oauth.models import OAuthSource
LOGGER = get_logger()
@CELERY_APP.task(bind=True, base=MonitoredTask)
def update_well_known_jwks(self: MonitoredTask):
@CELERY_APP.task(bind=True, base=SystemTask)
def update_well_known_jwks(self: SystemTask):
"""Update OAuth sources' config from well_known, and JWKS info from the configured URL"""
session = get_http_session()
result = TaskResult(TaskResultStatus.SUCCESSFUL, [])
messages = []
for source in OAuthSource.objects.all().exclude(oidc_well_known_url=""):
try:
well_known_config = session.get(source.oidc_well_known_url)
@ -24,7 +25,7 @@ def update_well_known_jwks(self: MonitoredTask):
except RequestException as exc:
text = exc.response.text if exc.response else str(exc)
LOGGER.warning("Failed to update well_known", source=source, exc=exc, text=text)
result.messages.append(f"Failed to update OIDC configuration for {source.slug}")
messages.append(f"Failed to update OIDC configuration for {source.slug}")
continue
config = well_known_config.json()
try:
@ -47,7 +48,7 @@ def update_well_known_jwks(self: MonitoredTask):
source=source,
exc=exc,
)
result.messages.append(f"Failed to update OIDC configuration for {source.slug}")
messages.append(f"Failed to update OIDC configuration for {source.slug}")
continue
if dirty:
LOGGER.info("Updating sources' OpenID Configuration", source=source)
@ -60,11 +61,11 @@ def update_well_known_jwks(self: MonitoredTask):
except RequestException as exc:
text = exc.response.text if exc.response else str(exc)
LOGGER.warning("Failed to update JWKS", source=source, exc=exc, text=text)
result.messages.append(f"Failed to update JWKS for {source.slug}")
messages.append(f"Failed to update JWKS for {source.slug}")
continue
config = jwks_config.json()
if dumps(source.oidc_jwks, sort_keys=True) != dumps(config, sort_keys=True):
source.oidc_jwks = config
LOGGER.info("Updating sources' JWKS", source=source)
source.save()
self.set_status(result)
self.set_status(TaskStatus.SUCCESSFUL, *messages)

View File

@ -1,8 +1,8 @@
"""Plex tasks"""
from requests import RequestException
from authentik.events.models import Event, EventAction
from authentik.events.monitored_tasks import MonitoredTask, TaskResult, TaskResultStatus
from authentik.events.models import Event, EventAction, TaskStatus
from authentik.events.system_tasks import SystemTask
from authentik.lib.utils.errors import exception_to_string
from authentik.root.celery import CELERY_APP
from authentik.sources.plex.models import PlexSource
@ -16,8 +16,8 @@ def check_plex_token_all():
check_plex_token.delay(source.slug)
@CELERY_APP.task(bind=True, base=MonitoredTask)
def check_plex_token(self: MonitoredTask, source_slug: int):
@CELERY_APP.task(bind=True, base=SystemTask)
def check_plex_token(self: SystemTask, source_slug: int):
"""Check the validity of a Plex source."""
sources = PlexSource.objects.filter(slug=source_slug)
if not sources.exists():
@ -27,16 +27,15 @@ def check_plex_token(self: MonitoredTask, source_slug: int):
auth = PlexAuth(source, source.plex_token)
try:
auth.get_user_info()
self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL, ["Plex token is valid."]))
self.set_status(TaskStatus.SUCCESSFUL, "Plex token is valid.")
except RequestException as exc:
error = exception_to_string(exc)
if len(source.plex_token) > 0:
error = error.replace(source.plex_token, "$PLEX_TOKEN")
self.set_status(
TaskResult(
TaskResultStatus.ERROR,
["Plex token is invalid/an error occurred:", error],
)
TaskStatus.ERROR,
"Plex token is invalid/an error occurred:",
error,
)
Event.new(
EventAction.CONFIGURATION_ERROR,

View File

@ -9,8 +9,8 @@ from django.core.mail.utils import DNS_NAME
from django.utils.text import slugify
from structlog.stdlib import get_logger
from authentik.events.models import Event, EventAction
from authentik.events.monitored_tasks import MonitoredTask, TaskResult, TaskResultStatus
from authentik.events.models import Event, EventAction, TaskStatus
from authentik.events.system_tasks import SystemTask
from authentik.root.celery import CELERY_APP
from authentik.stages.email.models import EmailStage
from authentik.stages.email.utils import logo_data
@ -22,7 +22,7 @@ def send_mails(stage: EmailStage, *messages: list[EmailMultiAlternatives]):
"""Wrapper to convert EmailMessage to dict and send it from worker"""
tasks = []
for message in messages:
tasks.append(send_mail.s(message.__dict__, stage.pk))
tasks.append(send_mail.s(message.__dict__, str(stage.pk)))
lazy_group = group(*tasks)
promise = lazy_group()
return promise
@ -44,9 +44,9 @@ def get_email_body(email: EmailMultiAlternatives) -> str:
OSError,
),
retry_backoff=True,
base=MonitoredTask,
base=SystemTask,
)
def send_mail(self: MonitoredTask, message: dict[Any, Any], email_stage_pk: Optional[int] = None):
def send_mail(self: SystemTask, message: dict[Any, Any], email_stage_pk: Optional[str] = None):
"""Send Email for Email Stage. Retries are scheduled automatically."""
self.save_on_success = False
message_id = make_msgid(domain=DNS_NAME)
@ -58,10 +58,8 @@ def send_mail(self: MonitoredTask, message: dict[Any, Any], email_stage_pk: Opti
stages = EmailStage.objects.filter(pk=email_stage_pk)
if not stages.exists():
self.set_status(
TaskResult(
TaskResultStatus.WARNING,
messages=["Email stage does not exist anymore. Discarding message."],
)
TaskStatus.WARNING,
"Email stage does not exist anymore. Discarding message.",
)
return
stage: EmailStage = stages.first()
@ -69,7 +67,7 @@ def send_mail(self: MonitoredTask, message: dict[Any, Any], email_stage_pk: Opti
backend = stage.backend
except ValueError as exc:
LOGGER.warning("failed to get email backend", exc=exc)
self.set_status(TaskResult(TaskResultStatus.ERROR).with_error(exc))
self.set_error(exc)
return
backend.open()
# Since django's EmailMessage objects are not JSON serialisable,
@ -97,12 +95,10 @@ def send_mail(self: MonitoredTask, message: dict[Any, Any], email_stage_pk: Opti
to_email=message_object.to,
).save()
self.set_status(
TaskResult(
TaskResultStatus.SUCCESSFUL,
messages=["Successfully sent Mail."],
)
TaskStatus.SUCCESSFUL,
"Successfully sent Mail.",
)
except (SMTPException, ConnectionError, OSError) as exc:
LOGGER.debug("Error sending email, retrying...", exc=exc)
self.set_status(TaskResult(TaskResultStatus.ERROR).with_error(exc))
self.set_error(exc)
raise exc