diff --git a/cpp/open3d/core/kernel/ReductionCPU.cpp b/cpp/open3d/core/kernel/ReductionCPU.cpp index 62cbd482606..710f1aaea24 100644 --- a/cpp/open3d/core/kernel/ReductionCPU.cpp +++ b/cpp/open3d/core/kernel/ReductionCPU.cpp @@ -191,14 +191,25 @@ class CPUArgReductionEngine { // elements. We need to keep track of the indices within each // sub-iteration. int64_t num_output_elements = indexer_.NumOutputElements(); + if (num_output_elements <= 1) { + LaunchArgReductionKernelTwoPass(indexer_, reduce_func, identity); + } else { + LaunchArgReductionParallelDim(indexer_, reduce_func, identity); + } + } + template + static void LaunchArgReductionParallelDim(const Indexer& indexer, + func_t reduce_func, + scalar_t identity) { + int64_t num_output_elements = indexer.NumOutputElements(); #pragma omp parallel for schedule(static) \ num_threads(utility::EstimateMaxThreads()) for (int64_t output_idx = 0; output_idx < num_output_elements; output_idx++) { // sub_indexer.NumWorkloads() == ipo. - // sub_indexer's workload_idx is indexer_'s ipo_idx. - Indexer sub_indexer = indexer_.GetPerOutputIndexer(output_idx); + // sub_indexer's workload_idx is indexer's ipo_idx. + Indexer sub_indexer = indexer.GetPerOutputIndexer(output_idx); scalar_t dst_val = identity; for (int64_t workload_idx = 0; workload_idx < sub_indexer.NumWorkloads(); workload_idx++) { @@ -213,6 +224,52 @@ class CPUArgReductionEngine { } } + /// Create num_threads workers to compute partial arg reductions + /// and then reduce to the final results. + /// This only applies to arg reduction op with one output. + template + static void LaunchArgReductionKernelTwoPass(const Indexer& indexer, + func_t reduce_func, + scalar_t identity) { + if (indexer.NumOutputElements() > 1) { + utility::LogError( + "Internal error: two-pass arg reduction only works for " + "single-output arg reduction ops."); + } + int64_t num_workloads = indexer.NumWorkloads(); + int64_t num_threads = utility::EstimateMaxThreads(); + int64_t workload_per_thread = + (num_workloads + num_threads - 1) / num_threads; + std::vector thread_results_idx(num_threads, 0); + std::vector thread_results_val(num_threads, identity); + +#pragma omp parallel for schedule(static) \ + num_threads(utility::EstimateMaxThreads()) + for (int64_t thread_idx = 0; thread_idx < num_threads; ++thread_idx) { + int64_t start = thread_idx * workload_per_thread; + int64_t end = std::min(start + workload_per_thread, num_workloads); + scalar_t local_result_val = identity; + int64_t local_result_idx = 0; + for (int64_t workload_idx = start; workload_idx < end; + ++workload_idx) { + int64_t src_idx = workload_idx; + scalar_t* src_val = reinterpret_cast( + indexer.GetInputPtr(0, workload_idx)); + std::tie(local_result_idx, local_result_val) = reduce_func( + src_idx, *src_val, local_result_idx, local_result_val); + } + thread_results_val[thread_idx] = local_result_val; + thread_results_idx[thread_idx] = local_result_idx; + } + scalar_t dst_val = identity; + int64_t* dst_idx = reinterpret_cast(indexer.GetOutputPtr(0)); + for (int64_t thread_idx = 0; thread_idx < num_threads; ++thread_idx) { + std::tie(*dst_idx, dst_val) = reduce_func( + thread_results_idx[thread_idx], + thread_results_val[thread_idx], *dst_idx, dst_val); + } + } + private: Indexer indexer_; };