Skip to content

Commit

Permalink
Indexed: handle initialisation outside of __init__
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Jan 15, 2025
1 parent cd1d21b commit 55cc653
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 4 deletions.
19 changes: 19 additions & 0 deletions test/test_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
triangle,
)
from ufl.algorithms import compute_form_data
from ufl.core.multiindex import FixedIndex, MultiIndex
from ufl.finiteelement import FiniteElement
from ufl.indexed import Indexed
from ufl.pullback import identity_pullback
from ufl.sobolevspace import H1

Expand Down Expand Up @@ -174,3 +176,20 @@ def test_tensor_from_indexed(self, shape):
space = FunctionSpace(domain, element)
f = Coefficient(space)
assert as_tensor(reshape([f[i] for i in ndindex(f.ufl_shape)], f.ufl_shape).tolist()) is f


def test_nested_indexed(self):
# Test that a nested Indexed expression simplifies to the existing Indexed object
shape = (2,)
element = FiniteElement("Lagrange", triangle, 1, shape, identity_pullback, H1)
domain = Mesh(FiniteElement("Lagrange", triangle, 1, (2,), identity_pullback, H1))
space = FunctionSpace(domain, element)
f = Coefficient(space)

comps = tuple(f[i] for i in range(2))
assert all(isinstance(c, Indexed) for c in comps)
expr = as_tensor(list(reversed(comps)))

multiindex = MultiIndex((FixedIndex(0),))
assert Indexed(expr, multiindex) is expr[0]
assert Indexed(expr, multiindex) is comps[1]
14 changes: 10 additions & 4 deletions ufl/indexed.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,14 @@ def __new__(cls, expression, multiindex):

try:
# Simplify indexed ListTensor
c = expression[multiindex]
return Indexed(*c.ufl_operands) if isinstance(c, Indexed) else c
return expression[multiindex]
except ValueError:
return Operator.__new__(cls)
# Construct and initialize a new Indexed object
self = Operator.__new__(cls)
self._init(expression, multiindex)
return self

def __init__(self, expression, multiindex):
def _init(self, expression, multiindex):
"""Initialise."""
# Store operands
Operator.__init__(self, (expression, multiindex))
Expand Down Expand Up @@ -94,6 +96,10 @@ def __init__(self, expression, multiindex):
self.ufl_free_indices = fi
self.ufl_index_dimensions = fid

def __init__(self, expression, multiindex):
"""Initialise."""
Operator.__init__(self)

ufl_shape = ()

def evaluate(self, x, mapping, component, index_values, derivatives=()):
Expand Down

0 comments on commit 55cc653

Please sign in to comment.