diff --git a/authentik/admin/tests/test_api.py b/authentik/admin/tests/test_api.py index 8ef016b88..f7bd03ff5 100644 --- a/authentik/admin/tests/test_api.py +++ b/authentik/admin/tests/test_api.py @@ -7,8 +7,6 @@ from django.urls import reverse from authentik import __version__ from authentik.blueprints.tests import reconcile_app from authentik.core.models import Group, User -from authentik.core.tasks import clean_expired_models -from authentik.events.monitored_tasks import TaskStatus from authentik.lib.generators import generate_id @@ -23,53 +21,6 @@ class TestAdminAPI(TestCase): self.group.save() self.client.force_login(self.user) - def test_tasks(self): - """Test Task API""" - clean_expired_models.delay() - response = self.client.get(reverse("authentik_api:admin_system_tasks-list")) - self.assertEqual(response.status_code, 200) - body = loads(response.content) - self.assertTrue(any(task["task_name"] == "clean_expired_models" for task in body)) - - def test_tasks_single(self): - """Test Task API (read single)""" - clean_expired_models.delay() - response = self.client.get( - reverse( - "authentik_api:admin_system_tasks-detail", - kwargs={"pk": "clean_expired_models"}, - ) - ) - self.assertEqual(response.status_code, 200) - body = loads(response.content) - self.assertEqual(body["status"], TaskStatus.SUCCESSFUL.name) - self.assertEqual(body["task_name"], "clean_expired_models") - response = self.client.get( - reverse("authentik_api:admin_system_tasks-detail", kwargs={"pk": "qwerqwer"}) - ) - self.assertEqual(response.status_code, 404) - - def test_tasks_retry(self): - """Test Task API (retry)""" - clean_expired_models.delay() - response = self.client.post( - reverse( - "authentik_api:admin_system_tasks-retry", - kwargs={"pk": "clean_expired_models"}, - ) - ) - self.assertEqual(response.status_code, 204) - - def test_tasks_retry_404(self): - """Test Task API (retry, 404)""" - response = self.client.post( - reverse( - "authentik_api:admin_system_tasks-retry", - kwargs={"pk": "qwerqewrqrqewrqewr"}, - ) - ) - self.assertEqual(response.status_code, 404) - def test_version(self): """Test Version API""" response = self.client.get(reverse("authentik_api:admin_version")) diff --git a/authentik/events/tests/test_tasks.py b/authentik/events/tests/test_tasks.py index d343ebdb1..8ba898753 100644 --- a/authentik/events/tests/test_tasks.py +++ b/authentik/events/tests/test_tasks.py @@ -1,15 +1,25 @@ """Test Monitored tasks""" -from django.test import TestCase +from json import loads +from django.urls import reverse +from rest_framework.test import APITestCase + +from authentik.core.tasks import clean_expired_models +from authentik.core.tests.utils import create_test_admin_user from authentik.events.models import SystemTask, TaskStatus from authentik.events.monitored_tasks import MonitoredTask from authentik.lib.generators import generate_id from authentik.root.celery import CELERY_APP -class TestMonitoredTasks(TestCase): +class TestSystemTasks(APITestCase): """Test Monitored tasks""" + def setUp(self): + super().setUp() + self.user = create_test_admin_user() + self.client.force_login(self.user) + def test_failed_successful_remove_state(self): """Test that a task with `save_on_success` set to `False` that failed saves a state, and upon successful completion will delete the state""" @@ -28,15 +38,64 @@ class TestMonitoredTasks(TestCase): # First test successful run should_fail = False test_task.delay().get() - self.assertIsNone(SystemTask.objects.filter(name="test_task", uid=uid)) + self.assertIsNone(SystemTask.objects.filter(name="test_task", uid=uid).first()) # Then test failed should_fail = True test_task.delay().get() - info = SystemTask.objects.filter(name="test_task", uid=uid) - self.assertEqual(info.status, TaskStatus.ERROR) + task = SystemTask.objects.filter(name="test_task", uid=uid).first() + self.assertEqual(task.status, TaskStatus.ERROR) # Then after that, the state should be removed should_fail = False test_task.delay().get() - self.assertIsNone(SystemTask.objects.filter(name="test_task", uid=uid)) + self.assertIsNone(SystemTask.objects.filter(name="test_task", uid=uid).first()) + + def test_tasks(self): + """Test Task API""" + clean_expired_models.delay().get() + response = self.client.get(reverse("authentik_api:systemtask-list")) + self.assertEqual(response.status_code, 200) + body = loads(response.content) + self.assertTrue(any(task["name"] == "clean_expired_models" for task in body["results"])) + + def test_tasks_single(self): + """Test Task API (read single)""" + clean_expired_models.delay().get() + task = SystemTask.objects.filter(name="clean_expired_models").first() + response = self.client.get( + reverse( + "authentik_api:systemtask-detail", + kwargs={"pk": str(task.pk)}, + ) + ) + self.assertEqual(response.status_code, 200) + body = loads(response.content) + self.assertEqual(body["status"], TaskStatus.SUCCESSFUL.value) + self.assertEqual(body["name"], "clean_expired_models") + response = self.client.get( + reverse("authentik_api:systemtask-detail", kwargs={"pk": "qwerqwer"}) + ) + self.assertEqual(response.status_code, 404) + + def test_tasks_run(self): + """Test Task API (run)""" + clean_expired_models.delay().get() + task = SystemTask.objects.filter(name="clean_expired_models").first() + response = self.client.post( + reverse( + "authentik_api:systemtask-run", + kwargs={"pk": str(task.pk)}, + ) + ) + self.assertEqual(response.status_code, 204) + + def test_tasks_run_404(self): + """Test Task API (run, 404)""" + response = self.client.post( + reverse( + "authentik_api:systemtask-run", + kwargs={"pk": "qwerqewrqrqewrqewr"}, + ) + ) + self.assertEqual(response.status_code, 404) diff --git a/authentik/sources/ldap/tasks.py b/authentik/sources/ldap/tasks.py index dcf6bff5d..6931affa2 100644 --- a/authentik/sources/ldap/tasks.py +++ b/authentik/sources/ldap/tasks.py @@ -34,7 +34,7 @@ CACHE_KEY_STATUS = "goauthentik.io/sources/ldap/status/" def ldap_sync_all(): """Sync all sources""" for source in LDAPSource.objects.filter(enabled=True): - ldap_sync_single.apply_async(args=[source.pk]) + ldap_sync_single.apply_async(args=[str(source.pk)]) @CELERY_APP.task() @@ -95,7 +95,7 @@ def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) -> for page in sync_inst.get_objects(): page_cache_key = CACHE_KEY_PREFIX + str(uuid4()) cache.set(page_cache_key, page, 60 * 60 * CONFIG.get_int("ldap.task_timeout_hours")) - page_sync = ldap_sync.si(source.pk, class_to_path(sync), page_cache_key) + page_sync = ldap_sync.si(str(source.pk), class_to_path(sync), page_cache_key) signatures.append(page_sync) return signatures diff --git a/authentik/sources/ldap/tests/test_sync.py b/authentik/sources/ldap/tests/test_sync.py index d5c8372c4..32c042be5 100644 --- a/authentik/sources/ldap/tests/test_sync.py +++ b/authentik/sources/ldap/tests/test_sync.py @@ -40,7 +40,7 @@ class LDAPSyncTests(TestCase): """Test sync with missing page""" connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD)) with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): - ldap_sync.delay(self.source.pk, class_to_path(UserLDAPSynchronizer), "foo").get() + ldap_sync.delay(str(self.source.pk), class_to_path(UserLDAPSynchronizer), "foo").get() task = SystemTask.objects.filter(name="ldap_sync", uid="ldap:users:foo").first() self.assertEqual(task.status, TaskStatus.ERROR) diff --git a/authentik/stages/email/tasks.py b/authentik/stages/email/tasks.py index 0d39e4758..51a0c6260 100644 --- a/authentik/stages/email/tasks.py +++ b/authentik/stages/email/tasks.py @@ -22,7 +22,7 @@ def send_mails(stage: EmailStage, *messages: list[EmailMultiAlternatives]): """Wrapper to convert EmailMessage to dict and send it from worker""" tasks = [] for message in messages: - tasks.append(send_mail.s(message.__dict__, stage.pk)) + tasks.append(send_mail.s(message.__dict__, str(stage.pk))) lazy_group = group(*tasks) promise = lazy_group() return promise @@ -46,7 +46,7 @@ def get_email_body(email: EmailMultiAlternatives) -> str: retry_backoff=True, base=MonitoredTask, ) -def send_mail(self: MonitoredTask, message: dict[Any, Any], email_stage_pk: Optional[int] = None): +def send_mail(self: MonitoredTask, message: dict[Any, Any], email_stage_pk: Optional[str] = None): """Send Email for Email Stage. Retries are scheduled automatically.""" self.save_on_success = False message_id = make_msgid(domain=DNS_NAME)