Files
authentik/authentik/blueprints/v1/common.py
2024-12-20 21:43:41 +01:00

697 lines
24 KiB
Python

"""transfer common classes"""
from collections import OrderedDict
from collections.abc import Generator, Iterable, Mapping
from copy import copy
from dataclasses import asdict, dataclass, field, is_dataclass
from enum import Enum
from functools import reduce
from operator import ixor
from os import getenv
from typing import Any, Literal, Union
from uuid import UUID
from deepmerge import always_merger
from django.apps import apps
from django.db.models import Model, Q
from rest_framework.exceptions import ValidationError
from rest_framework.fields import Field
from rest_framework.serializers import Serializer
from yaml import SafeDumper, SafeLoader, ScalarNode, SequenceNode
from authentik.lib.models import SerializerModel
from authentik.lib.sentry import SentryIgnoredException
from authentik.policies.models import PolicyBindingModel
class UNSET:
"""Used to test whether a key has not been set."""
def get_attrs(obj: SerializerModel) -> dict[str, Any]:
"""Get object's attributes via their serializer, and convert it to a normal dict"""
serializer: Serializer = obj.serializer(obj)
data = dict(serializer.data)
for field_name, _field in serializer.fields.items():
_field: Field
if field_name not in data:
continue
if _field.read_only:
data.pop(field_name, None)
if _field.get_initial() == data.get(field_name, None):
data.pop(field_name, None)
if field_name.endswith("_set"):
data.pop(field_name, None)
return data
@dataclass
class BlueprintEntryState:
"""State of a single instance"""
instance: Model | None = None
class BlueprintEntryDesiredState(Enum):
"""State an entry should be reconciled to"""
ABSENT = "absent"
PRESENT = "present"
CREATED = "created"
MUST_CREATED = "must_created"
@dataclass
class BlueprintEntryPermission:
"""Describe object-level permissions"""
permission: Union[str, "YAMLTag"]
user: Union[int, "YAMLTag", None] = field(default=None)
role: Union[str, "YAMLTag", None] = field(default=None)
@dataclass
class BlueprintEntry:
"""Single entry of a blueprint"""
model: Union[str, "YAMLTag"]
state: Union[BlueprintEntryDesiredState, "YAMLTag"] = field(
default=BlueprintEntryDesiredState.PRESENT
)
conditions: list[Any] = field(default_factory=list)
identifiers: dict[str, Any] = field(default_factory=dict)
attrs: dict[str, Any] | None = field(default_factory=dict)
permissions: list[BlueprintEntryPermission] = field(default_factory=list)
id: str | None = None
_state: BlueprintEntryState = field(default_factory=BlueprintEntryState)
def __post_init__(self, *args, **kwargs) -> None:
self.__tag_contexts: list[YAMLTagContext] = []
@staticmethod
def from_model(model: SerializerModel, *extra_identifier_names: str) -> "BlueprintEntry":
"""Convert a SerializerModel instance to a blueprint Entry"""
identifiers = {
"pk": model.pk,
}
all_attrs = get_attrs(model)
for extra_identifier_name in extra_identifier_names:
identifiers[extra_identifier_name] = all_attrs.pop(extra_identifier_name, None)
return BlueprintEntry(
identifiers=identifiers,
model=f"{model._meta.app_label}.{model._meta.model_name}",
attrs=all_attrs,
)
def get_tag_context(
self,
depth: int = 0,
context_tag_type: type["YAMLTagContext"] | tuple["YAMLTagContext", ...] | None = None,
) -> "YAMLTagContext":
"""Get a YAMLTagContext object located at a certain depth in the tag tree"""
if depth < 0:
raise ValueError("depth must be a positive number or zero")
if context_tag_type:
contexts = [x for x in self.__tag_contexts if isinstance(x, context_tag_type)]
else:
contexts = self.__tag_contexts
try:
return contexts[-(depth + 1)]
except IndexError as exc:
raise ValueError(f"invalid depth: {depth}. Max depth: {len(contexts) - 1}") from exc
def tag_resolver(self, value: Any, blueprint: "Blueprint") -> Any:
"""Check if we have any special tags that need handling"""
val = copy(value)
if isinstance(value, YAMLTagContext):
self.__tag_contexts.append(value)
if isinstance(value, YAMLTag):
val = value.resolve(self, blueprint)
if isinstance(value, dict):
for key, inner_value in value.items():
val[key] = self.tag_resolver(inner_value, blueprint)
if isinstance(value, list):
for idx, inner_value in enumerate(value):
val[idx] = self.tag_resolver(inner_value, blueprint)
if isinstance(value, YAMLTagContext):
self.__tag_contexts.pop()
return val
def get_attrs(self, blueprint: "Blueprint") -> dict[str, Any]:
"""Get attributes of this entry, with all yaml tags resolved"""
return self.tag_resolver(self.attrs, blueprint)
def get_identifiers(self, blueprint: "Blueprint") -> dict[str, Any]:
"""Get attributes of this entry, with all yaml tags resolved"""
return self.tag_resolver(self.identifiers, blueprint)
def get_state(self, blueprint: "Blueprint") -> BlueprintEntryDesiredState:
"""Get the blueprint state, with yaml tags resolved if present"""
return BlueprintEntryDesiredState(self.tag_resolver(self.state, blueprint))
def get_model(self, blueprint: "Blueprint") -> str:
"""Get the blueprint model, with yaml tags resolved if present"""
return str(self.tag_resolver(self.model, blueprint))
def get_permissions(
self, blueprint: "Blueprint"
) -> Generator[BlueprintEntryPermission, None, None]:
"""Get permissions of this entry, with all yaml tags resolved"""
for perm in self.permissions:
yield BlueprintEntryPermission(
permission=self.tag_resolver(perm.permission, blueprint),
user=self.tag_resolver(perm.user, blueprint),
role=self.tag_resolver(perm.role, blueprint),
)
def check_all_conditions_match(self, blueprint: "Blueprint") -> bool:
"""Check all conditions of this entry match (evaluate to True)"""
return all(self.tag_resolver(self.conditions, blueprint))
@dataclass
class BlueprintMetadata:
"""Optional blueprint metadata"""
name: str
labels: dict[str, str] = field(default_factory=dict)
@dataclass
class Blueprint:
"""Dataclass used for a full export"""
version: int = field(default=1)
entries: list[BlueprintEntry] = field(default_factory=list)
context: dict = field(default_factory=dict)
metadata: BlueprintMetadata | None = field(default=None)
class YAMLTag:
"""Base class for all YAML Tags"""
def __repr__(self) -> str:
return str(self.resolve(BlueprintEntry(""), Blueprint()))
def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any:
"""Implement yaml tag logic"""
raise NotImplementedError
class YAMLTagContext:
"""Base class for all YAML Tag Contexts"""
def get_context(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any:
"""Implement yaml tag context logic"""
raise NotImplementedError
class KeyOf(YAMLTag):
"""Reference another object by their ID"""
id_from: str
def __init__(self, loader: "BlueprintLoader", node: ScalarNode) -> None:
super().__init__()
self.id_from = node.value
def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any:
for _entry in blueprint.entries:
if _entry.id == self.id_from and _entry._state.instance:
# Special handling for PolicyBindingModels, as they'll have a different PK
# which is used when creating policy bindings
if (
isinstance(_entry._state.instance, PolicyBindingModel)
and entry.model.lower() == "authentik_policies.policybinding"
):
return _entry._state.instance.pbm_uuid
return _entry._state.instance.pk
raise EntryInvalidError.from_entry(
f"KeyOf: failed to find entry with `id` of `{self.id_from}` and a model instance", entry
)
class Env(YAMLTag):
"""Lookup environment variable with optional default"""
key: str
default: Any | None
def __init__(self, loader: "BlueprintLoader", node: ScalarNode | SequenceNode) -> None:
super().__init__()
self.default = None
if isinstance(node, ScalarNode):
self.key = node.value
if isinstance(node, SequenceNode):
self.key = loader.construct_object(node.value[0])
self.default = loader.construct_object(node.value[1])
def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any:
return getenv(self.key) or self.default
class Context(YAMLTag):
"""Lookup key from instance context"""
key: str
default: Any | None
def __init__(self, loader: "BlueprintLoader", node: ScalarNode | SequenceNode) -> None:
super().__init__()
self.default = None
if isinstance(node, ScalarNode):
self.key = node.value
if isinstance(node, SequenceNode):
self.key = loader.construct_object(node.value[0])
self.default = loader.construct_object(node.value[1])
def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any:
value = self.default
if self.key in blueprint.context:
value = blueprint.context[self.key]
if isinstance(value, YAMLTag):
return value.resolve(entry, blueprint)
return value
class Format(YAMLTag):
"""Format a string"""
format_string: str
args: list[Any]
def __init__(self, loader: "BlueprintLoader", node: SequenceNode) -> None:
super().__init__()
self.format_string = loader.construct_object(node.value[0])
self.args = []
for raw_node in node.value[1:]:
self.args.append(loader.construct_object(raw_node))
def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any:
args = []
for arg in self.args:
if isinstance(arg, YAMLTag):
args.append(arg.resolve(entry, blueprint))
else:
args.append(arg)
try:
return self.format_string % tuple(args)
except TypeError as exc:
raise EntryInvalidError.from_entry(exc, entry) from exc
class Find(YAMLTag):
"""Find any object"""
model_name: str | YAMLTag
conditions: list[list]
def __init__(self, loader: "BlueprintLoader", node: SequenceNode) -> None:
super().__init__()
self.model_name = loader.construct_object(node.value[0])
self.conditions = []
for raw_node in node.value[1:]:
values = []
for node_values in raw_node.value:
values.append(loader.construct_object(node_values))
self.conditions.append(values)
def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any:
if isinstance(self.model_name, YAMLTag):
model_name = self.model_name.resolve(entry, blueprint)
else:
model_name = self.model_name
try:
model_class = apps.get_model(*model_name.split("."))
except LookupError as exc:
raise EntryInvalidError.from_entry(exc, entry) from exc
query = Q()
for cond in self.conditions:
if isinstance(cond[0], YAMLTag):
query_key = cond[0].resolve(entry, blueprint)
else:
query_key = cond[0]
if isinstance(cond[1], YAMLTag):
query_value = cond[1].resolve(entry, blueprint)
else:
query_value = cond[1]
query &= Q(**{query_key: query_value})
instance = model_class.objects.filter(query).first()
if instance:
return instance.pk
return None
class Condition(YAMLTag):
"""Convert all values to a single boolean"""
mode: Literal["AND", "NAND", "OR", "NOR", "XOR", "XNOR"]
args: list[Any]
_COMPARATORS = {
# Using all and any here instead of from operator import iand, ior
# to improve performance
"AND": all,
"NAND": lambda args: not all(args),
"OR": any,
"NOR": lambda args: not any(args),
"XOR": lambda args: reduce(ixor, args) if len(args) > 1 else args[0],
"XNOR": lambda args: not (reduce(ixor, args) if len(args) > 1 else args[0]),
}
def __init__(self, loader: "BlueprintLoader", node: SequenceNode) -> None:
super().__init__()
self.mode = loader.construct_object(node.value[0])
self.args = []
for raw_node in node.value[1:]:
self.args.append(loader.construct_object(raw_node))
def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any:
args = []
for arg in self.args:
if isinstance(arg, YAMLTag):
args.append(arg.resolve(entry, blueprint))
else:
args.append(arg)
if not args:
raise EntryInvalidError.from_entry(
"At least one value is required after mode selection.", entry
)
try:
comparator = self._COMPARATORS[self.mode.upper()]
return comparator(tuple(bool(x) for x in args))
except (TypeError, KeyError) as exc:
raise EntryInvalidError.from_entry(exc, entry) from exc
class If(YAMLTag):
"""Select YAML to use based on condition"""
condition: Any
when_true: Any
when_false: Any
def __init__(self, loader: "BlueprintLoader", node: SequenceNode) -> None:
super().__init__()
self.condition = loader.construct_object(node.value[0])
if len(node.value) == 1:
self.when_true = True
self.when_false = False
else:
self.when_true = loader.construct_object(node.value[1])
self.when_false = loader.construct_object(node.value[2])
def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any:
if isinstance(self.condition, YAMLTag):
condition = self.condition.resolve(entry, blueprint)
else:
condition = self.condition
try:
return entry.tag_resolver(
self.when_true if condition else self.when_false,
blueprint,
)
except TypeError as exc:
raise EntryInvalidError.from_entry(exc, entry) from exc
class Enumerate(YAMLTag, YAMLTagContext):
"""Iterate over an iterable."""
iterable: YAMLTag | Iterable
item_body: Any
output_body: Literal["SEQ", "MAP"]
_OUTPUT_BODIES = {
"SEQ": (list, lambda a, b: [*a, b]),
"MAP": (
dict,
lambda a, b: always_merger.merge(a, {b[0]: b[1]} if isinstance(b, tuple | list) else b),
),
}
def __init__(self, loader: "BlueprintLoader", node: SequenceNode) -> None:
super().__init__()
self.iterable = loader.construct_object(node.value[0])
self.output_body = loader.construct_object(node.value[1])
self.item_body = loader.construct_object(node.value[2])
self.__current_context: tuple[Any, Any] = tuple()
def get_context(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any:
return self.__current_context
def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any:
if isinstance(self.iterable, EnumeratedItem) and self.iterable.depth == 0:
raise EntryInvalidError.from_entry(
f"{self.__class__.__name__} tag's iterable references this tag's context. "
"This is a noop. Check you are setting depth bigger than 0.",
entry,
)
if isinstance(self.iterable, YAMLTag):
iterable = self.iterable.resolve(entry, blueprint)
else:
iterable = self.iterable
if not isinstance(iterable, Iterable):
raise EntryInvalidError.from_entry(
f"{self.__class__.__name__}'s iterable must be an iterable "
"such as a sequence or a mapping",
entry,
)
if isinstance(iterable, Mapping):
iterable = tuple(iterable.items())
else:
iterable = tuple(enumerate(iterable))
try:
output_class, add_fn = self._OUTPUT_BODIES[self.output_body.upper()]
except KeyError as exc:
raise EntryInvalidError.from_entry(exc, entry) from exc
result = output_class()
self.__current_context = tuple()
try:
for item in iterable:
self.__current_context = item
resolved_body = entry.tag_resolver(self.item_body, blueprint)
result = add_fn(result, resolved_body)
if not isinstance(result, output_class):
raise EntryInvalidError.from_entry(
f"Invalid {self.__class__.__name__} item found: {resolved_body}", entry
)
finally:
self.__current_context = tuple()
return result
class EnumeratedItem(YAMLTag):
"""Get the current item value and index provided by an Enumerate tag context"""
depth: int
_SUPPORTED_CONTEXT_TAGS = (Enumerate,)
def __init__(self, _loader: "BlueprintLoader", node: ScalarNode) -> None:
super().__init__()
self.depth = int(node.value)
def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any:
try:
context_tag: Enumerate = entry.get_tag_context(
depth=self.depth,
context_tag_type=EnumeratedItem._SUPPORTED_CONTEXT_TAGS,
)
except ValueError as exc:
if self.depth == 0:
raise EntryInvalidError.from_entry(
f"{self.__class__.__name__} tags are only usable "
f"inside an {Enumerate.__name__} tag",
entry,
) from exc
raise EntryInvalidError.from_entry(
f"{self.__class__.__name__} tag: {exc}", entry
) from exc
return context_tag.get_context(entry, blueprint)
class Index(EnumeratedItem):
"""Get the current item index provided by an Enumerate tag context"""
def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any:
context = super().resolve(entry, blueprint)
try:
return context[0]
except IndexError as exc: # pragma: no cover
raise EntryInvalidError.from_entry(f"Empty/invalid context: {context}", entry) from exc
class Value(EnumeratedItem):
"""Get the current item value provided by an Enumerate tag context"""
def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any:
context = super().resolve(entry, blueprint)
try:
return context[1]
except IndexError as exc: # pragma: no cover
raise EntryInvalidError.from_entry(f"Empty/invalid context: {context}", entry) from exc
class AtIndex(YAMLTag):
"""Get value at index of a sequence or mapping"""
obj: YAMLTag | dict | list | tuple
attribute: int | str | YAMLTag
default: Any | UNSET
def __init__(self, loader: "BlueprintLoader", node: SequenceNode) -> None:
super().__init__()
self.obj = loader.construct_object(node.value[0])
self.attribute = loader.construct_object(node.value[1])
if len(node.value) == 2: # noqa: PLR2004
self.default = UNSET
else:
self.default = loader.construct_object(node.value[2])
def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any:
if isinstance(self.obj, YAMLTag):
obj = self.obj.resolve(entry, blueprint)
else:
obj = self.obj
if isinstance(self.attribute, YAMLTag):
attribute = self.attribute.resolve(entry, blueprint)
else:
attribute = self.attribute
if isinstance(obj, list | tuple):
try:
return obj[attribute]
except TypeError as exc:
raise EntryInvalidError.from_entry(
f"Invalid index for list: {attribute}", entry
) from exc
except IndexError as exc:
if self.default is UNSET:
raise EntryInvalidError.from_entry(
f"Index out of range: {attribute}", entry
) from exc
return self.default
if attribute in obj:
return obj[attribute]
else:
if self.default is UNSET:
raise EntryInvalidError.from_entry(f"Key does not exist: {attribute}", entry)
return self.default
class BlueprintDumper(SafeDumper):
"""Dump dataclasses to yaml"""
default_flow_style = False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.add_representer(UUID, lambda self, data: self.represent_str(str(data)))
self.add_representer(OrderedDict, lambda self, data: self.represent_dict(dict(data)))
self.add_representer(Enum, lambda self, data: self.represent_str(data.value))
self.add_representer(
BlueprintEntryDesiredState, lambda self, data: self.represent_str(data.value)
)
self.add_representer(None, lambda self, data: self.represent_str(str(data)))
def ignore_aliases(self, data):
"""Don't use any YAML anchors"""
return True
def represent(self, data) -> None:
if is_dataclass(data):
def factory(items):
final_dict = dict(items)
# Remove internal state variables
final_dict.pop("_state", None)
# Future-proof to only remove the ID if we don't set a value
if "id" in final_dict and final_dict.get("id") is None:
final_dict.pop("id")
return final_dict
data = asdict(data, dict_factory=factory)
return super().represent(data)
class BlueprintLoader(SafeLoader):
"""Loader for blueprints with custom tag support"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.add_constructor("!KeyOf", KeyOf)
self.add_constructor("!Find", Find)
self.add_constructor("!Context", Context)
self.add_constructor("!Format", Format)
self.add_constructor("!Condition", Condition)
self.add_constructor("!If", If)
self.add_constructor("!Env", Env)
self.add_constructor("!Enumerate", Enumerate)
self.add_constructor("!Value", Value)
self.add_constructor("!Index", Index)
self.add_constructor("!AtIndex", AtIndex)
class EntryInvalidError(SentryIgnoredException):
"""Error raised when an entry is invalid"""
entry_model: str | None
entry_id: str | None
validation_error: ValidationError | None
serializer: Serializer | None = None
def __init__(
self, *args: object, validation_error: ValidationError | None = None, **kwargs
) -> None:
super().__init__(*args)
self.entry_model = None
self.entry_id = None
self.validation_error = validation_error
for key, value in kwargs.items():
setattr(self, key, value)
@staticmethod
def from_entry(
msg_or_exc: str | Exception, entry: BlueprintEntry, *args, **kwargs
) -> "EntryInvalidError":
"""Create EntryInvalidError with the context of an entry"""
error = EntryInvalidError(msg_or_exc, *args, **kwargs)
if isinstance(msg_or_exc, ValidationError):
error.validation_error = msg_or_exc
# Make sure the model and id are strings, depending where the error happens
# they might still be YAMLTag instances
error.entry_model = str(entry.model)
error.entry_id = str(entry.id)
return error