diff --git a/cpp/oneapi/dal/algo/basic_statistics/backend/cpu/compute_kernel_dense.cpp b/cpp/oneapi/dal/algo/basic_statistics/backend/cpu/compute_kernel.cpp similarity index 84% rename from cpp/oneapi/dal/algo/basic_statistics/backend/cpu/compute_kernel_dense.cpp rename to cpp/oneapi/dal/algo/basic_statistics/backend/cpu/compute_kernel.cpp index e0d5bb6dda0..4405a2168d6 100644 --- a/cpp/oneapi/dal/algo/basic_statistics/backend/cpu/compute_kernel_dense.cpp +++ b/cpp/oneapi/dal/algo/basic_statistics/backend/cpu/compute_kernel.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2021 Intel Corporation +* Copyright 2023 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,7 +32,6 @@ namespace oneapi::dal::basic_statistics::backend { using dal::backend::context_cpu; -using method_t = method::dense; using task_t = task::compute; using input_t = compute_input; using result_t = compute_result; @@ -41,19 +40,26 @@ using descriptor_t = detail::descriptor_base; namespace daal_lom = daal::algorithms::low_order_moments; namespace interop = dal::backend::interop; -template -using daal_lom_batch_kernel_t = - daal_lom::internal::LowOrderMomentsBatchKernel; +template +using daal_method_constant = std::integral_constant; + +template +struct to_daal_method; + +template <> +struct to_daal_method : daal_method_constant {}; + +template <> +struct to_daal_method : daal_method_constant {}; + +template +using batch_kernel_t = + daal_lom::internal::LowOrderMomentsBatchKernel::value, Cpu>; template using daal_lom_online_kernel_t = daal_lom::internal::LowOrderMomentsOnlineKernel; -template -constexpr daal_lom::Method get_daal_method() { - return daal_lom::defaultDense; -} - template std::int64_t propose_block_size(std::int64_t row_count, std::int64_t col_count) { using idx_t = std::int64_t; @@ -174,10 +180,12 @@ result_t call_daal_kernel_with_weights(const context_cpu& ctx, return result; } -template +template result_t call_daal_kernel_without_weights(const context_cpu& ctx, const descriptor_t& desc, const table& data) { + auto daal_method = + std::is_same_v ? daal_lom::defaultDense : daal_lom::fastCSR; const auto daal_data = interop::convert_to_daal_table(data); auto daal_parameter = daal_lom::Parameter(get_daal_estimates_to_compute(desc)); @@ -187,13 +195,14 @@ result_t call_daal_kernel_without_weights(const context_cpu& ctx, daal_input.set(daal_lom::InputId::data, daal_data); interop::status_to_exception( - daal_result.allocate(&daal_input, &daal_parameter, get_daal_method())); + daal_result.allocate(&daal_input, &daal_parameter, daal_method)); - interop::status_to_exception( - interop::call_daal_kernel(ctx, - daal_data.get(), - &daal_result, - &daal_parameter)); + interop::status_to_exception(dal::backend::dispatch_by_cpu(ctx, [&](auto cpu) { + return batch_kernel_t::value, + Method>() + .compute(daal_data.get(), &daal_result, &daal_parameter); + })); auto result = get_result(desc, daal_result).set_result_options(desc.get_result_options()); @@ -201,7 +210,7 @@ result_t call_daal_kernel_without_weights(const context_cpu& ctx, return result; } -template +template static result_t compute(const context_cpu& ctx, const descriptor_t& desc, const input_t& input) { if (input.get_weights().has_data()) { return call_daal_kernel_with_weights(ctx, @@ -210,20 +219,22 @@ static result_t compute(const context_cpu& ctx, const descriptor_t& desc, const input.get_weights()); } else { - return call_daal_kernel_without_weights(ctx, desc, input.get_data()); + return call_daal_kernel_without_weights(ctx, desc, input.get_data()); } } -template -struct compute_kernel_cpu { +template +struct compute_kernel_cpu { result_t operator()(const context_cpu& ctx, const descriptor_t& desc, const input_t& input) const { - return compute(ctx, desc, input); + return compute(ctx, desc, input); } }; -template struct compute_kernel_cpu; -template struct compute_kernel_cpu; +template struct compute_kernel_cpu; +template struct compute_kernel_cpu; +template struct compute_kernel_cpu; +template struct compute_kernel_cpu; } // namespace oneapi::dal::basic_statistics::backend diff --git a/cpp/oneapi/dal/algo/basic_statistics/backend/cpu/compute_kernel.hpp b/cpp/oneapi/dal/algo/basic_statistics/backend/cpu/compute_kernel.hpp index dc228b7e6af..5b2d7bc3267 100644 --- a/cpp/oneapi/dal/algo/basic_statistics/backend/cpu/compute_kernel.hpp +++ b/cpp/oneapi/dal/algo/basic_statistics/backend/cpu/compute_kernel.hpp @@ -17,15 +17,17 @@ #pragma once #include "oneapi/dal/algo/basic_statistics/compute_types.hpp" +#include "oneapi/dal/table/csr.hpp" #include "oneapi/dal/backend/dispatcher.hpp" namespace oneapi::dal::basic_statistics::backend { template struct compute_kernel_cpu { + using input_t = compute_input; compute_result operator()(const dal::backend::context_cpu& ctx, const detail::descriptor_base& params, - const compute_input& input) const; + const input_t& input) const; }; } // namespace oneapi::dal::basic_statistics::backend diff --git a/cpp/oneapi/dal/algo/basic_statistics/backend/gpu/compute_kernel.hpp b/cpp/oneapi/dal/algo/basic_statistics/backend/gpu/compute_kernel.hpp index 8dffef5a44f..b00b64282e9 100644 --- a/cpp/oneapi/dal/algo/basic_statistics/backend/gpu/compute_kernel.hpp +++ b/cpp/oneapi/dal/algo/basic_statistics/backend/gpu/compute_kernel.hpp @@ -17,15 +17,17 @@ #pragma once #include "oneapi/dal/algo/basic_statistics/compute_types.hpp" +#include "oneapi/dal/table/csr.hpp" #include "oneapi/dal/backend/dispatcher.hpp" namespace oneapi::dal::basic_statistics::backend { template struct compute_kernel_gpu { + using input_t = compute_input; compute_result operator()(const dal::backend::context_gpu& ctx, const detail::descriptor_base& params, - const compute_input& input) const; + const input_t& input) const; }; } // namespace oneapi::dal::basic_statistics::backend diff --git a/cpp/oneapi/dal/algo/basic_statistics/backend/gpu/compute_kernel_csr_impl.hpp b/cpp/oneapi/dal/algo/basic_statistics/backend/gpu/compute_kernel_csr_impl.hpp new file mode 100644 index 00000000000..8bbf1d25647 --- /dev/null +++ b/cpp/oneapi/dal/algo/basic_statistics/backend/gpu/compute_kernel_csr_impl.hpp @@ -0,0 +1,163 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#pragma once + +#include "oneapi/dal/algo/basic_statistics/backend/gpu/compute_kernel.hpp" +#include "oneapi/dal/backend/primitives/utils.hpp" +#include "oneapi/dal/table/csr.hpp" +#include "oneapi/dal/util/common.hpp" +#include "oneapi/dal/detail/policy.hpp" +#include "oneapi/dal/backend/communicator.hpp" + +#ifdef ONEDAL_DATA_PARALLEL + +namespace oneapi::dal::basic_statistics::backend { + +namespace de = dal::detail; +namespace bk = dal::backend; +namespace pr = dal::backend::primitives; + +enum stat { min, max, sum, sum2, sum2_cent, mean, moment2, variance, stddev, variation }; + +template +class compute_kernel_csr_impl { + using method_t = method::sparse; + using task_t = task::compute; + using comm_t = bk::communicator; + using input_t = compute_input; + using result_t = compute_result; + using descriptor_t = detail::descriptor_base; + +public: + result_t operator()(const bk::context_gpu& ctx, const descriptor_t& desc, const input_t& input); + +private: + // Number of different basic statistics + static constexpr std::int32_t res_opt_count_ = 10; + // An array of basic statistics + const result_option_id res_options_[res_opt_count_] = { result_options::min, + result_options::max, + result_options::sum, + result_options::sum_squares, + result_options::sum_squares_centered, + result_options::mean, + result_options::second_order_raw_moment, + result_options::variance, + result_options::standard_deviation, + result_options::variation }; + + result_t get_result(sycl::queue q, + const pr::ndarray computed_result, + result_option_id requested_results, + const std::vector& deps = {}) { + result_t res; + std::vector res_events; + res.set_result_options(requested_results); + if (requested_results.test(result_options::min)) { + auto index = get_result_option_index(result_options::min); + auto [res_table, event] = get_result_table(q, computed_result, index, deps); + res.set_min(res_table); + res_events.push_back(event); + } + if (requested_results.test(result_options::max)) { + auto index = get_result_option_index(result_options::max); + auto [res_table, event] = get_result_table(q, computed_result, index, deps); + res.set_max(res_table); + res_events.push_back(event); + } + if (requested_results.test(result_options::sum)) { + auto index = get_result_option_index(result_options::sum); + auto [res_table, event] = get_result_table(q, computed_result, index, deps); + res.set_sum(res_table); + res_events.push_back(event); + } + if (requested_results.test(result_options::sum_squares)) { + auto index = get_result_option_index(result_options::sum_squares); + auto [res_table, event] = get_result_table(q, computed_result, index, deps); + res.set_sum_squares(res_table); + res_events.push_back(event); + } + if (requested_results.test(result_options::sum_squares_centered)) { + auto index = get_result_option_index(result_options::sum_squares_centered); + auto [res_table, event] = get_result_table(q, computed_result, index, deps); + res.set_sum_squares_centered(res_table); + res_events.push_back(event); + } + if (requested_results.test(result_options::mean)) { + auto index = get_result_option_index(result_options::mean); + auto [res_table, event] = get_result_table(q, computed_result, index, deps); + res.set_mean(res_table); + res_events.push_back(event); + } + if (requested_results.test(result_options::second_order_raw_moment)) { + auto index = get_result_option_index(result_options::second_order_raw_moment); + auto [res_table, event] = get_result_table(q, computed_result, index, deps); + res.set_second_order_raw_moment(res_table); + res_events.push_back(event); + } + if (requested_results.test(result_options::variance)) { + auto index = get_result_option_index(result_options::variance); + auto [res_table, event] = get_result_table(q, computed_result, index, deps); + res.set_variance(res_table); + res_events.push_back(event); + } + if (requested_results.test(result_options::standard_deviation)) { + auto index = get_result_option_index(result_options::standard_deviation); + auto [res_table, event] = get_result_table(q, computed_result, index, deps); + res.set_standard_deviation(res_table); + res_events.push_back(event); + } + if (requested_results.test(result_options::variation)) { + auto index = get_result_option_index(result_options::variation); + auto [res_table, event] = get_result_table(q, computed_result, index, deps); + res.set_variation(res_table); + res_events.push_back(event); + } + sycl::event::wait_and_throw(res_events); + return res; + } + + std::tuple get_result_table(sycl::queue q, + const pr::ndarray computed_result, + std::int32_t index, + const std::vector& deps = {}) { + ONEDAL_ASSERT(computed_result.has_data()); + auto column_count = computed_result.get_dimension(1); + const auto arr = dal::array::empty(column_count); + const auto res_arr_ptr = arr.get_mutable_data(); + const auto computed_res_ptr = computed_result.get_data() + index * column_count; + auto event = + dal::backend::copy_usm2host(q, res_arr_ptr, computed_res_ptr, column_count, deps); + return std::make_tuple(homogen_table::wrap(arr, 1, column_count), event); + } + + std::int32_t get_result_option_index(result_option_id opt) { + std::int32_t index = 0; + while (!opt.test(res_options_[index])) + ++index; + return index; + } + + sycl::event finalize_for_distr(sycl::queue& q, + comm_t& communicator, + pr::ndarray& results, + const input_t& input, + const std::vector& deps); +}; + +} // namespace oneapi::dal::basic_statistics::backend +#endif // ONEDAL_DATA_PARALLEL diff --git a/cpp/oneapi/dal/algo/basic_statistics/backend/gpu/compute_kernel_csr_impl_dpc.cpp b/cpp/oneapi/dal/algo/basic_statistics/backend/gpu/compute_kernel_csr_impl_dpc.cpp new file mode 100644 index 00000000000..3367947d26f --- /dev/null +++ b/cpp/oneapi/dal/algo/basic_statistics/backend/gpu/compute_kernel_csr_impl_dpc.cpp @@ -0,0 +1,383 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "oneapi/dal/algo/basic_statistics/backend/gpu/compute_kernel.hpp" +#include "oneapi/dal/algo/basic_statistics/backend/gpu/compute_kernel_csr_impl.hpp" +#include "oneapi/dal/table/csr_accessor.hpp" +#include "oneapi/dal/backend/primitives/utils.hpp" +#include "oneapi/dal/util/common.hpp" +#include "oneapi/dal/detail/policy.hpp" +#include "oneapi/dal/backend/communicator.hpp" + +#ifdef ONEDAL_DATA_PARALLEL + +namespace oneapi::dal::basic_statistics::backend { + +namespace de = dal::detail; +namespace bk = dal::backend; +namespace pr = dal::backend::primitives; + +using method_t = method::sparse; +using task_t = task::compute; +using comm_t = bk::communicator; +using input_t = compute_input; +using result_t = compute_result; +using descriptor_t = detail::descriptor_base; + +template +sycl::event compute_kernel_csr_impl::finalize_for_distr( + sycl::queue& q, + comm_t& communicator, + pr::ndarray& results, + const input_t& input, + const std::vector& deps) { + auto result_ptr = results.get_mutable_data(); + const csr_table csr_tdata = static_cast(input.get_data()); + auto [csr_data, column_indices, row_offsets] = + csr_accessor(csr_tdata).pull(q, { 0, -1 }, sparse_indexing::zero_based); + const auto column_count = csr_tdata.get_column_count(); + const auto row_count = csr_tdata.get_row_count(); + const auto nonzero_count = csr_tdata.get_non_zero_count(); + + auto host_results = results.flatten(q, deps); + communicator + .allreduce(host_results.get_slice(stat::min * column_count, column_count), + spmd::reduce_op::min) + .wait(); + communicator + .allreduce(host_results.get_slice(stat::max * column_count, column_count), + spmd::reduce_op::max) + .wait(); + communicator + .allreduce(host_results.get_slice(stat::sum * column_count, column_count), + spmd::reduce_op::sum) + .wait(); + communicator + .allreduce(host_results.get_slice(stat::sum2 * column_count, column_count), + spmd::reduce_op::sum) + .wait(); + communicator + .allreduce(host_results.get_slice(stat::moment2 * column_count, column_count), + spmd::reduce_op::sum) + .wait(); + + results.assign_from_host(q, host_results.get_data(), res_opt_count_ * column_count) + .wait_and_throw(); + + auto csr_data_ptr = csr_data.get_data(); + auto column_indices_ptr = column_indices.get_data(); + auto distr_range = sycl::range<1>(column_count); + auto calc_s2c_event = q.submit([&](sycl::handler& cgh) { + cgh.depends_on(deps); + cgh.parallel_for(distr_range, [=](std::int64_t col_idx) { + auto mean_val = result_ptr[stat::sum * column_count + col_idx] / row_count; + result_ptr[stat::mean * column_count + col_idx] = mean_val; + Float sum2_cent = Float(0); + std::int32_t nnz_row_count = 0; + for (std::int32_t data_idx = 0; data_idx < nonzero_count; ++data_idx) { + if (col_idx == column_indices_ptr[data_idx]) { + auto val = csr_data_ptr[data_idx]; + sum2_cent += (val - mean_val) * (val - mean_val); + nnz_row_count += 1; + } + } + // For zero values sum2_cent is just square of mean value + sum2_cent += (row_count - nnz_row_count) * mean_val * mean_val; + + result_ptr[stat::sum2_cent * column_count + col_idx] = sum2_cent; + }); + }); + + host_results = results.flatten(q, { calc_s2c_event }); + communicator + .allreduce(host_results.get_slice(stat::sum2_cent * column_count, column_count), + spmd::reduce_op::sum) + .wait(); + auto allreduce_event = + results.assign_from_host(q, host_results.get_data(), res_opt_count_ * column_count); + + auto final_event = q.submit([&](sycl::handler& cgh) { + cgh.depends_on({ allreduce_event }); + cgh.parallel_for(distr_range, [=](std::int64_t col_idx) { + auto mean_val = result_ptr[stat::mean * column_count + col_idx]; + result_ptr[stat::variance * column_count + col_idx] = + result_ptr[stat::sum2_cent * column_count + col_idx] / (row_count - 1); + result_ptr[stat::stddev * column_count + col_idx] = + sycl::sqrt(result_ptr[stat::variance * column_count + col_idx]); + result_ptr[stat::variation * column_count + col_idx] = + result_ptr[stat::stddev * column_count + col_idx] / mean_val; + }); + }); + + return final_event; +} + +template +result_t compute_kernel_csr_impl::operator()(const bk::context_gpu& ctx, + const descriptor_t& desc, + const input_t& input) { + auto queue = ctx.get_queue(); + const auto table = input.get_data(); + ONEDAL_ASSERT(table.get_kind() == csr_table::kind()); + const csr_table csr_tdata = static_cast(table); + comm_t comm = ctx.get_communicator(); + const bool distr_mode = comm.get_rank_count() > 1; + const auto column_count = csr_tdata.get_column_count(); + const auto row_count = csr_tdata.get_row_count(); + auto result_options = desc.get_result_options(); + const auto nonzero_count = csr_tdata.get_non_zero_count(); + auto [csr_data, column_indices, row_offsets] = + csr_accessor(csr_tdata).pull(queue, + { 0, -1 }, + sparse_indexing::zero_based, + sycl::usm::alloc::device); + auto csr_data_ptr = csr_data.get_data(); + auto column_indices_ptr = column_indices.get_data(); + + using limits_t = std::numeric_limits; + constexpr Float maximum = limits_t::max(); + + // number of columns processed in one group + const auto local_size = bk::device_max_wg_size(queue); + // number of data elements processed by one working item + constexpr std::int64_t n_items_per_work = 512; + const auto num_data_blocks = nonzero_count / (local_size * n_items_per_work) + + bool(nonzero_count % (local_size * n_items_per_work)); + const auto num_col_blocks = column_count / local_size + bool(column_count % local_size); + + auto result_data = + pr::ndarray::empty(queue, + { num_data_blocks * res_opt_count_, column_count }, + sycl::usm::alloc::device); + auto result_data_ptr = result_data.get_mutable_data(); + + const auto nd_range = + bk::make_multiple_nd_range_3d({ num_data_blocks, num_col_blocks, local_size }, + { 1, 1, local_size }); + const auto merge_range = bk::make_multiple_nd_range_1d(column_count, 1); + // First order kernel calculates basic statistics for min, max, sum, sum squares. + // The computation is splitted by blocks for columns and data to achieve best performance + // on GPU. + auto first_order_event = queue.submit([&](sycl::handler& cgh) { + sycl::local_accessor local_res_buf(local_size * res_opt_count_, cgh); + cgh.parallel_for(nd_range, [=](auto item) { + std::int64_t block_id = item.get_global_id(0); + std::int64_t col_ofs = item.get_global_id(1) * local_size; + std::int64_t local_id = item.get_local_id(2); + std::int64_t data_ofs = block_id * local_size * n_items_per_work; + if (col_ofs >= column_count) { + return; + } + Float* work_group_buf = + local_res_buf.template get_multi_ptr().get_raw(); + auto local_buf = work_group_buf + local_id * res_opt_count_; + local_buf[stat::min] = maximum; + local_buf[stat::max] = -maximum; + local_buf[stat::sum] = Float(0); + local_buf[stat::sum2] = Float(0); + local_buf[stat::sum2_cent] = Float(0); + item.barrier(sycl::access::fence_space::local_space); + for (std::int64_t idx = 0; idx < n_items_per_work; ++idx) { + auto data_idx = data_ofs + local_id * n_items_per_work + idx; + if (data_idx >= nonzero_count) { + break; + } + auto col_idx = column_indices_ptr[data_idx] - col_ofs; + auto val = csr_data_ptr[data_idx]; + if (col_idx >= 0 && col_idx < local_size) { + sycl::atomic_ref + col_min(local_res_buf[col_idx * res_opt_count_ + stat::min]); + col_min.fetch_min(val); + sycl::atomic_ref + col_max(local_res_buf[col_idx * res_opt_count_ + stat::max]); + col_max.fetch_max(val); + sycl::atomic_ref + col_sum(local_res_buf[col_idx * res_opt_count_ + stat::sum]); + col_sum.fetch_add(val); + sycl::atomic_ref + col_sum2(local_res_buf[col_idx * res_opt_count_ + stat::sum2]); + col_sum2.fetch_add(val * val); + } + } + item.barrier(sycl::access::fence_space::local_space); + if ((local_id + col_ofs) >= column_count) { + return; + } + const auto col_idx = col_ofs + local_id; + const auto block_idx = block_id * res_opt_count_ * column_count; + result_data_ptr[stat::min * column_count + block_idx + col_idx] = local_buf[stat::min]; + result_data_ptr[stat::max * column_count + block_idx + col_idx] = local_buf[stat::max]; + result_data_ptr[stat::sum * column_count + block_idx + col_idx] = local_buf[stat::sum]; + result_data_ptr[stat::sum2 * column_count + block_idx + col_idx] = + local_buf[stat::sum2]; + }); + }); + + // First order merge kernel merges results for data blocks computed on previous step. + auto merge_event = queue.submit([&](sycl::handler& cgh) { + cgh.depends_on({ first_order_event }); + cgh.parallel_for(merge_range, [=](auto item) { + const auto col_idx = item.get_global_id(); + auto cur_min = result_data_ptr[stat::min * column_count + col_idx]; + auto cur_max = result_data_ptr[stat::max * column_count + col_idx]; + auto cur_sum = result_data_ptr[stat::sum * column_count + col_idx]; + auto cur_sum2 = result_data_ptr[stat::sum2 * column_count + col_idx]; + for (std::int64_t block_id = 1; block_id < num_data_blocks; ++block_id) { + const auto block_idx = block_id * res_opt_count_ * column_count; + cur_min = + sycl::min(cur_min, + result_data_ptr[stat::min * column_count + block_idx + col_idx]); + cur_max = + sycl::max(cur_max, + result_data_ptr[stat::max * column_count + block_idx + col_idx]); + cur_sum += result_data_ptr[stat::sum * column_count + block_idx + col_idx]; + cur_sum2 += result_data_ptr[stat::sum2 * column_count + block_idx + col_idx]; + } + result_data_ptr[stat::min * column_count + col_idx] = cur_min; + result_data_ptr[stat::max * column_count + col_idx] = cur_max; + result_data_ptr[stat::sum * column_count + col_idx] = cur_sum; + result_data_ptr[stat::sum2 * column_count + col_idx] = cur_sum2; + }); + }); + + // Second order kernel computes sum squares centered. + // And additionally computes the number of non-zero rows + // in order to proper finalize min, max, sum squares centered statistics, + // since zero values are invovled to the results of them. + auto second_order_event = queue.submit([&](sycl::handler& cgh) { + cgh.depends_on({ merge_event }); + sycl::local_accessor local_res_buf(local_size, cgh); + sycl::local_accessor mean_vals_buf(local_size, cgh); + sycl::local_accessor row_counter(local_size, cgh); + cgh.parallel_for(nd_range, [=](auto item) { + std::int64_t block_id = item.get_global_id(0); + std::int64_t col_ofs = item.get_global_id(1) * local_size; + std::int64_t local_id = item.get_local_id(2); + std::int64_t data_ofs = block_id * local_size * n_items_per_work; + if (col_ofs >= column_count) { + return; + } + Float* work_group_buf = + local_res_buf.template get_multi_ptr().get_raw(); + Float* mean_vals = + mean_vals_buf.template get_multi_ptr().get_raw(); + std::int32_t* row_counter_buf = + row_counter.template get_multi_ptr().get_raw(); + mean_vals[local_id] = + result_data_ptr[stat::sum * column_count + col_ofs + local_id] / row_count; + row_counter_buf[local_id] = 0; + work_group_buf[local_id] = Float(0); + item.barrier(sycl::access::fence_space::local_space); + // Merge results of first order moments + for (std::int64_t idx = 0; idx < n_items_per_work; ++idx) { + auto data_idx = data_ofs + local_id * n_items_per_work + idx; + if (data_idx >= nonzero_count) { + break; + } + auto col_idx = column_indices_ptr[data_idx] - col_ofs; + auto val = csr_data_ptr[data_idx]; + if (col_idx >= 0 && col_idx < local_size) { + sycl::atomic_ref + col_sum2_cent(local_res_buf[col_idx]); + auto mean_val = mean_vals[col_idx]; + col_sum2_cent.fetch_add((val - mean_val) * (val - mean_val)); + sycl::atomic_ref + row_counter_at(row_counter[col_idx]); + row_counter_at.fetch_add(1); + } + } + item.barrier(sycl::access::fence_space::local_space); + if ((local_id + col_ofs) >= column_count) { + return; + } + const auto col_idx = col_ofs + local_id; + const auto block_idx = block_id * res_opt_count_ * column_count; + result_data_ptr[stat::sum2_cent * column_count + block_idx + col_idx] = + work_group_buf[local_id]; + // Mean is no need to merge it is the same for all blocks + result_data_ptr[stat::mean * column_count + block_idx + col_idx] = mean_vals[local_id]; + // Temporary save row_counts into varitaion placeholder in order to merge it + result_data_ptr[stat::variation * column_count + block_idx + col_idx] = + row_counter_buf[local_id]; + }); + }); + + // Second order merge kernel finalizes computations on basic statistics and + // merges sum squares centered statistic among data blocks. + auto second_merge_event = queue.submit([&](sycl::handler& cgh) { + cgh.depends_on({ second_order_event }); + cgh.parallel_for(merge_range, [=](auto item) { + const auto col_idx = item.get_global_id(); + auto cur_sum2_cent = result_data_ptr[stat::sum2_cent * column_count + col_idx]; + auto mean_val = result_data_ptr[stat::mean * column_count + col_idx]; + auto cur_row_count = result_data_ptr[stat::variation * column_count + col_idx]; + for (std::int64_t block_id = 1; block_id < num_data_blocks; ++block_id) { + const auto block_idx = block_id * res_opt_count_ * column_count; + cur_sum2_cent += + result_data_ptr[stat::sum2_cent * column_count + block_idx + col_idx]; + cur_row_count += + result_data_ptr[stat::variation * column_count + block_idx + col_idx]; + } + // In case when there are zeros in column it must be compared with min and max + // And added to sum2_cent + if (row_count != cur_row_count) { + auto cur_min = result_data_ptr[stat::min * column_count + col_idx]; + auto cur_max = result_data_ptr[stat::max * column_count + col_idx]; + result_data_ptr[stat::min * column_count + col_idx] = sycl::min(cur_min, 0); + result_data_ptr[stat::max * column_count + col_idx] = sycl::max(cur_max, 0); + cur_sum2_cent += Float(row_count - cur_row_count) * mean_val * mean_val; + } + result_data_ptr[stat::sum2_cent * column_count + col_idx] = cur_sum2_cent; + result_data_ptr[stat::moment2 * column_count + col_idx] = + result_data_ptr[stat::sum2 * column_count + col_idx] / row_count; + result_data_ptr[stat::variance * column_count + col_idx] = + cur_sum2_cent / (row_count - 1); + result_data_ptr[stat::stddev * column_count + col_idx] = + sycl::sqrt(result_data_ptr[stat::variance * column_count + col_idx]); + result_data_ptr[stat::variation * column_count + col_idx] = + result_data_ptr[stat::stddev * column_count + col_idx] / mean_val; + }); + }); + + if (distr_mode) { + second_merge_event = finalize_for_distr(queue, comm, result_data, input, { merge_event }); + } + return get_result(queue, result_data, result_options, { second_merge_event }); +} + +template class compute_kernel_csr_impl; +template class compute_kernel_csr_impl; + +} // namespace oneapi::dal::basic_statistics::backend +#endif // ONEDAL_DATA_PARALLEL diff --git a/cpp/oneapi/dal/algo/basic_statistics/backend/gpu/compute_kernel_sparse_dpc.cpp b/cpp/oneapi/dal/algo/basic_statistics/backend/gpu/compute_kernel_sparse_dpc.cpp new file mode 100644 index 00000000000..84d7d169992 --- /dev/null +++ b/cpp/oneapi/dal/algo/basic_statistics/backend/gpu/compute_kernel_sparse_dpc.cpp @@ -0,0 +1,43 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "oneapi/dal/algo/basic_statistics/backend/gpu/compute_kernel.hpp" +#include "oneapi/dal/algo/basic_statistics/backend/gpu/compute_kernel_csr_impl.hpp" +#include "oneapi/dal/backend/primitives/utils.hpp" +#include "oneapi/dal/detail/policy.hpp" +#include "oneapi/dal/detail/common.hpp" + +namespace oneapi::dal::basic_statistics::backend { + +using method_t = method::sparse; +using task_t = task::compute; +using input_t = compute_input; +using result_t = compute_result; +using descriptor_t = detail::descriptor_base; + +template +struct compute_kernel_gpu { + result_t operator()(const bk::context_gpu& ctx, + const descriptor_t& desc, + const input_t& input) const { + return compute_kernel_csr_impl{}(ctx, desc, input); + } +}; + +template struct compute_kernel_gpu; +template struct compute_kernel_gpu; + +} // namespace oneapi::dal::basic_statistics::backend diff --git a/cpp/oneapi/dal/algo/basic_statistics/common.hpp b/cpp/oneapi/dal/algo/basic_statistics/common.hpp index 9542c278ef3..a0fb211e8a4 100644 --- a/cpp/oneapi/dal/algo/basic_statistics/common.hpp +++ b/cpp/oneapi/dal/algo/basic_statistics/common.hpp @@ -43,12 +43,16 @@ namespace v1 { /// Tag-type that denotes dense computational method. struct dense {}; +/// Tag-type that denotes sparse computational method. +struct sparse {}; + /// Alias tag-type for dense computational method. using by_default = dense; } // namespace v1 using v1::dense; +using v1::sparse; using v1::by_default; } // namespace method @@ -116,7 +120,7 @@ template constexpr bool is_valid_float_v = dal::detail::is_one_of_v; template -constexpr bool is_valid_method_v = dal::detail::is_one_of_v; +constexpr bool is_valid_method_v = dal::detail::is_one_of_v; template constexpr bool is_valid_task_v = dal::detail::is_one_of_v; diff --git a/cpp/oneapi/dal/algo/basic_statistics/compute_types.cpp b/cpp/oneapi/dal/algo/basic_statistics/compute_types.cpp index a8d89d55031..ef40d48f239 100644 --- a/cpp/oneapi/dal/algo/basic_statistics/compute_types.cpp +++ b/cpp/oneapi/dal/algo/basic_statistics/compute_types.cpp @@ -15,6 +15,7 @@ *******************************************************************************/ #include "oneapi/dal/algo/basic_statistics/compute_types.hpp" +#include "oneapi/dal/table/csr.hpp" #include "oneapi/dal/detail/common.hpp" namespace oneapi::dal::basic_statistics { diff --git a/cpp/oneapi/dal/algo/basic_statistics/detail/compute_ops.cpp b/cpp/oneapi/dal/algo/basic_statistics/detail/compute_ops.cpp index 08874b3c9fb..155ce01280a 100644 --- a/cpp/oneapi/dal/algo/basic_statistics/detail/compute_ops.cpp +++ b/cpp/oneapi/dal/algo/basic_statistics/detail/compute_ops.cpp @@ -23,9 +23,10 @@ namespace v1 { template struct compute_ops_dispatcher { + using input_t = compute_input; compute_result operator()(const Policy& policy, const descriptor_base& desc, - const compute_input& input) const { + const input_t& input) const { using kernel_dispatcher_t = dal::backend::kernel_dispatcher< // KERNEL_SINGLE_NODE_CPU(backend::compute_kernel_cpu)>; return kernel_dispatcher_t()(policy, desc, input); @@ -38,6 +39,8 @@ struct compute_ops_dispatcher { INSTANTIATE(float, method::dense, task::compute) INSTANTIATE(double, method::dense, task::compute) +INSTANTIATE(float, method::sparse, task::compute) +INSTANTIATE(double, method::sparse, task::compute) } // namespace v1 } // namespace oneapi::dal::basic_statistics::detail diff --git a/cpp/oneapi/dal/algo/basic_statistics/detail/compute_ops.hpp b/cpp/oneapi/dal/algo/basic_statistics/detail/compute_ops.hpp index 6b69d2f6b56..c673e1e04c0 100644 --- a/cpp/oneapi/dal/algo/basic_statistics/detail/compute_ops.hpp +++ b/cpp/oneapi/dal/algo/basic_statistics/detail/compute_ops.hpp @@ -17,6 +17,7 @@ #pragma once #include "oneapi/dal/algo/basic_statistics/compute_types.hpp" +#include "oneapi/dal/table/csr.hpp" #include "oneapi/dal/detail/error_messages.hpp" namespace oneapi::dal::basic_statistics::detail { @@ -24,9 +25,10 @@ namespace v1 { template struct compute_ops_dispatcher { + using input_t = compute_input; compute_result operator()(const Context&, const descriptor_base&, - const compute_input&) const; + const input_t&) const; }; template diff --git a/cpp/oneapi/dal/algo/basic_statistics/detail/compute_ops_dpc.cpp b/cpp/oneapi/dal/algo/basic_statistics/detail/compute_ops_dpc.cpp index 16bf295b504..883ee8db3fd 100644 --- a/cpp/oneapi/dal/algo/basic_statistics/detail/compute_ops_dpc.cpp +++ b/cpp/oneapi/dal/algo/basic_statistics/detail/compute_ops_dpc.cpp @@ -24,9 +24,10 @@ namespace v1 { template struct compute_ops_dispatcher { + using input_t = compute_input; compute_result operator()(const Policy& policy, const descriptor_base& params, - const compute_input& input) const { + const input_t& input) const { using kernel_dispatcher_t = dal::backend::kernel_dispatcher< KERNEL_SINGLE_NODE_CPU(backend::compute_kernel_cpu), KERNEL_UNIVERSAL_SPMD_GPU(backend::compute_kernel_gpu)>; @@ -42,6 +43,8 @@ struct compute_ops_dispatcher { INSTANTIATE(float, method::dense, task::compute) INSTANTIATE(double, method::dense, task::compute) +INSTANTIATE(float, method::sparse, task::compute) +INSTANTIATE(double, method::sparse, task::compute) } // namespace v1 } // namespace oneapi::dal::basic_statistics::detail diff --git a/cpp/oneapi/dal/algo/basic_statistics/test/batch.cpp b/cpp/oneapi/dal/algo/basic_statistics/test/batch.cpp index 021a84668a7..273b408ce8b 100644 --- a/cpp/oneapi/dal/algo/basic_statistics/test/batch.cpp +++ b/cpp/oneapi/dal/algo/basic_statistics/test/batch.cpp @@ -63,4 +63,24 @@ TEMPLATE_LIST_TEST_M(basic_statistics_batch_test, this->general_checks(data, weights, compute_mode); } +TEMPLATE_LIST_TEST_M(basic_statistics_batch_test, + "basic_statistics common CSR flow", + "[basic_statistics][integration][batch]", + basic_statistics_sparse_types) { + SKIP_IF(this->not_float64_friendly()); + const auto data = GENERATE_COPY(te::csr_table_builder(5, 5), + te::csr_table_builder(7, 10), + te::csr_table_builder(100, 100), + te::csr_table_builder(1000, 1000), + te::csr_table_builder(150000, 1000)); + SKIP_IF(this->not_cpu_friendly(data)); + const bs::result_option_id res_min_max = result_options::min | result_options::max; + const bs::result_option_id res_mean_varc = result_options::mean | result_options::variance; + const bs::result_option_id res_all = + bs::result_option_id(dal::result_option_id_base(mask_full)); + const bs::result_option_id compute_mode = GENERATE_COPY(res_min_max, res_mean_varc, res_all); + + this->csr_general_checks(data, compute_mode); +} + } // namespace oneapi::dal::basic_statistics::test diff --git a/cpp/oneapi/dal/algo/basic_statistics/test/fixture.hpp b/cpp/oneapi/dal/algo/basic_statistics/test/fixture.hpp index 9f3672b1e8f..4bf36cc11e0 100644 --- a/cpp/oneapi/dal/algo/basic_statistics/test/fixture.hpp +++ b/cpp/oneapi/dal/algo/basic_statistics/test/fixture.hpp @@ -18,19 +18,23 @@ #include +#include "oneapi/dal/table/csr.hpp" #include "oneapi/dal/algo/basic_statistics/compute.hpp" #include "oneapi/dal/algo/basic_statistics/partial_compute.hpp" #include "oneapi/dal/algo/basic_statistics/finalize_compute.hpp" #include "oneapi/dal/test/engine/common.hpp" #include "oneapi/dal/test/engine/fixtures.hpp" #include "oneapi/dal/test/engine/dataframe.hpp" +#include "oneapi/dal/test/engine/csr_table_builder.hpp" #include "oneapi/dal/test/engine/math.hpp" +#include "oneapi/dal/table/csr_accessor.hpp" namespace oneapi::dal::basic_statistics::test { namespace te = dal::test::engine; namespace la = te::linalg; namespace bs = oneapi::dal::basic_statistics; +namespace dal = oneapi::dal; constexpr inline std::uint64_t mask_full = 0xffffffffffffffff; @@ -42,6 +46,7 @@ class basic_statistics_test : public te::crtp_algo_fixture { using input_t = bs::compute_input<>; using result_t = bs::compute_result<>; using descriptor_t = bs::descriptor; + using csr_table = dal::csr_table; auto get_descriptor(bs::result_option_id compute_mode) const { return descriptor_t{}.set_result_options(compute_mode); @@ -100,6 +105,27 @@ class basic_statistics_test : public te::crtp_algo_fixture { check_for_exception_for_non_requested_results(compute_mode, compute_result); } + void csr_general_checks(const te::csr_table_builder& data, bs::result_option_id compute_mode) { + const auto desc = + bs::descriptor{}.set_result_options( + compute_mode); + const auto csr_table = data.build_csr_table(this->get_policy()); + const auto dense_table = data.build_dense_table(); + + auto compute_result = this->compute(desc, csr_table); + table weights; + check_compute_result(compute_mode, dense_table, weights, compute_result); + } + + // TODO: Fix DAAL code. On big datasets there is an error in computing. + // To reproduce it remove this check from test case in batch.cpp + bool not_cpu_friendly(const te::csr_table_builder& data) { + using host_policy = oneapi::dal::test::engine::host_test_policy; + auto policy = this->get_policy(); + return (data.row_count_ > 100 || data.column_count_ > 100) && + std::is_same_v; + } + void online_general_checks(const te::dataframe& data_fr, std::shared_ptr weights_fr, bs::result_option_id compute_mode, @@ -214,7 +240,7 @@ class basic_statistics_test : public te::crtp_algo_fixture { CAPTURE(name, r_count, c_count, r, c, lval, rval); const auto aerr = std::abs(lval - rval); - if (aerr < tol) + if (aerr < tol || (!std::isfinite(lval) && !std::isfinite(rval))) continue; const auto den = std::max({ eps, // @@ -288,14 +314,12 @@ class basic_statistics_test : public te::crtp_algo_fixture { (elem * weight - ref_mean.get(0, clmn)); } } - for (std::int64_t clmn = 0; clmn < column_count; clmn++) { ref_sorm.set(0, clmn) = ref_sum2.get(0, clmn) / float_t(row_count); ref_varc.set(0, clmn) = ref_sum2cent.get(0, clmn) / float_t(row_count - 1); ref_stdev.set(0, clmn) = std::sqrt(ref_varc.get(0, clmn)); ref_vart.set(0, clmn) = ref_stdev.get(0, clmn) / ref_mean.get(0, clmn); } - if (compute_mode.test(result_options::min)) { const table ref = homogen_table::wrap(ref_min.get_array(), 1l, column_count); check_if_close(result.get_min(), ref, "Min"); @@ -380,5 +404,7 @@ class basic_statistics_test : public te::crtp_algo_fixture { }; using basic_statistics_types = COMBINE_TYPES((float, double), (basic_statistics::method::dense)); +using basic_statistics_sparse_types = COMBINE_TYPES((float, double), + (basic_statistics::method::sparse)); } // namespace oneapi::dal::basic_statistics::test diff --git a/cpp/oneapi/dal/test/engine/BUILD b/cpp/oneapi/dal/test/engine/BUILD index 1a4d102eed3..6432a50021e 100644 --- a/cpp/oneapi/dal/test/engine/BUILD +++ b/cpp/oneapi/dal/test/engine/BUILD @@ -15,6 +15,7 @@ dal_test_module( "@onedal//cpp/oneapi/dal:common", "@onedal//cpp/oneapi/dal:table", "@onedal//cpp/oneapi/dal/io:csv", + "@onedal//cpp/oneapi/dal/backend/primitives:reduction", ], dal_test_deps = [ "@onedal//cpp/oneapi/dal/test/engine/linalg", diff --git a/cpp/oneapi/dal/test/engine/csr_table_builder.hpp b/cpp/oneapi/dal/test/engine/csr_table_builder.hpp new file mode 100644 index 00000000000..e8de4036bfe --- /dev/null +++ b/cpp/oneapi/dal/test/engine/csr_table_builder.hpp @@ -0,0 +1,176 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "oneapi/dal/table/csr.hpp" +#include "oneapi/dal/test/engine/math.hpp" + +namespace oneapi::dal::test::engine { + +/** +* Generates random CSR table based on inputs +*/ +struct csr_table_builder { + using Float = float; + std::int64_t row_count_, column_count_; + float nonzero_fraction_; + sparse_indexing indexing_; + const dal::array data_; + const dal::array column_indices_; + const dal::array row_offsets_; + + csr_table_builder(std::int64_t row_count, + std::int64_t column_count, + float nnz_fraction = 0.1, + sparse_indexing indexing = sparse_indexing::one_based, + float min_val = -10.0, + float max_val = 10.0, + int seed_in = 42) + : row_count_(row_count), + column_count_(column_count), + nonzero_fraction_(nnz_fraction), + indexing_(indexing), + data_(dal::array::empty(nnz_fraction * row_count * column_count)), + column_indices_( + dal::array::empty(nnz_fraction * row_count * column_count)), + row_offsets_(dal::array::empty(row_count + 1)) { + std::int64_t total_count = row_count_ * column_count_; + std::int64_t nonzero_count = total_count * nnz_fraction; + std::int64_t indexing_shift = bool(indexing == sparse_indexing::one_based); + + std::uint32_t seed = seed_in; + std::mt19937 rng(seed); + std::uniform_real_distribution uniform_data(min_val, max_val); + std::uniform_int_distribution uniform_indices( + 0, + column_count_ - 1 - indexing_shift); + std::uniform_int_distribution uniform_ind_count(1, column_count_ - 2); + + auto data_ptr = data_.get_mutable_data(); + auto col_indices_ptr = column_indices_.get_mutable_data(); + auto row_offsets_ptr = row_offsets_.get_mutable_data(); + // Generate data + for (std::int32_t i = 0; i < nonzero_count; ++i) { + data_ptr[i] = uniform_data(rng); + } + // Generate column indices and fill row offsets + std::int64_t row_idx = 0; + std::int64_t fill_count = 0; + row_offsets_ptr[0] = indexing_shift; + while (fill_count < nonzero_count && row_idx < row_count_) { + // Generate the number of non-zero columns for current row + std::int64_t nnz_col_count = uniform_ind_count(rng); + nnz_col_count = std::min(nnz_col_count, nonzero_count - fill_count); + for (std::int32_t i = 0; i < nnz_col_count; ++i) { + std::int64_t col_idx = uniform_indices(rng) + indexing_shift; + col_indices_ptr[fill_count + i] = col_idx + indexing_shift; + } + std::sort(col_indices_ptr + fill_count, col_indices_ptr + fill_count + nnz_col_count); + // Remove duplications + std::int64_t dup_count = 0; + for (std::int32_t i = 1; i < nnz_col_count; ++i) { + auto cur_ptr = col_indices_ptr + (fill_count - dup_count); + if (cur_ptr[i] == cur_ptr[i - 1]) { + ++dup_count; + // Shift the tail if there is duplication + for (std::int32_t j = i + 1; j < nnz_col_count; ++j) { + cur_ptr[j - 1] = cur_ptr[j]; + } + } + } + fill_count += (nnz_col_count - dup_count); + // Update row offsets + row_offsets_ptr[row_idx + 1] = fill_count + indexing_shift; + row_idx++; + } + if (row_idx < row_count_) { + for (std::int32_t i = row_idx; i <= row_count_; ++i) { + row_offsets_ptr[i] = fill_count + indexing_shift; + } + } + } + +#ifdef ONEDAL_DATA_PARALLEL + csr_table build_csr_table(device_test_policy& policy) const { + auto queue = policy.get_queue(); + auto row_offs_ptr = row_offsets_.get_data(); + auto nnz_count = row_offs_ptr[row_count_] - row_offs_ptr[0]; + const auto copied_data = + dal::array::empty(queue, nnz_count, sycl::usm::alloc::device); + const auto copied_col_indices = + dal::array::empty(queue, nnz_count, sycl::usm::alloc::device); + const auto copied_row_offsets = + dal::array::empty(queue, row_count_ + 1, sycl::usm::alloc::device); + auto data_event = + queue.copy(data_.get_data(), copied_data.get_mutable_data(), nnz_count); + auto col_indices_event = queue.copy(column_indices_.get_data(), + copied_col_indices.get_mutable_data(), + nnz_count); + auto row_offsets_event = queue.copy(row_offsets_.get_data(), + copied_row_offsets.get_mutable_data(), + row_count_ + 1); + sycl::event::wait_and_throw({ data_event, col_indices_event, row_offsets_event }); + return csr_table::wrap(copied_data, + copied_col_indices, + copied_row_offsets, + column_count_, + indexing_); + } +#endif // ONEDAL_DATA_PARALLEL + + csr_table build_csr_table(host_test_policy& policy) const { + auto row_offs_ptr = row_offsets_.get_data(); + auto nnz_count = row_offs_ptr[row_count_] - row_offs_ptr[0]; + const auto copied_data = dal::array::empty(nnz_count); + const auto copied_col_indices = dal::array::empty(nnz_count); + const auto copied_row_offsets = dal::array::empty(row_count_ + 1); + + auto copied_data_ptr = copied_data.get_mutable_data(); + auto copied_col_indices_ptr = copied_col_indices.get_mutable_data(); + auto copied_row_offsets_ptr = copied_row_offsets.get_mutable_data(); + for (std::int32_t i = 0; i < nnz_count; ++i) { + copied_data_ptr[i] = data_.get_data()[i]; + copied_col_indices_ptr[i] = column_indices_.get_data()[i]; + } + for (std::int32_t i = 0; i <= row_count_; ++i) { + copied_row_offsets_ptr[i] = row_offs_ptr[i]; + } + return csr_table::wrap(copied_data, + copied_col_indices, + copied_row_offsets, + column_count_, + indexing_); + } + + table build_dense_table() const { + const dal::array dense_data = dal::array::zeros(row_count_ * column_count_); + std::int64_t indexing_shift = bool(indexing_ == sparse_indexing::one_based); + auto data_ptr = dense_data.get_mutable_data(); + auto sparse_data_ptr = data_.get_data(); + auto row_offs_ptr = row_offsets_.get_data(); + auto col_indices_ptr = column_indices_.get_data(); + for (std::int32_t row_idx = 0; row_idx < row_count_; ++row_idx) { + for (std::int32_t data_idx = row_offs_ptr[row_idx] - indexing_shift; + data_idx < row_offs_ptr[row_idx + 1] - indexing_shift; + ++data_idx) { + data_ptr[row_idx * column_count_ + col_indices_ptr[data_idx] - indexing_shift] = + sparse_data_ptr[data_idx]; + } + } + return homogen_table::wrap(dense_data, row_count_, column_count_); + } +}; + +} //namespace oneapi::dal::test::engine