diff --git a/authentik/core/tests/test_views_overview.py b/authentik/core/tests/test_views_overview.py index 84e214d51..e6eafdc75 100644 --- a/authentik/core/tests/test_views_overview.py +++ b/authentik/core/tests/test_views_overview.py @@ -28,9 +28,3 @@ class TestOverviewViews(TestCase): self.assertEqual( self.client.get(reverse("authentik_core:shell")).status_code, 200 ) - - def test_overview(self): - """Test overview""" - self.assertEqual( - self.client.get(reverse("authentik_core:overview")).status_code, 200 - ) diff --git a/authentik/stages/prompt/forms.py b/authentik/stages/prompt/forms.py index 9dcaee300..abdcf0ed8 100644 --- a/authentik/stages/prompt/forms.py +++ b/authentik/stages/prompt/forms.py @@ -1,20 +1,7 @@ """Prompt forms""" -from email.policy import Policy -from types import MethodType -from typing import Any, Callable, Iterator - from django import forms -from django.db.models.query import QuerySet -from django.http import HttpRequest -from django.utils.translation import gettext_lazy as _ -from guardian.shortcuts import get_anonymous_user -from authentik.core.models import User -from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan -from authentik.policies.engine import PolicyEngine -from authentik.policies.models import PolicyBinding, PolicyBindingModel -from authentik.stages.prompt.models import FieldTypes, Prompt, PromptStage -from authentik.stages.prompt.signals import password_validate +from authentik.stages.prompt.models import Prompt, PromptStage class PromptStageForm(forms.ModelForm): @@ -47,111 +34,3 @@ class PromptAdminForm(forms.ModelForm): "label": forms.TextInput(), "placeholder": forms.TextInput(), } - - -class ListPolicyEngine(PolicyEngine): - """Slightly modified policy engine, which uses a list instead of a PolicyBindingModel""" - - __list: list[Policy] - - def __init__( - self, policies: list[Policy], user: User, request: HttpRequest = None - ) -> None: - super().__init__(PolicyBindingModel(), user, request) - self.__list = policies - self.use_cache = False - - def _iter_bindings(self) -> Iterator[PolicyBinding]: - for policy in self.__list: - yield PolicyBinding( - policy=policy, - ) - - -class PromptForm(forms.Form): - """Dynamically created form based on PromptStage""" - - stage: PromptStage - plan: FlowPlan - - def __init__(self, stage: PromptStage, plan: FlowPlan, *args, **kwargs): - self.stage = stage - self.plan = plan - super().__init__(*args, **kwargs) - # list() is called so we only load the fields once - fields = list(self.stage.fields.all()) - for field in fields: - field: Prompt - self.fields[field.field_key] = field.field - # Special handling for fields with username type - # these check for existing users with the same username - if field.type == FieldTypes.USERNAME: - setattr( - self, - f"clean_{field.field_key}", - MethodType(username_field_cleaner_factory(field), self), - ) - # Check if we have a password field, add a handler that sends a signal - # to validate it - if field.type == FieldTypes.PASSWORD: - setattr( - self, - f"clean_{field.field_key}", - MethodType(password_single_cleaner_factory(field), self), - ) - - self.field_order = sorted(fields, key=lambda x: x.order) - - def _clean_password_fields(self, *field_names): - """Check if the value of all password fields match by merging them into a set - and checking the length""" - all_passwords = {self.cleaned_data[x] for x in field_names} - if len(all_passwords) > 1: - raise forms.ValidationError(_("Passwords don't match.")) - - def clean(self): - cleaned_data = super().clean() - if cleaned_data == {}: - return {} - # Check if we have two password fields, and make sure they are the same - password_fields: QuerySet[Prompt] = self.stage.fields.filter( - type=FieldTypes.PASSWORD - ) - if password_fields.exists() and password_fields.count() == 2: - self._clean_password_fields(*[field.field_key for field in password_fields]) - - user = self.plan.context.get(PLAN_CONTEXT_PENDING_USER, get_anonymous_user()) - engine = ListPolicyEngine(self.stage.validation_policies.all(), user) - engine.request.context = cleaned_data - engine.build() - result = engine.result - if not result.passing: - raise forms.ValidationError(list(result.messages)) - return cleaned_data - - -def username_field_cleaner_factory(field: Prompt) -> Callable: - """Return a `clean_` method for `field`. Clean method checks if username is taken already.""" - - def username_field_cleaner(self: PromptForm) -> Any: - """Check for duplicate usernames""" - username = self.cleaned_data.get(field.field_key) - if User.objects.filter(username=username).exists(): - raise forms.ValidationError("Username is already taken.") - return username - - return username_field_cleaner - - -def password_single_cleaner_factory(field: Prompt) -> Callable[[PromptForm], Any]: - """Return a `clean_` method for `field`. Clean method checks if username is taken already.""" - - def password_single_clean(self: PromptForm) -> Any: - """Send password validation signals for e.g. LDAP Source""" - password = self.cleaned_data[field.field_key] - password_validate.send( - sender=self, password=password, plan_context=self.plan.context - ) - return password - - return password_single_clean diff --git a/authentik/stages/prompt/models.py b/authentik/stages/prompt/models.py index 602e8b64e..46d55b870 100644 --- a/authentik/stages/prompt/models.py +++ b/authentik/stages/prompt/models.py @@ -2,17 +2,23 @@ from typing import Type from uuid import uuid4 -from django import forms from django.db import models from django.forms import ModelForm from django.utils.translation import gettext_lazy as _ from django.views import View +from rest_framework.fields import ( + BooleanField, + CharField, + DateField, + DateTimeField, + EmailField, + IntegerField, +) from rest_framework.serializers import BaseSerializer from authentik.flows.models import Stage from authentik.lib.models import SerializerModel from authentik.policies.models import Policy -from authentik.stages.prompt.widgets import HorizontalRuleWidget, StaticTextWidget class FieldTypes(models.TextChoices): @@ -43,8 +49,8 @@ class FieldTypes(models.TextChoices): ) NUMBER = "number" CHECKBOX = "checkbox" - DATE = "data" - DATE_TIME = "data-time" + DATE = "date" + DATE_TIME = "date-time" SEPARATOR = "separator", _("Separator: Static Separator Line") HIDDEN = "hidden", _("Hidden: Hidden field, can be used to insert data into form.") @@ -73,49 +79,34 @@ class Prompt(SerializerModel): return PromptSerializer @property - def field(self): - """Return instantiated form input field""" - attrs = {"placeholder": _(self.placeholder)} - field_class = forms.CharField - widget = forms.TextInput(attrs=attrs) + def field(self) -> CharField: + """Get field type for Challenge and response""" + field_class = CharField kwargs = { - "label": _(self.label), "required": self.required, } if self.type == FieldTypes.EMAIL: - field_class = forms.EmailField - if self.type == FieldTypes.USERNAME: - attrs["autocomplete"] = "username" - if self.type == FieldTypes.PASSWORD: - widget = forms.PasswordInput(attrs=attrs) - attrs["autocomplete"] = "new-password" + field_class = EmailField if self.type == FieldTypes.NUMBER: - field_class = forms.IntegerField - widget = forms.NumberInput(attrs=attrs) + field_class = IntegerField + # TODO: Hidden? if self.type == FieldTypes.HIDDEN: - widget = forms.HiddenInput(attrs=attrs) kwargs["required"] = False kwargs["initial"] = self.placeholder if self.type == FieldTypes.CHECKBOX: - field_class = forms.BooleanField + field_class = BooleanField kwargs["required"] = False if self.type == FieldTypes.DATE: - attrs["type"] = "date" - widget = forms.DateInput(attrs=attrs) + field_class = DateField if self.type == FieldTypes.DATE_TIME: - attrs["type"] = "datetime-local" - widget = forms.DateTimeInput(attrs=attrs) + field_class = DateTimeField if self.type == FieldTypes.STATIC: - widget = StaticTextWidget(attrs=attrs) kwargs["initial"] = self.placeholder kwargs["required"] = False kwargs["label"] = "" if self.type == FieldTypes.SEPARATOR: - widget = HorizontalRuleWidget(attrs=attrs) kwargs["required"] = False kwargs["label"] = "" - - kwargs["widget"] = widget return field_class(**kwargs) def save(self, *args, **kwargs): diff --git a/authentik/stages/prompt/stage.py b/authentik/stages/prompt/stage.py index 93ecf8b21..02965dd21 100644 --- a/authentik/stages/prompt/stage.py +++ b/authentik/stages/prompt/stage.py @@ -1,36 +1,189 @@ """Prompt Stage Logic""" -from django.http import HttpResponse +from email.policy import Policy +from types import MethodType +from typing import Any, Callable, Iterator + +from django.db.models.base import Model +from django.db.models.query import QuerySet +from django.http import HttpRequest, HttpResponse +from django.http.request import QueryDict from django.utils.translation import gettext_lazy as _ -from django.views.generic import FormView +from guardian.shortcuts import get_anonymous_user +from rest_framework.fields import BooleanField, CharField, IntegerField +from rest_framework.serializers import Serializer, ValidationError from structlog.stdlib import get_logger -from authentik.flows.stage import StageView -from authentik.stages.prompt.forms import PromptForm +from authentik.core.models import User +from authentik.flows.challenge import Challenge, ChallengeResponse, ChallengeTypes +from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan +from authentik.flows.stage import ChallengeStageView +from authentik.policies.engine import PolicyEngine +from authentik.policies.models import PolicyBinding, PolicyBindingModel +from authentik.stages.prompt.models import FieldTypes, Prompt, PromptStage +from authentik.stages.prompt.signals import password_validate LOGGER = get_logger() PLAN_CONTEXT_PROMPT = "prompt_data" -class PromptStageView(FormView, StageView): +class PromptSerializer(Serializer): + """Serializer for a single Prompt field""" + + field_key = CharField() + label = CharField() + type = CharField() + required = BooleanField() + placeholder = CharField() + order = IntegerField() + + def create(self, validated_data: dict) -> Model: + return Model() + + def update(self, instance: Model, validated_data: dict) -> Model: + return Model() + + +class PromptChallenge(Challenge): + """Initial challenge being sent, define fields""" + + fields = PromptSerializer(many=True) + + +class PromptResponseChallenge(ChallengeResponse): + """Validate response, fields are dynamically created based + on the stage""" + + def __init__(self, *args, stage: PromptStage, plan: FlowPlan, **kwargs): + super().__init__(*args, **kwargs) + self.stage = stage + self.plan = plan + # list() is called so we only load the fields once + fields = list(self.stage.fields.all()) + for field in fields: + field: Prompt + self.fields[field.field_key] = field.field + # Special handling for fields with username type + # these check for existing users with the same username + if field.type == FieldTypes.USERNAME: + setattr( + self, + f"validate_{field.field_key}", + MethodType(username_field_validator_factory(), self), + ) + # Check if we have a password field, add a handler that sends a signal + # to validate it + if field.type == FieldTypes.PASSWORD: + setattr( + self, + f"validate_{field.field_key}", + MethodType(password_single_validator_factory(), self), + ) + + self.field_order = sorted(fields, key=lambda x: x.order) + + def _validate_password_fields(self, *field_names): + """Check if the value of all password fields match by merging them into a set + and checking the length""" + all_passwords = {self.initial_data[x] for x in field_names} + if len(all_passwords) > 1: + raise ValidationError(_("Passwords don't match.")) + + def validate(self, attrs): + if attrs == {}: + return {} + # Check if we have two password fields, and make sure they are the same + password_fields: QuerySet[Prompt] = self.stage.fields.filter( + type=FieldTypes.PASSWORD + ) + if password_fields.exists() and password_fields.count() == 2: + self._validate_password_fields( + *[field.field_key for field in password_fields] + ) + + user = self.plan.context.get(PLAN_CONTEXT_PENDING_USER, get_anonymous_user()) + engine = ListPolicyEngine(self.stage.validation_policies.all(), user) + engine.request.context = attrs + engine.build() + result = engine.result + if not result.passing: + raise ValidationError(list(result.messages)) + return attrs + + +def username_field_validator_factory() -> Callable[[PromptChallenge, str], Any]: + """Return a `clean_` method for `field`. Clean method checks if username is taken already.""" + + # pylint: disable=unused-argument + def username_field_validator(self: PromptChallenge, value: str) -> Any: + """Check for duplicate usernames""" + if User.objects.filter(username=value).exists(): + raise ValidationError("Username is already taken.") + return value + + return username_field_validator + + +def password_single_validator_factory() -> Callable[[PromptChallenge, str], Any]: + """Return a `clean_` method for `field`. Clean method checks if username is taken already.""" + + def password_single_clean(self: PromptChallenge, value: str) -> Any: + """Send password validation signals for e.g. LDAP Source""" + password_validate.send( + sender=self, password=value, plan_context=self.plan.context + ) + return value + + return password_single_clean + + +class ListPolicyEngine(PolicyEngine): + """Slightly modified policy engine, which uses a list instead of a PolicyBindingModel""" + + __list: list[Policy] + + def __init__( + self, policies: list[Policy], user: User, request: HttpRequest = None + ) -> None: + super().__init__(PolicyBindingModel(), user, request) + self.__list = policies + self.use_cache = False + + def _iter_bindings(self) -> Iterator[PolicyBinding]: + for policy in self.__list: + yield PolicyBinding( + policy=policy, + ) + + +class PromptStageView(ChallengeStageView): """Prompt Stage, save form data in plan context.""" - template_name = "login/form.html" - form_class = PromptForm + response_class = PromptResponseChallenge - def get_context_data(self, **kwargs): - ctx = super().get_context_data(**kwargs) - ctx["title"] = _(self.executor.current_stage.name) - return ctx + def get_challenge(self, *args, **kwargs) -> Challenge: + fields = list(self.executor.current_stage.fields.all()) + challenge = PromptChallenge( + data={ + "type": ChallengeTypes.native, + "component": "ak-stage-prompt", + "fields": [PromptSerializer(field).data for field in fields], + }, + ) + return challenge - def get_form_kwargs(self): - kwargs = super().get_form_kwargs() - kwargs["stage"] = self.executor.current_stage - kwargs["plan"] = self.executor.plan - return kwargs + def get_response_instance(self, data: QueryDict) -> ChallengeResponse: + if not self.executor.plan: + raise ValueError + return PromptResponseChallenge( + instance=None, + data=data, + stage=self.executor.current_stage, + plan=self.executor.plan, + ) - def form_valid(self, form: PromptForm) -> HttpResponse: - """Form data is valid""" + def challenge_valid(self, response: ChallengeResponse) -> HttpResponse: if PLAN_CONTEXT_PROMPT not in self.executor.plan.context: self.executor.plan.context[PLAN_CONTEXT_PROMPT] = {} - self.executor.plan.context[PLAN_CONTEXT_PROMPT].update(form.cleaned_data) + self.executor.plan.context[PLAN_CONTEXT_PROMPT].update(response.validated_data) + print(self.executor.plan.context[PLAN_CONTEXT_PROMPT]) return self.executor.stage_ok() diff --git a/authentik/stages/prompt/tests.py b/authentik/stages/prompt/tests.py index 11bbac0ba..58be327cd 100644 --- a/authentik/stages/prompt/tests.py +++ b/authentik/stages/prompt/tests.py @@ -11,9 +11,8 @@ from authentik.flows.models import Flow, FlowDesignation, FlowStageBinding from authentik.flows.planner import FlowPlan from authentik.flows.views import SESSION_KEY_PLAN from authentik.policies.expression.models import ExpressionPolicy -from authentik.stages.prompt.forms import PromptForm from authentik.stages.prompt.models import FieldTypes, Prompt, PromptStage -from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT +from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT, PromptResponseChallenge class TestPromptStage(TestCase): @@ -112,8 +111,8 @@ class TestPromptStage(TestCase): self.assertIn(prompt.label, force_str(response.content)) self.assertIn(prompt.placeholder, force_str(response.content)) - def test_valid_form_with_policy(self) -> PromptForm: - """Test form validation""" + def test_valid_challenge_with_policy(self) -> PromptResponseChallenge: + """Test challenge_response validation""" plan = FlowPlan( flow_pk=self.flow.pk.hex, stages=[self.stage], markers=[StageMarker()] ) @@ -123,12 +122,14 @@ class TestPromptStage(TestCase): ) self.stage.validation_policies.set([expr_policy]) self.stage.save() - form = PromptForm(stage=self.stage, plan=plan, data=self.prompt_data) - self.assertEqual(form.is_valid(), True) - return form + challenge_response = PromptResponseChallenge( + None, stage=self.stage, plan=plan, data=self.prompt_data + ) + self.assertEqual(challenge_response.is_valid(), True) + return challenge_response - def test_invalid_form(self) -> PromptForm: - """Test form validation""" + def test_invalid_challenge(self) -> PromptResponseChallenge: + """Test challenge_response validation""" plan = FlowPlan( flow_pk=self.flow.pk.hex, stages=[self.stage], markers=[StageMarker()] ) @@ -138,12 +139,14 @@ class TestPromptStage(TestCase): ) self.stage.validation_policies.set([expr_policy]) self.stage.save() - form = PromptForm(stage=self.stage, plan=plan, data=self.prompt_data) - self.assertEqual(form.is_valid(), False) - return form + challenge_response = PromptResponseChallenge( + None, stage=self.stage, plan=plan, data=self.prompt_data + ) + self.assertEqual(challenge_response.is_valid(), False) + return challenge_response - def test_valid_form_request(self): - """Test a request with valid form data""" + def test_valid_challenge_request(self): + """Test a request with valid challenge_response data""" plan = FlowPlan( flow_pk=self.flow.pk.hex, stages=[self.stage], markers=[StageMarker()] ) @@ -151,7 +154,7 @@ class TestPromptStage(TestCase): session[SESSION_KEY_PLAN] = plan session.save() - form = self.test_valid_form_with_policy() + challenge_response = self.test_valid_challenge_with_policy() with patch("authentik.flows.views.FlowExecutorView.cancel", MagicMock()): response = self.client.post( @@ -159,7 +162,7 @@ class TestPromptStage(TestCase): "authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}, ), - form.cleaned_data, + challenge_response.validated_data, ) self.assertEqual(response.status_code, 200) self.assertJSONEqual( diff --git a/authentik/stages/prompt/widgets.py b/authentik/stages/prompt/widgets.py deleted file mode 100644 index 0ddb26a10..000000000 --- a/authentik/stages/prompt/widgets.py +++ /dev/null @@ -1,17 +0,0 @@ -"""Prompt Widgets""" -from django import forms -from django.utils.safestring import mark_safe - - -class StaticTextWidget(forms.widgets.Widget): - """Widget to render static text""" - - def render(self, name, value, attrs=None, renderer=None): - return mark_safe(f"
{value}
") # nosec - - -class HorizontalRuleWidget(forms.widgets.Widget): - """Widget, which renders an${prompt.placeholder} +
`; + } + return html``; + } + + render(): TemplateResult { + if (!this.challenge) { + return html`