diff --git a/passbook/flows/planner.py b/passbook/flows/planner.py index a2e654bf7..a2d0dc05f 100644 --- a/passbook/flows/planner.py +++ b/passbook/flows/planner.py @@ -56,7 +56,9 @@ class FlowPlanner: engine.build() return engine.result - def plan(self, request: HttpRequest) -> FlowPlan: + def plan( + self, request: HttpRequest, default_context: Optional[Dict[str, Any]] = None + ) -> FlowPlan: """Check each of the flows' policies, check policies for each stage with PolicyBinding and return ordered list""" LOGGER.debug("f(plan): Starting planning process", flow=self.flow) @@ -65,12 +67,38 @@ class FlowPlanner: root_passing, root_passing_messages = self._check_flow_root_policies(request) if not root_passing: raise FlowNonApplicableException(root_passing_messages) - cached_plan = cache.get(cache_key(self.flow, request.user), None) + # Bit of a workaround here, if there is a pending user set in the default context + # we use that user for our cache key + # to make sure they don't get the generic response + if default_context and PLAN_CONTEXT_PENDING_USER in default_context: + user = default_context[PLAN_CONTEXT_PENDING_USER] + else: + user = request.user + cached_plan_key = cache_key(self.flow, user) + cached_plan = cache.get(cached_plan_key, None) if cached_plan and self.use_cache: - LOGGER.debug("f(plan): Taking plan from cache", flow=self.flow) + LOGGER.debug( + "f(plan): Taking plan from cache", flow=self.flow, key=cached_plan_key + ) + LOGGER.debug(cached_plan) return cached_plan + plan = self._build_plan(user, request, default_context) + cache.set(cache_key(self.flow, user), plan) + if not plan.stages: + raise EmptyFlowException() + return plan + + def _build_plan( + self, + user: User, + request: HttpRequest, + default_context: Optional[Dict[str, Any]], + ) -> FlowPlan: + """Actually build flow plan""" start_time = time() plan = FlowPlan(flow_pk=self.flow.pk.hex) + if default_context: + plan.context = default_context # Check Flow policies for stage in ( self.flow.stages.order_by("flowstagebinding__order") @@ -78,7 +106,8 @@ class FlowPlanner: .select_related() ): binding = stage.flowstagebinding_set.get(flow__pk=self.flow.pk) - engine = PolicyEngine(binding.policies.all(), request.user, request) + engine = PolicyEngine(binding.policies.all(), user, request) + engine.request.context = plan.context engine.build() passing, _ = engine.result if passing: @@ -86,11 +115,8 @@ class FlowPlanner: plan.stages.append(stage) end_time = time() LOGGER.debug( - "f(plan): Finished planning", + "f(plan): Finished building", flow=self.flow, duration_s=end_time - start_time, ) - cache.set(cache_key(self.flow, request.user), plan) - if not plan.stages: - raise EmptyFlowException() return plan diff --git a/passbook/sources/oauth/views/core.py b/passbook/sources/oauth/views/core.py index 36e84d869..83526b855 100644 --- a/passbook/sources/oauth/views/core.py +++ b/passbook/sources/oauth/views/core.py @@ -176,10 +176,14 @@ class OAuthCallback(OAuthClientMixin, View): # We run the Flow planner here so we can pass the Pending user in the context flow = get_object_or_404(Flow, designation=FlowDesignation.AUTHENTICATION) planner = FlowPlanner(flow) - plan = planner.plan(self.request) - plan.context[PLAN_CONTEXT_PENDING_USER] = user - plan.context[PLAN_CONTEXT_AUTHENTICATION_BACKEND] = user.backend - plan.context[PLAN_CONTEXT_SSO] = True + plan = planner.plan( + self.request, + { + PLAN_CONTEXT_PENDING_USER: user, + PLAN_CONTEXT_AUTHENTICATION_BACKEND: user.backend, + PLAN_CONTEXT_SSO: True, + }, + ) self.request.session[SESSION_KEY_PLAN] = plan return redirect_with_qs( "passbook_flows:flow-executor", self.request.GET, flow_slug=flow.slug,