diff --git a/src/Base/Iterator.H b/src/Base/Iterator.H new file mode 100644 index 00000000..f79fb892 --- /dev/null +++ b/src/Base/Iterator.H @@ -0,0 +1,65 @@ +/* Copyright 2021-2022 The AMReX Community + * + * Authors: Axel Huebl + * License: BSD-3-Clause-LBNL + */ +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace py = pybind11; +using namespace amrex; + +namespace pyAMReX +{ + /** This is a helper function for the C++ equivalent of void operator++() + * + * In Python, iterators always are called with __next__, even for the + * first access. This means we need to handle the first iterator element + * explicitly, otherwise we will jump directly to the 2nd element. We do + * this the same way as pybind11 does this, via a little state: + * https://github.com/AMReX-Codes/pyamrex/pull/50 + * https://github.com/pybind/pybind11/blob/v2.10.0/include/pybind11/pybind11.h#L2269-L2282 + * + * To avoid unnecessary (and expensive) copies, remember to only call this + * helper always with py::return_value_policy::reference_internal! + * + * + * @tparam T_Iterator This is usally MFIter or Par(Const)Iter or derived classes + * @param it the current iterator + * @return the updated iterator + */ + template< typename T_Iterator > + T_Iterator & + iterator_next( T_Iterator & it ) + { + py::object self = py::cast(it); + if (!py::hasattr(self, "first_or_done")) + self.attr("first_or_done") = true; + + bool first_or_done = self.attr("first_or_done").cast(); + if (first_or_done) { + first_or_done = false; + self.attr("first_or_done") = first_or_done; + } + else + ++it; + if( !it.isValid() ) + { + first_or_done = true; + it.Finalize(); + throw py::stop_iteration(); + } + return it; + } +} diff --git a/src/Base/MultiFab.cpp b/src/Base/MultiFab.cpp index 1dbb1798..efa09265 100644 --- a/src/Base/MultiFab.cpp +++ b/src/Base/MultiFab.cpp @@ -6,6 +6,8 @@ #include #include +#include "Base/Iterator.H" + #include #include #include @@ -79,26 +81,7 @@ void init_MultiFab(py::module &m) { // eq. to void operator++() .def("__next__", - [](MFIter & mfi) -> MFIter & { - py::object self = py::cast(mfi); - if (!py::hasattr(self, "first_or_done")) - self.attr("first_or_done") = true; - - bool first_or_done = self.attr("first_or_done").cast(); - if (first_or_done) { - first_or_done = false; - self.attr("first_or_done") = first_or_done; - } - else - ++mfi; - if( !mfi.isValid() ) - { - first_or_done = true; - mfi.Finalize(); - throw py::stop_iteration(); - } - return mfi; - }, + &pyAMReX::iterator_next, py::return_value_policy::reference_internal ) diff --git a/src/Particle/ParticleContainer.cpp b/src/Particle/ParticleContainer.cpp index 9d54c6fd..1bacb77b 100644 --- a/src/Particle/ParticleContainer.cpp +++ b/src/Particle/ParticleContainer.cpp @@ -6,8 +6,11 @@ #include #include +#include "Base/Iterator.H" + #include #include +#include #include #include #include @@ -19,9 +22,88 @@ namespace py = pybind11; using namespace amrex; + +template +void make_Base_Iterators (py::module &m) +{ + using iterator_base = T_ParIterBase; + using container = typename iterator_base::ContainerType; + constexpr int NStructReal = container::NStructReal; + constexpr int NStructInt = container::NStructInt; + constexpr int NArrayReal = container::NArrayReal; + constexpr int NArrayInt = container::NArrayInt; + + std::string particle_it_base_name = std::string("ParIterBase_").append(std::to_string(NStructReal) + "_" + std::to_string(NStructInt) + "_" + std::to_string(NArrayReal) + "_" + std::to_string(NArrayInt)); + if (is_const) particle_it_base_name = "Const" + particle_it_base_name; + py::class_(m, particle_it_base_name.c_str()) + .def(py::init(), + py::arg("particle_container"), py::arg("level")) + .def(py::init(), + py::arg("particle_container"), py::arg("level"), py::arg("info")) + + .def("particle_tile", &iterator_base::GetParticleTile, + py::return_value_policy::reference_internal) + .def("aos", &iterator_base::GetArrayOfStructs, + py::return_value_policy::reference_internal) + .def("soa", &iterator_base::GetStructOfArrays, + py::return_value_policy::reference_internal) + + .def_property_readonly("num_particles", &iterator_base::numParticles) + .def_property_readonly("num_real_particles", &iterator_base::numRealParticles) + .def_property_readonly("num_neighbor_particles", &iterator_base::numNeighborParticles) + .def_property_readonly("level", &iterator_base::GetLevel) + .def_property_readonly("pair_index", &iterator_base::GetPairIndex) + .def("geom", &iterator_base::Geom, py::arg("level")) + + // eq. to void operator++() + .def("__next__", + &pyAMReX::iterator_next, + py::return_value_policy::reference_internal + ) + .def("__iter__", + [](iterator_base & it) -> iterator_base & { + return it; + }, + py::return_value_policy::reference_internal + ) + ; +} + +template class Allocator=DefaultAllocator> +void make_Iterators (py::module &m) +{ + using iterator = T_ParIter; + using container = typename iterator::ContainerType; + constexpr int NStructReal = container::NStructReal; + constexpr int NStructInt = container::NStructInt; + constexpr int NArrayReal = container::NArrayReal; + constexpr int NArrayInt = container::NArrayInt; + + using iterator_base = amrex::ParIterBase; + make_Base_Iterators< is_const, iterator_base >(m); + + auto particle_it_name = std::string("Par"); + if (is_const) particle_it_name += "Const"; + particle_it_name += std::string("Iter_").append(std::to_string(NStructReal) + "_" + std::to_string(NStructInt) + "_" + std::to_string(NArrayReal) + "_" + std::to_string(NArrayInt)); + py::class_(m, particle_it_name.c_str()) + .def("__repr__", + [particle_it_name](iterator const & pti) { + std::string r = ""); + return r; + } + ) + .def(py::init(), + py::arg("particle_container"), py::arg("level")) + .def(py::init(), + py::arg("particle_container"), py::arg("level"), py::arg("info")) + ; +} + template class Allocator=DefaultAllocator> -void make_ParticleContainer(py::module &m) +void make_ParticleContainer_and_Iterators (py::module &m) { using ParticleContainerType = ParticleContainer< T_NStructReal, T_NStructInt, T_NArrayReal, T_NArrayInt, @@ -237,17 +319,21 @@ void make_ParticleContainer(py::module &m) // m_particles[lev][index].define(NumRuntimeRealComps(), NumRuntimeIntComps()); // return ParticlesAt(lev, iter); // } - ; + + using iterator = amrex::ParIter; + make_Iterators< false, iterator, Allocator >(m); + using const_iterator = amrex::ParConstIter; + make_Iterators< true, const_iterator, Allocator >(m); } void init_ParticleContainer(py::module& m) { // TODO: we might need to move all or most of the defines in here into a // test/example submodule, so they do not collide with downstream projects - make_ParticleContainer< 1, 1, 2, 1> (m); - make_ParticleContainer< 0, 0, 4, 0> (m); // HiPACE++ 22.07 - make_ParticleContainer< 0, 0, 5, 0> (m); // ImpactX 22.07 - make_ParticleContainer< 0, 0, 7, 0> (m); - make_ParticleContainer< 0, 0, 37, 1> (m); // HiPACE++ 22.07 + make_ParticleContainer_and_Iterators< 1, 1, 2, 1> (m); + make_ParticleContainer_and_Iterators< 0, 0, 4, 0> (m); // HiPACE++ 22.07 + make_ParticleContainer_and_Iterators< 0, 0, 5, 0> (m); // ImpactX 22.07 + make_ParticleContainer_and_Iterators< 0, 0, 7, 0> (m); + make_ParticleContainer_and_Iterators< 0, 0, 37, 1> (m); // HiPACE++ 22.07 } diff --git a/tests/test_particleContainer.py b/tests/test_particleContainer.py index 7d82af66..bdd7d6cd 100644 --- a/tests/test_particleContainer.py +++ b/tests/test_particleContainer.py @@ -121,6 +121,58 @@ def test_pc_init(): assert pc.TotalNumberOfParticles() == pc.NumberOfParticlesAtLevel(0) == npart assert pc.OK() + print("Iterate particle boxes & set values") + lvl = 0 + for pti in amrex.ParIter_1_1_2_1(pc, level=lvl): + print("...") + assert pti.num_particles == 1 + assert pti.num_real_particles == 1 + assert pti.num_neighbor_particles == 0 + assert pti.level == lvl + print(pti.pair_index) + print(pti.geom(level=lvl)) + + aos = pti.aos() + aos_arr = np.array(aos, copy=False) + aos_arr[0]["x"] = 0.30 + aos_arr[0]["y"] = 0.35 + aos_arr[0]["z"] = 0.40 + + # TODO: this seems to write into a copy of the data + soa = pti.soa() + real_arrays = soa.GetRealData() + int_arrays = soa.GetIntData() + real_arrays[0] = [0.55] + real_arrays[1] = [0.22] + int_arrays[0] = [2] + + assert np.allclose(real_arrays[0], np.array([0.55])) + assert np.allclose(real_arrays[1], np.array([0.22])) + assert np.allclose(int_arrays[0], np.array([2])) + + # read-only + for pti in amrex.ParConstIter_1_1_2_1(pc, level=lvl): + assert pti.num_particles == 1 + assert pti.num_real_particles == 1 + assert pti.num_neighbor_particles == 0 + assert pti.level == lvl + + aos = pti.aos() + aos_arr = np.array(aos, copy=False) + assert aos[0].x == 0.30 + assert aos[0].y == 0.35 + assert aos[0].z == 0.40 + assert aos_arr[0]["z"] == 0.40 + + soa = pti.soa() + real_arrays = soa.GetRealData() + int_arrays = soa.GetIntData() + print(real_arrays[0]) + # TODO: this does not work yet and is still the original data + # assert np.allclose(real_arrays[0], np.array([0.55])) + # assert np.allclose(real_arrays[1], np.array([0.22])) + # assert np.allclose(int_arrays[0], np.array([2])) + def test_particle_init(Npart, particle_container): pc = particle_container