root: add get_int to config loader instead of casting to int everywhere (#6436)
* root: add get_int to config loader instead of casting to int everywhere Signed-off-by: Jens Langhammer <jens@goauthentik.io> * improve error handling, add test Signed-off-by: Jens Langhammer <jens@goauthentik.io> --------- Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
		| @ -93,10 +93,10 @@ class ConfigView(APIView): | ||||
|                     "traces_sample_rate": float(CONFIG.get("error_reporting.sample_rate", 0.4)), | ||||
|                 }, | ||||
|                 "capabilities": self.get_capabilities(), | ||||
|                 "cache_timeout": int(CONFIG.get("redis.cache_timeout")), | ||||
|                 "cache_timeout_flows": int(CONFIG.get("redis.cache_timeout_flows")), | ||||
|                 "cache_timeout_policies": int(CONFIG.get("redis.cache_timeout_policies")), | ||||
|                 "cache_timeout_reputation": int(CONFIG.get("redis.cache_timeout_reputation")), | ||||
|                 "cache_timeout": CONFIG.get_int("redis.cache_timeout"), | ||||
|                 "cache_timeout_flows": CONFIG.get_int("redis.cache_timeout_flows"), | ||||
|                 "cache_timeout_policies": CONFIG.get_int("redis.cache_timeout_policies"), | ||||
|                 "cache_timeout_reputation": CONFIG.get_int("redis.cache_timeout_reputation"), | ||||
|             } | ||||
|         ) | ||||
|  | ||||
|  | ||||
| @ -60,7 +60,7 @@ def default_token_key(): | ||||
|     """Default token key""" | ||||
|     # We use generate_id since the chars in the key should be easy | ||||
|     # to use in Emails (for verification) and URLs (for recovery) | ||||
|     return generate_id(int(CONFIG.get("default_token_length"))) | ||||
|     return generate_id(CONFIG.get_int("default_token_length")) | ||||
|  | ||||
|  | ||||
| class UserTypes(models.TextChoices): | ||||
|  | ||||
| @ -33,7 +33,7 @@ PLAN_CONTEXT_SOURCE = "source" | ||||
| # Is set by the Flow Planner when a FlowToken was used, and the currently active flow plan | ||||
| # was restored. | ||||
| PLAN_CONTEXT_IS_RESTORED = "is_restored" | ||||
| CACHE_TIMEOUT = int(CONFIG.get("redis.cache_timeout_flows")) | ||||
| CACHE_TIMEOUT = CONFIG.get_int("redis.cache_timeout_flows") | ||||
| CACHE_PREFIX = "goauthentik.io/flows/planner/" | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -213,6 +213,14 @@ class ConfigLoader: | ||||
|         attr: Attr = get_path_from_dict(root, path, sep=sep, default=Attr(default)) | ||||
|         return attr.value | ||||
|  | ||||
|     def get_int(self, path: str, default=0) -> int: | ||||
|         """Wrapper for get that converts value into int""" | ||||
|         try: | ||||
|             return int(self.get(path, default)) | ||||
|         except ValueError as exc: | ||||
|             self.log("warning", "Failed to parse config as int", path=path, exc=str(exc)) | ||||
|             return default | ||||
|  | ||||
|     def get_bool(self, path: str, default=False) -> bool: | ||||
|         """Wrapper for get that converts value into boolean""" | ||||
|         return str(self.get(path, default)).lower() == "true" | ||||
|  | ||||
| @ -79,3 +79,15 @@ class TestConfig(TestCase): | ||||
|         config.update_from_file(file2_name) | ||||
|         unlink(file_name) | ||||
|         unlink(file2_name) | ||||
|  | ||||
|     def test_get_int(self): | ||||
|         """Test get_int""" | ||||
|         config = ConfigLoader() | ||||
|         config.set("foo", 1234) | ||||
|         self.assertEqual(config.get_int("foo"), 1234) | ||||
|  | ||||
|     def test_get_int_invalid(self): | ||||
|         """Test get_int""" | ||||
|         config = ConfigLoader() | ||||
|         config.set("foo", "bar") | ||||
|         self.assertEqual(config.get_int("foo", 1234), 1234) | ||||
|  | ||||
| @ -19,7 +19,7 @@ from authentik.policies.types import CACHE_PREFIX, PolicyRequest, PolicyResult | ||||
| LOGGER = get_logger() | ||||
|  | ||||
| FORK_CTX = get_context("fork") | ||||
| CACHE_TIMEOUT = int(CONFIG.get("redis.cache_timeout_policies")) | ||||
| CACHE_TIMEOUT = CONFIG.get_int("redis.cache_timeout_policies") | ||||
| PROCESS_CLASS = FORK_CTX.Process | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -13,7 +13,7 @@ from authentik.policies.reputation.tasks import save_reputation | ||||
| from authentik.stages.identification.signals import identification_failed | ||||
|  | ||||
| LOGGER = get_logger() | ||||
| CACHE_TIMEOUT = int(CONFIG.get("redis.cache_timeout_reputation")) | ||||
| CACHE_TIMEOUT = CONFIG.get_int("redis.cache_timeout_reputation") | ||||
|  | ||||
|  | ||||
| def update_score(request: HttpRequest, identifier: str, amount: int): | ||||
|  | ||||
| @ -30,7 +30,7 @@ def get_install_id_raw(): | ||||
|         user=CONFIG.get("postgresql.user"), | ||||
|         password=CONFIG.get("postgresql.password"), | ||||
|         host=CONFIG.get("postgresql.host"), | ||||
|         port=int(CONFIG.get("postgresql.port")), | ||||
|         port=CONFIG.get_int("postgresql.port"), | ||||
|         sslmode=CONFIG.get("postgresql.sslmode"), | ||||
|         sslrootcert=CONFIG.get("postgresql.sslrootcert"), | ||||
|         sslcert=CONFIG.get("postgresql.sslcert"), | ||||
|  | ||||
| @ -190,14 +190,14 @@ if CONFIG.get_bool("redis.tls", False): | ||||
| _redis_url = ( | ||||
|     f"{_redis_protocol_prefix}:" | ||||
|     f"{quote_plus(CONFIG.get('redis.password'))}@{quote_plus(CONFIG.get('redis.host'))}:" | ||||
|     f"{int(CONFIG.get('redis.port'))}" | ||||
|     f"{CONFIG.get_int('redis.port')}" | ||||
| ) | ||||
|  | ||||
| CACHES = { | ||||
|     "default": { | ||||
|         "BACKEND": "django_redis.cache.RedisCache", | ||||
|         "LOCATION": f"{_redis_url}/{CONFIG.get('redis.db')}", | ||||
|         "TIMEOUT": int(CONFIG.get("redis.cache_timeout", 300)), | ||||
|         "TIMEOUT": CONFIG.get_int("redis.cache_timeout", 300), | ||||
|         "OPTIONS": {"CLIENT_CLASS": "django_redis.client.DefaultClient"}, | ||||
|         "KEY_PREFIX": "authentik_cache", | ||||
|     } | ||||
| @ -274,7 +274,7 @@ DATABASES = { | ||||
|         "NAME": CONFIG.get("postgresql.name"), | ||||
|         "USER": CONFIG.get("postgresql.user"), | ||||
|         "PASSWORD": CONFIG.get("postgresql.password"), | ||||
|         "PORT": int(CONFIG.get("postgresql.port")), | ||||
|         "PORT": CONFIG.get_int("postgresql.port"), | ||||
|         "SSLMODE": CONFIG.get("postgresql.sslmode"), | ||||
|         "SSLROOTCERT": CONFIG.get("postgresql.sslrootcert"), | ||||
|         "SSLCERT": CONFIG.get("postgresql.sslcert"), | ||||
| @ -293,12 +293,12 @@ if CONFIG.get_bool("postgresql.use_pgbouncer", False): | ||||
| # loads the config directly from CONFIG | ||||
| # See authentik/stages/email/models.py, line 105 | ||||
| EMAIL_HOST = CONFIG.get("email.host") | ||||
| EMAIL_PORT = int(CONFIG.get("email.port")) | ||||
| EMAIL_PORT = CONFIG.get_int("email.port") | ||||
| EMAIL_HOST_USER = CONFIG.get("email.username") | ||||
| EMAIL_HOST_PASSWORD = CONFIG.get("email.password") | ||||
| EMAIL_USE_TLS = CONFIG.get_bool("email.use_tls", False) | ||||
| EMAIL_USE_SSL = CONFIG.get_bool("email.use_ssl", False) | ||||
| EMAIL_TIMEOUT = int(CONFIG.get("email.timeout")) | ||||
| EMAIL_TIMEOUT = CONFIG.get_int("email.timeout") | ||||
| DEFAULT_FROM_EMAIL = CONFIG.get("email.from") | ||||
| SERVER_EMAIL = DEFAULT_FROM_EMAIL | ||||
| EMAIL_SUBJECT_PREFIX = "[authentik] " | ||||
|  | ||||
| @ -93,7 +93,7 @@ class BaseLDAPSynchronizer: | ||||
|         types_only=False, | ||||
|         get_operational_attributes=False, | ||||
|         controls=None, | ||||
|         paged_size=int(CONFIG.get("ldap.page_size", 50)), | ||||
|         paged_size=CONFIG.get_int("ldap.page_size", 50), | ||||
|         paged_criticality=False, | ||||
|     ): | ||||
|         """Search in pages, returns each page""" | ||||
|  | ||||
| @ -59,7 +59,7 @@ def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) -> | ||||
|     signatures = [] | ||||
|     for page in sync_inst.get_objects(): | ||||
|         page_cache_key = CACHE_KEY_PREFIX + str(uuid4()) | ||||
|         cache.set(page_cache_key, page, 60 * 60 * int(CONFIG.get("ldap.task_timeout_hours"))) | ||||
|         cache.set(page_cache_key, page, 60 * 60 * CONFIG.get_int("ldap.task_timeout_hours")) | ||||
|         page_sync = ldap_sync.si(source.pk, class_to_path(sync), page_cache_key) | ||||
|         signatures.append(page_sync) | ||||
|     return signatures | ||||
| @ -68,12 +68,12 @@ def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) -> | ||||
| @CELERY_APP.task( | ||||
|     bind=True, | ||||
|     base=MonitoredTask, | ||||
|     soft_time_limit=60 * 60 * int(CONFIG.get("ldap.task_timeout_hours")), | ||||
|     task_time_limit=60 * 60 * int(CONFIG.get("ldap.task_timeout_hours")), | ||||
|     soft_time_limit=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours"), | ||||
|     task_time_limit=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours"), | ||||
| ) | ||||
| def ldap_sync(self: MonitoredTask, source_pk: str, sync_class: str, page_cache_key: str): | ||||
|     """Synchronization of an LDAP Source""" | ||||
|     self.result_timeout_hours = int(CONFIG.get("ldap.task_timeout_hours")) | ||||
|     self.result_timeout_hours = CONFIG.get_int("ldap.task_timeout_hours") | ||||
|     source: LDAPSource = LDAPSource.objects.filter(pk=source_pk).first() | ||||
|     if not source: | ||||
|         # Because the source couldn't be found, we don't have a UID | ||||
|  | ||||
| @ -108,12 +108,12 @@ class EmailStage(Stage): | ||||
|             CONFIG.refresh("email.password") | ||||
|             return self.backend_class( | ||||
|                 host=CONFIG.get("email.host"), | ||||
|                 port=int(CONFIG.get("email.port")), | ||||
|                 port=CONFIG.get_int("email.port"), | ||||
|                 username=CONFIG.get("email.username"), | ||||
|                 password=CONFIG.get("email.password"), | ||||
|                 use_tls=CONFIG.get_bool("email.use_tls", False), | ||||
|                 use_ssl=CONFIG.get_bool("email.use_ssl", False), | ||||
|                 timeout=int(CONFIG.get("email.timeout")), | ||||
|                 timeout=CONFIG.get_int("email.timeout"), | ||||
|             ) | ||||
|         return self.backend_class( | ||||
|             host=self.host, | ||||
|  | ||||
| @ -80,8 +80,8 @@ if SERVICE_HOST_ENV_NAME in os.environ: | ||||
| else: | ||||
|     default_workers = max(cpu_count() * 0.25, 1) + 1  # Minimum of 2 workers | ||||
|  | ||||
| workers = int(CONFIG.get("web.workers", default_workers)) | ||||
| threads = int(CONFIG.get("web.threads", 4)) | ||||
| workers = CONFIG.get_int("web.workers", default_workers) | ||||
| threads = CONFIG.get_int("web.threads", 4) | ||||
|  | ||||
|  | ||||
| def post_fork(server: "Arbiter", worker: DjangoUvicornWorker): | ||||
|  | ||||
| @ -56,7 +56,7 @@ if __name__ == "__main__": | ||||
|         user=CONFIG.get("postgresql.user"), | ||||
|         password=CONFIG.get("postgresql.password"), | ||||
|         host=CONFIG.get("postgresql.host"), | ||||
|         port=int(CONFIG.get("postgresql.port")), | ||||
|         port=CONFIG.get_int("postgresql.port"), | ||||
|         sslmode=CONFIG.get("postgresql.sslmode"), | ||||
|         sslrootcert=CONFIG.get("postgresql.sslrootcert"), | ||||
|         sslcert=CONFIG.get("postgresql.sslcert"), | ||||
|  | ||||
| @ -28,7 +28,7 @@ while True: | ||||
|             user=CONFIG.get("postgresql.user"), | ||||
|             password=CONFIG.get("postgresql.password"), | ||||
|             host=CONFIG.get("postgresql.host"), | ||||
|             port=int(CONFIG.get("postgresql.port")), | ||||
|             port=CONFIG.get_int("postgresql.port"), | ||||
|             sslmode=CONFIG.get("postgresql.sslmode"), | ||||
|             sslrootcert=CONFIG.get("postgresql.sslrootcert"), | ||||
|             sslcert=CONFIG.get("postgresql.sslcert"), | ||||
| @ -47,7 +47,7 @@ if CONFIG.get_bool("redis.tls", False): | ||||
| REDIS_URL = ( | ||||
|     f"{REDIS_PROTOCOL_PREFIX}:" | ||||
|     f"{quote_plus(CONFIG.get('redis.password'))}@{quote_plus(CONFIG.get('redis.host'))}:" | ||||
|     f"{int(CONFIG.get('redis.port'))}/{CONFIG.get('redis.db')}" | ||||
|     f"{CONFIG.get_int('redis.port')}/{CONFIG.get('redis.db')}" | ||||
| ) | ||||
| while True: | ||||
|     try: | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	 Jens L
					Jens L