Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ridge Regression support in oneapi #2743

Merged
merged 18 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
4a03d23
added alpha parameter support to linear regression descriptor
DDJHB Apr 24, 2024
410da8d
Merge branch 'oneapi-src:main' into dev/ridge_reg_alpha_param
DDJHB Apr 25, 2024
3ef36b2
added support for rigde regression penalty in linear regression on ba…
DDJHB May 2, 2024
64c8a82
fixed gpu implementation, added switch to daal ridge on cpu
DDJHB May 8, 2024
cd3fb2e
added gpu/cpu implementations of ridge regression for online mode
DDJHB May 8, 2024
f6706a6
added dll linking for BatchKernel on ridge regression from daal
DDJHB May 10, 2024
0f39e21
dll linking for online kernel
DDJHB May 10, 2024
3630121
modified tests to support ridge regression in both batch and online m…
DDJHB May 13, 2024
872c7dd
refactored code per suggestions and adjusted tests to reduce floating…
DDJHB May 14, 2024
a49208c
generalized adding ridge penalty for both batch and online modes
DDJHB May 14, 2024
2b1f1e9
fixed alpha float type for gpus not supporting fp64
DDJHB May 15, 2024
1027417
adjusted alpha floating type from double to supported float for gpus …
DDJHB May 15, 2024
8fe5cdf
Merge branch 'oneapi-src:main' into dev/ridge_reg_alpha_param
DDJHB May 15, 2024
ed6feaa
switched linreg batch cpu switch to use daal batch instead of online …
DDJHB May 16, 2024
f6bddb4
Merge branch 'dev/ridge_reg_alpha_param' of https://github.com/DDJHB/…
DDJHB May 16, 2024
d99b79e
added dependency description of xtx on compute_intercept
DDJHB May 17, 2024
a1f030d
fixed dll for daal lin reg batch kernel
DDJHB May 17, 2024
43f06e7
updated copyright message for misc.hpp
DDJHB May 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ template class BatchContainer<DAAL_FPTYPE, normEqDense, DAAL_CPU>;

namespace internal
{
template class BatchKernel<DAAL_FPTYPE, normEqDense, DAAL_CPU>;
template class DAAL_EXPORT BatchKernel<DAAL_FPTYPE, normEqDense, DAAL_CPU>;

} // namespace internal
} // namespace training
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ template class OnlineContainer<DAAL_FPTYPE, normEqDense, DAAL_CPU>;

namespace internal
{
template class OnlineKernel<DAAL_FPTYPE, normEqDense, DAAL_CPU>;
template class DAAL_EXPORT OnlineKernel<DAAL_FPTYPE, normEqDense, DAAL_CPU>;

} // namespace internal
} // namespace training
Expand Down
1 change: 1 addition & 0 deletions cpp/oneapi/dal/algo/linear_regression/backend/cpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@ dal_module(
"@onedal//cpp/daal:core",
"@onedal//cpp/daal/src/algorithms/linear_model:kernel",
"@onedal//cpp/daal/src/algorithms/linear_regression:kernel",
"@onedal//cpp/daal/src/algorithms/ridge_regression:kernel"
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <daal/src/algorithms/linear_regression/linear_regression_train_kernel.h>
#include <daal/src/algorithms/linear_regression/linear_regression_hyperparameter_impl.h>
#include <daal/src/algorithms/ridge_regression/ridge_regression_train_kernel.h>

#include "oneapi/dal/backend/interop/common.hpp"
#include "oneapi/dal/backend/interop/error_converter.hpp"
Expand All @@ -37,21 +38,26 @@ namespace be = dal::backend;
namespace pr = be::primitives;
namespace interop = dal::backend::interop;
namespace daal_lr = daal::algorithms::linear_regression;
namespace daal_rr = daal::algorithms::ridge_regression;

using daal_hyperparameters_t = daal_lr::internal::Hyperparameter;
using daal_lr_hyperparameters_t = daal_lr::internal::Hyperparameter;

constexpr auto daal_method = daal_lr::training::normEqDense;
constexpr auto daal_lr_method = daal_lr::training::normEqDense;
constexpr auto daal_rr_method = daal_rr::training::normEqDense;

template <typename Float, daal::CpuType Cpu>
using online_kernel_t = daal_lr::training::internal::OnlineKernel<Float, daal_method, Cpu>;
using online_lr_kernel_t = daal_lr::training::internal::OnlineKernel<Float, daal_lr_method, Cpu>;

template <typename Float, daal::CpuType Cpu>
using online_rr_kernel_t = daal_rr::training::internal::OnlineKernel<Float, daal_rr_method, Cpu>;

template <typename Float, typename Task>
static daal_hyperparameters_t convert_parameters(const detail::train_parameters<Task>& params) {
static daal_lr_hyperparameters_t convert_parameters(const detail::train_parameters<Task>& params) {
using daal_lr::internal::HyperparameterId;

const std::int64_t block = params.get_cpu_macro_block();

daal_hyperparameters_t daal_hyperparameter;
daal_lr_hyperparameters_t daal_hyperparameter;
auto status = daal_hyperparameter.set(HyperparameterId::denseUpdateStepBlockSize, block);
interop::status_to_exception(status);

Expand All @@ -68,36 +74,58 @@ static train_result<Task> call_daal_kernel(const context_cpu& ctx,
using model_t = model<Task>;
using model_impl_t = detail::model_impl<Task>;

const bool beta = desc.get_compute_intercept();
const bool compute_intercept = desc.get_compute_intercept();

const auto response_count = input.get_partial_xty().get_row_count();
const auto ext_feature_count = input.get_partial_xty().get_column_count();

const auto feature_count = ext_feature_count - beta;
const auto feature_count = ext_feature_count - compute_intercept;

const auto betas_size = check_mul_overflow(response_count, feature_count + 1);
auto betas_arr = array<Float>::zeros(betas_size);

const daal_hyperparameters_t& hp = convert_parameters<Float>(params);
const daal_lr_hyperparameters_t& hp = convert_parameters<Float>(params);
DDJHB marked this conversation as resolved.
Show resolved Hide resolved

auto xtx_daal_table = interop::convert_to_daal_table<Float>(input.get_partial_xtx());
auto xty_daal_table = interop::convert_to_daal_table<Float>(input.get_partial_xty());
auto betas_daal_table =
interop::convert_to_daal_homogen_table(betas_arr, response_count, feature_count + 1);

{
const auto status = dal::backend::dispatch_by_cpu(ctx, [&](auto cpu) {
constexpr auto cpu_type = interop::to_daal_cpu_type<decltype(cpu)>::value;
return online_kernel_t<Float, cpu_type>().finalizeCompute(*xtx_daal_table,
*xty_daal_table,
*xtx_daal_table,
*xty_daal_table,
*betas_daal_table,
beta,
&hp);
});

interop::status_to_exception(status);
double alpha = desc.get_alpha();
if (alpha != 0.0) {
auto ridge_matrix_array = array<Float>::full(1, static_cast<Float>(alpha));
auto ridge_matrix = interop::convert_to_daal_homogen_table<Float>(ridge_matrix_array, 1, 1);

{
const auto status = dal::backend::dispatch_by_cpu(ctx, [&](auto cpu) {
constexpr auto cpu_type = interop::to_daal_cpu_type<decltype(cpu)>::value;
return online_rr_kernel_t<Float, cpu_type>().finalizeCompute(*xtx_daal_table,
*xty_daal_table,
*xtx_daal_table,
*xty_daal_table,
*betas_daal_table,
compute_intercept,
*ridge_matrix);
});

interop::status_to_exception(status);
}
}
else {
{
const auto status = dal::backend::dispatch_by_cpu(ctx, [&](auto cpu) {
constexpr auto cpu_type = interop::to_daal_cpu_type<decltype(cpu)>::value;
return online_lr_kernel_t<Float, cpu_type>().finalizeCompute(*xtx_daal_table,
*xty_daal_table,
*xtx_daal_table,
*xty_daal_table,
*betas_daal_table,
compute_intercept,
&hp);
});

interop::status_to_exception(status);
}
}

auto betas_table = homogen_table::wrap(betas_arr, response_count, feature_count + 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,14 @@ static partial_train_result<Task> call_daal_kernel(const context_cpu& ctx,
const partial_train_input<Task>& input) {
using dal::detail::check_mul_overflow;

const bool beta = desc.get_compute_intercept();
const bool compute_intercept = desc.get_compute_intercept();

const auto feature_count = input.get_data().get_column_count();
const auto response_count = input.get_responses().get_column_count();

const daal_hyperparameters_t& hp = convert_parameters<Float>(params);

const auto ext_feature_count = feature_count + beta;
const auto ext_feature_count = feature_count + compute_intercept;

const bool has_xtx_data = input.get_prev().get_partial_xtx().has_data();
if (has_xtx_data) {
Expand All @@ -85,7 +85,7 @@ static partial_train_result<Task> call_daal_kernel(const context_cpu& ctx,
*y_daal_table,
*daal_xtx,
*daal_xty,
beta,
compute_intercept,
&hp);

interop::status_to_exception(status);
Expand Down Expand Up @@ -117,7 +117,7 @@ static partial_train_result<Task> call_daal_kernel(const context_cpu& ctx,
*y_daal_table,
*xtx_daal_table,
*xty_daal_table,
beta,
compute_intercept,
&hp);

interop::status_to_exception(status);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <daal/src/algorithms/linear_regression/linear_regression_train_kernel.h>
#include <daal/src/algorithms/linear_regression/linear_regression_hyperparameter_impl.h>
#include <daal/src/algorithms/ridge_regression/ridge_regression_train_kernel.h>

#include "oneapi/dal/backend/interop/common.hpp"
#include "oneapi/dal/backend/interop/error_converter.hpp"
Expand All @@ -39,21 +40,26 @@ namespace be = dal::backend;
namespace pr = be::primitives;
namespace interop = dal::backend::interop;
namespace daal_lr = daal::algorithms::linear_regression;
namespace daal_rr = daal::algorithms::ridge_regression;

using daal_hyperparameters_t = daal_lr::internal::Hyperparameter;
using daal_lr_hyperparameters_t = daal_lr::internal::Hyperparameter;

constexpr auto daal_method = daal_lr::training::normEqDense;
constexpr auto daal_lr_method = daal_lr::training::normEqDense;
constexpr auto daal_rr_method = daal_rr::training::normEqDense;

template <typename Float, daal::CpuType Cpu>
using online_kernel_t = daal_lr::training::internal::OnlineKernel<Float, daal_method, Cpu>;
using online_lr_kernel_t = daal_lr::training::internal::OnlineKernel<Float, daal_lr_method, Cpu>;

template <typename Float, daal::CpuType Cpu>
using batch_rr_kernel_t = daal_rr::training::internal::BatchKernel<Float, daal_rr_method, Cpu>;

template <typename Float, typename Task>
static daal_hyperparameters_t convert_parameters(const detail::train_parameters<Task>& params) {
static daal_lr_hyperparameters_t convert_parameters(const detail::train_parameters<Task>& params) {
using daal_lr::internal::HyperparameterId;

const std::int64_t block = params.get_cpu_macro_block();

daal_hyperparameters_t daal_hyperparameter;
daal_lr_hyperparameters_t daal_hyperparameter;
auto status = daal_hyperparameter.set(HyperparameterId::denseUpdateStepBlockSize, block);
interop::status_to_exception(status);

Expand Down Expand Up @@ -97,33 +103,55 @@ static train_result<Task> call_daal_kernel(const context_cpu& ctx,
auto x_daal_table = interop::convert_to_daal_table<Float>(data);
auto y_daal_table = interop::convert_to_daal_table<Float>(resp);

const daal_hyperparameters_t& hp = convert_parameters<Float>(params);

{
const auto status = interop::call_daal_kernel<Float, online_kernel_t>(ctx,
*x_daal_table,
*y_daal_table,
*xtx_daal_table,
*xty_daal_table,
intp,
&hp);

interop::status_to_exception(status);
const daal_lr_hyperparameters_t& hp = convert_parameters<Float>(params);
DDJHB marked this conversation as resolved.
Show resolved Hide resolved

double alpha = desc.get_alpha();
if (alpha != 0.0) {
auto ridge_matrix_array = array<Float>::full(1, static_cast<Float>(alpha));
auto ridge_matrix = interop::convert_to_daal_homogen_table<Float>(ridge_matrix_array, 1, 1);

{
const auto status =
interop::call_daal_kernel<Float, batch_rr_kernel_t>(ctx,
DDJHB marked this conversation as resolved.
Show resolved Hide resolved
*x_daal_table,
*y_daal_table,
*xtx_daal_table,
*xty_daal_table,
*betas_daal_table,
intp,
*ridge_matrix);

interop::status_to_exception(status);
}
}

{
const auto status = dal::backend::dispatch_by_cpu(ctx, [&](auto cpu) {
constexpr auto cpu_type = interop::to_daal_cpu_type<decltype(cpu)>::value;
return online_kernel_t<Float, cpu_type>().finalizeCompute(*xtx_daal_table,
*xty_daal_table,
*xtx_daal_table,
*xty_daal_table,
*betas_daal_table,
intp,
&hp);
});

interop::status_to_exception(status);
else {
{
const auto status =
interop::call_daal_kernel<Float, online_lr_kernel_t>(ctx,
*x_daal_table,
*y_daal_table,
*xtx_daal_table,
*xty_daal_table,
intp,
&hp);

interop::status_to_exception(status);
}

{
const auto status = dal::backend::dispatch_by_cpu(ctx, [&](auto cpu) {
constexpr auto cpu_type = interop::to_daal_cpu_type<decltype(cpu)>::value;
return online_lr_kernel_t<Float, cpu_type>().finalizeCompute(*xtx_daal_table,
*xty_daal_table,
*xtx_daal_table,
*xty_daal_table,
*betas_daal_table,
intp,
&hp);
});

interop::status_to_exception(status);
}
}

auto betas_table = homogen_table::wrap(betas_arr, response_count, feature_count + 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,27 @@ using dal::backend::context_gpu;
namespace be = dal::backend;
namespace pr = be::primitives;

template <typename Float>
sycl::event add_ridge_penalty(sycl::queue& q,
const pr::ndarray<Float, 2, pr::ndorder::c>& xtx,
bool compute_intercept,
double alpha) {
ONEDAL_ASSERT(xtx.has_mutable_data());
ONEDAL_ASSERT(be::is_known_usm(q, xtx.get_mutable_data()));
ONEDAL_ASSERT(xtx.get_dimension(0) == xtx.get_dimension(1));

Float* xtx_ptr = xtx.get_mutable_data();
std::int64_t feature_count = xtx.get_dimension(0);
std::int64_t original_feature_count = feature_count - compute_intercept;

return q.submit([&](sycl::handler& cgh) {
const auto range = be::make_range_1d(original_feature_count);
cgh.parallel_for(range, [=](sycl::id<1> idx) {
xtx_ptr[idx * (feature_count + 1)] += alpha;
});
});
}

template <typename Float, typename Task>
static train_result<Task> call_dal_kernel(const context_gpu& ctx,
const detail::descriptor_base<Task>& desc,
Expand All @@ -47,14 +68,14 @@ static train_result<Task> call_dal_kernel(const context_gpu& ctx,

auto& queue = ctx.get_queue();

const bool beta = desc.get_compute_intercept();
const bool compute_intercept = desc.get_compute_intercept();

constexpr auto uplo = pr::mkl::uplo::upper;
constexpr auto alloc = sycl::usm::alloc::device;

const auto response_count = input.get_partial_xty().get_row_count();
const auto ext_feature_count = input.get_partial_xty().get_column_count();
const auto feature_count = ext_feature_count - beta;
const auto feature_count = ext_feature_count - compute_intercept;

const pr::ndshape<2> xtx_shape{ ext_feature_count, ext_feature_count };

Expand All @@ -69,9 +90,21 @@ static train_result<Task> call_dal_kernel(const context_gpu& ctx,
const auto betas_size = check_mul_overflow(response_count, feature_count + 1);
auto betas_arr = array<Float>::zeros(queue, betas_size, alloc);

double alpha = desc.get_alpha();
sycl::event ridge_event;
if (alpha != 0.0) {
ridge_event = add_ridge_penalty<Float>(queue, xtx_nd, compute_intercept, alpha);
}

auto nxtx = pr::ndarray<Float, 2>::empty(queue, xtx_shape, alloc);
auto nxty = pr::ndview<Float, 2>::wrap_mutable(betas_arr, betas_shape);
auto solve_event = pr::solve_system<uplo>(queue, beta, xtx_nd, xty_nd, nxtx, nxty, {});
auto solve_event = pr::solve_system<uplo>(queue,
compute_intercept,
xtx_nd,
xty_nd,
nxtx,
nxty,
{ ridge_event });
sycl::event::wait_and_throw({ solve_event });

auto betas = homogen_table::wrap(betas_arr, response_count, feature_count + 1);
Expand Down
Loading