Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
Ig-dolci committed Nov 29, 2024
1 parent c66d385 commit 7151998
Showing 1 changed file with 33 additions and 14 deletions.
47 changes: 33 additions & 14 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,8 @@ def visitor(e, *operands):
return self.base_form_assembly_visitor(e, t, *operands)

# DAG assembly: traverse the DAG in a post-order fashion and evaluate the node on the fly.
result = BaseFormAssembler.base_form_postorder_traversal(self._form, visitor)
visited = {}
result = BaseFormAssembler.base_form_postorder_traversal(self._form, visitor, visited=visited)

if tensor:
BaseFormAssembler.update_tensor(result, tensor)
Expand Down Expand Up @@ -584,27 +585,36 @@ def update_tensor(assembled_base_form, tensor):

@staticmethod
def base_form_postorder_traversal(expr, visitor, visited=None):
visited = visited if visited is not None else {}
visited = visited or {}
if expr in visited:
return visited[expr]

stack = [expr]
processing = set()

while stack:
e = stack.pop()
unvisited_children = []
e = stack[-1]
if e in visited:
stack.pop()
continue

operands = BaseFormAssembler.base_form_operands(e)
for arg in operands:
if arg not in visited:
unvisited_children.append(arg)
unvisited_children = [arg for arg in operands if arg not in visited]

if unvisited_children:
stack.append(e)
stack.extend(unvisited_children)
processing.update(unvisited_children)
else:
stack.pop()
# if not isinstance(e, firedrake.Cofunction):
visited[e] = visitor(e, *(visited[arg] for arg in operands))
# else:
# visited[e] = e
processing.discard(e)

return visited[expr]


@staticmethod
def base_form_preorder_traversal(expr, visitor, visited=None):
visited = visited if visited is not None else {}
Expand Down Expand Up @@ -658,19 +668,27 @@ def visitor(expr, *operands):
# Need to reconstruct the expression with its visited operands!
expr = BaseFormAssembler.reconstruct_node_from_operands(expr, operands)
# Perform the DAG restructuring when needed
return BaseFormAssembler.restructure_base_form(expr, visited)

return BaseFormAssembler.base_form_postorder_traversal(expression, visitor, visited)
if visited:
return BaseFormAssembler.restructure_base_form(expr, visited=visited)
return BaseFormAssembler.restructure_base_form(expr)
if visited:
return BaseFormAssembler.base_form_postorder_traversal(
expression, visitor, visited=visited)
return BaseFormAssembler.base_form_postorder_traversal(expression, visitor)

@staticmethod
def restructure_base_form_preorder(expression, visited=None):
visited = visited or {}

def visitor(expr):
# Perform the DAG restructuring when needed
return BaseFormAssembler.restructure_base_form(expr, visited)

expression = BaseFormAssembler.base_form_preorder_traversal(expression, visitor, visited)
if visited:
return BaseFormAssembler.restructure_base_form(expr, visited=visited)
return BaseFormAssembler.restructure_base_form(expr)
if visited:
return BaseFormAssembler.base_form_preorder_traversal(expression, visitor, visited=visited)
else:
return BaseFormAssembler.base_form_preorder_traversal(expression, visitor)
# Need to reconstruct the expression at the end when all its operands have been visited!
operands = [visited.get(args, args) for args in BaseFormAssembler.base_form_operands(expression)]
return BaseFormAssembler.reconstruct_node_from_operands(expression, operands)
Expand Down Expand Up @@ -974,6 +992,7 @@ def assemble(self, tensor=None):
)

if tensor is None:
# Creating a cofunction.
tensor = self.allocate()
else:
self._check_tensor(tensor)
Expand Down

0 comments on commit 7151998

Please sign in to comment.