From 7664e40a5295f1973f5b1f38713e4beee337d906 Mon Sep 17 00:00:00 2001 From: Aleksandr Solovev Date: Wed, 11 Oct 2023 19:34:21 +0200 Subject: [PATCH] Fix online lom performance (#2541) --- .../backend/basic_statistics_interop.hpp | 29 ++ .../cpu/finalize_compute_kernel_dense.cpp | 20 +- .../cpu/partial_compute_kernel_dense.cpp | 134 +++--- .../gpu/finalize_compute_kernel_dense_dpc.cpp | 140 +++--- .../gpu/partial_compute_kernel_dense_dpc.cpp | 405 ++++++++---------- .../detail/finalize_compute_ops.hpp | 38 +- .../algo/basic_statistics/test/fixture.hpp | 33 +- .../dal/algo/basic_statistics/test/online.cpp | 3 +- 8 files changed, 432 insertions(+), 370 deletions(-) diff --git a/cpp/oneapi/dal/algo/basic_statistics/backend/basic_statistics_interop.hpp b/cpp/oneapi/dal/algo/basic_statistics/backend/basic_statistics_interop.hpp index 9518dd8ae20..7adbefb6399 100644 --- a/cpp/oneapi/dal/algo/basic_statistics/backend/basic_statistics_interop.hpp +++ b/cpp/oneapi/dal/algo/basic_statistics/backend/basic_statistics_interop.hpp @@ -31,6 +31,35 @@ namespace bk = dal::backend; using task_t = task::compute; using descriptor_t = detail::descriptor_base; +template +inline auto get_desc_to_compute(const descriptor_t& desc) { + const auto res_op = desc.get_result_options(); + bool has_min_max = res_op.test(result_options::min) || res_op.test(result_options::max); + bool has_other_stat = + res_op.test(result_options::mean) || res_op.test(result_options::variance) || + res_op.test(result_options::second_order_raw_moment) || + res_op.test(result_options::variation) || res_op.test(result_options::standard_deviation); + bool has_sums = res_op.test(result_options::sum) || + res_op.test(result_options::sum_squares_centered) || has_other_stat; + bool has_sums2 = res_op.test(result_options::sum_squares_centered) || + res_op.test(result_options::sum_squares_centered) || has_other_stat; + auto local_desc = + basic_statistics::descriptor(); + if (has_min_max && has_sums && has_sums2) { + local_desc.set_result_options(result_options::min | result_options::max | + result_options::sum | result_options::sum_squares | + result_options::sum_squares_centered); + } + else if (!has_min_max || has_sums || has_sums2) { + local_desc.set_result_options(result_options::sum | result_options::sum_squares | + result_options::sum_squares_centered); + } + else if (has_min_max && !has_sums && !has_sums2) { + local_desc.set_result_options(result_options::min | result_options::max); + } + return local_desc; +} + inline auto get_daal_estimates_to_compute(const descriptor_t& desc) { const auto res_op = desc.get_result_options(); const auto res_min_max = result_options::min | result_options::max; diff --git a/cpp/oneapi/dal/algo/basic_statistics/backend/cpu/finalize_compute_kernel_dense.cpp b/cpp/oneapi/dal/algo/basic_statistics/backend/cpu/finalize_compute_kernel_dense.cpp index 5807ec85573..c8cef97dc16 100644 --- a/cpp/oneapi/dal/algo/basic_statistics/backend/cpu/finalize_compute_kernel_dense.cpp +++ b/cpp/oneapi/dal/algo/basic_statistics/backend/cpu/finalize_compute_kernel_dense.cpp @@ -50,11 +50,21 @@ static compute_result call_daal_kernel_finalize_compute( const context_cpu& ctx, const descriptor_t& desc, const partial_compute_result& input) { - const auto result_ids = daal_lom::estimatesAll; + const auto result_ids = get_daal_estimates_to_compute(desc); const auto daal_parameter = daal_lom::Parameter(result_ids); - auto column_count = input.get_partial_min().get_column_count(); + std::int64_t column_count = 0; + const auto res_op = desc.get_result_options(); + + if (result_ids == daal_lom::estimatesMeanVariance || result_ids == daal_lom::estimatesAll) { + column_count = input.get_partial_sum().get_column_count(); + } + if (result_ids == daal_lom::estimatesMinMax) { + column_count = input.get_partial_min().get_column_count(); + } + ONEDAL_ASSERT(column_count > 0); + auto daal_partial = daal_lom::PartialResult(); auto daal_partial_obs = interop::copy_to_daal_homogen_table(input.get_partial_n_rows()); auto daal_partial_min = interop::copy_to_daal_homogen_table(input.get_partial_min()); auto daal_partial_max = interop::copy_to_daal_homogen_table(input.get_partial_max()); @@ -66,11 +76,10 @@ static compute_result call_daal_kernel_finalize_compute( auto daal_means = interop::allocate_daal_homogen_table(1, column_count); auto daal_rawt = interop::allocate_daal_homogen_table(1, column_count); - auto daal_variance = interop::allocate_daal_homogen_table(1, column_count); auto daal_stdev = interop::allocate_daal_homogen_table(1, column_count); auto daal_variation = interop::allocate_daal_homogen_table(1, column_count); - { + if (result_ids == daal_lom::estimatesMeanVariance || result_ids == daal_lom::estimatesAll) { interop::status_to_exception( interop::call_daal_kernel_finalize_compute( ctx, @@ -85,11 +94,8 @@ static compute_result call_daal_kernel_finalize_compute( daal_variation.get(), &daal_parameter)); } - compute_result res; - const auto res_op = desc.get_result_options(); res.set_result_options(desc.get_result_options()); - if (res_op.test(result_options::min)) { res.set_min(interop::convert_from_daal_homogen_table(daal_partial_min)); } diff --git a/cpp/oneapi/dal/algo/basic_statistics/backend/cpu/partial_compute_kernel_dense.cpp b/cpp/oneapi/dal/algo/basic_statistics/backend/cpu/partial_compute_kernel_dense.cpp index 0b25ebf4bcd..a32732edd45 100644 --- a/cpp/oneapi/dal/algo/basic_statistics/backend/cpu/partial_compute_kernel_dense.cpp +++ b/cpp/oneapi/dal/algo/basic_statistics/backend/cpu/partial_compute_kernel_dense.cpp @@ -46,22 +46,28 @@ using daal_lom_online_kernel_t = daal_lom::internal::LowOrderMomentsOnlineKernel; template -inline auto get_partial_result(daal_lom::PartialResult daal_partial_result) { +inline auto get_partial_result(daal_lom::PartialResult daal_partial_result, + const descriptor_t& desc) { auto result = partial_compute_result(); - + const auto res_op = desc.get_result_options(); + const auto result_ids = get_daal_estimates_to_compute(desc); result.set_partial_n_rows(interop::convert_from_daal_homogen_table( daal_partial_result.get(daal_lom::PartialResultId::nObservations))); - result.set_partial_min(interop::convert_from_daal_homogen_table( - daal_partial_result.get(daal_lom::PartialResultId::partialMinimum))); - result.set_partial_max(interop::convert_from_daal_homogen_table( - daal_partial_result.get(daal_lom::PartialResultId::partialMaximum))); - result.set_partial_sum(interop::convert_from_daal_homogen_table( - daal_partial_result.get(daal_lom::PartialResultId::partialSum))); - result.set_partial_sum_squares(interop::convert_from_daal_homogen_table( - daal_partial_result.get(daal_lom::PartialResultId::partialSumSquares))); - result.set_partial_sum_squares_centered(interop::convert_from_daal_homogen_table( - daal_partial_result.get(daal_lom::PartialResultId::partialSumSquaresCentered))); - + if (result_ids == daal_lom::estimatesMinMax || res_op.test(result_options::min) || + res_op.test(result_options::max)) { + result.set_partial_min(interop::convert_from_daal_homogen_table( + daal_partial_result.get(daal_lom::PartialResultId::partialMinimum))); + result.set_partial_max(interop::convert_from_daal_homogen_table( + daal_partial_result.get(daal_lom::PartialResultId::partialMaximum))); + } + if (result_ids == daal_lom::estimatesMeanVariance || result_ids == daal_lom::estimatesAll) { + result.set_partial_sum(interop::convert_from_daal_homogen_table( + daal_partial_result.get(daal_lom::PartialResultId::partialSum))); + result.set_partial_sum_squares(interop::convert_from_daal_homogen_table( + daal_partial_result.get(daal_lom::PartialResultId::partialSumSquares))); + result.set_partial_sum_squares_centered(interop::convert_from_daal_homogen_table( + daal_partial_result.get(daal_lom::PartialResultId::partialSumSquaresCentered))); + } return result; } @@ -82,10 +88,12 @@ result_t call_daal_kernel_with_weights(const context_cpu& ctx, auto daal_input = daal_lom::Input(); auto daal_partial = daal_lom::PartialResult(); + const auto res_op = desc.get_result_options(); + const auto input_ = input.get_prev(); row_accessor data_accessor(data); row_accessor weights_accessor(weights); - const auto result_ids = daal_lom::estimatesAll; + const auto result_ids = get_daal_estimates_to_compute(desc); const auto daal_parameter = daal_lom::Parameter(result_ids); auto weights_arr = weights_accessor.pull(); auto gen_data_block = data_accessor.pull(); @@ -110,26 +118,32 @@ result_t call_daal_kernel_with_weights(const context_cpu& ctx, } const bool has_nobs_data = input_.get_partial_n_rows().has_data(); if (has_nobs_data) { - auto daal_partial_max = - interop::copy_to_daal_homogen_table(input_.get_partial_max()); - auto daal_partial_min = - interop::copy_to_daal_homogen_table(input_.get_partial_min()); - auto daal_partial_sums = - interop::copy_to_daal_homogen_table(input_.get_partial_sum()); - auto daal_partial_sum_squares = - interop::copy_to_daal_homogen_table(input_.get_partial_sum_squares()); - auto daal_partial_sum_squares_centered = - interop::copy_to_daal_homogen_table(input_.get_partial_sum_squares_centered()); auto daal_nobs = interop::copy_to_daal_homogen_table(input_.get_partial_n_rows()); daal_partial.set(daal_lom::PartialResultId::nObservations, daal_nobs); - - daal_partial.set(daal_lom::PartialResultId::partialMaximum, daal_partial_max); - daal_partial.set(daal_lom::PartialResultId::partialMinimum, daal_partial_min); - daal_partial.set(daal_lom::PartialResultId::partialSum, daal_partial_sums); - daal_partial.set(daal_lom::PartialResultId::partialSumSquaresCentered, - daal_partial_sum_squares_centered); - - daal_partial.set(daal_lom::PartialResultId::partialSumSquares, daal_partial_sum_squares); + if (result_ids == daal_lom::estimatesMinMax || res_op.test(result_options::min) || + res_op.test(result_options::max)) { + auto daal_partial_max = + interop::copy_to_daal_homogen_table(input_.get_partial_max()); + auto daal_partial_min = + interop::copy_to_daal_homogen_table(input_.get_partial_min()); + daal_partial.set(daal_lom::PartialResultId::partialMaximum, daal_partial_max); + daal_partial.set(daal_lom::PartialResultId::partialMinimum, daal_partial_min); + } + if (result_ids == daal_lom::estimatesMeanVariance || result_ids == daal_lom::estimatesAll) { + auto daal_partial_sums = + interop::copy_to_daal_homogen_table(input_.get_partial_sum()); + auto daal_partial_sum_squares = + interop::copy_to_daal_homogen_table(input_.get_partial_sum_squares()); + auto daal_partial_sum_squares_centered = interop::copy_to_daal_homogen_table( + input_.get_partial_sum_squares_centered()); + + daal_partial.set(daal_lom::PartialResultId::partialSum, daal_partial_sums); + daal_partial.set(daal_lom::PartialResultId::partialSumSquaresCentered, + daal_partial_sum_squares_centered); + + daal_partial.set(daal_lom::PartialResultId::partialSumSquares, + daal_partial_sum_squares); + } { interop::status_to_exception( interop::call_daal_kernel(ctx, @@ -138,7 +152,7 @@ result_t call_daal_kernel_with_weights(const context_cpu& ctx, &daal_parameter, is_online)); } - auto result = get_partial_result(daal_partial); + auto result = get_partial_result(daal_partial, desc); return result; } @@ -151,7 +165,7 @@ result_t call_daal_kernel_with_weights(const context_cpu& ctx, &daal_parameter, is_online)); } - auto result = get_partial_result(daal_partial); + auto result = get_partial_result(daal_partial, desc); return result; } } @@ -169,11 +183,13 @@ result_t call_daal_kernel_without_weights(const context_cpu& ctx, const auto input_ = input.get_prev(); - const auto result_ids = daal_lom::estimatesAll; + const auto result_ids = get_daal_estimates_to_compute(desc); const auto daal_parameter = daal_lom::Parameter(result_ids); const auto daal_data = interop::convert_to_daal_table(data); + const auto res_op = desc.get_result_options(); + daal_input.set(daal_lom::InputId::data, daal_data); const bool has_nobs_data = input_.get_partial_n_rows().has_data(); { @@ -181,35 +197,39 @@ result_t call_daal_kernel_without_weights(const context_cpu& ctx, initialize_result(daal_partial, &daal_input, &daal_parameter, result_ids); } if (has_nobs_data) { - auto daal_partial_max = - interop::copy_to_daal_homogen_table(input_.get_partial_max()); - auto daal_partial_min = - interop::copy_to_daal_homogen_table(input_.get_partial_min()); - auto daal_partial_sums = - interop::copy_to_daal_homogen_table(input_.get_partial_sum()); - auto daal_partial_sum_squares = - interop::copy_to_daal_homogen_table(input_.get_partial_sum_squares()); - auto daal_partial_sum_squares_centered = - interop::copy_to_daal_homogen_table(input_.get_partial_sum_squares_centered()); auto daal_nobs = interop::copy_to_daal_homogen_table(input_.get_partial_n_rows()); - daal_partial.set(daal_lom::PartialResultId::nObservations, daal_nobs); - - daal_partial.set(daal_lom::PartialResultId::partialMaximum, daal_partial_max); - daal_partial.set(daal_lom::PartialResultId::partialMinimum, daal_partial_min); - daal_partial.set(daal_lom::PartialResultId::partialSum, daal_partial_sums); - daal_partial.set(daal_lom::PartialResultId::partialSumSquaresCentered, - daal_partial_sum_squares_centered); - - daal_partial.set(daal_lom::PartialResultId::partialSumSquares, daal_partial_sum_squares); - + if (result_ids == daal_lom::estimatesMinMax || res_op.test(result_options::min) || + res_op.test(result_options::max)) { + auto daal_partial_max = + interop::copy_to_daal_homogen_table(input_.get_partial_max()); + auto daal_partial_min = + interop::copy_to_daal_homogen_table(input_.get_partial_min()); + daal_partial.set(daal_lom::PartialResultId::partialMaximum, daal_partial_max); + daal_partial.set(daal_lom::PartialResultId::partialMinimum, daal_partial_min); + } + if (result_ids == daal_lom::estimatesMeanVariance || result_ids == daal_lom::estimatesAll) { + auto daal_partial_sums = + interop::copy_to_daal_homogen_table(input_.get_partial_sum()); + auto daal_partial_sum_squares = + interop::copy_to_daal_homogen_table(input_.get_partial_sum_squares()); + auto daal_partial_sum_squares_centered = interop::copy_to_daal_homogen_table( + input_.get_partial_sum_squares_centered()); + + daal_partial.set(daal_lom::PartialResultId::partialSum, daal_partial_sums); + daal_partial.set(daal_lom::PartialResultId::partialSumSquaresCentered, + daal_partial_sum_squares_centered); + + daal_partial.set(daal_lom::PartialResultId::partialSumSquares, + daal_partial_sum_squares); + } interop::status_to_exception( interop::call_daal_kernel(ctx, daal_data.get(), &daal_partial, &daal_parameter, is_online)); - auto result = get_partial_result(daal_partial); + auto result = get_partial_result(daal_partial, desc); return result; } else { @@ -221,7 +241,7 @@ result_t call_daal_kernel_without_weights(const context_cpu& ctx, &daal_parameter, is_online)); } - auto result = get_partial_result(daal_partial); + auto result = get_partial_result(daal_partial, desc); return result; } } diff --git a/cpp/oneapi/dal/algo/basic_statistics/backend/gpu/finalize_compute_kernel_dense_dpc.cpp b/cpp/oneapi/dal/algo/basic_statistics/backend/gpu/finalize_compute_kernel_dense_dpc.cpp index 86b8a0f540a..9c87a5b8818 100644 --- a/cpp/oneapi/dal/algo/basic_statistics/backend/gpu/finalize_compute_kernel_dense_dpc.cpp +++ b/cpp/oneapi/dal/algo/basic_statistics/backend/gpu/finalize_compute_kernel_dense_dpc.cpp @@ -22,6 +22,8 @@ #include "oneapi/dal/detail/policy.hpp" #include "oneapi/dal/table/row_accessor.hpp" +#include "oneapi/dal/algo/basic_statistics/backend/basic_statistics_interop.hpp" + namespace oneapi::dal::basic_statistics::backend { namespace bk = dal::backend; @@ -56,10 +58,10 @@ auto compute_all_metrics(sycl::queue& q, auto result_stddev_ptr = result_stddev.get_mutable_data(); auto nobs_ptr = nobs.get_data(); - auto sums_data = sums.get_data(); auto sums2_data = sums2.get_data(); auto sums2cent_data = sums2cent.get_data(); + const Float inv_n = Float(1.0 / double(nobs_ptr[0])); auto update_event = q.submit([&](sycl::handler& cgh) { const auto range = sycl::range<1>(column_count); @@ -67,6 +69,7 @@ auto compute_all_metrics(sycl::queue& q, cgh.depends_on(deps); cgh.parallel_for(range, [=](sycl::item<1> id) { result_means_ptr[id] = sums_data[id] / nobs_ptr[0]; + result_variance_ptr[id] = sums2cent_data[id] / (nobs_ptr[0] - 1); result_raw_moment_ptr[id] = sums2_data[id] * inv_n; @@ -91,75 +94,106 @@ static compute_result finalize_compute(const context_gpu& ctx, const partial_compute_result& input) { auto& q_ = ctx.get_queue(); result_t res; + auto local_desc = get_desc_to_compute(desc); + const auto res_op_partial = local_desc.get_result_options(); + auto column_count = 0; + + if (res_op_partial.test(result_options::min)) { + column_count = input.get_partial_min().get_column_count(); + } + if (res_op_partial.test(result_options::sum)) { + column_count = input.get_partial_sum().get_column_count(); + } - auto column_count = input.get_partial_sum_squares().get_column_count(); ONEDAL_ASSERT(column_count > 0); const auto res_op = desc.get_result_options(); res.set_result_options(desc.get_result_options()); - const auto sums_nd = - pr::table2ndarray_1d(q_, input.get_partial_sum(), sycl::usm::alloc::device); const auto nobs_nd = pr::table2ndarray_1d(q_, input.get_partial_n_rows()); - const auto sums2_nd = - pr::table2ndarray_1d(q_, input.get_partial_sum_squares(), sycl::usm::alloc::device); - const auto sums2cent_nd = pr::table2ndarray_1d(q_, - input.get_partial_sum_squares_centered(), - sycl::usm::alloc::device); if (res_op.test(result_options::min)) { ONEDAL_ASSERT(input.get_partial_min().get_column_count() == column_count); res.set_min(input.get_partial_min()); } + if (res_op.test(result_options::max)) { ONEDAL_ASSERT(input.get_partial_max().get_column_count() == column_count); res.set_max(input.get_partial_max()); } - if (res_op.test(result_options::sum)) { - ONEDAL_ASSERT(input.get_partial_sum().get_column_count() == column_count); - res.set_sum(input.get_partial_sum()); - } - if (res_op.test(result_options::sum_squares)) { - ONEDAL_ASSERT(input.get_partial_sum_squares().get_column_count() == column_count); - res.set_sum_squares(input.get_partial_sum_squares()); - } - if (res_op.test(result_options::sum_squares_centered)) { - ONEDAL_ASSERT(input.get_partial_sum_squares_centered().get_column_count() == column_count); - res.set_sum_squares_centered(input.get_partial_sum_squares_centered()); - } - - auto [result_means, - result_variance, - result_raw_moment, - result_variation, - result_stddev, - update_event] = - compute_all_metrics(q_, sums_nd, sums2_nd, sums2cent_nd, nobs_nd, column_count, {}); - if (res_op.test(result_options::mean)) { - ONEDAL_ASSERT(result_means.get_dimension(0) == column_count); - res.set_mean( - homogen_table::wrap(result_means.flatten(q_, { update_event }), 1, column_count)); - } - if (res_op.test(result_options::second_order_raw_moment)) { - ONEDAL_ASSERT(result_raw_moment.get_dimension(0) == column_count); - res.set_second_order_raw_moment( - homogen_table::wrap(result_raw_moment.flatten(q_, { update_event }), 1, column_count)); - } - if (res_op.test(result_options::variance)) { - ONEDAL_ASSERT(result_variance.get_dimension(0) == column_count); - res.set_variance( - homogen_table::wrap(result_variance.flatten(q_, { update_event }), 1, column_count)); - } - if (res_op.test(result_options::standard_deviation)) { - ONEDAL_ASSERT(result_stddev.get_dimension(0) == column_count); - res.set_standard_deviation( - homogen_table::wrap(result_stddev.flatten(q_, { update_event }), 1, column_count)); - } - if (res_op.test(result_options::variation)) { - ONEDAL_ASSERT(result_variation.get_dimension(0) == column_count); - res.set_variation( - homogen_table::wrap(result_variation.flatten(q_, { update_event }), 1, column_count)); + if (res_op_partial.test(result_options::sum)) { + const auto sums_nd = + pr::table2ndarray_1d(q_, input.get_partial_sum(), sycl::usm::alloc::device); + const auto sums2_nd = pr::table2ndarray_1d(q_, + input.get_partial_sum_squares(), + sycl::usm::alloc::device); + const auto sums2cent_nd = + pr::table2ndarray_1d(q_, + input.get_partial_sum_squares_centered(), + sycl::usm::alloc::device); + auto [result_means, + result_variance, + result_raw_moment, + result_variation, + result_stddev, + update_event] = compute_all_metrics(q_, + sums_nd, + sums2_nd, + sums2cent_nd, + nobs_nd, + column_count, + {}); + + if (res_op.test(result_options::sum)) { + ONEDAL_ASSERT(input.get_partial_sum().get_column_count() == column_count); + res.set_sum(input.get_partial_sum()); + } + + if (res_op.test(result_options::sum_squares)) { + ONEDAL_ASSERT(input.get_partial_sum_squares().get_column_count() == column_count); + res.set_sum_squares(input.get_partial_sum_squares()); + } + + if (res_op.test(result_options::sum_squares_centered)) { + ONEDAL_ASSERT(input.get_partial_sum_squares_centered().get_column_count() == + column_count); + res.set_sum_squares_centered(input.get_partial_sum_squares_centered()); + } + + if (res_op.test(result_options::mean)) { + ONEDAL_ASSERT(result_means.get_dimension(0) == column_count); + res.set_mean( + homogen_table::wrap(result_means.flatten(q_, { update_event }), 1, column_count)); + } + + if (res_op.test(result_options::second_order_raw_moment)) { + ONEDAL_ASSERT(result_raw_moment.get_dimension(0) == column_count); + res.set_second_order_raw_moment( + homogen_table::wrap(result_raw_moment.flatten(q_, { update_event }), + 1, + column_count)); + } + + if (res_op.test(result_options::variance)) { + ONEDAL_ASSERT(result_variance.get_dimension(0) == column_count); + res.set_variance(homogen_table::wrap(result_variance.flatten(q_, { update_event }), + 1, + column_count)); + } + + if (res_op.test(result_options::standard_deviation)) { + ONEDAL_ASSERT(result_stddev.get_dimension(0) == column_count); + res.set_standard_deviation( + homogen_table::wrap(result_stddev.flatten(q_, { update_event }), 1, column_count)); + } + + if (res_op.test(result_options::variation)) { + ONEDAL_ASSERT(result_variation.get_dimension(0) == column_count); + res.set_variation(homogen_table::wrap(result_variation.flatten(q_, { update_event }), + 1, + column_count)); + } } return res; } diff --git a/cpp/oneapi/dal/algo/basic_statistics/backend/gpu/partial_compute_kernel_dense_dpc.cpp b/cpp/oneapi/dal/algo/basic_statistics/backend/gpu/partial_compute_kernel_dense_dpc.cpp index 4712e5531df..e646c02cfa2 100644 --- a/cpp/oneapi/dal/algo/basic_statistics/backend/gpu/partial_compute_kernel_dense_dpc.cpp +++ b/cpp/oneapi/dal/algo/basic_statistics/backend/gpu/partial_compute_kernel_dense_dpc.cpp @@ -15,6 +15,7 @@ *******************************************************************************/ #include "oneapi/dal/algo/basic_statistics/backend/gpu/partial_compute_kernel.hpp" +#include "oneapi/dal/algo/basic_statistics/backend/gpu/compute_kernel.hpp" #include "oneapi/dal/backend/common.hpp" #include "oneapi/dal/detail/common.hpp" @@ -25,6 +26,7 @@ #include "oneapi/dal/util/common.hpp" #include "oneapi/dal/backend/primitives/reduction.hpp" +#include "oneapi/dal/algo/basic_statistics/backend/basic_statistics_interop.hpp" namespace oneapi::dal::basic_statistics::backend { namespace bk = dal::backend; @@ -39,167 +41,106 @@ using result_t = partial_compute_result; using descriptor_t = detail::descriptor_base; template -auto update_partial_results(sycl::queue& q, +auto update_partial_n_rows_results(sycl::queue& q, + const std::int64_t row_count, + const pr::ndview& nobs, + const dal::backend::event_vector& deps = {}) { + ONEDAL_PROFILER_TASK(update_partial_n_rows_results, q); + + auto result_nobs = pr::ndarray::empty(q, 1, alloc::device); + auto result_nobs_ptr = result_nobs.get_mutable_data(); + auto nobs_ptr = nobs.get_data(); + + auto nobs_update_event = q.submit([&](sycl::handler& cgh) { + const auto range = sycl::range(1); + + cgh.depends_on(deps); + cgh.parallel_for(range, [=](sycl::item<1> id) { + result_nobs_ptr[0] = nobs_ptr[0] + row_count; + }); + }); + + return std::make_tuple(result_nobs, nobs_update_event); +} + +template +auto update_min_max_results(sycl::queue& q, const pr::ndview& min, - const pr::ndview& current_min, + const table current_min, const pr::ndview& max, - const pr::ndview& current_max, - const pr::ndview& sums, - const pr::ndview& current_sums, - const pr::ndview& sums2, - const pr::ndview& current_sums2, - const pr::ndview& sums2cent, - const pr::ndview& current_sums2cent, + const table current_max, const std::int64_t column_count, - const std::int64_t row_count, - const pr::ndview& nobs, const dal::backend::event_vector& deps = {}) { - ONEDAL_PROFILER_TASK(update_partial_results, q); + ONEDAL_PROFILER_TASK(update_min_max_results, q); auto result_min = pr::ndarray::empty(q, column_count, alloc::device); auto result_max = pr::ndarray::empty(q, column_count, alloc::device); - auto result_sums = pr::ndarray::empty(q, column_count, alloc::device); - auto result_sums2 = pr::ndarray::empty(q, column_count, alloc::device); - auto result_sums2cent = pr::ndarray::empty(q, column_count, alloc::device); auto result_min_ptr = result_min.get_mutable_data(); auto result_max_ptr = result_max.get_mutable_data(); - auto result_sums_ptr = result_sums.get_mutable_data(); - auto result_sums2_ptr = result_sums2.get_mutable_data(); - auto result_sums2cent_ptr = result_sums2cent.get_mutable_data(); - auto current_min_ptr = current_min.get_mutable_data(); - auto current_max_ptr = current_max.get_mutable_data(); - auto current_sums_ptr = current_sums.get_mutable_data(); - auto current_sums2_ptr = current_sums2.get_mutable_data(); + auto current_min_ptr = + pr::table2ndarray_1d(q, current_min, sycl::usm::alloc::device).get_data(); + auto current_max_ptr = + pr::table2ndarray_1d(q, current_max, sycl::usm::alloc::device).get_data(); - auto nobs_ptr = nobs.get_data(); auto min_data = min.get_data(); auto max_data = max.get_data(); - auto sums_data = sums.get_data(); - auto sums2_data = sums2.get_data(); - auto update_event = q.submit([&](sycl::handler& cgh) { + auto merge_min_max_event = q.submit([&](sycl::handler& cgh) { const auto range = sycl::range<1>(column_count); cgh.depends_on(deps); cgh.parallel_for(range, [=](sycl::item<1> id) { result_min_ptr[id] = sycl::fmin(current_min_ptr[id], min_data[id]); result_max_ptr[id] = sycl::fmax(current_max_ptr[id], max_data[id]); - - result_sums_ptr[id] = current_sums_ptr[id] + sums_data[id]; - - result_sums2_ptr[id] = current_sums2_ptr[id] + sums2_data[id]; - - result_sums2cent_ptr[id] = - result_sums2_ptr[id] - result_sums_ptr[id] * result_sums_ptr[id] / nobs_ptr[0]; - }); - }); - return std::make_tuple(result_min, - result_max, - result_sums, - result_sums2, - result_sums2cent, - update_event); -} - -template -auto apply_weights(sycl::queue& q, - const pr::ndview& data, - std::int64_t row_count, - std::int64_t column_count, - const pr::ndview& weights, - const dal::backend::event_vector& deps = {}) { - ONEDAL_PROFILER_TASK(apply_weights, q); - auto data_to_compute = - pr::ndarray::empty(q, { row_count, column_count }, alloc::device); - - auto weights_ptr = weights.get_data(); - - auto data_to_compute_ptr = data_to_compute.get_mutable_data(); - - auto input_data = data.get_data(); - - auto apply_weights_event = q.submit([&](sycl::handler& cgh) { - const auto range = sycl::range<2>(row_count, column_count); - - cgh.depends_on(deps); - cgh.parallel_for(range, [=](sycl::item<2> id) { - data_to_compute_ptr[id[0] * column_count + id[1]] = - input_data[id[0] * column_count + id[1]] * weights_ptr[id[0]]; }); }); - - return std::make_tuple(data_to_compute, apply_weights_event); + return std::make_tuple(result_min, result_max, merge_min_max_event); } template -auto init_computation(sycl::queue& q, - const pr::ndview& data, - const pr::ndview& nobs, - std::int64_t column_count, - std::int64_t row_count, - const dal::backend::event_vector& deps = {}) { - ONEDAL_PROFILER_TASK(init_partial_results, q); - - auto component_count = column_count; - auto current_nobs_ptr = nobs.get_data(); - auto result_nobs = pr::ndarray::empty(q, 1); - auto result_nobs_ptr = result_nobs.get_mutable_data(); - auto result_max = pr::ndarray::empty(q, component_count, alloc::device); +auto update_partial_sums(sycl::queue& q, + const pr::ndview& sums, + const table current_sums, + const pr::ndview& sums2, + const table current_sums2, + const std::int64_t column_count, + const pr::ndview& nobs, + const dal::backend::event_vector& deps = {}) { + ONEDAL_PROFILER_TASK(update_partial_results, q); - auto result_min = pr::ndarray::empty(q, component_count, alloc::device); + auto result_sums = pr::ndarray::empty(q, column_count, alloc::device); + auto result_sums2 = pr::ndarray::empty(q, column_count, alloc::device); + auto result_sums2cent = pr::ndarray::empty(q, column_count, alloc::device); - auto result_sums = pr::ndarray::empty(q, component_count, alloc::device); + auto result_sums_ptr = result_sums.get_mutable_data(); + auto result_sums2_ptr = result_sums2.get_mutable_data(); + auto result_sums2cent_ptr = result_sums2cent.get_mutable_data(); - auto result_sums2 = pr::ndarray::empty(q, component_count, alloc::device); + auto current_sums_ptr = + pr::table2ndarray_1d(q, current_sums, sycl::usm::alloc::device).get_data(); + auto current_sums2_ptr = + pr::table2ndarray_1d(q, current_sums2, sycl::usm::alloc::device).get_data(); - auto result_sums2cent = pr::ndarray::empty(q, component_count, alloc::device); + auto nobs_ptr = nobs.get_data(); + auto sums_data = sums.get_data(); + auto sums2_data = sums2.get_data(); - auto nobs_update_event = q.submit([&](sycl::handler& cgh) { - const auto range = sycl::range(1); + auto update_sums_event = q.submit([&](sycl::handler& cgh) { + const auto range = sycl::range<1>(column_count); cgh.depends_on(deps); cgh.parallel_for(range, [=](sycl::item<1> id) { - result_nobs_ptr[0] = current_nobs_ptr[0] + row_count; + result_sums_ptr[id] = current_sums_ptr[id] + sums_data[id]; + + result_sums2_ptr[id] = current_sums2_ptr[id] + sums2_data[id]; + + result_sums2cent_ptr[id] = + result_sums2_ptr[id] - result_sums_ptr[id] * result_sums_ptr[id] / nobs_ptr[0]; }); }); - auto reduce_event_min = pr::reduce_by_columns(q, - data, - result_min, - pr::min{}, - pr::identity{}, - { nobs_update_event }); - reduce_event_min.wait_and_throw(); - auto reduce_event_max = pr::reduce_by_columns(q, - data, - result_max, - pr::max{}, - pr::identity{}, - { reduce_event_min }); - reduce_event_max.wait_and_throw(); - auto reduce_event_sums = pr::reduce_by_columns(q, - data, - result_sums, - pr::sum{}, - pr::identity{}, - { reduce_event_min }); - reduce_event_sums.wait_and_throw(); - auto reduce_event_sumssquares = pr::reduce_by_columns(q, - data, - result_sums2, - pr::sum{}, - pr::square{}, - { reduce_event_min }); - reduce_event_sumssquares.wait_and_throw(); - - return std::make_tuple(result_min, - result_max, - result_sums, - result_sums2, - result_sums2cent, - result_nobs, - reduce_event_sumssquares); + return std::make_tuple(result_sums, result_sums2, result_sums2cent, update_sums_event); } template @@ -210,132 +151,126 @@ static partial_compute_result partial_compute(const context_gpu& ctx, const auto data = input.get_data(); const bool weights_enabling = input.get_weights().has_data(); const auto weights = input.get_weights(); + + auto kernel = compute_kernel_gpu{}; + auto compute_result_ = compute_result(); + auto local_desc = get_desc_to_compute(desc); + const auto res_op = local_desc.get_result_options(); + auto result = partial_compute_result(); const auto input_ = input.get_prev(); + const std::int64_t row_count = data.get_row_count(); const std::int64_t column_count = data.get_column_count(); const std::int64_t component_count = data.get_column_count(); + dal::detail::check_mul_overflow(row_count, column_count); dal::detail::check_mul_overflow(column_count, column_count); dal::detail::check_mul_overflow(component_count, column_count); - const auto data_nd = pr::table2ndarray(q, data, sycl::usm::alloc::device); - - auto data_to_compute = data_nd; - sycl::event apply_weights_event; - if (weights_enabling) { - auto weights_nd = pr::table2ndarray_1d(q, weights, sycl::usm::alloc::device); - std::tie(data_to_compute, apply_weights_event) = - apply_weights(q, data_nd, row_count, column_count, weights_nd); - } - const bool has_nobs_data = input_.get_partial_n_rows().has_data(); - if (has_nobs_data) { - const auto sums_nd = - pr::table2ndarray_1d(q, input_.get_partial_sum(), sycl::usm::alloc::device); + if (weights_enabling) { + compute_result_ = kernel(ctx, local_desc, { data, weights }); + } + else { + compute_result_ = kernel(ctx, local_desc, { data }); + } const auto nobs_nd = pr::table2ndarray_1d(q, input_.get_partial_n_rows()); + auto [result_nobs, nobs_update_event] = + update_partial_n_rows_results(q, row_count, nobs_nd); + + if (res_op.test(result_options::min) || res_op.test(result_options::max)) { + const auto min_nd = + pr::table2ndarray_1d(q, input_.get_partial_min(), sycl::usm::alloc::device); + const auto max_nd = pr::table2ndarray_1d(q, input_.get_partial_max()); + auto [result_min, result_max, update_min_max_event] = + update_min_max_results(q, + min_nd, + compute_result_.get_min(), + max_nd, + compute_result_.get_max(), + column_count, + { nobs_update_event }); + + result.set_partial_min( + (homogen_table::wrap(result_min.flatten(q, { update_min_max_event }), + 1, + column_count))); + + result.set_partial_max( + (homogen_table::wrap(result_max.flatten(q, { update_min_max_event }), + 1, + column_count))); + } + + if (res_op.test(result_options::sum)) { + const auto sums_nd = + pr::table2ndarray_1d(q, input_.get_partial_sum(), sycl::usm::alloc::device); + const auto sums2_nd = pr::table2ndarray_1d(q, + input_.get_partial_sum_squares(), + sycl::usm::alloc::device); + auto [result_sums, result_sums2, result_sums2cent, merge_sums_event] = + update_partial_sums(q, + sums_nd, + compute_result_.get_sum(), + sums2_nd, + compute_result_.get_sum_squares(), + column_count, + result_nobs, + { nobs_update_event }); + + result.set_partial_sum( + (homogen_table::wrap(result_sums.flatten(q, { merge_sums_event }), + 1, + column_count))); + + result.set_partial_sum_squares( + (homogen_table::wrap(result_sums2.flatten(q, { merge_sums_event }), + 1, + column_count))); + + result.set_partial_sum_squares_centered( + (homogen_table::wrap(result_sums2cent.flatten(q, { merge_sums_event }), + 1, + column_count))); + } - const auto min_nd = - pr::table2ndarray_1d(q, input_.get_partial_min(), sycl::usm::alloc::device); - const auto max_nd = pr::table2ndarray_1d(q, input_.get_partial_max()); - - const auto sums2_nd = pr::table2ndarray_1d(q, - input_.get_partial_sum_squares(), - sycl::usm::alloc::device); - const auto sums2cent_nd = - pr::table2ndarray_1d(q, - input_.get_partial_sum_squares_centered(), - sycl::usm::alloc::device); - auto [partial_min, - partial_max, - partial_sums, - partial_sums2, - partial_sums2cent, - partial_nobs, - init_computation_event] = init_computation(q, - data_to_compute, - nobs_nd, - column_count, - row_count, - { apply_weights_event }); - - auto [result_min, - result_max, - result_sums, - result_sums2, - result_sums2cent, - merge_results_event] = update_partial_results(q, - min_nd, - partial_min, - max_nd, - partial_max, - sums_nd, - partial_sums, - sums2_nd, - partial_sums2, - sums2cent_nd, - partial_sums2cent, - column_count, - row_count, - partial_nobs, - { init_computation_event }); - result.set_partial_min( - (homogen_table::wrap(result_min.flatten(q, { merge_results_event }), 1, column_count))); - result.set_partial_max( - (homogen_table::wrap(result_max.flatten(q, { merge_results_event }), 1, column_count))); - - result.set_partial_sum(( - homogen_table::wrap(result_sums.flatten(q, { merge_results_event }), 1, column_count))); - result.set_partial_sum_squares( - (homogen_table::wrap(result_sums2.flatten(q, { merge_results_event }), - 1, - column_count))); - result.set_partial_sum_squares_centered( - (homogen_table::wrap(result_sums2cent.flatten(q, { merge_results_event }), - 1, - column_count))); result.set_partial_n_rows( - (homogen_table::wrap(partial_nobs.flatten(q, { merge_results_event }), 1, 1))); + (homogen_table::wrap(result_nobs.flatten(q, { nobs_update_event }), 1, 1))); } else { - auto init_nobs = pr::ndarray::empty(q, 1); - - auto [result_min, - result_max, - result_sums, - result_sums2, - result_sums2cent, - result_nobs, - init_computation_event] = init_computation(q, - data_to_compute, - init_nobs, - column_count, - row_count, - { apply_weights_event }); - - result.set_partial_min( - (homogen_table::wrap(result_min.flatten(q, { init_computation_event }), - 1, - column_count))); - result.set_partial_max( - (homogen_table::wrap(result_max.flatten(q, { init_computation_event }), - 1, - column_count))); - result.set_partial_sum( - (homogen_table::wrap(result_sums.flatten(q, { init_computation_event }), - 1, - column_count))); - result.set_partial_sum_squares( - (homogen_table::wrap(result_sums2.flatten(q, { init_computation_event }), - 1, - column_count))); - result.set_partial_sum_squares_centered( - (homogen_table::wrap(result_sums2cent.flatten(q, { init_computation_event }), - 1, - column_count))); - result.set_partial_n_rows( - (homogen_table::wrap(result_nobs.flatten(q, { init_computation_event }), 1, 1))); + auto [init_nobs, init_event] = + pr::ndarray::full(q, { 1 }, row_count, sycl::usm::alloc::device); + init_event.wait_and_throw(); + + if (weights_enabling) { + compute_result_ = kernel(ctx, local_desc, { data, weights }); + } + else { + compute_result_ = kernel(ctx, local_desc, { data }); + } + + if (res_op.test(result_options::min)) { + result.set_partial_min(compute_result_.get_min()); + } + + if (res_op.test(result_options::max)) { + result.set_partial_max(compute_result_.get_max()); + } + + if (res_op.test(result_options::sum)) { + result.set_partial_sum(compute_result_.get_sum()); + } + + if (res_op.test(result_options::sum_squares)) { + result.set_partial_sum_squares(compute_result_.get_sum_squares()); + } + + if (res_op.test(result_options::sum_squares_centered)) + result.set_partial_sum_squares_centered(compute_result_.get_sum_squares_centered()); + + result.set_partial_n_rows((homogen_table::wrap(init_nobs.flatten(q, {}), 1, 1))); } return result; diff --git a/cpp/oneapi/dal/algo/basic_statistics/detail/finalize_compute_ops.hpp b/cpp/oneapi/dal/algo/basic_statistics/detail/finalize_compute_ops.hpp index 563f8afe848..243d6b4c130 100644 --- a/cpp/oneapi/dal/algo/basic_statistics/detail/finalize_compute_ops.hpp +++ b/cpp/oneapi/dal/algo/basic_statistics/detail/finalize_compute_ops.hpp @@ -38,15 +38,41 @@ struct finalize_compute_ops { using result_t = compute_result; using descriptor_base_t = descriptor_base; - void check_preconditions(const Descriptor& params, const input_t& input) const { + void check_preconditions(const Descriptor& desc, const input_t& input) const { + const auto compute_mode = desc.get_result_options(); ONEDAL_ASSERT(input.get_partial_n_rows().has_data()); ONEDAL_ASSERT(input.get_partial_n_rows().get_column_count() == 1); ONEDAL_ASSERT(input.get_partial_n_rows().get_row_count() == 1); - ONEDAL_ASSERT(input.get_partial_max().has_data()); - ONEDAL_ASSERT(input.get_partial_min().has_data()); - ONEDAL_ASSERT(input.get_partial_sum().has_data()); - ONEDAL_ASSERT(input.get_partial_sum_squares().has_data()); - ONEDAL_ASSERT(input.get_partial_sum_squares_centered().has_data()); + if (compute_mode.test(result_options::min)) { + ONEDAL_ASSERT(input.get_partial_min().has_data()); + } + if (compute_mode.test(result_options::max)) { + ONEDAL_ASSERT(input.get_partial_max().has_data()); + } + if (compute_mode.test(result_options::sum)) { + ONEDAL_ASSERT(input.get_partial_sum().has_data()); + } + if (compute_mode.test(result_options::sum_squares)) { + ONEDAL_ASSERT(input.get_partial_sum_squares().has_data()); + } + if (compute_mode.test(result_options::sum_squares_centered)) { + ONEDAL_ASSERT(input.get_partial_sum_squares_centered().has_data()); + } + if (compute_mode.test(result_options::mean)) { + ONEDAL_ASSERT(input.get_partial_sum().has_data()); + } + if (compute_mode.test(result_options::second_order_raw_moment)) { + ONEDAL_ASSERT(input.get_partial_sum().has_data()); + } + if (compute_mode.test(result_options::variance)) { + ONEDAL_ASSERT(input.get_partial_sum().has_data()); + } + if (compute_mode.test(result_options::standard_deviation)) { + ONEDAL_ASSERT(input.get_partial_sum().has_data()); + } + if (compute_mode.test(result_options::variation)) { + ONEDAL_ASSERT(input.get_partial_sum().has_data()); + } } void check_postconditions(const Descriptor& params, diff --git a/cpp/oneapi/dal/algo/basic_statistics/test/fixture.hpp b/cpp/oneapi/dal/algo/basic_statistics/test/fixture.hpp index 6f911a9aff9..9f3672b1e8f 100644 --- a/cpp/oneapi/dal/algo/basic_statistics/test/fixture.hpp +++ b/cpp/oneapi/dal/algo/basic_statistics/test/fixture.hpp @@ -102,10 +102,10 @@ class basic_statistics_test : public te::crtp_algo_fixture { void online_general_checks(const te::dataframe& data_fr, std::shared_ptr weights_fr, - bs::result_option_id compute_mode) { + bs::result_option_id compute_mode, + std::int64_t nBlocks) { const auto use_weights = bool(weights_fr); CAPTURE(use_weights, compute_mode); - const std::int64_t nBlocks = 10; const auto bs_desc = get_descriptor(compute_mode); const auto data_table_id = this->get_homogen_table_id(); @@ -154,25 +154,36 @@ class basic_statistics_test : public te::crtp_algo_fixture { const result_t& result) { CAPTURE(data.get_row_count()); CAPTURE(data.get_column_count()); - if (compute_mode.test(res_min_max)) { + if (compute_mode.test(result_options::min)) { REQUIRE(result.get_min().get_column_count() == data.get_column_count()); - REQUIRE(result.get_max().get_column_count() == data.get_column_count()); } - - if (compute_mode.test(res_mean_varc)) { - REQUIRE(result.get_mean().get_column_count() == data.get_column_count()); - REQUIRE(result.get_variance().get_column_count() == data.get_column_count()); + if (compute_mode.test(result_options::max)) { + REQUIRE(result.get_max().get_column_count() == data.get_column_count()); } - - if ((compute_mode.test(res_min_max) && compute_mode.test(~res_min_max)) || - (compute_mode.test(res_mean_varc) && compute_mode.test(~res_mean_varc))) { + if (compute_mode.test(result_options::sum)) { REQUIRE(result.get_sum().get_column_count() == data.get_column_count()); + } + if (compute_mode.test(result_options::sum_squares)) { REQUIRE(result.get_sum_squares().get_column_count() == data.get_column_count()); + } + if (compute_mode.test(result_options::sum_squares_centered)) { REQUIRE(result.get_sum_squares_centered().get_column_count() == data.get_column_count()); + } + if (compute_mode.test(result_options::mean)) { + REQUIRE(result.get_mean().get_column_count() == data.get_column_count()); + } + if (compute_mode.test(result_options::second_order_raw_moment)) { REQUIRE(result.get_second_order_raw_moment().get_column_count() == data.get_column_count()); + } + if (compute_mode.test(result_options::variance)) { + REQUIRE(result.get_variance().get_column_count() == data.get_column_count()); + } + if (compute_mode.test(result_options::standard_deviation)) { REQUIRE(result.get_standard_deviation().get_column_count() == data.get_column_count()); + } + if (compute_mode.test(result_options::variation)) { REQUIRE(result.get_variation().get_column_count() == data.get_column_count()); } } diff --git a/cpp/oneapi/dal/algo/basic_statistics/test/online.cpp b/cpp/oneapi/dal/algo/basic_statistics/test/online.cpp index e974ac4fe26..1cb9330ac06 100644 --- a/cpp/oneapi/dal/algo/basic_statistics/test/online.cpp +++ b/cpp/oneapi/dal/algo/basic_statistics/test/online.cpp @@ -45,6 +45,7 @@ TEMPLATE_LIST_TEST_M(basic_statistics_online_test, std::shared_ptr weights; const bool use_weights = GENERATE(0, 1); + const int64_t nBlocks = GENERATE(1, 3, 10); if (use_weights) { const auto row_count = data.get_row_count(); @@ -59,7 +60,7 @@ TEMPLATE_LIST_TEST_M(basic_statistics_online_test, const bs::result_option_id compute_mode = GENERATE_COPY(res_min_max, res_mean_varc, res_all); - this->online_general_checks(data, weights, compute_mode); + this->online_general_checks(data, weights, compute_mode, nBlocks); } } // namespace oneapi::dal::basic_statistics::test