Files
authentik/authentik/enterprise/audit/middleware.py
2025-02-17 13:43:18 +01:00

143 lines
5.2 KiB
Python

"""Enterprise audit middleware"""
from copy import deepcopy
from functools import partial
from typing import Any
from django.apps.registry import apps
from django.core.files import File
from django.db import connection
from django.db.models import ManyToManyRel, Model
from django.db.models.expressions import BaseExpression, Combinable
from django.db.models.signals import post_init
from django.http import HttpRequest
from authentik.events.middleware import AuditMiddleware, should_log_model
from authentik.events.utils import cleanse_dict, sanitize_item
class EnterpriseAuditMiddleware(AuditMiddleware):
"""Enterprise audit middleware"""
@property
def enabled(self):
"""Check if audit logging is enabled"""
return apps.get_app_config("authentik_enterprise").enabled()
def connect(self, request: HttpRequest):
super().connect(request)
if not self.enabled:
return
if not hasattr(request, "request_id"):
return
post_init.connect(
partial(self.post_init_handler, request=request),
dispatch_uid=request.request_id,
weak=False,
)
def disconnect(self, request: HttpRequest):
super().disconnect(request)
if not self.enabled:
return
if not hasattr(request, "request_id"):
return
post_init.disconnect(dispatch_uid=request.request_id)
def serialize_simple(self, model: Model) -> dict:
"""Serialize a model in a very simple way. No ForeignKeys or other relationships are
resolved"""
data = {}
deferred_fields = model.get_deferred_fields()
for field in model._meta.concrete_fields:
value = None
if field.get_attname() in deferred_fields:
continue
field_value = getattr(model, field.attname)
if isinstance(value, File):
field_value = value.name
# If current field value is an expression, we are not evaluating it
if isinstance(field_value, BaseExpression | Combinable):
continue
field_value = field.to_python(field_value)
data[field.name] = deepcopy(field_value)
return cleanse_dict(data)
def diff(self, before: dict, after: dict) -> dict:
"""Generate diff between dicts"""
diff = {}
for key, value in before.items():
if after.get(key) != value:
diff[key] = {"previous_value": value, "new_value": after.get(key)}
for key, value in after.items():
if key not in before and key not in diff and before.get(key) != value:
diff[key] = {"previous_value": before.get(key), "new_value": value}
return sanitize_item(diff)
def post_init_handler(self, request: HttpRequest, sender, instance: Model, **_):
"""post_init django model handler"""
if not should_log_model(instance):
return
if hasattr(instance, "_previous_state"):
return
before = len(connection.queries)
instance._previous_state = self.serialize_simple(instance)
after = len(connection.queries)
if after > before:
raise AssertionError("More queries generated by serialize_simple")
def post_save_handler(
self,
request: HttpRequest,
sender,
instance: Model,
created: bool,
thread_kwargs: dict | None = None,
**_,
):
if not self.enabled:
return super().post_save_handler(request, sender, instance, created, thread_kwargs, **_)
if not should_log_model(instance):
return None
thread_kwargs = {}
if hasattr(instance, "_previous_state") or created:
prev_state = getattr(instance, "_previous_state", {})
if created:
prev_state = {}
# Get current state
new_state = self.serialize_simple(instance)
diff = self.diff(prev_state, new_state)
thread_kwargs["diff"] = diff
return super().post_save_handler(request, sender, instance, created, thread_kwargs, **_)
def m2m_changed_handler( # noqa: PLR0913
self,
request: HttpRequest,
sender,
instance: Model,
action: str,
pk_set: set[Any],
thread_kwargs: dict | None = None,
**_,
):
thread_kwargs = {}
m2m_field = None
if not self.enabled:
return super().m2m_changed_handler(request, sender, instance, action, thread_kwargs)
# For the audit log we don't care about `pre_` or `post_` so we trim that part off
_, _, action_direction = action.partition("_")
# resolve the "through" model to an actual field
for field in instance._meta.get_fields():
if not isinstance(field, ManyToManyRel):
continue
if field.through == sender:
m2m_field = field
if m2m_field:
# If we're clearing we just set the "flag" to True
if action_direction == "clear":
pk_set = True
thread_kwargs["diff"] = {m2m_field.related_name: {action_direction: pk_set}}
return super().m2m_changed_handler(request, sender, instance, action, thread_kwargs)