policies: add and/or mode (#463)
* policies: add mode to PolicyEngine for AND and OR modes * events: use PolicyEngine in OR mode
This commit is contained in:
parent
b2b737e59e
commit
c727c845df
|
@ -9,7 +9,7 @@ from authentik.events.models import (
|
||||||
NotificationTrigger,
|
NotificationTrigger,
|
||||||
)
|
)
|
||||||
from authentik.lib.tasks import MonitoredTask, TaskResult, TaskResultStatus
|
from authentik.lib.tasks import MonitoredTask, TaskResult, TaskResultStatus
|
||||||
from authentik.policies.engine import PolicyEngine
|
from authentik.policies.engine import PolicyEngine, PolicyEngineMode
|
||||||
from authentik.root.celery import CELERY_APP
|
from authentik.root.celery import CELERY_APP
|
||||||
|
|
||||||
LOGGER = get_logger()
|
LOGGER = get_logger()
|
||||||
|
@ -43,6 +43,8 @@ def event_trigger_handler(event_uuid: str, trigger_name: str):
|
||||||
return
|
return
|
||||||
|
|
||||||
policy_engine = PolicyEngine(trigger, get_anonymous_user())
|
policy_engine = PolicyEngine(trigger, get_anonymous_user())
|
||||||
|
policy_engine.mode = PolicyEngineMode.MODE_OR
|
||||||
|
policy_engine.empty_result = False
|
||||||
policy_engine.request.context["event"] = event
|
policy_engine.request.context["event"] = event
|
||||||
policy_engine.build()
|
policy_engine.build()
|
||||||
result = policy_engine.result
|
result = policy_engine.result
|
||||||
|
|
|
@ -25,6 +25,18 @@ class TestEventsNotifications(TestCase):
|
||||||
self.group.users.add(self.user)
|
self.group.users.add(self.user)
|
||||||
self.group.save()
|
self.group.save()
|
||||||
|
|
||||||
|
def test_trigger_empty(self):
|
||||||
|
"""Test trigger without any policies attached"""
|
||||||
|
transport = NotificationTransport.objects.create(name="transport")
|
||||||
|
trigger = NotificationTrigger.objects.create(name="trigger", group=self.group)
|
||||||
|
trigger.transports.add(transport)
|
||||||
|
trigger.save()
|
||||||
|
|
||||||
|
execute_mock = MagicMock()
|
||||||
|
with patch("authentik.events.models.NotificationTransport.send", execute_mock):
|
||||||
|
Event.new(EventAction.CUSTOM_PREFIX).save()
|
||||||
|
self.assertEqual(execute_mock.call_count, 0)
|
||||||
|
|
||||||
def test_trigger_single(self):
|
def test_trigger_single(self):
|
||||||
"""Test simple transport triggering"""
|
"""Test simple transport triggering"""
|
||||||
transport = NotificationTransport.objects.create(name="transport")
|
transport = NotificationTransport.objects.create(name="transport")
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
"""authentik policy engine"""
|
"""authentik policy engine"""
|
||||||
|
from enum import Enum
|
||||||
from multiprocessing import Pipe, set_start_method
|
from multiprocessing import Pipe, set_start_method
|
||||||
from multiprocessing.connection import Connection
|
from multiprocessing.connection import Connection
|
||||||
from typing import Iterator, List, Optional
|
from typing import Iterator, List, Optional
|
||||||
|
@ -37,12 +38,23 @@ class PolicyProcessInfo:
|
||||||
self.result = None
|
self.result = None
|
||||||
|
|
||||||
|
|
||||||
|
class PolicyEngineMode(Enum):
|
||||||
|
"""Decide how results of multiple policies should be combined."""
|
||||||
|
|
||||||
|
MODE_AND = "and"
|
||||||
|
MODE_OR = "or"
|
||||||
|
|
||||||
|
|
||||||
class PolicyEngine:
|
class PolicyEngine:
|
||||||
"""Orchestrate policy checking, launch tasks and return result"""
|
"""Orchestrate policy checking, launch tasks and return result"""
|
||||||
|
|
||||||
use_cache: bool
|
use_cache: bool
|
||||||
request: PolicyRequest
|
request: PolicyRequest
|
||||||
|
|
||||||
|
mode: PolicyEngineMode
|
||||||
|
# Allow objects with no policies attached to pass
|
||||||
|
empty_result: bool
|
||||||
|
|
||||||
__pbm: PolicyBindingModel
|
__pbm: PolicyBindingModel
|
||||||
__cached_policies: List[PolicyResult]
|
__cached_policies: List[PolicyResult]
|
||||||
__processes: List[PolicyProcessInfo]
|
__processes: List[PolicyProcessInfo]
|
||||||
|
@ -52,6 +64,10 @@ class PolicyEngine:
|
||||||
def __init__(
|
def __init__(
|
||||||
self, pbm: PolicyBindingModel, user: User, request: HttpRequest = None
|
self, pbm: PolicyBindingModel, user: User, request: HttpRequest = None
|
||||||
):
|
):
|
||||||
|
self.mode = PolicyEngineMode.MODE_AND
|
||||||
|
# For backwards compatibility, set empty_result to true
|
||||||
|
# objects with no policies attached will pass.
|
||||||
|
self.empty_result = True
|
||||||
if not isinstance(pbm, PolicyBindingModel): # pragma: no cover
|
if not isinstance(pbm, PolicyBindingModel): # pragma: no cover
|
||||||
raise ValueError(f"{pbm} is not instance of PolicyBindingModel")
|
raise ValueError(f"{pbm} is not instance of PolicyBindingModel")
|
||||||
self.__pbm = pbm
|
self.__pbm = pbm
|
||||||
|
@ -119,24 +135,19 @@ class PolicyEngine:
|
||||||
x.result for x in self.__processes if x.result
|
x.result for x in self.__processes if x.result
|
||||||
]
|
]
|
||||||
all_results = list(process_results + self.__cached_policies)
|
all_results = list(process_results + self.__cached_policies)
|
||||||
final_result = PolicyResult(False)
|
|
||||||
final_result.messages = []
|
|
||||||
final_result.source_results = all_results
|
|
||||||
if len(all_results) < self.__expected_result_count: # pragma: no cover
|
if len(all_results) < self.__expected_result_count: # pragma: no cover
|
||||||
raise AssertionError("Got less results than polices")
|
raise AssertionError("Got less results than polices")
|
||||||
for result in all_results:
|
# No results, no policies attached -> passing
|
||||||
LOGGER.debug(
|
if len(all_results) == 0:
|
||||||
"P_ENG: result", passing=result.passing, messages=result.messages
|
return PolicyResult(self.empty_result)
|
||||||
)
|
passing = False
|
||||||
if result.messages:
|
if self.mode == PolicyEngineMode.MODE_AND:
|
||||||
final_result.messages.extend(result.messages)
|
passing = all([x.passing for x in all_results])
|
||||||
if not result.passing:
|
if self.mode == PolicyEngineMode.MODE_OR:
|
||||||
final_result.messages = tuple(final_result.messages)
|
passing = any([x.passing for x in all_results])
|
||||||
final_result.passing = False
|
result = PolicyResult(passing)
|
||||||
return final_result
|
result.messages = tuple([y for x in all_results for y in x.messages])
|
||||||
final_result.messages = tuple(final_result.messages)
|
return result
|
||||||
final_result.passing = True
|
|
||||||
return final_result
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def passing(self) -> bool:
|
def passing(self) -> bool:
|
||||||
|
|
|
@ -4,7 +4,7 @@ from django.test import TestCase
|
||||||
|
|
||||||
from authentik.core.models import User
|
from authentik.core.models import User
|
||||||
from authentik.policies.dummy.models import DummyPolicy
|
from authentik.policies.dummy.models import DummyPolicy
|
||||||
from authentik.policies.engine import PolicyEngine
|
from authentik.policies.engine import PolicyEngine, PolicyEngineMode
|
||||||
from authentik.policies.expression.models import ExpressionPolicy
|
from authentik.policies.expression.models import ExpressionPolicy
|
||||||
from authentik.policies.models import Policy, PolicyBinding, PolicyBindingModel
|
from authentik.policies.models import Policy, PolicyBinding, PolicyBindingModel
|
||||||
from authentik.policies.tests.test_process import clear_policy_cache
|
from authentik.policies.tests.test_process import clear_policy_cache
|
||||||
|
@ -44,15 +44,38 @@ class TestPolicyEngine(TestCase):
|
||||||
self.assertEqual(result.passing, True)
|
self.assertEqual(result.passing, True)
|
||||||
self.assertEqual(result.messages, ("dummy",))
|
self.assertEqual(result.messages, ("dummy",))
|
||||||
|
|
||||||
def test_engine(self):
|
def test_engine_mode_and(self):
|
||||||
"""Ensure all policies passes (Mix of false and true -> false)"""
|
"""Ensure all policies passes with AND mode (false and true -> false)"""
|
||||||
pbm = PolicyBindingModel.objects.create()
|
pbm = PolicyBindingModel.objects.create()
|
||||||
PolicyBinding.objects.create(target=pbm, policy=self.policy_false, order=0)
|
PolicyBinding.objects.create(target=pbm, policy=self.policy_false, order=0)
|
||||||
PolicyBinding.objects.create(target=pbm, policy=self.policy_true, order=1)
|
PolicyBinding.objects.create(target=pbm, policy=self.policy_true, order=1)
|
||||||
engine = PolicyEngine(pbm, self.user)
|
engine = PolicyEngine(pbm, self.user)
|
||||||
result = engine.build().result
|
result = engine.build().result
|
||||||
self.assertEqual(result.passing, False)
|
self.assertEqual(result.passing, False)
|
||||||
self.assertEqual(result.messages, ("dummy",))
|
self.assertEqual(
|
||||||
|
result.messages,
|
||||||
|
(
|
||||||
|
"dummy",
|
||||||
|
"dummy",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_engine_mode_or(self):
|
||||||
|
"""Ensure all policies passes with OR mode (false and true -> true)"""
|
||||||
|
pbm = PolicyBindingModel.objects.create()
|
||||||
|
PolicyBinding.objects.create(target=pbm, policy=self.policy_false, order=0)
|
||||||
|
PolicyBinding.objects.create(target=pbm, policy=self.policy_true, order=1)
|
||||||
|
engine = PolicyEngine(pbm, self.user)
|
||||||
|
engine.mode = PolicyEngineMode.MODE_OR
|
||||||
|
result = engine.build().result
|
||||||
|
self.assertEqual(result.passing, True)
|
||||||
|
self.assertEqual(
|
||||||
|
result.messages,
|
||||||
|
(
|
||||||
|
"dummy",
|
||||||
|
"dummy",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def test_engine_negate(self):
|
def test_engine_negate(self):
|
||||||
"""Test negate flag"""
|
"""Test negate flag"""
|
||||||
|
|
Reference in New Issue