Skip to content

Commit

Permalink
pyAMReX: Bind Particle Iterators (#87)
Browse files Browse the repository at this point in the history
* pyAMReX: Bind Particle Iterators

* First Tests
  • Loading branch information
ax3l authored Oct 20, 2022
1 parent 7eb82a9 commit 16802d8
Show file tree
Hide file tree
Showing 4 changed files with 213 additions and 27 deletions.
65 changes: 65 additions & 0 deletions src/Base/Iterator.H
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/* Copyright 2021-2022 The AMReX Community
*
* Authors: Axel Huebl
* License: BSD-3-Clause-LBNL
*/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include <AMReX_Config.H>
#include <AMReX_BoxArray.H>
#include <AMReX_DistributionMapping.H>
#include <AMReX_FArrayBox.H>
#include <AMReX_FabArray.H>
#include <AMReX_FabArrayBase.H>
#include <AMReX_MultiFab.H>

#include <memory>
#include <string>

namespace py = pybind11;
using namespace amrex;

namespace pyAMReX
{
/** This is a helper function for the C++ equivalent of void operator++()
*
* In Python, iterators always are called with __next__, even for the
* first access. This means we need to handle the first iterator element
* explicitly, otherwise we will jump directly to the 2nd element. We do
* this the same way as pybind11 does this, via a little state:
* https://github.com/AMReX-Codes/pyamrex/pull/50
* https://github.com/pybind/pybind11/blob/v2.10.0/include/pybind11/pybind11.h#L2269-L2282
*
* To avoid unnecessary (and expensive) copies, remember to only call this
* helper always with py::return_value_policy::reference_internal!
*
*
* @tparam T_Iterator This is usally MFIter or Par(Const)Iter or derived classes
* @param it the current iterator
* @return the updated iterator
*/
template< typename T_Iterator >
T_Iterator &
iterator_next( T_Iterator & it )
{
py::object self = py::cast(it);
if (!py::hasattr(self, "first_or_done"))
self.attr("first_or_done") = true;

bool first_or_done = self.attr("first_or_done").cast<bool>();
if (first_or_done) {
first_or_done = false;
self.attr("first_or_done") = first_or_done;
}
else
++it;
if( !it.isValid() )
{
first_or_done = true;
it.Finalize();
throw py::stop_iteration();
}
return it;
}
}
23 changes: 3 additions & 20 deletions src/Base/MultiFab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "Base/Iterator.H"

#include <AMReX_Config.H>
#include <AMReX_BoxArray.H>
#include <AMReX_DistributionMapping.H>
Expand Down Expand Up @@ -79,26 +81,7 @@ void init_MultiFab(py::module &m) {

// eq. to void operator++()
.def("__next__",
[](MFIter & mfi) -> MFIter & {
py::object self = py::cast(mfi);
if (!py::hasattr(self, "first_or_done"))
self.attr("first_or_done") = true;

bool first_or_done = self.attr("first_or_done").cast<bool>();
if (first_or_done) {
first_or_done = false;
self.attr("first_or_done") = first_or_done;
}
else
++mfi;
if( !mfi.isValid() )
{
first_or_done = true;
mfi.Finalize();
throw py::stop_iteration();
}
return mfi;
},
&pyAMReX::iterator_next<MFIter>,
py::return_value_policy::reference_internal
)

Expand Down
100 changes: 93 additions & 7 deletions src/Particle/ParticleContainer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "Base/Iterator.H"

#include <AMReX_BoxArray.H>
#include <AMReX_IntVect.H>
#include <AMReX_ParIter.H>
#include <AMReX_Particles.H>
#include <AMReX_ParticleContainer.H>
#include <AMReX_ParticleTile.H>
Expand All @@ -19,9 +22,88 @@
namespace py = pybind11;
using namespace amrex;


template <bool is_const, typename T_ParIterBase>
void make_Base_Iterators (py::module &m)
{
using iterator_base = T_ParIterBase;
using container = typename iterator_base::ContainerType;
constexpr int NStructReal = container::NStructReal;
constexpr int NStructInt = container::NStructInt;
constexpr int NArrayReal = container::NArrayReal;
constexpr int NArrayInt = container::NArrayInt;

std::string particle_it_base_name = std::string("ParIterBase_").append(std::to_string(NStructReal) + "_" + std::to_string(NStructInt) + "_" + std::to_string(NArrayReal) + "_" + std::to_string(NArrayInt));
if (is_const) particle_it_base_name = "Const" + particle_it_base_name;
py::class_<iterator_base, MFIter>(m, particle_it_base_name.c_str())
.def(py::init<container&, int>(),
py::arg("particle_container"), py::arg("level"))
.def(py::init<container&, int, MFItInfo&>(),
py::arg("particle_container"), py::arg("level"), py::arg("info"))

.def("particle_tile", &iterator_base::GetParticleTile,
py::return_value_policy::reference_internal)
.def("aos", &iterator_base::GetArrayOfStructs,
py::return_value_policy::reference_internal)
.def("soa", &iterator_base::GetStructOfArrays,
py::return_value_policy::reference_internal)

.def_property_readonly("num_particles", &iterator_base::numParticles)
.def_property_readonly("num_real_particles", &iterator_base::numRealParticles)
.def_property_readonly("num_neighbor_particles", &iterator_base::numNeighborParticles)
.def_property_readonly("level", &iterator_base::GetLevel)
.def_property_readonly("pair_index", &iterator_base::GetPairIndex)
.def("geom", &iterator_base::Geom, py::arg("level"))

// eq. to void operator++()
.def("__next__",
&pyAMReX::iterator_next<iterator_base>,
py::return_value_policy::reference_internal
)
.def("__iter__",
[](iterator_base & it) -> iterator_base & {
return it;
},
py::return_value_policy::reference_internal
)
;
}

template <bool is_const, typename T_ParIter, template<class> class Allocator=DefaultAllocator>
void make_Iterators (py::module &m)
{
using iterator = T_ParIter;
using container = typename iterator::ContainerType;
constexpr int NStructReal = container::NStructReal;
constexpr int NStructInt = container::NStructInt;
constexpr int NArrayReal = container::NArrayReal;
constexpr int NArrayInt = container::NArrayInt;

using iterator_base = amrex::ParIterBase<is_const, NStructReal, NStructInt, NArrayReal, NArrayInt, Allocator>;
make_Base_Iterators< is_const, iterator_base >(m);

auto particle_it_name = std::string("Par");
if (is_const) particle_it_name += "Const";
particle_it_name += std::string("Iter_").append(std::to_string(NStructReal) + "_" + std::to_string(NStructInt) + "_" + std::to_string(NArrayReal) + "_" + std::to_string(NArrayInt));
py::class_<iterator, iterator_base>(m, particle_it_name.c_str())
.def("__repr__",
[particle_it_name](iterator const & pti) {
std::string r = "<amrex." + particle_it_name + " (";
if( !pti.isValid() ) { r.append("in"); }
r.append("valid)>");
return r;
}
)
.def(py::init<container&, int>(),
py::arg("particle_container"), py::arg("level"))
.def(py::init<container&, int, MFItInfo&>(),
py::arg("particle_container"), py::arg("level"), py::arg("info"))
;
}

template <int T_NStructReal, int T_NStructInt=0, int T_NArrayReal=0, int T_NArrayInt=0,
template<class> class Allocator=DefaultAllocator>
void make_ParticleContainer(py::module &m)
void make_ParticleContainer_and_Iterators (py::module &m)
{
using ParticleContainerType = ParticleContainer<
T_NStructReal, T_NStructInt, T_NArrayReal, T_NArrayInt,
Expand Down Expand Up @@ -237,17 +319,21 @@ void make_ParticleContainer(py::module &m)
// m_particles[lev][index].define(NumRuntimeRealComps(), NumRuntimeIntComps());
// return ParticlesAt(lev, iter);
// }

;

using iterator = amrex::ParIter<T_NStructReal, T_NStructInt, T_NArrayReal, T_NArrayInt, Allocator>;
make_Iterators< false, iterator, Allocator >(m);
using const_iterator = amrex::ParConstIter<T_NStructReal, T_NStructInt, T_NArrayReal, T_NArrayInt, Allocator>;
make_Iterators< true, const_iterator, Allocator >(m);
}


void init_ParticleContainer(py::module& m) {
// TODO: we might need to move all or most of the defines in here into a
// test/example submodule, so they do not collide with downstream projects
make_ParticleContainer< 1, 1, 2, 1> (m);
make_ParticleContainer< 0, 0, 4, 0> (m); // HiPACE++ 22.07
make_ParticleContainer< 0, 0, 5, 0> (m); // ImpactX 22.07
make_ParticleContainer< 0, 0, 7, 0> (m);
make_ParticleContainer< 0, 0, 37, 1> (m); // HiPACE++ 22.07
make_ParticleContainer_and_Iterators< 1, 1, 2, 1> (m);
make_ParticleContainer_and_Iterators< 0, 0, 4, 0> (m); // HiPACE++ 22.07
make_ParticleContainer_and_Iterators< 0, 0, 5, 0> (m); // ImpactX 22.07
make_ParticleContainer_and_Iterators< 0, 0, 7, 0> (m);
make_ParticleContainer_and_Iterators< 0, 0, 37, 1> (m); // HiPACE++ 22.07
}
52 changes: 52 additions & 0 deletions tests/test_particleContainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,58 @@ def test_pc_init():
assert pc.TotalNumberOfParticles() == pc.NumberOfParticlesAtLevel(0) == npart
assert pc.OK()

print("Iterate particle boxes & set values")
lvl = 0
for pti in amrex.ParIter_1_1_2_1(pc, level=lvl):
print("...")
assert pti.num_particles == 1
assert pti.num_real_particles == 1
assert pti.num_neighbor_particles == 0
assert pti.level == lvl
print(pti.pair_index)
print(pti.geom(level=lvl))

aos = pti.aos()
aos_arr = np.array(aos, copy=False)
aos_arr[0]["x"] = 0.30
aos_arr[0]["y"] = 0.35
aos_arr[0]["z"] = 0.40

# TODO: this seems to write into a copy of the data
soa = pti.soa()
real_arrays = soa.GetRealData()
int_arrays = soa.GetIntData()
real_arrays[0] = [0.55]
real_arrays[1] = [0.22]
int_arrays[0] = [2]

assert np.allclose(real_arrays[0], np.array([0.55]))
assert np.allclose(real_arrays[1], np.array([0.22]))
assert np.allclose(int_arrays[0], np.array([2]))

# read-only
for pti in amrex.ParConstIter_1_1_2_1(pc, level=lvl):
assert pti.num_particles == 1
assert pti.num_real_particles == 1
assert pti.num_neighbor_particles == 0
assert pti.level == lvl

aos = pti.aos()
aos_arr = np.array(aos, copy=False)
assert aos[0].x == 0.30
assert aos[0].y == 0.35
assert aos[0].z == 0.40
assert aos_arr[0]["z"] == 0.40

soa = pti.soa()
real_arrays = soa.GetRealData()
int_arrays = soa.GetIntData()
print(real_arrays[0])
# TODO: this does not work yet and is still the original data
# assert np.allclose(real_arrays[0], np.array([0.55]))
# assert np.allclose(real_arrays[1], np.array([0.22]))
# assert np.allclose(int_arrays[0], np.array([2]))


def test_particle_init(Npart, particle_container):
pc = particle_container
Expand Down

0 comments on commit 16802d8

Please sign in to comment.