Skip to content

Commit

Permalink
Add two pass CPU ArgReduction Engine to speed up argreductions with s…
Browse files Browse the repository at this point in the history
…ingle outputs
  • Loading branch information
manuelvogel12 committed Sep 26, 2024
1 parent 10b85f7 commit e961d19
Showing 1 changed file with 59 additions and 2 deletions.
61 changes: 59 additions & 2 deletions cpp/open3d/core/kernel/ReductionCPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,14 +191,25 @@ class CPUArgReductionEngine {
// elements. We need to keep track of the indices within each
// sub-iteration.
int64_t num_output_elements = indexer_.NumOutputElements();
if (num_output_elements <= 1) {
LaunchArgReductionKernelTwoPass(indexer_, reduce_func, identity);
} else {
LaunchArgReductionParallelDim(indexer_, reduce_func, identity);
}
}

template <typename scalar_t, typename func_t>
static void LaunchArgReductionParallelDim(const Indexer& indexer,
func_t reduce_func,
scalar_t identity) {
int64_t num_output_elements = indexer.NumOutputElements();
#pragma omp parallel for schedule(static) \
num_threads(utility::EstimateMaxThreads())
for (int64_t output_idx = 0; output_idx < num_output_elements;
output_idx++) {
// sub_indexer.NumWorkloads() == ipo.
// sub_indexer's workload_idx is indexer_'s ipo_idx.
Indexer sub_indexer = indexer_.GetPerOutputIndexer(output_idx);
// sub_indexer's workload_idx is indexer's ipo_idx.
Indexer sub_indexer = indexer.GetPerOutputIndexer(output_idx);
scalar_t dst_val = identity;
for (int64_t workload_idx = 0;
workload_idx < sub_indexer.NumWorkloads(); workload_idx++) {
Expand All @@ -213,6 +224,52 @@ class CPUArgReductionEngine {
}
}

/// Create num_threads workers to compute partial arg reductions
/// and then reduce to the final results.
/// This only applies to arg reduction op with one output.
template <typename scalar_t, typename func_t>
static void LaunchArgReductionKernelTwoPass(const Indexer& indexer,
func_t reduce_func,
scalar_t identity) {
if (indexer.NumOutputElements() > 1) {
utility::LogError(
"Internal error: two-pass arg reduction only works for "
"single-output arg reduction ops.");
}
int64_t num_workloads = indexer.NumWorkloads();
int64_t num_threads = utility::EstimateMaxThreads();
int64_t workload_per_thread =
(num_workloads + num_threads - 1) / num_threads;
std::vector<int64_t> thread_results_idx(num_threads, 0);
std::vector<scalar_t> thread_results_val(num_threads, identity);

#pragma omp parallel for schedule(static) \
num_threads(utility::EstimateMaxThreads())
for (int64_t thread_idx = 0; thread_idx < num_threads; ++thread_idx) {
int64_t start = thread_idx * workload_per_thread;
int64_t end = std::min(start + workload_per_thread, num_workloads);
scalar_t local_result_val = identity;
int64_t local_result_idx = 0;
for (int64_t workload_idx = start; workload_idx < end;
++workload_idx) {
int64_t src_idx = workload_idx;
scalar_t* src_val = reinterpret_cast<scalar_t*>(
indexer.GetInputPtr(0, workload_idx));
std::tie(local_result_idx, local_result_val) = reduce_func(
src_idx, *src_val, local_result_idx, local_result_val);
}
thread_results_val[thread_idx] = local_result_val;
thread_results_idx[thread_idx] = local_result_idx;
}
scalar_t dst_val = identity;
int64_t* dst_idx = reinterpret_cast<int64_t*>(indexer.GetOutputPtr(0));
for (int64_t thread_idx = 0; thread_idx < num_threads; ++thread_idx) {
std::tie(*dst_idx, dst_val) = reduce_func(
thread_results_idx[thread_idx],
thread_results_val[thread_idx], *dst_idx, dst_val);
}
}

private:
Indexer indexer_;
};
Expand Down

0 comments on commit e961d19

Please sign in to comment.