diff --git a/passbook/policies/engine.py b/passbook/policies/engine.py index f76570757..25709f166 100644 --- a/passbook/policies/engine.py +++ b/passbook/policies/engine.py @@ -16,6 +16,7 @@ LOGGER = get_logger() # spawn causes issues with objects that aren't picklable, and also the django setup set_start_method("fork") + class PolicyProcessInfo: """Dataclass to hold all information and communication channels to a process""" @@ -38,13 +39,15 @@ class PolicyEngine: policies: List[Policy] = [] request: PolicyRequest - __processes: List[PolicyProcessInfo] = [] + __cached_policies: List[PolicyResult] + __processes: List[PolicyProcessInfo] def __init__(self, policies, user: User, request: HttpRequest = None): self.policies = policies self.request = PolicyRequest(user) if request: self.request.http_request = request + self.__cached_policies = [] self.__processes = [] def _select_subclasses(self) -> List[Policy]: @@ -57,21 +60,20 @@ class PolicyEngine: def build(self) -> "PolicyEngine": """Build task group""" - cached_policies = [] for policy in self._select_subclasses(): cached_policy = cache.get(cache_key(policy, self.request.user), None) if cached_policy and self.use_cache: LOGGER.debug("Taking result from cache", policy=policy) - cached_policies.append(cached_policy) - else: - LOGGER.debug("Evaluating policy", policy=policy) - our_end, task_end = Pipe(False) - task = PolicyProcess(policy, self.request, task_end) - LOGGER.debug("Starting Process", policy=policy) - task.start() - self.__processes.append( - PolicyProcessInfo(process=task, connection=our_end, policy=policy) - ) + self.__cached_policies.append(cached_policy) + continue + LOGGER.debug("Evaluating policy", policy=policy) + our_end, task_end = Pipe(False) + task = PolicyProcess(policy, self.request, task_end) + LOGGER.debug("Starting Process", policy=policy) + task.start() + self.__processes.append( + PolicyProcessInfo(process=task, connection=our_end, policy=policy) + ) # If all policies are cached, we have an empty list here. for proc_info in self.__processes: proc_info.process.join(proc_info.policy.timeout) @@ -84,13 +86,14 @@ class PolicyEngine: def result(self) -> Tuple[bool, List[str]]: """Get policy-checking result""" messages: List[str] = [] - for proc_info in self.__processes: - LOGGER.debug( - "Result", policy=proc_info.policy, passing=proc_info.result.passing - ) - if proc_info.result.messages: - messages += proc_info.result.messages - if not proc_info.result.passing: + process_results: List[PolicyResult] = [ + x.result for x in self.__processes if x.result + ] + for result in process_results + self.__cached_policies: + LOGGER.debug("result", passing=result.passing) + if result.messages: + messages += result.messages + if not result.passing: return False, messages return True, messages