events: migrate SystemTasks to DB (#8159)
* events: migrate system tasks to save in DB Signed-off-by: Jens Langhammer <jens@goauthentik.io> * prefill in app startup Signed-off-by: Jens Langhammer <jens@goauthentik.io> * cleanup api Signed-off-by: Jens Langhammer <jens@goauthentik.io> * update web Signed-off-by: Jens Langhammer <jens@goauthentik.io> * use string for status Signed-off-by: Jens Langhammer <jens@goauthentik.io> * fix enum Signed-off-by: Jens Langhammer <jens@goauthentik.io> * save start and end directly in timestamp from default_timer() Signed-off-by: Jens Langhammer <jens@goauthentik.io> * improve metrics Signed-off-by: Jens Langhammer <jens@goauthentik.io> * lint Signed-off-by: Jens Langhammer <jens@goauthentik.io> * fix tests Signed-off-by: Jens Langhammer <jens@goauthentik.io> * rename globally to system task Signed-off-by: Jens Langhammer <jens@goauthentik.io> * recreate migrations, better denote anonymous user Signed-off-by: Jens Langhammer <jens@goauthentik.io> * events: lookup actual django app instead of using module path, fallback to module path Signed-off-by: Jens Langhammer <jens@goauthentik.io> * fix logger call Signed-off-by: Jens Langhammer <jens@goauthentik.io> * fix Signed-off-by: Jens Langhammer <jens@goauthentik.io> --------- Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
@ -1,134 +0,0 @@
|
||||
"""Tasks API"""
|
||||
from importlib import import_module
|
||||
|
||||
from django.contrib import messages
|
||||
from django.http.response import Http404
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.fields import (
|
||||
CharField,
|
||||
ChoiceField,
|
||||
DateTimeField,
|
||||
ListField,
|
||||
SerializerMethodField,
|
||||
)
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.viewsets import ViewSet
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.api.decorators import permission_required
|
||||
from authentik.core.api.utils import PassiveSerializer
|
||||
from authentik.events.monitored_tasks import TaskInfo, TaskResultStatus
|
||||
from authentik.rbac.permissions import HasPermission
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
class TaskSerializer(PassiveSerializer):
|
||||
"""Serialize TaskInfo and TaskResult"""
|
||||
|
||||
task_name = CharField()
|
||||
task_description = CharField()
|
||||
task_finish_timestamp = DateTimeField(source="finish_time")
|
||||
task_duration = SerializerMethodField()
|
||||
|
||||
status = ChoiceField(
|
||||
source="result.status.name",
|
||||
choices=[(x.name, x.name) for x in TaskResultStatus],
|
||||
)
|
||||
messages = ListField(source="result.messages")
|
||||
|
||||
def get_task_duration(self, instance: TaskInfo) -> int:
|
||||
"""Get the duration a task took to run"""
|
||||
return max(instance.finish_timestamp - instance.start_timestamp, 0)
|
||||
|
||||
def to_representation(self, instance: TaskInfo):
|
||||
"""When a new version of authentik adds fields to TaskInfo,
|
||||
the API will fail with an AttributeError, as the classes
|
||||
are pickled in cache. In that case, just delete the info"""
|
||||
try:
|
||||
return super().to_representation(instance)
|
||||
# pylint: disable=broad-except
|
||||
except Exception: # pragma: no cover
|
||||
if isinstance(self.instance, list):
|
||||
for inst in self.instance:
|
||||
inst.delete()
|
||||
else:
|
||||
self.instance.delete()
|
||||
return {}
|
||||
|
||||
|
||||
class TaskViewSet(ViewSet):
|
||||
"""Read-only view set that returns all background tasks"""
|
||||
|
||||
permission_classes = [HasPermission("authentik_rbac.view_system_tasks")]
|
||||
serializer_class = TaskSerializer
|
||||
|
||||
@extend_schema(
|
||||
responses={
|
||||
200: TaskSerializer(many=False),
|
||||
404: OpenApiResponse(description="Task not found"),
|
||||
},
|
||||
parameters=[
|
||||
OpenApiParameter(
|
||||
"id",
|
||||
type=OpenApiTypes.STR,
|
||||
location=OpenApiParameter.PATH,
|
||||
required=True,
|
||||
),
|
||||
],
|
||||
)
|
||||
def retrieve(self, request: Request, pk=None) -> Response:
|
||||
"""Get a single system task"""
|
||||
task = TaskInfo.by_name(pk)
|
||||
if not task:
|
||||
raise Http404
|
||||
return Response(TaskSerializer(task, many=False).data)
|
||||
|
||||
@extend_schema(responses={200: TaskSerializer(many=True)})
|
||||
def list(self, request: Request) -> Response:
|
||||
"""List system tasks"""
|
||||
tasks = sorted(TaskInfo.all().values(), key=lambda task: task.task_name)
|
||||
return Response(TaskSerializer(tasks, many=True).data)
|
||||
|
||||
@permission_required(None, ["authentik_rbac.run_system_tasks"])
|
||||
@extend_schema(
|
||||
request=OpenApiTypes.NONE,
|
||||
responses={
|
||||
204: OpenApiResponse(description="Task retried successfully"),
|
||||
404: OpenApiResponse(description="Task not found"),
|
||||
500: OpenApiResponse(description="Failed to retry task"),
|
||||
},
|
||||
parameters=[
|
||||
OpenApiParameter(
|
||||
"id",
|
||||
type=OpenApiTypes.STR,
|
||||
location=OpenApiParameter.PATH,
|
||||
required=True,
|
||||
),
|
||||
],
|
||||
)
|
||||
@action(detail=True, methods=["post"])
|
||||
def retry(self, request: Request, pk=None) -> Response:
|
||||
"""Retry task"""
|
||||
task = TaskInfo.by_name(pk)
|
||||
if not task:
|
||||
raise Http404
|
||||
try:
|
||||
task_module = import_module(task.task_call_module)
|
||||
task_func = getattr(task_module, task.task_call_func)
|
||||
LOGGER.debug("Running task", task=task_func)
|
||||
task_func.delay(*task.task_call_args, **task.task_call_kwargs)
|
||||
messages.success(
|
||||
self.request,
|
||||
_("Successfully re-scheduled Task %(name)s!" % {"name": task.task_name}),
|
||||
)
|
||||
return Response(status=204)
|
||||
except (ImportError, AttributeError): # pragma: no cover
|
||||
LOGGER.warning("Failed to run task, remove state", task=task)
|
||||
# if we get an import error, the module path has probably changed
|
||||
task.delete()
|
||||
return Response(status=500)
|
||||
@ -1,7 +1,6 @@
|
||||
"""admin signals"""
|
||||
from django.dispatch import receiver
|
||||
|
||||
from authentik.admin.api.tasks import TaskInfo
|
||||
from authentik.admin.apps import GAUGE_WORKERS
|
||||
from authentik.root.celery import CELERY_APP
|
||||
from authentik.root.monitoring import monitoring_set
|
||||
@ -12,10 +11,3 @@ def monitoring_set_workers(sender, **kwargs):
|
||||
"""Set worker gauge"""
|
||||
count = len(CELERY_APP.control.ping(timeout=0.5))
|
||||
GAUGE_WORKERS.set(count)
|
||||
|
||||
|
||||
@receiver(monitoring_set)
|
||||
def monitoring_set_tasks(sender, **kwargs):
|
||||
"""Set task gauges"""
|
||||
for task in TaskInfo.all().values():
|
||||
task.update_metrics()
|
||||
|
||||
@ -11,12 +11,7 @@ from structlog.stdlib import get_logger
|
||||
from authentik import __version__, get_build_hash
|
||||
from authentik.admin.apps import PROM_INFO
|
||||
from authentik.events.models import Event, EventAction, Notification
|
||||
from authentik.events.monitored_tasks import (
|
||||
MonitoredTask,
|
||||
TaskResult,
|
||||
TaskResultStatus,
|
||||
prefill_task,
|
||||
)
|
||||
from authentik.events.system_tasks import SystemTask, TaskStatus, prefill_task
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.lib.utils.http import get_http_session
|
||||
from authentik.root.celery import CELERY_APP
|
||||
@ -54,13 +49,13 @@ def clear_update_notifications():
|
||||
notification.delete()
|
||||
|
||||
|
||||
@CELERY_APP.task(bind=True, base=MonitoredTask)
|
||||
@CELERY_APP.task(bind=True, base=SystemTask)
|
||||
@prefill_task
|
||||
def update_latest_version(self: MonitoredTask):
|
||||
def update_latest_version(self: SystemTask):
|
||||
"""Update latest version info"""
|
||||
if CONFIG.get_bool("disable_update_check"):
|
||||
cache.set(VERSION_CACHE_KEY, "0.0.0", VERSION_CACHE_TIMEOUT)
|
||||
self.set_status(TaskResult(TaskResultStatus.WARNING, messages=["Version check disabled."]))
|
||||
self.set_status(TaskStatus.WARNING, "Version check disabled.")
|
||||
return
|
||||
try:
|
||||
response = get_http_session().get(
|
||||
@ -70,9 +65,7 @@ def update_latest_version(self: MonitoredTask):
|
||||
data = response.json()
|
||||
upstream_version = data.get("stable", {}).get("version")
|
||||
cache.set(VERSION_CACHE_KEY, upstream_version, VERSION_CACHE_TIMEOUT)
|
||||
self.set_status(
|
||||
TaskResult(TaskResultStatus.SUCCESSFUL, ["Successfully updated latest Version"])
|
||||
)
|
||||
self.set_status(TaskStatus.SUCCESSFUL, "Successfully updated latest Version")
|
||||
_set_prom_info()
|
||||
# Check if upstream version is newer than what we're running,
|
||||
# and if no event exists yet, create one.
|
||||
@ -89,7 +82,7 @@ def update_latest_version(self: MonitoredTask):
|
||||
Event.new(EventAction.UPDATE_AVAILABLE, **event_dict).save()
|
||||
except (RequestException, IndexError) as exc:
|
||||
cache.set(VERSION_CACHE_KEY, "0.0.0", VERSION_CACHE_TIMEOUT)
|
||||
self.set_status(TaskResult(TaskResultStatus.ERROR).with_error(exc))
|
||||
self.set_error(exc)
|
||||
|
||||
|
||||
_set_prom_info()
|
||||
|
||||
@ -7,8 +7,6 @@ from django.urls import reverse
|
||||
from authentik import __version__
|
||||
from authentik.blueprints.tests import reconcile_app
|
||||
from authentik.core.models import Group, User
|
||||
from authentik.core.tasks import clean_expired_models
|
||||
from authentik.events.monitored_tasks import TaskResultStatus
|
||||
from authentik.lib.generators import generate_id
|
||||
|
||||
|
||||
@ -23,53 +21,6 @@ class TestAdminAPI(TestCase):
|
||||
self.group.save()
|
||||
self.client.force_login(self.user)
|
||||
|
||||
def test_tasks(self):
|
||||
"""Test Task API"""
|
||||
clean_expired_models.delay()
|
||||
response = self.client.get(reverse("authentik_api:admin_system_tasks-list"))
|
||||
self.assertEqual(response.status_code, 200)
|
||||
body = loads(response.content)
|
||||
self.assertTrue(any(task["task_name"] == "clean_expired_models" for task in body))
|
||||
|
||||
def test_tasks_single(self):
|
||||
"""Test Task API (read single)"""
|
||||
clean_expired_models.delay()
|
||||
response = self.client.get(
|
||||
reverse(
|
||||
"authentik_api:admin_system_tasks-detail",
|
||||
kwargs={"pk": "clean_expired_models"},
|
||||
)
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
body = loads(response.content)
|
||||
self.assertEqual(body["status"], TaskResultStatus.SUCCESSFUL.name)
|
||||
self.assertEqual(body["task_name"], "clean_expired_models")
|
||||
response = self.client.get(
|
||||
reverse("authentik_api:admin_system_tasks-detail", kwargs={"pk": "qwerqwer"})
|
||||
)
|
||||
self.assertEqual(response.status_code, 404)
|
||||
|
||||
def test_tasks_retry(self):
|
||||
"""Test Task API (retry)"""
|
||||
clean_expired_models.delay()
|
||||
response = self.client.post(
|
||||
reverse(
|
||||
"authentik_api:admin_system_tasks-retry",
|
||||
kwargs={"pk": "clean_expired_models"},
|
||||
)
|
||||
)
|
||||
self.assertEqual(response.status_code, 204)
|
||||
|
||||
def test_tasks_retry_404(self):
|
||||
"""Test Task API (retry, 404)"""
|
||||
response = self.client.post(
|
||||
reverse(
|
||||
"authentik_api:admin_system_tasks-retry",
|
||||
kwargs={"pk": "qwerqewrqrqewrqewr"},
|
||||
)
|
||||
)
|
||||
self.assertEqual(response.status_code, 404)
|
||||
|
||||
def test_version(self):
|
||||
"""Test Version API"""
|
||||
response = self.client.get(reverse("authentik_api:admin_version"))
|
||||
|
||||
@ -4,12 +4,10 @@ from django.urls import path
|
||||
from authentik.admin.api.meta import AppsViewSet, ModelViewSet
|
||||
from authentik.admin.api.metrics import AdministrationMetricsViewSet
|
||||
from authentik.admin.api.system import SystemView
|
||||
from authentik.admin.api.tasks import TaskViewSet
|
||||
from authentik.admin.api.version import VersionView
|
||||
from authentik.admin.api.workers import WorkerView
|
||||
|
||||
api_urlpatterns = [
|
||||
("admin/system_tasks", TaskViewSet, "admin_system_tasks"),
|
||||
("admin/apps", AppsViewSet, "apps"),
|
||||
("admin/models", ModelViewSet, "models"),
|
||||
path(
|
||||
|
||||
@ -11,14 +11,14 @@ from structlog.stdlib import BoundLogger, get_logger
|
||||
class ManagedAppConfig(AppConfig):
|
||||
"""Basic reconciliation logic for apps"""
|
||||
|
||||
_logger: BoundLogger
|
||||
logger: BoundLogger
|
||||
|
||||
RECONCILE_GLOBAL_PREFIX: str = "reconcile_global_"
|
||||
RECONCILE_TENANT_PREFIX: str = "reconcile_tenant_"
|
||||
|
||||
def __init__(self, app_name: str, *args, **kwargs) -> None:
|
||||
super().__init__(app_name, *args, **kwargs)
|
||||
self._logger = get_logger().bind(app_name=app_name)
|
||||
self.logger = get_logger().bind(app_name=app_name)
|
||||
|
||||
def ready(self) -> None:
|
||||
self.reconcile_global()
|
||||
@ -38,11 +38,11 @@ class ManagedAppConfig(AppConfig):
|
||||
continue
|
||||
name = meth_name.replace(prefix, "")
|
||||
try:
|
||||
self._logger.debug("Starting reconciler", name=name)
|
||||
self.logger.debug("Starting reconciler", name=name)
|
||||
meth()
|
||||
self._logger.debug("Successfully reconciled", name=name)
|
||||
self.logger.debug("Successfully reconciled", name=name)
|
||||
except (DatabaseError, ProgrammingError, InternalError) as exc:
|
||||
self._logger.warning("Failed to run reconcile", name=name, exc=exc)
|
||||
self.logger.warning("Failed to run reconcile", name=name, exc=exc)
|
||||
|
||||
def reconcile_tenant(self) -> None:
|
||||
"""reconcile ourselves for tenanted methods"""
|
||||
@ -51,7 +51,7 @@ class ManagedAppConfig(AppConfig):
|
||||
try:
|
||||
tenants = list(Tenant.objects.filter(ready=True))
|
||||
except (DatabaseError, ProgrammingError, InternalError) as exc:
|
||||
self._logger.debug("Failed to get tenants to run reconcile", exc=exc)
|
||||
self.logger.debug("Failed to get tenants to run reconcile", exc=exc)
|
||||
return
|
||||
for tenant in tenants:
|
||||
with tenant:
|
||||
|
||||
@ -14,6 +14,7 @@ from django.db.models import Model
|
||||
from django.db.models.query_utils import Q
|
||||
from django.db.transaction import atomic
|
||||
from django.db.utils import IntegrityError
|
||||
from guardian.models import UserObjectPermission
|
||||
from rest_framework.exceptions import ValidationError
|
||||
from rest_framework.serializers import BaseSerializer, Serializer
|
||||
from structlog.stdlib import BoundLogger, get_logger
|
||||
@ -38,12 +39,16 @@ from authentik.core.models import (
|
||||
UserSourceConnection,
|
||||
)
|
||||
from authentik.enterprise.models import LicenseKey, LicenseUsage
|
||||
from authentik.enterprise.providers.rac.models import ConnectionToken
|
||||
from authentik.events.models import SystemTask
|
||||
from authentik.events.utils import cleanse_dict
|
||||
from authentik.flows.models import FlowToken, Stage
|
||||
from authentik.lib.models import SerializerModel
|
||||
from authentik.lib.sentry import SentryIgnoredException
|
||||
from authentik.outposts.models import OutpostServiceConnection
|
||||
from authentik.policies.models import Policy, PolicyBindingModel
|
||||
from authentik.policies.reputation.models import Reputation
|
||||
from authentik.providers.oauth2.models import AccessToken, AuthorizationCode, RefreshToken
|
||||
from authentik.providers.scim.models import SCIMGroup, SCIMUser
|
||||
from authentik.tenants.models import Tenant
|
||||
|
||||
@ -65,6 +70,7 @@ def excluded_models() -> list[type[Model]]:
|
||||
DjangoGroup,
|
||||
ContentType,
|
||||
Permission,
|
||||
UserObjectPermission,
|
||||
# Base classes
|
||||
Provider,
|
||||
Source,
|
||||
@ -82,6 +88,12 @@ def excluded_models() -> list[type[Model]]:
|
||||
SCIMGroup,
|
||||
SCIMUser,
|
||||
Tenant,
|
||||
SystemTask,
|
||||
ConnectionToken,
|
||||
AuthorizationCode,
|
||||
AccessToken,
|
||||
RefreshToken,
|
||||
Reputation,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -29,12 +29,8 @@ from authentik.blueprints.v1.common import BlueprintLoader, BlueprintMetadata, E
|
||||
from authentik.blueprints.v1.importer import Importer
|
||||
from authentik.blueprints.v1.labels import LABEL_AUTHENTIK_INSTANTIATE
|
||||
from authentik.blueprints.v1.oci import OCI_PREFIX
|
||||
from authentik.events.monitored_tasks import (
|
||||
MonitoredTask,
|
||||
TaskResult,
|
||||
TaskResultStatus,
|
||||
prefill_task,
|
||||
)
|
||||
from authentik.events.models import TaskStatus
|
||||
from authentik.events.system_tasks import SystemTask, prefill_task
|
||||
from authentik.events.utils import sanitize_dict
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.root.celery import CELERY_APP
|
||||
@ -134,10 +130,10 @@ def blueprints_find() -> list[BlueprintFile]:
|
||||
|
||||
|
||||
@CELERY_APP.task(
|
||||
throws=(DatabaseError, ProgrammingError, InternalError), base=MonitoredTask, bind=True
|
||||
throws=(DatabaseError, ProgrammingError, InternalError), base=SystemTask, bind=True
|
||||
)
|
||||
@prefill_task
|
||||
def blueprints_discovery(self: MonitoredTask, path: Optional[str] = None):
|
||||
def blueprints_discovery(self: SystemTask, path: Optional[str] = None):
|
||||
"""Find blueprints and check if they need to be created in the database"""
|
||||
count = 0
|
||||
for blueprint in blueprints_find():
|
||||
@ -146,10 +142,7 @@ def blueprints_discovery(self: MonitoredTask, path: Optional[str] = None):
|
||||
check_blueprint_v1_file(blueprint)
|
||||
count += 1
|
||||
self.set_status(
|
||||
TaskResult(
|
||||
TaskResultStatus.SUCCESSFUL,
|
||||
messages=[_("Successfully imported %(count)d files." % {"count": count})],
|
||||
)
|
||||
TaskStatus.SUCCESSFUL, _("Successfully imported %(count)d files." % {"count": count})
|
||||
)
|
||||
|
||||
|
||||
@ -182,9 +175,9 @@ def check_blueprint_v1_file(blueprint: BlueprintFile):
|
||||
|
||||
@CELERY_APP.task(
|
||||
bind=True,
|
||||
base=MonitoredTask,
|
||||
base=SystemTask,
|
||||
)
|
||||
def apply_blueprint(self: MonitoredTask, instance_pk: str):
|
||||
def apply_blueprint(self: SystemTask, instance_pk: str):
|
||||
"""Apply single blueprint"""
|
||||
self.save_on_success = False
|
||||
instance: Optional[BlueprintInstance] = None
|
||||
@ -202,18 +195,18 @@ def apply_blueprint(self: MonitoredTask, instance_pk: str):
|
||||
if not valid:
|
||||
instance.status = BlueprintInstanceStatus.ERROR
|
||||
instance.save()
|
||||
self.set_status(TaskResult(TaskResultStatus.ERROR, [x["event"] for x in logs]))
|
||||
self.set_status(TaskStatus.ERROR, *[x["event"] for x in logs])
|
||||
return
|
||||
applied = importer.apply()
|
||||
if not applied:
|
||||
instance.status = BlueprintInstanceStatus.ERROR
|
||||
instance.save()
|
||||
self.set_status(TaskResult(TaskResultStatus.ERROR, "Failed to apply"))
|
||||
self.set_status(TaskStatus.ERROR, "Failed to apply")
|
||||
return
|
||||
instance.status = BlueprintInstanceStatus.SUCCESSFUL
|
||||
instance.last_applied_hash = file_hash
|
||||
instance.last_applied = now()
|
||||
self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL))
|
||||
self.set_status(TaskStatus.SUCCESSFUL)
|
||||
except (
|
||||
DatabaseError,
|
||||
ProgrammingError,
|
||||
@ -224,7 +217,7 @@ def apply_blueprint(self: MonitoredTask, instance_pk: str):
|
||||
) as exc:
|
||||
if instance:
|
||||
instance.status = BlueprintInstanceStatus.ERROR
|
||||
self.set_status(TaskResult(TaskResultStatus.ERROR).with_error(exc))
|
||||
self.set_error(exc)
|
||||
finally:
|
||||
if instance:
|
||||
instance.save()
|
||||
|
||||
@ -31,7 +31,7 @@ class UsedBySerializer(PassiveSerializer):
|
||||
model_name = CharField()
|
||||
pk = CharField()
|
||||
name = CharField()
|
||||
action = ChoiceField(choices=[(x.name, x.name) for x in DeleteAction])
|
||||
action = ChoiceField(choices=[(x.value, x.name) for x in DeleteAction])
|
||||
|
||||
|
||||
def get_delete_action(manager: Manager) -> str:
|
||||
|
||||
@ -13,20 +13,15 @@ from authentik.core.models import (
|
||||
ExpiringModel,
|
||||
User,
|
||||
)
|
||||
from authentik.events.monitored_tasks import (
|
||||
MonitoredTask,
|
||||
TaskResult,
|
||||
TaskResultStatus,
|
||||
prefill_task,
|
||||
)
|
||||
from authentik.events.system_tasks import SystemTask, TaskStatus, prefill_task
|
||||
from authentik.root.celery import CELERY_APP
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
@CELERY_APP.task(bind=True, base=MonitoredTask)
|
||||
@CELERY_APP.task(bind=True, base=SystemTask)
|
||||
@prefill_task
|
||||
def clean_expired_models(self: MonitoredTask):
|
||||
def clean_expired_models(self: SystemTask):
|
||||
"""Remove expired objects"""
|
||||
messages = []
|
||||
for cls in ExpiringModel.__subclasses__():
|
||||
@ -54,12 +49,12 @@ def clean_expired_models(self: MonitoredTask):
|
||||
amount += 1
|
||||
LOGGER.debug("Expired sessions", model=AuthenticatedSession, amount=amount)
|
||||
messages.append(f"Expired {amount} {AuthenticatedSession._meta.verbose_name_plural}")
|
||||
self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL, messages))
|
||||
self.set_status(TaskStatus.SUCCESSFUL, *messages)
|
||||
|
||||
|
||||
@CELERY_APP.task(bind=True, base=MonitoredTask)
|
||||
@CELERY_APP.task(bind=True, base=SystemTask)
|
||||
@prefill_task
|
||||
def clean_temporary_users(self: MonitoredTask):
|
||||
def clean_temporary_users(self: SystemTask):
|
||||
"""Remove temporary users created by SAML Sources"""
|
||||
_now = datetime.now()
|
||||
messages = []
|
||||
@ -75,4 +70,4 @@ def clean_temporary_users(self: MonitoredTask):
|
||||
user.delete()
|
||||
deleted_users += 1
|
||||
messages.append(f"Successfully deleted {deleted_users} users.")
|
||||
self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL, messages))
|
||||
self.set_status(TaskStatus.SUCCESSFUL, *messages)
|
||||
|
||||
@ -9,12 +9,8 @@ from django.utils.translation import gettext_lazy as _
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.crypto.models import CertificateKeyPair
|
||||
from authentik.events.monitored_tasks import (
|
||||
MonitoredTask,
|
||||
TaskResult,
|
||||
TaskResultStatus,
|
||||
prefill_task,
|
||||
)
|
||||
from authentik.events.models import TaskStatus
|
||||
from authentik.events.system_tasks import SystemTask, prefill_task
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.root.celery import CELERY_APP
|
||||
|
||||
@ -39,9 +35,9 @@ def ensure_certificate_valid(body: str):
|
||||
return body
|
||||
|
||||
|
||||
@CELERY_APP.task(bind=True, base=MonitoredTask)
|
||||
@CELERY_APP.task(bind=True, base=SystemTask)
|
||||
@prefill_task
|
||||
def certificate_discovery(self: MonitoredTask):
|
||||
def certificate_discovery(self: SystemTask):
|
||||
"""Discover, import and update certificates from the filesystem"""
|
||||
certs = {}
|
||||
private_keys = {}
|
||||
@ -88,8 +84,5 @@ def certificate_discovery(self: MonitoredTask):
|
||||
if dirty:
|
||||
cert.save()
|
||||
self.set_status(
|
||||
TaskResult(
|
||||
TaskResultStatus.SUCCESSFUL,
|
||||
messages=[_("Successfully imported %(count)d files." % {"count": discovered})],
|
||||
)
|
||||
TaskStatus.SUCCESSFUL, _("Successfully imported %(count)d files." % {"count": discovered})
|
||||
)
|
||||
|
||||
107
authentik/events/api/tasks.py
Normal file
107
authentik/events/api/tasks.py
Normal file
@ -0,0 +1,107 @@
|
||||
"""Tasks API"""
|
||||
from datetime import datetime, timezone
|
||||
from importlib import import_module
|
||||
|
||||
from django.contrib import messages
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import OpenApiResponse, extend_schema
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.fields import CharField, ChoiceField, ListField, SerializerMethodField
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.serializers import ModelSerializer
|
||||
from rest_framework.viewsets import ReadOnlyModelViewSet
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.api.decorators import permission_required
|
||||
from authentik.events.models import SystemTask, TaskStatus
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
class SystemTaskSerializer(ModelSerializer):
|
||||
"""Serialize TaskInfo and TaskResult"""
|
||||
|
||||
name = CharField()
|
||||
full_name = SerializerMethodField()
|
||||
uid = CharField(required=False)
|
||||
description = CharField()
|
||||
start_timestamp = SerializerMethodField()
|
||||
finish_timestamp = SerializerMethodField()
|
||||
duration = SerializerMethodField()
|
||||
|
||||
status = ChoiceField(choices=[(x.value, x.name) for x in TaskStatus])
|
||||
messages = ListField(child=CharField())
|
||||
|
||||
def get_full_name(self, instance: SystemTask) -> str:
|
||||
"""Get full name with UID"""
|
||||
if instance.uid:
|
||||
return f"{instance.name}:{instance.uid}"
|
||||
return instance.name
|
||||
|
||||
def get_start_timestamp(self, instance: SystemTask) -> datetime:
|
||||
"""Timestamp when the task started"""
|
||||
return datetime.fromtimestamp(instance.start_timestamp, tz=timezone.utc)
|
||||
|
||||
def get_finish_timestamp(self, instance: SystemTask) -> datetime:
|
||||
"""Timestamp when the task finished"""
|
||||
return datetime.fromtimestamp(instance.finish_timestamp, tz=timezone.utc)
|
||||
|
||||
def get_duration(self, instance: SystemTask) -> float:
|
||||
"""Get the duration a task took to run"""
|
||||
return max(instance.finish_timestamp - instance.start_timestamp, 0)
|
||||
|
||||
class Meta:
|
||||
model = SystemTask
|
||||
fields = [
|
||||
"uuid",
|
||||
"name",
|
||||
"full_name",
|
||||
"uid",
|
||||
"description",
|
||||
"start_timestamp",
|
||||
"finish_timestamp",
|
||||
"duration",
|
||||
"status",
|
||||
"messages",
|
||||
]
|
||||
|
||||
|
||||
class SystemTaskViewSet(ReadOnlyModelViewSet):
|
||||
"""Read-only view set that returns all background tasks"""
|
||||
|
||||
queryset = SystemTask.objects.all()
|
||||
serializer_class = SystemTaskSerializer
|
||||
filterset_fields = ["name", "uid", "status"]
|
||||
ordering = ["name", "uid", "status"]
|
||||
search_fields = ["name", "description", "uid", "status"]
|
||||
|
||||
@permission_required(None, ["authentik_events.run_task"])
|
||||
@extend_schema(
|
||||
request=OpenApiTypes.NONE,
|
||||
responses={
|
||||
204: OpenApiResponse(description="Task retried successfully"),
|
||||
404: OpenApiResponse(description="Task not found"),
|
||||
500: OpenApiResponse(description="Failed to retry task"),
|
||||
},
|
||||
)
|
||||
@action(detail=True, methods=["post"])
|
||||
def run(self, request: Request, pk=None) -> Response:
|
||||
"""Run task"""
|
||||
task: SystemTask = self.get_object()
|
||||
try:
|
||||
task_module = import_module(task.task_call_module)
|
||||
task_func = getattr(task_module, task.task_call_func)
|
||||
LOGGER.info("Running task", task=task_func)
|
||||
task_func.delay(*task.task_call_args, **task.task_call_kwargs)
|
||||
messages.success(
|
||||
self.request,
|
||||
_("Successfully started task %(name)s." % {"name": task.name}),
|
||||
)
|
||||
return Response(status=204)
|
||||
except (ImportError, AttributeError) as exc: # pragma: no cover
|
||||
LOGGER.warning("Failed to run task, remove state", task=task.name, exc=exc)
|
||||
# if we get an import error, the module path has probably changed
|
||||
task.delete()
|
||||
return Response(status=500)
|
||||
@ -1,15 +1,26 @@
|
||||
"""authentik events app"""
|
||||
from prometheus_client import Gauge
|
||||
from prometheus_client import Gauge, Histogram
|
||||
|
||||
from authentik.blueprints.apps import ManagedAppConfig
|
||||
from authentik.lib.config import CONFIG, ENV_PREFIX
|
||||
|
||||
# TODO: Deprecated metric - remove in 2024.2 or later
|
||||
GAUGE_TASKS = Gauge(
|
||||
"authentik_system_tasks",
|
||||
"System tasks and their status",
|
||||
["tenant", "task_name", "task_uid", "status"],
|
||||
)
|
||||
|
||||
SYSTEM_TASK_TIME = Histogram(
|
||||
"authentik_system_tasks_time_seconds",
|
||||
"Runtime of system tasks",
|
||||
)
|
||||
SYSTEM_TASK_STATUS = Gauge(
|
||||
"authentik_system_tasks_status",
|
||||
"System task status",
|
||||
["task_name", "task_uid", "status"],
|
||||
)
|
||||
|
||||
|
||||
class AuthentikEventsConfig(ManagedAppConfig):
|
||||
"""authentik events app"""
|
||||
@ -43,3 +54,14 @@ class AuthentikEventsConfig(ManagedAppConfig):
|
||||
replacement_env=replace_env,
|
||||
message=msg,
|
||||
).save()
|
||||
|
||||
def reconcile_prefill_tasks(self):
|
||||
"""Prefill tasks"""
|
||||
from authentik.events.models import SystemTask
|
||||
from authentik.events.system_tasks import _prefill_tasks
|
||||
|
||||
for task in _prefill_tasks:
|
||||
if SystemTask.objects.filter(name=task.name).exists():
|
||||
continue
|
||||
task.save()
|
||||
self.logger.debug("prefilled task", task_name=task.name)
|
||||
|
||||
@ -9,19 +9,14 @@ from django.core.exceptions import SuspiciousOperation
|
||||
from django.db.models import Model
|
||||
from django.db.models.signals import m2m_changed, post_save, pre_delete
|
||||
from django.http import HttpRequest, HttpResponse
|
||||
from guardian.models import UserObjectPermission
|
||||
from structlog.stdlib import BoundLogger, get_logger
|
||||
|
||||
from authentik.blueprints.v1.importer import excluded_models
|
||||
from authentik.core.models import Group, User
|
||||
from authentik.enterprise.providers.rac.models import ConnectionToken
|
||||
from authentik.events.models import Event, EventAction, Notification
|
||||
from authentik.events.utils import model_to_dict
|
||||
from authentik.lib.sentry import before_send
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
from authentik.policies.reputation.models import Reputation
|
||||
from authentik.providers.oauth2.models import AccessToken, AuthorizationCode, RefreshToken
|
||||
from authentik.providers.scim.models import SCIMGroup, SCIMUser
|
||||
from authentik.stages.authenticator_static.models import StaticToken
|
||||
|
||||
IGNORED_MODELS = tuple(
|
||||
@ -29,16 +24,8 @@ IGNORED_MODELS = tuple(
|
||||
+ (
|
||||
Event,
|
||||
Notification,
|
||||
UserObjectPermission,
|
||||
StaticToken,
|
||||
Session,
|
||||
AuthorizationCode,
|
||||
AccessToken,
|
||||
RefreshToken,
|
||||
SCIMUser,
|
||||
SCIMGroup,
|
||||
Reputation,
|
||||
ConnectionToken,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
60
authentik/events/migrations/0004_systemtask.py
Normal file
60
authentik/events/migrations/0004_systemtask.py
Normal file
@ -0,0 +1,60 @@
|
||||
# Generated by Django 5.0.1 on 2024-01-24 12:48
|
||||
|
||||
import uuid
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
import authentik.core.models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("authentik_events", "0003_rename_tenant_event_brand"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="SystemTask",
|
||||
fields=[
|
||||
(
|
||||
"expires",
|
||||
models.DateTimeField(default=authentik.core.models.default_token_duration),
|
||||
),
|
||||
("expiring", models.BooleanField(default=True)),
|
||||
(
|
||||
"uuid",
|
||||
models.UUIDField(
|
||||
default=uuid.uuid4, editable=False, primary_key=True, serialize=False
|
||||
),
|
||||
),
|
||||
("name", models.TextField()),
|
||||
("uid", models.TextField(null=True)),
|
||||
("start_timestamp", models.FloatField()),
|
||||
("finish_timestamp", models.FloatField()),
|
||||
(
|
||||
"status",
|
||||
models.TextField(
|
||||
choices=[
|
||||
("unknown", "Unknown"),
|
||||
("successful", "Successful"),
|
||||
("warning", "Warning"),
|
||||
("error", "Error"),
|
||||
]
|
||||
),
|
||||
),
|
||||
("description", models.TextField(null=True)),
|
||||
("messages", models.JSONField()),
|
||||
("task_call_module", models.TextField()),
|
||||
("task_call_func", models.TextField()),
|
||||
("task_call_args", models.JSONField(default=list)),
|
||||
("task_call_kwargs", models.JSONField(default=dict)),
|
||||
],
|
||||
options={
|
||||
"verbose_name": "System Task",
|
||||
"verbose_name_plural": "System Tasks",
|
||||
"permissions": [("run_task", "Run task")],
|
||||
"default_permissions": ["view"],
|
||||
"unique_together": {("name", "uid")},
|
||||
},
|
||||
),
|
||||
]
|
||||
@ -2,11 +2,14 @@
|
||||
import time
|
||||
from collections import Counter
|
||||
from datetime import timedelta
|
||||
from difflib import get_close_matches
|
||||
from functools import lru_cache
|
||||
from inspect import currentframe
|
||||
from smtplib import SMTPException
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from django.apps import apps
|
||||
from django.db import models
|
||||
from django.db.models import Count, ExpressionWrapper, F
|
||||
from django.db.models.fields import DurationField
|
||||
@ -18,6 +21,7 @@ from django.http.request import QueryDict
|
||||
from django.utils.timezone import now
|
||||
from django.utils.translation import gettext as _
|
||||
from requests import RequestException
|
||||
from rest_framework.serializers import Serializer
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik import get_full_version
|
||||
@ -28,6 +32,7 @@ from authentik.core.middleware import (
|
||||
SESSION_KEY_IMPERSONATE_USER,
|
||||
)
|
||||
from authentik.core.models import ExpiringModel, Group, PropertyMapping, User
|
||||
from authentik.events.apps import GAUGE_TASKS, SYSTEM_TASK_STATUS, SYSTEM_TASK_TIME
|
||||
from authentik.events.context_processors.base import get_context_processors
|
||||
from authentik.events.utils import (
|
||||
cleanse_dict,
|
||||
@ -46,8 +51,6 @@ from authentik.stages.email.utils import TemplateEmailMessage
|
||||
from authentik.tenants.models import Tenant
|
||||
|
||||
LOGGER = get_logger()
|
||||
if TYPE_CHECKING:
|
||||
from rest_framework.serializers import Serializer
|
||||
|
||||
|
||||
def default_event_duration():
|
||||
@ -61,6 +64,12 @@ def default_brand():
|
||||
return sanitize_dict(model_to_dict(DEFAULT_BRAND))
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def django_app_names() -> list[str]:
|
||||
"""Get a cached list of all django apps' names (not labels)"""
|
||||
return [x.name for x in apps.app_configs.values()]
|
||||
|
||||
|
||||
class NotificationTransportError(SentryIgnoredException):
|
||||
"""Error raised when a notification fails to be delivered"""
|
||||
|
||||
@ -198,6 +207,11 @@ class Event(SerializerModel, ExpiringModel):
|
||||
current = currentframe()
|
||||
parent = current.f_back
|
||||
app = parent.f_globals["__name__"]
|
||||
# Attempt to match the calling module to the django app it belongs to
|
||||
# if we can't find a match, keep the module name
|
||||
django_apps = get_close_matches(app, django_app_names(), n=1)
|
||||
if len(django_apps) > 0:
|
||||
app = django_apps[0]
|
||||
cleaned_kwargs = cleanse_dict(sanitize_dict(kwargs))
|
||||
event = Event(action=action, app=app, context=cleaned_kwargs)
|
||||
return event
|
||||
@ -270,7 +284,7 @@ class Event(SerializerModel, ExpiringModel):
|
||||
super().save(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def serializer(self) -> "Serializer":
|
||||
def serializer(self) -> type[Serializer]:
|
||||
from authentik.events.api.events import EventSerializer
|
||||
|
||||
return EventSerializer
|
||||
@ -478,7 +492,7 @@ class NotificationTransport(SerializerModel):
|
||||
raise NotificationTransportError(exc) from exc
|
||||
|
||||
@property
|
||||
def serializer(self) -> "Serializer":
|
||||
def serializer(self) -> type[Serializer]:
|
||||
from authentik.events.api.notification_transports import NotificationTransportSerializer
|
||||
|
||||
return NotificationTransportSerializer
|
||||
@ -511,7 +525,7 @@ class Notification(SerializerModel):
|
||||
user = models.ForeignKey(User, on_delete=models.CASCADE)
|
||||
|
||||
@property
|
||||
def serializer(self) -> "Serializer":
|
||||
def serializer(self) -> type[Serializer]:
|
||||
from authentik.events.api.notifications import NotificationSerializer
|
||||
|
||||
return NotificationSerializer
|
||||
@ -554,7 +568,7 @@ class NotificationRule(SerializerModel, PolicyBindingModel):
|
||||
)
|
||||
|
||||
@property
|
||||
def serializer(self) -> "Serializer":
|
||||
def serializer(self) -> type[Serializer]:
|
||||
from authentik.events.api.notification_rules import NotificationRuleSerializer
|
||||
|
||||
return NotificationRuleSerializer
|
||||
@ -575,7 +589,7 @@ class NotificationWebhookMapping(PropertyMapping):
|
||||
return "ak-property-mapping-notification-form"
|
||||
|
||||
@property
|
||||
def serializer(self) -> type["Serializer"]:
|
||||
def serializer(self) -> type[type[Serializer]]:
|
||||
from authentik.events.api.notification_mappings import NotificationWebhookMappingSerializer
|
||||
|
||||
return NotificationWebhookMappingSerializer
|
||||
@ -586,3 +600,66 @@ class NotificationWebhookMapping(PropertyMapping):
|
||||
class Meta:
|
||||
verbose_name = _("Webhook Mapping")
|
||||
verbose_name_plural = _("Webhook Mappings")
|
||||
|
||||
|
||||
class TaskStatus(models.TextChoices):
|
||||
"""Possible states of tasks"""
|
||||
|
||||
UNKNOWN = "unknown"
|
||||
SUCCESSFUL = "successful"
|
||||
WARNING = "warning"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class SystemTask(SerializerModel, ExpiringModel):
|
||||
"""Info about a system task running in the background along with details to restart the task"""
|
||||
|
||||
uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4)
|
||||
name = models.TextField()
|
||||
uid = models.TextField(null=True)
|
||||
|
||||
start_timestamp = models.FloatField()
|
||||
finish_timestamp = models.FloatField()
|
||||
|
||||
status = models.TextField(choices=TaskStatus.choices)
|
||||
|
||||
description = models.TextField(null=True)
|
||||
messages = models.JSONField()
|
||||
|
||||
task_call_module = models.TextField()
|
||||
task_call_func = models.TextField()
|
||||
task_call_args = models.JSONField(default=list)
|
||||
task_call_kwargs = models.JSONField(default=dict)
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[Serializer]:
|
||||
from authentik.events.api.tasks import SystemTaskSerializer
|
||||
|
||||
return SystemTaskSerializer
|
||||
|
||||
def update_metrics(self):
|
||||
"""Update prometheus metrics"""
|
||||
duration = max(self.finish_timestamp - self.start_timestamp, 0)
|
||||
# TODO: Deprecated metric - remove in 2024.2 or later
|
||||
GAUGE_TASKS.labels(
|
||||
task_name=self.name,
|
||||
task_uid=self.uid or "",
|
||||
status=self.status.lower(),
|
||||
).set(duration)
|
||||
SYSTEM_TASK_TIME.observe(duration)
|
||||
SYSTEM_TASK_STATUS.labels(
|
||||
task_name=self.name,
|
||||
task_uid=self.uid or "",
|
||||
status=self.status.lower(),
|
||||
).inc()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"System Task {self.name}"
|
||||
|
||||
class Meta:
|
||||
unique_together = (("name", "uid"),)
|
||||
# Remove "add", "change" and "delete" permissions as those are not used
|
||||
default_permissions = ["view"]
|
||||
permissions = [("run_task", _("Run task"))]
|
||||
verbose_name = _("System Task")
|
||||
verbose_name_plural = _("System Tasks")
|
||||
|
||||
@ -1,216 +0,0 @@
|
||||
"""Monitored tasks"""
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from timeit import default_timer
|
||||
from typing import Any, Optional
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.db import connection
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from structlog.stdlib import get_logger
|
||||
from tenant_schemas_celery.task import TenantTask
|
||||
|
||||
from authentik.events.apps import GAUGE_TASKS
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
|
||||
LOGGER = get_logger()
|
||||
CACHE_KEY_PREFIX = "goauthentik.io/events/tasks/"
|
||||
|
||||
|
||||
class TaskResultStatus(Enum):
|
||||
"""Possible states of tasks"""
|
||||
|
||||
SUCCESSFUL = 1
|
||||
WARNING = 2
|
||||
ERROR = 4
|
||||
UNKNOWN = 8
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskResult:
|
||||
"""Result of a task run, this class is created by the task itself
|
||||
and used by self.set_status"""
|
||||
|
||||
status: TaskResultStatus
|
||||
|
||||
messages: list[str] = field(default_factory=list)
|
||||
|
||||
# Optional UID used in cache for tasks that run in different instances
|
||||
uid: Optional[str] = field(default=None)
|
||||
|
||||
def with_error(self, exc: Exception) -> "TaskResult":
|
||||
"""Since errors might not always be pickle-able, set the traceback"""
|
||||
# TODO: Mark exception somehow so that is rendered as <pre> in frontend
|
||||
self.messages.append(exception_to_string(exc))
|
||||
return self
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskInfo:
|
||||
"""Info about a task run"""
|
||||
|
||||
task_name: str
|
||||
start_timestamp: float
|
||||
finish_timestamp: float
|
||||
finish_time: datetime
|
||||
|
||||
result: TaskResult
|
||||
|
||||
task_call_module: str
|
||||
task_call_func: str
|
||||
task_call_args: list[Any] = field(default_factory=list)
|
||||
task_call_kwargs: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
task_description: Optional[str] = field(default=None)
|
||||
|
||||
@staticmethod
|
||||
def all() -> dict[str, "TaskInfo"]:
|
||||
"""Get all TaskInfo objects"""
|
||||
return cache.get_many(cache.keys(CACHE_KEY_PREFIX + "*"))
|
||||
|
||||
@staticmethod
|
||||
def by_name(name: str) -> Optional["TaskInfo"] | Optional[list["TaskInfo"]]:
|
||||
"""Get TaskInfo Object by name"""
|
||||
if "*" in name:
|
||||
return cache.get_many(cache.keys(CACHE_KEY_PREFIX + name)).values()
|
||||
return cache.get(CACHE_KEY_PREFIX + name, None)
|
||||
|
||||
@property
|
||||
def full_name(self) -> str:
|
||||
"""Get the full cache key with task name and UID"""
|
||||
key = CACHE_KEY_PREFIX + self.task_name
|
||||
if self.result.uid:
|
||||
uid_suffix = f":{self.result.uid}"
|
||||
key += uid_suffix
|
||||
if not self.task_name.endswith(uid_suffix):
|
||||
self.task_name += uid_suffix
|
||||
return key
|
||||
|
||||
def delete(self):
|
||||
"""Delete task info from cache"""
|
||||
return cache.delete(self.full_name)
|
||||
|
||||
def update_metrics(self):
|
||||
"""Update prometheus metrics"""
|
||||
start = default_timer()
|
||||
if hasattr(self, "start_timestamp"):
|
||||
start = self.start_timestamp
|
||||
try:
|
||||
duration = max(self.finish_timestamp - start, 0)
|
||||
except TypeError:
|
||||
duration = 0
|
||||
GAUGE_TASKS.labels(
|
||||
tenant=connection.schema_name,
|
||||
task_name=self.task_name.split(":")[0],
|
||||
task_uid=self.result.uid or "",
|
||||
status=self.result.status.name.lower(),
|
||||
).set(duration)
|
||||
|
||||
def save(self, timeout_hours=6):
|
||||
"""Save task into cache"""
|
||||
self.update_metrics()
|
||||
cache.set(self.full_name, self, timeout=timeout_hours * 60 * 60)
|
||||
|
||||
|
||||
class MonitoredTask(TenantTask):
|
||||
"""Task which can save its state to the cache"""
|
||||
|
||||
# For tasks that should only be listed if they failed, set this to False
|
||||
save_on_success: bool
|
||||
|
||||
_result: Optional[TaskResult]
|
||||
|
||||
_uid: Optional[str]
|
||||
start: Optional[float] = None
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.save_on_success = True
|
||||
self._uid = None
|
||||
self._result = None
|
||||
self.result_timeout_hours = 6
|
||||
|
||||
def set_uid(self, uid: str):
|
||||
"""Set UID, so in the case of an unexpected error its saved correctly"""
|
||||
self._uid = uid
|
||||
|
||||
def set_status(self, result: TaskResult):
|
||||
"""Set result for current run, will overwrite previous result."""
|
||||
self._result = result
|
||||
|
||||
def before_start(self, task_id, args, kwargs):
|
||||
self.start = default_timer()
|
||||
return super().before_start(task_id, args, kwargs)
|
||||
|
||||
# pylint: disable=too-many-arguments
|
||||
def after_return(self, status, retval, task_id, args: list[Any], kwargs: dict[str, Any], einfo):
|
||||
super().after_return(status, retval, task_id, args, kwargs, einfo=einfo)
|
||||
if not self._result:
|
||||
return
|
||||
if not self._result.uid:
|
||||
self._result.uid = self._uid
|
||||
info = TaskInfo(
|
||||
task_name=self.__name__,
|
||||
task_description=self.__doc__,
|
||||
start_timestamp=self.start or default_timer(),
|
||||
finish_timestamp=default_timer(),
|
||||
finish_time=datetime.now(),
|
||||
result=self._result,
|
||||
task_call_module=self.__module__,
|
||||
task_call_func=self.__name__,
|
||||
task_call_args=args,
|
||||
task_call_kwargs=kwargs,
|
||||
)
|
||||
if self._result.status == TaskResultStatus.SUCCESSFUL and not self.save_on_success:
|
||||
info.delete()
|
||||
return
|
||||
info.save(self.result_timeout_hours)
|
||||
|
||||
# pylint: disable=too-many-arguments
|
||||
def on_failure(self, exc, task_id, args, kwargs, einfo):
|
||||
super().on_failure(exc, task_id, args, kwargs, einfo=einfo)
|
||||
if not self._result:
|
||||
self._result = TaskResult(status=TaskResultStatus.ERROR, messages=[str(exc)])
|
||||
if not self._result.uid:
|
||||
self._result.uid = self._uid
|
||||
TaskInfo(
|
||||
task_name=self.__name__,
|
||||
task_description=self.__doc__,
|
||||
start_timestamp=self.start or default_timer(),
|
||||
finish_timestamp=default_timer(),
|
||||
finish_time=datetime.now(),
|
||||
result=self._result,
|
||||
task_call_module=self.__module__,
|
||||
task_call_func=self.__name__,
|
||||
task_call_args=args,
|
||||
task_call_kwargs=kwargs,
|
||||
).save(self.result_timeout_hours)
|
||||
Event.new(
|
||||
EventAction.SYSTEM_TASK_EXCEPTION,
|
||||
message=f"Task {self.__name__} encountered an error: {exception_to_string(exc)}",
|
||||
).save()
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def prefill_task(func):
|
||||
"""Ensure a task's details are always in cache, so it can always be triggered via API"""
|
||||
status = TaskInfo.by_name(func.__name__)
|
||||
if status:
|
||||
return func
|
||||
TaskInfo(
|
||||
task_name=func.__name__,
|
||||
task_description=func.__doc__,
|
||||
result=TaskResult(TaskResultStatus.UNKNOWN, messages=[_("Task has not been run yet.")]),
|
||||
task_call_module=func.__module__,
|
||||
task_call_func=func.__name__,
|
||||
# We don't have real values for these attributes but they cannot be null
|
||||
start_timestamp=0,
|
||||
finish_timestamp=0,
|
||||
finish_time=datetime.now(),
|
||||
).save(86400)
|
||||
LOGGER.debug("prefilled task", task_name=func.__name__)
|
||||
return func
|
||||
@ -8,11 +8,13 @@ from django.http import HttpRequest
|
||||
|
||||
from authentik.core.models import User
|
||||
from authentik.core.signals import login_failed, password_changed
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.events.apps import SYSTEM_TASK_STATUS
|
||||
from authentik.events.models import Event, EventAction, SystemTask
|
||||
from authentik.events.tasks import event_notification_handler, gdpr_cleanup
|
||||
from authentik.flows.models import Stage
|
||||
from authentik.flows.planner import PLAN_CONTEXT_SOURCE, FlowPlan
|
||||
from authentik.flows.views.executor import SESSION_KEY_PLAN
|
||||
from authentik.root.monitoring import monitoring_set
|
||||
from authentik.stages.invitation.models import Invitation
|
||||
from authentik.stages.invitation.signals import invitation_used
|
||||
from authentik.stages.password.stage import PLAN_CONTEXT_METHOD, PLAN_CONTEXT_METHOD_ARGS
|
||||
@ -100,3 +102,11 @@ def event_user_pre_delete_cleanup(sender, instance: User, **_):
|
||||
"""If gdpr_compliance is enabled, remove all the user's events"""
|
||||
if get_current_tenant().gdpr_compliance:
|
||||
gdpr_cleanup.delay(instance.pk)
|
||||
|
||||
|
||||
@receiver(monitoring_set)
|
||||
def monitoring_system_task(sender, **_):
|
||||
"""Update metrics when task is saved"""
|
||||
SYSTEM_TASK_STATUS.clear()
|
||||
for task in SystemTask.objects.all():
|
||||
task.update_metrics()
|
||||
|
||||
135
authentik/events/system_tasks.py
Normal file
135
authentik/events/system_tasks.py
Normal file
@ -0,0 +1,135 @@
|
||||
"""Monitored tasks"""
|
||||
from datetime import timedelta
|
||||
from timeit import default_timer
|
||||
from typing import Any, Optional
|
||||
|
||||
from django.utils.timezone import now
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from structlog.stdlib import get_logger
|
||||
from tenant_schemas_celery.task import TenantTask
|
||||
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.events.models import SystemTask as DBSystemTask
|
||||
from authentik.events.models import TaskStatus
|
||||
from authentik.events.utils import sanitize_item
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
class SystemTask(TenantTask):
|
||||
"""Task which can save its state to the cache"""
|
||||
|
||||
# For tasks that should only be listed if they failed, set this to False
|
||||
save_on_success: bool
|
||||
|
||||
_status: Optional[TaskStatus]
|
||||
_messages: list[str]
|
||||
|
||||
_uid: Optional[str]
|
||||
_start: Optional[float] = None
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.save_on_success = True
|
||||
self._uid = None
|
||||
self._status = None
|
||||
self._messages = []
|
||||
self.result_timeout_hours = 6
|
||||
|
||||
def set_uid(self, uid: str):
|
||||
"""Set UID, so in the case of an unexpected error its saved correctly"""
|
||||
self._uid = uid
|
||||
|
||||
def set_status(self, status: TaskStatus, *messages: str):
|
||||
"""Set result for current run, will overwrite previous result."""
|
||||
self._status = status
|
||||
self._messages = messages
|
||||
|
||||
def set_error(self, exception: Exception):
|
||||
"""Set result to error and save exception"""
|
||||
self._status = TaskStatus.ERROR
|
||||
self._messages = [exception_to_string(exception)]
|
||||
|
||||
def before_start(self, task_id, args, kwargs):
|
||||
self._start = default_timer()
|
||||
return super().before_start(task_id, args, kwargs)
|
||||
|
||||
# pylint: disable=too-many-arguments
|
||||
def after_return(self, status, retval, task_id, args: list[Any], kwargs: dict[str, Any], einfo):
|
||||
super().after_return(status, retval, task_id, args, kwargs, einfo=einfo)
|
||||
if not self._status:
|
||||
return
|
||||
if self._status == TaskStatus.SUCCESSFUL and not self.save_on_success:
|
||||
DBSystemTask.objects.filter(
|
||||
name=self.__name__,
|
||||
uid=self._uid,
|
||||
).delete()
|
||||
return
|
||||
DBSystemTask.objects.update_or_create(
|
||||
name=self.__name__,
|
||||
uid=self._uid,
|
||||
defaults={
|
||||
"description": self.__doc__,
|
||||
"start_timestamp": self._start or default_timer(),
|
||||
"finish_timestamp": default_timer(),
|
||||
"task_call_module": self.__module__,
|
||||
"task_call_func": self.__name__,
|
||||
"task_call_args": args,
|
||||
"task_call_kwargs": kwargs,
|
||||
"status": self._status,
|
||||
"messages": sanitize_item(self._messages),
|
||||
"expires": now() + timedelta(hours=self.result_timeout_hours),
|
||||
"expiring": True,
|
||||
},
|
||||
)
|
||||
|
||||
# pylint: disable=too-many-arguments
|
||||
def on_failure(self, exc, task_id, args, kwargs, einfo):
|
||||
super().on_failure(exc, task_id, args, kwargs, einfo=einfo)
|
||||
if not self._status:
|
||||
self._status = TaskStatus.ERROR
|
||||
self._messages = exception_to_string(exc)
|
||||
DBSystemTask.objects.update_or_create(
|
||||
name=self.__name__,
|
||||
uid=self._uid,
|
||||
defaults={
|
||||
"description": self.__doc__,
|
||||
"start_timestamp": self._start or default_timer(),
|
||||
"finish_timestamp": default_timer(),
|
||||
"task_call_module": self.__module__,
|
||||
"task_call_func": self.__name__,
|
||||
"task_call_args": args,
|
||||
"task_call_kwargs": kwargs,
|
||||
"status": self._status,
|
||||
"messages": sanitize_item(self._messages),
|
||||
"expires": now() + timedelta(hours=self.result_timeout_hours),
|
||||
"expiring": True,
|
||||
},
|
||||
)
|
||||
Event.new(
|
||||
EventAction.SYSTEM_TASK_EXCEPTION,
|
||||
message=f"Task {self.__name__} encountered an error: {exception_to_string(exc)}",
|
||||
).save()
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def prefill_task(func):
|
||||
"""Ensure a task's details are always in cache, so it can always be triggered via API"""
|
||||
_prefill_tasks.append(
|
||||
DBSystemTask(
|
||||
name=func.__name__,
|
||||
description=func.__doc__,
|
||||
status=TaskStatus.UNKNOWN,
|
||||
messages=sanitize_item([_("Task has not been run yet.")]),
|
||||
task_call_module=func.__module__,
|
||||
task_call_func=func.__name__,
|
||||
expiring=False,
|
||||
)
|
||||
)
|
||||
return func
|
||||
|
||||
|
||||
_prefill_tasks = []
|
||||
@ -13,13 +13,9 @@ from authentik.events.models import (
|
||||
NotificationRule,
|
||||
NotificationTransport,
|
||||
NotificationTransportError,
|
||||
TaskStatus,
|
||||
)
|
||||
from authentik.events.monitored_tasks import (
|
||||
MonitoredTask,
|
||||
TaskResult,
|
||||
TaskResultStatus,
|
||||
prefill_task,
|
||||
)
|
||||
from authentik.events.system_tasks import SystemTask, prefill_task
|
||||
from authentik.policies.engine import PolicyEngine
|
||||
from authentik.policies.models import PolicyBinding, PolicyEngineMode
|
||||
from authentik.root.celery import CELERY_APP
|
||||
@ -99,10 +95,10 @@ def event_trigger_handler(event_uuid: str, trigger_name: str):
|
||||
bind=True,
|
||||
autoretry_for=(NotificationTransportError,),
|
||||
retry_backoff=True,
|
||||
base=MonitoredTask,
|
||||
base=SystemTask,
|
||||
)
|
||||
def notification_transport(
|
||||
self: MonitoredTask, transport_pk: int, event_pk: str, user_pk: int, trigger_pk: str
|
||||
self: SystemTask, transport_pk: int, event_pk: str, user_pk: int, trigger_pk: str
|
||||
):
|
||||
"""Send notification over specified transport"""
|
||||
self.save_on_success = False
|
||||
@ -123,9 +119,9 @@ def notification_transport(
|
||||
if not transport:
|
||||
return
|
||||
transport.send(notification)
|
||||
self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL))
|
||||
self.set_status(TaskStatus.SUCCESSFUL)
|
||||
except (NotificationTransportError, PropertyMappingExpressionException) as exc:
|
||||
self.set_status(TaskResult(TaskResultStatus.ERROR).with_error(exc))
|
||||
self.set_error(exc)
|
||||
raise exc
|
||||
|
||||
|
||||
@ -137,13 +133,13 @@ def gdpr_cleanup(user_pk: int):
|
||||
events.delete()
|
||||
|
||||
|
||||
@CELERY_APP.task(bind=True, base=MonitoredTask)
|
||||
@CELERY_APP.task(bind=True, base=SystemTask)
|
||||
@prefill_task
|
||||
def notification_cleanup(self: MonitoredTask):
|
||||
def notification_cleanup(self: SystemTask):
|
||||
"""Cleanup seen notifications and notifications whose event expired."""
|
||||
notifications = Notification.objects.filter(Q(event=None) | Q(seen=True))
|
||||
amount = notifications.count()
|
||||
for notification in notifications:
|
||||
notification.delete()
|
||||
LOGGER.debug("Expired notifications", amount=amount)
|
||||
self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL, [f"Expired {amount} Notifications"]))
|
||||
self.set_status(TaskStatus.SUCCESSFUL, f"Expired {amount} Notifications")
|
||||
|
||||
@ -1,14 +1,26 @@
|
||||
"""Test Monitored tasks"""
|
||||
from django.test import TestCase
|
||||
from json import loads
|
||||
|
||||
from authentik.events.monitored_tasks import MonitoredTask, TaskInfo, TaskResult, TaskResultStatus
|
||||
from django.urls import reverse
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.core.tasks import clean_expired_models
|
||||
from authentik.core.tests.utils import create_test_admin_user
|
||||
from authentik.events.models import SystemTask as DBSystemTask
|
||||
from authentik.events.models import TaskStatus
|
||||
from authentik.events.system_tasks import SystemTask
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.root.celery import CELERY_APP
|
||||
|
||||
|
||||
class TestMonitoredTasks(TestCase):
|
||||
class TestSystemTasks(APITestCase):
|
||||
"""Test Monitored tasks"""
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.user = create_test_admin_user()
|
||||
self.client.force_login(self.user)
|
||||
|
||||
def test_failed_successful_remove_state(self):
|
||||
"""Test that a task with `save_on_success` set to `False` that failed saves
|
||||
a state, and upon successful completion will delete the state"""
|
||||
@ -17,27 +29,74 @@ class TestMonitoredTasks(TestCase):
|
||||
|
||||
@CELERY_APP.task(
|
||||
bind=True,
|
||||
base=MonitoredTask,
|
||||
base=SystemTask,
|
||||
)
|
||||
def test_task(self: MonitoredTask):
|
||||
def test_task(self: SystemTask):
|
||||
self.save_on_success = False
|
||||
self.set_uid(uid)
|
||||
self.set_status(
|
||||
TaskResult(TaskResultStatus.ERROR if should_fail else TaskResultStatus.SUCCESSFUL)
|
||||
)
|
||||
self.set_status(TaskStatus.ERROR if should_fail else TaskStatus.SUCCESSFUL)
|
||||
|
||||
# First test successful run
|
||||
should_fail = False
|
||||
test_task.delay().get()
|
||||
self.assertIsNone(TaskInfo.by_name(f"test_task:{uid}"))
|
||||
self.assertIsNone(DBSystemTask.objects.filter(name="test_task", uid=uid).first())
|
||||
|
||||
# Then test failed
|
||||
should_fail = True
|
||||
test_task.delay().get()
|
||||
info = TaskInfo.by_name(f"test_task:{uid}")
|
||||
self.assertEqual(info.result.status, TaskResultStatus.ERROR)
|
||||
task = DBSystemTask.objects.filter(name="test_task", uid=uid).first()
|
||||
self.assertEqual(task.status, TaskStatus.ERROR)
|
||||
|
||||
# Then after that, the state should be removed
|
||||
should_fail = False
|
||||
test_task.delay().get()
|
||||
self.assertIsNone(TaskInfo.by_name(f"test_task:{uid}"))
|
||||
self.assertIsNone(DBSystemTask.objects.filter(name="test_task", uid=uid).first())
|
||||
|
||||
def test_tasks(self):
|
||||
"""Test Task API"""
|
||||
clean_expired_models.delay().get()
|
||||
response = self.client.get(reverse("authentik_api:systemtask-list"))
|
||||
self.assertEqual(response.status_code, 200)
|
||||
body = loads(response.content)
|
||||
self.assertTrue(any(task["name"] == "clean_expired_models" for task in body["results"]))
|
||||
|
||||
def test_tasks_single(self):
|
||||
"""Test Task API (read single)"""
|
||||
clean_expired_models.delay().get()
|
||||
task = DBSystemTask.objects.filter(name="clean_expired_models").first()
|
||||
response = self.client.get(
|
||||
reverse(
|
||||
"authentik_api:systemtask-detail",
|
||||
kwargs={"pk": str(task.pk)},
|
||||
)
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
body = loads(response.content)
|
||||
self.assertEqual(body["status"], TaskStatus.SUCCESSFUL.value)
|
||||
self.assertEqual(body["name"], "clean_expired_models")
|
||||
response = self.client.get(
|
||||
reverse("authentik_api:systemtask-detail", kwargs={"pk": "qwerqwer"})
|
||||
)
|
||||
self.assertEqual(response.status_code, 404)
|
||||
|
||||
def test_tasks_run(self):
|
||||
"""Test Task API (run)"""
|
||||
clean_expired_models.delay().get()
|
||||
task = DBSystemTask.objects.filter(name="clean_expired_models").first()
|
||||
response = self.client.post(
|
||||
reverse(
|
||||
"authentik_api:systemtask-run",
|
||||
kwargs={"pk": str(task.pk)},
|
||||
)
|
||||
)
|
||||
self.assertEqual(response.status_code, 204)
|
||||
|
||||
def test_tasks_run_404(self):
|
||||
"""Test Task API (run, 404)"""
|
||||
response = self.client.post(
|
||||
reverse(
|
||||
"authentik_api:systemtask-run",
|
||||
kwargs={"pk": "qwerqewrqrqewrqewr"},
|
||||
)
|
||||
)
|
||||
self.assertEqual(response.status_code, 404)
|
||||
|
||||
@ -4,11 +4,13 @@ from authentik.events.api.notification_mappings import NotificationWebhookMappin
|
||||
from authentik.events.api.notification_rules import NotificationRuleViewSet
|
||||
from authentik.events.api.notification_transports import NotificationTransportViewSet
|
||||
from authentik.events.api.notifications import NotificationViewSet
|
||||
from authentik.events.api.tasks import SystemTaskViewSet
|
||||
|
||||
api_urlpatterns = [
|
||||
("events/events", EventViewSet),
|
||||
("events/notifications", NotificationViewSet),
|
||||
("events/transports", NotificationTransportViewSet),
|
||||
("events/rules", NotificationRuleViewSet),
|
||||
("events/system_tasks", SystemTaskViewSet),
|
||||
("propertymappings/notification", NotificationWebhookMappingViewSet),
|
||||
]
|
||||
|
||||
@ -18,6 +18,7 @@ from django.http.request import HttpRequest
|
||||
from django.utils import timezone
|
||||
from django.views.debug import SafeExceptionReporterFilter
|
||||
from geoip2.models import ASN, City
|
||||
from guardian.conf import settings
|
||||
from guardian.utils import get_anonymous_user
|
||||
|
||||
from authentik.blueprints.v1.common import YAMLTag
|
||||
@ -84,6 +85,8 @@ def get_user(user: User | AnonymousUser, original_user: Optional[User] = None) -
|
||||
"pk": user.pk,
|
||||
"email": user.email,
|
||||
}
|
||||
if user.username == settings.ANONYMOUS_USER_NAME:
|
||||
user_data["is_anonymous"] = True
|
||||
if original_user:
|
||||
original_data = get_user(original_user)
|
||||
original_data["on_behalf_of"] = user_data
|
||||
|
||||
@ -19,12 +19,8 @@ from yaml import safe_load
|
||||
|
||||
from authentik.enterprise.providers.rac.controllers.docker import RACDockerController
|
||||
from authentik.enterprise.providers.rac.controllers.kubernetes import RACKubernetesController
|
||||
from authentik.events.monitored_tasks import (
|
||||
MonitoredTask,
|
||||
TaskResult,
|
||||
TaskResultStatus,
|
||||
prefill_task,
|
||||
)
|
||||
from authentik.events.models import TaskStatus
|
||||
from authentik.events.system_tasks import SystemTask, prefill_task
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.lib.utils.reflection import path_to_class
|
||||
from authentik.outposts.consumer import OUTPOST_GROUP
|
||||
@ -108,20 +104,18 @@ def outpost_service_connection_state(connection_pk: Any):
|
||||
|
||||
@CELERY_APP.task(
|
||||
bind=True,
|
||||
base=MonitoredTask,
|
||||
base=SystemTask,
|
||||
throws=(DatabaseError, ProgrammingError, InternalError),
|
||||
)
|
||||
@prefill_task
|
||||
def outpost_service_connection_monitor(self: MonitoredTask):
|
||||
def outpost_service_connection_monitor(self: SystemTask):
|
||||
"""Regularly check the state of Outpost Service Connections"""
|
||||
connections = OutpostServiceConnection.objects.all()
|
||||
for connection in connections.iterator():
|
||||
outpost_service_connection_state.delay(connection.pk)
|
||||
self.set_status(
|
||||
TaskResult(
|
||||
TaskResultStatus.SUCCESSFUL,
|
||||
[f"Successfully updated {len(connections)} connections."],
|
||||
)
|
||||
TaskStatus.SUCCESSFUL,
|
||||
f"Successfully updated {len(connections)} connections.",
|
||||
)
|
||||
|
||||
|
||||
@ -134,9 +128,9 @@ def outpost_controller_all():
|
||||
outpost_controller.delay(outpost.pk.hex, "up", from_cache=False)
|
||||
|
||||
|
||||
@CELERY_APP.task(bind=True, base=MonitoredTask)
|
||||
@CELERY_APP.task(bind=True, base=SystemTask)
|
||||
def outpost_controller(
|
||||
self: MonitoredTask, outpost_pk: str, action: str = "up", from_cache: bool = False
|
||||
self: SystemTask, outpost_pk: str, action: str = "up", from_cache: bool = False
|
||||
):
|
||||
"""Create/update/monitor/delete the deployment of an Outpost"""
|
||||
logs = []
|
||||
@ -161,16 +155,16 @@ def outpost_controller(
|
||||
LOGGER.debug(log)
|
||||
LOGGER.debug("-----------------Outpost Controller logs end-------------------")
|
||||
except (ControllerException, ServiceConnectionInvalid) as exc:
|
||||
self.set_status(TaskResult(TaskResultStatus.ERROR).with_error(exc))
|
||||
self.set_error(exc)
|
||||
else:
|
||||
if from_cache:
|
||||
cache.delete(CACHE_KEY_OUTPOST_DOWN % outpost_pk)
|
||||
self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL, logs))
|
||||
self.set_status(TaskStatus.SUCCESSFUL, *logs)
|
||||
|
||||
|
||||
@CELERY_APP.task(bind=True, base=MonitoredTask)
|
||||
@CELERY_APP.task(bind=True, base=SystemTask)
|
||||
@prefill_task
|
||||
def outpost_token_ensurer(self: MonitoredTask):
|
||||
def outpost_token_ensurer(self: SystemTask):
|
||||
"""Periodically ensure that all Outposts have valid Service Accounts
|
||||
and Tokens"""
|
||||
all_outposts = Outpost.objects.all()
|
||||
@ -178,10 +172,8 @@ def outpost_token_ensurer(self: MonitoredTask):
|
||||
_ = outpost.token
|
||||
outpost.build_user_permissions(outpost.user)
|
||||
self.set_status(
|
||||
TaskResult(
|
||||
TaskResultStatus.SUCCESSFUL,
|
||||
[f"Successfully checked {len(all_outposts)} Outposts."],
|
||||
)
|
||||
TaskStatus.SUCCESSFUL,
|
||||
f"Successfully checked {len(all_outposts)} Outposts.",
|
||||
)
|
||||
|
||||
|
||||
@ -256,32 +248,32 @@ def _outpost_single_update(outpost: Outpost, layer=None):
|
||||
|
||||
|
||||
@CELERY_APP.task(
|
||||
base=MonitoredTask,
|
||||
base=SystemTask,
|
||||
bind=True,
|
||||
)
|
||||
def outpost_connection_discovery(self: MonitoredTask):
|
||||
def outpost_connection_discovery(self: SystemTask):
|
||||
"""Checks the local environment and create Service connections."""
|
||||
status = TaskResult(TaskResultStatus.SUCCESSFUL)
|
||||
messages = []
|
||||
if not CONFIG.get_bool("outposts.discover"):
|
||||
status.messages.append("Outpost integration discovery is disabled")
|
||||
self.set_status(status)
|
||||
messages.append("Outpost integration discovery is disabled")
|
||||
self.set_status(TaskStatus.SUCCESSFUL, *messages)
|
||||
return
|
||||
# Explicitly check against token filename, as that's
|
||||
# only present when the integration is enabled
|
||||
if Path(SERVICE_TOKEN_FILENAME).exists():
|
||||
status.messages.append("Detected in-cluster Kubernetes Config")
|
||||
messages.append("Detected in-cluster Kubernetes Config")
|
||||
if not KubernetesServiceConnection.objects.filter(local=True).exists():
|
||||
status.messages.append("Created Service Connection for in-cluster")
|
||||
messages.append("Created Service Connection for in-cluster")
|
||||
KubernetesServiceConnection.objects.create(
|
||||
name="Local Kubernetes Cluster", local=True, kubeconfig={}
|
||||
)
|
||||
# For development, check for the existence of a kubeconfig file
|
||||
kubeconfig_path = Path(KUBE_CONFIG_DEFAULT_LOCATION).expanduser()
|
||||
if kubeconfig_path.exists():
|
||||
status.messages.append("Detected kubeconfig")
|
||||
messages.append("Detected kubeconfig")
|
||||
kubeconfig_local_name = f"k8s-{gethostname()}"
|
||||
if not KubernetesServiceConnection.objects.filter(name=kubeconfig_local_name).exists():
|
||||
status.messages.append("Creating kubeconfig Service Connection")
|
||||
messages.append("Creating kubeconfig Service Connection")
|
||||
with kubeconfig_path.open("r", encoding="utf8") as _kubeconfig:
|
||||
KubernetesServiceConnection.objects.create(
|
||||
name=kubeconfig_local_name,
|
||||
@ -290,12 +282,12 @@ def outpost_connection_discovery(self: MonitoredTask):
|
||||
unix_socket_path = urlparse(DEFAULT_UNIX_SOCKET).path
|
||||
socket = Path(unix_socket_path)
|
||||
if socket.exists() and access(socket, R_OK):
|
||||
status.messages.append("Detected local docker socket")
|
||||
messages.append("Detected local docker socket")
|
||||
if len(DockerServiceConnection.objects.filter(local=True)) == 0:
|
||||
status.messages.append("Created Service Connection for docker")
|
||||
messages.append("Created Service Connection for docker")
|
||||
DockerServiceConnection.objects.create(
|
||||
name="Local Docker connection",
|
||||
local=True,
|
||||
url=unix_socket_path,
|
||||
)
|
||||
self.set_status(status)
|
||||
self.set_status(TaskStatus.SUCCESSFUL, *messages)
|
||||
|
||||
@ -4,12 +4,8 @@ from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.events.context_processors.asn import ASN_CONTEXT_PROCESSOR
|
||||
from authentik.events.context_processors.geoip import GEOIP_CONTEXT_PROCESSOR
|
||||
from authentik.events.monitored_tasks import (
|
||||
MonitoredTask,
|
||||
TaskResult,
|
||||
TaskResultStatus,
|
||||
prefill_task,
|
||||
)
|
||||
from authentik.events.models import TaskStatus
|
||||
from authentik.events.system_tasks import SystemTask, prefill_task
|
||||
from authentik.policies.reputation.models import Reputation
|
||||
from authentik.policies.reputation.signals import CACHE_KEY_PREFIX
|
||||
from authentik.root.celery import CELERY_APP
|
||||
@ -17,9 +13,9 @@ from authentik.root.celery import CELERY_APP
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
@CELERY_APP.task(bind=True, base=MonitoredTask)
|
||||
@CELERY_APP.task(bind=True, base=SystemTask)
|
||||
@prefill_task
|
||||
def save_reputation(self: MonitoredTask):
|
||||
def save_reputation(self: SystemTask):
|
||||
"""Save currently cached reputation to database"""
|
||||
objects_to_update = []
|
||||
for _, score in cache.get_many(cache.keys(CACHE_KEY_PREFIX + "*")).items():
|
||||
@ -32,4 +28,4 @@ def save_reputation(self: MonitoredTask):
|
||||
rep.score = score["score"]
|
||||
objects_to_update.append(rep)
|
||||
Reputation.objects.bulk_update(objects_to_update, ["score", "ip_geo_data"])
|
||||
self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL, ["Successfully updated Reputation"]))
|
||||
self.set_status(TaskStatus.SUCCESSFUL, "Successfully updated Reputation")
|
||||
|
||||
@ -1,17 +1,17 @@
|
||||
"""SCIM Provider API Views"""
|
||||
from django.utils.text import slugify
|
||||
from drf_spectacular.utils import OpenApiResponse, extend_schema
|
||||
from guardian.shortcuts import get_objects_for_user
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.fields import BooleanField
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from authentik.admin.api.tasks import TaskSerializer
|
||||
from authentik.core.api.providers import ProviderSerializer
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.utils import PassiveSerializer
|
||||
from authentik.events.monitored_tasks import TaskInfo
|
||||
from authentik.events.api.tasks import SystemTaskSerializer
|
||||
from authentik.providers.scim.models import SCIMProvider
|
||||
|
||||
|
||||
@ -43,7 +43,7 @@ class SCIMSyncStatusSerializer(PassiveSerializer):
|
||||
"""SCIM Provider sync status"""
|
||||
|
||||
is_running = BooleanField(read_only=True)
|
||||
tasks = TaskSerializer(many=True, read_only=True)
|
||||
tasks = SystemTaskSerializer(many=True, read_only=True)
|
||||
|
||||
|
||||
class SCIMProviderViewSet(UsedByMixin, ModelViewSet):
|
||||
@ -65,8 +65,12 @@ class SCIMProviderViewSet(UsedByMixin, ModelViewSet):
|
||||
def sync_status(self, request: Request, pk: int) -> Response:
|
||||
"""Get provider's sync status"""
|
||||
provider: SCIMProvider = self.get_object()
|
||||
task = TaskInfo.by_name(f"scim_sync:{slugify(provider.name)}")
|
||||
tasks = [task] if task else []
|
||||
tasks = list(
|
||||
get_objects_for_user(request.user, "authentik_events.view_systemtask").filter(
|
||||
name="scim_sync",
|
||||
uid=slugify(provider.name),
|
||||
)
|
||||
)
|
||||
status = {
|
||||
"tasks": tasks,
|
||||
"is_running": provider.sync_lock.locked(),
|
||||
|
||||
@ -10,7 +10,8 @@ from pydanticscim.responses import PatchOp
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.models import Group, User
|
||||
from authentik.events.monitored_tasks import MonitoredTask, TaskResult, TaskResultStatus
|
||||
from authentik.events.models import TaskStatus
|
||||
from authentik.events.system_tasks import SystemTask
|
||||
from authentik.lib.utils.reflection import path_to_class
|
||||
from authentik.providers.scim.clients import PAGE_SIZE, PAGE_TIMEOUT
|
||||
from authentik.providers.scim.clients.base import SCIMClient
|
||||
@ -39,8 +40,8 @@ def scim_sync_all():
|
||||
scim_sync.delay(provider.pk)
|
||||
|
||||
|
||||
@CELERY_APP.task(bind=True, base=MonitoredTask)
|
||||
def scim_sync(self: MonitoredTask, provider_pk: int) -> None:
|
||||
@CELERY_APP.task(bind=True, base=SystemTask)
|
||||
def scim_sync(self: SystemTask, provider_pk: int) -> None:
|
||||
"""Run SCIM full sync for provider"""
|
||||
provider: SCIMProvider = SCIMProvider.objects.filter(
|
||||
pk=provider_pk, backchannel_application__isnull=False
|
||||
@ -52,8 +53,8 @@ def scim_sync(self: MonitoredTask, provider_pk: int) -> None:
|
||||
LOGGER.debug("SCIM sync locked, skipping task", source=provider.name)
|
||||
return
|
||||
self.set_uid(slugify(provider.name))
|
||||
result = TaskResult(TaskResultStatus.SUCCESSFUL, [])
|
||||
result.messages.append(_("Starting full SCIM sync"))
|
||||
messages = []
|
||||
messages.append(_("Starting full SCIM sync"))
|
||||
LOGGER.debug("Starting SCIM sync")
|
||||
users_paginator = Paginator(provider.get_user_qs(), PAGE_SIZE)
|
||||
groups_paginator = Paginator(provider.get_group_qs(), PAGE_SIZE)
|
||||
@ -63,17 +64,17 @@ def scim_sync(self: MonitoredTask, provider_pk: int) -> None:
|
||||
with allow_join_result():
|
||||
try:
|
||||
for page in users_paginator.page_range:
|
||||
result.messages.append(_("Syncing page %(page)d of users" % {"page": page}))
|
||||
messages.append(_("Syncing page %(page)d of users" % {"page": page}))
|
||||
for msg in scim_sync_users.delay(page, provider_pk).get():
|
||||
result.messages.append(msg)
|
||||
messages.append(msg)
|
||||
for page in groups_paginator.page_range:
|
||||
result.messages.append(_("Syncing page %(page)d of groups" % {"page": page}))
|
||||
messages.append(_("Syncing page %(page)d of groups" % {"page": page}))
|
||||
for msg in scim_sync_group.delay(page, provider_pk).get():
|
||||
result.messages.append(msg)
|
||||
messages.append(msg)
|
||||
except StopSync as exc:
|
||||
self.set_status(TaskResult(TaskResultStatus.ERROR).with_error(exc))
|
||||
self.set_error(exc)
|
||||
return
|
||||
self.set_status(result)
|
||||
self.set_status(TaskStatus.SUCCESSFUL, *messages)
|
||||
|
||||
|
||||
@CELERY_APP.task(
|
||||
|
||||
@ -6,6 +6,7 @@ from django_filters.filters import AllValuesMultipleFilter
|
||||
from django_filters.filterset import FilterSet
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import extend_schema, extend_schema_field, inline_serializer
|
||||
from guardian.shortcuts import get_objects_for_user
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.exceptions import ValidationError
|
||||
from rest_framework.fields import BooleanField, DictField, ListField, SerializerMethodField
|
||||
@ -14,13 +15,12 @@ from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from authentik.admin.api.tasks import TaskSerializer
|
||||
from authentik.core.api.propertymappings import PropertyMappingSerializer
|
||||
from authentik.core.api.sources import SourceSerializer
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.utils import PassiveSerializer
|
||||
from authentik.crypto.models import CertificateKeyPair
|
||||
from authentik.events.monitored_tasks import TaskInfo
|
||||
from authentik.events.api.tasks import SystemTaskSerializer
|
||||
from authentik.sources.ldap.models import LDAPPropertyMapping, LDAPSource
|
||||
from authentik.sources.ldap.tasks import CACHE_KEY_STATUS, SYNC_CLASSES
|
||||
|
||||
@ -91,7 +91,7 @@ class LDAPSyncStatusSerializer(PassiveSerializer):
|
||||
"""LDAP Source sync status"""
|
||||
|
||||
is_running = BooleanField(read_only=True)
|
||||
tasks = TaskSerializer(many=True, read_only=True)
|
||||
tasks = SystemTaskSerializer(many=True, read_only=True)
|
||||
|
||||
|
||||
class LDAPSourceViewSet(UsedByMixin, ModelViewSet):
|
||||
@ -136,7 +136,12 @@ class LDAPSourceViewSet(UsedByMixin, ModelViewSet):
|
||||
def sync_status(self, request: Request, slug: str) -> Response:
|
||||
"""Get source's sync status"""
|
||||
source: LDAPSource = self.get_object()
|
||||
tasks = TaskInfo.by_name(f"ldap_sync:{source.slug}:*") or []
|
||||
tasks = list(
|
||||
get_objects_for_user(request.user, "authentik_events.view_systemtask").filter(
|
||||
name="ldap_sync",
|
||||
uid__startswith=source.slug,
|
||||
)
|
||||
)
|
||||
status = {
|
||||
"tasks": tasks,
|
||||
"is_running": source.sync_lock.locked(),
|
||||
|
||||
@ -1,9 +1,7 @@
|
||||
"""FreeIPA specific"""
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Generator
|
||||
|
||||
from pytz import UTC
|
||||
|
||||
from authentik.core.models import User
|
||||
from authentik.sources.ldap.sync.base import BaseLDAPSynchronizer, flatten
|
||||
|
||||
@ -27,7 +25,7 @@ class FreeIPA(BaseLDAPSynchronizer):
|
||||
if "krbLastPwdChange" not in attributes:
|
||||
return
|
||||
pwd_last_set: datetime = attributes.get("krbLastPwdChange", datetime.now())
|
||||
pwd_last_set = pwd_last_set.replace(tzinfo=UTC)
|
||||
pwd_last_set = pwd_last_set.replace(tzinfo=timezone.utc)
|
||||
if created or pwd_last_set >= user.password_change_date:
|
||||
self.message(f"'{user.username}': Reset user's password")
|
||||
self._logger.debug(
|
||||
|
||||
6
authentik/sources/ldap/sync/vendor/ms_ad.py
vendored
6
authentik/sources/ldap/sync/vendor/ms_ad.py
vendored
@ -1,10 +1,8 @@
|
||||
"""Active Directory specific"""
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from enum import IntFlag
|
||||
from typing import Any, Generator
|
||||
|
||||
from pytz import UTC
|
||||
|
||||
from authentik.core.models import User
|
||||
from authentik.sources.ldap.sync.base import BaseLDAPSynchronizer
|
||||
|
||||
@ -58,7 +56,7 @@ class MicrosoftActiveDirectory(BaseLDAPSynchronizer):
|
||||
if "pwdLastSet" not in attributes:
|
||||
return
|
||||
pwd_last_set: datetime = attributes.get("pwdLastSet", datetime.now())
|
||||
pwd_last_set = pwd_last_set.replace(tzinfo=UTC)
|
||||
pwd_last_set = pwd_last_set.replace(tzinfo=timezone.utc)
|
||||
if created or pwd_last_set >= user.password_change_date:
|
||||
self.message(f"'{user.username}': Reset user's password")
|
||||
self._logger.debug(
|
||||
|
||||
@ -8,8 +8,9 @@ from ldap3.core.exceptions import LDAPException
|
||||
from redis.exceptions import LockError
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.events.monitored_tasks import CACHE_KEY_PREFIX as CACHE_KEY_PREFIX_TASKS
|
||||
from authentik.events.monitored_tasks import MonitoredTask, TaskResult, TaskResultStatus
|
||||
from authentik.events.models import SystemTask as DBSystemTask
|
||||
from authentik.events.models import TaskStatus
|
||||
from authentik.events.system_tasks import SystemTask
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
from authentik.lib.utils.reflection import class_to_path, path_to_class
|
||||
@ -34,7 +35,7 @@ CACHE_KEY_STATUS = "goauthentik.io/sources/ldap/status/"
|
||||
def ldap_sync_all():
|
||||
"""Sync all sources"""
|
||||
for source in LDAPSource.objects.filter(enabled=True):
|
||||
ldap_sync_single.apply_async(args=[source.pk])
|
||||
ldap_sync_single.apply_async(args=[str(source.pk)])
|
||||
|
||||
|
||||
@CELERY_APP.task()
|
||||
@ -69,8 +70,7 @@ def ldap_sync_single(source_pk: str):
|
||||
try:
|
||||
with lock:
|
||||
# Delete all sync tasks from the cache
|
||||
keys = cache.keys(f"{CACHE_KEY_PREFIX_TASKS}ldap_sync:{source.slug}*")
|
||||
cache.delete_many(keys)
|
||||
DBSystemTask.objects.filter(name="ldap_sync", uid__startswith=source.slug).delete()
|
||||
task = chain(
|
||||
# User and group sync can happen at once, they have no dependencies on each other
|
||||
group(
|
||||
@ -96,18 +96,18 @@ def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) ->
|
||||
for page in sync_inst.get_objects():
|
||||
page_cache_key = CACHE_KEY_PREFIX + str(uuid4())
|
||||
cache.set(page_cache_key, page, 60 * 60 * CONFIG.get_int("ldap.task_timeout_hours"))
|
||||
page_sync = ldap_sync.si(source.pk, class_to_path(sync), page_cache_key)
|
||||
page_sync = ldap_sync.si(str(source.pk), class_to_path(sync), page_cache_key)
|
||||
signatures.append(page_sync)
|
||||
return signatures
|
||||
|
||||
|
||||
@CELERY_APP.task(
|
||||
bind=True,
|
||||
base=MonitoredTask,
|
||||
base=SystemTask,
|
||||
soft_time_limit=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours"),
|
||||
task_time_limit=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours"),
|
||||
)
|
||||
def ldap_sync(self: MonitoredTask, source_pk: str, sync_class: str, page_cache_key: str):
|
||||
def ldap_sync(self: SystemTask, source_pk: str, sync_class: str, page_cache_key: str):
|
||||
"""Synchronization of an LDAP Source"""
|
||||
self.result_timeout_hours = CONFIG.get_int("ldap.task_timeout_hours")
|
||||
source: LDAPSource = LDAPSource.objects.filter(pk=source_pk).first()
|
||||
@ -127,20 +127,18 @@ def ldap_sync(self: MonitoredTask, source_pk: str, sync_class: str, page_cache_k
|
||||
+ "Try increasing ldap.task_timeout_hours"
|
||||
)
|
||||
LOGGER.warning(error_message)
|
||||
self.set_status(TaskResult(TaskResultStatus.ERROR, [error_message]))
|
||||
self.set_status(TaskStatus.ERROR, error_message)
|
||||
return
|
||||
cache.touch(page_cache_key)
|
||||
count = sync_inst.sync(page)
|
||||
messages = sync_inst.messages
|
||||
messages.append(f"Synced {count} objects.")
|
||||
self.set_status(
|
||||
TaskResult(
|
||||
TaskResultStatus.SUCCESSFUL,
|
||||
messages,
|
||||
)
|
||||
TaskStatus.SUCCESSFUL,
|
||||
*messages,
|
||||
)
|
||||
cache.delete(page_cache_key)
|
||||
except LDAPException as exc:
|
||||
# No explicit event is created here as .set_status with an error will do that
|
||||
LOGGER.warning(exception_to_string(exc))
|
||||
self.set_status(TaskResult(TaskResultStatus.ERROR).with_error(exc))
|
||||
self.set_error(exc)
|
||||
|
||||
@ -7,8 +7,8 @@ from django.test import TestCase
|
||||
from authentik.blueprints.tests import apply_blueprint
|
||||
from authentik.core.models import Group, User
|
||||
from authentik.core.tests.utils import create_test_admin_user
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.events.monitored_tasks import TaskInfo, TaskResultStatus
|
||||
from authentik.events.models import Event, EventAction, SystemTask
|
||||
from authentik.events.system_tasks import TaskStatus
|
||||
from authentik.lib.generators import generate_id, generate_key
|
||||
from authentik.lib.utils.reflection import class_to_path
|
||||
from authentik.sources.ldap.models import LDAPPropertyMapping, LDAPSource
|
||||
@ -40,9 +40,9 @@ class LDAPSyncTests(TestCase):
|
||||
"""Test sync with missing page"""
|
||||
connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD))
|
||||
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
|
||||
ldap_sync.delay(self.source.pk, class_to_path(UserLDAPSynchronizer), "foo").get()
|
||||
status = TaskInfo.by_name("ldap_sync:ldap:users:foo")
|
||||
self.assertEqual(status.result.status, TaskResultStatus.ERROR)
|
||||
ldap_sync.delay(str(self.source.pk), class_to_path(UserLDAPSynchronizer), "foo").get()
|
||||
task = SystemTask.objects.filter(name="ldap_sync", uid="ldap:users:foo").first()
|
||||
self.assertEqual(task.status, TaskStatus.ERROR)
|
||||
|
||||
def test_sync_error(self):
|
||||
"""Test user sync"""
|
||||
|
||||
@ -4,7 +4,8 @@ from json import dumps
|
||||
from requests import RequestException
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.events.monitored_tasks import MonitoredTask, TaskResult, TaskResultStatus
|
||||
from authentik.events.models import TaskStatus
|
||||
from authentik.events.system_tasks import SystemTask
|
||||
from authentik.lib.utils.http import get_http_session
|
||||
from authentik.root.celery import CELERY_APP
|
||||
from authentik.sources.oauth.models import OAuthSource
|
||||
@ -12,11 +13,11 @@ from authentik.sources.oauth.models import OAuthSource
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
@CELERY_APP.task(bind=True, base=MonitoredTask)
|
||||
def update_well_known_jwks(self: MonitoredTask):
|
||||
@CELERY_APP.task(bind=True, base=SystemTask)
|
||||
def update_well_known_jwks(self: SystemTask):
|
||||
"""Update OAuth sources' config from well_known, and JWKS info from the configured URL"""
|
||||
session = get_http_session()
|
||||
result = TaskResult(TaskResultStatus.SUCCESSFUL, [])
|
||||
messages = []
|
||||
for source in OAuthSource.objects.all().exclude(oidc_well_known_url=""):
|
||||
try:
|
||||
well_known_config = session.get(source.oidc_well_known_url)
|
||||
@ -24,7 +25,7 @@ def update_well_known_jwks(self: MonitoredTask):
|
||||
except RequestException as exc:
|
||||
text = exc.response.text if exc.response else str(exc)
|
||||
LOGGER.warning("Failed to update well_known", source=source, exc=exc, text=text)
|
||||
result.messages.append(f"Failed to update OIDC configuration for {source.slug}")
|
||||
messages.append(f"Failed to update OIDC configuration for {source.slug}")
|
||||
continue
|
||||
config = well_known_config.json()
|
||||
try:
|
||||
@ -47,7 +48,7 @@ def update_well_known_jwks(self: MonitoredTask):
|
||||
source=source,
|
||||
exc=exc,
|
||||
)
|
||||
result.messages.append(f"Failed to update OIDC configuration for {source.slug}")
|
||||
messages.append(f"Failed to update OIDC configuration for {source.slug}")
|
||||
continue
|
||||
if dirty:
|
||||
LOGGER.info("Updating sources' OpenID Configuration", source=source)
|
||||
@ -60,11 +61,11 @@ def update_well_known_jwks(self: MonitoredTask):
|
||||
except RequestException as exc:
|
||||
text = exc.response.text if exc.response else str(exc)
|
||||
LOGGER.warning("Failed to update JWKS", source=source, exc=exc, text=text)
|
||||
result.messages.append(f"Failed to update JWKS for {source.slug}")
|
||||
messages.append(f"Failed to update JWKS for {source.slug}")
|
||||
continue
|
||||
config = jwks_config.json()
|
||||
if dumps(source.oidc_jwks, sort_keys=True) != dumps(config, sort_keys=True):
|
||||
source.oidc_jwks = config
|
||||
LOGGER.info("Updating sources' JWKS", source=source)
|
||||
source.save()
|
||||
self.set_status(result)
|
||||
self.set_status(TaskStatus.SUCCESSFUL, *messages)
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
"""Plex tasks"""
|
||||
from requests import RequestException
|
||||
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.events.monitored_tasks import MonitoredTask, TaskResult, TaskResultStatus
|
||||
from authentik.events.models import Event, EventAction, TaskStatus
|
||||
from authentik.events.system_tasks import SystemTask
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
from authentik.root.celery import CELERY_APP
|
||||
from authentik.sources.plex.models import PlexSource
|
||||
@ -16,8 +16,8 @@ def check_plex_token_all():
|
||||
check_plex_token.delay(source.slug)
|
||||
|
||||
|
||||
@CELERY_APP.task(bind=True, base=MonitoredTask)
|
||||
def check_plex_token(self: MonitoredTask, source_slug: int):
|
||||
@CELERY_APP.task(bind=True, base=SystemTask)
|
||||
def check_plex_token(self: SystemTask, source_slug: int):
|
||||
"""Check the validity of a Plex source."""
|
||||
sources = PlexSource.objects.filter(slug=source_slug)
|
||||
if not sources.exists():
|
||||
@ -27,16 +27,15 @@ def check_plex_token(self: MonitoredTask, source_slug: int):
|
||||
auth = PlexAuth(source, source.plex_token)
|
||||
try:
|
||||
auth.get_user_info()
|
||||
self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL, ["Plex token is valid."]))
|
||||
self.set_status(TaskStatus.SUCCESSFUL, "Plex token is valid.")
|
||||
except RequestException as exc:
|
||||
error = exception_to_string(exc)
|
||||
if len(source.plex_token) > 0:
|
||||
error = error.replace(source.plex_token, "$PLEX_TOKEN")
|
||||
self.set_status(
|
||||
TaskResult(
|
||||
TaskResultStatus.ERROR,
|
||||
["Plex token is invalid/an error occurred:", error],
|
||||
)
|
||||
TaskStatus.ERROR,
|
||||
"Plex token is invalid/an error occurred:",
|
||||
error,
|
||||
)
|
||||
Event.new(
|
||||
EventAction.CONFIGURATION_ERROR,
|
||||
|
||||
@ -9,8 +9,8 @@ from django.core.mail.utils import DNS_NAME
|
||||
from django.utils.text import slugify
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.events.monitored_tasks import MonitoredTask, TaskResult, TaskResultStatus
|
||||
from authentik.events.models import Event, EventAction, TaskStatus
|
||||
from authentik.events.system_tasks import SystemTask
|
||||
from authentik.root.celery import CELERY_APP
|
||||
from authentik.stages.email.models import EmailStage
|
||||
from authentik.stages.email.utils import logo_data
|
||||
@ -22,7 +22,7 @@ def send_mails(stage: EmailStage, *messages: list[EmailMultiAlternatives]):
|
||||
"""Wrapper to convert EmailMessage to dict and send it from worker"""
|
||||
tasks = []
|
||||
for message in messages:
|
||||
tasks.append(send_mail.s(message.__dict__, stage.pk))
|
||||
tasks.append(send_mail.s(message.__dict__, str(stage.pk)))
|
||||
lazy_group = group(*tasks)
|
||||
promise = lazy_group()
|
||||
return promise
|
||||
@ -44,9 +44,9 @@ def get_email_body(email: EmailMultiAlternatives) -> str:
|
||||
OSError,
|
||||
),
|
||||
retry_backoff=True,
|
||||
base=MonitoredTask,
|
||||
base=SystemTask,
|
||||
)
|
||||
def send_mail(self: MonitoredTask, message: dict[Any, Any], email_stage_pk: Optional[int] = None):
|
||||
def send_mail(self: SystemTask, message: dict[Any, Any], email_stage_pk: Optional[str] = None):
|
||||
"""Send Email for Email Stage. Retries are scheduled automatically."""
|
||||
self.save_on_success = False
|
||||
message_id = make_msgid(domain=DNS_NAME)
|
||||
@ -58,10 +58,8 @@ def send_mail(self: MonitoredTask, message: dict[Any, Any], email_stage_pk: Opti
|
||||
stages = EmailStage.objects.filter(pk=email_stage_pk)
|
||||
if not stages.exists():
|
||||
self.set_status(
|
||||
TaskResult(
|
||||
TaskResultStatus.WARNING,
|
||||
messages=["Email stage does not exist anymore. Discarding message."],
|
||||
)
|
||||
TaskStatus.WARNING,
|
||||
"Email stage does not exist anymore. Discarding message.",
|
||||
)
|
||||
return
|
||||
stage: EmailStage = stages.first()
|
||||
@ -69,7 +67,7 @@ def send_mail(self: MonitoredTask, message: dict[Any, Any], email_stage_pk: Opti
|
||||
backend = stage.backend
|
||||
except ValueError as exc:
|
||||
LOGGER.warning("failed to get email backend", exc=exc)
|
||||
self.set_status(TaskResult(TaskResultStatus.ERROR).with_error(exc))
|
||||
self.set_error(exc)
|
||||
return
|
||||
backend.open()
|
||||
# Since django's EmailMessage objects are not JSON serialisable,
|
||||
@ -97,12 +95,10 @@ def send_mail(self: MonitoredTask, message: dict[Any, Any], email_stage_pk: Opti
|
||||
to_email=message_object.to,
|
||||
).save()
|
||||
self.set_status(
|
||||
TaskResult(
|
||||
TaskResultStatus.SUCCESSFUL,
|
||||
messages=["Successfully sent Mail."],
|
||||
)
|
||||
TaskStatus.SUCCESSFUL,
|
||||
"Successfully sent Mail.",
|
||||
)
|
||||
except (SMTPException, ConnectionError, OSError) as exc:
|
||||
LOGGER.debug("Error sending email, retrying...", exc=exc)
|
||||
self.set_status(TaskResult(TaskResultStatus.ERROR).with_error(exc))
|
||||
self.set_error(exc)
|
||||
raise exc
|
||||
|
||||
Reference in New Issue
Block a user