Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
Marc 'risson' Schmitt
2025-06-03 18:25:49 +02:00
parent b0af20b0d5
commit 4302f91028
9 changed files with 130 additions and 49 deletions

View File

@ -14,7 +14,6 @@ from authentik.events.models import Event, EventAction, Notification
from authentik.lib.config import CONFIG from authentik.lib.config import CONFIG
from authentik.lib.utils.http import get_http_session from authentik.lib.utils.http import get_http_session
from authentik.tasks.middleware import CurrentTask from authentik.tasks.middleware import CurrentTask
from authentik.tasks.models import Task, TaskStatus
LOGGER = get_logger() LOGGER = get_logger()
VERSION_NULL = "0.0.0" VERSION_NULL = "0.0.0"
@ -49,10 +48,10 @@ def clear_update_notifications():
@actor @actor
def update_latest_version(): def update_latest_version():
"""Update latest version info""" """Update latest version info"""
self: Task = CurrentTask.get_task() self = CurrentTask.get_task()
if CONFIG.get_bool("disable_update_check"): if CONFIG.get_bool("disable_update_check"):
cache.set(VERSION_CACHE_KEY, VERSION_NULL, VERSION_CACHE_TIMEOUT) cache.set(VERSION_CACHE_KEY, VERSION_NULL, VERSION_CACHE_TIMEOUT)
self.set_status(TaskStatus.WARNING, "Version check disabled.") self.warning("Version check disabled.")
return return
try: try:
response = get_http_session().get( response = get_http_session().get(
@ -62,7 +61,7 @@ def update_latest_version():
data = response.json() data = response.json()
upstream_version = data.get("stable", {}).get("version") upstream_version = data.get("stable", {}).get("version")
cache.set(VERSION_CACHE_KEY, upstream_version, VERSION_CACHE_TIMEOUT) cache.set(VERSION_CACHE_KEY, upstream_version, VERSION_CACHE_TIMEOUT)
self.set_status(TaskStatus.SUCCESSFUL, "Successfully updated latest Version") self.info("Successfully updated latest Version")
_set_prom_info() _set_prom_info()
# Check if upstream version is newer than what we're running, # Check if upstream version is newer than what we're running,
# and if no event exists yet, create one. # and if no event exists yet, create one.
@ -85,7 +84,7 @@ def update_latest_version():
).save() ).save()
except (RequestException, IndexError) as exc: except (RequestException, IndexError) as exc:
cache.set(VERSION_CACHE_KEY, VERSION_NULL, VERSION_CACHE_TIMEOUT) cache.set(VERSION_CACHE_KEY, VERSION_NULL, VERSION_CACHE_TIMEOUT)
self.set_error(exc) raise exc
_set_prom_info() _set_prom_info()

View File

@ -154,9 +154,7 @@ def blueprints_discovery(path: str | None = None):
continue continue
check_blueprint_v1_file(blueprint) check_blueprint_v1_file(blueprint)
count += 1 count += 1
self.set_status( self.info(_("Successfully imported {count} files.".format(count=count)))
TaskStatus.SUCCESSFUL, _("Successfully imported {count} files.".format(count=count))
)
def check_blueprint_v1_file(blueprint: BlueprintFile): def check_blueprint_v1_file(blueprint: BlueprintFile):
@ -189,15 +187,13 @@ def check_blueprint_v1_file(blueprint: BlueprintFile):
@actor @actor
def apply_blueprint(instance_pk: str): def apply_blueprint(instance_pk: str):
"""Apply single blueprint""" """Apply single blueprint"""
self: Task = CurrentTask.get_task() self = CurrentTask.get_task()
# TODO: fixme
# self.save_on_success = False
instance: BlueprintInstance | None = None instance: BlueprintInstance | None = None
try: try:
instance: BlueprintInstance = BlueprintInstance.objects.filter(pk=instance_pk).first() instance: BlueprintInstance = BlueprintInstance.objects.filter(pk=instance_pk).first()
if not instance or not instance.enabled: if not instance or not instance.enabled:
return return
self.set_uid(slugify(instance.name)) self.uid = slugify(instance.name)
blueprint_content = instance.retrieve() blueprint_content = instance.retrieve()
file_hash = sha512(blueprint_content.encode()).hexdigest() file_hash = sha512(blueprint_content.encode()).hexdigest()
importer = Importer.from_string(blueprint_content, instance.context) importer = Importer.from_string(blueprint_content, instance.context)
@ -207,19 +203,18 @@ def apply_blueprint(instance_pk: str):
if not valid: if not valid:
instance.status = BlueprintInstanceStatus.ERROR instance.status = BlueprintInstanceStatus.ERROR
instance.save() instance.save()
self.set_status(TaskStatus.ERROR, *logs) self.error(*logs)
return return
with capture_logs() as logs: with capture_logs() as logs:
applied = importer.apply() applied = importer.apply()
if not applied: if not applied:
instance.status = BlueprintInstanceStatus.ERROR instance.status = BlueprintInstanceStatus.ERROR
instance.save() instance.save()
self.set_status(TaskStatus.ERROR, *logs) self.error(*logs)
return return
instance.status = BlueprintInstanceStatus.SUCCESSFUL instance.status = BlueprintInstanceStatus.SUCCESSFUL
instance.last_applied_hash = file_hash instance.last_applied_hash = file_hash
instance.last_applied = now() instance.last_applied = now()
self.set_status(TaskStatus.SUCCESSFUL)
except ( except (
OSError, OSError,
DatabaseError, DatabaseError,
@ -230,7 +225,7 @@ def apply_blueprint(instance_pk: str):
) as exc: ) as exc:
if instance: if instance:
instance.status = BlueprintInstanceStatus.ERROR instance.status = BlueprintInstanceStatus.ERROR
self.set_error(exc) self.error(exc)
finally: finally:
if instance: if instance:
instance.save() instance.save()

View File

@ -23,7 +23,6 @@ class TaskSerializer(ModelSerializer):
"actor_name", "actor_name",
"state", "state",
"mtime", "mtime",
"schedule_uid",
"uid", "uid",
"messages", "messages",
] ]
@ -41,13 +40,11 @@ class TaskViewSet(
"queue_name", "queue_name",
"actor_name", "actor_name",
"state", "state",
"schedule_uid",
) )
filterset_fields = ( filterset_fields = (
"queue_name", "queue_name",
"actor_name", "actor_name",
"state", "state",
"schedule_uid",
) )
ordering = ( ordering = (
"actor_name", "actor_name",

View File

@ -98,7 +98,7 @@ class PostgresBroker(Broker):
self.queues = set() self.queues = set()
self.actor_options = { self.actor_options = {
"schedule_uid", "rel_obj",
} }
self.db_alias = db_alias self.db_alias = db_alias
@ -192,7 +192,7 @@ class PostgresBroker(Broker):
"actor_name": message.actor_name, "actor_name": message.actor_name,
"state": TaskState.QUEUED, "state": TaskState.QUEUED,
"message": message.encode(), "message": message.encode(),
"schedule_uid": message.options.get("schedule_uid", ""), "rel_obj": message.options.get("rel_obj", None),
} }
create_defaults = { create_defaults = {
**query, **query,

View File

@ -0,0 +1,52 @@
# Generated by Django 5.1.9 on 2025-06-03 15:54
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_tasks", "0001_initial"),
("authentik_tenants", "0005_tenant_reputation_lower_limit_and_more"),
("contenttypes", "0002_remove_content_type_name"),
]
operations = [
migrations.RemoveField(
model_name="task",
name="description",
),
migrations.RemoveField(
model_name="task",
name="schedule_uid",
),
migrations.RemoveField(
model_name="task",
name="status",
),
migrations.RemoveField(
model_name="task",
name="uid",
),
migrations.AddField(
model_name="task",
name="rel_obj_content_type",
field=models.ForeignKey(
null=True,
on_delete=django.db.models.deletion.CASCADE,
to="contenttypes.contenttype",
),
),
migrations.AddField(
model_name="task",
name="rel_obj_id",
field=models.TextField(null=True),
),
migrations.AddIndex(
model_name="task",
index=models.Index(
fields=["rel_obj_content_type", "rel_obj_id"], name="authentik_t_rel_obj_3a177a_idx"
),
),
]

View File

@ -0,0 +1,18 @@
# Generated by Django 5.1.9 on 2025-06-03 16:02
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_tasks", "0002_remove_task_description_remove_task_schedule_uid_and_more"),
]
operations = [
migrations.AddField(
model_name="task",
name="uid",
field=models.TextField(blank=True),
),
]

View File

@ -0,0 +1,18 @@
# Generated by Django 5.1.9 on 2025-06-03 16:03
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_tasks", "0003_task_uid"),
]
operations = [
migrations.AlterField(
model_name="task",
name="uid",
field=models.TextField(blank=True, null=True),
),
]

View File

@ -1,6 +1,7 @@
from enum import StrEnum, auto from enum import StrEnum, auto
from uuid import uuid4 from uuid import uuid4
from django.contrib.contenttypes.fields import ContentType, GenericForeignKey
import pgtrigger import pgtrigger
from django.db import models from django.db import models
from django.utils import timezone from django.utils import timezone
@ -32,8 +33,7 @@ class TaskState(models.TextChoices):
class TaskStatus(models.TextChoices): class TaskStatus(models.TextChoices):
"""Task soft-state. Self-reported by the task""" """Task soft-state. Self-reported by the task"""
UNKNOWN = "unknown" INFO = "info"
SUCCESSFUL = "successful"
WARNING = "warning" WARNING = "warning"
ERROR = "error" ERROR = "error"
@ -59,18 +59,21 @@ class Task(SerializerModel):
result = models.BinaryField(null=True, help_text=_("Task result")) result = models.BinaryField(null=True, help_text=_("Task result"))
result_expiry = models.DateTimeField(null=True, help_text=_("Result expiry time")) result_expiry = models.DateTimeField(null=True, help_text=_("Result expiry time"))
schedule_uid = models.TextField(blank=True) rel_obj_content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE, null=True)
uid = models.TextField(blank=True) rel_obj_id = models.TextField(null=True)
# Probably only have one `logs` field rel_obj = GenericForeignKey("rel_obj_content_type", "rel_obj_id")
description = models.TextField(blank=True)
status = models.TextField(blank=True, choices=TaskStatus.choices) uid = models.TextField(blank=True, null=True)
messages = models.JSONField(default=list) messages = models.JSONField(default=list)
class Meta: class Meta:
verbose_name = _("Task") verbose_name = _("Task")
verbose_name_plural = _("Tasks") verbose_name_plural = _("Tasks")
default_permissions = ("view",) default_permissions = ("view",)
indexes = (models.Index(fields=("state", "mtime")),) indexes = (
models.Index(fields=("state", "mtime")),
models.Index(fields=("rel_obj_content_type", "rel_obj_id")),
)
triggers = ( triggers = (
pgtrigger.Trigger( pgtrigger.Trigger(
name="notify_enqueueing", name="notify_enqueueing",
@ -97,24 +100,23 @@ class Task(SerializerModel):
return TaskSerializer return TaskSerializer
def set_uid(self, uid: str): def log(self, status: TaskStatus, *messages: str | LogEvent | Exception, save=False):
"""Set UID, so in the case of an unexpected error its saved correctly""" self.messages: list
self.uid = uid for msg in messages:
message = msg
if isinstance(message, Exception):
message = exception_to_string(message)
if not isinstance(message, LogEvent):
message = LogEvent(message, logger=self.actor_name, log_level=status.value)
self.messages.append(sanitize_item(message))
if save:
self.save()
def set_status(self, status: TaskStatus, *messages: LogEvent | str): def info(self, *messages: str | LogEvent | Exception, save=False):
"""Set result for current run, will overwrite previous result.""" self.log(TaskStatus.INFO, *messages, save=save)
self.status = status
self.messages = list(messages)
for idx, msg in enumerate(self.messages):
if not isinstance(msg, LogEvent):
self.messages[idx] = LogEvent(msg, logger=str(self), log_level="info")
self.messages = sanitize_item(self.messages)
def set_error(self, exception: Exception, *messages: LogEvent | str): def warning(self, *messages: str | LogEvent | Exception, save=False):
"""Set result to error and save exception""" self.log(TaskStatus.WARNING, *messages, save=save)
self.status = TaskStatus.ERROR
self.messages = list(messages) def error(self, *messages: str | LogEvent | Exception, save=False):
self.messages.extend( self.log(TaskStatus.ERROR, *messages, save=save)
[LogEvent(exception_to_string(exception), logger=str(self), log_level="error")]
)
self.messages = sanitize_item(self.messages)

View File

@ -71,7 +71,7 @@ class Schedule(SerializerModel):
return actor.send_with_options( return actor.send_with_options(
args=pickle.loads(self.args), # nosec args=pickle.loads(self.args), # nosec
kwargs=pickle.loads(self.kwargs), # nosec kwargs=pickle.loads(self.kwargs), # nosec
schedule_uid=self.uid, rel_obj=self,
**pickle.loads(self.options), # nosec **pickle.loads(self.options), # nosec
) )