root: add get_int to config loader instead of casting to int everywhere (#6436)
* root: add get_int to config loader instead of casting to int everywhere Signed-off-by: Jens Langhammer <jens@goauthentik.io> * improve error handling, add test Signed-off-by: Jens Langhammer <jens@goauthentik.io> --------- Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
		| @ -93,10 +93,10 @@ class ConfigView(APIView): | |||||||
|                     "traces_sample_rate": float(CONFIG.get("error_reporting.sample_rate", 0.4)), |                     "traces_sample_rate": float(CONFIG.get("error_reporting.sample_rate", 0.4)), | ||||||
|                 }, |                 }, | ||||||
|                 "capabilities": self.get_capabilities(), |                 "capabilities": self.get_capabilities(), | ||||||
|                 "cache_timeout": int(CONFIG.get("redis.cache_timeout")), |                 "cache_timeout": CONFIG.get_int("redis.cache_timeout"), | ||||||
|                 "cache_timeout_flows": int(CONFIG.get("redis.cache_timeout_flows")), |                 "cache_timeout_flows": CONFIG.get_int("redis.cache_timeout_flows"), | ||||||
|                 "cache_timeout_policies": int(CONFIG.get("redis.cache_timeout_policies")), |                 "cache_timeout_policies": CONFIG.get_int("redis.cache_timeout_policies"), | ||||||
|                 "cache_timeout_reputation": int(CONFIG.get("redis.cache_timeout_reputation")), |                 "cache_timeout_reputation": CONFIG.get_int("redis.cache_timeout_reputation"), | ||||||
|             } |             } | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  | |||||||
| @ -60,7 +60,7 @@ def default_token_key(): | |||||||
|     """Default token key""" |     """Default token key""" | ||||||
|     # We use generate_id since the chars in the key should be easy |     # We use generate_id since the chars in the key should be easy | ||||||
|     # to use in Emails (for verification) and URLs (for recovery) |     # to use in Emails (for verification) and URLs (for recovery) | ||||||
|     return generate_id(int(CONFIG.get("default_token_length"))) |     return generate_id(CONFIG.get_int("default_token_length")) | ||||||
|  |  | ||||||
|  |  | ||||||
| class UserTypes(models.TextChoices): | class UserTypes(models.TextChoices): | ||||||
|  | |||||||
| @ -33,7 +33,7 @@ PLAN_CONTEXT_SOURCE = "source" | |||||||
| # Is set by the Flow Planner when a FlowToken was used, and the currently active flow plan | # Is set by the Flow Planner when a FlowToken was used, and the currently active flow plan | ||||||
| # was restored. | # was restored. | ||||||
| PLAN_CONTEXT_IS_RESTORED = "is_restored" | PLAN_CONTEXT_IS_RESTORED = "is_restored" | ||||||
| CACHE_TIMEOUT = int(CONFIG.get("redis.cache_timeout_flows")) | CACHE_TIMEOUT = CONFIG.get_int("redis.cache_timeout_flows") | ||||||
| CACHE_PREFIX = "goauthentik.io/flows/planner/" | CACHE_PREFIX = "goauthentik.io/flows/planner/" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| @ -213,6 +213,14 @@ class ConfigLoader: | |||||||
|         attr: Attr = get_path_from_dict(root, path, sep=sep, default=Attr(default)) |         attr: Attr = get_path_from_dict(root, path, sep=sep, default=Attr(default)) | ||||||
|         return attr.value |         return attr.value | ||||||
|  |  | ||||||
|  |     def get_int(self, path: str, default=0) -> int: | ||||||
|  |         """Wrapper for get that converts value into int""" | ||||||
|  |         try: | ||||||
|  |             return int(self.get(path, default)) | ||||||
|  |         except ValueError as exc: | ||||||
|  |             self.log("warning", "Failed to parse config as int", path=path, exc=str(exc)) | ||||||
|  |             return default | ||||||
|  |  | ||||||
|     def get_bool(self, path: str, default=False) -> bool: |     def get_bool(self, path: str, default=False) -> bool: | ||||||
|         """Wrapper for get that converts value into boolean""" |         """Wrapper for get that converts value into boolean""" | ||||||
|         return str(self.get(path, default)).lower() == "true" |         return str(self.get(path, default)).lower() == "true" | ||||||
|  | |||||||
| @ -79,3 +79,15 @@ class TestConfig(TestCase): | |||||||
|         config.update_from_file(file2_name) |         config.update_from_file(file2_name) | ||||||
|         unlink(file_name) |         unlink(file_name) | ||||||
|         unlink(file2_name) |         unlink(file2_name) | ||||||
|  |  | ||||||
|  |     def test_get_int(self): | ||||||
|  |         """Test get_int""" | ||||||
|  |         config = ConfigLoader() | ||||||
|  |         config.set("foo", 1234) | ||||||
|  |         self.assertEqual(config.get_int("foo"), 1234) | ||||||
|  |  | ||||||
|  |     def test_get_int_invalid(self): | ||||||
|  |         """Test get_int""" | ||||||
|  |         config = ConfigLoader() | ||||||
|  |         config.set("foo", "bar") | ||||||
|  |         self.assertEqual(config.get_int("foo", 1234), 1234) | ||||||
|  | |||||||
| @ -19,7 +19,7 @@ from authentik.policies.types import CACHE_PREFIX, PolicyRequest, PolicyResult | |||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  |  | ||||||
| FORK_CTX = get_context("fork") | FORK_CTX = get_context("fork") | ||||||
| CACHE_TIMEOUT = int(CONFIG.get("redis.cache_timeout_policies")) | CACHE_TIMEOUT = CONFIG.get_int("redis.cache_timeout_policies") | ||||||
| PROCESS_CLASS = FORK_CTX.Process | PROCESS_CLASS = FORK_CTX.Process | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| @ -13,7 +13,7 @@ from authentik.policies.reputation.tasks import save_reputation | |||||||
| from authentik.stages.identification.signals import identification_failed | from authentik.stages.identification.signals import identification_failed | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
| CACHE_TIMEOUT = int(CONFIG.get("redis.cache_timeout_reputation")) | CACHE_TIMEOUT = CONFIG.get_int("redis.cache_timeout_reputation") | ||||||
|  |  | ||||||
|  |  | ||||||
| def update_score(request: HttpRequest, identifier: str, amount: int): | def update_score(request: HttpRequest, identifier: str, amount: int): | ||||||
|  | |||||||
| @ -30,7 +30,7 @@ def get_install_id_raw(): | |||||||
|         user=CONFIG.get("postgresql.user"), |         user=CONFIG.get("postgresql.user"), | ||||||
|         password=CONFIG.get("postgresql.password"), |         password=CONFIG.get("postgresql.password"), | ||||||
|         host=CONFIG.get("postgresql.host"), |         host=CONFIG.get("postgresql.host"), | ||||||
|         port=int(CONFIG.get("postgresql.port")), |         port=CONFIG.get_int("postgresql.port"), | ||||||
|         sslmode=CONFIG.get("postgresql.sslmode"), |         sslmode=CONFIG.get("postgresql.sslmode"), | ||||||
|         sslrootcert=CONFIG.get("postgresql.sslrootcert"), |         sslrootcert=CONFIG.get("postgresql.sslrootcert"), | ||||||
|         sslcert=CONFIG.get("postgresql.sslcert"), |         sslcert=CONFIG.get("postgresql.sslcert"), | ||||||
|  | |||||||
| @ -190,14 +190,14 @@ if CONFIG.get_bool("redis.tls", False): | |||||||
| _redis_url = ( | _redis_url = ( | ||||||
|     f"{_redis_protocol_prefix}:" |     f"{_redis_protocol_prefix}:" | ||||||
|     f"{quote_plus(CONFIG.get('redis.password'))}@{quote_plus(CONFIG.get('redis.host'))}:" |     f"{quote_plus(CONFIG.get('redis.password'))}@{quote_plus(CONFIG.get('redis.host'))}:" | ||||||
|     f"{int(CONFIG.get('redis.port'))}" |     f"{CONFIG.get_int('redis.port')}" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| CACHES = { | CACHES = { | ||||||
|     "default": { |     "default": { | ||||||
|         "BACKEND": "django_redis.cache.RedisCache", |         "BACKEND": "django_redis.cache.RedisCache", | ||||||
|         "LOCATION": f"{_redis_url}/{CONFIG.get('redis.db')}", |         "LOCATION": f"{_redis_url}/{CONFIG.get('redis.db')}", | ||||||
|         "TIMEOUT": int(CONFIG.get("redis.cache_timeout", 300)), |         "TIMEOUT": CONFIG.get_int("redis.cache_timeout", 300), | ||||||
|         "OPTIONS": {"CLIENT_CLASS": "django_redis.client.DefaultClient"}, |         "OPTIONS": {"CLIENT_CLASS": "django_redis.client.DefaultClient"}, | ||||||
|         "KEY_PREFIX": "authentik_cache", |         "KEY_PREFIX": "authentik_cache", | ||||||
|     } |     } | ||||||
| @ -274,7 +274,7 @@ DATABASES = { | |||||||
|         "NAME": CONFIG.get("postgresql.name"), |         "NAME": CONFIG.get("postgresql.name"), | ||||||
|         "USER": CONFIG.get("postgresql.user"), |         "USER": CONFIG.get("postgresql.user"), | ||||||
|         "PASSWORD": CONFIG.get("postgresql.password"), |         "PASSWORD": CONFIG.get("postgresql.password"), | ||||||
|         "PORT": int(CONFIG.get("postgresql.port")), |         "PORT": CONFIG.get_int("postgresql.port"), | ||||||
|         "SSLMODE": CONFIG.get("postgresql.sslmode"), |         "SSLMODE": CONFIG.get("postgresql.sslmode"), | ||||||
|         "SSLROOTCERT": CONFIG.get("postgresql.sslrootcert"), |         "SSLROOTCERT": CONFIG.get("postgresql.sslrootcert"), | ||||||
|         "SSLCERT": CONFIG.get("postgresql.sslcert"), |         "SSLCERT": CONFIG.get("postgresql.sslcert"), | ||||||
| @ -293,12 +293,12 @@ if CONFIG.get_bool("postgresql.use_pgbouncer", False): | |||||||
| # loads the config directly from CONFIG | # loads the config directly from CONFIG | ||||||
| # See authentik/stages/email/models.py, line 105 | # See authentik/stages/email/models.py, line 105 | ||||||
| EMAIL_HOST = CONFIG.get("email.host") | EMAIL_HOST = CONFIG.get("email.host") | ||||||
| EMAIL_PORT = int(CONFIG.get("email.port")) | EMAIL_PORT = CONFIG.get_int("email.port") | ||||||
| EMAIL_HOST_USER = CONFIG.get("email.username") | EMAIL_HOST_USER = CONFIG.get("email.username") | ||||||
| EMAIL_HOST_PASSWORD = CONFIG.get("email.password") | EMAIL_HOST_PASSWORD = CONFIG.get("email.password") | ||||||
| EMAIL_USE_TLS = CONFIG.get_bool("email.use_tls", False) | EMAIL_USE_TLS = CONFIG.get_bool("email.use_tls", False) | ||||||
| EMAIL_USE_SSL = CONFIG.get_bool("email.use_ssl", False) | EMAIL_USE_SSL = CONFIG.get_bool("email.use_ssl", False) | ||||||
| EMAIL_TIMEOUT = int(CONFIG.get("email.timeout")) | EMAIL_TIMEOUT = CONFIG.get_int("email.timeout") | ||||||
| DEFAULT_FROM_EMAIL = CONFIG.get("email.from") | DEFAULT_FROM_EMAIL = CONFIG.get("email.from") | ||||||
| SERVER_EMAIL = DEFAULT_FROM_EMAIL | SERVER_EMAIL = DEFAULT_FROM_EMAIL | ||||||
| EMAIL_SUBJECT_PREFIX = "[authentik] " | EMAIL_SUBJECT_PREFIX = "[authentik] " | ||||||
|  | |||||||
| @ -93,7 +93,7 @@ class BaseLDAPSynchronizer: | |||||||
|         types_only=False, |         types_only=False, | ||||||
|         get_operational_attributes=False, |         get_operational_attributes=False, | ||||||
|         controls=None, |         controls=None, | ||||||
|         paged_size=int(CONFIG.get("ldap.page_size", 50)), |         paged_size=CONFIG.get_int("ldap.page_size", 50), | ||||||
|         paged_criticality=False, |         paged_criticality=False, | ||||||
|     ): |     ): | ||||||
|         """Search in pages, returns each page""" |         """Search in pages, returns each page""" | ||||||
|  | |||||||
| @ -59,7 +59,7 @@ def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) -> | |||||||
|     signatures = [] |     signatures = [] | ||||||
|     for page in sync_inst.get_objects(): |     for page in sync_inst.get_objects(): | ||||||
|         page_cache_key = CACHE_KEY_PREFIX + str(uuid4()) |         page_cache_key = CACHE_KEY_PREFIX + str(uuid4()) | ||||||
|         cache.set(page_cache_key, page, 60 * 60 * int(CONFIG.get("ldap.task_timeout_hours"))) |         cache.set(page_cache_key, page, 60 * 60 * CONFIG.get_int("ldap.task_timeout_hours")) | ||||||
|         page_sync = ldap_sync.si(source.pk, class_to_path(sync), page_cache_key) |         page_sync = ldap_sync.si(source.pk, class_to_path(sync), page_cache_key) | ||||||
|         signatures.append(page_sync) |         signatures.append(page_sync) | ||||||
|     return signatures |     return signatures | ||||||
| @ -68,12 +68,12 @@ def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) -> | |||||||
| @CELERY_APP.task( | @CELERY_APP.task( | ||||||
|     bind=True, |     bind=True, | ||||||
|     base=MonitoredTask, |     base=MonitoredTask, | ||||||
|     soft_time_limit=60 * 60 * int(CONFIG.get("ldap.task_timeout_hours")), |     soft_time_limit=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours"), | ||||||
|     task_time_limit=60 * 60 * int(CONFIG.get("ldap.task_timeout_hours")), |     task_time_limit=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours"), | ||||||
| ) | ) | ||||||
| def ldap_sync(self: MonitoredTask, source_pk: str, sync_class: str, page_cache_key: str): | def ldap_sync(self: MonitoredTask, source_pk: str, sync_class: str, page_cache_key: str): | ||||||
|     """Synchronization of an LDAP Source""" |     """Synchronization of an LDAP Source""" | ||||||
|     self.result_timeout_hours = int(CONFIG.get("ldap.task_timeout_hours")) |     self.result_timeout_hours = CONFIG.get_int("ldap.task_timeout_hours") | ||||||
|     source: LDAPSource = LDAPSource.objects.filter(pk=source_pk).first() |     source: LDAPSource = LDAPSource.objects.filter(pk=source_pk).first() | ||||||
|     if not source: |     if not source: | ||||||
|         # Because the source couldn't be found, we don't have a UID |         # Because the source couldn't be found, we don't have a UID | ||||||
|  | |||||||
| @ -108,12 +108,12 @@ class EmailStage(Stage): | |||||||
|             CONFIG.refresh("email.password") |             CONFIG.refresh("email.password") | ||||||
|             return self.backend_class( |             return self.backend_class( | ||||||
|                 host=CONFIG.get("email.host"), |                 host=CONFIG.get("email.host"), | ||||||
|                 port=int(CONFIG.get("email.port")), |                 port=CONFIG.get_int("email.port"), | ||||||
|                 username=CONFIG.get("email.username"), |                 username=CONFIG.get("email.username"), | ||||||
|                 password=CONFIG.get("email.password"), |                 password=CONFIG.get("email.password"), | ||||||
|                 use_tls=CONFIG.get_bool("email.use_tls", False), |                 use_tls=CONFIG.get_bool("email.use_tls", False), | ||||||
|                 use_ssl=CONFIG.get_bool("email.use_ssl", False), |                 use_ssl=CONFIG.get_bool("email.use_ssl", False), | ||||||
|                 timeout=int(CONFIG.get("email.timeout")), |                 timeout=CONFIG.get_int("email.timeout"), | ||||||
|             ) |             ) | ||||||
|         return self.backend_class( |         return self.backend_class( | ||||||
|             host=self.host, |             host=self.host, | ||||||
|  | |||||||
| @ -80,8 +80,8 @@ if SERVICE_HOST_ENV_NAME in os.environ: | |||||||
| else: | else: | ||||||
|     default_workers = max(cpu_count() * 0.25, 1) + 1  # Minimum of 2 workers |     default_workers = max(cpu_count() * 0.25, 1) + 1  # Minimum of 2 workers | ||||||
|  |  | ||||||
| workers = int(CONFIG.get("web.workers", default_workers)) | workers = CONFIG.get_int("web.workers", default_workers) | ||||||
| threads = int(CONFIG.get("web.threads", 4)) | threads = CONFIG.get_int("web.threads", 4) | ||||||
|  |  | ||||||
|  |  | ||||||
| def post_fork(server: "Arbiter", worker: DjangoUvicornWorker): | def post_fork(server: "Arbiter", worker: DjangoUvicornWorker): | ||||||
|  | |||||||
| @ -56,7 +56,7 @@ if __name__ == "__main__": | |||||||
|         user=CONFIG.get("postgresql.user"), |         user=CONFIG.get("postgresql.user"), | ||||||
|         password=CONFIG.get("postgresql.password"), |         password=CONFIG.get("postgresql.password"), | ||||||
|         host=CONFIG.get("postgresql.host"), |         host=CONFIG.get("postgresql.host"), | ||||||
|         port=int(CONFIG.get("postgresql.port")), |         port=CONFIG.get_int("postgresql.port"), | ||||||
|         sslmode=CONFIG.get("postgresql.sslmode"), |         sslmode=CONFIG.get("postgresql.sslmode"), | ||||||
|         sslrootcert=CONFIG.get("postgresql.sslrootcert"), |         sslrootcert=CONFIG.get("postgresql.sslrootcert"), | ||||||
|         sslcert=CONFIG.get("postgresql.sslcert"), |         sslcert=CONFIG.get("postgresql.sslcert"), | ||||||
|  | |||||||
| @ -28,7 +28,7 @@ while True: | |||||||
|             user=CONFIG.get("postgresql.user"), |             user=CONFIG.get("postgresql.user"), | ||||||
|             password=CONFIG.get("postgresql.password"), |             password=CONFIG.get("postgresql.password"), | ||||||
|             host=CONFIG.get("postgresql.host"), |             host=CONFIG.get("postgresql.host"), | ||||||
|             port=int(CONFIG.get("postgresql.port")), |             port=CONFIG.get_int("postgresql.port"), | ||||||
|             sslmode=CONFIG.get("postgresql.sslmode"), |             sslmode=CONFIG.get("postgresql.sslmode"), | ||||||
|             sslrootcert=CONFIG.get("postgresql.sslrootcert"), |             sslrootcert=CONFIG.get("postgresql.sslrootcert"), | ||||||
|             sslcert=CONFIG.get("postgresql.sslcert"), |             sslcert=CONFIG.get("postgresql.sslcert"), | ||||||
| @ -47,7 +47,7 @@ if CONFIG.get_bool("redis.tls", False): | |||||||
| REDIS_URL = ( | REDIS_URL = ( | ||||||
|     f"{REDIS_PROTOCOL_PREFIX}:" |     f"{REDIS_PROTOCOL_PREFIX}:" | ||||||
|     f"{quote_plus(CONFIG.get('redis.password'))}@{quote_plus(CONFIG.get('redis.host'))}:" |     f"{quote_plus(CONFIG.get('redis.password'))}@{quote_plus(CONFIG.get('redis.host'))}:" | ||||||
|     f"{int(CONFIG.get('redis.port'))}/{CONFIG.get('redis.db')}" |     f"{CONFIG.get_int('redis.port')}/{CONFIG.get('redis.db')}" | ||||||
| ) | ) | ||||||
| while True: | while True: | ||||||
|     try: |     try: | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user
	 Jens L
					Jens L