Compare commits
	
		
			1 Commits
		
	
	
		
			imports-fo
			...
			root/confi
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 3fa987f443 | 
| @ -5,13 +5,20 @@ from contextlib import contextmanager | ||||
| from glob import glob | ||||
| from json import dumps, loads | ||||
| from json.decoder import JSONDecodeError | ||||
| from pathlib import Path | ||||
| from sys import argv, stderr | ||||
| from time import time | ||||
| from typing import Any | ||||
| from typing import Any, Optional | ||||
| from urllib.parse import urlparse | ||||
|  | ||||
| import yaml | ||||
| from django.conf import ImproperlyConfigured | ||||
| from watchdog.events import ( | ||||
|     FileModifiedEvent, | ||||
|     FileSystemEvent, | ||||
|     FileSystemEventHandler, | ||||
| ) | ||||
| from watchdog.observers import Observer | ||||
|  | ||||
| SEARCH_PATHS = ["authentik/lib/default.yml", "/etc/authentik/config.yml", ""] + glob( | ||||
|     "/etc/authentik/config.d/*.yml", recursive=True | ||||
| @ -38,9 +45,47 @@ class ConfigLoader: | ||||
|     A variable like AUTHENTIK_POSTGRESQL__HOST would translate to postgresql.host""" | ||||
|  | ||||
|     loaded_file = [] | ||||
|     observer: Observer | ||||
|  | ||||
|     class FSObserver(FileSystemEventHandler): | ||||
|         """File system observer""" | ||||
|  | ||||
|         loader: "ConfigLoader" | ||||
|         path: str | ||||
|         container: Optional[dict] = None | ||||
|         key: Optional[str] = None | ||||
|  | ||||
|         def __init__( | ||||
|             self, | ||||
|             loader: "ConfigLoader", | ||||
|             path: str, | ||||
|             container: Optional[dict] = None, | ||||
|             key: Optional[str] = None, | ||||
|         ) -> None: | ||||
|             super().__init__() | ||||
|             self.loader = loader | ||||
|             self.path = path | ||||
|             self.container = container | ||||
|             self.key = key | ||||
|  | ||||
|         def on_any_event(self, event: FileSystemEvent): | ||||
|             if not isinstance(event, FileModifiedEvent): | ||||
|                 return | ||||
|             if event.is_directory: | ||||
|                 return | ||||
|             if event.src_path != self.path: | ||||
|                 return | ||||
|             if self.container and self.key: | ||||
|                 with open(self.path, "r", encoding="utf8") as _file: | ||||
|                     self.container[self.key] = _file.read() | ||||
|             else: | ||||
|                 self.loader.log("info", "Updating from changed file", file=self.path) | ||||
|                 self.loader.update_from_file(self.path, watch=False) | ||||
|  | ||||
|     def __init__(self): | ||||
|         super().__init__() | ||||
|         self.observer = Observer() | ||||
|         self.observer.start() | ||||
|         self.__config = {} | ||||
|         base_dir = os.path.realpath(os.path.join(os.path.dirname(__file__), "../..")) | ||||
|         for path in SEARCH_PATHS: | ||||
| @ -81,11 +126,11 @@ class ConfigLoader: | ||||
|                 root[key] = self.update(root.get(key, {}), value) | ||||
|             else: | ||||
|                 if isinstance(value, str): | ||||
|                     value = self.parse_uri(value) | ||||
|                     value = self.parse_uri(value, root, key) | ||||
|                 root[key] = value | ||||
|         return root | ||||
|  | ||||
|     def parse_uri(self, value: str) -> str: | ||||
|     def parse_uri(self, value: str, container: dict[str, Any], key: Optional[str] = None, ) -> str: | ||||
|         """Parse string values which start with a URI""" | ||||
|         url = urlparse(value) | ||||
|         if url.scheme == "env": | ||||
| @ -93,13 +138,23 @@ class ConfigLoader: | ||||
|         if url.scheme == "file": | ||||
|             try: | ||||
|                 with open(url.path, "r", encoding="utf8") as _file: | ||||
|                     value = _file.read().strip() | ||||
|                     value = _file.read() | ||||
|                 if key: | ||||
|                     self.observer.schedule( | ||||
|                         ConfigLoader.FSObserver( | ||||
|                             self, | ||||
|                             url.path, | ||||
|                             container, | ||||
|                             key, | ||||
|                         ), | ||||
|                         Path(url.path).parent, | ||||
|                     ) | ||||
|             except OSError as exc: | ||||
|                 self.log("error", f"Failed to read config value from {url.path}: {exc}") | ||||
|                 value = url.query | ||||
|         return value | ||||
|  | ||||
|     def update_from_file(self, path: str): | ||||
|     def update_from_file(self, path: str, watch=True): | ||||
|         """Update config from file contents""" | ||||
|         try: | ||||
|             with open(path, encoding="utf8") as file: | ||||
| @ -107,6 +162,8 @@ class ConfigLoader: | ||||
|                     self.update(self.__config, yaml.safe_load(file)) | ||||
|                     self.log("debug", "Loaded config", file=path) | ||||
|                     self.loaded_file.append(path) | ||||
|                     if watch: | ||||
|                         self.observer.schedule(ConfigLoader.FSObserver(self, path), Path(path).parent) | ||||
|                 except yaml.YAMLError as exc: | ||||
|                     raise ImproperlyConfigured from exc | ||||
|         except PermissionError as exc: | ||||
| @ -181,13 +238,12 @@ class ConfigLoader: | ||||
|             if comp not in root: | ||||
|                 root[comp] = {} | ||||
|             root = root.get(comp, {}) | ||||
|         root[path_parts[-1]] = value | ||||
|         self.parse_uri(value, root, path_parts[-1]) | ||||
|  | ||||
|     def y_bool(self, path: str, default=False) -> bool: | ||||
|         """Wrapper for y that converts value into boolean""" | ||||
|         return str(self.y(path, default)).lower() == "true" | ||||
|  | ||||
|  | ||||
| CONFIG = ConfigLoader() | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|  | ||||
| @ -5,7 +5,7 @@ from tempfile import mkstemp | ||||
| from django.conf import ImproperlyConfigured | ||||
| from django.test import TestCase | ||||
|  | ||||
| from authentik.lib.config import ENV_PREFIX, ConfigLoader | ||||
| from authentik.lib.config import CONFIG, ENV_PREFIX, ConfigLoader | ||||
|  | ||||
|  | ||||
| class TestConfig(TestCase): | ||||
| @ -31,8 +31,8 @@ class TestConfig(TestCase): | ||||
|         """Test URI parsing (environment)""" | ||||
|         config = ConfigLoader() | ||||
|         environ["foo"] = "bar" | ||||
|         self.assertEqual(config.parse_uri("env://foo"), "bar") | ||||
|         self.assertEqual(config.parse_uri("env://foo?bar"), "bar") | ||||
|         self.assertEqual(config.parse_uri("env://foo", {}), "bar") | ||||
|         self.assertEqual(config.parse_uri("env://foo?bar", {}), "bar") | ||||
|  | ||||
|     def test_uri_file(self): | ||||
|         """Test URI parsing (file load)""" | ||||
| @ -41,8 +41,8 @@ class TestConfig(TestCase): | ||||
|         write(file, "foo".encode()) | ||||
|         _, file2_name = mkstemp() | ||||
|         chmod(file2_name, 0o000)  # Remove all permissions so we can't read the file | ||||
|         self.assertEqual(config.parse_uri(f"file://{file_name}"), "foo") | ||||
|         self.assertEqual(config.parse_uri(f"file://{file2_name}?def"), "def") | ||||
|         self.assertEqual(config.parse_uri(f"file://{file_name}", {}), "foo") | ||||
|         self.assertEqual(config.parse_uri(f"file://{file2_name}?def", {}), "def") | ||||
|         unlink(file_name) | ||||
|         unlink(file2_name) | ||||
|  | ||||
| @ -59,3 +59,13 @@ class TestConfig(TestCase): | ||||
|         config.update_from_file(file2_name) | ||||
|         unlink(file_name) | ||||
|         unlink(file2_name) | ||||
|  | ||||
|     def test_update(self): | ||||
|         """Test change to file""" | ||||
|         file, file_name = mkstemp() | ||||
|         write(file, b"test") | ||||
|         CONFIG.y_set("test.file", f"file://{file_name}") | ||||
|         self.assertEqual(CONFIG.y("test.file"), "test") | ||||
|         write(file, "test2") | ||||
|         self.assertEqual(CONFIG.y("test.file"), "test2") | ||||
|         unlink(file_name) | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	