Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
Marc 'risson' Schmitt
2025-06-03 18:25:49 +02:00
parent b0af20b0d5
commit 4302f91028
9 changed files with 130 additions and 49 deletions

View File

@ -14,7 +14,6 @@ from authentik.events.models import Event, EventAction, Notification
from authentik.lib.config import CONFIG
from authentik.lib.utils.http import get_http_session
from authentik.tasks.middleware import CurrentTask
from authentik.tasks.models import Task, TaskStatus
LOGGER = get_logger()
VERSION_NULL = "0.0.0"
@ -49,10 +48,10 @@ def clear_update_notifications():
@actor
def update_latest_version():
"""Update latest version info"""
self: Task = CurrentTask.get_task()
self = CurrentTask.get_task()
if CONFIG.get_bool("disable_update_check"):
cache.set(VERSION_CACHE_KEY, VERSION_NULL, VERSION_CACHE_TIMEOUT)
self.set_status(TaskStatus.WARNING, "Version check disabled.")
self.warning("Version check disabled.")
return
try:
response = get_http_session().get(
@ -62,7 +61,7 @@ def update_latest_version():
data = response.json()
upstream_version = data.get("stable", {}).get("version")
cache.set(VERSION_CACHE_KEY, upstream_version, VERSION_CACHE_TIMEOUT)
self.set_status(TaskStatus.SUCCESSFUL, "Successfully updated latest Version")
self.info("Successfully updated latest Version")
_set_prom_info()
# Check if upstream version is newer than what we're running,
# and if no event exists yet, create one.
@ -85,7 +84,7 @@ def update_latest_version():
).save()
except (RequestException, IndexError) as exc:
cache.set(VERSION_CACHE_KEY, VERSION_NULL, VERSION_CACHE_TIMEOUT)
self.set_error(exc)
raise exc
_set_prom_info()

View File

@ -154,9 +154,7 @@ def blueprints_discovery(path: str | None = None):
continue
check_blueprint_v1_file(blueprint)
count += 1
self.set_status(
TaskStatus.SUCCESSFUL, _("Successfully imported {count} files.".format(count=count))
)
self.info(_("Successfully imported {count} files.".format(count=count)))
def check_blueprint_v1_file(blueprint: BlueprintFile):
@ -189,15 +187,13 @@ def check_blueprint_v1_file(blueprint: BlueprintFile):
@actor
def apply_blueprint(instance_pk: str):
"""Apply single blueprint"""
self: Task = CurrentTask.get_task()
# TODO: fixme
# self.save_on_success = False
self = CurrentTask.get_task()
instance: BlueprintInstance | None = None
try:
instance: BlueprintInstance = BlueprintInstance.objects.filter(pk=instance_pk).first()
if not instance or not instance.enabled:
return
self.set_uid(slugify(instance.name))
self.uid = slugify(instance.name)
blueprint_content = instance.retrieve()
file_hash = sha512(blueprint_content.encode()).hexdigest()
importer = Importer.from_string(blueprint_content, instance.context)
@ -207,19 +203,18 @@ def apply_blueprint(instance_pk: str):
if not valid:
instance.status = BlueprintInstanceStatus.ERROR
instance.save()
self.set_status(TaskStatus.ERROR, *logs)
self.error(*logs)
return
with capture_logs() as logs:
applied = importer.apply()
if not applied:
instance.status = BlueprintInstanceStatus.ERROR
instance.save()
self.set_status(TaskStatus.ERROR, *logs)
self.error(*logs)
return
instance.status = BlueprintInstanceStatus.SUCCESSFUL
instance.last_applied_hash = file_hash
instance.last_applied = now()
self.set_status(TaskStatus.SUCCESSFUL)
except (
OSError,
DatabaseError,
@ -230,7 +225,7 @@ def apply_blueprint(instance_pk: str):
) as exc:
if instance:
instance.status = BlueprintInstanceStatus.ERROR
self.set_error(exc)
self.error(exc)
finally:
if instance:
instance.save()

View File

@ -23,7 +23,6 @@ class TaskSerializer(ModelSerializer):
"actor_name",
"state",
"mtime",
"schedule_uid",
"uid",
"messages",
]
@ -41,13 +40,11 @@ class TaskViewSet(
"queue_name",
"actor_name",
"state",
"schedule_uid",
)
filterset_fields = (
"queue_name",
"actor_name",
"state",
"schedule_uid",
)
ordering = (
"actor_name",

View File

@ -98,7 +98,7 @@ class PostgresBroker(Broker):
self.queues = set()
self.actor_options = {
"schedule_uid",
"rel_obj",
}
self.db_alias = db_alias
@ -192,7 +192,7 @@ class PostgresBroker(Broker):
"actor_name": message.actor_name,
"state": TaskState.QUEUED,
"message": message.encode(),
"schedule_uid": message.options.get("schedule_uid", ""),
"rel_obj": message.options.get("rel_obj", None),
}
create_defaults = {
**query,

View File

@ -0,0 +1,52 @@
# Generated by Django 5.1.9 on 2025-06-03 15:54
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_tasks", "0001_initial"),
("authentik_tenants", "0005_tenant_reputation_lower_limit_and_more"),
("contenttypes", "0002_remove_content_type_name"),
]
operations = [
migrations.RemoveField(
model_name="task",
name="description",
),
migrations.RemoveField(
model_name="task",
name="schedule_uid",
),
migrations.RemoveField(
model_name="task",
name="status",
),
migrations.RemoveField(
model_name="task",
name="uid",
),
migrations.AddField(
model_name="task",
name="rel_obj_content_type",
field=models.ForeignKey(
null=True,
on_delete=django.db.models.deletion.CASCADE,
to="contenttypes.contenttype",
),
),
migrations.AddField(
model_name="task",
name="rel_obj_id",
field=models.TextField(null=True),
),
migrations.AddIndex(
model_name="task",
index=models.Index(
fields=["rel_obj_content_type", "rel_obj_id"], name="authentik_t_rel_obj_3a177a_idx"
),
),
]

View File

@ -0,0 +1,18 @@
# Generated by Django 5.1.9 on 2025-06-03 16:02
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_tasks", "0002_remove_task_description_remove_task_schedule_uid_and_more"),
]
operations = [
migrations.AddField(
model_name="task",
name="uid",
field=models.TextField(blank=True),
),
]

View File

@ -0,0 +1,18 @@
# Generated by Django 5.1.9 on 2025-06-03 16:03
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_tasks", "0003_task_uid"),
]
operations = [
migrations.AlterField(
model_name="task",
name="uid",
field=models.TextField(blank=True, null=True),
),
]

View File

@ -1,6 +1,7 @@
from enum import StrEnum, auto
from uuid import uuid4
from django.contrib.contenttypes.fields import ContentType, GenericForeignKey
import pgtrigger
from django.db import models
from django.utils import timezone
@ -32,8 +33,7 @@ class TaskState(models.TextChoices):
class TaskStatus(models.TextChoices):
"""Task soft-state. Self-reported by the task"""
UNKNOWN = "unknown"
SUCCESSFUL = "successful"
INFO = "info"
WARNING = "warning"
ERROR = "error"
@ -59,18 +59,21 @@ class Task(SerializerModel):
result = models.BinaryField(null=True, help_text=_("Task result"))
result_expiry = models.DateTimeField(null=True, help_text=_("Result expiry time"))
schedule_uid = models.TextField(blank=True)
uid = models.TextField(blank=True)
# Probably only have one `logs` field
description = models.TextField(blank=True)
status = models.TextField(blank=True, choices=TaskStatus.choices)
rel_obj_content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE, null=True)
rel_obj_id = models.TextField(null=True)
rel_obj = GenericForeignKey("rel_obj_content_type", "rel_obj_id")
uid = models.TextField(blank=True, null=True)
messages = models.JSONField(default=list)
class Meta:
verbose_name = _("Task")
verbose_name_plural = _("Tasks")
default_permissions = ("view",)
indexes = (models.Index(fields=("state", "mtime")),)
indexes = (
models.Index(fields=("state", "mtime")),
models.Index(fields=("rel_obj_content_type", "rel_obj_id")),
)
triggers = (
pgtrigger.Trigger(
name="notify_enqueueing",
@ -97,24 +100,23 @@ class Task(SerializerModel):
return TaskSerializer
def set_uid(self, uid: str):
"""Set UID, so in the case of an unexpected error its saved correctly"""
self.uid = uid
def log(self, status: TaskStatus, *messages: str | LogEvent | Exception, save=False):
self.messages: list
for msg in messages:
message = msg
if isinstance(message, Exception):
message = exception_to_string(message)
if not isinstance(message, LogEvent):
message = LogEvent(message, logger=self.actor_name, log_level=status.value)
self.messages.append(sanitize_item(message))
if save:
self.save()
def set_status(self, status: TaskStatus, *messages: LogEvent | str):
"""Set result for current run, will overwrite previous result."""
self.status = status
self.messages = list(messages)
for idx, msg in enumerate(self.messages):
if not isinstance(msg, LogEvent):
self.messages[idx] = LogEvent(msg, logger=str(self), log_level="info")
self.messages = sanitize_item(self.messages)
def info(self, *messages: str | LogEvent | Exception, save=False):
self.log(TaskStatus.INFO, *messages, save=save)
def set_error(self, exception: Exception, *messages: LogEvent | str):
"""Set result to error and save exception"""
self.status = TaskStatus.ERROR
self.messages = list(messages)
self.messages.extend(
[LogEvent(exception_to_string(exception), logger=str(self), log_level="error")]
)
self.messages = sanitize_item(self.messages)
def warning(self, *messages: str | LogEvent | Exception, save=False):
self.log(TaskStatus.WARNING, *messages, save=save)
def error(self, *messages: str | LogEvent | Exception, save=False):
self.log(TaskStatus.ERROR, *messages, save=save)

View File

@ -71,7 +71,7 @@ class Schedule(SerializerModel):
return actor.send_with_options(
args=pickle.loads(self.args), # nosec
kwargs=pickle.loads(self.kwargs), # nosec
schedule_uid=self.uid,
rel_obj=self,
**pickle.loads(self.options), # nosec
)