From 80ab7cafd8c468442be4f989378133a2d88a0e03 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Sun, 29 Dec 2024 21:36:19 -0600 Subject: [PATCH] Simplify indexed ListTensor objects --- ufl/algorithms/analysis.py | 38 +++++++++++++++----------------------- ufl/checks.py | 16 ++-------------- ufl/indexed.py | 10 +++++++++- 3 files changed, 26 insertions(+), 38 deletions(-) diff --git a/ufl/algorithms/analysis.py b/ufl/algorithms/analysis.py index 1d0464a95..4376adae9 100644 --- a/ufl/algorithms/analysis.py +++ b/ufl/algorithms/analysis.py @@ -69,33 +69,25 @@ def extract_type(a, ufl_types): objects = set() arg_types = tuple(t for t in ufl_types if issubclass(t, BaseArgument)) if arg_types: - objects.update([e for e in a.arguments() if isinstance(e, arg_types)]) + objects.update(e for e in a.arguments() if isinstance(e, arg_types)) coeff_types = tuple(t for t in ufl_types if issubclass(t, BaseCoefficient)) if coeff_types: - objects.update([e for e in a.coefficients() if isinstance(e, coeff_types)]) + objects.update(e for e in a.coefficients() if isinstance(e, coeff_types)) return objects if all(issubclass(t, Terminal) for t in ufl_types): # Optimization - objects = set( - o - for e in iter_expressions(a) - for o in traverse_unique_terminals(e) - if any(isinstance(o, t) for t in ufl_types) - ) + traversal = traverse_unique_terminals else: - objects = set( - o - for e in iter_expressions(a) - for o in unique_pre_traversal(e) - if any(isinstance(o, t) for t in ufl_types) - ) + traversal = unique_pre_traversal + + objects = set(o for e in iter_expressions(a) for o in traversal(e) if isinstance(o, ufl_types)) # Need to extract objects contained in base form operators whose # type is in ufl_types base_form_ops = set(e for e in objects if isinstance(e, BaseFormOperator)) ufl_types_no_args = tuple(t for t in ufl_types if not issubclass(t, BaseArgument)) - base_form_objects = () + base_form_objects = [] for o in base_form_ops: # This accounts for having BaseFormOperator in Forms: if N is a BaseFormOperator # `N(u; v*) * v * dx` <=> `action(v1 * v * dx, N(...; v*))` @@ -106,9 +98,9 @@ def extract_type(a, ufl_types): # argument of the Coargument and not its primal argument. if isinstance(ai, Coargument): new_types = tuple(Coargument if t is BaseArgument else t for t in ufl_types) - base_form_objects += tuple(extract_type(ai, new_types)) + base_form_objects.extend(extract_type(ai, new_types)) else: - base_form_objects += tuple(extract_type(ai, ufl_types)) + base_form_objects.extend(extract_type(ai, ufl_types)) # Look for BaseArguments in BaseFormOperator's argument slots # only since that's where they are by definition. Don't look # into operands, which is convenient for external operator @@ -116,7 +108,7 @@ def extract_type(a, ufl_types): # and not a form. slots = o.ufl_operands for ai in slots: - base_form_objects += tuple(extract_type(ai, ufl_types_no_args)) + base_form_objects.extend(extract_type(ai, ufl_types_no_args)) objects.update(base_form_objects) # `Remove BaseFormOperator` objects if there were initially not in `ufl_types` @@ -213,7 +205,7 @@ def extract_arguments_and_coefficients(a): coefficients = [f for f in base_coeff_and_args if isinstance(f, BaseCoefficient)] # Build number,part: instance mappings, should be one to one - bfnp = dict((f, (f.number(), f.part())) for f in arguments) + bfnp = {f: (f.number(), f.part()) for f in arguments} if len(bfnp) != len(set(bfnp.values())): raise ValueError( "Found different Arguments with same number and part.\n" @@ -222,7 +214,7 @@ def extract_arguments_and_coefficients(a): ) # Build count: instance mappings, should be one to one - fcounts = dict((f, f.count()) for f in coefficients) + fcounts = {f: f.count() for f in coefficients} if len(fcounts) != len(set(fcounts.values())): raise ValueError( "Found different coefficients with same counts.\n" @@ -249,10 +241,10 @@ def extract_unique_elements(form): def extract_sub_elements(elements): """Build sorted tuple of all sub elements (including parent element).""" - sub_elements = tuple(chain(*[e.sub_elements for e in elements])) + sub_elements = tuple(chain(*(e.sub_elements for e in elements))) if not sub_elements: return tuple(elements) - return tuple(elements) + extract_sub_elements(sub_elements) + return (*elements, *extract_sub_elements(sub_elements)) def sort_elements(elements): @@ -268,7 +260,7 @@ def sort_elements(elements): nodes = list(elements) # Set edges - edges = dict((node, []) for node in nodes) + edges = {node: [] for node in nodes} for element in elements: for sub_element in element.sub_elements: edges[element].append(sub_element) diff --git a/ufl/checks.py b/ufl/checks.py index b01b92e6b..09ae4453b 100644 --- a/ufl/checks.py +++ b/ufl/checks.py @@ -11,7 +11,7 @@ from ufl.core.expr import Expr from ufl.core.terminal import FormArgument from ufl.corealg.traversal import traverse_unique_terminals -from ufl.geometry import GeometricQuantity, SpatialCoordinate +from ufl.geometry import GeometricQuantity from ufl.sobolevspace import H1 @@ -34,19 +34,7 @@ def is_true_ufl_scalar(expression): def is_cellwise_constant(expr): """Return whether expression is constant over a single cell.""" - from ufl.coefficient import Coefficient - from ufl.differentiation import ReferenceGrad - - if isinstance(expr, ReferenceGrad): - (expr,) = expr.ufl_operands - if is_cellwise_constant(expr): - return True - elif isinstance(expr, SpatialCoordinate): - return expr.ufl_domain().is_piecewise_linear_simplex_domain() - elif isinstance(expr, Coefficient): - element = expr.ufl_element() - return element.embedded_superdegree <= 1 - + # TODO: Implement more accurately considering e.g. derivatives? return all(e.is_cellwise_constant() for e in traverse_unique_terminals(expr)) diff --git a/ufl/indexed.py b/ufl/indexed.py index 9f6525863..9505a92aa 100644 --- a/ufl/indexed.py +++ b/ufl/indexed.py @@ -26,13 +26,21 @@ class Indexed(Operator): def __new__(cls, expression, multiindex): """Create a new Indexed.""" + from ufl.tensors import ListTensor + + indices = multiindex._indices + while isinstance(expression, ListTensor) and isinstance(indices[0], FixedIndex): + # Simplify indexed ListTensor objects + expression = expression.ufl_operands[int(indices[0])] + indices = indices[1:] + if isinstance(expression, Zero): # Zero-simplify indexed Zero objects shape = expression.ufl_shape efi = expression.ufl_free_indices efid = expression.ufl_index_dimensions fi = list(zip(efi, efid)) - for pos, ind in enumerate(multiindex._indices): + for pos, ind in enumerate(indices): if isinstance(ind, Index): fi.append((ind.count(), shape[pos])) fi = unique_sorted_indices(sorted(fi))