maybe fix tenant tests
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
@ -8,4 +8,8 @@ LOGGER = get_logger()
|
||||
class Broker(PostgresBroker):
|
||||
@property
|
||||
def query_set(self) -> QuerySet:
|
||||
return self.model.objects.select_related("tenant").using(self.db_alias)
|
||||
return (
|
||||
self.model.objects.select_related("tenant")
|
||||
.using(self.db_alias)
|
||||
.filter(tenant__ready=True)
|
||||
)
|
||||
|
@ -37,6 +37,8 @@ class TenantMiddleware(Middleware):
|
||||
def after_process_message(self, *args, **kwargs):
|
||||
Tenant.deactivate()
|
||||
|
||||
after_skip_message = after_process_message
|
||||
|
||||
|
||||
class RelObjMiddleware(Middleware):
|
||||
@property
|
||||
|
@ -5,6 +5,7 @@ import uuid
|
||||
import django.db.models.deletion
|
||||
import django_tenants.postgresql_backend.base
|
||||
from django.db import migrations, models
|
||||
from django_tenants.utils import get_tenant_base_schema
|
||||
|
||||
import authentik.lib.utils.time
|
||||
import authentik.tenants.models
|
||||
@ -144,7 +145,7 @@ class Migration(migrations.Migration):
|
||||
),
|
||||
migrations.RunPython(code=create_default_tenant, reverse_code=migrations.RunPython.noop),
|
||||
migrations.RunSQL(
|
||||
sql="CREATE SCHEMA IF NOT EXISTS template;",
|
||||
reverse_sql="DROP SCHEMA IF EXISTS template;",
|
||||
sql=f"CREATE SCHEMA IF NOT EXISTS {get_tenant_base_schema()};",
|
||||
reverse_sql=f"DROP SCHEMA IF EXISTS {get_tenant_base_schema()};",
|
||||
),
|
||||
]
|
||||
|
@ -4,6 +4,7 @@ import re
|
||||
from uuid import uuid4
|
||||
|
||||
from django.apps import apps
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.core.validators import MaxValueValidator, MinValueValidator
|
||||
from django.db import models
|
||||
@ -11,6 +12,7 @@ from django.db.utils import IntegrityError
|
||||
from django.dispatch import receiver
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django_tenants.models import DomainMixin, TenantMixin, post_schema_sync
|
||||
from django_tenants.utils import get_tenant_base_schema
|
||||
from rest_framework.serializers import Serializer
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
@ -113,8 +115,8 @@ class Tenant(TenantMixin, SerializerModel):
|
||||
)
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
if self.schema_name == "template":
|
||||
raise IntegrityError("Cannot create schema named template")
|
||||
if self.schema_name == get_tenant_base_schema() and not settings.TEST:
|
||||
raise IntegrityError(f"Cannot create schema named {self.schema_name}")
|
||||
super().save(*args, **kwargs)
|
||||
|
||||
@property
|
||||
|
@ -1,7 +1,10 @@
|
||||
from django.core.management import call_command
|
||||
from django.db import connection, connections
|
||||
from django_tenants.utils import get_public_schema_name, get_tenant_base_schema, schema_context
|
||||
from rest_framework.test import APITransactionTestCase
|
||||
|
||||
from authentik.tenants.models import Tenant
|
||||
|
||||
|
||||
class TenantAPITestCase(APITransactionTestCase):
|
||||
# Overridden to also remove additional schemas we may have created
|
||||
@ -17,7 +20,12 @@ class TenantAPITestCase(APITransactionTestCase):
|
||||
super()._fixture_teardown()
|
||||
|
||||
def setUp(self):
|
||||
call_command("migrate_schemas", schema="template", tenant=True)
|
||||
with schema_context(get_public_schema_name()):
|
||||
Tenant.objects.update_or_create(
|
||||
defaults={"name": "Template", "ready": False},
|
||||
schema_name=get_tenant_base_schema(),
|
||||
)
|
||||
call_command("migrate_schemas", schema=get_tenant_base_schema(), tenant=True)
|
||||
|
||||
def assertSchemaExists(self, schema_name):
|
||||
with connection.cursor() as cursor:
|
||||
@ -28,7 +36,8 @@ class TenantAPITestCase(APITransactionTestCase):
|
||||
self.assertEqual(cursor.rowcount, 1)
|
||||
|
||||
cursor.execute(
|
||||
"SELECT * FROM information_schema.tables WHERE table_schema = 'template'"
|
||||
"SELECT * FROM information_schema.tables WHERE table_schema = %(schema_name)s",
|
||||
{"schema_name": get_tenant_base_schema()},
|
||||
)
|
||||
expected_tables = cursor.rowcount
|
||||
cursor.execute(
|
||||
|
Reference in New Issue
Block a user