Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
Marc 'risson' Schmitt
2025-06-04 15:57:31 +02:00
parent 4302f91028
commit eb87e30076
3 changed files with 32 additions and 13 deletions

View File

@ -135,12 +135,14 @@ class AuthentikBlueprintsConfig(ManagedAppConfig):
def blueprints_discovery(self):
"""Run blueprint discovery"""
from authentik.tasks.schedules.models import Schedule
from authentik.blueprints.v1.tasks import blueprints_discovery, clear_failed_blueprints
for schedule in Schedule.objects.filter(
actor_name__in=(
"authentik.blueprints.v1.tasks.blueprints_discovery",
"authentik.blueprints.v1.tasks.clear_failed_blueprints",
blueprints_discovery.actor_name,
clear_failed_blueprints.actor_name,
),
paused=False,
):
schedule.send()

View File

@ -28,6 +28,7 @@ from authentik.blueprints.models import (
BlueprintInstanceStatus,
BlueprintRetrievalFailed,
)
from authentik.tasks.schedules.models import Schedule
from authentik.blueprints.v1.common import BlueprintLoader, BlueprintMetadata, EntryInvalidError
from authentik.blueprints.v1.importer import Importer
from authentik.blueprints.v1.labels import LABEL_AUTHENTIK_INSTANTIATE
@ -90,7 +91,13 @@ class BlueprintEventHandler(FileSystemEventHandler):
LOGGER.debug("new blueprint file created, starting discovery")
for tenant in Tenant.objects.filter(ready=True):
with tenant:
blueprints_discovery.send()
schedule = Schedule.objects.filter(
actor_name=blueprints_discovery.actor_name,
paused=False,
).first()
if schedule:
schedule.send()
# Schedule was paused or doesn't exist, no dispatch
def on_modified(self, event: FileSystemEvent):
"""Process file modification"""
@ -101,7 +108,7 @@ class BlueprintEventHandler(FileSystemEventHandler):
with tenant:
for instance in BlueprintInstance.objects.filter(path=rel_path, enabled=True):
LOGGER.debug("modified blueprint file, starting apply", instance=instance)
apply_blueprint.send(instance.pk)
apply_blueprint.send_with_options(args=(instance.pk,), rel_obj=instance)
@actor(
@ -147,14 +154,14 @@ def blueprints_find() -> list[BlueprintFile]:
@actor(throws=(DatabaseError, ProgrammingError, InternalError))
def blueprints_discovery(path: str | None = None):
"""Find blueprints and check if they need to be created in the database"""
self: Task = CurrentTask.get_task()
self = CurrentTask.get_task()
count = 0
for blueprint in blueprints_find():
if path and blueprint.path != path:
continue
check_blueprint_v1_file(blueprint)
count += 1
self.info(_("Successfully imported {count} files.".format(count=count)))
self.info(f"Successfully imported {count} files.")
def check_blueprint_v1_file(blueprint: BlueprintFile):
@ -181,19 +188,24 @@ def check_blueprint_v1_file(blueprint: BlueprintFile):
)
if instance.last_applied_hash != blueprint.hash:
LOGGER.info("Applying blueprint due to changed file", instance=instance, path=instance.path)
apply_blueprint.send(instance.pk)
apply_blueprint.send_with_options(args=(instance.pk,), rel_obj=instance)
@actor
def apply_blueprint(instance_pk: str):
"""Apply single blueprint"""
self = CurrentTask.get_task()
self.set_uid(str(instance_pk))
instance: BlueprintInstance | None = None
try:
instance: BlueprintInstance = BlueprintInstance.objects.filter(pk=instance_pk).first()
if not instance or not instance.enabled:
if not instance:
self.warning(f"Could not find blueprint {instance_pk}, skipping")
return
self.set_uid(slugify(instance.name))
if not instance.enabled:
self.info(f"Blueprint {instance.name} is disabled, skipping")
return
self.uid = slugify(instance.name)
blueprint_content = instance.retrieve()
file_hash = sha512(blueprint_content.encode()).hexdigest()
importer = Importer.from_string(blueprint_content, instance.context)

View File

@ -100,7 +100,12 @@ class Task(SerializerModel):
return TaskSerializer
def log(self, status: TaskStatus, *messages: str | LogEvent | Exception, save=False):
def set_uid(self, uid: str, save: bool = False):
self.uid = uid
if save:
self.save()
def log(self, status: TaskStatus, *messages: str | LogEvent | Exception, save: bool = False):
self.messages: list
for msg in messages:
message = msg
@ -112,11 +117,11 @@ class Task(SerializerModel):
if save:
self.save()
def info(self, *messages: str | LogEvent | Exception, save=False):
def info(self, *messages: str | LogEvent | Exception, save: bool = False):
self.log(TaskStatus.INFO, *messages, save=save)
def warning(self, *messages: str | LogEvent | Exception, save=False):
def warning(self, *messages: str | LogEvent | Exception, save: bool = False):
self.log(TaskStatus.WARNING, *messages, save=save)
def error(self, *messages: str | LogEvent | Exception, save=False):
def error(self, *messages: str | LogEvent | Exception, save: bool = False):
self.log(TaskStatus.ERROR, *messages, save=save)