Skip to content

Commit

Permalink
[fix] Fix online SPMD algorithms finalize call (#2882)
Browse files Browse the repository at this point in the history
* Fix finalize method for all four online SPMD algorithms, now the call does not affect partial results since all necessary tables are copied. It allows to call finalize between calls of partial_fit to obtain intermediate results
  • Loading branch information
olegkkruglov authored Aug 29, 2024
1 parent 1310739 commit ac5ea85
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

#include "oneapi/dal/algo/basic_statistics/backend/basic_statistics_interop.hpp"

#ifdef ONEDAL_DATA_PARALLEL

namespace oneapi::dal::basic_statistics::backend {

namespace bk = dal::backend;
Expand Down Expand Up @@ -151,16 +153,21 @@ result_t finalize_compute_kernel_dense_impl<Float>::operator()(const descriptor_
const auto nobs_nd = pr::table2ndarray_1d<Float>(q, input.get_partial_n_rows());

auto rows_count_global = nobs_nd.get_data()[0];
auto is_distributed = (comm_.get_rank_count() > 1);
{
ONEDAL_PROFILER_TASK(allreduce_rows_count_global);
comm_.allreduce(rows_count_global, spmd::reduce_op::sum).wait();
if (is_distributed) {
comm_.allreduce(rows_count_global, spmd::reduce_op::sum).wait();
}
}
if (res_op.test(result_options::min)) {
ONEDAL_ASSERT(input.get_partial_min().get_column_count() == column_count);
const auto min =
pr::table2ndarray_1d<Float>(q, input.get_partial_min(), sycl::usm::alloc::device);

{ comm_.allreduce(min.flatten(q, {}), spmd::reduce_op::min).wait(); }
if (is_distributed) {
comm_.allreduce(min.flatten(q, {}), spmd::reduce_op::min).wait();
}
res.set_min(homogen_table::wrap(min.flatten(q, {}), 1, column_count));
}

Expand All @@ -174,27 +181,48 @@ result_t finalize_compute_kernel_dense_impl<Float>::operator()(const descriptor_
}

if (res_op_partial.test(result_options::sum)) {
const auto sums_nd =
auto sums_nd =
pr::table2ndarray_1d<Float>(q, input.get_partial_sum(), sycl::usm::alloc::device);
{
ONEDAL_PROFILER_TASK(allreduce_sums, q);
comm_.allreduce(sums_nd.flatten(q, {}), spmd::reduce_op::sum).wait();
}
const auto sums2_nd = pr::table2ndarray_1d<Float>(q,
input.get_partial_sum_squares(),
sycl::usm::alloc::device);
{
ONEDAL_PROFILER_TASK(allreduce_sums, q);
comm_.allreduce(sums2_nd.flatten(q, {}), spmd::reduce_op::sum).wait();
}
const auto sums2cent_nd =
pr::table2ndarray_1d<Float>(q,
input.get_partial_sum_squares_centered(),
sycl::usm::alloc::device);
{
ONEDAL_PROFILER_TASK(allreduce_sums, q);
comm_.allreduce(sums2cent_nd.flatten(q, {}), spmd::reduce_op::sum).wait();
auto sums2_nd = pr::table2ndarray_1d<Float>(q,
input.get_partial_sum_squares(),
sycl::usm::alloc::device);

auto sums2cent_nd = pr::table2ndarray_1d<Float>(q,
input.get_partial_sum_squares_centered(),
sycl::usm::alloc::device);
if (is_distributed) {
auto sums_nd_copy =
pr::ndarray<Float, 1>::empty(q, { column_count }, sycl::usm::alloc::device);
auto copy_event = copy(q, sums_nd_copy, sums_nd, {});
copy_event.wait_and_throw();
sums_nd = sums_nd_copy;

{
ONEDAL_PROFILER_TASK(allreduce_sums, q);
comm_.allreduce(sums_nd.flatten(q, {}), spmd::reduce_op::sum).wait();
}

auto sums2_nd_copy =
pr::ndarray<Float, 1>::empty(q, { column_count }, sycl::usm::alloc::device);
copy_event = copy(q, sums2_nd_copy, sums2_nd, {});
copy_event.wait_and_throw();
sums2_nd = sums2_nd_copy;

{
ONEDAL_PROFILER_TASK(allreduce_sums, q);
comm_.allreduce(sums2_nd.flatten(q, {}), spmd::reduce_op::sum).wait();
}
auto sums2cent_nd_copy =
pr::ndarray<Float, 1>::empty(q, { column_count }, sycl::usm::alloc::device);
copy_event = copy(q, sums2cent_nd_copy, sums2cent_nd, {});
copy_event.wait_and_throw();
sums2cent_nd = sums2cent_nd_copy;
{
ONEDAL_PROFILER_TASK(allreduce_sums, q);
comm_.allreduce(sums2cent_nd.flatten(q, {}), spmd::reduce_op::sum).wait();
}
}

auto [result_means,
result_variance,
result_raw_moment,
Expand All @@ -210,18 +238,20 @@ result_t finalize_compute_kernel_dense_impl<Float>::operator()(const descriptor_

if (res_op.test(result_options::sum)) {
ONEDAL_ASSERT(input.get_partial_sum().get_column_count() == column_count);
res.set_sum(input.get_partial_sum());
res.set_sum(homogen_table::wrap(sums_nd.flatten(q, { update_event }), 1, column_count));
}

if (res_op.test(result_options::sum_squares)) {
ONEDAL_ASSERT(input.get_partial_sum_squares().get_column_count() == column_count);
res.set_sum_squares(input.get_partial_sum_squares());
res.set_sum_squares(
homogen_table::wrap(sums2_nd.flatten(q, { update_event }), 1, column_count));
}

if (res_op.test(result_options::sum_squares_centered)) {
ONEDAL_ASSERT(input.get_partial_sum_squares_centered().get_column_count() ==
column_count);
res.set_sum_squares_centered(input.get_partial_sum_squares_centered());
res.set_sum_squares_centered(
homogen_table::wrap(sums2cent_nd.flatten(q, { update_event }), 1, column_count));
}

if (res_op.test(result_options::mean)) {
Expand Down Expand Up @@ -264,3 +294,5 @@ template class finalize_compute_kernel_dense_impl<float>;
template class finalize_compute_kernel_dense_impl<double>;

} // namespace oneapi::dal::basic_statistics::backend

#endif // ONEDAL_DATA_PARALLEL
10 changes: 6 additions & 4 deletions cpp/oneapi/dal/algo/basic_statistics/test/online_spmd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,9 @@ class basic_statistics_online_spmd_test
}
partial_results.push_back(partial_result);
}
const auto compute_result = this->finalize_compute_override(bs_desc, partial_results);

auto compute_result = this->finalize_compute_override(bs_desc, partial_results);
base_t::check_compute_result(compute_mode, data, weights, compute_result);
compute_result = this->finalize_compute_override(bs_desc, partial_results);
base_t::check_compute_result(compute_mode, data, weights, compute_result);
}
else {
Expand All @@ -103,8 +104,9 @@ class basic_statistics_online_spmd_test
}
partial_results.push_back(partial_result);
}
const auto compute_result = this->finalize_compute_override(bs_desc, partial_results);

auto compute_result = this->finalize_compute_override(bs_desc, partial_results);
base_t::check_compute_result(compute_mode, data, table{}, compute_result);
compute_result = this->finalize_compute_override(bs_desc, partial_results);
base_t::check_compute_result(compute_mode, data, table{}, compute_result);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,28 +66,38 @@ result_t finalize_compute_kernel_dense_impl<Float>::operator()(const descriptor_

const auto nobs_host = pr::table2ndarray<Float>(q, input.get_partial_n_rows());
auto rows_count_global = nobs_host.get_data()[0];
{
ONEDAL_PROFILER_TASK(allreduce_rows_count_global);
comm_.allreduce(rows_count_global, spmd::reduce_op::sum).wait();
}

ONEDAL_ASSERT(rows_count_global > 0);
auto sums = pr::table2ndarray_1d<Float>(q, input.get_partial_sum(), sycl::usm::alloc::device);
auto xtx =
pr::table2ndarray<Float>(q, input.get_partial_crossproduct(), sycl::usm::alloc::device);

const auto sums =
pr::table2ndarray_1d<Float>(q, input.get_partial_sum(), sycl::usm::alloc::device);
if (comm_.get_rank_count() > 1) {
{
ONEDAL_PROFILER_TASK(allreduce_rows_count_global);
comm_.allreduce(rows_count_global, spmd::reduce_op::sum).wait();
}
auto sums_copy =
pr::ndarray<Float, 1>::empty(q, { column_count }, sycl::usm::alloc::device);
auto copy_event = copy(q, sums_copy, sums, {});
copy_event.wait_and_throw();
sums = sums_copy;
{
ONEDAL_PROFILER_TASK(allreduce_sums, q);
comm_.allreduce(sums.flatten(q, {}), spmd::reduce_op::sum).wait();
}

{
ONEDAL_PROFILER_TASK(allreduce_sums, q);
comm_.allreduce(sums.flatten(q, {}), spmd::reduce_op::sum).wait();
auto xtx_copy = pr::ndarray<Float, 2>::empty(q,
{ column_count, column_count },
sycl::usm::alloc::device);
copy_event = copy(q, xtx_copy, xtx, {});
copy_event.wait_and_throw();
xtx = xtx_copy;
{
ONEDAL_PROFILER_TASK(allreduce_xtx, q);
comm_.allreduce(xtx.flatten(q, {}), spmd::reduce_op::sum).wait();
}
}

const auto xtx =
pr::table2ndarray<Float>(q, input.get_partial_crossproduct(), sycl::usm::alloc::device);

{
ONEDAL_PROFILER_TASK(allreduce_xtx, q);
comm_.allreduce(xtx.flatten(q, {}), spmd::reduce_op::sum).wait();
}
ONEDAL_ASSERT(rows_count_global > 0);

if (desc.get_result_options().test(result_options::cov_matrix)) {
auto [cov, cov_event] =
Expand Down
5 changes: 3 additions & 2 deletions cpp/oneapi/dal/algo/covariance/test/online_spmd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,9 @@ class covariance_online_spmd_test
}
partial_results.push_back(partial_result);
}
const auto compute_result = this->finalize_compute_override(cov_desc, partial_results);

auto compute_result = this->finalize_compute_override(cov_desc, partial_results);
base_t::check_compute_result(cov_desc, data, compute_result);
compute_result = this->finalize_compute_override(cov_desc, partial_results);
base_t::check_compute_result(cov_desc, data, compute_result);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

#include "oneapi/dal/backend/primitives/lapack.hpp"

#ifdef ONEDAL_DATA_PARALLEL

namespace oneapi::dal::linear_regression::backend {

namespace be = dal::backend;
Expand Down Expand Up @@ -47,25 +49,32 @@ train_result<Task> finalize_train_kernel_norm_eq_impl<Float, Task>::operator()(
const auto feature_count = ext_feature_count - compute_intercept;

const pr::ndshape<2> xtx_shape{ ext_feature_count, ext_feature_count };

const auto xtx_nd =
pr::table2ndarray<Float>(q, input.get_partial_xtx(), sycl::usm::alloc::device);
const auto xty_nd = pr::table2ndarray<Float, pr::ndorder::f>(q,
input.get_partial_xty(),
sycl::usm::alloc::device);

const pr::ndshape<2> betas_shape{ response_count, feature_count + 1 };

auto xtx_nd = pr::table2ndarray<Float>(q, input.get_partial_xtx(), sycl::usm::alloc::device);
auto xty_nd = pr::table2ndarray<Float, pr::ndorder::f>(q,
input.get_partial_xty(),
sycl::usm::alloc::device);

const auto betas_size = check_mul_overflow(response_count, feature_count + 1);
auto betas_arr = array<Float>::zeros(q, betas_size, alloc);

if (comm_.get_rank_count() > 1) {
auto xtx_nd_copy = pr::ndarray<Float, 2>::empty(q, xtx_shape, sycl::usm::alloc::device);
auto copy_event = copy(q, xtx_nd_copy, xtx_nd, {});
copy_event.wait_and_throw();
xtx_nd = xtx_nd_copy;
{
ONEDAL_PROFILER_TASK(xtx_allreduce);
auto xtx_arr =
dal::array<Float>::wrap(q, xtx_nd.get_mutable_data(), xtx_nd.get_count());
comm_.allreduce(xtx_arr).wait();
}
auto xty_nd_copy =
pr::ndarray<Float, 2, pr::ndorder::f>::empty(q, betas_shape, sycl::usm::alloc::device);
copy_event = copy(q, xty_nd_copy, xty_nd, {});
copy_event.wait_and_throw();
xty_nd = xty_nd_copy;
{
ONEDAL_PROFILER_TASK(xty_allreduce);
auto xty_arr =
Expand Down Expand Up @@ -125,3 +134,5 @@ template class finalize_train_kernel_norm_eq_impl<float, task::regression>;
template class finalize_train_kernel_norm_eq_impl<double, task::regression>;

} // namespace oneapi::dal::linear_regression::backend

#endif // ONEDAL_DATA_PARALLEL
14 changes: 13 additions & 1 deletion cpp/oneapi/dal/algo/linear_regression/test/online_spmd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class lr_online_spmd_test : public lr_test<TestType, lr_online_spmd_test<TestTyp
partial_results.push_back(partial_result);
}

const auto train_result = this->finalize_train_override(desc, partial_results);
auto train_result = this->finalize_train_override(desc, partial_results);

SECTION("Checking intercept values") {
if (desc.get_result_options().test(result_options::intercept))
Expand All @@ -105,6 +105,18 @@ class lr_online_spmd_test : public lr_test<TestType, lr_online_spmd_test<TestTyp
if (desc.get_result_options().test(result_options::coefficients))
base_t::check_if_close(train_result.get_coefficients(), base_t::beta_, tol);
}

train_result = this->finalize_train_override(desc, partial_results);

SECTION("Checking intercept values after double finalize") {
if (desc.get_result_options().test(result_options::intercept))
base_t::check_if_close(train_result.get_intercept(), base_t::bias_, tol);
}

SECTION("Checking coefficient values after double finalize") {
if (desc.get_result_options().test(result_options::coefficients))
base_t::check_if_close(train_result.get_coefficients(), base_t::beta_, tol);
}
}

private:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
#include "oneapi/dal/algo/pca/backend/sign_flip.hpp"
#include "oneapi/dal/table/row_accessor.hpp"

#ifdef ONEDAL_DATA_PARALLEL

namespace oneapi::dal::pca::backend {

namespace bk = dal::backend;
Expand Down Expand Up @@ -57,30 +59,42 @@ result_t finalize_train_kernel_cov_impl<Float>::operator()(const descriptor_t& d

const auto nobs_host = pr::table2ndarray<Float>(q, input.get_partial_n_rows());
auto rows_count_global = nobs_host.get_data()[0];
{
ONEDAL_PROFILER_TASK(allreduce_rows_count_global);
comm_.allreduce(rows_count_global, spmd::reduce_op::sum).wait();
}

const auto sums =
pr::table2ndarray_1d<Float>(q, input.get_partial_sum(), sycl::usm::alloc::device);

{
ONEDAL_PROFILER_TASK(allreduce_sums, q);
comm_.allreduce(sums.flatten(q, {}), spmd::reduce_op::sum).wait();
auto sums = pr::table2ndarray_1d<Float>(q, input.get_partial_sum(), sycl::usm::alloc::device);
auto xtx =
pr::table2ndarray<Float>(q, input.get_partial_crossproduct(), sycl::usm::alloc::device);
if (comm_.get_rank_count() > 1) {
{
ONEDAL_PROFILER_TASK(allreduce_rows_count_global);
comm_.allreduce(rows_count_global, spmd::reduce_op::sum).wait();
}
auto sums_copy =
pr::ndarray<Float, 1>::empty(q, { column_count }, sycl::usm::alloc::device);
auto copy_event = copy(q, sums_copy, sums, {});
copy_event.wait_and_throw();
sums = sums_copy;

auto xtx_copy = pr::ndarray<Float, 2>::empty(q,
{ column_count, column_count },
sycl::usm::alloc::device);
copy_event = copy(q, xtx_copy, xtx, {});
copy_event.wait_and_throw();
xtx = xtx_copy;

{
ONEDAL_PROFILER_TASK(allreduce_sums, q);
comm_.allreduce(sums.flatten(q, {}), spmd::reduce_op::sum).wait();
}

{
ONEDAL_PROFILER_TASK(allreduce_xtx, q);
comm_.allreduce(xtx.flatten(q, {}), spmd::reduce_op::sum).wait();
}
}

if (desc.get_result_options().test(result_options::means)) {
auto [means, means_event] = compute_means(q, sums, rows_count_global, {});
result.set_means(homogen_table::wrap(means.flatten(q, { means_event }), 1, column_count));
}

const auto xtx =
pr::table2ndarray<Float>(q, input.get_partial_crossproduct(), sycl::usm::alloc::device);
{
ONEDAL_PROFILER_TASK(allreduce_xtx, q);
comm_.allreduce(xtx.flatten(q, {}), spmd::reduce_op::sum).wait();
}
auto [cov, cov_event] = compute_covariance(q, rows_count_global, xtx, sums, {});

auto [vars, vars_event] = compute_variances(q, cov, { cov_event });
Expand Down Expand Up @@ -144,3 +158,5 @@ template class finalize_train_kernel_cov_impl<float>;
template class finalize_train_kernel_cov_impl<double>;

} // namespace oneapi::dal::pca::backend

#endif // ONEDAL_DATA_PARALLEL
4 changes: 3 additions & 1 deletion cpp/oneapi/dal/algo/pca/test/online_spmd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,10 @@ class pca_online_spmd_test : public pca_test<TestType, pca_online_spmd_test<Test
}
partial_results.push_back(partial_result);
}
const auto train_result = this->finalize_train_override(pca_desc, partial_results);
auto train_result = this->finalize_train_override(pca_desc, partial_results);
base_t::check_train_result(pca_desc, data_fr, train_result);

train_result = this->finalize_train_override(pca_desc, partial_results);
base_t::check_train_result(pca_desc, data_fr, train_result);
}

Expand Down

0 comments on commit ac5ea85

Please sign in to comment.