events: rewrite GeoIP to a wrapper, reload file every 8 hours
Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
parent
f5dbdbd48b
commit
17326615b7
|
@ -43,7 +43,7 @@ class ConfigView(APIView):
|
|||
deb_test = settings.DEBUG or settings.TEST
|
||||
if path.ismount(settings.MEDIA_ROOT) or deb_test:
|
||||
caps.append(Capabilities.CAN_SAVE_MEDIA)
|
||||
if GEOIP_READER:
|
||||
if GEOIP_READER.enabled:
|
||||
caps.append(Capabilities.CAN_GEO_IP)
|
||||
return caps
|
||||
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
from typing import Optional, TypedDict
|
||||
|
||||
from django_filters.rest_framework import DjangoFilterBackend
|
||||
from geoip2.errors import GeoIP2Error
|
||||
from guardian.utils import get_anonymous_user
|
||||
from rest_framework import mixins
|
||||
from rest_framework.fields import SerializerMethodField
|
||||
|
@ -13,7 +12,7 @@ from rest_framework.viewsets import GenericViewSet
|
|||
from ua_parser import user_agent_parser
|
||||
|
||||
from authentik.core.models import AuthenticatedSession
|
||||
from authentik.events.geo import GEOIP_READER
|
||||
from authentik.events.geo import GEOIP_READER, GeoIPDict
|
||||
|
||||
|
||||
class UserAgentDeviceDict(TypedDict):
|
||||
|
@ -52,15 +51,6 @@ class UserAgentDict(TypedDict):
|
|||
string: str
|
||||
|
||||
|
||||
class GeoIPDict(TypedDict):
|
||||
"""GeoIP Details"""
|
||||
|
||||
continent: str
|
||||
country: str
|
||||
lat: float
|
||||
long: float
|
||||
|
||||
|
||||
class AuthenticatedSessionSerializer(ModelSerializer):
|
||||
"""AuthenticatedSession Serializer"""
|
||||
|
||||
|
@ -81,18 +71,7 @@ class AuthenticatedSessionSerializer(ModelSerializer):
|
|||
self, instance: AuthenticatedSession
|
||||
) -> Optional[GeoIPDict]: # pragma: no cover
|
||||
"""Get parsed user agent"""
|
||||
if not GEOIP_READER:
|
||||
return None
|
||||
try:
|
||||
city = GEOIP_READER.city(instance.last_ip)
|
||||
return {
|
||||
"continent": city.continent.code,
|
||||
"country": city.country.iso_code,
|
||||
"lat": city.location.latitude,
|
||||
"long": city.location.longitude,
|
||||
}
|
||||
except (GeoIP2Error, ValueError):
|
||||
return None
|
||||
return GEOIP_READER.city_dict(instance.last_ip)
|
||||
|
||||
class Meta:
|
||||
|
||||
|
|
|
@ -1,7 +1,12 @@
|
|||
"""events GeoIP Reader"""
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
from os import stat
|
||||
from time import time
|
||||
from typing import Optional, TypedDict
|
||||
|
||||
from geoip2.database import Reader
|
||||
from geoip2.errors import GeoIP2Error
|
||||
from geoip2.models import City
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.lib.config import CONFIG
|
||||
|
@ -9,17 +14,78 @@ from authentik.lib.config import CONFIG
|
|||
LOGGER = get_logger()
|
||||
|
||||
|
||||
def get_geoip_reader() -> Optional[Reader]:
|
||||
class GeoIPDict(TypedDict):
|
||||
"""GeoIP Details"""
|
||||
|
||||
continent: str
|
||||
country: str
|
||||
lat: float
|
||||
long: float
|
||||
city: str
|
||||
|
||||
|
||||
class GeoIPReader:
|
||||
"""Slim wrapper around GeoIP API"""
|
||||
|
||||
__reader: Optional[Reader] = None
|
||||
__last_mtime: float = 0.0
|
||||
|
||||
def __init__(self):
|
||||
self.__open()
|
||||
|
||||
def __open(self):
|
||||
"""Get GeoIP Reader, if configured, otherwise none"""
|
||||
path = CONFIG.y("authentik.geoip")
|
||||
if path == "" or not path:
|
||||
return None
|
||||
return
|
||||
try:
|
||||
reader = Reader(path)
|
||||
LOGGER.info("Enabled GeoIP support")
|
||||
return reader
|
||||
except OSError:
|
||||
LOGGER.info("Loaded GeoIP database")
|
||||
self.__reader = reader
|
||||
self.__last_mtime = stat(path).st_mtime
|
||||
except OSError as exc:
|
||||
LOGGER.warning("Failed to load GeoIP database", exc=exc)
|
||||
|
||||
def __check_expired(self):
|
||||
"""Check if the geoip database has been opened longer than 8 hours,
|
||||
and re-open it, as it will probably will have been re-downloaded"""
|
||||
now = time()
|
||||
diff = datetime.fromtimestamp(now) - datetime.fromtimestamp(self.__last_mtime)
|
||||
diff_hours = diff.total_seconds() // 3600
|
||||
if diff_hours >= 8:
|
||||
LOGGER.info("GeoIP databased loaded too long, re-opening", diff=diff)
|
||||
self.__open()
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
"""Check if GeoIP is enabled"""
|
||||
return bool(self.__reader)
|
||||
|
||||
def city(self, ip_address: str) -> Optional[City]:
|
||||
"""Wrapper for Reader.city"""
|
||||
if not self.enabled:
|
||||
return None
|
||||
self.__check_expired()
|
||||
try:
|
||||
return self.__reader.city(ip_address)
|
||||
except (GeoIP2Error, ValueError):
|
||||
return None
|
||||
|
||||
def city_dict(self, ip_address: str) -> Optional[GeoIPDict]:
|
||||
"""Wrapper for self.city that returns a dict"""
|
||||
city = self.city(ip_address)
|
||||
if not city:
|
||||
return None
|
||||
city_dict: GeoIPDict = {
|
||||
"continent": city.continent.code,
|
||||
"country": city.country.iso_code,
|
||||
"lat": city.location.latitude,
|
||||
"long": city.location.longitude,
|
||||
"city": "",
|
||||
}
|
||||
if city.city.name:
|
||||
city_dict["city"] = city.city.name
|
||||
return city_dict
|
||||
|
||||
GEOIP_READER = get_geoip_reader()
|
||||
|
||||
GEOIP_READER = GeoIPReader()
|
||||
|
|
|
@ -10,7 +10,6 @@ from django.db import models
|
|||
from django.http import HttpRequest
|
||||
from django.utils.timezone import now
|
||||
from django.utils.translation import gettext as _
|
||||
from geoip2.errors import GeoIP2Error
|
||||
from prometheus_client import Gauge
|
||||
from requests import RequestException, post
|
||||
from structlog.stdlib import get_logger
|
||||
|
@ -160,20 +159,10 @@ class Event(ExpiringModel):
|
|||
|
||||
def with_geoip(self): # pragma: no cover
|
||||
"""Apply GeoIP Data, when enabled"""
|
||||
if not GEOIP_READER:
|
||||
city = GEOIP_READER.city_dict(self.client_ip)
|
||||
if not city:
|
||||
return
|
||||
try:
|
||||
response = GEOIP_READER.city(self.client_ip)
|
||||
self.context["geo"] = {
|
||||
"continent": response.continent.code,
|
||||
"country": response.country.iso_code,
|
||||
"lat": response.location.latitude,
|
||||
"long": response.location.longitude,
|
||||
}
|
||||
if response.city.name:
|
||||
self.context["geo"]["city"] = response.city.name
|
||||
except (GeoIP2Error, ValueError) as exc:
|
||||
LOGGER.warning("Failed to add geoIP Data to event", exc=exc)
|
||||
self.context["geo"] = city
|
||||
|
||||
def _set_prom_metrics(self):
|
||||
GAUGE_EVENTS.labels(
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
"""Test GeoIP Wrapper"""
|
||||
from django.test import TestCase
|
||||
|
||||
from authentik.events.geo import GeoIPReader
|
||||
|
||||
|
||||
class TestGeoIP(TestCase):
|
||||
"""Test GeoIP Wrapper"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.reader = GeoIPReader()
|
||||
|
||||
def test_simple(self):
|
||||
"""Test simple city wrapper"""
|
||||
# IPs from
|
||||
# https://github.com/maxmind/MaxMind-DB/blob/main/source-data/GeoLite2-City-Test.json
|
||||
self.assertEqual(
|
||||
self.reader.city_dict("2.125.160.216"),
|
||||
{
|
||||
"city": "Boxford",
|
||||
"continent": "EU",
|
||||
"country": "GB",
|
||||
"lat": 51.75,
|
||||
"long": -1.25,
|
||||
},
|
||||
)
|
|
@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, Any, Optional
|
|||
|
||||
from django.db.models import Model
|
||||
from django.http import HttpRequest
|
||||
from geoip2.errors import GeoIP2Error
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.events.geo import GEOIP_READER
|
||||
|
@ -39,16 +38,12 @@ class PolicyRequest:
|
|||
def set_http_request(self, request: HttpRequest): # pragma: no cover
|
||||
"""Load data from HTTP request, including geoip when enabled"""
|
||||
self.http_request = request
|
||||
if not GEOIP_READER:
|
||||
if not GEOIP_READER.enabled:
|
||||
return
|
||||
try:
|
||||
client_ip = get_client_ip(request)
|
||||
if not client_ip:
|
||||
return
|
||||
response = GEOIP_READER.city(client_ip)
|
||||
self.context["geoip"] = response
|
||||
except (GeoIP2Error, ValueError) as exc:
|
||||
LOGGER.warning("failed to get geoip data", exc=exc)
|
||||
self.context["geoip"] = GEOIP_READER.city(client_ip)
|
||||
|
||||
def __str__(self):
|
||||
text = f"<PolicyRequest user={self.user}"
|
||||
|
|
|
@ -14,6 +14,7 @@ class PytestTestRunner: # pragma: no cover
|
|||
settings.TEST = True
|
||||
settings.CELERY_TASK_ALWAYS_EAGER = True
|
||||
CONFIG.y_set("authentik.avatars", "none")
|
||||
CONFIG.y_set("authentik.geoip", "tests/GeoLite2-City-Test.mmdb")
|
||||
|
||||
def run_tests(self, test_labels):
|
||||
"""Run pytest and return the exitcode.
|
||||
|
|
Reference in New Issue