"""Blueprint importer""" from contextlib import contextmanager from copy import deepcopy from typing import Any from dacite.config import Config from dacite.core import from_dict from dacite.exceptions import DaciteError from deepmerge import always_merger from django.contrib.auth.models import Permission from django.contrib.contenttypes.models import ContentType from django.core.exceptions import FieldError from django.db.models import Model from django.db.models.query_utils import Q from django.db.transaction import atomic from django.db.utils import IntegrityError from guardian.models import UserObjectPermission from guardian.shortcuts import assign_perm from rest_framework.exceptions import ValidationError from rest_framework.serializers import BaseSerializer, Serializer from structlog.stdlib import BoundLogger, get_logger from yaml import load from authentik.blueprints.v1.common import ( Blueprint, BlueprintEntry, BlueprintEntryDesiredState, BlueprintEntryState, BlueprintLoader, EntryInvalidError, ) from authentik.blueprints.v1.meta.registry import BaseMetaModel, registry from authentik.core.models import ( AuthenticatedSession, GroupSourceConnection, PropertyMapping, Provider, Session, Source, User, UserSourceConnection, ) from authentik.enterprise.license import LicenseKey from authentik.enterprise.models import LicenseUsage from authentik.enterprise.providers.google_workspace.models import ( GoogleWorkspaceProviderGroup, GoogleWorkspaceProviderUser, ) from authentik.enterprise.providers.microsoft_entra.models import ( MicrosoftEntraProviderGroup, MicrosoftEntraProviderUser, ) from authentik.enterprise.providers.ssf.models import StreamEvent from authentik.enterprise.stages.authenticator_endpoint_gdtc.models import ( EndpointDevice, EndpointDeviceConnection, ) from authentik.events.logs import LogEvent, capture_logs from authentik.events.models import SystemTask from authentik.events.utils import cleanse_dict from authentik.flows.models import FlowToken, Stage from authentik.lib.models import SerializerModel from authentik.lib.sentry import SentryIgnoredException from authentik.lib.utils.reflection import get_apps from authentik.outposts.models import OutpostServiceConnection from authentik.policies.models import Policy, PolicyBindingModel from authentik.policies.reputation.models import Reputation from authentik.providers.oauth2.models import ( AccessToken, AuthorizationCode, DeviceToken, RefreshToken, ) from authentik.providers.rac.models import ConnectionToken from authentik.providers.scim.models import SCIMProviderGroup, SCIMProviderUser from authentik.rbac.models import Role from authentik.sources.scim.models import SCIMSourceGroup, SCIMSourceUser from authentik.stages.authenticator_webauthn.models import WebAuthnDeviceType from authentik.tenants.models import Tenant # Context set when the serializer is created in a blueprint context # Update website/docs/customize/blueprints/v1/models.md when used SERIALIZER_CONTEXT_BLUEPRINT = "blueprint_entry" def excluded_models() -> list[type[Model]]: """Return a list of all excluded models that shouldn't be exposed via API or other means (internal only, base classes, non-used objects, etc)""" from django.contrib.auth.models import Group as DjangoGroup from django.contrib.auth.models import User as DjangoUser return ( # Django only classes DjangoUser, DjangoGroup, ContentType, Permission, UserObjectPermission, # Base classes Provider, Source, PropertyMapping, UserSourceConnection, GroupSourceConnection, Stage, OutpostServiceConnection, Policy, PolicyBindingModel, # Classes that have other dependencies Session, AuthenticatedSession, # Classes which are only internally managed # FIXME: these shouldn't need to be explicitly listed, but rather based off of a mixin FlowToken, LicenseUsage, SCIMProviderGroup, SCIMProviderUser, Tenant, SystemTask, ConnectionToken, AuthorizationCode, AccessToken, RefreshToken, Reputation, WebAuthnDeviceType, SCIMSourceUser, SCIMSourceGroup, GoogleWorkspaceProviderUser, GoogleWorkspaceProviderGroup, MicrosoftEntraProviderUser, MicrosoftEntraProviderGroup, EndpointDevice, EndpointDeviceConnection, DeviceToken, StreamEvent, ) def is_model_allowed(model: type[Model]) -> bool: """Check if model is allowed""" return model not in excluded_models() and issubclass(model, SerializerModel | BaseMetaModel) class DoRollback(SentryIgnoredException): """Exception to trigger a rollback""" @contextmanager def transaction_rollback(): """Enters an atomic transaction and always triggers a rollback at the end of the block.""" try: with atomic(): yield raise DoRollback() except DoRollback: pass def rbac_models() -> dict: models = {} for app in get_apps(): for model in app.get_models(): if not is_model_allowed(model): continue models[model._meta.model_name] = app.label return models class Importer: """Import Blueprint from raw dict or YAML/JSON""" logger: BoundLogger _import: Blueprint def __init__(self, blueprint: Blueprint, context: dict | None = None): self.__pk_map: dict[Any, Model] = {} self._import = blueprint self.logger = get_logger() ctx = self.default_context() always_merger.merge(ctx, self._import.context) if context: always_merger.merge(ctx, context) self._import.context = ctx def default_context(self): """Default context""" return { "goauthentik.io/enterprise/licensed": LicenseKey.get_total().status().is_valid, "goauthentik.io/rbac/models": rbac_models(), } @staticmethod def from_string(yaml_input: str, context: dict | None = None) -> "Importer": """Parse YAML string and create blueprint importer from it""" import_dict = load(yaml_input, BlueprintLoader) try: _import = from_dict( Blueprint, import_dict, config=Config(cast=[BlueprintEntryDesiredState]) ) except DaciteError as exc: raise EntryInvalidError from exc return Importer(_import, context) @property def blueprint(self) -> Blueprint: """Get imported blueprint""" return self._import def __update_pks_for_attrs(self, attrs: dict[str, Any]) -> dict[str, Any]: """Replace any value if it is a known primary key of an other object""" def updater(value) -> Any: if value in self.__pk_map: self.logger.debug("Updating reference in entry", value=value) return self.__pk_map[value] return value for key, value in attrs.items(): try: if isinstance(value, dict): for _, _inner_key in enumerate(value): value[_inner_key] = updater(value[_inner_key]) elif isinstance(value, list): for idx, _inner_value in enumerate(value): attrs[key][idx] = updater(_inner_value) else: attrs[key] = updater(value) except TypeError: continue return attrs def __query_from_identifier(self, attrs: dict[str, Any]) -> Q: """Generate an or'd query from all identifiers in an entry""" # Since identifiers can also be pk-references to other objects (see FlowStageBinding) # we have to ensure those references are also replaced main_query = Q() if "pk" in attrs: main_query = Q(pk=attrs["pk"]) sub_query = Q() for identifier, value in attrs.items(): if identifier == "pk": continue if isinstance(value, dict): sub_query &= Q(**{f"{identifier}__contains": value}) else: sub_query &= Q(**{identifier: value}) return main_query | sub_query def _validate_single(self, entry: BlueprintEntry) -> BaseSerializer | None: # noqa: PLR0915 """Validate a single entry""" if not entry.check_all_conditions_match(self._import): self.logger.debug("One or more conditions of this entry are not fulfilled, skipping") return None model_app_label, model_name = entry.get_model(self._import).split(".") try: model: type[SerializerModel] = registry.get_model(model_app_label, model_name) except LookupError as exc: raise EntryInvalidError.from_entry(exc, entry) from exc # Don't use isinstance since we don't want to check for inheritance if not is_model_allowed(model): raise EntryInvalidError.from_entry(f"Model {model} not allowed", entry) if issubclass(model, BaseMetaModel): serializer_class: type[Serializer] = model.serializer() serializer = serializer_class( data=entry.get_attrs(self._import), context={ SERIALIZER_CONTEXT_BLUEPRINT: entry, }, ) try: serializer.is_valid(raise_exception=True) except ValidationError as exc: raise EntryInvalidError.from_entry( f"Serializer errors {serializer.errors}", validation_error=exc, entry=entry, ) from exc return serializer # If we try to validate without referencing a possible instance # we'll get a duplicate error, hence we load the model here and return # the full serializer for later usage # Because a model might have multiple unique columns, we chain all identifiers together # to create an OR query. updated_identifiers = self.__update_pks_for_attrs(entry.get_identifiers(self._import)) for key, value in list(updated_identifiers.items()): if isinstance(value, dict) and "pk" in value: del updated_identifiers[key] updated_identifiers[f"{key}"] = value["pk"] query = self.__query_from_identifier(updated_identifiers) if not query: raise EntryInvalidError.from_entry("No or invalid identifiers", entry) try: existing_models = model.objects.filter(query) except FieldError as exc: raise EntryInvalidError.from_entry(f"Invalid identifier field: {exc}", entry) from exc serializer_kwargs = {} model_instance = existing_models.first() if ( not isinstance(model(), BaseMetaModel) and model_instance and entry.state != BlueprintEntryDesiredState.MUST_CREATED ): self.logger.debug( "Initialise serializer with instance", model=model, instance=model_instance, pk=model_instance.pk, ) serializer_kwargs["instance"] = model_instance serializer_kwargs["partial"] = True elif model_instance and entry.state == BlueprintEntryDesiredState.MUST_CREATED: msg = ( f"State is set to {BlueprintEntryDesiredState.MUST_CREATED.value} " "and object exists already", ) raise EntryInvalidError.from_entry( ValidationError({k: msg for k in entry.identifiers.keys()}, "unique"), entry, ) else: self.logger.debug( "Initialised new serializer instance", model=model, **cleanse_dict(updated_identifiers), ) model_instance = model() # pk needs to be set on the model instance otherwise a new one will be generated if "pk" in updated_identifiers: model_instance.pk = updated_identifiers["pk"] serializer_kwargs["instance"] = model_instance try: full_data = self.__update_pks_for_attrs(entry.get_attrs(self._import)) except ValueError as exc: raise EntryInvalidError.from_entry(exc, entry) from exc always_merger.merge(full_data, updated_identifiers) serializer_kwargs["data"] = full_data serializer: Serializer = model().serializer( context={ SERIALIZER_CONTEXT_BLUEPRINT: entry, }, **serializer_kwargs, ) try: serializer.is_valid(raise_exception=True) except ValidationError as exc: raise EntryInvalidError.from_entry( f"Serializer errors {serializer.errors}", validation_error=exc, entry=entry, serializer=serializer, ) from exc return serializer def _apply_permissions(self, instance: Model, entry: BlueprintEntry): """Apply object-level permissions for an entry""" for perm in entry.get_permissions(self._import): if perm.user is not None: assign_perm(perm.permission, User.objects.get(pk=perm.user), instance) if perm.role is not None: role = Role.objects.get(pk=perm.role) role.assign_permission(perm.permission, obj=instance) def apply(self) -> bool: """Apply (create/update) models yaml, in database transaction""" try: with atomic(): if not self._apply_models(): self.logger.debug("Reverting changes due to error") raise IntegrityError except IntegrityError: return False self.logger.debug("Committing changes") return True def _apply_models(self, raise_errors=False) -> bool: """Apply (create/update) models yaml""" self.__pk_map = {} for entry in self._import.entries: model_app_label, model_name = entry.get_model(self._import).split(".") try: model: type[SerializerModel] = registry.get_model(model_app_label, model_name) except LookupError: self.logger.warning( "App or Model does not exist", app=model_app_label, model=model_name ) return False # Validate each single entry serializer = None try: serializer = self._validate_single(entry) except EntryInvalidError as exc: # For deleting objects we don't need the serializer to be valid if entry.get_state(self._import) == BlueprintEntryDesiredState.ABSENT: serializer = exc.serializer else: self.logger.warning(f"Entry invalid: {exc}", entry=entry, error=exc) if raise_errors: raise exc return False if not serializer: continue state = entry.get_state(self._import) if state in [ BlueprintEntryDesiredState.PRESENT, BlueprintEntryDesiredState.CREATED, BlueprintEntryDesiredState.MUST_CREATED, ]: instance = serializer.instance if ( instance and not instance._state.adding and state == BlueprintEntryDesiredState.CREATED ): self.logger.debug( "Instance exists, skipping", model=model, instance=instance, pk=instance.pk, ) else: instance = serializer.save() self.logger.debug("Updated model", model=instance) if "pk" in entry.identifiers: self.__pk_map[entry.identifiers["pk"]] = instance.pk entry._state = BlueprintEntryState(instance) self._apply_permissions(instance, entry) elif state == BlueprintEntryDesiredState.ABSENT: instance: Model | None = serializer.instance if instance.pk: instance.delete() self.logger.debug("Deleted model", mode=instance) continue self.logger.debug("Entry to delete with no instance, skipping") return True def validate(self, raise_validation_errors=False) -> tuple[bool, list[LogEvent]]: """Validate loaded blueprint export, ensure all models are allowed and serializers have no errors""" self.logger.debug("Starting blueprint import validation") orig_import = deepcopy(self._import) if self._import.version != 1: self.logger.warning("Invalid blueprint version") return False, [LogEvent("Invalid blueprint version", log_level="warning", logger=None)] with ( transaction_rollback(), capture_logs() as logs, ): successful = self._apply_models(raise_errors=raise_validation_errors) if not successful: self.logger.warning("Blueprint validation failed") self.logger.debug("Finished blueprint import validation") self._import = orig_import return successful, logs