diff --git a/firedrake/assemble.py b/firedrake/assemble.py index 978fdc41a5..7bb4f4dbb3 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -378,7 +378,8 @@ def visitor(e, *operands): return self.base_form_assembly_visitor(e, t, *operands) # DAG assembly: traverse the DAG in a post-order fashion and evaluate the node on the fly. - result = BaseFormAssembler.base_form_postorder_traversal(self._form, visitor) + visited = {} + result = BaseFormAssembler.base_form_postorder_traversal(self._form, visitor, visited=visited) if tensor: BaseFormAssembler.update_tensor(result, tensor) @@ -584,27 +585,36 @@ def update_tensor(assembled_base_form, tensor): @staticmethod def base_form_postorder_traversal(expr, visitor, visited=None): - visited = visited if visited is not None else {} + visited = visited or {} if expr in visited: return visited[expr] stack = [expr] + processing = set() + while stack: - e = stack.pop() - unvisited_children = [] + e = stack[-1] + if e in visited: + stack.pop() + continue + operands = BaseFormAssembler.base_form_operands(e) - for arg in operands: - if arg not in visited: - unvisited_children.append(arg) + unvisited_children = [arg for arg in operands if arg not in visited] if unvisited_children: - stack.append(e) stack.extend(unvisited_children) + processing.update(unvisited_children) else: + stack.pop() + # if not isinstance(e, firedrake.Cofunction): visited[e] = visitor(e, *(visited[arg] for arg in operands)) + # else: + # visited[e] = e + processing.discard(e) return visited[expr] + @staticmethod def base_form_preorder_traversal(expr, visitor, visited=None): visited = visited if visited is not None else {} @@ -658,9 +668,13 @@ def visitor(expr, *operands): # Need to reconstruct the expression with its visited operands! expr = BaseFormAssembler.reconstruct_node_from_operands(expr, operands) # Perform the DAG restructuring when needed - return BaseFormAssembler.restructure_base_form(expr, visited) - - return BaseFormAssembler.base_form_postorder_traversal(expression, visitor, visited) + if visited: + return BaseFormAssembler.restructure_base_form(expr, visited=visited) + return BaseFormAssembler.restructure_base_form(expr) + if visited: + return BaseFormAssembler.base_form_postorder_traversal( + expression, visitor, visited=visited) + return BaseFormAssembler.base_form_postorder_traversal(expression, visitor) @staticmethod def restructure_base_form_preorder(expression, visited=None): @@ -668,9 +682,13 @@ def restructure_base_form_preorder(expression, visited=None): def visitor(expr): # Perform the DAG restructuring when needed - return BaseFormAssembler.restructure_base_form(expr, visited) - - expression = BaseFormAssembler.base_form_preorder_traversal(expression, visitor, visited) + if visited: + return BaseFormAssembler.restructure_base_form(expr, visited=visited) + return BaseFormAssembler.restructure_base_form(expr) + if visited: + return BaseFormAssembler.base_form_preorder_traversal(expression, visitor, visited=visited) + else: + return BaseFormAssembler.base_form_preorder_traversal(expression, visitor) # Need to reconstruct the expression at the end when all its operands have been visited! operands = [visited.get(args, args) for args in BaseFormAssembler.base_form_operands(expression)] return BaseFormAssembler.reconstruct_node_from_operands(expression, operands) @@ -974,6 +992,7 @@ def assemble(self, tensor=None): ) if tensor is None: + # Creating a cofunction. tensor = self.allocate() else: self._check_tensor(tensor)