root: add primary-replica db router (#9479)
* root: add primary-replica db router Signed-off-by: Jens Langhammer <jens@goauthentik.io> * copy all settings for database replicas Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space> * refresh read replicas config, switch to using a dict instead of a list for easier refresh Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space> * add test for get_keys Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space> * fix getting override Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space> * lint Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space> * nosec Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space> * small fixes Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space> * fix replica settings Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space> * generate config: add a dummy read replica Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space> * add doc Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space> * add healthchecks for replicas Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space> * fix Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space> * add note about hot reloading Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space> --------- Signed-off-by: Jens Langhammer <jens@goauthentik.io> Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space> Co-authored-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
		@ -304,6 +304,12 @@ class ConfigLoader:
 | 
			
		||||
        """Wrapper for get that converts value into boolean"""
 | 
			
		||||
        return str(self.get(path, default)).lower() == "true"
 | 
			
		||||
 | 
			
		||||
    def get_keys(self, path: str, sep=".") -> list[str]:
 | 
			
		||||
        """List attribute keys by using yaml path"""
 | 
			
		||||
        root = self.raw
 | 
			
		||||
        attr: Attr = get_path_from_dict(root, path, sep=sep, default=Attr({}))
 | 
			
		||||
        return attr.keys()
 | 
			
		||||
 | 
			
		||||
    def get_dict_from_b64_json(self, path: str, default=None) -> dict:
 | 
			
		||||
        """Wrapper for get that converts value from Base64 encoded string into dictionary"""
 | 
			
		||||
        config_value = self.get(path)
 | 
			
		||||
 | 
			
		||||
@ -10,6 +10,10 @@ postgresql:
 | 
			
		||||
  use_pgpool: false
 | 
			
		||||
  test:
 | 
			
		||||
    name: test_authentik
 | 
			
		||||
  read_replicas: {}
 | 
			
		||||
  # For example
 | 
			
		||||
  # 0:
 | 
			
		||||
  #   host: replica1.example.com
 | 
			
		||||
 | 
			
		||||
listen:
 | 
			
		||||
  listen_http: 0.0.0.0:9000
 | 
			
		||||
 | 
			
		||||
@ -169,3 +169,9 @@ class TestConfig(TestCase):
 | 
			
		||||
        self.assertEqual(config.get("cache.timeout_flows"), "32m")
 | 
			
		||||
        self.assertEqual(config.get("cache.timeout_policies"), "3920ns")
 | 
			
		||||
        self.assertEqual(config.get("cache.timeout_reputation"), "298382us")
 | 
			
		||||
 | 
			
		||||
    def test_get_keys(self):
 | 
			
		||||
        """Test get_keys"""
 | 
			
		||||
        config = ConfigLoader()
 | 
			
		||||
        config.set("foo.bar", "baz")
 | 
			
		||||
        self.assertEqual(list(config.get_keys("foo")), ["bar"])
 | 
			
		||||
 | 
			
		||||
@ -10,8 +10,15 @@ class DatabaseWrapper(BaseDatabaseWrapper):
 | 
			
		||||
 | 
			
		||||
    def get_connection_params(self):
 | 
			
		||||
        """Refresh DB credentials before getting connection params"""
 | 
			
		||||
        CONFIG.refresh("postgresql.password")
 | 
			
		||||
        conn_params = super().get_connection_params()
 | 
			
		||||
        conn_params["user"] = CONFIG.get("postgresql.user")
 | 
			
		||||
        conn_params["password"] = CONFIG.get("postgresql.password")
 | 
			
		||||
 | 
			
		||||
        prefix = "postgresql"
 | 
			
		||||
        if self.alias.startswith("replica_"):
 | 
			
		||||
            prefix = f"postgresql.read_replicas.{self.alias.removeprefix('replica_')}"
 | 
			
		||||
 | 
			
		||||
        for setting in ("host", "port", "user", "password"):
 | 
			
		||||
            conn_params[setting] = CONFIG.refresh(f"{prefix}.{setting}")
 | 
			
		||||
            if conn_params[setting] is None and self.alias.startswith("replica_"):
 | 
			
		||||
                conn_params[setting] = CONFIG.refresh(f"postgresql.{setting}")
 | 
			
		||||
 | 
			
		||||
        return conn_params
 | 
			
		||||
 | 
			
		||||
@ -47,8 +47,8 @@ class ReadyView(View):
 | 
			
		||||
 | 
			
		||||
    def dispatch(self, request: HttpRequest) -> HttpResponse:
 | 
			
		||||
        try:
 | 
			
		||||
            db_conn = connections["default"]
 | 
			
		||||
            _ = db_conn.cursor()
 | 
			
		||||
            for db_conn in connections.all():
 | 
			
		||||
                _ = db_conn.cursor()
 | 
			
		||||
        except OperationalError:  # pragma: no cover
 | 
			
		||||
            return HttpResponse(status=503)
 | 
			
		||||
        try:
 | 
			
		||||
 | 
			
		||||
@ -293,7 +293,7 @@ DATABASES = {
 | 
			
		||||
        "NAME": CONFIG.get("postgresql.name"),
 | 
			
		||||
        "USER": CONFIG.get("postgresql.user"),
 | 
			
		||||
        "PASSWORD": CONFIG.get("postgresql.password"),
 | 
			
		||||
        "PORT": CONFIG.get_int("postgresql.port"),
 | 
			
		||||
        "PORT": CONFIG.get("postgresql.port"),
 | 
			
		||||
        "SSLMODE": CONFIG.get("postgresql.sslmode"),
 | 
			
		||||
        "SSLROOTCERT": CONFIG.get("postgresql.sslrootcert"),
 | 
			
		||||
        "SSLCERT": CONFIG.get("postgresql.sslcert"),
 | 
			
		||||
@ -313,7 +313,23 @@ if CONFIG.get_bool("postgresql.use_pgbouncer", False):
 | 
			
		||||
    # https://docs.djangoproject.com/en/4.0/ref/databases/#persistent-connections
 | 
			
		||||
    DATABASES["default"]["CONN_MAX_AGE"] = None  # persistent
 | 
			
		||||
 | 
			
		||||
DATABASE_ROUTERS = ("django_tenants.routers.TenantSyncRouter",)
 | 
			
		||||
for replica in CONFIG.get_keys("postgresql.read_replicas"):
 | 
			
		||||
    _database = DATABASES["default"].copy()
 | 
			
		||||
    for setting in DATABASES["default"].keys():
 | 
			
		||||
        default = object()
 | 
			
		||||
        if setting in ("TEST",):
 | 
			
		||||
            continue
 | 
			
		||||
        override = CONFIG.get(
 | 
			
		||||
            f"postgresql.read_replicas.{replica}.{setting.lower()}", default=default
 | 
			
		||||
        )
 | 
			
		||||
        if override is not default:
 | 
			
		||||
            _database[setting] = override
 | 
			
		||||
    DATABASES[f"replica_{replica}"] = _database
 | 
			
		||||
 | 
			
		||||
DATABASE_ROUTERS = (
 | 
			
		||||
    "authentik.tenants.db.FailoverRouter",
 | 
			
		||||
    "django_tenants.routers.TenantSyncRouter",
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
# Email
 | 
			
		||||
# These values should never actually be used, emails are only sent from email stages, which
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										29
									
								
								authentik/tenants/db.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								authentik/tenants/db.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,29 @@
 | 
			
		||||
from random import choice
 | 
			
		||||
 | 
			
		||||
from django.conf import settings
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FailoverRouter:
 | 
			
		||||
    """Support an primary/read-replica PostgreSQL setup (reading from replicas
 | 
			
		||||
    and write to primary only)"""
 | 
			
		||||
 | 
			
		||||
    def __init__(self) -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.database_aliases = set(settings.DATABASES.keys())
 | 
			
		||||
        self.read_replica_aliases = list(self.database_aliases - {"default"})
 | 
			
		||||
        self.replica_enabled = len(self.read_replica_aliases) > 0
 | 
			
		||||
 | 
			
		||||
    def db_for_read(self, model, **hints):
 | 
			
		||||
        if not self.replica_enabled:
 | 
			
		||||
            return "default"
 | 
			
		||||
        return choice(self.read_replica_aliases)  # nosec
 | 
			
		||||
 | 
			
		||||
    def db_for_write(self, model, **hints):
 | 
			
		||||
        return "default"
 | 
			
		||||
 | 
			
		||||
    def allow_relation(self, obj1, obj2, **hints):
 | 
			
		||||
        """Relations between objects are allowed if both objects are
 | 
			
		||||
        in the primary/replica pool."""
 | 
			
		||||
        if obj1._state.db in self.database_aliases and obj2._state.db in self.database_aliases:
 | 
			
		||||
            return True
 | 
			
		||||
        return None
 | 
			
		||||
		Reference in New Issue
	
	Block a user