From b6157ecaf1425460dbea312e6075fdd4eefdc6ab Mon Sep 17 00:00:00 2001 From: "gcp-cherry-pick-bot[bot]" <98988430+gcp-cherry-pick-bot[bot]@users.noreply.github.com> Date: Sun, 16 Jun 2024 19:52:04 +0200 Subject: [PATCH] policies/reputation: fix existing reputation update (cherry-pick #10124) (#10125) policies/reputation: fix existing reputation update (#10124) * add failing test case * fix reputation update * lint --------- Signed-off-by: Marc 'risson' Schmitt Co-authored-by: Marc 'risson' Schmitt --- authentik/policies/reputation/signals.py | 27 +++++++++++++++--------- authentik/policies/reputation/tests.py | 9 ++++++++ 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/authentik/policies/reputation/signals.py b/authentik/policies/reputation/signals.py index 31042f8406..a3969c4d9e 100644 --- a/authentik/policies/reputation/signals.py +++ b/authentik/policies/reputation/signals.py @@ -1,6 +1,8 @@ """authentik reputation request signals""" from django.contrib.auth.signals import user_logged_in +from django.db import transaction +from django.db.models import F from django.dispatch import receiver from django.http import HttpRequest from structlog.stdlib import get_logger @@ -19,16 +21,21 @@ def update_score(request: HttpRequest, identifier: str, amount: int): """Update score for IP and User""" remote_ip = ClientIPMiddleware.get_client_ip(request) - Reputation.objects.update_or_create( - ip=remote_ip, - identifier=identifier, - defaults={ - "score": amount, - "ip_geo_data": GEOIP_CONTEXT_PROCESSOR.city_dict(remote_ip) or {}, - "ip_asn_data": ASN_CONTEXT_PROCESSOR.asn_dict(remote_ip) or {}, - "expires": reputation_expiry(), - }, - ) + with transaction.atomic(): + reputation, created = Reputation.objects.select_for_update().get_or_create( + ip=remote_ip, + identifier=identifier, + defaults={ + "score": amount, + "ip_geo_data": GEOIP_CONTEXT_PROCESSOR.city_dict(remote_ip) or {}, + "ip_asn_data": ASN_CONTEXT_PROCESSOR.asn_dict(remote_ip) or {}, + "expires": reputation_expiry(), + }, + ) + + if not created: + reputation.score = F("score") + amount + reputation.save() LOGGER.debug("Updated score", amount=amount, for_user=identifier, for_ip=remote_ip) diff --git a/authentik/policies/reputation/tests.py b/authentik/policies/reputation/tests.py index 7d4e33fb79..50b7b5a196 100644 --- a/authentik/policies/reputation/tests.py +++ b/authentik/policies/reputation/tests.py @@ -39,6 +39,15 @@ class TestReputationPolicy(TestCase): ) self.assertEqual(Reputation.objects.get(identifier=self.test_username).score, -1) + def test_update_reputation(self): + """test reputation update""" + Reputation.objects.create(identifier=self.test_username, ip=self.test_ip, score=43) + # Trigger negative reputation + authenticate( + self.request, self.backends, username=self.test_username, password=self.test_username + ) + self.assertEqual(Reputation.objects.get(identifier=self.test_username).score, 42) + def test_policy(self): """Test Policy""" request = PolicyRequest(user=self.user)