From 0a67bc69ad9f8ef7e0451bbb8c1a91baa893051b Mon Sep 17 00:00:00 2001 From: ksagiyam Date: Thu, 5 Dec 2024 00:02:44 +0000 Subject: [PATCH] k --- firedrake/functionspaceimpl.py | 3 +- pyop2/types/dat.py | 29 ++++++++++- pyop2/types/dataset.py | 16 ++++--- .../test_restricted_function_space.py | 48 ++++++++++++++++++- 4 files changed, 86 insertions(+), 10 deletions(-) diff --git a/firedrake/functionspaceimpl.py b/firedrake/functionspaceimpl.py index 28e56515e1..6ee00ea610 100644 --- a/firedrake/functionspaceimpl.py +++ b/firedrake/functionspaceimpl.py @@ -910,7 +910,8 @@ def set_shared_data(self): self.node_set = sdata.node_set r"""A :class:`pyop2.types.set.Set` representing the function space nodes.""" self.dof_dset = op2.DataSet(self.node_set, self.shape or 1, - name="%s_nodes_dset" % self.name) + name="%s_nodes_dset" % self.name, + apply_local_global_filter=sdata.extruded) r"""A :class:`pyop2.types.dataset.DataSet` representing the function space degrees of freedom.""" diff --git a/pyop2/types/dat.py b/pyop2/types/dat.py index fb877c1a88..fca72d8fca 100644 --- a/pyop2/types/dat.py +++ b/pyop2/types/dat.py @@ -774,12 +774,39 @@ def _vec(self): data = self._data[:size[0]] return PETSc.Vec().createWithArray(data, size=size, bsize=self.cdim, comm=self.comm) + @utils.cached_property + def _data_filtered(self): + size = self.dataset.layout_vec.getSizes() + data = self._data[:size[0]] + return np.empty_like(data) + + @utils.cached_property + def _data_filter(self): + lgmap = self.dataset.lgmap + n = self.dataset.size + lgmap_owned = lgmap.indices[:n] + return lgmap_owned >= 0 + + @utils.cached_property + def _vec_filtered(self): + assert self.dtype == PETSc.ScalarType, \ + "Can't create Vec with type %s, must be %s" % (self.dtype, PETSc.ScalarType) + size = self.dataset.layout_vec.getSizes() + return PETSc.Vec().createWithArray(self._data_filtered, size=size, bsize=self.cdim, comm=self.comm) + @contextlib.contextmanager def vec_context(self, access): r"""A context manager for a :class:`PETSc.Vec` from a :class:`Dat`. :param access: Access descriptor: READ, WRITE, or RW.""" - yield self._vec + size = self.dataset.size + if self.dataset._apply_local_global_filter: + self._data_filtered[:] = self._data[:size][self._data_filter] + yield self._vec_filtered + else: + yield self._vec + if self.dataset._apply_local_global_filter: + self._data[:size][self._data_filter] = self._data_filtered[:] if access is not Access.READ: self.halo_valid = False diff --git a/pyop2/types/dataset.py b/pyop2/types/dataset.py index 3b4f4bfd8a..a2aa5f98ee 100644 --- a/pyop2/types/dataset.py +++ b/pyop2/types/dataset.py @@ -21,8 +21,9 @@ class DataSet(caching.ObjectCached): @utils.validate_type(('iter_set', Set, ex.SetTypeError), ('dim', (numbers.Integral, tuple, list), ex.DimTypeError), - ('name', str, ex.NameTypeError)) - def __init__(self, iter_set, dim=1, name=None): + ('name', str, ex.NameTypeError), + ('apply_local_global_filter', bool, ex.DataTypeError)) + def __init__(self, iter_set, dim=1, name=None, apply_local_global_filter=False): if isinstance(iter_set, ExtrudedSet): raise NotImplementedError("Not allowed!") if self._initialized: @@ -35,18 +36,19 @@ def __init__(self, iter_set, dim=1, name=None): self._cdim = np.prod(self._dim).item() self._name = name or "dset_#x%x" % id(self) self._initialized = True + self._apply_local_global_filter = apply_local_global_filter @classmethod def _process_args(cls, *args, **kwargs): return (args[0], ) + args, kwargs @classmethod - def _cache_key(cls, iter_set, dim=1, name=None): + def _cache_key(cls, iter_set, dim=1, name=None, apply_local_global_filter=False): return (iter_set, utils.as_tuple(dim, numbers.Integral)) @utils.cached_property def _wrapper_cache_key_(self): - return (type(self), self.dim, self._set._wrapper_cache_key_) + return (type(self), self.dim, self._set._wrapper_cache_key_, self._apply_local_global_filter) def __getstate__(self): """Extract state to pickle.""" @@ -97,11 +99,11 @@ def __len__(self): return 1 def __str__(self): - return "OP2 DataSet: %s on set %s, with dim %s" % \ - (self._name, self._set, self._dim) + return "OP2 DataSet: %s on set %s, with dim %s, %s" % \ + (self._name, self._set, self._dim, self._apply_local_global_filter) def __repr__(self): - return "DataSet(%r, %r, %r)" % (self._set, self._dim, self._name) + return "DataSet(%r, %r, %r, %r)" % (self._set, self._dim, self._name, self._apply_local_global_filter) def __contains__(self, dat): """Indicate whether a given Dat is compatible with this DataSet.""" diff --git a/tests/firedrake/regression/test_restricted_function_space.py b/tests/firedrake/regression/test_restricted_function_space.py index 190e597ab2..20ff838ea9 100644 --- a/tests/firedrake/regression/test_restricted_function_space.py +++ b/tests/firedrake/regression/test_restricted_function_space.py @@ -209,6 +209,42 @@ def test_restricted_mixed_spaces(i, j): @pytest.mark.parallel(nprocs=2) def test_restricted_function_space_extrusion(): + # + # rank 0 rank 1 + # + # plex points: + # + # +-------+-------+ +-------+-------+ + # | | | | | | + # | | | | | | + # | | | | | | + # +-------+-------+ +-------+-------+ + # 2 0 (3) (1) (4) (4) (1) 2 0 3 () = ghost + # + # mesh._dm_renumbering: + # + # [0, 2, 3, 1, 4] [0, 3, 2, 1, 4] + # + # Local DoFs: + # + # 5---2--(8)(11)(14) (14)(11)--8---2---5 + # | | | | | | + # 4 1 (7)(10)(13) (13)(10) 7 1 4 + # | | | | | | + # 3---0--(6)-(9)(12) (12)-(9)--6---0---3 () = ghost + # + # Global DoFs: + # + # 3---1---9---5---7 + # | | | + # 2 0 8 4 6 + # | | | + # x---x---x---x---x + # + # LGMap: + # + # rank 0 : [-1, 0, 1, -1, 2, 3, -1, 8, 9, -1, 4, 5, -1, 6, 7] + # rank 1 : [-1, 4, 5, -1, 6, 7, -1, 8, 9, -1, 0, 1, -1, 2, 3] mesh = UnitIntervalMesh(2) extm = ExtrudedMesh(mesh, 1) V = FunctionSpace(extm, "CG", 2) @@ -223,4 +259,14 @@ def test_restricted_function_space_extrusion(): lgmap_expected = [-1, 0, 1, -1, 2, 3, -1, 8, 9, -1, 4, 5, -1, 6, 7] else: lgmap_expected = [-1, 4, 5, -1, 6, 7, -1, 8, 9, -1, 0, 1, -1, 2, 3] - assert np.allclose(lgmap, lgmap_expected) + assert np.allclose(lgmap.indices, lgmap_expected) + f = Function(V_res) + n = V_res.dof_dset.size + lgmap_owned = lgmap.indices[:n] + local_global_filter = lgmap_owned >= 0 + local_array = 1.0 * np.arange(V_res.dof_dset.total_size) + f.dat.data_wo_with_halos[:] = local_array + with f.dat.vec as v: + assert np.allclose(v.getArray(), local_array[:n][local_global_filter]) + v *= 2. + assert np.allclose(f.dat.data_ro_with_halos[:n][local_global_filter], 2. * local_array[:n][local_global_filter])