events: rewrite GeoIP to a wrapper, reload file every 8 hours
Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
		| @ -1,7 +1,12 @@ | ||||
| """events GeoIP Reader""" | ||||
| from typing import Optional | ||||
| from datetime import datetime | ||||
| from os import stat | ||||
| from time import time | ||||
| from typing import Optional, TypedDict | ||||
|  | ||||
| from geoip2.database import Reader | ||||
| from geoip2.errors import GeoIP2Error | ||||
| from geoip2.models import City | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.lib.config import CONFIG | ||||
| @ -9,17 +14,78 @@ from authentik.lib.config import CONFIG | ||||
| LOGGER = get_logger() | ||||
|  | ||||
|  | ||||
| def get_geoip_reader() -> Optional[Reader]: | ||||
|     """Get GeoIP Reader, if configured, otherwise none""" | ||||
|     path = CONFIG.y("authentik.geoip") | ||||
|     if path == "" or not path: | ||||
|         return None | ||||
|     try: | ||||
|         reader = Reader(path) | ||||
|         LOGGER.info("Enabled GeoIP support") | ||||
|         return reader | ||||
|     except OSError: | ||||
|         return None | ||||
| class GeoIPDict(TypedDict): | ||||
|     """GeoIP Details""" | ||||
|  | ||||
|     continent: str | ||||
|     country: str | ||||
|     lat: float | ||||
|     long: float | ||||
|     city: str | ||||
|  | ||||
|  | ||||
| GEOIP_READER = get_geoip_reader() | ||||
| class GeoIPReader: | ||||
|     """Slim wrapper around GeoIP API""" | ||||
|  | ||||
|     __reader: Optional[Reader] = None | ||||
|     __last_mtime: float = 0.0 | ||||
|  | ||||
|     def __init__(self): | ||||
|         self.__open() | ||||
|  | ||||
|     def __open(self): | ||||
|         """Get GeoIP Reader, if configured, otherwise none""" | ||||
|         path = CONFIG.y("authentik.geoip") | ||||
|         if path == "" or not path: | ||||
|             return | ||||
|         try: | ||||
|             reader = Reader(path) | ||||
|             LOGGER.info("Loaded GeoIP database") | ||||
|             self.__reader = reader | ||||
|             self.__last_mtime = stat(path).st_mtime | ||||
|         except OSError as exc: | ||||
|             LOGGER.warning("Failed to load GeoIP database", exc=exc) | ||||
|  | ||||
|     def __check_expired(self): | ||||
|         """Check if the geoip database has been opened longer than 8 hours, | ||||
|         and re-open it, as it will probably will have been re-downloaded""" | ||||
|         now = time() | ||||
|         diff = datetime.fromtimestamp(now) - datetime.fromtimestamp(self.__last_mtime) | ||||
|         diff_hours = diff.total_seconds() // 3600 | ||||
|         if diff_hours >= 8: | ||||
|             LOGGER.info("GeoIP databased loaded too long, re-opening", diff=diff) | ||||
|             self.__open() | ||||
|  | ||||
|     @property | ||||
|     def enabled(self) -> bool: | ||||
|         """Check if GeoIP is enabled""" | ||||
|         return bool(self.__reader) | ||||
|  | ||||
|     def city(self, ip_address: str) -> Optional[City]: | ||||
|         """Wrapper for Reader.city""" | ||||
|         if not self.enabled: | ||||
|             return None | ||||
|         self.__check_expired() | ||||
|         try: | ||||
|             return self.__reader.city(ip_address) | ||||
|         except (GeoIP2Error, ValueError): | ||||
|             return None | ||||
|  | ||||
|     def city_dict(self, ip_address: str) -> Optional[GeoIPDict]: | ||||
|         """Wrapper for self.city that returns a dict""" | ||||
|         city = self.city(ip_address) | ||||
|         if not city: | ||||
|             return None | ||||
|         city_dict: GeoIPDict = { | ||||
|             "continent": city.continent.code, | ||||
|             "country": city.country.iso_code, | ||||
|             "lat": city.location.latitude, | ||||
|             "long": city.location.longitude, | ||||
|             "city": "", | ||||
|         } | ||||
|         if city.city.name: | ||||
|             city_dict["city"] = city.city.name | ||||
|         return city_dict | ||||
|  | ||||
|  | ||||
| GEOIP_READER = GeoIPReader() | ||||
|  | ||||
| @ -10,7 +10,6 @@ from django.db import models | ||||
| from django.http import HttpRequest | ||||
| from django.utils.timezone import now | ||||
| from django.utils.translation import gettext as _ | ||||
| from geoip2.errors import GeoIP2Error | ||||
| from prometheus_client import Gauge | ||||
| from requests import RequestException, post | ||||
| from structlog.stdlib import get_logger | ||||
| @ -160,20 +159,10 @@ class Event(ExpiringModel): | ||||
|  | ||||
|     def with_geoip(self):  # pragma: no cover | ||||
|         """Apply GeoIP Data, when enabled""" | ||||
|         if not GEOIP_READER: | ||||
|         city = GEOIP_READER.city_dict(self.client_ip) | ||||
|         if not city: | ||||
|             return | ||||
|         try: | ||||
|             response = GEOIP_READER.city(self.client_ip) | ||||
|             self.context["geo"] = { | ||||
|                 "continent": response.continent.code, | ||||
|                 "country": response.country.iso_code, | ||||
|                 "lat": response.location.latitude, | ||||
|                 "long": response.location.longitude, | ||||
|             } | ||||
|             if response.city.name: | ||||
|                 self.context["geo"]["city"] = response.city.name | ||||
|         except (GeoIP2Error, ValueError) as exc: | ||||
|             LOGGER.warning("Failed to add geoIP Data to event", exc=exc) | ||||
|         self.context["geo"] = city | ||||
|  | ||||
|     def _set_prom_metrics(self): | ||||
|         GAUGE_EVENTS.labels( | ||||
|  | ||||
							
								
								
									
										26
									
								
								authentik/events/tests/test_geoip.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								authentik/events/tests/test_geoip.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,26 @@ | ||||
| """Test GeoIP Wrapper""" | ||||
| from django.test import TestCase | ||||
|  | ||||
| from authentik.events.geo import GeoIPReader | ||||
|  | ||||
|  | ||||
| class TestGeoIP(TestCase): | ||||
|     """Test GeoIP Wrapper""" | ||||
|  | ||||
|     def setUp(self) -> None: | ||||
|         self.reader = GeoIPReader() | ||||
|  | ||||
|     def test_simple(self): | ||||
|         """Test simple city wrapper""" | ||||
|         # IPs from | ||||
|         # https://github.com/maxmind/MaxMind-DB/blob/main/source-data/GeoLite2-City-Test.json | ||||
|         self.assertEqual( | ||||
|             self.reader.city_dict("2.125.160.216"), | ||||
|             { | ||||
|                 "city": "Boxford", | ||||
|                 "continent": "EU", | ||||
|                 "country": "GB", | ||||
|                 "lat": 51.75, | ||||
|                 "long": -1.25, | ||||
|             }, | ||||
|         ) | ||||
		Reference in New Issue
	
	Block a user
	 Jens Langhammer
					Jens Langhammer