Skip to content

Commit

Permalink
Simplify indexed ListTensor objects
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Dec 30, 2024
1 parent c3ce9fe commit a7a7922
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 24 deletions.
38 changes: 15 additions & 23 deletions ufl/algorithms/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,33 +69,25 @@ def extract_type(a, ufl_types):
objects = set()
arg_types = tuple(t for t in ufl_types if issubclass(t, BaseArgument))
if arg_types:
objects.update([e for e in a.arguments() if isinstance(e, arg_types)])
objects.update(e for e in a.arguments() if isinstance(e, arg_types))
coeff_types = tuple(t for t in ufl_types if issubclass(t, BaseCoefficient))
if coeff_types:
objects.update([e for e in a.coefficients() if isinstance(e, coeff_types)])
objects.update(e for e in a.coefficients() if isinstance(e, coeff_types))
return objects

if all(issubclass(t, Terminal) for t in ufl_types):
# Optimization
objects = set(
o
for e in iter_expressions(a)
for o in traverse_unique_terminals(e)
if any(isinstance(o, t) for t in ufl_types)
)
traversal = traverse_unique_terminals
else:
objects = set(
o
for e in iter_expressions(a)
for o in unique_pre_traversal(e)
if any(isinstance(o, t) for t in ufl_types)
)
traversal = unique_pre_traversal

objects = set(o for e in iter_expressions(a) for o in traversal(e) if isinstance(o, ufl_types))

# Need to extract objects contained in base form operators whose
# type is in ufl_types
base_form_ops = set(e for e in objects if isinstance(e, BaseFormOperator))
ufl_types_no_args = tuple(t for t in ufl_types if not issubclass(t, BaseArgument))
base_form_objects = ()
base_form_objects = []
for o in base_form_ops:
# This accounts for having BaseFormOperator in Forms: if N is a BaseFormOperator
# `N(u; v*) * v * dx` <=> `action(v1 * v * dx, N(...; v*))`
Expand All @@ -106,17 +98,17 @@ def extract_type(a, ufl_types):
# argument of the Coargument and not its primal argument.
if isinstance(ai, Coargument):
new_types = tuple(Coargument if t is BaseArgument else t for t in ufl_types)
base_form_objects += tuple(extract_type(ai, new_types))
base_form_objects.extend(extract_type(ai, new_types))
else:
base_form_objects += tuple(extract_type(ai, ufl_types))
base_form_objects.extend(extract_type(ai, ufl_types))
# Look for BaseArguments in BaseFormOperator's argument slots
# only since that's where they are by definition. Don't look
# into operands, which is convenient for external operator
# composition, e.g. N1(N2; v*) where N2 is seen as an operator
# and not a form.
slots = o.ufl_operands
for ai in slots:
base_form_objects += tuple(extract_type(ai, ufl_types_no_args))
base_form_objects.extend(extract_type(ai, ufl_types_no_args))
objects.update(base_form_objects)

# `Remove BaseFormOperator` objects if there were initially not in `ufl_types`
Expand Down Expand Up @@ -213,7 +205,7 @@ def extract_arguments_and_coefficients(a):
coefficients = [f for f in base_coeff_and_args if isinstance(f, BaseCoefficient)]

# Build number,part: instance mappings, should be one to one
bfnp = dict((f, (f.number(), f.part())) for f in arguments)
bfnp = {f: (f.number(), f.part()) for f in arguments}
if len(bfnp) != len(set(bfnp.values())):
raise ValueError(
"Found different Arguments with same number and part.\n"
Expand All @@ -222,7 +214,7 @@ def extract_arguments_and_coefficients(a):
)

# Build count: instance mappings, should be one to one
fcounts = dict((f, f.count()) for f in coefficients)
fcounts = {f: f.count() for f in coefficients}
if len(fcounts) != len(set(fcounts.values())):
raise ValueError(
"Found different coefficients with same counts.\n"
Expand All @@ -249,10 +241,10 @@ def extract_unique_elements(form):

def extract_sub_elements(elements):
"""Build sorted tuple of all sub elements (including parent element)."""
sub_elements = tuple(chain(*[e.sub_elements for e in elements]))
sub_elements = tuple(chain(*(e.sub_elements for e in elements)))
if not sub_elements:
return tuple(elements)
return tuple(elements) + extract_sub_elements(sub_elements)
return (*elements, *extract_sub_elements(sub_elements))


def sort_elements(elements):
Expand All @@ -268,7 +260,7 @@ def sort_elements(elements):
nodes = list(elements)

# Set edges
edges = dict((node, []) for node in nodes)
edges = {node: [] for node in nodes}
for element in elements:
for sub_element in element.sub_elements:
edges[element].append(sub_element)
Expand Down
10 changes: 9 additions & 1 deletion ufl/indexed.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,21 @@ class Indexed(Operator):

def __new__(cls, expression, multiindex):
"""Create a new Indexed."""
from ufl.tensors import ListTensor

indices = multiindex._indices
while isinstance(expression, ListTensor) and isinstance(indices[0], FixedIndex):
# Simplify indexed ListTensor objects
expression = expression.ufl_operands[int(indices[0])]
indices = indices[1:]

if isinstance(expression, Zero):
# Zero-simplify indexed Zero objects
shape = expression.ufl_shape
efi = expression.ufl_free_indices
efid = expression.ufl_index_dimensions
fi = list(zip(efi, efid))
for pos, ind in enumerate(multiindex._indices):
for pos, ind in enumerate(indices):
if isinstance(ind, Index):
fi.append((ind.count(), shape[pos]))
fi = unique_sorted_indices(sorted(fi))
Expand Down

0 comments on commit a7a7922

Please sign in to comment.