Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
Ig-dolci committed Nov 29, 2024
1 parent 7151998 commit 243c6dc
Showing 1 changed file with 19 additions and 35 deletions.
54 changes: 19 additions & 35 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,8 +378,7 @@ def visitor(e, *operands):
return self.base_form_assembly_visitor(e, t, *operands)

# DAG assembly: traverse the DAG in a post-order fashion and evaluate the node on the fly.
visited = {}
result = BaseFormAssembler.base_form_postorder_traversal(self._form, visitor, visited=visited)
result = BaseFormAssembler.base_form_postorder_traversal(self._form, visitor)

if tensor:
BaseFormAssembler.update_tensor(result, tensor)
Expand Down Expand Up @@ -469,8 +468,11 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
return sum(weight * arg for weight, arg in zip(expr.weights(), args))
elif all(isinstance(op, firedrake.Cofunction) for op in args):
V, = set(a.function_space() for a in args)
res = sum([w*op.dat for (op, w) in zip(args, expr.weights())])
return firedrake.Cofunction(V, res)
# res = sum([w*op.dat for (op, w) in zip(args, expr.weights())])
# return firedrake.Cofunction(V, res)
result = firedrake.Cofunction(V)
result.dat.data[...] = sum(w * op.dat.data[...] for op, w in zip(args, expr.weights()))
return result
elif all(isinstance(op, ufl.Matrix) for op in args):
res = tensor.petscmat if tensor else PETSc.Mat()
is_set = False
Expand Down Expand Up @@ -585,36 +587,27 @@ def update_tensor(assembled_base_form, tensor):

@staticmethod
def base_form_postorder_traversal(expr, visitor, visited=None):
visited = visited or {}
visited = visited if visited is not None else {}
if expr in visited:
return visited[expr]

stack = [expr]
processing = set()

while stack:
e = stack[-1]
if e in visited:
stack.pop()
continue

e = stack.pop()
unvisited_children = []
operands = BaseFormAssembler.base_form_operands(e)
unvisited_children = [arg for arg in operands if arg not in visited]
for arg in operands:
if arg not in visited:
unvisited_children.append(arg)

if unvisited_children:
stack.append(e)
stack.extend(unvisited_children)
processing.update(unvisited_children)
else:
stack.pop()
# if not isinstance(e, firedrake.Cofunction):
visited[e] = visitor(e, *(visited[arg] for arg in operands))
# else:
# visited[e] = e
processing.discard(e)

return visited[expr]


@staticmethod
def base_form_preorder_traversal(expr, visitor, visited=None):
visited = visited if visited is not None else {}
Expand Down Expand Up @@ -668,27 +661,19 @@ def visitor(expr, *operands):
# Need to reconstruct the expression with its visited operands!
expr = BaseFormAssembler.reconstruct_node_from_operands(expr, operands)
# Perform the DAG restructuring when needed
if visited:
return BaseFormAssembler.restructure_base_form(expr, visited=visited)
return BaseFormAssembler.restructure_base_form(expr)
if visited:
return BaseFormAssembler.base_form_postorder_traversal(
expression, visitor, visited=visited)
return BaseFormAssembler.base_form_postorder_traversal(expression, visitor)
return BaseFormAssembler.restructure_base_form(expr, visited)

return BaseFormAssembler.base_form_postorder_traversal(expression, visitor, visited)

@staticmethod
def restructure_base_form_preorder(expression, visited=None):
visited = visited or {}

def visitor(expr):
# Perform the DAG restructuring when needed
if visited:
return BaseFormAssembler.restructure_base_form(expr, visited=visited)
return BaseFormAssembler.restructure_base_form(expr)
if visited:
return BaseFormAssembler.base_form_preorder_traversal(expression, visitor, visited=visited)
else:
return BaseFormAssembler.base_form_preorder_traversal(expression, visitor)
return BaseFormAssembler.restructure_base_form(expr, visited)

expression = BaseFormAssembler.base_form_preorder_traversal(expression, visitor, visited)
# Need to reconstruct the expression at the end when all its operands have been visited!
operands = [visited.get(args, args) for args in BaseFormAssembler.base_form_operands(expression)]
return BaseFormAssembler.reconstruct_node_from_operands(expression, operands)
Expand Down Expand Up @@ -992,7 +977,6 @@ def assemble(self, tensor=None):
)

if tensor is None:
# Creating a cofunction.
tensor = self.allocate()
else:
self._check_tensor(tensor)
Expand Down

0 comments on commit 243c6dc

Please sign in to comment.