From 217063ef7b9fce217d279b53a2848cdb623975c9 Mon Sep 17 00:00:00 2001 From: Marc 'risson' Schmitt Date: Thu, 12 Jun 2025 12:29:13 +0200 Subject: [PATCH] fix db connection middleware Signed-off-by: Marc 'risson' Schmitt --- authentik/providers/scim/tests/test_user.py | 15 ++++++++------- authentik/tasks/broker.py | 7 +++++-- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/authentik/providers/scim/tests/test_user.py b/authentik/providers/scim/tests/test_user.py index d4e2698aa1..1daa2e9dd2 100644 --- a/authentik/providers/scim/tests/test_user.py +++ b/authentik/providers/scim/tests/test_user.py @@ -3,6 +3,7 @@ from json import loads from django.test import TestCase +from django.utils.text import slugify from jsonschema import validate from requests_mock import Mocker @@ -11,7 +12,8 @@ from authentik.core.models import Application, Group, User from authentik.lib.generators import generate_id from authentik.lib.sync.outgoing.base import SAFE_METHODS from authentik.providers.scim.models import SCIMMapping, SCIMProvider -from authentik.providers.scim.tasks import scim_sync, sync_tasks +from authentik.providers.scim.tasks import scim_sync, scim_sync_objects +from authentik.tasks.models import Task from authentik.tenants.models import Tenant @@ -354,7 +356,7 @@ class SCIMUserTests(TestCase): email=f"{uid}@goauthentik.io", ) - sync_tasks.trigger_single_task(self.provider, scim_sync).get() + scim_sync.send(self.provider.pk) self.assertEqual(mock.call_count, 5) self.assertEqual(mock.request_history[0].method, "GET") @@ -426,15 +428,14 @@ class SCIMUserTests(TestCase): email=f"{uid}@goauthentik.io", ) - sync_tasks.trigger_single_task(self.provider, scim_sync).get() + scim_sync.send(self.provider.pk) self.assertEqual(mock.call_count, 3) for request in mock.request_history: self.assertIn(request.method, SAFE_METHODS) - drop_msg = {} - # task = Task.objects.filter(uid=slugify(self.provider.name)).first() - # self.assertIsNotNone(task) - # drop_msg = task.messages[3] + task = Task.objects.filter(actor_name=scim_sync_objects.actor_name).first() + self.assertIsNotNone(task) + drop_msg = task._messages[2] self.assertEqual(drop_msg["event"], "Dropping mutating request due to dry run") self.assertIsNotNone(drop_msg["attributes"]["url"]) self.assertIsNotNone(drop_msg["attributes"]["body"]) diff --git a/authentik/tasks/broker.py b/authentik/tasks/broker.py index 21f5df249c..39cfc3e4e2 100644 --- a/authentik/tasks/broker.py +++ b/authentik/tasks/broker.py @@ -6,6 +6,7 @@ from queue import Empty, Queue from random import randint import tenacity +from django.conf import settings from django.db import ( DEFAULT_DB_ALIAS, DatabaseError, @@ -61,13 +62,15 @@ def raise_connection_error(func): class DbConnectionMiddleware(Middleware): def _close_old_connections(self, *args, **kwargs): + if settings.TEST: + return close_old_connections() # TODO: figure out if we really need this, it seems a bit excessive to close connections after # each message and if fails in tests - # before_process_message = _close_old_connections - # after_process_message = _close_old_connections + before_process_message = _close_old_connections + after_process_message = _close_old_connections def _close_connections(self, *args, **kwargs): connections.close_all()