Skip to content

Commit

Permalink
k
Browse files Browse the repository at this point in the history
  • Loading branch information
ksagiyam committed Dec 5, 2024
1 parent 95bf46f commit 0a67bc6
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 10 deletions.
3 changes: 2 additions & 1 deletion firedrake/functionspaceimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,7 +910,8 @@ def set_shared_data(self):
self.node_set = sdata.node_set
r"""A :class:`pyop2.types.set.Set` representing the function space nodes."""
self.dof_dset = op2.DataSet(self.node_set, self.shape or 1,
name="%s_nodes_dset" % self.name)
name="%s_nodes_dset" % self.name,
apply_local_global_filter=sdata.extruded)
r"""A :class:`pyop2.types.dataset.DataSet` representing the function space
degrees of freedom."""

Expand Down
29 changes: 28 additions & 1 deletion pyop2/types/dat.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,12 +774,39 @@ def _vec(self):
data = self._data[:size[0]]
return PETSc.Vec().createWithArray(data, size=size, bsize=self.cdim, comm=self.comm)

@utils.cached_property
def _data_filtered(self):
size = self.dataset.layout_vec.getSizes()
data = self._data[:size[0]]
return np.empty_like(data)

@utils.cached_property
def _data_filter(self):
lgmap = self.dataset.lgmap
n = self.dataset.size
lgmap_owned = lgmap.indices[:n]
return lgmap_owned >= 0

@utils.cached_property
def _vec_filtered(self):
assert self.dtype == PETSc.ScalarType, \
"Can't create Vec with type %s, must be %s" % (self.dtype, PETSc.ScalarType)
size = self.dataset.layout_vec.getSizes()
return PETSc.Vec().createWithArray(self._data_filtered, size=size, bsize=self.cdim, comm=self.comm)

@contextlib.contextmanager
def vec_context(self, access):
r"""A context manager for a :class:`PETSc.Vec` from a :class:`Dat`.
:param access: Access descriptor: READ, WRITE, or RW."""
yield self._vec
size = self.dataset.size
if self.dataset._apply_local_global_filter:
self._data_filtered[:] = self._data[:size][self._data_filter]
yield self._vec_filtered
else:
yield self._vec
if self.dataset._apply_local_global_filter:
self._data[:size][self._data_filter] = self._data_filtered[:]
if access is not Access.READ:
self.halo_valid = False

Expand Down
16 changes: 9 additions & 7 deletions pyop2/types/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ class DataSet(caching.ObjectCached):

@utils.validate_type(('iter_set', Set, ex.SetTypeError),
('dim', (numbers.Integral, tuple, list), ex.DimTypeError),
('name', str, ex.NameTypeError))
def __init__(self, iter_set, dim=1, name=None):
('name', str, ex.NameTypeError),
('apply_local_global_filter', bool, ex.DataTypeError))
def __init__(self, iter_set, dim=1, name=None, apply_local_global_filter=False):
if isinstance(iter_set, ExtrudedSet):
raise NotImplementedError("Not allowed!")
if self._initialized:
Expand All @@ -35,18 +36,19 @@ def __init__(self, iter_set, dim=1, name=None):
self._cdim = np.prod(self._dim).item()
self._name = name or "dset_#x%x" % id(self)
self._initialized = True
self._apply_local_global_filter = apply_local_global_filter

@classmethod
def _process_args(cls, *args, **kwargs):
return (args[0], ) + args, kwargs

@classmethod
def _cache_key(cls, iter_set, dim=1, name=None):
def _cache_key(cls, iter_set, dim=1, name=None, apply_local_global_filter=False):
return (iter_set, utils.as_tuple(dim, numbers.Integral))

@utils.cached_property
def _wrapper_cache_key_(self):
return (type(self), self.dim, self._set._wrapper_cache_key_)
return (type(self), self.dim, self._set._wrapper_cache_key_, self._apply_local_global_filter)

def __getstate__(self):
"""Extract state to pickle."""
Expand Down Expand Up @@ -97,11 +99,11 @@ def __len__(self):
return 1

def __str__(self):
return "OP2 DataSet: %s on set %s, with dim %s" % \
(self._name, self._set, self._dim)
return "OP2 DataSet: %s on set %s, with dim %s, %s" % \
(self._name, self._set, self._dim, self._apply_local_global_filter)

def __repr__(self):
return "DataSet(%r, %r, %r)" % (self._set, self._dim, self._name)
return "DataSet(%r, %r, %r, %r)" % (self._set, self._dim, self._name, self._apply_local_global_filter)

def __contains__(self, dat):
"""Indicate whether a given Dat is compatible with this DataSet."""
Expand Down
48 changes: 47 additions & 1 deletion tests/firedrake/regression/test_restricted_function_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,42 @@ def test_restricted_mixed_spaces(i, j):

@pytest.mark.parallel(nprocs=2)
def test_restricted_function_space_extrusion():
#
# rank 0 rank 1
#
# plex points:
#
# +-------+-------+ +-------+-------+
# | | | | | |
# | | | | | |
# | | | | | |
# +-------+-------+ +-------+-------+
# 2 0 (3) (1) (4) (4) (1) 2 0 3 () = ghost
#
# mesh._dm_renumbering:
#
# [0, 2, 3, 1, 4] [0, 3, 2, 1, 4]
#
# Local DoFs:
#
# 5---2--(8)(11)(14) (14)(11)--8---2---5
# | | | | | |
# 4 1 (7)(10)(13) (13)(10) 7 1 4
# | | | | | |
# 3---0--(6)-(9)(12) (12)-(9)--6---0---3 () = ghost
#
# Global DoFs:
#
# 3---1---9---5---7
# | | |
# 2 0 8 4 6
# | | |
# x---x---x---x---x
#
# LGMap:
#
# rank 0 : [-1, 0, 1, -1, 2, 3, -1, 8, 9, -1, 4, 5, -1, 6, 7]
# rank 1 : [-1, 4, 5, -1, 6, 7, -1, 8, 9, -1, 0, 1, -1, 2, 3]
mesh = UnitIntervalMesh(2)
extm = ExtrudedMesh(mesh, 1)
V = FunctionSpace(extm, "CG", 2)
Expand All @@ -223,4 +259,14 @@ def test_restricted_function_space_extrusion():
lgmap_expected = [-1, 0, 1, -1, 2, 3, -1, 8, 9, -1, 4, 5, -1, 6, 7]
else:
lgmap_expected = [-1, 4, 5, -1, 6, 7, -1, 8, 9, -1, 0, 1, -1, 2, 3]
assert np.allclose(lgmap, lgmap_expected)
assert np.allclose(lgmap.indices, lgmap_expected)
f = Function(V_res)
n = V_res.dof_dset.size
lgmap_owned = lgmap.indices[:n]
local_global_filter = lgmap_owned >= 0
local_array = 1.0 * np.arange(V_res.dof_dset.total_size)
f.dat.data_wo_with_halos[:] = local_array
with f.dat.vec as v:
assert np.allclose(v.getArray(), local_array[:n][local_global_filter])
v *= 2.
assert np.allclose(f.dat.data_ro_with_halos[:n][local_global_filter], 2. * local_array[:n][local_global_filter])

0 comments on commit 0a67bc6

Please sign in to comment.