diff --git a/authentik/events/tasks.py b/authentik/events/tasks.py index ecd5c6926..813803c11 100644 --- a/authentik/events/tasks.py +++ b/authentik/events/tasks.py @@ -9,7 +9,7 @@ from authentik.events.models import ( NotificationTrigger, ) from authentik.lib.tasks import MonitoredTask, TaskResult, TaskResultStatus -from authentik.policies.engine import PolicyEngine +from authentik.policies.engine import PolicyEngine, PolicyEngineMode from authentik.root.celery import CELERY_APP LOGGER = get_logger() @@ -43,6 +43,8 @@ def event_trigger_handler(event_uuid: str, trigger_name: str): return policy_engine = PolicyEngine(trigger, get_anonymous_user()) + policy_engine.mode = PolicyEngineMode.MODE_OR + policy_engine.empty_result = False policy_engine.request.context["event"] = event policy_engine.build() result = policy_engine.result diff --git a/authentik/events/tests/test_notifications.py b/authentik/events/tests/test_notifications.py index 2dbc93f6a..de4e29f60 100644 --- a/authentik/events/tests/test_notifications.py +++ b/authentik/events/tests/test_notifications.py @@ -25,6 +25,18 @@ class TestEventsNotifications(TestCase): self.group.users.add(self.user) self.group.save() + def test_trigger_empty(self): + """Test trigger without any policies attached""" + transport = NotificationTransport.objects.create(name="transport") + trigger = NotificationTrigger.objects.create(name="trigger", group=self.group) + trigger.transports.add(transport) + trigger.save() + + execute_mock = MagicMock() + with patch("authentik.events.models.NotificationTransport.send", execute_mock): + Event.new(EventAction.CUSTOM_PREFIX).save() + self.assertEqual(execute_mock.call_count, 0) + def test_trigger_single(self): """Test simple transport triggering""" transport = NotificationTransport.objects.create(name="transport") diff --git a/authentik/policies/engine.py b/authentik/policies/engine.py index d1b79c6b8..a1c83bb83 100644 --- a/authentik/policies/engine.py +++ b/authentik/policies/engine.py @@ -1,4 +1,5 @@ """authentik policy engine""" +from enum import Enum from multiprocessing import Pipe, set_start_method from multiprocessing.connection import Connection from typing import Iterator, List, Optional @@ -37,12 +38,23 @@ class PolicyProcessInfo: self.result = None +class PolicyEngineMode(Enum): + """Decide how results of multiple policies should be combined.""" + + MODE_AND = "and" + MODE_OR = "or" + + class PolicyEngine: """Orchestrate policy checking, launch tasks and return result""" use_cache: bool request: PolicyRequest + mode: PolicyEngineMode + # Allow objects with no policies attached to pass + empty_result: bool + __pbm: PolicyBindingModel __cached_policies: List[PolicyResult] __processes: List[PolicyProcessInfo] @@ -52,6 +64,10 @@ class PolicyEngine: def __init__( self, pbm: PolicyBindingModel, user: User, request: HttpRequest = None ): + self.mode = PolicyEngineMode.MODE_AND + # For backwards compatibility, set empty_result to true + # objects with no policies attached will pass. + self.empty_result = True if not isinstance(pbm, PolicyBindingModel): # pragma: no cover raise ValueError(f"{pbm} is not instance of PolicyBindingModel") self.__pbm = pbm @@ -119,24 +135,19 @@ class PolicyEngine: x.result for x in self.__processes if x.result ] all_results = list(process_results + self.__cached_policies) - final_result = PolicyResult(False) - final_result.messages = [] - final_result.source_results = all_results if len(all_results) < self.__expected_result_count: # pragma: no cover raise AssertionError("Got less results than polices") - for result in all_results: - LOGGER.debug( - "P_ENG: result", passing=result.passing, messages=result.messages - ) - if result.messages: - final_result.messages.extend(result.messages) - if not result.passing: - final_result.messages = tuple(final_result.messages) - final_result.passing = False - return final_result - final_result.messages = tuple(final_result.messages) - final_result.passing = True - return final_result + # No results, no policies attached -> passing + if len(all_results) == 0: + return PolicyResult(self.empty_result) + passing = False + if self.mode == PolicyEngineMode.MODE_AND: + passing = all([x.passing for x in all_results]) + if self.mode == PolicyEngineMode.MODE_OR: + passing = any([x.passing for x in all_results]) + result = PolicyResult(passing) + result.messages = tuple([y for x in all_results for y in x.messages]) + return result @property def passing(self) -> bool: diff --git a/authentik/policies/tests/test_engine.py b/authentik/policies/tests/test_engine.py index d59825c4e..e58dc27ca 100644 --- a/authentik/policies/tests/test_engine.py +++ b/authentik/policies/tests/test_engine.py @@ -4,7 +4,7 @@ from django.test import TestCase from authentik.core.models import User from authentik.policies.dummy.models import DummyPolicy -from authentik.policies.engine import PolicyEngine +from authentik.policies.engine import PolicyEngine, PolicyEngineMode from authentik.policies.expression.models import ExpressionPolicy from authentik.policies.models import Policy, PolicyBinding, PolicyBindingModel from authentik.policies.tests.test_process import clear_policy_cache @@ -44,15 +44,38 @@ class TestPolicyEngine(TestCase): self.assertEqual(result.passing, True) self.assertEqual(result.messages, ("dummy",)) - def test_engine(self): - """Ensure all policies passes (Mix of false and true -> false)""" + def test_engine_mode_and(self): + """Ensure all policies passes with AND mode (false and true -> false)""" pbm = PolicyBindingModel.objects.create() PolicyBinding.objects.create(target=pbm, policy=self.policy_false, order=0) PolicyBinding.objects.create(target=pbm, policy=self.policy_true, order=1) engine = PolicyEngine(pbm, self.user) result = engine.build().result self.assertEqual(result.passing, False) - self.assertEqual(result.messages, ("dummy",)) + self.assertEqual( + result.messages, + ( + "dummy", + "dummy", + ), + ) + + def test_engine_mode_or(self): + """Ensure all policies passes with OR mode (false and true -> true)""" + pbm = PolicyBindingModel.objects.create() + PolicyBinding.objects.create(target=pbm, policy=self.policy_false, order=0) + PolicyBinding.objects.create(target=pbm, policy=self.policy_true, order=1) + engine = PolicyEngine(pbm, self.user) + engine.mode = PolicyEngineMode.MODE_OR + result = engine.build().result + self.assertEqual(result.passing, True) + self.assertEqual( + result.messages, + ( + "dummy", + "dummy", + ), + ) def test_engine_negate(self): """Test negate flag"""