Skip to content

Commit

Permalink
tbb for kdtree
Browse files Browse the repository at this point in the history
  • Loading branch information
koide3 committed Sep 22, 2024
1 parent 178defd commit de843d2
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 8 deletions.
12 changes: 11 additions & 1 deletion include/gtsam_points/ann/kdtree2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <gtsam_points/ann/small_kdtree.hpp>
#include <gtsam_points/types/frame_traits.hpp>
#include <gtsam_points/ann/nearest_neighbor_search.hpp>
#include <gtsam_points/util/parallelism.hpp>

namespace gtsam_points {

Expand All @@ -24,7 +25,16 @@ struct KdTree2 : public NearestNeighborSearch {
KdTree2(const std::shared_ptr<const Frame>& frame, int build_num_threads = 1)
: frame(frame),
search_eps(-1.0),
index(new Index(*this->frame, KdTreeBuilderOMP(build_num_threads))) {
index(
is_omp_default() || build_num_threads == 1 ? //
new Index(*this->frame, KdTreeBuilderOMP(build_num_threads)) //
: //
#ifdef GTSAM_POINTS_TBB //
new Index(*this->frame, KdTreeBuilderTBB()) //
#else //
new Index(*this->frame, KdTreeBuilder())
#endif
) {
if (frame::size(*frame) == 0) {
std::cerr << "error: empty frame is given for KdTree2" << std::endl;
std::cerr << " : frame::size() may not be implemented" << std::endl;
Expand Down
75 changes: 75 additions & 0 deletions include/gtsam_points/ann/small_kdtree.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@
#include <gtsam_points/ann/knn_result.hpp>
#include <gtsam_points/types/frame_traits.hpp>

#ifdef GTSAM_POINTS_TBB
#include <tbb/parallel_invoke.h>
#endif

namespace gtsam_points {

/// @brief Parameters to control the projection axis search.
Expand Down Expand Up @@ -266,6 +270,77 @@ struct KdTreeBuilderOMP {
ProjectionSetting projection_setting; ///< Projection setting.
};

#ifdef GTSAM_POINTS_TBB
/// @brief Kd-tree builder with TBB.
struct KdTreeBuilderTBB {
public:
/// @brief Build KdTree
template <typename KdTree, typename PointCloud>
void build_tree(KdTree& kdtree, const PointCloud& points) const {
kdtree.indices.resize(frame::size(points));
std::iota(kdtree.indices.begin(), kdtree.indices.end(), 0);

std::atomic_uint64_t node_count = 0;
kdtree.nodes.resize(frame::size(points));
kdtree.root = create_node(kdtree, node_count, points, kdtree.indices.begin(), kdtree.indices.begin(), kdtree.indices.end());
kdtree.nodes.resize(node_count);
}

/// @brief Create a Kd-tree node from the given point indices.
/// @param global_first Global first point index iterator (i.e., this->indices.begin()).
/// @param first First point index iterator to be scanned.
/// @param last Last point index iterator to be scanned.
/// @return Index of the created node.
template <typename PointCloud, typename KdTree, typename IndexConstIterator>
NodeIndexType create_node(
KdTree& kdtree,
std::atomic_uint64_t& node_count,
const PointCloud& points,
IndexConstIterator global_first,
IndexConstIterator first,
IndexConstIterator last) const {
const size_t N = std::distance(first, last);
const NodeIndexType node_index = node_count++;
auto& node = kdtree.nodes[node_index];

// Create a leaf node.
if (N <= max_leaf_size) {
// std::sort(first, last);
node.node_type.lr.first = std::distance(global_first, first);
node.node_type.lr.last = std::distance(global_first, last);

return node_index;
}

// Find the best axis to split the input points.
using Projection = typename KdTree::Projection;
const auto proj = Projection::find_axis(points, first, last, projection_setting);
const auto median_itr = first + N / 2;
std::nth_element(first, median_itr, last, [&](size_t i, size_t j) { return proj(frame::point(points, i)) < proj(frame::point(points, j)); });

// Create a non-leaf node.
node.node_type.sub.proj = proj;
node.node_type.sub.thresh = proj(frame::point(points, *median_itr));

// Create left and right child nodes.
if (N > 512) {
tbb::parallel_invoke(
[&] { node.left = create_node(kdtree, node_count, points, global_first, first, median_itr); },
[&] { node.right = create_node(kdtree, node_count, points, global_first, median_itr, last); });
} else {
node.left = create_node(kdtree, node_count, points, global_first, first, median_itr);
node.right = create_node(kdtree, node_count, points, global_first, median_itr, last);
}

return node_index;
}

public:
int max_leaf_size = 20; ///< Maximum number of points in a leaf node.
ProjectionSetting projection_setting; ///< Projection setting.
};
#endif

/// @brief "Unsafe" KdTree.
/// @note This class does not hold the ownership of the input points.
/// You must keep the input points along with this class.
Expand Down
13 changes: 12 additions & 1 deletion src/gtsam_points/ann/kdtree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include <gtsam_points/types/frame_traits.hpp>
#include <gtsam_points/ann/small_kdtree.hpp>
#include <gtsam_points/util/parallelism.hpp>

namespace gtsam_points {

Expand All @@ -22,7 +23,17 @@ KdTree::KdTree(const Eigen::Vector4d* points, int num_points, int build_num_thre
: num_points(num_points),
points(points),
search_eps(-1.0),
index(new Index(*this, KdTreeBuilderOMP(build_num_threads))) {}
index(
is_omp_default() || build_num_threads == 1 ? //
new Index(*this, KdTreeBuilderOMP(build_num_threads)) //
: //
#ifdef GTSAM_POINTS_TBB //
new Index(*this, KdTreeBuilderTBB()) //
#else //
new Index(*this, KdTreeBuilder())
#endif
) {
}

KdTree::~KdTree() {}

Expand Down
28 changes: 25 additions & 3 deletions src/gtsam_points/factors/experimental/intensity_gradients_ivox.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
#include <gtsam_points/factors/intensity_gradients_ivox.hpp>

#include <Eigen/Eigen>
#include <gtsam_points/util/parallelism.hpp>

#ifdef GTSAM_POINTS_TBB
#include <tbb/parallel_for.h>
#endif

namespace gtsam_points {

Expand Down Expand Up @@ -44,9 +49,7 @@ void IntensityGradientsiVox::insert(const PointCloud& frame) {
grads[voxel.second->serial_id] = found->second;
}

// Add new points and estimate their normal and gradients
#pragma omp parallel for num_threads(num_threads) schedule(guided, 8)
for (int i = 0; i < flat_voxels.size(); i++) {
const auto pervoxel_task = [&](int i) {
const auto& voxel = flat_voxels[i];
const auto& gradients = flat_grads[i];
gradients->points.reserve(voxel.second->size());
Expand All @@ -66,6 +69,25 @@ void IntensityGradientsiVox::insert(const PointCloud& frame) {
gradients->normals.push_back(normal);
gradients->points.push_back(gradient);
}
};

if (is_omp_default() || num_threads == 1) {
// Add new points and estimate their normal and gradients
#pragma omp parallel for num_threads(num_threads) schedule(guided, 8)
for (int i = 0; i < flat_voxels.size(); i++) {
pervoxel_task(i);
}
} else {
#ifdef GTSAM_POINTS_TBB
tbb::parallel_for(tbb::blocked_range<int>(0, flat_voxels.size(), 8), [&](const tbb::blocked_range<int>& range) {
for (int i = range.begin(); i < range.end(); i++) {
pervoxel_task(i);
}
});
#else
std::cerr << "error: TBB is not available" << std::endl;
abort();
#endif
}
}

Expand Down
20 changes: 17 additions & 3 deletions src/test/test_kdtree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <gtsam_points/ann/kdtree.hpp>
#include <gtsam_points/ann/kdtree2.hpp>
#include <gtsam_points/types/point_cloud_cpu.hpp>
#include <gtsam_points/util/parallelism.hpp>

class KdTreeTest : public testing::Test, public testing::WithParamInterface<std::string> {
virtual void SetUp() {
Expand Down Expand Up @@ -65,23 +66,36 @@ TEST_F(KdTreeTest, LoadCheck) {
ASSERT_EQ(gt_sq_dists.size(), queries.size());
}

INSTANTIATE_TEST_SUITE_P(gtsam_points, KdTreeTest, testing::Values("KdTree", "KdTreeMT", "KdTree2", "KdTree2MT"), [](const auto& info) {
return info.param;
});
INSTANTIATE_TEST_SUITE_P(
gtsam_points,
KdTreeTest,
testing::Values("KdTree", "KdTreeMT", "KdTreeTBB", "KdTree2", "KdTree2MT", "KdTree2TBB"),
[](const auto& info) { return info.param; });

TEST_P(KdTreeTest, KdTreeTest) {
gtsam_points::NearestNeighborSearch::ConstPtr kdtree;

if (GetParam().find("TBB") != std::string::npos) {
gtsam_points::set_tbb_as_default();
} else {
gtsam_points::set_omp_as_default();
}

if (GetParam() == "KdTree") {
kdtree = std::make_shared<gtsam_points::KdTree>(points.data(), points.size());
} else if (GetParam() == "KdTreeMT") {
kdtree = std::make_shared<gtsam_points::KdTree>(points.data(), points.size(), 2);
} else if (GetParam() == "KdTreeTBB") {
kdtree = std::make_shared<gtsam_points::KdTree>(points.data(), points.size(), 2);
} else if (GetParam() == "KdTree2") {
auto pts = std::make_shared<gtsam_points::PointCloudCPU>(points);
kdtree = std::make_shared<gtsam_points::KdTree2<gtsam_points::PointCloud>>(pts);
} else if (GetParam() == "KdTree2MT") {
auto pts = std::make_shared<gtsam_points::PointCloudCPU>(points);
kdtree = std::make_shared<gtsam_points::KdTree2<gtsam_points::PointCloud>>(pts, 2);
} else if (GetParam() == "KdTree2TBB") {
auto pts = std::make_shared<gtsam_points::PointCloudCPU>(points);
kdtree = std::make_shared<gtsam_points::KdTree2<gtsam_points::PointCloud>>(pts, 2);
} else {
FAIL() << "Unknown KdTree type: " << GetParam();
}
Expand Down

0 comments on commit de843d2

Please sign in to comment.