Skip to content

Commit

Permalink
feature: Linear Regression online spmd support (#2846)
Browse files Browse the repository at this point in the history
  • Loading branch information
DDJHB authored Jul 18, 2024
1 parent 50584aa commit 7b7f61e
Show file tree
Hide file tree
Showing 8 changed files with 336 additions and 125 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,129 +14,32 @@
* limitations under the License.
*******************************************************************************/

#include "oneapi/dal/detail/common.hpp"
#include "oneapi/dal/backend/dispatcher.hpp"
#include "oneapi/dal/backend/primitives/ndarray.hpp"
#include "oneapi/dal/backend/primitives/lapack.hpp"
#include "oneapi/dal/backend/primitives/utils.hpp"

#include "oneapi/dal/table/row_accessor.hpp"

#include "oneapi/dal/algo/linear_regression/common.hpp"
#include "oneapi/dal/algo/linear_regression/train_types.hpp"
#include "oneapi/dal/algo/linear_regression/backend/model_impl.hpp"
#include "oneapi/dal/algo/linear_regression/backend/gpu/finalize_train_kernel.hpp"
#include "oneapi/dal/algo/linear_regression/backend/gpu/update_kernel.hpp"
#include "oneapi/dal/algo/linear_regression/backend/gpu/misc.hpp"

namespace oneapi::dal::linear_regression::backend {

using dal::backend::context_gpu;

namespace be = dal::backend;
namespace pr = be::primitives;

template <typename Float, typename Task>
static train_result<Task> call_dal_kernel(const context_gpu& ctx,
const detail::descriptor_base<Task>& desc,
const detail::train_parameters<Task>& params,
const partial_train_result<Task>& input) {
using dal::detail::check_mul_overflow;

using model_t = model<Task>;
using model_impl_t = detail::model_impl<Task>;

auto& queue = ctx.get_queue();

const bool compute_intercept = desc.get_compute_intercept();

constexpr auto uplo = pr::mkl::uplo::upper;
constexpr auto alloc = sycl::usm::alloc::device;

const auto response_count = input.get_partial_xty().get_row_count();
const auto ext_feature_count = input.get_partial_xty().get_column_count();
const auto feature_count = ext_feature_count - compute_intercept;

const pr::ndshape<2> xtx_shape{ ext_feature_count, ext_feature_count };

const auto xtx_nd =
pr::table2ndarray<Float>(queue, input.get_partial_xtx(), sycl::usm::alloc::device);
const auto xty_nd = pr::table2ndarray<Float, pr::ndorder::f>(queue,
input.get_partial_xty(),
sycl::usm::alloc::device);

const pr::ndshape<2> betas_shape{ response_count, feature_count + 1 };
#include "oneapi/dal/algo/linear_regression/backend/gpu/finalize_train_kernel_norm_eq_impl.hpp"

const auto betas_size = check_mul_overflow(response_count, feature_count + 1);
auto betas_arr = array<Float>::zeros(queue, betas_size, alloc);

double alpha = desc.get_alpha();
sycl::event ridge_event;
if (alpha != 0.0) {
ridge_event = add_ridge_penalty<Float>(queue, xtx_nd, compute_intercept, alpha);
}

auto nxtx = pr::ndarray<Float, 2>::empty(queue, xtx_shape, alloc);
auto nxty = pr::ndview<Float, 2>::wrap_mutable(betas_arr, betas_shape);
auto solve_event = pr::solve_system<uplo>(queue,
compute_intercept,
xtx_nd,
xty_nd,
nxtx,
nxty,
{ ridge_event });
sycl::event::wait_and_throw({ solve_event });

auto betas = homogen_table::wrap(betas_arr, response_count, feature_count + 1);

const auto model_impl = std::make_shared<model_impl_t>(betas);
const auto model = dal::detail::make_private<model_t>(model_impl);

const auto options = desc.get_result_options();
auto result = train_result<Task>().set_model(model).set_result_options(options);

if (options.test(result_options::intercept)) {
auto arr = array<Float>::zeros(queue, response_count, alloc);
auto dst = pr::ndview<Float, 2>::wrap_mutable(arr, { 1l, response_count });
const auto src = nxty.get_col_slice(0l, 1l).t();

pr::copy(queue, dst, src).wait_and_throw();

auto intercept = homogen_table::wrap(arr, 1l, response_count);
result.set_intercept(intercept);
}

if (options.test(result_options::coefficients)) {
const auto size = check_mul_overflow(response_count, feature_count);

auto arr = array<Float>::zeros(queue, size, alloc);
const auto src = nxty.get_col_slice(1l, feature_count + 1);
auto dst = pr::ndview<Float, 2>::wrap_mutable(arr, { response_count, feature_count });
#include "oneapi/dal/detail/common.hpp"

pr::copy(queue, dst, src).wait_and_throw();
#include "oneapi/dal/backend/dispatcher.hpp"

auto coefficients = homogen_table::wrap(arr, response_count, feature_count);
result.set_coefficients(coefficients);
}
namespace oneapi::dal::linear_regression::backend {

return result;
}
namespace bk = dal::backend;

template <typename Float, typename Task>
static train_result<Task> train(const context_gpu& ctx,
const detail::descriptor_base<Task>& desc,
const detail::train_parameters<Task>& params,
const partial_train_result<Task>& input) {
return call_dal_kernel<Float, Task>(ctx, desc, params, input);
static train_result<Task> finalize_train(const bk::context_gpu& ctx,
const detail::descriptor_base<Task>& desc,
const detail::train_parameters<Task>& params,
const partial_train_result<Task>& input) {
return finalize_train_kernel_norm_eq_impl<Float, Task>(ctx)(desc, params, input);
}

template <typename Float, typename Task>
struct finalize_train_kernel_gpu<Float, method::norm_eq, Task> {
train_result<Task> operator()(const context_gpu& ctx,
train_result<Task> operator()(const bk::context_gpu& ctx,
const detail::descriptor_base<Task>& desc,
const detail::train_parameters<Task>& params,
const partial_train_result<Task>& input) const {
return train<Float, Task>(ctx, desc, params, input);
return finalize_train<Float, Task>(ctx, desc, params, input);
}
};

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*******************************************************************************
* Copyright contributors to the oneDAL project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#pragma once

#include "oneapi/dal/algo/linear_regression/backend/gpu/finalize_train_kernel.hpp"
#include "oneapi/dal/backend/primitives/utils.hpp"

#ifdef ONEDAL_DATA_PARALLEL

namespace oneapi::dal::linear_regression::backend {

namespace bk = dal::backend;

template <typename Float, typename Task>
class finalize_train_kernel_norm_eq_impl {
using comm_t = bk::communicator<spmd::device_memory_access::usm>;
using input_t = partial_train_result<Task>;
using result_t = train_result<Task>;
using descriptor_t = detail::descriptor_base<Task>;
using train_parameters_t = detail::train_parameters<Task>;

public:
finalize_train_kernel_norm_eq_impl(const bk::context_gpu& ctx)
: q(ctx.get_queue()),
comm_(ctx.get_communicator()) {}
result_t operator()(const descriptor_t& desc,
const train_parameters_t& params,
const input_t& input);

private:
sycl::queue q;
comm_t comm_;
};

} // namespace oneapi::dal::linear_regression::backend

#endif // ONEDAL_DATA_PARALLEL
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*******************************************************************************
* Copyright contributors to the oneDAL project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#include "oneapi/dal/algo/linear_regression/backend/gpu/finalize_train_kernel_norm_eq_impl.hpp"
#include "oneapi/dal/algo/linear_regression/backend/gpu/misc.hpp"
#include "oneapi/dal/algo/linear_regression/backend/model_impl.hpp"

#include "oneapi/dal/backend/primitives/lapack.hpp"

namespace oneapi::dal::linear_regression::backend {

namespace be = dal::backend;
namespace pr = be::primitives;

using be::context_gpu;

template <typename Float, typename Task>
train_result<Task> finalize_train_kernel_norm_eq_impl<Float, Task>::operator()(
const detail::descriptor_base<Task>& desc,
const detail::train_parameters<Task>& params,
const partial_train_result<Task>& input) {
using dal::detail::check_mul_overflow;

using model_t = model<Task>;
using model_impl_t = detail::model_impl<Task>;

const bool compute_intercept = desc.get_compute_intercept();

constexpr auto uplo = pr::mkl::uplo::upper;
constexpr auto alloc = sycl::usm::alloc::device;

const auto response_count = input.get_partial_xty().get_row_count();
const auto ext_feature_count = input.get_partial_xty().get_column_count();
const auto feature_count = ext_feature_count - compute_intercept;

const pr::ndshape<2> xtx_shape{ ext_feature_count, ext_feature_count };

const auto xtx_nd =
pr::table2ndarray<Float>(q, input.get_partial_xtx(), sycl::usm::alloc::device);
const auto xty_nd = pr::table2ndarray<Float, pr::ndorder::f>(q,
input.get_partial_xty(),
sycl::usm::alloc::device);

const pr::ndshape<2> betas_shape{ response_count, feature_count + 1 };

const auto betas_size = check_mul_overflow(response_count, feature_count + 1);
auto betas_arr = array<Float>::zeros(q, betas_size, alloc);

if (comm_.get_rank_count() > 1) {
{
ONEDAL_PROFILER_TASK(xtx_allreduce);
auto xtx_arr =
dal::array<Float>::wrap(q, xtx_nd.get_mutable_data(), xtx_nd.get_count());
comm_.allreduce(xtx_arr).wait();
}
{
ONEDAL_PROFILER_TASK(xty_allreduce);
auto xty_arr =
dal::array<Float>::wrap(q, xty_nd.get_mutable_data(), xty_nd.get_count());
comm_.allreduce(xty_arr).wait();
}
}

double alpha = desc.get_alpha();
sycl::event ridge_event;
if (alpha != 0.0) {
ridge_event = add_ridge_penalty<Float>(q, xtx_nd, compute_intercept, alpha);
}

auto nxtx = pr::ndarray<Float, 2>::empty(q, xtx_shape, alloc);
auto nxty = pr::ndview<Float, 2>::wrap_mutable(betas_arr, betas_shape);
auto solve_event =
pr::solve_system<uplo>(q, compute_intercept, xtx_nd, xty_nd, nxtx, nxty, { ridge_event });
sycl::event::wait_and_throw({ solve_event });

auto betas = homogen_table::wrap(betas_arr, response_count, feature_count + 1);

const auto model_impl = std::make_shared<model_impl_t>(betas);
const auto model = dal::detail::make_private<model_t>(model_impl);

const auto options = desc.get_result_options();
auto result = train_result<Task>().set_model(model).set_result_options(options);

if (options.test(result_options::intercept)) {
auto arr = array<Float>::zeros(q, response_count, alloc);
auto dst = pr::ndview<Float, 2>::wrap_mutable(arr, { 1l, response_count });
const auto src = nxty.get_col_slice(0l, 1l).t();

pr::copy(q, dst, src).wait_and_throw();

auto intercept = homogen_table::wrap(arr, 1l, response_count);
result.set_intercept(intercept);
}

if (options.test(result_options::coefficients)) {
const auto size = check_mul_overflow(response_count, feature_count);

auto arr = array<Float>::zeros(q, size, alloc);
const auto src = nxty.get_col_slice(1l, feature_count + 1);
auto dst = pr::ndview<Float, 2>::wrap_mutable(arr, { response_count, feature_count });

pr::copy(q, dst, src).wait_and_throw();

auto coefficients = homogen_table::wrap(arr, response_count, feature_count);
result.set_coefficients(coefficients);
}

return result;
}

template class finalize_train_kernel_norm_eq_impl<float, task::regression>;
template class finalize_train_kernel_norm_eq_impl<double, task::regression>;

} // namespace oneapi::dal::linear_regression::backend
4 changes: 2 additions & 2 deletions cpp/oneapi/dal/algo/linear_regression/backend/gpu/misc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,15 @@ sycl::event add_ridge_penalty(sycl::queue& q,
Float alpha,
const bk::event_vector& deps = {}) {
ONEDAL_ASSERT(xtx.has_mutable_data());
ONEDAL_ASSERT(be::is_known_usm(q, xtx.get_mutable_data()));
ONEDAL_ASSERT(bk::is_known_usm(q, xtx.get_mutable_data()));
ONEDAL_ASSERT(xtx.get_dimension(0) == xtx.get_dimension(1));

Float* xtx_ptr = xtx.get_mutable_data();
std::int64_t feature_count = xtx.get_dimension(0);
std::int64_t original_feature_count = feature_count - compute_intercept;

return q.submit([&](sycl::handler& cgh) {
const auto range = be::make_range_1d(original_feature_count);
const auto range = bk::make_range_1d(original_feature_count);
cgh.depends_on(deps);
std::int64_t step = feature_count + 1;
cgh.parallel_for(range, [=](sycl::id<1> idx) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,9 @@ static train_result<Task> call_dal_kernel(const context_gpu& ctx,
old_x_arr = std::move(x_arr), old_y_arr = std::move(y_arr);
}

const be::event_vector solve_deps{ last_xty_event, last_xtx_event };

double alpha = desc.get_alpha();
if (alpha != 0.0) {
last_xtx_event =
add_ridge_penalty<Float>(queue, xtx, compute_intercept, alpha, { last_xtx_event });
}

auto& comm = ctx.get_communicator();
if (comm.get_rank_count() > 1) {
sycl::event::wait_and_throw(solve_deps);
sycl::event::wait_and_throw({ last_xty_event, last_xtx_event });
{
ONEDAL_PROFILER_TASK(xtx_allreduce);
auto xtx_arr = dal::array<Float>::wrap(queue, xtx.get_mutable_data(), xtx.get_count());
Expand All @@ -127,6 +119,13 @@ static train_result<Task> call_dal_kernel(const context_gpu& ctx,
}
}

double alpha = desc.get_alpha();
if (alpha != 0.0) {
last_xtx_event =
add_ridge_penalty<Float>(queue, xtx, compute_intercept, alpha, { last_xtx_event });
}
const be::event_vector solve_deps{ last_xty_event, last_xtx_event };

auto nxtx = pr::ndarray<Float, 2>::empty(queue, xtx_shape, alloc);
auto nxty = pr::ndview<Float, 2>::wrap_mutable(betas_arr, betas_shape);
auto solve_event =
Expand Down
Loading

0 comments on commit 7b7f61e

Please sign in to comment.