diff --git a/authentik/events/monitored_tasks.py b/authentik/events/monitored_tasks.py index 772e6ccf9..9acee40b6 100644 --- a/authentik/events/monitored_tasks.py +++ b/authentik/events/monitored_tasks.py @@ -76,9 +76,20 @@ class TaskInfo: return cache.get_many(cache.keys(CACHE_KEY_PREFIX + name)).values() return cache.get(CACHE_KEY_PREFIX + name, None) + @property + def full_name(self) -> str: + """Get the full cache key with task name and UID""" + key = CACHE_KEY_PREFIX + self.task_name + if self.result.uid: + uid_suffix = f":{self.result.uid}" + key += uid_suffix + if not self.task_name.endswith(uid_suffix): + self.task_name += uid_suffix + return key + def delete(self): """Delete task info from cache""" - return cache.delete(CACHE_KEY_PREFIX + self.task_name) + return cache.delete(self.full_name) def update_metrics(self): """Update prometheus metrics""" @@ -97,12 +108,8 @@ class TaskInfo: def save(self, timeout_hours=6): """Save task into cache""" - key = CACHE_KEY_PREFIX + self.task_name - if self.result.uid: - key += f":{self.result.uid}" - self.task_name += f":{self.result.uid}" self.update_metrics() - cache.set(key, self, timeout=timeout_hours * 60 * 60) + cache.set(self.full_name, self, timeout=timeout_hours * 60 * 60) class MonitoredTask(Task): diff --git a/authentik/events/tests/test_tasks.py b/authentik/events/tests/test_tasks.py new file mode 100644 index 000000000..58dad6556 --- /dev/null +++ b/authentik/events/tests/test_tasks.py @@ -0,0 +1,43 @@ +"""Test Monitored tasks""" +from django.test import TestCase + +from authentik.events.monitored_tasks import MonitoredTask, TaskInfo, TaskResult, TaskResultStatus +from authentik.lib.generators import generate_id +from authentik.root.celery import CELERY_APP + + +class TestMonitoredTasks(TestCase): + """Test Monitored tasks""" + + def test_failed_successful_remove_state(self): + """Test that a task with `save_on_success` set to `False` that failed saves + a state, and upon successful completion will delete the state""" + should_fail = True + uid = generate_id() + + @CELERY_APP.task( + bind=True, + base=MonitoredTask, + ) + def test_task(self: MonitoredTask): + self.save_on_success = False + self.set_uid(uid) + self.set_status( + TaskResult(TaskResultStatus.ERROR if should_fail else TaskResultStatus.SUCCESSFUL) + ) + + # First test successful run + should_fail = False + test_task.delay().get() + self.assertIsNone(TaskInfo.by_name(f"test_task:{uid}")) + + # Then test failed + should_fail = True + test_task.delay().get() + info = TaskInfo.by_name(f"test_task:{uid}") + self.assertEqual(info.result.status, TaskResultStatus.ERROR) + + # Then after that, the state should be removed + should_fail = False + test_task.delay().get() + self.assertIsNone(TaskInfo.by_name(f"test_task:{uid}"))