diff --git a/authentik/enterprise/audit/middleware.py b/authentik/enterprise/audit/middleware.py index 9a621721eb..542e04d871 100644 --- a/authentik/enterprise/audit/middleware.py +++ b/authentik/enterprise/audit/middleware.py @@ -11,7 +11,6 @@ from django.db.models.expressions import BaseExpression, Combinable from django.db.models.signals import post_init from django.http import HttpRequest -from authentik.core.models import User from authentik.events.middleware import AuditMiddleware, should_log_model from authentik.events.utils import cleanse_dict, sanitize_item @@ -28,13 +27,10 @@ class EnterpriseAuditMiddleware(AuditMiddleware): super().connect(request) if not self.enabled: return - user = getattr(request, "user", self.anonymous_user) - if not user.is_authenticated: - user = self.anonymous_user if not hasattr(request, "request_id"): return post_init.connect( - partial(self.post_init_handler, user=user, request=request), + partial(self.post_init_handler, request=request), dispatch_uid=request.request_id, weak=False, ) @@ -76,7 +72,7 @@ class EnterpriseAuditMiddleware(AuditMiddleware): diff[key] = {"previous_value": value, "new_value": after.get(key)} return sanitize_item(diff) - def post_init_handler(self, user: User, request: HttpRequest, sender, instance: Model, **_): + def post_init_handler(self, request: HttpRequest, sender, instance: Model, **_): """post_init django model handler""" if not should_log_model(instance): return @@ -90,7 +86,6 @@ class EnterpriseAuditMiddleware(AuditMiddleware): def post_save_handler( self, - user: User, request: HttpRequest, sender, instance: Model, @@ -112,6 +107,4 @@ class EnterpriseAuditMiddleware(AuditMiddleware): for field_set in ignored_field_sets: if set(diff.keys()) == set(field_set): return None - return super().post_save_handler( - user, request, sender, instance, created, thread_kwargs, **_ - ) + return super().post_save_handler(request, sender, instance, created, thread_kwargs, **_) diff --git a/authentik/events/middleware.py b/authentik/events/middleware.py index eec473202a..4f7ebbf37e 100644 --- a/authentik/events/middleware.py +++ b/authentik/events/middleware.py @@ -110,26 +110,32 @@ class AuditMiddleware: self.anonymous_user = get_anonymous_user() + def get_user(self, request: HttpRequest) -> User: + user = _CTX_OVERWRITE_USER.get() + if user: + return user + user = getattr(request, "user", self.anonymous_user) + if not user.is_authenticated: + return self.anonymous_user + return user + def connect(self, request: HttpRequest): """Connect signal for automatic logging""" self._ensure_fallback_user() - user = getattr(request, "user", self.anonymous_user) - if not user.is_authenticated: - user = self.anonymous_user if not hasattr(request, "request_id"): return post_save.connect( - partial(self.post_save_handler, user=user, request=request), + partial(self.post_save_handler, request=request), dispatch_uid=request.request_id, weak=False, ) pre_delete.connect( - partial(self.pre_delete_handler, user=user, request=request), + partial(self.pre_delete_handler, request=request), dispatch_uid=request.request_id, weak=False, ) m2m_changed.connect( - partial(self.m2m_changed_handler, user=user, request=request), + partial(self.m2m_changed_handler, request=request), dispatch_uid=request.request_id, weak=False, ) @@ -174,7 +180,6 @@ class AuditMiddleware: def post_save_handler( self, - user: User, request: HttpRequest, sender, instance: Model, @@ -187,22 +192,20 @@ class AuditMiddleware: return if _CTX_IGNORE.get(): return - if _new_user := _CTX_OVERWRITE_USER.get(): - user = _new_user + user = self.get_user(request) action = EventAction.MODEL_CREATED if created else EventAction.MODEL_UPDATED thread = EventNewThread(action, request, user=user, model=model_to_dict(instance)) thread.kwargs.update(thread_kwargs or {}) thread.run() - def pre_delete_handler(self, user: User, request: HttpRequest, sender, instance: Model, **_): + def pre_delete_handler(self, request: HttpRequest, sender, instance: Model, **_): """Signal handler for all object's pre_delete""" if not should_log_model(instance): # pragma: no cover return if _CTX_IGNORE.get(): return - if _new_user := _CTX_OVERWRITE_USER.get(): - user = _new_user + user = self.get_user(request) EventNewThread( EventAction.MODEL_DELETED, @@ -211,9 +214,7 @@ class AuditMiddleware: model=model_to_dict(instance), ).run() - def m2m_changed_handler( - self, user: User, request: HttpRequest, sender, instance: Model, action: str, **_ - ): + def m2m_changed_handler(self, request: HttpRequest, sender, instance: Model, action: str, **_): """Signal handler for all object's m2m_changed""" if action not in ["pre_add", "pre_remove", "post_clear"]: return @@ -221,8 +222,7 @@ class AuditMiddleware: return if _CTX_IGNORE.get(): return - if _new_user := _CTX_OVERWRITE_USER.get(): - user = _new_user + user = self.get_user(request) EventNewThread( EventAction.MODEL_UPDATED, diff --git a/authentik/events/tests/test_middleware.py b/authentik/events/tests/test_middleware.py index 906deb030d..a5d721a4f0 100644 --- a/authentik/events/tests/test_middleware.py +++ b/authentik/events/tests/test_middleware.py @@ -3,7 +3,7 @@ from django.urls import reverse from rest_framework.test import APITestCase -from authentik.core.models import Application +from authentik.core.models import Application, Token, TokenIntents from authentik.core.tests.utils import create_test_admin_user from authentik.events.middleware import audit_ignore, audit_overwrite_user from authentik.events.models import Event, EventAction @@ -27,14 +27,13 @@ class TestEventsMiddleware(APITestCase): data={"name": uid, "slug": uid}, ) self.assertTrue(Application.objects.filter(name=uid).exists()) - self.assertTrue( - Event.objects.filter( - action=EventAction.MODEL_CREATED, - context__model__model_name="application", - context__model__app="authentik_core", - context__model__name=uid, - ).exists() - ) + event = Event.objects.filter( + action=EventAction.MODEL_CREATED, + context__model__model_name="application", + context__model__app="authentik_core", + context__model__name=uid, + ).first() + self.assertIsNotNone(event) def test_delete(self): """Test model creation event""" @@ -88,3 +87,30 @@ class TestEventsMiddleware(APITestCase): user__username=new_user.username, ).exists() ) + + def test_create_with_api(self): + """Test model creation event (with API token auth)""" + self.client.logout() + token = Token.objects.create(user=self.user, intent=TokenIntents.INTENT_API, expiring=False) + uid = generate_id() + self.client.post( + reverse("authentik_api:application-list"), + data={"name": uid, "slug": uid}, + HTTP_AUTHORIZATION=f"Bearer {token.key}", + ) + self.assertTrue(Application.objects.filter(name=uid).exists()) + event = Event.objects.filter( + action=EventAction.MODEL_CREATED, + context__model__model_name="application", + context__model__app="authentik_core", + context__model__name=uid, + ).first() + self.assertIsNotNone(event) + self.assertEqual( + event.user, + { + "pk": self.user.pk, + "email": self.user.email, + "username": self.user.username, + }, + )