diff --git a/authentik/stages/prompt/stage.py b/authentik/stages/prompt/stage.py index b5b2d978b..6b650f2d3 100644 --- a/authentik/stages/prompt/stage.py +++ b/authentik/stages/prompt/stage.py @@ -53,9 +53,11 @@ class PromptChallengeResponse(ChallengeResponse): def __init__(self, *args, **kwargs): stage: PromptStage = kwargs.pop("stage", None) plan: FlowPlan = kwargs.pop("plan", None) + request: HttpRequest = kwargs.pop("request", None) super().__init__(*args, **kwargs) self.stage = stage self.plan = plan + self.request = request if not self.stage: return # list() is called so we only load the fields once @@ -104,8 +106,9 @@ class PromptChallengeResponse(ChallengeResponse): self._validate_password_fields(*[field.field_key for field in password_fields]) user = self.plan.context.get(PLAN_CONTEXT_PENDING_USER, get_anonymous_user()) - engine = ListPolicyEngine(self.stage.validation_policies.all(), user) - engine.request.context = attrs + engine = ListPolicyEngine(self.stage.validation_policies.all(), user, self.request) + engine.request.context[PLAN_CONTEXT_PROMPT] = attrs + engine.request.context.update(attrs) engine.build() result = engine.result if not result.passing: @@ -173,6 +176,7 @@ class PromptStageView(ChallengeStageView): return PromptChallengeResponse( instance=None, data=data, + request=self.request, stage=self.executor.current_stage, plan=self.executor.plan, )