diff --git a/authentik/core/api/applications.py b/authentik/core/api/applications.py index 784497bd0..1c6437d7f 100644 --- a/authentik/core/api/applications.py +++ b/authentik/core/api/applications.py @@ -52,6 +52,7 @@ class ApplicationSerializer(ModelSerializer): "meta_icon", "meta_description", "meta_publisher", + "policy_engine_mode", ] diff --git a/authentik/core/api/sources.py b/authentik/core/api/sources.py index 3fa14ad3e..94b31c92b 100644 --- a/authentik/core/api/sources.py +++ b/authentik/core/api/sources.py @@ -43,6 +43,7 @@ class SourceSerializer(ModelSerializer, MetaNameSerializer): "object_type", "verbose_name", "verbose_name_plural", + "policy_engine_mode", ] diff --git a/authentik/events/tasks.py b/authentik/events/tasks.py index aa9793bf5..9822bf02c 100644 --- a/authentik/events/tasks.py +++ b/authentik/events/tasks.py @@ -11,8 +11,8 @@ from authentik.events.models import ( NotificationTransportError, ) from authentik.events.monitored_tasks import MonitoredTask, TaskResult, TaskResultStatus -from authentik.policies.engine import PolicyEngine, PolicyEngineMode -from authentik.policies.models import PolicyBinding +from authentik.policies.engine import PolicyEngine +from authentik.policies.models import PolicyBinding, PolicyEngineMode from authentik.root.celery import CELERY_APP LOGGER = get_logger() @@ -60,7 +60,7 @@ def event_trigger_handler(event_uuid: str, trigger_name: str): LOGGER.debug("e(trigger): checking if trigger applies", trigger=trigger) user = User.objects.filter(pk=event.user.get("pk")).first() or get_anonymous_user() policy_engine = PolicyEngine(trigger, user) - policy_engine.mode = PolicyEngineMode.MODE_OR + policy_engine.mode = PolicyEngineMode.MODE_ANY policy_engine.empty_result = False policy_engine.use_cache = False policy_engine.request.context["event"] = event diff --git a/authentik/flows/api/bindings.py b/authentik/flows/api/bindings.py index 32d127ffb..4a1f93a49 100644 --- a/authentik/flows/api/bindings.py +++ b/authentik/flows/api/bindings.py @@ -23,7 +23,7 @@ class FlowStageBindingSerializer(ModelSerializer): "evaluate_on_plan", "re_evaluate_policies", "order", - "policies", + "policy_engine_mode", ] diff --git a/authentik/flows/api/flows.py b/authentik/flows/api/flows.py index 0923d6c88..b2b6d84e4 100644 --- a/authentik/flows/api/flows.py +++ b/authentik/flows/api/flows.py @@ -59,6 +59,7 @@ class FlowSerializer(ModelSerializer): "stages", "policies", "cache_count", + "policy_engine_mode", ] diff --git a/authentik/flows/forms.py b/authentik/flows/forms.py index f8082aa67..79b4a2b31 100644 --- a/authentik/flows/forms.py +++ b/authentik/flows/forms.py @@ -27,6 +27,7 @@ class FlowStageBindingForm(forms.ModelForm): "evaluate_on_plan", "re_evaluate_policies", "order", + "policy_engine_mode", ] widgets = { "name": forms.TextInput(), diff --git a/authentik/policies/engine.py b/authentik/policies/engine.py index 37dc27385..7019af158 100644 --- a/authentik/policies/engine.py +++ b/authentik/policies/engine.py @@ -1,5 +1,4 @@ """authentik policy engine""" -from enum import Enum from multiprocessing import Pipe, current_process from multiprocessing.connection import Connection from typing import Iterator, Optional @@ -11,7 +10,12 @@ from sentry_sdk.tracing import Span from structlog.stdlib import BoundLogger, get_logger from authentik.core.models import User -from authentik.policies.models import Policy, PolicyBinding, PolicyBindingModel +from authentik.policies.models import ( + Policy, + PolicyBinding, + PolicyBindingModel, + PolicyEngineMode, +) from authentik.policies.process import PolicyProcess, cache_key from authentik.policies.types import PolicyRequest, PolicyResult @@ -35,13 +39,6 @@ class PolicyProcessInfo: self.result = None -class PolicyEngineMode(Enum): - """Decide how results of multiple policies should be combined.""" - - MODE_AND = "and" - MODE_OR = "or" - - class PolicyEngine: """Orchestrate policy checking, launch tasks and return result""" @@ -63,7 +60,7 @@ class PolicyEngine: self, pbm: PolicyBindingModel, user: User, request: HttpRequest = None ): self.logger = get_logger().bind() - self.mode = PolicyEngineMode.MODE_AND + self.mode = pbm.policy_engine_mode # For backwards compatibility, set empty_result to true # objects with no policies attached will pass. self.empty_result = True @@ -147,9 +144,9 @@ class PolicyEngine: if len(all_results) == 0: return PolicyResult(self.empty_result) passing = False - if self.mode == PolicyEngineMode.MODE_AND: + if self.mode == PolicyEngineMode.MODE_ALL: passing = all(x.passing for x in all_results) - if self.mode == PolicyEngineMode.MODE_OR: + if self.mode == PolicyEngineMode.MODE_ANY: passing = any(x.passing for x in all_results) result = PolicyResult(passing) result.messages = tuple(y for x in all_results for y in x.messages) diff --git a/authentik/policies/migrations/0007_policybindingmodel_policy_engine_mode.py b/authentik/policies/migrations/0007_policybindingmodel_policy_engine_mode.py new file mode 100644 index 000000000..dd9438b22 --- /dev/null +++ b/authentik/policies/migrations/0007_policybindingmodel_policy_engine_mode.py @@ -0,0 +1,37 @@ +# Generated by Django 3.1.7 on 2021-03-31 08:19 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("authentik_policies", "0006_auto_20210329_1334"), + ] + + operations = [ + # Create field with default as all for backwards compat + migrations.AddField( + model_name="policybindingmodel", + name="policy_engine_mode", + field=models.TextField( + choices=[ + ("all", "ALL, all policies must pass"), + ("any", "ANY, any policy must pass"), + ], + default="all", + ), + ), + # Set default for new objects to any + migrations.AlterField( + model_name="policybindingmodel", + name="policy_engine_mode", + field=models.TextField( + choices=[ + ("all", "ALL, all policies must pass"), + ("any", "ANY, any policy must pass"), + ], + default="any", + ), + ), + ] diff --git a/authentik/policies/models.py b/authentik/policies/models.py index 9964a1dcd..19045ae78 100644 --- a/authentik/policies/models.py +++ b/authentik/policies/models.py @@ -18,6 +18,15 @@ from authentik.policies.exceptions import PolicyException from authentik.policies.types import PolicyRequest, PolicyResult +class PolicyEngineMode(models.TextChoices): + """Decide how results of multiple policies should be combined.""" + + # pyright: reportGeneralTypeIssues=false + MODE_ALL = "all", _("ALL, all policies must pass") # type: "PolicyEngineMode" + # pyright: reportGeneralTypeIssues=false + MODE_ANY = "any", _("ANY, any policy must pass") # type: "PolicyEngineMode" + + class PolicyBindingModel(models.Model): """Base Model for objects that have policies applied to them.""" @@ -27,6 +36,11 @@ class PolicyBindingModel(models.Model): "Policy", through="PolicyBinding", related_name="bindings", blank=True ) + policy_engine_mode = models.TextField( + choices=PolicyEngineMode.choices, + default=PolicyEngineMode.MODE_ANY, + ) + objects = InheritanceManager() class Meta: diff --git a/authentik/policies/tests/test_engine.py b/authentik/policies/tests/test_engine.py index e58dc27ca..0223718ca 100644 --- a/authentik/policies/tests/test_engine.py +++ b/authentik/policies/tests/test_engine.py @@ -4,9 +4,14 @@ from django.test import TestCase from authentik.core.models import User from authentik.policies.dummy.models import DummyPolicy -from authentik.policies.engine import PolicyEngine, PolicyEngineMode +from authentik.policies.engine import PolicyEngine from authentik.policies.expression.models import ExpressionPolicy -from authentik.policies.models import Policy, PolicyBinding, PolicyBindingModel +from authentik.policies.models import ( + Policy, + PolicyBinding, + PolicyBindingModel, + PolicyEngineMode, +) from authentik.policies.tests.test_process import clear_policy_cache @@ -44,9 +49,11 @@ class TestPolicyEngine(TestCase): self.assertEqual(result.passing, True) self.assertEqual(result.messages, ("dummy",)) - def test_engine_mode_and(self): + def test_engine_mode_all(self): """Ensure all policies passes with AND mode (false and true -> false)""" - pbm = PolicyBindingModel.objects.create() + pbm = PolicyBindingModel.objects.create( + policy_engine_mode=PolicyEngineMode.MODE_ALL + ) PolicyBinding.objects.create(target=pbm, policy=self.policy_false, order=0) PolicyBinding.objects.create(target=pbm, policy=self.policy_true, order=1) engine = PolicyEngine(pbm, self.user) @@ -60,13 +67,14 @@ class TestPolicyEngine(TestCase): ), ) - def test_engine_mode_or(self): + def test_engine_mode_any(self): """Ensure all policies passes with OR mode (false and true -> true)""" - pbm = PolicyBindingModel.objects.create() + pbm = PolicyBindingModel.objects.create( + policy_engine_mode=PolicyEngineMode.MODE_ANY + ) PolicyBinding.objects.create(target=pbm, policy=self.policy_false, order=0) PolicyBinding.objects.create(target=pbm, policy=self.policy_true, order=1) engine = PolicyEngine(pbm, self.user) - engine.mode = PolicyEngineMode.MODE_OR result = engine.build().result self.assertEqual(result.passing, True) self.assertEqual( diff --git a/authentik/sources/ldap/forms.py b/authentik/sources/ldap/forms.py index c9bfb58a4..ae2088421 100644 --- a/authentik/sources/ldap/forms.py +++ b/authentik/sources/ldap/forms.py @@ -26,6 +26,7 @@ class LDAPSourceForm(forms.ModelForm): "name", "slug", "enabled", + "policy_engine_mode", # -- start of our custom fields "server_uri", "start_tls", diff --git a/authentik/sources/oauth/forms.py b/authentik/sources/oauth/forms.py index c8190882b..bdcd17ab1 100644 --- a/authentik/sources/oauth/forms.py +++ b/authentik/sources/oauth/forms.py @@ -32,6 +32,7 @@ class OAuthSourceForm(forms.ModelForm): "name", "slug", "enabled", + "policy_engine_mode", "authentication_flow", "enrollment_flow", "provider_type", diff --git a/authentik/sources/saml/forms.py b/authentik/sources/saml/forms.py index 5763e9c87..0a4773b3a 100644 --- a/authentik/sources/saml/forms.py +++ b/authentik/sources/saml/forms.py @@ -35,6 +35,7 @@ class SAMLSourceForm(forms.ModelForm): "name", "slug", "enabled", + "policy_engine_mode", "pre_authentication_flow", "authentication_flow", "enrollment_flow", diff --git a/swagger.yaml b/swagger.yaml index 553ac6459..378ac2615 100755 --- a/swagger.yaml +++ b/swagger.yaml @@ -3421,6 +3421,11 @@ paths: description: '' required: false type: string + - name: policy_engine_mode + in: query + description: '' + required: false + type: string - name: fsb_uuid in: query description: '' @@ -14729,6 +14734,12 @@ definitions: meta_publisher: title: Meta publisher type: string + policy_engine_mode: + title: Policy engine mode + type: string + enum: + - all + - any Group: required: - name @@ -15288,6 +15299,12 @@ definitions: title: Cache count type: string readOnly: true + policy_engine_mode: + title: Policy engine mode + type: string + enum: + - all + - any Stage: required: - name @@ -15358,13 +15375,12 @@ definitions: type: integer maximum: 2147483647 minimum: -2147483648 - policies: - type: array - items: - type: string - format: uuid - readOnly: true - uniqueItems: true + policy_engine_mode: + title: Policy engine mode + type: string + enum: + - all + - any ErrorDetail: required: - string @@ -16151,6 +16167,12 @@ definitions: type: string format: uuid readOnly: true + policy_engine_mode: + title: Policy engine mode + type: string + enum: + - all + - any name: title: Name description: Source's display Name. @@ -17172,6 +17194,12 @@ definitions: title: Verbose name plural type: string readOnly: true + policy_engine_mode: + title: Policy engine mode + type: string + enum: + - all + - any UserSetting: required: - object_uid @@ -17246,6 +17274,12 @@ definitions: title: Verbose name plural type: string readOnly: true + policy_engine_mode: + title: Policy engine mode + type: string + enum: + - all + - any server_uri: title: Server URI type: string @@ -17388,6 +17422,12 @@ definitions: title: Verbose name plural type: string readOnly: true + policy_engine_mode: + title: Policy engine mode + type: string + enum: + - all + - any provider_type: title: Provider type type: string @@ -17504,6 +17544,12 @@ definitions: title: Verbose name plural type: string readOnly: true + policy_engine_mode: + title: Policy engine mode + type: string + enum: + - all + - any pre_authentication_flow: title: Pre authentication flow description: Flow used before authentication. @@ -18196,6 +18242,12 @@ definitions: type: string format: uuid readOnly: true + policy_engine_mode: + title: Policy engine mode + type: string + enum: + - all + - any name: title: Name description: Source's display Name. diff --git a/web/src/pages/applications/ApplicationForm.ts b/web/src/pages/applications/ApplicationForm.ts index e194566ea..c079517a5 100644 --- a/web/src/pages/applications/ApplicationForm.ts +++ b/web/src/pages/applications/ApplicationForm.ts @@ -1,4 +1,4 @@ -import { CoreApi, Application, ProvidersApi, Provider } from "authentik-api"; +import { CoreApi, Application, ProvidersApi, Provider, ApplicationPolicyEngineModeEnum } from "authentik-api"; import { gettext } from "django"; import { customElement, property } from "lit-element"; import { html, TemplateResult } from "lit-html"; @@ -97,6 +97,19 @@ export class ApplicationForm extends Form { }), html``)} + + + diff --git a/web/src/pages/flows/FlowForm.ts b/web/src/pages/flows/FlowForm.ts index ba32c6af1..9163ca6f1 100644 --- a/web/src/pages/flows/FlowForm.ts +++ b/web/src/pages/flows/FlowForm.ts @@ -1,4 +1,4 @@ -import { Flow, FlowDesignationEnum, FlowsApi } from "authentik-api"; +import { Flow, FlowDesignationEnum, FlowPolicyEngineModeEnum, FlowsApi } from "authentik-api"; import { gettext } from "django"; import { customElement, property } from "lit-element"; import { html, TemplateResult } from "lit-html"; @@ -93,6 +93,19 @@ export class FlowForm extends Form {

${gettext("Visible in the URL.")}

+ + +