diff --git a/cpp/daal/src/algorithms/linear_regression/linear_regression_train_dense_normeq_batch_fpt_cpu.cpp b/cpp/daal/src/algorithms/linear_regression/linear_regression_train_dense_normeq_batch_fpt_cpu.cpp index ef9ab58c256..bb0a22d089e 100644 --- a/cpp/daal/src/algorithms/linear_regression/linear_regression_train_dense_normeq_batch_fpt_cpu.cpp +++ b/cpp/daal/src/algorithms/linear_regression/linear_regression_train_dense_normeq_batch_fpt_cpu.cpp @@ -39,7 +39,7 @@ template class BatchContainer; } namespace internal { -template class BatchKernel; +template class DAAL_EXPORT BatchKernel; } } // namespace training } // namespace linear_regression diff --git a/cpp/daal/src/algorithms/ridge_regression/ridge_regression_train_dense_normeq_batch_fpt_cpu.cpp b/cpp/daal/src/algorithms/ridge_regression/ridge_regression_train_dense_normeq_batch_fpt_cpu.cpp index 2c71f4d64a0..e1ed3085861 100644 --- a/cpp/daal/src/algorithms/ridge_regression/ridge_regression_train_dense_normeq_batch_fpt_cpu.cpp +++ b/cpp/daal/src/algorithms/ridge_regression/ridge_regression_train_dense_normeq_batch_fpt_cpu.cpp @@ -41,7 +41,7 @@ template class BatchContainer; namespace internal { -template class BatchKernel; +template class DAAL_EXPORT BatchKernel; } // namespace internal } // namespace training diff --git a/cpp/daal/src/algorithms/ridge_regression/ridge_regression_train_dense_normeq_online_fpt_cpu.cpp b/cpp/daal/src/algorithms/ridge_regression/ridge_regression_train_dense_normeq_online_fpt_cpu.cpp index 867f3a23b56..c82553c834a 100644 --- a/cpp/daal/src/algorithms/ridge_regression/ridge_regression_train_dense_normeq_online_fpt_cpu.cpp +++ b/cpp/daal/src/algorithms/ridge_regression/ridge_regression_train_dense_normeq_online_fpt_cpu.cpp @@ -40,7 +40,7 @@ template class OnlineContainer; namespace internal { -template class OnlineKernel; +template class DAAL_EXPORT OnlineKernel; } // namespace internal } // namespace training diff --git a/cpp/oneapi/dal/algo/linear_regression/backend/cpu/BUILD b/cpp/oneapi/dal/algo/linear_regression/backend/cpu/BUILD index 55adfee47a9..7bd3d6e679d 100644 --- a/cpp/oneapi/dal/algo/linear_regression/backend/cpu/BUILD +++ b/cpp/oneapi/dal/algo/linear_regression/backend/cpu/BUILD @@ -15,5 +15,6 @@ dal_module( "@onedal//cpp/daal:core", "@onedal//cpp/daal/src/algorithms/linear_model:kernel", "@onedal//cpp/daal/src/algorithms/linear_regression:kernel", + "@onedal//cpp/daal/src/algorithms/ridge_regression:kernel" ], ) diff --git a/cpp/oneapi/dal/algo/linear_regression/backend/cpu/finalize_train_kernel_norm_eq.cpp b/cpp/oneapi/dal/algo/linear_regression/backend/cpu/finalize_train_kernel_norm_eq.cpp index 5540641d8fd..88b1c58ccc4 100644 --- a/cpp/oneapi/dal/algo/linear_regression/backend/cpu/finalize_train_kernel_norm_eq.cpp +++ b/cpp/oneapi/dal/algo/linear_regression/backend/cpu/finalize_train_kernel_norm_eq.cpp @@ -16,6 +16,7 @@ #include #include +#include #include "oneapi/dal/backend/interop/common.hpp" #include "oneapi/dal/backend/interop/error_converter.hpp" @@ -37,21 +38,26 @@ namespace be = dal::backend; namespace pr = be::primitives; namespace interop = dal::backend::interop; namespace daal_lr = daal::algorithms::linear_regression; +namespace daal_rr = daal::algorithms::ridge_regression; -using daal_hyperparameters_t = daal_lr::internal::Hyperparameter; +using daal_lr_hyperparameters_t = daal_lr::internal::Hyperparameter; -constexpr auto daal_method = daal_lr::training::normEqDense; +constexpr auto daal_lr_method = daal_lr::training::normEqDense; +constexpr auto daal_rr_method = daal_rr::training::normEqDense; template -using online_kernel_t = daal_lr::training::internal::OnlineKernel; +using online_lr_kernel_t = daal_lr::training::internal::OnlineKernel; + +template +using online_rr_kernel_t = daal_rr::training::internal::OnlineKernel; template -static daal_hyperparameters_t convert_parameters(const detail::train_parameters& params) { +static daal_lr_hyperparameters_t convert_parameters(const detail::train_parameters& params) { using daal_lr::internal::HyperparameterId; const std::int64_t block = params.get_cpu_macro_block(); - daal_hyperparameters_t daal_hyperparameter; + daal_lr_hyperparameters_t daal_hyperparameter; auto status = daal_hyperparameter.set(HyperparameterId::denseUpdateStepBlockSize, block); interop::status_to_exception(status); @@ -68,36 +74,58 @@ static train_result call_daal_kernel(const context_cpu& ctx, using model_t = model; using model_impl_t = detail::model_impl; - const bool beta = desc.get_compute_intercept(); + const bool compute_intercept = desc.get_compute_intercept(); const auto response_count = input.get_partial_xty().get_row_count(); const auto ext_feature_count = input.get_partial_xty().get_column_count(); - const auto feature_count = ext_feature_count - beta; + const auto feature_count = ext_feature_count - compute_intercept; const auto betas_size = check_mul_overflow(response_count, feature_count + 1); auto betas_arr = array::zeros(betas_size); - const daal_hyperparameters_t& hp = convert_parameters(params); - auto xtx_daal_table = interop::convert_to_daal_table(input.get_partial_xtx()); auto xty_daal_table = interop::convert_to_daal_table(input.get_partial_xty()); auto betas_daal_table = interop::convert_to_daal_homogen_table(betas_arr, response_count, feature_count + 1); - { - const auto status = dal::backend::dispatch_by_cpu(ctx, [&](auto cpu) { - constexpr auto cpu_type = interop::to_daal_cpu_type::value; - return online_kernel_t().finalizeCompute(*xtx_daal_table, - *xty_daal_table, - *xtx_daal_table, - *xty_daal_table, - *betas_daal_table, - beta, - &hp); - }); - - interop::status_to_exception(status); + double alpha = desc.get_alpha(); + if (alpha != 0.0) { + auto ridge_matrix_array = array::full(1, static_cast(alpha)); + auto ridge_matrix = interop::convert_to_daal_homogen_table(ridge_matrix_array, 1, 1); + + { + const auto status = dal::backend::dispatch_by_cpu(ctx, [&](auto cpu) { + constexpr auto cpu_type = interop::to_daal_cpu_type::value; + return online_rr_kernel_t().finalizeCompute(*xtx_daal_table, + *xty_daal_table, + *xtx_daal_table, + *xty_daal_table, + *betas_daal_table, + compute_intercept, + *ridge_matrix); + }); + + interop::status_to_exception(status); + } + } + else { + const daal_lr_hyperparameters_t& hp = convert_parameters(params); + + { + const auto status = dal::backend::dispatch_by_cpu(ctx, [&](auto cpu) { + constexpr auto cpu_type = interop::to_daal_cpu_type::value; + return online_lr_kernel_t().finalizeCompute(*xtx_daal_table, + *xty_daal_table, + *xtx_daal_table, + *xty_daal_table, + *betas_daal_table, + compute_intercept, + &hp); + }); + + interop::status_to_exception(status); + } } auto betas_table = homogen_table::wrap(betas_arr, response_count, feature_count + 1); diff --git a/cpp/oneapi/dal/algo/linear_regression/backend/cpu/partial_train_kernel_norm_eq.cpp b/cpp/oneapi/dal/algo/linear_regression/backend/cpu/partial_train_kernel_norm_eq.cpp index 7cac1aa47b7..d5d9f61003c 100644 --- a/cpp/oneapi/dal/algo/linear_regression/backend/cpu/partial_train_kernel_norm_eq.cpp +++ b/cpp/oneapi/dal/algo/linear_regression/backend/cpu/partial_train_kernel_norm_eq.cpp @@ -62,14 +62,14 @@ static partial_train_result call_daal_kernel(const context_cpu& ctx, const partial_train_input& input) { using dal::detail::check_mul_overflow; - const bool beta = desc.get_compute_intercept(); + const bool compute_intercept = desc.get_compute_intercept(); const auto feature_count = input.get_data().get_column_count(); const auto response_count = input.get_responses().get_column_count(); const daal_hyperparameters_t& hp = convert_parameters(params); - const auto ext_feature_count = feature_count + beta; + const auto ext_feature_count = feature_count + compute_intercept; const bool has_xtx_data = input.get_prev().get_partial_xtx().has_data(); if (has_xtx_data) { @@ -85,7 +85,7 @@ static partial_train_result call_daal_kernel(const context_cpu& ctx, *y_daal_table, *daal_xtx, *daal_xty, - beta, + compute_intercept, &hp); interop::status_to_exception(status); @@ -117,7 +117,7 @@ static partial_train_result call_daal_kernel(const context_cpu& ctx, *y_daal_table, *xtx_daal_table, *xty_daal_table, - beta, + compute_intercept, &hp); interop::status_to_exception(status); diff --git a/cpp/oneapi/dal/algo/linear_regression/backend/cpu/train_kernel_norm_eq.cpp b/cpp/oneapi/dal/algo/linear_regression/backend/cpu/train_kernel_norm_eq.cpp index dbea53a33f6..0e6e1f8cd10 100644 --- a/cpp/oneapi/dal/algo/linear_regression/backend/cpu/train_kernel_norm_eq.cpp +++ b/cpp/oneapi/dal/algo/linear_regression/backend/cpu/train_kernel_norm_eq.cpp @@ -16,6 +16,7 @@ #include #include +#include #include "oneapi/dal/backend/interop/common.hpp" #include "oneapi/dal/backend/interop/error_converter.hpp" @@ -39,21 +40,26 @@ namespace be = dal::backend; namespace pr = be::primitives; namespace interop = dal::backend::interop; namespace daal_lr = daal::algorithms::linear_regression; +namespace daal_rr = daal::algorithms::ridge_regression; -using daal_hyperparameters_t = daal_lr::internal::Hyperparameter; +using daal_lr_hyperparameters_t = daal_lr::internal::Hyperparameter; -constexpr auto daal_method = daal_lr::training::normEqDense; +constexpr auto daal_lr_method = daal_lr::training::normEqDense; +constexpr auto daal_rr_method = daal_rr::training::normEqDense; template -using online_kernel_t = daal_lr::training::internal::OnlineKernel; +using batch_lr_kernel_t = daal_lr::training::internal::BatchKernel; + +template +using batch_rr_kernel_t = daal_rr::training::internal::BatchKernel; template -static daal_hyperparameters_t convert_parameters(const detail::train_parameters& params) { +static daal_lr_hyperparameters_t convert_parameters(const detail::train_parameters& params) { using daal_lr::internal::HyperparameterId; const std::int64_t block = params.get_cpu_macro_block(); - daal_hyperparameters_t daal_hyperparameter; + daal_lr_hyperparameters_t daal_hyperparameter; auto status = daal_hyperparameter.set(HyperparameterId::denseUpdateStepBlockSize, block); interop::status_to_exception(status); @@ -97,33 +103,41 @@ static train_result call_daal_kernel(const context_cpu& ctx, auto x_daal_table = interop::convert_to_daal_table(data); auto y_daal_table = interop::convert_to_daal_table(resp); - const daal_hyperparameters_t& hp = convert_parameters(params); - - { - const auto status = interop::call_daal_kernel(ctx, - *x_daal_table, - *y_daal_table, - *xtx_daal_table, - *xty_daal_table, - intp, - &hp); - - interop::status_to_exception(status); + double alpha = desc.get_alpha(); + if (alpha != 0.0) { + auto ridge_matrix_array = array::full(1, static_cast(alpha)); + auto ridge_matrix = interop::convert_to_daal_homogen_table(ridge_matrix_array, 1, 1); + + { + const auto status = + interop::call_daal_kernel(ctx, + *x_daal_table, + *y_daal_table, + *xtx_daal_table, + *xty_daal_table, + *betas_daal_table, + intp, + *ridge_matrix); + + interop::status_to_exception(status); + } } - - { - const auto status = dal::backend::dispatch_by_cpu(ctx, [&](auto cpu) { - constexpr auto cpu_type = interop::to_daal_cpu_type::value; - return online_kernel_t().finalizeCompute(*xtx_daal_table, - *xty_daal_table, - *xtx_daal_table, - *xty_daal_table, - *betas_daal_table, - intp, - &hp); - }); - - interop::status_to_exception(status); + else { + const daal_lr_hyperparameters_t& hp = convert_parameters(params); + + { + const auto status = + interop::call_daal_kernel(ctx, + *x_daal_table, + *y_daal_table, + *xtx_daal_table, + *xty_daal_table, + *betas_daal_table, + intp, + &hp); + + interop::status_to_exception(status); + } } auto betas_table = homogen_table::wrap(betas_arr, response_count, feature_count + 1); diff --git a/cpp/oneapi/dal/algo/linear_regression/backend/gpu/finalize_train_kernel_norm_eq_dpc.cpp b/cpp/oneapi/dal/algo/linear_regression/backend/gpu/finalize_train_kernel_norm_eq_dpc.cpp index 733bb46b0b3..d3431663249 100644 --- a/cpp/oneapi/dal/algo/linear_regression/backend/gpu/finalize_train_kernel_norm_eq_dpc.cpp +++ b/cpp/oneapi/dal/algo/linear_regression/backend/gpu/finalize_train_kernel_norm_eq_dpc.cpp @@ -27,6 +27,7 @@ #include "oneapi/dal/algo/linear_regression/backend/model_impl.hpp" #include "oneapi/dal/algo/linear_regression/backend/gpu/finalize_train_kernel.hpp" #include "oneapi/dal/algo/linear_regression/backend/gpu/update_kernel.hpp" +#include "oneapi/dal/algo/linear_regression/backend/gpu/misc.hpp" namespace oneapi::dal::linear_regression::backend { @@ -47,14 +48,14 @@ static train_result call_dal_kernel(const context_gpu& ctx, auto& queue = ctx.get_queue(); - const bool beta = desc.get_compute_intercept(); + const bool compute_intercept = desc.get_compute_intercept(); constexpr auto uplo = pr::mkl::uplo::upper; constexpr auto alloc = sycl::usm::alloc::device; const auto response_count = input.get_partial_xty().get_row_count(); const auto ext_feature_count = input.get_partial_xty().get_column_count(); - const auto feature_count = ext_feature_count - beta; + const auto feature_count = ext_feature_count - compute_intercept; const pr::ndshape<2> xtx_shape{ ext_feature_count, ext_feature_count }; @@ -69,9 +70,21 @@ static train_result call_dal_kernel(const context_gpu& ctx, const auto betas_size = check_mul_overflow(response_count, feature_count + 1); auto betas_arr = array::zeros(queue, betas_size, alloc); + double alpha = desc.get_alpha(); + sycl::event ridge_event; + if (alpha != 0.0) { + ridge_event = add_ridge_penalty(queue, xtx_nd, compute_intercept, alpha); + } + auto nxtx = pr::ndarray::empty(queue, xtx_shape, alloc); auto nxty = pr::ndview::wrap_mutable(betas_arr, betas_shape); - auto solve_event = pr::solve_system(queue, beta, xtx_nd, xty_nd, nxtx, nxty, {}); + auto solve_event = pr::solve_system(queue, + compute_intercept, + xtx_nd, + xty_nd, + nxtx, + nxty, + { ridge_event }); sycl::event::wait_and_throw({ solve_event }); auto betas = homogen_table::wrap(betas_arr, response_count, feature_count + 1); diff --git a/cpp/oneapi/dal/algo/linear_regression/backend/gpu/misc.hpp b/cpp/oneapi/dal/algo/linear_regression/backend/gpu/misc.hpp new file mode 100644 index 00000000000..5ad5ba647ec --- /dev/null +++ b/cpp/oneapi/dal/algo/linear_regression/backend/gpu/misc.hpp @@ -0,0 +1,66 @@ +/******************************************************************************* +* Copyright contributors to the oneDAL project +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "oneapi/dal/detail/profiler.hpp" +#include "oneapi/dal/backend/primitives/ndarray.hpp" + +namespace oneapi::dal::linear_regression::backend { + +#ifdef ONEDAL_DATA_PARALLEL + +using alloc = sycl::usm::alloc; +namespace bk = dal::backend; +namespace pr = dal::backend::primitives; + +/// Adds ridge penalty to the diagonal elements of the xtx matrix + +/// +/// @tparam Float Floating-point type used to perform computations +/// +/// @param[in] q The SYCL queue +/// @param[in] xtx The input matrix to which the ridge penalty is added +/// @param[in] compute_intercept Flag indicating whether the intercept term is used in the matrix, extending it with extra dimension if true +/// @param[in] alpha The regularization parameter +/// @param[in] deps Events indicating the availability of the `xtx` for reading or writing +/// +/// @return A SYCL event indicating the availability of the matrix for reading and writing +template +sycl::event add_ridge_penalty(sycl::queue& q, + const pr::ndarray& xtx, + bool compute_intercept, + Float alpha, + const bk::event_vector& deps = {}) { + ONEDAL_ASSERT(xtx.has_mutable_data()); + ONEDAL_ASSERT(be::is_known_usm(q, xtx.get_mutable_data())); + ONEDAL_ASSERT(xtx.get_dimension(0) == xtx.get_dimension(1)); + + Float* xtx_ptr = xtx.get_mutable_data(); + std::int64_t feature_count = xtx.get_dimension(0); + std::int64_t original_feature_count = feature_count - compute_intercept; + + return q.submit([&](sycl::handler& cgh) { + const auto range = be::make_range_1d(original_feature_count); + cgh.depends_on(deps); + std::int64_t step = feature_count + 1; + cgh.parallel_for(range, [=](sycl::id<1> idx) { + xtx_ptr[idx * step] += alpha; + }); + }); +} + +} // namespace oneapi::dal::linear_regression::backend + +#endif // ONEDAL_DATA_PARALLEL diff --git a/cpp/oneapi/dal/algo/linear_regression/backend/gpu/partial_train_kernel_norm_eq_dpc.cpp b/cpp/oneapi/dal/algo/linear_regression/backend/gpu/partial_train_kernel_norm_eq_dpc.cpp index dff0548afe4..a9aa7c373e4 100644 --- a/cpp/oneapi/dal/algo/linear_regression/backend/gpu/partial_train_kernel_norm_eq_dpc.cpp +++ b/cpp/oneapi/dal/algo/linear_regression/backend/gpu/partial_train_kernel_norm_eq_dpc.cpp @@ -45,11 +45,11 @@ static partial_train_result call_dal_kernel(const context_gpu& ctx, constexpr auto alloc = sycl::usm::alloc::device; - const bool beta = desc.get_compute_intercept(); + const bool compute_intercept = desc.get_compute_intercept(); const auto feature_count = input.get_data().get_column_count(); const auto response_count = input.get_responses().get_column_count(); - const std::int64_t ext_feature_count = feature_count + beta; + const std::int64_t ext_feature_count = feature_count + compute_intercept; const pr::ndshape<2> xty_shape{ response_count, ext_feature_count }; const pr::ndshape<2> xtx_shape{ ext_feature_count, ext_feature_count }; @@ -74,8 +74,10 @@ static partial_train_result call_dal_kernel(const context_gpu& ctx, input_.get_partial_xty(), sycl::usm::alloc::device); auto copy_xty_event = copy(queue, xty, xty_nd, { fill_xty_event }); - auto last_xtx_event = update_xtx(queue, beta, data_nd, xtx, { copy_xtx_event }); - auto last_xty_event = update_xty(queue, beta, data_nd, res_nd, xty, { copy_xty_event }); + auto last_xtx_event = + update_xtx(queue, compute_intercept, data_nd, xtx, { copy_xtx_event }); + auto last_xty_event = + update_xty(queue, compute_intercept, data_nd, res_nd, xty, { copy_xty_event }); result.set_partial_xtx(homogen_table::wrap(xtx.flatten(queue, { last_xtx_event }), ext_feature_count, @@ -97,8 +99,10 @@ static partial_train_result call_dal_kernel(const context_gpu& ctx, auto [xtx, fill_xtx_event] = pr::ndarray::zeros(queue, xtx_shape, alloc); - auto last_xty_event = update_xty(queue, beta, data_nd, res_nd, xty, { fill_xty_event }); - auto last_xtx_event = update_xtx(queue, beta, data_nd, xtx, { fill_xtx_event }); + auto last_xty_event = + update_xty(queue, compute_intercept, data_nd, res_nd, xty, { fill_xty_event }); + auto last_xtx_event = + update_xtx(queue, compute_intercept, data_nd, xtx, { fill_xtx_event }); result.set_partial_xtx(homogen_table::wrap(xtx.flatten(queue, { last_xtx_event }), ext_feature_count, diff --git a/cpp/oneapi/dal/algo/linear_regression/backend/gpu/train_kernel_norm_eq_dpc.cpp b/cpp/oneapi/dal/algo/linear_regression/backend/gpu/train_kernel_norm_eq_dpc.cpp index bf0cd04c00e..25b08aa7710 100644 --- a/cpp/oneapi/dal/algo/linear_regression/backend/gpu/train_kernel_norm_eq_dpc.cpp +++ b/cpp/oneapi/dal/algo/linear_regression/backend/gpu/train_kernel_norm_eq_dpc.cpp @@ -29,6 +29,7 @@ #include "oneapi/dal/algo/linear_regression/backend/model_impl.hpp" #include "oneapi/dal/algo/linear_regression/backend/gpu/train_kernel.hpp" #include "oneapi/dal/algo/linear_regression/backend/gpu/update_kernel.hpp" +#include "oneapi/dal/algo/linear_regression/backend/gpu/misc.hpp" namespace oneapi::dal::linear_regression::backend { @@ -62,8 +63,8 @@ static train_result call_dal_kernel(const context_gpu& ctx, const auto feature_count = data.get_column_count(); const auto response_count = resp.get_column_count(); ONEDAL_ASSERT(sample_count == resp.get_row_count()); - const bool beta = desc.get_compute_intercept(); - const std::int64_t ext_feature_count = feature_count + beta; + const bool compute_intercept = desc.get_compute_intercept(); + const std::int64_t ext_feature_count = feature_count + compute_intercept; const auto betas_size = check_mul_overflow(response_count, feature_count + 1); auto betas_arr = array::zeros(queue, betas_size, alloc); @@ -95,8 +96,8 @@ static train_result call_dal_kernel(const context_gpu& ctx, auto y_arr = y_accessor.pull(queue, { first, last }, alloc); auto y = pr::ndview::wrap(y_arr.get_data(), { length, response_count }); - last_xty_event = update_xty(queue, beta, x, y, xty, { last_xty_event }); - last_xtx_event = update_xtx(queue, beta, x, xtx, { last_xtx_event }); + last_xty_event = update_xty(queue, compute_intercept, x, y, xty, { last_xty_event }); + last_xtx_event = update_xtx(queue, compute_intercept, x, xtx, { last_xtx_event }); // We keep the latest slice of data up to date because of pimpl - // it virtually extend lifetime of pulled arrays @@ -105,6 +106,12 @@ static train_result call_dal_kernel(const context_gpu& ctx, const be::event_vector solve_deps{ last_xty_event, last_xtx_event }; + double alpha = desc.get_alpha(); + if (alpha != 0.0) { + last_xtx_event = + add_ridge_penalty(queue, xtx, compute_intercept, alpha, { last_xtx_event }); + } + auto& comm = ctx.get_communicator(); if (comm.get_rank_count() > 1) { sycl::event::wait_and_throw(solve_deps); @@ -122,7 +129,8 @@ static train_result call_dal_kernel(const context_gpu& ctx, auto nxtx = pr::ndarray::empty(queue, xtx_shape, alloc); auto nxty = pr::ndview::wrap_mutable(betas_arr, betas_shape); - auto solve_event = pr::solve_system(queue, beta, xtx, xty, nxtx, nxty, solve_deps); + auto solve_event = + pr::solve_system(queue, compute_intercept, xtx, xty, nxtx, nxty, solve_deps); sycl::event::wait_and_throw({ solve_event }); auto betas = homogen_table::wrap(betas_arr, response_count, feature_count + 1); diff --git a/cpp/oneapi/dal/algo/linear_regression/common.cpp b/cpp/oneapi/dal/algo/linear_regression/common.cpp index 70fd04f221e..949898f3524 100644 --- a/cpp/oneapi/dal/algo/linear_regression/common.cpp +++ b/cpp/oneapi/dal/algo/linear_regression/common.cpp @@ -42,6 +42,8 @@ class descriptor_impl : public base { explicit descriptor_impl() = default; bool compute_intercept = true; + double alpha = 0; + result_option_id result_options = get_default_result_options(); }; @@ -81,6 +83,16 @@ void descriptor_base::set_compute_intercept_impl(bool compute_intercept) { impl_->compute_intercept = compute_intercept; } +template +double descriptor_base::get_alpha() const { + return impl_->alpha; +} + +template +void descriptor_base::set_alpha_impl(double value) { + impl_->alpha = value; +} + template class ONEDAL_EXPORT descriptor_base; } // namespace v1 diff --git a/cpp/oneapi/dal/algo/linear_regression/common.hpp b/cpp/oneapi/dal/algo/linear_regression/common.hpp index 633e919f1bb..57d597a984d 100644 --- a/cpp/oneapi/dal/algo/linear_regression/common.hpp +++ b/cpp/oneapi/dal/algo/linear_regression/common.hpp @@ -112,10 +112,12 @@ class descriptor_base : public base { descriptor_base(bool compute_intercept); bool get_compute_intercept() const; + double get_alpha() const; result_option_id get_result_options() const; protected: void set_compute_intercept_impl(bool compute_intercept); + void set_alpha_impl(double alpha); void set_result_options_impl(const result_option_id& value); private: @@ -165,6 +167,14 @@ class descriptor : public detail::descriptor_base { /// Creates a new instance of the class with default parameters explicit descriptor() : base_t(true) {} + explicit descriptor(bool compute_intercept, double alpha) : base_t(compute_intercept) { + set_alpha(alpha); + } + + explicit descriptor(double alpha) : base_t(true) { + set_alpha(alpha); + } + /// Defines should intercept be taken into consideration. bool get_compute_intercept() const { return base_t::get_compute_intercept(); @@ -175,6 +185,16 @@ class descriptor : public detail::descriptor_base { return *this; } + /// Defines regularization term alpha used in Ridge Regression + double get_alpha() const { + return base_t::get_alpha(); + } + + auto& set_alpha(double value) { + base_t::set_alpha_impl(value); + return *this; + } + /// Choose which results should be computed and returned. result_option_id get_result_options() const { return base_t::get_result_options(); diff --git a/cpp/oneapi/dal/algo/linear_regression/test/batch.cpp b/cpp/oneapi/dal/algo/linear_regression/test/batch.cpp index 270b34b9ddc..00ec7babbb9 100644 --- a/cpp/oneapi/dal/algo/linear_regression/test/batch.cpp +++ b/cpp/oneapi/dal/algo/linear_regression/test/batch.cpp @@ -47,7 +47,15 @@ TEMPLATE_LIST_TEST_M(lr_batch_test, "LR common flow", "[lr][batch]", lr_types) { this->generate(777); - this->run_and_check(); + this->run_and_check_linear(); +} + +TEMPLATE_LIST_TEST_M(lr_batch_test, "RR common flow", "[rr][batch]", lr_types) { + SKIP_IF(this->not_float64_friendly()); + + this->generate(777); + + this->run_and_check_ridge(); } } // namespace oneapi::dal::linear_regression::test diff --git a/cpp/oneapi/dal/algo/linear_regression/test/fixture.hpp b/cpp/oneapi/dal/algo/linear_regression/test/fixture.hpp index a8994a7c704..aedf0165454 100644 --- a/cpp/oneapi/dal/algo/linear_regression/test/fixture.hpp +++ b/cpp/oneapi/dal/algo/linear_regression/test/fixture.hpp @@ -123,6 +123,17 @@ class lr_test : public te::crtp_algo_fixture { return result; } + double generate_alpha(std::int64_t seed) const { + std::mt19937 gen(seed); + + double alpha_min = 1; + double alpha_max = 5; + + std::uniform_real_distribution dist(alpha_min, alpha_max); + + return dist(gen); + } + void check_table_dimensions(const table& x_train, const table& y_train, const table& x_test, @@ -144,13 +155,14 @@ class lr_test : public te::crtp_algo_fixture { this->bias_ = std::move(bias); this->beta_ = std::move(beta); + this->alpha_ = generate_alpha(seed); } - auto get_descriptor() const { + auto get_descriptor(double alpha = 0.0) const { result_option_id resopts = result_options::coefficients; if (this->intercept_) resopts = resopts | result_options::intercept; - return linear_regression::descriptor(intercept_) + return linear_regression::descriptor(intercept_, alpha) .set_result_options(resopts); } @@ -191,7 +203,25 @@ class lr_test : public te::crtp_algo_fixture { } } - void run_and_check(std::int64_t seed = 888, double tol = 1e-2) { + void check_coefficient_shrinkage(const table& lr_coeffs, + const table& rr_coeffs, + double tol = 1e-3) { + row_accessor lr_acc(lr_coeffs); + row_accessor rr_acc(rr_coeffs); + const auto lr_arr = lr_acc.pull({ 0, -1 }); + const auto rr_arr = rr_acc.pull({ 0, -1 }); + + double lr_norm_squared = 0, rr_norm_squared = 0; + for (std::int64_t i = 0; i < lr_arr.get_count(); ++i) { + lr_norm_squared += lr_arr[i] * lr_arr[i]; + rr_norm_squared += rr_arr[i] * rr_arr[i]; + } + + REQUIRE(rr_norm_squared <= lr_norm_squared + tol); + } + + std::tuple prepare_inputs(std::int64_t seed = 888, + double tol = 1e-2) { using namespace ::oneapi::dal::detail; std::mt19937 meta_gen(seed); @@ -214,6 +244,29 @@ class lr_test : public te::crtp_algo_fixture { auto y_test = compute_responses(this->beta_, this->bias_, x_test); check_table_dimensions(x_train, y_train, x_test, y_test); + return { x_train, y_train, x_test, y_test }; + } + + void run_and_check_ridge(std::int64_t seed = 888, double tol = 1e-2) { + table x_train, y_train, x_test, y_test; + std::tie(x_train, y_train, x_test, y_test) = prepare_inputs(seed, tol); + + const auto linear_desc = this->get_descriptor(); + const auto linear_train_res = this->train(linear_desc, x_train, y_train); + + const auto ridge_desc = this->get_descriptor(this->alpha_); + const auto ridge_train_res = this->train(ridge_desc, x_train, y_train); + + SECTION("Checking coefficient shrinkage") { + this->check_coefficient_shrinkage(linear_train_res.get_coefficients(), + ridge_train_res.get_coefficients(), + tol); + } + } + + void run_and_check_linear(std::int64_t seed = 888, double tol = 1e-2) { + table x_train, y_train, x_test, y_test; + std::tie(x_train, y_train, x_test, y_test) = prepare_inputs(seed, tol); const auto desc = this->get_descriptor(); const auto train_res = this->train(desc, x_train, y_train); @@ -234,6 +287,7 @@ class lr_test : public te::crtp_algo_fixture { check_if_close(infer_res.get_responses(), y_test, tol); } } + template std::vector split_table_by_rows(const dal::table& t, std::int64_t split_count) { ONEDAL_ASSERT(0l < split_count); @@ -259,31 +313,12 @@ class lr_test : public te::crtp_algo_fixture { return result; } - void run_and_check_online(std::int64_t nBlocks) { - using namespace ::oneapi::dal::detail; + void run_and_check_linear_online(std::int64_t nBlocks) { std::int64_t seed = 888; double tol = 1e-2; - - std::mt19937 meta_gen(seed); - const std::int64_t train_seed = meta_gen(); - const auto train_dataframe = GENERATE_DATAFRAME( - te::dataframe_builder{ this->s_count_, this->f_count_ }.fill_uniform(-5.5, - 3.5, - train_seed)); - auto x_train = train_dataframe.get_table(this->get_homogen_table_id()); - - const std::int64_t test_seed = meta_gen(); - const auto test_dataframe = GENERATE_DATAFRAME( - te::dataframe_builder{ this->t_count_, this->f_count_ }.fill_uniform(-3.5, - 5.5, - test_seed)); - auto x_test = test_dataframe.get_table(this->get_homogen_table_id()); - - auto y_train = compute_responses(this->beta_, this->bias_, x_train); - auto y_test = compute_responses(this->beta_, this->bias_, x_test); - - check_table_dimensions(x_train, y_train, x_test, y_test); + table x_train, y_train, x_test, y_test; + std::tie(x_train, y_train, x_test, y_test) = prepare_inputs(seed, tol); const auto desc = this->get_descriptor(); dal::linear_regression::partial_train_result<> partial_result; @@ -312,8 +347,45 @@ class lr_test : public te::crtp_algo_fixture { } } + void run_and_check_ridge_online(std::int64_t nBlocks) { + std::int64_t seed = 888; + double tol = 1e-2; + table x_train, y_train, x_test, y_test; + std::tie(x_train, y_train, x_test, y_test) = prepare_inputs(seed, tol); + + auto input_table_x = split_table_by_rows(x_train, nBlocks); + auto input_table_y = split_table_by_rows(y_train, nBlocks); + + const auto linear_desc = this->get_descriptor(); + dal::linear_regression::partial_train_result<> linear_partial_result; + for (std::int64_t i = 0; i < nBlocks; i++) { + linear_partial_result = this->partial_train(linear_desc, + linear_partial_result, + input_table_x[i], + input_table_y[i]); + } + auto linear_train_res = this->finalize_train(linear_desc, linear_partial_result); + + const auto ridge_desc = this->get_descriptor(this->alpha_); + dal::linear_regression::partial_train_result<> ridge_partial_result; + for (std::int64_t i = 0; i < nBlocks; i++) { + ridge_partial_result = this->partial_train(ridge_desc, + ridge_partial_result, + input_table_x[i], + input_table_y[i]); + } + auto ridge_train_res = this->finalize_train(ridge_desc, ridge_partial_result); + + SECTION("Checking coefficient shrinkage") { + this->check_coefficient_shrinkage(linear_train_res.get_coefficients(), + ridge_train_res.get_coefficients(), + tol); + } + } + protected: bool intercept_ = true; + float_t alpha_; std::int64_t t_count_; std::int64_t s_count_; std::int64_t f_count_; diff --git a/cpp/oneapi/dal/algo/linear_regression/test/online.cpp b/cpp/oneapi/dal/algo/linear_regression/test/online.cpp index 2724768491b..c16e1c06f26 100644 --- a/cpp/oneapi/dal/algo/linear_regression/test/online.cpp +++ b/cpp/oneapi/dal/algo/linear_regression/test/online.cpp @@ -47,7 +47,15 @@ TEMPLATE_LIST_TEST_M(lr_online_test, "LR common flow", "[lr][online]", lr_types) this->generate(777); const int64_t nBlocks = GENERATE(1, 3, 5, 8); - this->run_and_check_online(nBlocks); + this->run_and_check_linear_online(nBlocks); +} + +TEMPLATE_LIST_TEST_M(lr_online_test, "RR common flow", "[rr][online]", lr_types) { + SKIP_IF(this->not_float64_friendly()); + this->generate(777); + const int64_t nBlocks = GENERATE(1, 3, 5, 8); + + this->run_and_check_ridge_online(nBlocks); } } // namespace oneapi::dal::linear_regression::test diff --git a/cpp/oneapi/dal/algo/linear_regression/test/spmd.cpp b/cpp/oneapi/dal/algo/linear_regression/test/spmd.cpp index d0cca4e943c..62223f03fdd 100644 --- a/cpp/oneapi/dal/algo/linear_regression/test/spmd.cpp +++ b/cpp/oneapi/dal/algo/linear_regression/test/spmd.cpp @@ -25,7 +25,7 @@ TEMPLATE_LIST_TEST_M(lr_spmd_test, "LR common flow", "[lr][spmd]", lr_types) { this->generate(777); this->set_rank_count(GENERATE(2, 3)); - this->run_and_check(); + this->run_and_check_linear(); } } // namespace oneapi::dal::linear_regression::test diff --git a/cpp/oneapi/dal/algo/linear_regression/test/train_parameters.cpp b/cpp/oneapi/dal/algo/linear_regression/test/train_parameters.cpp index 48f9ead5d3a..835b8ecc1b4 100644 --- a/cpp/oneapi/dal/algo/linear_regression/test/train_parameters.cpp +++ b/cpp/oneapi/dal/algo/linear_regression/test/train_parameters.cpp @@ -89,7 +89,7 @@ TEMPLATE_LIST_TEST_M(lr_train_params_test, "LR train params", "[lr][train][param this->generate(999); this->generate_parameters(); - this->run_and_check(); + this->run_and_check_linear(); } } // namespace oneapi::dal::linear_regression::test