Compare commits
	
		
			1 Commits
		
	
	
		
			version/20
			...
			root/confi
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 3fa987f443 | 
@ -5,13 +5,20 @@ from contextlib import contextmanager
 | 
			
		||||
from glob import glob
 | 
			
		||||
from json import dumps, loads
 | 
			
		||||
from json.decoder import JSONDecodeError
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from sys import argv, stderr
 | 
			
		||||
from time import time
 | 
			
		||||
from typing import Any
 | 
			
		||||
from typing import Any, Optional
 | 
			
		||||
from urllib.parse import urlparse
 | 
			
		||||
 | 
			
		||||
import yaml
 | 
			
		||||
from django.conf import ImproperlyConfigured
 | 
			
		||||
from watchdog.events import (
 | 
			
		||||
    FileModifiedEvent,
 | 
			
		||||
    FileSystemEvent,
 | 
			
		||||
    FileSystemEventHandler,
 | 
			
		||||
)
 | 
			
		||||
from watchdog.observers import Observer
 | 
			
		||||
 | 
			
		||||
SEARCH_PATHS = ["authentik/lib/default.yml", "/etc/authentik/config.yml", ""] + glob(
 | 
			
		||||
    "/etc/authentik/config.d/*.yml", recursive=True
 | 
			
		||||
@ -38,9 +45,47 @@ class ConfigLoader:
 | 
			
		||||
    A variable like AUTHENTIK_POSTGRESQL__HOST would translate to postgresql.host"""
 | 
			
		||||
 | 
			
		||||
    loaded_file = []
 | 
			
		||||
    observer: Observer
 | 
			
		||||
 | 
			
		||||
    class FSObserver(FileSystemEventHandler):
 | 
			
		||||
        """File system observer"""
 | 
			
		||||
 | 
			
		||||
        loader: "ConfigLoader"
 | 
			
		||||
        path: str
 | 
			
		||||
        container: Optional[dict] = None
 | 
			
		||||
        key: Optional[str] = None
 | 
			
		||||
 | 
			
		||||
        def __init__(
 | 
			
		||||
            self,
 | 
			
		||||
            loader: "ConfigLoader",
 | 
			
		||||
            path: str,
 | 
			
		||||
            container: Optional[dict] = None,
 | 
			
		||||
            key: Optional[str] = None,
 | 
			
		||||
        ) -> None:
 | 
			
		||||
            super().__init__()
 | 
			
		||||
            self.loader = loader
 | 
			
		||||
            self.path = path
 | 
			
		||||
            self.container = container
 | 
			
		||||
            self.key = key
 | 
			
		||||
 | 
			
		||||
        def on_any_event(self, event: FileSystemEvent):
 | 
			
		||||
            if not isinstance(event, FileModifiedEvent):
 | 
			
		||||
                return
 | 
			
		||||
            if event.is_directory:
 | 
			
		||||
                return
 | 
			
		||||
            if event.src_path != self.path:
 | 
			
		||||
                return
 | 
			
		||||
            if self.container and self.key:
 | 
			
		||||
                with open(self.path, "r", encoding="utf8") as _file:
 | 
			
		||||
                    self.container[self.key] = _file.read()
 | 
			
		||||
            else:
 | 
			
		||||
                self.loader.log("info", "Updating from changed file", file=self.path)
 | 
			
		||||
                self.loader.update_from_file(self.path, watch=False)
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.observer = Observer()
 | 
			
		||||
        self.observer.start()
 | 
			
		||||
        self.__config = {}
 | 
			
		||||
        base_dir = os.path.realpath(os.path.join(os.path.dirname(__file__), "../.."))
 | 
			
		||||
        for path in SEARCH_PATHS:
 | 
			
		||||
@ -81,11 +126,11 @@ class ConfigLoader:
 | 
			
		||||
                root[key] = self.update(root.get(key, {}), value)
 | 
			
		||||
            else:
 | 
			
		||||
                if isinstance(value, str):
 | 
			
		||||
                    value = self.parse_uri(value)
 | 
			
		||||
                    value = self.parse_uri(value, root, key)
 | 
			
		||||
                root[key] = value
 | 
			
		||||
        return root
 | 
			
		||||
 | 
			
		||||
    def parse_uri(self, value: str) -> str:
 | 
			
		||||
    def parse_uri(self, value: str, container: dict[str, Any], key: Optional[str] = None, ) -> str:
 | 
			
		||||
        """Parse string values which start with a URI"""
 | 
			
		||||
        url = urlparse(value)
 | 
			
		||||
        if url.scheme == "env":
 | 
			
		||||
@ -93,13 +138,23 @@ class ConfigLoader:
 | 
			
		||||
        if url.scheme == "file":
 | 
			
		||||
            try:
 | 
			
		||||
                with open(url.path, "r", encoding="utf8") as _file:
 | 
			
		||||
                    value = _file.read().strip()
 | 
			
		||||
                    value = _file.read()
 | 
			
		||||
                if key:
 | 
			
		||||
                    self.observer.schedule(
 | 
			
		||||
                        ConfigLoader.FSObserver(
 | 
			
		||||
                            self,
 | 
			
		||||
                            url.path,
 | 
			
		||||
                            container,
 | 
			
		||||
                            key,
 | 
			
		||||
                        ),
 | 
			
		||||
                        Path(url.path).parent,
 | 
			
		||||
                    )
 | 
			
		||||
            except OSError as exc:
 | 
			
		||||
                self.log("error", f"Failed to read config value from {url.path}: {exc}")
 | 
			
		||||
                value = url.query
 | 
			
		||||
        return value
 | 
			
		||||
 | 
			
		||||
    def update_from_file(self, path: str):
 | 
			
		||||
    def update_from_file(self, path: str, watch=True):
 | 
			
		||||
        """Update config from file contents"""
 | 
			
		||||
        try:
 | 
			
		||||
            with open(path, encoding="utf8") as file:
 | 
			
		||||
@ -107,6 +162,8 @@ class ConfigLoader:
 | 
			
		||||
                    self.update(self.__config, yaml.safe_load(file))
 | 
			
		||||
                    self.log("debug", "Loaded config", file=path)
 | 
			
		||||
                    self.loaded_file.append(path)
 | 
			
		||||
                    if watch:
 | 
			
		||||
                        self.observer.schedule(ConfigLoader.FSObserver(self, path), Path(path).parent)
 | 
			
		||||
                except yaml.YAMLError as exc:
 | 
			
		||||
                    raise ImproperlyConfigured from exc
 | 
			
		||||
        except PermissionError as exc:
 | 
			
		||||
@ -181,13 +238,12 @@ class ConfigLoader:
 | 
			
		||||
            if comp not in root:
 | 
			
		||||
                root[comp] = {}
 | 
			
		||||
            root = root.get(comp, {})
 | 
			
		||||
        root[path_parts[-1]] = value
 | 
			
		||||
        self.parse_uri(value, root, path_parts[-1])
 | 
			
		||||
 | 
			
		||||
    def y_bool(self, path: str, default=False) -> bool:
 | 
			
		||||
        """Wrapper for y that converts value into boolean"""
 | 
			
		||||
        return str(self.y(path, default)).lower() == "true"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
CONFIG = ConfigLoader()
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
 | 
			
		||||
@ -5,7 +5,7 @@ from tempfile import mkstemp
 | 
			
		||||
from django.conf import ImproperlyConfigured
 | 
			
		||||
from django.test import TestCase
 | 
			
		||||
 | 
			
		||||
from authentik.lib.config import ENV_PREFIX, ConfigLoader
 | 
			
		||||
from authentik.lib.config import CONFIG, ENV_PREFIX, ConfigLoader
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestConfig(TestCase):
 | 
			
		||||
@ -31,8 +31,8 @@ class TestConfig(TestCase):
 | 
			
		||||
        """Test URI parsing (environment)"""
 | 
			
		||||
        config = ConfigLoader()
 | 
			
		||||
        environ["foo"] = "bar"
 | 
			
		||||
        self.assertEqual(config.parse_uri("env://foo"), "bar")
 | 
			
		||||
        self.assertEqual(config.parse_uri("env://foo?bar"), "bar")
 | 
			
		||||
        self.assertEqual(config.parse_uri("env://foo", {}), "bar")
 | 
			
		||||
        self.assertEqual(config.parse_uri("env://foo?bar", {}), "bar")
 | 
			
		||||
 | 
			
		||||
    def test_uri_file(self):
 | 
			
		||||
        """Test URI parsing (file load)"""
 | 
			
		||||
@ -41,8 +41,8 @@ class TestConfig(TestCase):
 | 
			
		||||
        write(file, "foo".encode())
 | 
			
		||||
        _, file2_name = mkstemp()
 | 
			
		||||
        chmod(file2_name, 0o000)  # Remove all permissions so we can't read the file
 | 
			
		||||
        self.assertEqual(config.parse_uri(f"file://{file_name}"), "foo")
 | 
			
		||||
        self.assertEqual(config.parse_uri(f"file://{file2_name}?def"), "def")
 | 
			
		||||
        self.assertEqual(config.parse_uri(f"file://{file_name}", {}), "foo")
 | 
			
		||||
        self.assertEqual(config.parse_uri(f"file://{file2_name}?def", {}), "def")
 | 
			
		||||
        unlink(file_name)
 | 
			
		||||
        unlink(file2_name)
 | 
			
		||||
 | 
			
		||||
@ -59,3 +59,13 @@ class TestConfig(TestCase):
 | 
			
		||||
        config.update_from_file(file2_name)
 | 
			
		||||
        unlink(file_name)
 | 
			
		||||
        unlink(file2_name)
 | 
			
		||||
 | 
			
		||||
    def test_update(self):
 | 
			
		||||
        """Test change to file"""
 | 
			
		||||
        file, file_name = mkstemp()
 | 
			
		||||
        write(file, b"test")
 | 
			
		||||
        CONFIG.y_set("test.file", f"file://{file_name}")
 | 
			
		||||
        self.assertEqual(CONFIG.y("test.file"), "test")
 | 
			
		||||
        write(file, "test2")
 | 
			
		||||
        self.assertEqual(CONFIG.y("test.file"), "test2")
 | 
			
		||||
        unlink(file_name)
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user