Skip to content

Commit

Permalink
Fix test for mixed-ness in Function.sub, and fix Cofunction.sub to ma…
Browse files Browse the repository at this point in the history
…tch (#3961)

* fix test for mixed-ness in Function.sub, and fix Cofunction.sub to match

* allow FunctionSpace.sub to take negative indices

* cofunc docstring
  • Loading branch information
JHopeCollins authored Jan 9, 2025
1 parent 39742a6 commit 4109283
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 23 deletions.
28 changes: 22 additions & 6 deletions firedrake/cofunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from ufl.form import BaseForm
from pyop2 import op2, mpi
from pyadjoint.tape import stop_annotating, annotate_tape, get_working_tape
from finat.ufl import MixedElement
import firedrake.assemble
import firedrake.functionspaceimpl as functionspaceimpl
from firedrake import utils, vector, ufl_expr
Expand Down Expand Up @@ -119,11 +120,14 @@ def split(self):

@utils.cached_property
def _components(self):
if self.function_space().value_size == 1:
if self.function_space().rank == 0:
return (self, )
else:
return tuple(type(self)(self.function_space().sub(i), val=op2.DatView(self.dat, i))
for i in range(self.function_space().value_size))
if self.dof_dset.cdim == 1:
return (type(self)(self.function_space().sub(0), val=self.dat),)
else:
return tuple(type(self)(self.function_space().sub(i), val=op2.DatView(self.dat, j))
for i, j in enumerate(np.ndindex(self.dof_dset.dim)))

@PETSc.Log.EventDecorator()
def sub(self, i):
Expand All @@ -137,9 +141,9 @@ def sub(self, i):
:func:`~.VectorFunctionSpace` or :func:`~.TensorFunctionSpace`
this returns a proxy object indexing the ith component of the space,
suitable for use in boundary condition application."""
if len(self.function_space()) == 1:
return self._components[i]
return self.subfunctions[i]
mixed = type(self.function_space().ufl_element()) is MixedElement
data = self.subfunctions if mixed else self._components
return data[i]

def function_space(self):
r"""Return the :class:`.FunctionSpace`, or :class:`.MixedFunctionSpace`
Expand Down Expand Up @@ -321,6 +325,12 @@ def vector(self):
:class:`Cofunction`"""
return vector.Vector(self)

@property
def cell_set(self):
r"""The :class:`pyop2.types.set.Set` of cells for the mesh on which this
:class:`Cofunction` is defined."""
return self.function_space()._mesh.cell_set

@property
def node_set(self):
r"""A :class:`pyop2.types.set.Set` containing the nodes of this
Expand All @@ -330,6 +340,12 @@ def node_set(self):
"""
return self.function_space().node_set

@property
def dof_dset(self):
r"""A :class:`pyop2.types.dataset.DataSet` containing the degrees of freedom of
this :class:`Cofunction`."""
return self.function_space().dof_dset

def ufl_id(self):
return self.uid

Expand Down
13 changes: 4 additions & 9 deletions firedrake/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pyop2 import op2, mpi
from pyop2.exceptions import DataTypeError, DataValueError

from finat.ufl import MixedElement
from firedrake.utils import ScalarType, IntType, as_ctypes

from firedrake import functionspaceimpl
Expand Down Expand Up @@ -147,11 +148,8 @@ def sub(self, i):
rank-n :class:`~.FunctionSpace`, this returns a proxy object
indexing the ith component of the space, suitable for use in
boundary condition application."""
mixed = len(self.function_space()) != 1
mixed = type(self.function_space().ufl_element()) is MixedElement
data = self.subfunctions if mixed else self._components
bound = len(data)
if i < 0 or i >= bound:
raise IndexError(f"Invalid component {i}, not in [0, {bound})")
return data[i]

@property
Expand Down Expand Up @@ -352,11 +350,8 @@ def sub(self, i):
:func:`~.VectorFunctionSpace` or :func:`~.TensorFunctionSpace` this returns a proxy object
indexing the ith component of the space, suitable for use in
boundary condition application."""
mixed = len(self.function_space()) != 1
mixed = type(self.function_space().ufl_element()) is MixedElement
data = self.subfunctions if mixed else self._components
bound = len(data)
if i < 0 or i >= bound:
raise IndexError(f"Invalid component {i}, not in [0, {bound})")
return data[i]

@PETSc.Log.EventDecorator()
Expand Down Expand Up @@ -672,7 +667,7 @@ def single_eval(x, buf):
value_shape = self.ufl_shape

subfunctions = self.subfunctions
mixed = len(subfunctions) != 1
mixed = type(self.function_space().ufl_element()) is MixedElement

# Local evaluation
l_result = []
Expand Down
8 changes: 1 addition & 7 deletions firedrake/functionspaceimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,8 @@ def _components(self):

@PETSc.Log.EventDecorator()
def sub(self, i):
mixed = len(self) != 1
mixed = type(self.ufl_element()) is finat.ufl.MixedElement
data = self.subfunctions if mixed else self._components
bound = len(data)
if i < 0 or i >= bound:
raise IndexError(f"Invalid component {i}, not in [0, {bound})")
return data[i]

@utils.cached_property
Expand Down Expand Up @@ -664,9 +661,6 @@ def _components(self):

def sub(self, i):
r"""Return a view into the ith component."""
bound = len(self._components)
if i < 0 or i >= bound:
raise IndexError(f"Invalid component {i}, not in [0, {bound})")
return self._components[i]

def __mul__(self, other):
Expand Down
2 changes: 1 addition & 1 deletion tests/firedrake/regression/test_vfs_component_bcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def test_cant_integrate_subscripted_VFS(V):


@pytest.mark.parametrize("cmpt",
[-1, 2])
[-3, 2])
def test_cant_subscript_outside_components(V, cmpt):
with pytest.raises(IndexError):
return V.sub(cmpt)
Expand Down

0 comments on commit 4109283

Please sign in to comment.