Compare commits
	
		
			1 Commits
		
	
	
		
			version/20
			...
			root/confi
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 3fa987f443 | 
| @ -5,13 +5,20 @@ from contextlib import contextmanager | |||||||
| from glob import glob | from glob import glob | ||||||
| from json import dumps, loads | from json import dumps, loads | ||||||
| from json.decoder import JSONDecodeError | from json.decoder import JSONDecodeError | ||||||
|  | from pathlib import Path | ||||||
| from sys import argv, stderr | from sys import argv, stderr | ||||||
| from time import time | from time import time | ||||||
| from typing import Any | from typing import Any, Optional | ||||||
| from urllib.parse import urlparse | from urllib.parse import urlparse | ||||||
|  |  | ||||||
| import yaml | import yaml | ||||||
| from django.conf import ImproperlyConfigured | from django.conf import ImproperlyConfigured | ||||||
|  | from watchdog.events import ( | ||||||
|  |     FileModifiedEvent, | ||||||
|  |     FileSystemEvent, | ||||||
|  |     FileSystemEventHandler, | ||||||
|  | ) | ||||||
|  | from watchdog.observers import Observer | ||||||
|  |  | ||||||
| SEARCH_PATHS = ["authentik/lib/default.yml", "/etc/authentik/config.yml", ""] + glob( | SEARCH_PATHS = ["authentik/lib/default.yml", "/etc/authentik/config.yml", ""] + glob( | ||||||
|     "/etc/authentik/config.d/*.yml", recursive=True |     "/etc/authentik/config.d/*.yml", recursive=True | ||||||
| @ -38,9 +45,47 @@ class ConfigLoader: | |||||||
|     A variable like AUTHENTIK_POSTGRESQL__HOST would translate to postgresql.host""" |     A variable like AUTHENTIK_POSTGRESQL__HOST would translate to postgresql.host""" | ||||||
|  |  | ||||||
|     loaded_file = [] |     loaded_file = [] | ||||||
|  |     observer: Observer | ||||||
|  |  | ||||||
|  |     class FSObserver(FileSystemEventHandler): | ||||||
|  |         """File system observer""" | ||||||
|  |  | ||||||
|  |         loader: "ConfigLoader" | ||||||
|  |         path: str | ||||||
|  |         container: Optional[dict] = None | ||||||
|  |         key: Optional[str] = None | ||||||
|  |  | ||||||
|  |         def __init__( | ||||||
|  |             self, | ||||||
|  |             loader: "ConfigLoader", | ||||||
|  |             path: str, | ||||||
|  |             container: Optional[dict] = None, | ||||||
|  |             key: Optional[str] = None, | ||||||
|  |         ) -> None: | ||||||
|  |             super().__init__() | ||||||
|  |             self.loader = loader | ||||||
|  |             self.path = path | ||||||
|  |             self.container = container | ||||||
|  |             self.key = key | ||||||
|  |  | ||||||
|  |         def on_any_event(self, event: FileSystemEvent): | ||||||
|  |             if not isinstance(event, FileModifiedEvent): | ||||||
|  |                 return | ||||||
|  |             if event.is_directory: | ||||||
|  |                 return | ||||||
|  |             if event.src_path != self.path: | ||||||
|  |                 return | ||||||
|  |             if self.container and self.key: | ||||||
|  |                 with open(self.path, "r", encoding="utf8") as _file: | ||||||
|  |                     self.container[self.key] = _file.read() | ||||||
|  |             else: | ||||||
|  |                 self.loader.log("info", "Updating from changed file", file=self.path) | ||||||
|  |                 self.loader.update_from_file(self.path, watch=False) | ||||||
|  |  | ||||||
|     def __init__(self): |     def __init__(self): | ||||||
|         super().__init__() |         super().__init__() | ||||||
|  |         self.observer = Observer() | ||||||
|  |         self.observer.start() | ||||||
|         self.__config = {} |         self.__config = {} | ||||||
|         base_dir = os.path.realpath(os.path.join(os.path.dirname(__file__), "../..")) |         base_dir = os.path.realpath(os.path.join(os.path.dirname(__file__), "../..")) | ||||||
|         for path in SEARCH_PATHS: |         for path in SEARCH_PATHS: | ||||||
| @ -81,11 +126,11 @@ class ConfigLoader: | |||||||
|                 root[key] = self.update(root.get(key, {}), value) |                 root[key] = self.update(root.get(key, {}), value) | ||||||
|             else: |             else: | ||||||
|                 if isinstance(value, str): |                 if isinstance(value, str): | ||||||
|                     value = self.parse_uri(value) |                     value = self.parse_uri(value, root, key) | ||||||
|                 root[key] = value |                 root[key] = value | ||||||
|         return root |         return root | ||||||
|  |  | ||||||
|     def parse_uri(self, value: str) -> str: |     def parse_uri(self, value: str, container: dict[str, Any], key: Optional[str] = None, ) -> str: | ||||||
|         """Parse string values which start with a URI""" |         """Parse string values which start with a URI""" | ||||||
|         url = urlparse(value) |         url = urlparse(value) | ||||||
|         if url.scheme == "env": |         if url.scheme == "env": | ||||||
| @ -93,13 +138,23 @@ class ConfigLoader: | |||||||
|         if url.scheme == "file": |         if url.scheme == "file": | ||||||
|             try: |             try: | ||||||
|                 with open(url.path, "r", encoding="utf8") as _file: |                 with open(url.path, "r", encoding="utf8") as _file: | ||||||
|                     value = _file.read().strip() |                     value = _file.read() | ||||||
|  |                 if key: | ||||||
|  |                     self.observer.schedule( | ||||||
|  |                         ConfigLoader.FSObserver( | ||||||
|  |                             self, | ||||||
|  |                             url.path, | ||||||
|  |                             container, | ||||||
|  |                             key, | ||||||
|  |                         ), | ||||||
|  |                         Path(url.path).parent, | ||||||
|  |                     ) | ||||||
|             except OSError as exc: |             except OSError as exc: | ||||||
|                 self.log("error", f"Failed to read config value from {url.path}: {exc}") |                 self.log("error", f"Failed to read config value from {url.path}: {exc}") | ||||||
|                 value = url.query |                 value = url.query | ||||||
|         return value |         return value | ||||||
|  |  | ||||||
|     def update_from_file(self, path: str): |     def update_from_file(self, path: str, watch=True): | ||||||
|         """Update config from file contents""" |         """Update config from file contents""" | ||||||
|         try: |         try: | ||||||
|             with open(path, encoding="utf8") as file: |             with open(path, encoding="utf8") as file: | ||||||
| @ -107,6 +162,8 @@ class ConfigLoader: | |||||||
|                     self.update(self.__config, yaml.safe_load(file)) |                     self.update(self.__config, yaml.safe_load(file)) | ||||||
|                     self.log("debug", "Loaded config", file=path) |                     self.log("debug", "Loaded config", file=path) | ||||||
|                     self.loaded_file.append(path) |                     self.loaded_file.append(path) | ||||||
|  |                     if watch: | ||||||
|  |                         self.observer.schedule(ConfigLoader.FSObserver(self, path), Path(path).parent) | ||||||
|                 except yaml.YAMLError as exc: |                 except yaml.YAMLError as exc: | ||||||
|                     raise ImproperlyConfigured from exc |                     raise ImproperlyConfigured from exc | ||||||
|         except PermissionError as exc: |         except PermissionError as exc: | ||||||
| @ -181,13 +238,12 @@ class ConfigLoader: | |||||||
|             if comp not in root: |             if comp not in root: | ||||||
|                 root[comp] = {} |                 root[comp] = {} | ||||||
|             root = root.get(comp, {}) |             root = root.get(comp, {}) | ||||||
|         root[path_parts[-1]] = value |         self.parse_uri(value, root, path_parts[-1]) | ||||||
|  |  | ||||||
|     def y_bool(self, path: str, default=False) -> bool: |     def y_bool(self, path: str, default=False) -> bool: | ||||||
|         """Wrapper for y that converts value into boolean""" |         """Wrapper for y that converts value into boolean""" | ||||||
|         return str(self.y(path, default)).lower() == "true" |         return str(self.y(path, default)).lower() == "true" | ||||||
|  |  | ||||||
|  |  | ||||||
| CONFIG = ConfigLoader() | CONFIG = ConfigLoader() | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|  | |||||||
| @ -5,7 +5,7 @@ from tempfile import mkstemp | |||||||
| from django.conf import ImproperlyConfigured | from django.conf import ImproperlyConfigured | ||||||
| from django.test import TestCase | from django.test import TestCase | ||||||
|  |  | ||||||
| from authentik.lib.config import ENV_PREFIX, ConfigLoader | from authentik.lib.config import CONFIG, ENV_PREFIX, ConfigLoader | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestConfig(TestCase): | class TestConfig(TestCase): | ||||||
| @ -31,8 +31,8 @@ class TestConfig(TestCase): | |||||||
|         """Test URI parsing (environment)""" |         """Test URI parsing (environment)""" | ||||||
|         config = ConfigLoader() |         config = ConfigLoader() | ||||||
|         environ["foo"] = "bar" |         environ["foo"] = "bar" | ||||||
|         self.assertEqual(config.parse_uri("env://foo"), "bar") |         self.assertEqual(config.parse_uri("env://foo", {}), "bar") | ||||||
|         self.assertEqual(config.parse_uri("env://foo?bar"), "bar") |         self.assertEqual(config.parse_uri("env://foo?bar", {}), "bar") | ||||||
|  |  | ||||||
|     def test_uri_file(self): |     def test_uri_file(self): | ||||||
|         """Test URI parsing (file load)""" |         """Test URI parsing (file load)""" | ||||||
| @ -41,8 +41,8 @@ class TestConfig(TestCase): | |||||||
|         write(file, "foo".encode()) |         write(file, "foo".encode()) | ||||||
|         _, file2_name = mkstemp() |         _, file2_name = mkstemp() | ||||||
|         chmod(file2_name, 0o000)  # Remove all permissions so we can't read the file |         chmod(file2_name, 0o000)  # Remove all permissions so we can't read the file | ||||||
|         self.assertEqual(config.parse_uri(f"file://{file_name}"), "foo") |         self.assertEqual(config.parse_uri(f"file://{file_name}", {}), "foo") | ||||||
|         self.assertEqual(config.parse_uri(f"file://{file2_name}?def"), "def") |         self.assertEqual(config.parse_uri(f"file://{file2_name}?def", {}), "def") | ||||||
|         unlink(file_name) |         unlink(file_name) | ||||||
|         unlink(file2_name) |         unlink(file2_name) | ||||||
|  |  | ||||||
| @ -59,3 +59,13 @@ class TestConfig(TestCase): | |||||||
|         config.update_from_file(file2_name) |         config.update_from_file(file2_name) | ||||||
|         unlink(file_name) |         unlink(file_name) | ||||||
|         unlink(file2_name) |         unlink(file2_name) | ||||||
|  |  | ||||||
|  |     def test_update(self): | ||||||
|  |         """Test change to file""" | ||||||
|  |         file, file_name = mkstemp() | ||||||
|  |         write(file, b"test") | ||||||
|  |         CONFIG.y_set("test.file", f"file://{file_name}") | ||||||
|  |         self.assertEqual(CONFIG.y("test.file"), "test") | ||||||
|  |         write(file, "test2") | ||||||
|  |         self.assertEqual(CONFIG.y("test.file"), "test2") | ||||||
|  |         unlink(file_name) | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user
	