diff --git a/authentik/flows/challenge.py b/authentik/flows/challenge.py index af5030276..f19859749 100644 --- a/authentik/flows/challenge.py +++ b/authentik/flows/challenge.py @@ -1,5 +1,6 @@ """Challenge helpers""" from enum import Enum +from typing import TYPE_CHECKING, Optional from django.db.models.base import Model from django.http import JsonResponse @@ -8,6 +9,9 @@ from rest_framework.serializers import CharField, Serializer from authentik.flows.transfer.common import DataclassEncoder +if TYPE_CHECKING: + from authentik.flows.stage import StageView + class ChallengeTypes(Enum): """Currently defined challenge types""" @@ -36,6 +40,12 @@ class Challenge(Serializer): class ChallengeResponse(Serializer): """Base class for all challenge responses""" + stage: Optional["StageView"] + + def __init__(self, instance, data, **kwargs): + self.stage = kwargs.pop("stage", None) + super().__init__(instance=instance, data=data, **kwargs) + def create(self, validated_data: dict) -> Model: return Model() diff --git a/authentik/flows/stage.py b/authentik/flows/stage.py index d5ba4e40d..22a3162d2 100644 --- a/authentik/flows/stage.py +++ b/authentik/flows/stage.py @@ -3,6 +3,7 @@ from collections import namedtuple from typing import Any, Type from django.http import HttpRequest +from django.http.request import QueryDict from django.http.response import HttpResponse, JsonResponse from django.utils.translation import gettext_lazy as _ from django.views.generic import TemplateView @@ -56,9 +57,9 @@ class ChallengeStageView(StageView): response_class = ChallengeResponse - def get_response_class(self) -> Type[ChallengeResponse]: + def get_response_instance(self, data: QueryDict) -> ChallengeResponse: """Return the response class type""" - return self.response_class + return self.response_class(None, data=data, stage=self) def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: challenge = self.get_challenge() @@ -69,7 +70,7 @@ class ChallengeStageView(StageView): # pylint: disable=unused-argument def post(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: """Handle challenge response""" - challenge: ChallengeResponse = self.get_response_class()(data=request.POST) + challenge: ChallengeResponse = self.get_response_instance(data=request.POST) if not challenge.is_valid(): return self.challenge_invalid(challenge) return self.challenge_valid(challenge) diff --git a/authentik/stages/identification/stage.py b/authentik/stages/identification/stage.py index 21989be1d..2da6aa102 100644 --- a/authentik/stages/identification/stage.py +++ b/authentik/stages/identification/stage.py @@ -7,6 +7,7 @@ from django.http import HttpResponse from django.urls import reverse from django.utils.translation import gettext as _ from rest_framework.fields import CharField +from rest_framework.serializers import ValidationError from structlog.stdlib import get_logger from authentik.core.models import Source, User @@ -27,8 +28,16 @@ class IdentificationChallengeResponse(ChallengeResponse): """Identification challenge""" uid_field = CharField() + pre_user: Optional[User] = None - # TODO: Validate here instead of challenge_valid() + def validate_uid_field(self, value: str) -> str: + """Validate that user exists""" + pre_user = self.stage.get_user(value) + if not pre_user: + LOGGER.debug("invalid_login", identifier=value) + raise ValidationError("Failed to authenticate.") + self.pre_user = pre_user + return value class IdentificationStageView(ChallengeStageView): @@ -96,18 +105,10 @@ class IdentificationStageView(ChallengeStageView): def challenge_valid( self, challenge: IdentificationChallengeResponse ) -> HttpResponse: - user_identifier = challenge.data.get("uid_field") - pre_user = self.get_user(user_identifier) - if not pre_user: - LOGGER.debug("invalid_login", identifier=user_identifier) - messages.error(self.request, _("Failed to authenticate.")) - return self.challenge_invalid(challenge) - self.executor.plan.context[PLAN_CONTEXT_PENDING_USER] = pre_user - + self.executor.plan.context[PLAN_CONTEXT_PENDING_USER] = challenge.pre_user current_stage: IdentificationStage = self.executor.current_stage if not current_stage.show_matched_user: self.executor.plan.context[ PLAN_CONTEXT_PENDING_USER_IDENTIFIER - ] = user_identifier - + ] = challenge.validated_data.get("uid_field") return self.executor.stage_ok() diff --git a/tests/e2e/test_flows_enroll.py b/tests/e2e/test_flows_enroll.py index 99e6ca26f..76bbf4832 100644 --- a/tests/e2e/test_flows_enroll.py +++ b/tests/e2e/test_flows_enroll.py @@ -87,9 +87,7 @@ class TestFlowsEnroll(SeleniumTestCase): FlowStageBinding.objects.create(target=flow, stage=user_login, order=3) self.driver.get(self.live_server_url) - self.wait.until( - ec.presence_of_element_located((By.CSS_SELECTOR, "#enroll")) - ) + self.wait.until(ec.presence_of_element_located((By.CSS_SELECTOR, "#enroll"))) self.driver.find_element(By.CSS_SELECTOR, "#enroll").click() self.wait.until(ec.presence_of_element_located((By.ID, "id_username")))