diff --git a/cpp/oneapi/dal/algo/BUILD b/cpp/oneapi/dal/algo/BUILD index daf045d177f..78fd4a219f4 100644 --- a/cpp/oneapi/dal/algo/BUILD +++ b/cpp/oneapi/dal/algo/BUILD @@ -23,6 +23,7 @@ ALGOS = [ "kmeans_init", "knn", "linear_kernel", + "logistic_regression", "logloss_objective", "louvain", "minkowski_distance", diff --git a/cpp/oneapi/dal/algo/logistic_regression.hpp b/cpp/oneapi/dal/algo/logistic_regression.hpp new file mode 100644 index 00000000000..e3fa3d88cfb --- /dev/null +++ b/cpp/oneapi/dal/algo/logistic_regression.hpp @@ -0,0 +1,20 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#pragma once + +#include "oneapi/dal/algo/logistic_regression/infer.hpp" +#include "oneapi/dal/algo/logistic_regression/train.hpp" diff --git a/cpp/oneapi/dal/algo/logistic_regression/BUILD b/cpp/oneapi/dal/algo/logistic_regression/BUILD new file mode 100644 index 00000000000..f7f9ca38661 --- /dev/null +++ b/cpp/oneapi/dal/algo/logistic_regression/BUILD @@ -0,0 +1,61 @@ +package(default_visibility = ["//visibility:public"]) +load("@onedal//dev/bazel:dal.bzl", + "dal_module", + "dal_test_suite", +) + +dal_module( + name = "core", + auto = True, + dal_deps = [ + "@onedal//cpp/oneapi/dal:core", + "@onedal//cpp/oneapi/dal/algo/logistic_regression/backend:model_impl", + "@onedal//cpp/oneapi/dal/algo/logistic_regression/detail:optimizers", + ], +) + +dal_module( + name = "parameters", + dal_deps = [ + "@onedal//cpp/oneapi/dal/algo/logistic_regression/parameters", + ], +) + +dal_module( + name = "logistic_regression", + dal_deps = [ + ":core", + ":parameters", + "@onedal//cpp/oneapi/dal/algo/logistic_regression/detail", + "@onedal//cpp/oneapi/dal/algo/logistic_regression/backend", + + ] +) + +dal_test_suite( + name = "interface_tests", + framework = "catch2", + compile_as = [ "dpc++" ], + hdrs = glob([ + "test/*.hpp", + ], exclude=[ + "test/mpi_*.hpp", + "test/ccl_*.hpp" + ]), + srcs = glob([ + "test/*.cpp", + ], exclude=[ + "test/mpi_*.cpp", + "test/ccl_*.cpp" + ]), + dal_deps = [ + ":logistic_regression", + ], +) + +dal_test_suite( + name = "tests", + tests = [ + ":interface_tests", + ], +) diff --git a/cpp/oneapi/dal/algo/logistic_regression/backend/BUILD b/cpp/oneapi/dal/algo/logistic_regression/backend/BUILD new file mode 100644 index 00000000000..5c5784e66f0 --- /dev/null +++ b/cpp/oneapi/dal/algo/logistic_regression/backend/BUILD @@ -0,0 +1,23 @@ +package(default_visibility = ["//visibility:public"]) +load("@onedal//dev/bazel:dal.bzl", + "dal_module", + "dal_test_suite", +) + +dal_module( + name = "backend", + dal_deps = [ + "@onedal//cpp/oneapi/dal/algo/logistic_regression/backend/cpu", + "@onedal//cpp/oneapi/dal/algo/logistic_regression/backend/gpu", + ], +) + +dal_module( + name = "model_impl", + hdrs = glob(["model_impl.hpp"]), +) + +dal_module( + name = "optimizer_impl", + hdrs = glob(["optimizer_impl.hpp"]), +) diff --git a/cpp/oneapi/dal/algo/logistic_regression/backend/cpu/BUILD b/cpp/oneapi/dal/algo/logistic_regression/backend/cpu/BUILD new file mode 100644 index 00000000000..4f3ea1f1e7c --- /dev/null +++ b/cpp/oneapi/dal/algo/logistic_regression/backend/cpu/BUILD @@ -0,0 +1,14 @@ +package(default_visibility = ["//visibility:public"]) +load("@onedal//dev/bazel:dal.bzl", + "dal_module", + "dal_test_suite", +) + +dal_module( + name = "cpu", + auto = True, + dal_deps = [ + "@onedal//cpp/oneapi/dal/algo/logistic_regression:core", + "@onedal//cpp/oneapi/dal/backend/primitives:common", + ], +) diff --git a/cpp/oneapi/dal/algo/logistic_regression/backend/cpu/infer_kernel.hpp b/cpp/oneapi/dal/algo/logistic_regression/backend/cpu/infer_kernel.hpp new file mode 100644 index 00000000000..935148e3b0b --- /dev/null +++ b/cpp/oneapi/dal/algo/logistic_regression/backend/cpu/infer_kernel.hpp @@ -0,0 +1,31 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#pragma once + +#include "oneapi/dal/algo/logistic_regression/infer_types.hpp" +#include "oneapi/dal/backend/dispatcher.hpp" + +namespace oneapi::dal::logistic_regression::backend { + +template +struct infer_kernel_cpu { + infer_result operator()(const dal::backend::context_cpu& ctx, + const detail::descriptor_base& params, + const infer_input& input) const; +}; + +} // namespace oneapi::dal::logistic_regression::backend diff --git a/cpp/oneapi/dal/algo/logistic_regression/backend/cpu/infer_kernel_dense_batch.cpp b/cpp/oneapi/dal/algo/logistic_regression/backend/cpu/infer_kernel_dense_batch.cpp new file mode 100644 index 00000000000..0bc8ddb219d --- /dev/null +++ b/cpp/oneapi/dal/algo/logistic_regression/backend/cpu/infer_kernel_dense_batch.cpp @@ -0,0 +1,40 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "oneapi/dal/backend/interop/common.hpp" + +#include "oneapi/dal/algo/logistic_regression/common.hpp" +#include "oneapi/dal/algo/logistic_regression/infer_types.hpp" +#include "oneapi/dal/algo/logistic_regression/backend/cpu/infer_kernel.hpp" + +namespace oneapi::dal::logistic_regression::backend { + +using dal::backend::context_cpu; + +template +struct infer_kernel_cpu { + infer_result operator()(const context_cpu& ctx, + const detail::descriptor_base& desc, + const infer_input& input) const { + throw unimplemented( + dal::detail::error_messages::log_reg_dense_batch_method_is_not_implemented_for_cpu()); + } +}; + +template struct infer_kernel_cpu; +template struct infer_kernel_cpu; + +} // namespace oneapi::dal::logistic_regression::backend diff --git a/cpp/oneapi/dal/algo/logistic_regression/backend/cpu/train_kernel.hpp b/cpp/oneapi/dal/algo/logistic_regression/backend/cpu/train_kernel.hpp new file mode 100644 index 00000000000..96ff72b8bb4 --- /dev/null +++ b/cpp/oneapi/dal/algo/logistic_regression/backend/cpu/train_kernel.hpp @@ -0,0 +1,32 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#pragma once + +#include "oneapi/dal/algo/logistic_regression/train_types.hpp" +#include "oneapi/dal/backend/dispatcher.hpp" + +namespace oneapi::dal::logistic_regression::backend { + +template +struct train_kernel_cpu { + train_result operator()(const dal::backend::context_cpu& ctx, + const detail::descriptor_base& desc, + const detail::train_parameters& params, + const train_input& input) const; +}; + +} // namespace oneapi::dal::logistic_regression::backend diff --git a/cpp/oneapi/dal/algo/logistic_regression/backend/cpu/train_kernel_dense_batch.cpp b/cpp/oneapi/dal/algo/logistic_regression/backend/cpu/train_kernel_dense_batch.cpp new file mode 100644 index 00000000000..e5728534b1d --- /dev/null +++ b/cpp/oneapi/dal/algo/logistic_regression/backend/cpu/train_kernel_dense_batch.cpp @@ -0,0 +1,38 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "oneapi/dal/algo/logistic_regression/backend/cpu/train_kernel.hpp" +#include "oneapi/dal/exceptions.hpp" + +namespace oneapi::dal::logistic_regression::backend { + +using dal::backend::context_cpu; + +template +struct train_kernel_cpu { + train_result operator()(const context_cpu& ctx, + const detail::descriptor_base& desc, + const detail::train_parameters& params, + const train_input& input) const { + throw unimplemented( + dal::detail::error_messages::log_reg_dense_batch_method_is_not_implemented_for_cpu()); + } +}; + +template struct train_kernel_cpu; +template struct train_kernel_cpu; + +} // namespace oneapi::dal::logistic_regression::backend diff --git a/cpp/oneapi/dal/algo/logistic_regression/backend/gpu/BUILD b/cpp/oneapi/dal/algo/logistic_regression/backend/gpu/BUILD new file mode 100644 index 00000000000..c22cb27010c --- /dev/null +++ b/cpp/oneapi/dal/algo/logistic_regression/backend/gpu/BUILD @@ -0,0 +1,18 @@ +package(default_visibility = ["//visibility:public"]) +load("@onedal//dev/bazel:dal.bzl", + "dal_module", + "dal_test_suite", +) + +dal_module( + name = "gpu", + auto = True, + dal_deps = [ + "@onedal//cpp/oneapi/dal/backend/primitives:blas", + "@onedal//cpp/oneapi/dal/backend/primitives:common", + "@onedal//cpp/oneapi/dal/backend/primitives:lapack", + "@onedal//cpp/oneapi/dal/backend/primitives:objective_function", + "@onedal//cpp/oneapi/dal/backend/primitives:optimizers", + "@onedal//cpp/oneapi/dal/algo/logistic_regression:core", + ], +) diff --git a/cpp/oneapi/dal/algo/logistic_regression/backend/gpu/infer_kernel.hpp b/cpp/oneapi/dal/algo/logistic_regression/backend/gpu/infer_kernel.hpp new file mode 100644 index 00000000000..338a0a722a0 --- /dev/null +++ b/cpp/oneapi/dal/algo/logistic_regression/backend/gpu/infer_kernel.hpp @@ -0,0 +1,31 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#pragma once + +#include "oneapi/dal/algo/logistic_regression/infer_types.hpp" +#include "oneapi/dal/backend/dispatcher.hpp" + +namespace oneapi::dal::logistic_regression::backend { + +template +struct infer_kernel_gpu { + infer_result operator()(const dal::backend::context_gpu& ctx, + const detail::descriptor_base& params, + const infer_input& input) const; +}; + +} // namespace oneapi::dal::logistic_regression::backend diff --git a/cpp/oneapi/dal/algo/logistic_regression/backend/gpu/infer_kernel_dense_batch_dpc.cpp b/cpp/oneapi/dal/algo/logistic_regression/backend/gpu/infer_kernel_dense_batch_dpc.cpp new file mode 100644 index 00000000000..f3ccfa3da8f --- /dev/null +++ b/cpp/oneapi/dal/algo/logistic_regression/backend/gpu/infer_kernel_dense_batch_dpc.cpp @@ -0,0 +1,144 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "oneapi/dal/detail/profiler.hpp" + +#include "oneapi/dal/backend/interop/common.hpp" +#include "oneapi/dal/backend/interop/common_dpc.hpp" +#include "oneapi/dal/backend/interop/error_converter.hpp" +#include "oneapi/dal/backend/interop/table_conversion.hpp" + +#include "oneapi/dal/backend/dispatcher.hpp" +#include "oneapi/dal/backend/primitives/blas.hpp" +#include "oneapi/dal/backend/primitives/ndarray.hpp" +#include "oneapi/dal/backend/primitives/objective_function.hpp" +#include "oneapi/dal/backend/primitives/ndindexer.hpp" + +#include "oneapi/dal/table/row_accessor.hpp" + +#include "oneapi/dal/algo/logistic_regression/common.hpp" +#include "oneapi/dal/algo/logistic_regression/train_types.hpp" +#include "oneapi/dal/algo/logistic_regression/backend/model_impl.hpp" +#include "oneapi/dal/algo/logistic_regression/backend/gpu/infer_kernel.hpp" + +namespace oneapi::dal::logistic_regression::backend { + +using daal::services::Status; +using dal::backend::context_gpu; + +namespace be = dal::backend; +namespace pr = be::primitives; +namespace interop = dal::backend::interop; + +template +std::int64_t propose_block_size(const sycl::queue& q, std::int64_t f, std::int64_t r) { + constexpr std::int64_t fsize = sizeof(Float); + return 0x10000l * (8 / fsize); +} + +template +static infer_result call_dal_kernel(const context_gpu& ctx, + const detail::descriptor_base& desc, + const table& infer, + const model& m) { + using dal::detail::check_mul_overflow; + + auto queue = ctx.get_queue(); + ONEDAL_PROFILER_TASK(logreg_infer_kernel, queue); + + constexpr auto alloc = sycl::usm::alloc::device; + + const auto& betas = m.get_packed_coefficients(); + + const auto sample_count = infer.get_row_count(); + const auto feature_count = infer.get_column_count(); + const bool fit_intercept = desc.get_compute_intercept(); + ONEDAL_ASSERT((feature_count + 1) == betas.get_column_count()); + ONEDAL_ASSERT(1 == betas.get_row_count()); + + pr::ndarray params = pr::table2ndarray_1d(queue, betas, alloc); + pr::ndview params_suf = fit_intercept ? params : params.slice(1, feature_count); + + pr::ndarray probs = pr::ndarray::empty(queue, { sample_count }, alloc); + pr::ndarray responses = + pr::ndarray::empty(queue, { sample_count }, alloc); + + const auto bsize = propose_block_size(queue, feature_count, 1); + const be::uniform_blocking blocking(sample_count, bsize); + const auto b_count = blocking.get_block_count(); + + row_accessor x_accessor(infer); + + be::event_vector all_deps; + all_deps.reserve(b_count); + + for (std::int64_t b = 0; b < b_count; ++b) { + const auto last = blocking.get_block_end_index(b); + const auto first = blocking.get_block_start_index(b); + + const auto length = last - first; + ONEDAL_ASSERT(0l < length); + + auto probs_slice = probs.slice(first, length); + auto resp_slice = responses.slice(first, length); + auto x_rows = x_accessor.pull(queue, { first, last }, alloc); + auto x_nd = pr::ndarray::wrap(x_rows, { length, feature_count }); + + auto prob_event = + pr::compute_probabilities(queue, params_suf, x_nd, probs_slice, fit_intercept, {}); + + const auto* const prob_ptr = probs_slice.get_data(); + auto* const resp_ptr = resp_slice.get_mutable_data(); + + auto fill_resp_event = queue.submit([&](sycl::handler& cgh) { + cgh.depends_on(prob_event); + const auto range = be::make_range_1d(length); + cgh.parallel_for(range, [=](sycl::id<1> idx) { + constexpr Float half = 0.5f; + resp_ptr[idx] = prob_ptr[idx] < half ? 0 : 1; + }); + }); + all_deps.push_back(fill_resp_event); + } + + auto resp_table = homogen_table::wrap(responses.flatten(queue, all_deps), sample_count, 1); + auto prob_table = homogen_table::wrap(probs.flatten(queue, all_deps), sample_count, 1); + + auto result = infer_result().set_responses(resp_table).set_probabilities(prob_table); + + return result; +} + +template +static infer_result infer(const context_gpu& ctx, + const detail::descriptor_base& desc, + const infer_input& input) { + return call_dal_kernel(ctx, desc, input.get_data(), input.get_model()); +} + +template +struct infer_kernel_gpu { + infer_result operator()(const context_gpu& ctx, + const detail::descriptor_base& desc, + const infer_input& input) const { + return infer(ctx, desc, input); + } +}; + +template struct infer_kernel_gpu; +template struct infer_kernel_gpu; + +} // namespace oneapi::dal::logistic_regression::backend diff --git a/cpp/oneapi/dal/algo/logistic_regression/backend/gpu/train_kernel.hpp b/cpp/oneapi/dal/algo/logistic_regression/backend/gpu/train_kernel.hpp new file mode 100644 index 00000000000..79b399a2077 --- /dev/null +++ b/cpp/oneapi/dal/algo/logistic_regression/backend/gpu/train_kernel.hpp @@ -0,0 +1,32 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#pragma once + +#include "oneapi/dal/algo/logistic_regression/train_types.hpp" +#include "oneapi/dal/backend/dispatcher.hpp" + +namespace oneapi::dal::logistic_regression::backend { + +template +struct train_kernel_gpu { + train_result operator()(const dal::backend::context_gpu& ctx, + const detail::descriptor_base& desc, + const detail::train_parameters& params, + const train_input& input) const; +}; + +} // namespace oneapi::dal::logistic_regression::backend diff --git a/cpp/oneapi/dal/algo/logistic_regression/backend/gpu/train_kernel_dense_batch_dpc.cpp b/cpp/oneapi/dal/algo/logistic_regression/backend/gpu/train_kernel_dense_batch_dpc.cpp new file mode 100644 index 00000000000..1b6f2f75936 --- /dev/null +++ b/cpp/oneapi/dal/algo/logistic_regression/backend/gpu/train_kernel_dense_batch_dpc.cpp @@ -0,0 +1,138 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "oneapi/dal/detail/profiler.hpp" + +#include "oneapi/dal/detail/common.hpp" +#include "oneapi/dal/backend/dispatcher.hpp" +#include "oneapi/dal/backend/primitives/ndarray.hpp" +#include "oneapi/dal/backend/primitives/lapack.hpp" +#include "oneapi/dal/backend/primitives/utils.hpp" + +#include "oneapi/dal/table/row_accessor.hpp" + +#include "oneapi/dal/algo/logistic_regression/common.hpp" +#include "oneapi/dal/algo/logistic_regression/train_types.hpp" +#include "oneapi/dal/algo/logistic_regression/backend/model_impl.hpp" +#include "oneapi/dal/algo/logistic_regression/backend/gpu/train_kernel.hpp" +#include "oneapi/dal/backend/primitives/objective_function.hpp" +#include "oneapi/dal/backend/primitives/optimizers.hpp" +#include "oneapi/dal/algo/logistic_regression/backend/optimizer_impl.hpp" + +namespace oneapi::dal::logistic_regression::backend { + +using dal::backend::context_gpu; + +namespace be = dal::backend; +namespace pr = be::primitives; + +template +static train_result call_dal_kernel(const context_gpu& ctx, + const detail::descriptor_base& desc, + const detail::train_parameters& params, + const table& data, + const table& resp) { + using dal::detail::check_mul_overflow; + + auto queue = ctx.get_queue(); + + ONEDAL_PROFILER_TASK(log_reg_train_kernel, queue); + + using model_t = model; + using model_impl_t = detail::model_impl; + + auto opt_impl = detail::get_optimizer_impl(desc); + + if (!opt_impl) { + throw internal_error{ dal::detail::error_messages::unknown_optimizer() }; + } + + const auto sample_count = data.get_row_count(); + const auto feature_count = data.get_column_count(); + ONEDAL_ASSERT(sample_count == resp.get_row_count()); + const auto responses_nd = + pr::table2ndarray_1d(queue, resp, sycl::usm::alloc::device); + + const std::int64_t bsize = params.get_gpu_macro_block(); + + const Float l2 = Float(1.0) / desc.get_inverse_regularization(); + const bool fit_intercept = desc.get_compute_intercept(); + + // TODO: add check if the dataset can be moved to gpu + // Move data to gpu + pr::ndarray data_nd = pr::table2ndarray(queue, data, sycl::usm::alloc::device); + table data_gpu = homogen_table::wrap(data_nd.flatten(queue, {}), sample_count, feature_count); + + pr::logloss_function loss_func = + pr::logloss_function(queue, data_gpu, responses_nd, l2, fit_intercept, bsize); + + auto [x, fill_event] = + pr::ndarray::zeros(queue, { feature_count + 1 }, sycl::usm::alloc::device); + + pr::ndview x_suf = fit_intercept ? x : x.slice(1, feature_count); + + auto [train_event, iter_num] = opt_impl->minimize(queue, loss_func, x_suf, { fill_event }); + + auto all_coefs = homogen_table::wrap(x.flatten(queue, { train_event }), 1, feature_count + 1); + + const auto model_impl = std::make_shared(all_coefs); + const auto model = dal::detail::make_private(model_impl); + + const auto options = desc.get_result_options(); + auto result = train_result().set_model(model).set_result_options(options); + + if (options.test(result_options::intercept)) { + ONEDAL_ASSERT(fit_intercept); + table intercept_table = + homogen_table::wrap(x.slice(0, 1).flatten(queue, { train_event }), 1, 1); + result.set_intercept(intercept_table); + } + + if (options.test(result_options::coefficients)) { + auto coefs_array = x.slice(1, feature_count).flatten(queue, { train_event }); + auto coefs_table = homogen_table::wrap(coefs_array, 1, feature_count); + result.set_coefficients(coefs_table); + } + + if (options.test(result_options::iterations_count)) { + result.set_iterations_count(iter_num); + } + + return result; +} + +template +static train_result train(const context_gpu& ctx, + const detail::descriptor_base& desc, + const detail::train_parameters& params, + const train_input& input) { + return call_dal_kernel(ctx, desc, params, input.get_data(), input.get_responses()); +} + +template +struct train_kernel_gpu { + train_result operator()(const context_gpu& ctx, + const detail::descriptor_base& desc, + const detail::train_parameters& params, + const train_input& input) const { + return train(ctx, desc, params, input); + } +}; + +template struct train_kernel_gpu; +template struct train_kernel_gpu; + +} // namespace oneapi::dal::logistic_regression::backend diff --git a/cpp/oneapi/dal/algo/logistic_regression/backend/model_impl.hpp b/cpp/oneapi/dal/algo/logistic_regression/backend/model_impl.hpp new file mode 100644 index 00000000000..fcabcaf8ef5 --- /dev/null +++ b/cpp/oneapi/dal/algo/logistic_regression/backend/model_impl.hpp @@ -0,0 +1,54 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#pragma once + +#include "oneapi/dal/algo/logistic_regression/common.hpp" + +#include "oneapi/dal/backend/serialization.hpp" + +namespace oneapi::dal::logistic_regression { + +using dense_batch_proto = ONEDAL_SERIALIZABLE(logistic_regression_model_impl_id); + +template +class detail::v1::model_impl : public dense_batch_proto { +public: + model_impl() = default; + + model_impl(const table& packed_coefficients) : packed_coefficients_(packed_coefficients) {} + + void serialize(dal::detail::output_archive& ar) const { + ar(packed_coefficients_); + } + + void deserialize(dal::detail::input_archive& ar) { + ar(packed_coefficients_); + } + + const table& get_packed_coefficients() const { + return packed_coefficients_; + } + + void set_packed_coefficients(const table& v) { + this->packed_coefficients_ = v; + } + +private: + table packed_coefficients_; +}; + +} // namespace oneapi::dal::logistic_regression diff --git a/cpp/oneapi/dal/algo/logistic_regression/backend/optimizer_impl.hpp b/cpp/oneapi/dal/algo/logistic_regression/backend/optimizer_impl.hpp new file mode 100644 index 00000000000..449a0cac61b --- /dev/null +++ b/cpp/oneapi/dal/algo/logistic_regression/backend/optimizer_impl.hpp @@ -0,0 +1,56 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#pragma once +#include "oneapi/dal/backend/primitives/ndarray.hpp" + +#ifdef ONEDAL_DATA_PARALLEL +#include "oneapi/dal/backend/primitives/optimizers.hpp" +#endif + +namespace oneapi::dal::logistic_regression::detail { +namespace v1 { + +namespace be = dal::backend; +namespace pr = be::primitives; + +enum optimizer_type { newton_cg }; + +class optimizer_impl : public base { +public: + virtual ~optimizer_impl() = default; + virtual optimizer_type get_optimizer_type() = 0; + virtual double get_tol() = 0; + virtual std::int64_t get_max_iter() = 0; + +#ifdef ONEDAL_DATA_PARALLEL + virtual std::pair minimize(sycl::queue& q, + pr::base_function& f, + pr::ndview& x, + const be::event_vector& deps = {}) = 0; + virtual std::pair minimize(sycl::queue& q, + pr::base_function& f, + pr::ndview& x, + const be::event_vector& deps = {}) = 0; +#endif +}; + +} // namespace v1 + +using v1::optimizer_impl; +using v1::optimizer_type; + +} // namespace oneapi::dal::logistic_regression::detail diff --git a/cpp/oneapi/dal/algo/logistic_regression/common.cpp b/cpp/oneapi/dal/algo/logistic_regression/common.cpp new file mode 100644 index 00000000000..01a3a047187 --- /dev/null +++ b/cpp/oneapi/dal/algo/logistic_regression/common.cpp @@ -0,0 +1,169 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "oneapi/dal/algo/logistic_regression/backend/model_impl.hpp" +#include "oneapi/dal/algo/logistic_regression/common.hpp" +#include "oneapi/dal/exceptions.hpp" + +namespace oneapi::dal::logistic_regression { + +namespace detail { + +result_option_id get_intercept_id() { + return result_option_id{ result_option_id::make_by_index(0) }; +} + +result_option_id get_coefficients_id() { + return result_option_id{ result_option_id::make_by_index(1) }; +} + +result_option_id get_iterations_count_id() { + return result_option_id{ result_option_id::make_by_index(2) }; +} + +template +result_option_id get_default_result_options() { + return result_option_id{}; +} + +namespace v1 { +template +class descriptor_impl : public base { +public: + explicit descriptor_impl(const detail::optimizer_ptr& optimizer) : opt(optimizer) {} + + bool compute_intercept = true; + double C = 1.0; + std::int64_t class_count = 2; + detail::optimizer_ptr opt; + + result_option_id result_options = get_default_result_options(); +}; + +template +descriptor_base::descriptor_base() + : impl_(new descriptor_impl{ + std::make_shared>(optimizer_t()) }) {} + +template +result_option_id descriptor_base::get_result_options() const { + return impl_->result_options; +} + +template +void descriptor_base::set_result_options_impl(const result_option_id& value) { + using msg = dal::detail::error_messages; + if (!value) { + throw domain_error(msg::empty_set_of_result_options()); + } + else if (!get_compute_intercept() && value.test(result_options::intercept)) { + throw domain_error(msg::intercept_result_option_requires_intercept_flag()); + } + impl_->result_options = value; +} + +template +descriptor_base::descriptor_base(bool compute_intercept, + double C, + const detail::optimizer_ptr& optimizer) + : impl_(new descriptor_impl{ optimizer }) { + impl_->compute_intercept = compute_intercept; + impl_->C = C; + impl_->opt = optimizer; +} + +template +bool descriptor_base::get_compute_intercept() const { + return impl_->compute_intercept; +} + +template +double descriptor_base::get_inverse_regularization() const { + return impl_->C; +} + +template +std::int64_t descriptor_base::get_class_count() const { + return impl_->class_count; +} + +template +const detail::optimizer_ptr& descriptor_base::get_optimizer_impl() const { + return impl_->opt; +} + +template +void descriptor_base::set_optimizer_impl(const detail::optimizer_ptr& opt) { + impl_->opt = opt; +} + +template +void descriptor_base::set_compute_intercept_impl(bool compute_intercept) { + impl_->compute_intercept = compute_intercept; +} + +template +void descriptor_base::set_inverse_regularization_impl(double C) { + impl_->C = C; +} + +template +void descriptor_base::set_class_count_impl(std::int64_t class_count) { + impl_->class_count = class_count; +} + +template class ONEDAL_EXPORT descriptor_base; + +} // namespace v1 +} // namespace detail + +namespace v1 { + +using detail::v1::model_impl; + +template +model::model() : impl_{ std::make_shared>() } {} + +template +model::model(const std::shared_ptr>& impl) : impl_{ impl } {} + +template +const table& model::get_packed_coefficients() const { + return impl_->get_packed_coefficients(); +} + +template +model& model::set_packed_coefficients(const table& t) { + impl_->set_packed_coefficients(t); + return *this; +} + +template +void model::serialize(dal::detail::output_archive& ar) const { + dal::detail::serialize_polymorphic_shared(impl_, ar); +} + +template +void model::deserialize(dal::detail::input_archive& ar) { + dal::detail::deserialize_polymorphic_shared(impl_, ar); +} + +template class ONEDAL_EXPORT model; + +ONEDAL_REGISTER_SERIALIZABLE(detail::model_impl) + +} // namespace v1 +} // namespace oneapi::dal::logistic_regression diff --git a/cpp/oneapi/dal/algo/logistic_regression/common.hpp b/cpp/oneapi/dal/algo/logistic_regression/common.hpp new file mode 100644 index 00000000000..8f883ee1467 --- /dev/null +++ b/cpp/oneapi/dal/algo/logistic_regression/common.hpp @@ -0,0 +1,279 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#pragma once + +#include "oneapi/dal/util/result_option_id.hpp" +#include "oneapi/dal/detail/serialization.hpp" +#include "oneapi/dal/detail/common.hpp" +#include "oneapi/dal/table/common.hpp" +#include "oneapi/dal/common.hpp" +#include "oneapi/dal/algo/logistic_regression/detail/optimizer.hpp" + +namespace oneapi::dal::logistic_regression { + +namespace task { +namespace v1 { +/// Tag-type that parameterizes entities used for solving +/// :capterm:`classification problem `. +struct classification {}; + +/// Alias tag-type for regression task. +using by_default = classification; +} // namespace v1 + +using v1::classification; +using v1::by_default; + +} // namespace task + +namespace method { +namespace v1 { +/// Tag-type that denotes :ref:`dense batch ` computational method. +struct dense_batch {}; + +using by_default = dense_batch; +} // namespace v1 + +using v1::dense_batch; +using v1::by_default; + +} // namespace method + +/// Represents result option flag +/// Behaves like a regular :expr`enum`. +class result_option_id : public result_option_id_base { +public: + constexpr result_option_id() = default; + constexpr explicit result_option_id(const result_option_id_base& base) + : result_option_id_base{ base } {} +}; + +namespace detail { + +ONEDAL_EXPORT result_option_id get_intercept_id(); +ONEDAL_EXPORT result_option_id get_coefficients_id(); +ONEDAL_EXPORT result_option_id get_iterations_count_id(); + +} // namespace detail + +/// Result options are used to define +/// what should algorithm return +namespace result_options { + +/// Return the indices the intercept term in logistic regression +const inline result_option_id intercept = detail::get_intercept_id(); + +/// Return the coefficients to use in logistic regression +const inline result_option_id coefficients = detail::get_coefficients_id(); + +const inline result_option_id iterations_count = detail::get_iterations_count_id(); + +} // namespace result_options + +namespace detail { +namespace v1 { + +struct descriptor_tag {}; + +template +class descriptor_impl; + +template +class model_impl; + +template +constexpr bool is_valid_float_v = dal::detail::is_one_of_v; + +template +constexpr bool is_valid_method_v = dal::detail::is_one_of_v; + +template +constexpr bool is_valid_task_v = dal::detail::is_one_of_v; + +template +constexpr bool is_valid_optimizer_v = + dal::detail::is_tag_one_of_v; + +template +class descriptor_base : public base { + static_assert(is_valid_task_v); + friend detail::optimizer_accessor; + +public: + using tag_t = descriptor_tag; + using float_t = float; + using optimizer_t = oneapi::dal::newton_cg::descriptor; + descriptor_base(); + + bool get_compute_intercept() const; + double get_inverse_regularization() const; + std::int64_t get_class_count() const; + result_option_id get_result_options() const; + +protected: + explicit descriptor_base(bool compute_intercept, + double C, + const detail::optimizer_ptr& optimizer); + + void set_compute_intercept_impl(bool compute_intercept); + void set_inverse_regularization_impl(double C); + void set_class_count_impl(std::int64_t class_count); + + void set_optimizer_impl(const detail::optimizer_ptr& opt); + void set_result_options_impl(const result_option_id& value); + + const detail::optimizer_ptr& get_optimizer_impl() const; + +private: + dal::detail::pimpl> impl_; +}; + +} // namespace v1 + +using v1::descriptor_tag; +using v1::descriptor_impl; +using v1::model_impl; +using v1::descriptor_base; + +using v1::is_valid_float_v; +using v1::is_valid_method_v; +using v1::is_valid_task_v; +using v1::is_valid_optimizer_v; + +} // namespace detail + +namespace v1 { + +/// @tparam Float The floating-point type that the algorithm uses for +/// intermediate computations. Can be :expr:`float` or +/// :expr:`double`. +/// @tparam Method Tag-type that specifies an implementation of algorithm. Can +/// be :expr:`method::dense_batch`. +/// @tparam Task Tag-type that specifies type of the problem to solve. Can +/// be :expr:`task::classification`. +/// @tparam Optimizer Tag-type that specifies type of the optimizer used by algorithm. +/// Can be :expr:`optimizer::newton_cg`. +template > +class descriptor : public detail::descriptor_base { + static_assert(detail::is_valid_float_v); + static_assert(detail::is_valid_method_v); + static_assert(detail::is_valid_task_v); + static_assert(detail::is_valid_optimizer_v); + + using base_t = detail::descriptor_base; + +public: + using float_t = Float; + using method_t = Method; + using task_t = Task; + using optimizer_t = Optimizer; + + /// Creates a new instance of the class with the given :literal:`compute_intercept` + explicit descriptor(bool compute_intercept = true, double C = 1.0) + : base_t(compute_intercept, + C, + std::make_shared>(optimizer_t{})) {} + + /// Creates a new instance of the class with the given :literal:`compute_intercept` + explicit descriptor(bool compute_intercept, double C, const optimizer_t& optimizer) + : base_t(compute_intercept, + C, + std::make_shared>(optimizer)) {} + + /// Defines should intercept be taken into consideration. + bool get_compute_intercept() const { + return base_t::get_compute_intercept(); + } + + /// Defines inverse regularization factor. + double get_inverse_regularization() const { + return base_t::get_inverse_regularization(); + } + + /// Defines number of classes. + double get_class_count() const { + return base_t::get_class_count(); + } + + auto& set_compute_intercept(bool compute_intercept) const { + base_t::set_compute_intercept_impl(compute_intercept); + return *this; + } + + auto& set_inverse_regularization(double C) const { + base_t::set_inverse_regularization_impl(C); + return *this; + } + + auto& set_class_count(std::int64_t class_count) const { + base_t::set_class_count_impl(class_count); + return *this; + } + + const optimizer_t& get_optimizer() const { + using optimizer_t = detail::optimizer; + const auto opt = std::static_pointer_cast(base_t::get_optimizer_impl()); + return opt; + } + + auto& set_optimizer(const optimizer_t& opt) { + base_t::set_optimizer_impl(std::make_shared>(opt)); + return *this; + } + + /// Choose which results should be computed and returned. + result_option_id get_result_options() const { + return base_t::get_result_options(); + } + + auto& set_result_options(const result_option_id& value) { + base_t::set_result_options_impl(value); + return *this; + } +}; + +/// @tparam Task Tag-type that specifies type of the problem to solve. +template +class model : public base { + static_assert(detail::is_valid_task_v); + friend dal::detail::pimpl_accessor; + friend dal::detail::serialization_accessor; + +public: + /// Creates a new instance of the class with the default property values. + model(); + + const table& get_packed_coefficients() const; + model& set_packed_coefficients(const table& t); + +private: + void serialize(dal::detail::output_archive& ar) const; + void deserialize(dal::detail::input_archive& ar); + + explicit model(const std::shared_ptr>& impl); + dal::detail::pimpl> impl_; +}; + +} // namespace v1 + +using v1::descriptor; +using v1::model; + +} // namespace oneapi::dal::logistic_regression diff --git a/cpp/oneapi/dal/algo/logistic_regression/detail/BUILD b/cpp/oneapi/dal/algo/logistic_regression/detail/BUILD new file mode 100644 index 00000000000..6b5e7c71962 --- /dev/null +++ b/cpp/oneapi/dal/algo/logistic_regression/detail/BUILD @@ -0,0 +1,27 @@ +package(default_visibility = ["//visibility:public"]) +load("@onedal//dev/bazel:dal.bzl", + "dal_module", + "dal_test_suite", +) + +dal_module( + name = "detail", + auto=True, + dal_deps = [ + "@onedal//cpp/oneapi/dal/algo/logistic_regression:core", + "@onedal//cpp/oneapi/dal/algo/logistic_regression/backend", + "@onedal//cpp/oneapi/dal/algo/logistic_regression/parameters", + "@onedal//cpp/oneapi/dal/algo/newton_cg:newton_cg" + ] +) + + +dal_module( + name = "optimizers", + hdrs = glob(["optimizer.hpp"]), + dal_deps = [ + "@onedal//cpp/oneapi/dal:core", + "@onedal//cpp/oneapi/dal/algo/logistic_regression/backend:optimizer_impl", + "@onedal//cpp/oneapi/dal/algo/newton_cg", + ], +) \ No newline at end of file diff --git a/cpp/oneapi/dal/algo/logistic_regression/detail/infer_ops.cpp b/cpp/oneapi/dal/algo/logistic_regression/detail/infer_ops.cpp new file mode 100644 index 00000000000..935bd6ab9af --- /dev/null +++ b/cpp/oneapi/dal/algo/logistic_regression/detail/infer_ops.cpp @@ -0,0 +1,44 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "oneapi/dal/algo/logistic_regression/backend/cpu/infer_kernel.hpp" +#include "oneapi/dal/algo/logistic_regression/detail/infer_ops.hpp" +#include "oneapi/dal/backend/dispatcher.hpp" + +namespace oneapi::dal::logistic_regression::detail { +namespace v1 { + +using dal::detail::host_policy; + +template +struct infer_ops_dispatcher { + infer_result operator()(const host_policy& ctx, + const descriptor_base& desc, + const infer_input& input) const { + using kernel_dispatcher_t = dal::backend::kernel_dispatcher)>; + return kernel_dispatcher_t()(ctx, desc, input); + } +}; + +#define INSTANTIATE(F, M, T) \ + template struct ONEDAL_EXPORT infer_ops_dispatcher; + +INSTANTIATE(float, method::dense_batch, task::classification) +INSTANTIATE(double, method::dense_batch, task::classification) + +} // namespace v1 +} // namespace oneapi::dal::logistic_regression::detail diff --git a/cpp/oneapi/dal/algo/logistic_regression/detail/infer_ops.hpp b/cpp/oneapi/dal/algo/logistic_regression/detail/infer_ops.hpp new file mode 100644 index 00000000000..6029ae09a65 --- /dev/null +++ b/cpp/oneapi/dal/algo/logistic_regression/detail/infer_ops.hpp @@ -0,0 +1,70 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#pragma once + +#include "oneapi/dal/algo/logistic_regression/infer_types.hpp" +#include "oneapi/dal/detail/error_messages.hpp" + +namespace oneapi::dal::logistic_regression::detail { +namespace v1 { + +template +struct infer_ops_dispatcher { + infer_result operator()(const Context&, + const descriptor_base&, + const infer_input&) const; +}; + +template +struct infer_ops { + using float_t = typename Descriptor::float_t; + using method_t = typename Descriptor::method_t; + using task_t = typename Descriptor::task_t; + using input_t = infer_input; + using result_t = infer_result; + using descriptor_base_t = descriptor_base; + + void check_preconditions(const Descriptor& params, const input_t& input) const { + using msg = dal::detail::error_messages; + + if (!input.get_data().has_data()) { + throw domain_error(msg::input_data_is_empty()); + } + } + + void check_postconditions(const Descriptor& params, + const input_t& input, + const result_t& result) const { + ONEDAL_ASSERT(result.get_responses().get_column_count() == 1); + ONEDAL_ASSERT(result.get_responses().get_row_count() == input.get_data().get_row_count()); + } + + template + auto operator()(const Context& ctx, const Descriptor& desc, const input_t& input) const { + check_preconditions(desc, input); + const auto result = + infer_ops_dispatcher()(ctx, desc, input); + check_postconditions(desc, input, result); + return result; + } +}; + +} // namespace v1 + +using v1::infer_ops; + +} // namespace oneapi::dal::logistic_regression::detail diff --git a/cpp/oneapi/dal/algo/logistic_regression/detail/infer_ops_dpc.cpp b/cpp/oneapi/dal/algo/logistic_regression/detail/infer_ops_dpc.cpp new file mode 100644 index 00000000000..4f35d112f10 --- /dev/null +++ b/cpp/oneapi/dal/algo/logistic_regression/detail/infer_ops_dpc.cpp @@ -0,0 +1,44 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "oneapi/dal/algo/logistic_regression/backend/cpu/infer_kernel.hpp" +#include "oneapi/dal/algo/logistic_regression/backend/gpu/infer_kernel.hpp" +#include "oneapi/dal/algo/logistic_regression/detail/infer_ops.hpp" +#include "oneapi/dal/backend/dispatcher.hpp" + +namespace oneapi::dal::logistic_regression::detail { +namespace v1 { + +template +struct infer_ops_dispatcher { + infer_result operator()(const Policy& ctx, + const descriptor_base& params, + const infer_input& input) const { + using kernel_dispatcher_t = dal::backend::kernel_dispatcher< + KERNEL_SINGLE_NODE_CPU(backend::infer_kernel_cpu), + KERNEL_SINGLE_NODE_GPU(backend::infer_kernel_gpu)>; + return kernel_dispatcher_t{}(ctx, params, input); + } +}; + +#define INSTANTIATE(F, M, T) \ + template struct ONEDAL_EXPORT infer_ops_dispatcher; + +INSTANTIATE(float, method::dense_batch, task::classification) +INSTANTIATE(double, method::dense_batch, task::classification) + +} // namespace v1 +} // namespace oneapi::dal::logistic_regression::detail diff --git a/cpp/oneapi/dal/algo/logistic_regression/detail/optimizer.cpp b/cpp/oneapi/dal/algo/logistic_regression/detail/optimizer.cpp new file mode 100644 index 00000000000..514498aaf8d --- /dev/null +++ b/cpp/oneapi/dal/algo/logistic_regression/detail/optimizer.cpp @@ -0,0 +1,95 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "oneapi/dal/algo/logistic_regression/detail/optimizer.hpp" +#include "oneapi/dal/algo/logistic_regression/backend/optimizer_impl.hpp" + +#ifdef ONEDAL_DATA_PARALLEL +#include "oneapi/dal/backend/primitives/optimizers.hpp" +#endif + +namespace oneapi::dal::logistic_regression::detail { +namespace v1 { + +template +using newton_cg_optimizer_t = newton_cg::descriptor; + +namespace be = dal::backend; +namespace pr = be::primitives; + +class newton_cg_optimizer_impl : public optimizer_impl { +public: + newton_cg_optimizer_impl(std::int64_t max_iter, double tol) : max_iter_(max_iter), tol_(tol) {} + + optimizer_type get_optimizer_type() override { + return optimizer_type::newton_cg; + } + + double get_tol() override { + return tol_; + } + + std::int64_t get_max_iter() override { + return max_iter_; + } + +#ifdef ONEDAL_DATA_PARALLEL + template + std::pair minimize_impl(sycl::queue& q, + pr::base_function& f, + pr::ndview& x, + const be::event_vector& deps = {}) { + return pr::newton_cg(q, f, x, Float(tol_), max_iter_, deps); + } + + std::pair minimize(sycl::queue& q, + pr::base_function& f, + pr::ndview& x, + const be::event_vector& deps = {}) final { + return minimize_impl(q, f, x, deps); + } + + std::pair minimize(sycl::queue& q, + pr::base_function& f, + pr::ndview& x, + const be::event_vector& deps = {}) final { + return minimize_impl(q, f, x, deps); + } +#endif + +private: + std::int64_t max_iter_; + double tol_; +}; + +template +optimizer>::optimizer(const newton_cg_optimizer_t& opt) + : optimizer_(opt), + impl_(new newton_cg_optimizer_impl{ opt.get_max_iteration(), opt.get_tolerance() }) {} + +template +optimizer_impl* optimizer>::get_impl() const { + return impl_.get(); +} + +#define INSTANTIATE_NEWTON_CG(F, M) \ + template class ONEDAL_EXPORT optimizer>; + +INSTANTIATE_NEWTON_CG(float, newton_cg::method::dense) +INSTANTIATE_NEWTON_CG(double, newton_cg::method::dense) + +} // namespace v1 +} // namespace oneapi::dal::logistic_regression::detail diff --git a/cpp/oneapi/dal/algo/logistic_regression/detail/optimizer.hpp b/cpp/oneapi/dal/algo/logistic_regression/detail/optimizer.hpp new file mode 100644 index 00000000000..baa7d28bea4 --- /dev/null +++ b/cpp/oneapi/dal/algo/logistic_regression/detail/optimizer.hpp @@ -0,0 +1,85 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#pragma once + +#include "oneapi/dal/algo/newton_cg/common.hpp" + +namespace oneapi::dal::logistic_regression::detail { +namespace v1 { + +class optimizer_impl; + +class optimizer_iface { +public: + virtual ~optimizer_iface() {} + virtual optimizer_impl* get_impl() const = 0; +}; + +using optimizer_ptr = std::shared_ptr; + +template +class optimizer : public base, public optimizer_iface { +public: + explicit optimizer(const Optimizer& optimizer) : optimizer_(optimizer) {} + + optimizer_impl* get_impl() const override { + return nullptr; + } + + const Optimizer& get_optimizer() const { + return optimizer_; + } + +private: + Optimizer optimizer_; + dal::detail::pimpl impl_; +}; + +template +class optimizer> : public base, public optimizer_iface { +public: + using optimizer_t = newton_cg::descriptor; + explicit optimizer(const optimizer_t& opt); + optimizer_impl* get_impl() const override; + +private: + optimizer_t optimizer_; + dal::detail::pimpl impl_; +}; + +struct optimizer_accessor { + template + const optimizer_ptr& get_optimizer_impl(Optimizer&& desc) const { + return desc.get_optimizer_impl(); + } +}; + +template +optimizer_impl* get_optimizer_impl(Descriptor&& desc) { + const auto& optimizer = optimizer_accessor{}.get_optimizer_impl(std::forward(desc)); + return optimizer ? optimizer->get_impl() : nullptr; +} + +} // namespace v1 + +using v1::optimizer_iface; +using v1::optimizer_ptr; +using v1::optimizer; +using v1::optimizer_accessor; +using v1::get_optimizer_impl; + +} // namespace oneapi::dal::logistic_regression::detail diff --git a/cpp/oneapi/dal/algo/logistic_regression/detail/train_ops.cpp b/cpp/oneapi/dal/algo/logistic_regression/detail/train_ops.cpp new file mode 100644 index 00000000000..ab8b385154e --- /dev/null +++ b/cpp/oneapi/dal/algo/logistic_regression/detail/train_ops.cpp @@ -0,0 +1,69 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "oneapi/dal/algo/logistic_regression/parameters/cpu/train_parameters.hpp" +#include "oneapi/dal/algo/logistic_regression/backend/cpu/train_kernel.hpp" +#include "oneapi/dal/algo/logistic_regression/detail/train_ops.hpp" +#include "oneapi/dal/backend/dispatcher.hpp" + +namespace oneapi::dal::logistic_regression::detail { +namespace v1 { + +using dal::detail::host_policy; + +template +struct train_ops_dispatcher { + train_result operator()(const Policy& ctx, + const descriptor_base& desc, + const train_parameters& params, + const train_input& input) const { + return implementation(ctx, desc, params, input); + } + + train_parameters select_parameters(const Policy& ctx, + const descriptor_base& desc, + const train_input& input) const { + using kernel_dispatcher_t = dal::backend::kernel_dispatcher< // + KERNEL_SINGLE_NODE_CPU(parameters::train_parameters_cpu)>; + return kernel_dispatcher_t{}(ctx, desc, input); + } + + train_result operator()(const Policy& ctx, + const descriptor_base& desc, + const train_input& input) const { + const auto params = select_parameters(ctx, desc, input); + return implementation(ctx, desc, params, input); + } + +private: + inline auto implementation(const Policy& ctx, + const descriptor_base& desc, + const train_parameters& params, + const train_input& input) const { + using kernel_dispatcher_t = dal::backend::kernel_dispatcher< // + KERNEL_SINGLE_NODE_CPU(backend::train_kernel_cpu)>; + return kernel_dispatcher_t{}(ctx, desc, params, input); + } +}; + +#define INSTANTIATE(F, M, T) \ + template struct ONEDAL_EXPORT train_ops_dispatcher; + +INSTANTIATE(float, method::dense_batch, task::classification) +INSTANTIATE(double, method::dense_batch, task::classification) + +} // namespace v1 +} // namespace oneapi::dal::logistic_regression::detail diff --git a/cpp/oneapi/dal/algo/logistic_regression/detail/train_ops.hpp b/cpp/oneapi/dal/algo/logistic_regression/detail/train_ops.hpp new file mode 100644 index 00000000000..7de0f498f26 --- /dev/null +++ b/cpp/oneapi/dal/algo/logistic_regression/detail/train_ops.hpp @@ -0,0 +1,147 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#pragma once + +#include "oneapi/dal/algo/logistic_regression/train_types.hpp" +#include "oneapi/dal/detail/error_messages.hpp" + +namespace oneapi::dal::logistic_regression::detail { +namespace v1 { + +template +struct train_ops_dispatcher { + train_result operator()(const Context&, + const descriptor_base&, + const train_parameters&, + const train_input&) const; + train_parameters select_parameters(const Context&, + const descriptor_base&, + const train_input&) const; + train_result operator()(const Context&, + const descriptor_base&, + const train_input&) const; +}; + +template +struct train_ops { + using float_t = typename Descriptor::float_t; + using method_t = typename Descriptor::method_t; + using task_t = typename Descriptor::task_t; + + using input_t = train_input; + using result_t = train_result; + using param_t = train_parameters; + using descriptor_base_t = descriptor_base; + + void check_preconditions(const Descriptor& params, const input_t& input) const { + using msg = dal::detail::error_messages; + + const auto& data = input.get_data(); + const auto& responses = input.get_responses(); + + if (params.get_class_count() != 2) { + throw domain_error(msg::class_count_neq_two()); + } + if (params.get_inverse_regularization() <= 0.0) { + throw domain_error(msg::inverse_regularization_leq_zero()); + } + + if (!data.has_data()) { + throw domain_error(msg::input_data_is_empty()); + } + if (!responses.has_data()) { + throw domain_error(msg::input_responses_are_empty()); + } + if (data.get_row_count() != responses.get_row_count()) { + throw domain_error(msg::input_data_rc_neq_input_responses_rc()); + } + + if (responses.get_column_count() != 1) { + throw domain_error(msg::input_responses_table_has_wrong_cc_expect_one()); + } + } + + void check_postconditions(const Descriptor& params, + const input_t& input, + const result_t& result) const { + const auto& res = params.get_result_options(); + + [[maybe_unused]] const std::int64_t f_count = // + input.get_data().get_column_count(); + [[maybe_unused]] const std::int64_t r_count = // + input.get_responses().get_column_count(); + + ONEDAL_ASSERT(r_count == 1); + + if (res.test(result_options::coefficients)) { + [[maybe_unused]] const table& coefficients = // + result.get_coefficients(); + ONEDAL_ASSERT(coefficients.has_data()); + ONEDAL_ASSERT(coefficients.get_row_count() == r_count); + ONEDAL_ASSERT(coefficients.get_column_count() == f_count); + } + + if (res.test(result_options::intercept)) { + [[maybe_unused]] const table& intercept = // + result.get_intercept(); + + ONEDAL_ASSERT(intercept.has_data()); + ONEDAL_ASSERT(intercept.get_row_count() == r_count); + ONEDAL_ASSERT(intercept.get_column_count() == 1); + } + + { + [[maybe_unused]] const table& betas = // + result.get_packed_coefficients(); + + ONEDAL_ASSERT(betas.has_data()); + ONEDAL_ASSERT(betas.get_row_count() == r_count); + ONEDAL_ASSERT(betas.get_column_count() == f_count + 1); + } + } + + template + auto select_parameters(const Context& ctx, const Descriptor& desc, const input_t& input) const { + check_preconditions(desc, input); + return train_ops_dispatcher{}.select_parameters(ctx, + desc, + input); + } + + template + auto operator()(const Context& ctx, + const Descriptor& desc, + const param_t& params, + const input_t& input) const { + const auto result = + train_ops_dispatcher{}(ctx, desc, params, input); + check_postconditions(desc, input, result); + return result; + } + + template + auto operator()(const Context& ctx, const Descriptor& desc, const input_t& input) const { + const auto params = select_parameters(ctx, desc, input); + return this->operator()(ctx, desc, params, input); + } +}; + +} // namespace v1 + +using v1::train_ops; + +} // namespace oneapi::dal::logistic_regression::detail diff --git a/cpp/oneapi/dal/algo/logistic_regression/detail/train_ops_dpc.cpp b/cpp/oneapi/dal/algo/logistic_regression/detail/train_ops_dpc.cpp new file mode 100644 index 00000000000..52c426b00f2 --- /dev/null +++ b/cpp/oneapi/dal/algo/logistic_regression/detail/train_ops_dpc.cpp @@ -0,0 +1,71 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "oneapi/dal/algo/logistic_regression/parameters/cpu/train_parameters.hpp" +#include "oneapi/dal/algo/logistic_regression/parameters/gpu/train_parameters.hpp" +#include "oneapi/dal/algo/logistic_regression/backend/cpu/train_kernel.hpp" +#include "oneapi/dal/algo/logistic_regression/backend/gpu/train_kernel.hpp" +#include "oneapi/dal/algo/logistic_regression/detail/train_ops.hpp" +#include "oneapi/dal/backend/dispatcher.hpp" + +namespace oneapi::dal::logistic_regression::detail { +namespace v1 { + +template +struct train_ops_dispatcher { + train_result operator()(const Policy& ctx, + const descriptor_base& desc, + const train_parameters& params, + const train_input& input) const { + return implementation(ctx, desc, params, input); + } + + train_parameters select_parameters(const Policy& ctx, + const descriptor_base& desc, + const train_input& input) const { + using kernel_dispatcher_t = dal::backend::kernel_dispatcher< + KERNEL_SINGLE_NODE_CPU(parameters::train_parameters_cpu), + KERNEL_SINGLE_NODE_GPU(parameters::train_parameters_gpu)>; + return kernel_dispatcher_t{}(ctx, desc, input); + } + + train_result operator()(const Policy& ctx, + const descriptor_base& desc, + const train_input& input) const { + const auto params = select_parameters(ctx, desc, input); + return implementation(ctx, desc, params, input); + } + +private: + inline auto implementation(const Policy& ctx, + const descriptor_base& desc, + const train_parameters& params, + const train_input& input) const { + using kernel_dispatcher_t = dal::backend::kernel_dispatcher< + KERNEL_SINGLE_NODE_CPU(backend::train_kernel_cpu), + KERNEL_SINGLE_NODE_GPU(backend::train_kernel_gpu)>; + return kernel_dispatcher_t{}(ctx, desc, params, input); + } +}; + +#define INSTANTIATE(F, M, T) \ + template struct ONEDAL_EXPORT train_ops_dispatcher; + +INSTANTIATE(float, method::dense_batch, task::classification) +INSTANTIATE(double, method::dense_batch, task::classification) + +} // namespace v1 +} // namespace oneapi::dal::logistic_regression::detail diff --git a/cpp/oneapi/dal/algo/logistic_regression/infer.hpp b/cpp/oneapi/dal/algo/logistic_regression/infer.hpp new file mode 100644 index 00000000000..e8352316f91 --- /dev/null +++ b/cpp/oneapi/dal/algo/logistic_regression/infer.hpp @@ -0,0 +1,32 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#pragma once + +#include "oneapi/dal/algo/logistic_regression/common.hpp" +#include "oneapi/dal/algo/logistic_regression/detail/infer_ops.hpp" +#include "oneapi/dal/algo/logistic_regression/infer_types.hpp" +#include "oneapi/dal/infer.hpp" + +namespace oneapi::dal::detail { +namespace v1 { + +template +struct infer_ops + : dal::logistic_regression::detail::infer_ops {}; + +} // namespace v1 +} // namespace oneapi::dal::detail diff --git a/cpp/oneapi/dal/algo/logistic_regression/infer_types.cpp b/cpp/oneapi/dal/algo/logistic_regression/infer_types.cpp new file mode 100644 index 00000000000..04f98c23872 --- /dev/null +++ b/cpp/oneapi/dal/algo/logistic_regression/infer_types.cpp @@ -0,0 +1,95 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "oneapi/dal/algo/logistic_regression/infer_types.hpp" +#include "oneapi/dal/detail/common.hpp" +#include "oneapi/dal/exceptions.hpp" + +namespace oneapi::dal::logistic_regression { + +template +class detail::v1::infer_input_impl : public base { +public: + infer_input_impl(const table& data, const model& m) : data(data), trained_model(m) {} + + table data; + model trained_model; +}; + +template +class detail::v1::infer_result_impl : public base { +public: + table responses; + table probabilities; +}; + +using detail::v1::infer_input_impl; +using detail::v1::infer_result_impl; + +namespace v1 { + +template +infer_input::infer_input(const table& data, const model& m) + : impl_(new infer_input_impl(data, m)) {} + +template +const table& infer_input::get_data() const { + return impl_->data; +} + +template +const model& infer_input::get_model() const { + return impl_->trained_model; +} + +template +void infer_input::set_data_impl(const table& value) { + impl_->data = value; +} + +template +void infer_input::set_model_impl(const model& value) { + impl_->trained_model = value; +} + +template +infer_result::infer_result() : impl_(new infer_result_impl{}) {} + +template +const table& infer_result::get_responses() const { + return impl_->responses; +} + +template +const table& infer_result::get_probabilities() const { + return impl_->probabilities; +} + +template +void infer_result::set_responses_impl(const table& value) { + impl_->responses = value; +} + +template +void infer_result::set_probabilities_impl(const table& value) { + impl_->probabilities = value; +} + +template class ONEDAL_EXPORT infer_input; +template class ONEDAL_EXPORT infer_result; + +} // namespace v1 +} // namespace oneapi::dal::logistic_regression diff --git a/cpp/oneapi/dal/algo/logistic_regression/infer_types.hpp b/cpp/oneapi/dal/algo/logistic_regression/infer_types.hpp new file mode 100644 index 00000000000..eee2bebb651 --- /dev/null +++ b/cpp/oneapi/dal/algo/logistic_regression/infer_types.hpp @@ -0,0 +1,121 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#pragma once + +#include "oneapi/dal/algo/logistic_regression/common.hpp" + +namespace oneapi::dal::logistic_regression { + +namespace detail { +namespace v1 { +template +class infer_input_impl; + +template +class infer_result_impl; +} // namespace v1 + +using v1::infer_input_impl; +using v1::infer_result_impl; + +} // namespace detail + +namespace v1 { + +/// @tparam Task Tag-type that specifies type of the problem to solve. Can +/// be :expr:`task::classification` or :expr:`task::search`. +template +class infer_input : public base { + static_assert(detail::is_valid_task_v); + +public: + using task_t = Task; + + /// Creates a new instance of the class with the given :literal:`model` + /// and :literal:`data` property values + infer_input(const table& data, const model& model); + + /// The dataset for inference $X'$ + /// @remark default = table{} + const table& get_data() const; + + auto& set_data(const table& data) { + set_data_impl(data); + return *this; + } + + /// The trained k-NN model + /// @remark default = model{} + const model& get_model() const; + + auto& set_model(const model& m) { + set_model_impl(m); + return *this; + } + +protected: + void set_data_impl(const table& data); + void set_model_impl(const model& model); + +private: + dal::detail::pimpl> impl_; +}; + +/// @tparam Task Tag-type that specifies type of the problem to solve. Can +/// be :expr:`task::regression`. +template +class infer_result { + static_assert(detail::is_valid_task_v); + +public: + using task_t = Task; + + /// Creates a new instance of the class with the default property values. + infer_result(); + + /// The predicted responses + /// @remark default = table{} + const table& get_responses() const; + + /// The predicted responses + /// @remark default = table{} + const table& get_probabilities() const; + + auto& set_responses(const table& value) { + set_responses_impl(value); + return *this; + } + + auto& set_probabilities(const table& value) { + set_probabilities_impl(value); + return *this; + } + +protected: + void set_responses_impl(const table&); + void set_probabilities_impl(const table&); + +private: + dal::detail::pimpl> impl_; +}; + +} // namespace v1 + +using v1::infer_input; +using v1::infer_result; + +} // namespace oneapi::dal::logistic_regression diff --git a/cpp/oneapi/dal/algo/logistic_regression/parameters/BUILD b/cpp/oneapi/dal/algo/logistic_regression/parameters/BUILD new file mode 100644 index 00000000000..8c17a868e22 --- /dev/null +++ b/cpp/oneapi/dal/algo/logistic_regression/parameters/BUILD @@ -0,0 +1,13 @@ +package(default_visibility = ["//visibility:public"]) +load("@onedal//dev/bazel:dal.bzl", + "dal_module", + "dal_test_suite", +) + +dal_module( + name = "parameters", + dal_deps = [ + "@onedal//cpp/oneapi/dal/algo/logistic_regression/parameters/cpu", + "@onedal//cpp/oneapi/dal/algo/logistic_regression/parameters/gpu", + ], +) diff --git a/cpp/oneapi/dal/algo/logistic_regression/parameters/cpu/BUILD b/cpp/oneapi/dal/algo/logistic_regression/parameters/cpu/BUILD new file mode 100644 index 00000000000..252452637ce --- /dev/null +++ b/cpp/oneapi/dal/algo/logistic_regression/parameters/cpu/BUILD @@ -0,0 +1,13 @@ +package(default_visibility = ["//visibility:public"]) +load("@onedal//dev/bazel:dal.bzl", + "dal_module", + "dal_test_suite", +) + +dal_module( + name = "cpu", + auto = True, + dal_deps = [ + "@onedal//cpp/oneapi/dal/algo/logistic_regression:core", + ], +) diff --git a/cpp/oneapi/dal/algo/logistic_regression/parameters/cpu/train_parameters.cpp b/cpp/oneapi/dal/algo/logistic_regression/parameters/cpu/train_parameters.cpp new file mode 100644 index 00000000000..96e7d8e0ddb --- /dev/null +++ b/cpp/oneapi/dal/algo/logistic_regression/parameters/cpu/train_parameters.cpp @@ -0,0 +1,63 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "oneapi/dal/detail/common.hpp" + +#include "oneapi/dal/backend/dispatcher.hpp" +#include "oneapi/dal/table/row_accessor.hpp" + +#include "oneapi/dal/algo/logistic_regression/common.hpp" +#include "oneapi/dal/algo/logistic_regression/train_types.hpp" + +#include "oneapi/dal/algo/logistic_regression/parameters/cpu/train_parameters.hpp" + +namespace oneapi::dal::logistic_regression::parameters { + +using dal::backend::context_cpu; + +template +std::int64_t propose_block_size(const std::int64_t f, const std::int64_t r) { + constexpr std::int64_t fsize = sizeof(Float); + std::int64_t proposal = 0x100l * (8 / fsize); + return std::max(128l, proposal); +} + +template +struct train_parameters_cpu { + using params_t = detail::train_parameters; + params_t operator()(const context_cpu& ctx, + const detail::descriptor_base& desc, + const train_input& input) const { + const auto& x_train = input.get_data(); + const auto& y_train = input.get_responses(); + + const auto f_count = x_train.get_column_count(); + const auto r_count = y_train.get_column_count(); + + const auto block = propose_block_size(f_count, r_count); + + return params_t{}.set_cpu_macro_block(block); + } +}; + +template struct ONEDAL_EXPORT + train_parameters_cpu; +template struct ONEDAL_EXPORT + train_parameters_cpu; + +} // namespace oneapi::dal::logistic_regression::parameters diff --git a/cpp/oneapi/dal/algo/logistic_regression/parameters/cpu/train_parameters.hpp b/cpp/oneapi/dal/algo/logistic_regression/parameters/cpu/train_parameters.hpp new file mode 100644 index 00000000000..84c3708f0f8 --- /dev/null +++ b/cpp/oneapi/dal/algo/logistic_regression/parameters/cpu/train_parameters.hpp @@ -0,0 +1,32 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#pragma once + +#include "oneapi/dal/algo/logistic_regression/train_types.hpp" +#include "oneapi/dal/backend/dispatcher.hpp" + +namespace oneapi::dal::logistic_regression::parameters { + +template +struct ONEDAL_EXPORT train_parameters_cpu { + using params_t = detail::train_parameters; + params_t operator()(const dal::backend::context_cpu& ctx, + const detail::descriptor_base& desc, + const train_input& input) const; +}; + +} // namespace oneapi::dal::logistic_regression::parameters diff --git a/cpp/oneapi/dal/algo/logistic_regression/parameters/gpu/BUILD b/cpp/oneapi/dal/algo/logistic_regression/parameters/gpu/BUILD new file mode 100644 index 00000000000..12109440cee --- /dev/null +++ b/cpp/oneapi/dal/algo/logistic_regression/parameters/gpu/BUILD @@ -0,0 +1,13 @@ +package(default_visibility = ["//visibility:public"]) +load("@onedal//dev/bazel:dal.bzl", + "dal_module", + "dal_test_suite", +) + +dal_module( + name = "gpu", + auto = True, + dal_deps = [ + "@onedal//cpp/oneapi/dal/algo/logistic_regression:core", + ], +) diff --git a/cpp/oneapi/dal/algo/logistic_regression/parameters/gpu/train_parameters.hpp b/cpp/oneapi/dal/algo/logistic_regression/parameters/gpu/train_parameters.hpp new file mode 100644 index 00000000000..28442ac0cdf --- /dev/null +++ b/cpp/oneapi/dal/algo/logistic_regression/parameters/gpu/train_parameters.hpp @@ -0,0 +1,32 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#pragma once + +#include "oneapi/dal/algo/logistic_regression/train_types.hpp" +#include "oneapi/dal/backend/dispatcher.hpp" + +namespace oneapi::dal::logistic_regression::parameters { + +template +struct ONEDAL_EXPORT train_parameters_gpu { + using params_t = detail::train_parameters; + params_t operator()(const dal::backend::context_gpu& ctx, + const detail::descriptor_base& desc, + const train_input& input) const; +}; + +} // namespace oneapi::dal::logistic_regression::parameters diff --git a/cpp/oneapi/dal/algo/logistic_regression/parameters/gpu/train_parameters_dpc.cpp b/cpp/oneapi/dal/algo/logistic_regression/parameters/gpu/train_parameters_dpc.cpp new file mode 100644 index 00000000000..f2537443c68 --- /dev/null +++ b/cpp/oneapi/dal/algo/logistic_regression/parameters/gpu/train_parameters_dpc.cpp @@ -0,0 +1,63 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "oneapi/dal/detail/common.hpp" + +#include "oneapi/dal/backend/common.hpp" +#include "oneapi/dal/backend/dispatcher.hpp" +#include "oneapi/dal/table/row_accessor.hpp" + +#include "oneapi/dal/algo/logistic_regression/common.hpp" +#include "oneapi/dal/algo/logistic_regression/train_types.hpp" + +#include "oneapi/dal/algo/logistic_regression/parameters/gpu/train_parameters.hpp" + +namespace oneapi::dal::logistic_regression::parameters { + +using dal::backend::context_gpu; + +template +std::int64_t propose_block_size(const sycl::queue& q, const std::int64_t f, const std::int64_t r) { + constexpr std::int64_t fsize = sizeof(Float); + return 0x10000l * (8 / fsize); +} + +template +struct train_parameters_gpu { + using params_t = detail::train_parameters; + params_t operator()(const context_gpu& ctx, + const detail::descriptor_base& desc, + const train_input& input) const { + const auto& queue = ctx.get_queue(); + + const auto& x_train = input.get_data(); + const auto& y_train = input.get_responses(); + + const auto f_count = x_train.get_column_count(); + const auto r_count = y_train.get_column_count(); + + const auto block = propose_block_size(queue, f_count, r_count); + + return params_t{}.set_gpu_macro_block(block); + } +}; + +template struct ONEDAL_EXPORT + train_parameters_gpu; +template struct ONEDAL_EXPORT + train_parameters_gpu; + +} // namespace oneapi::dal::logistic_regression::parameters diff --git a/cpp/oneapi/dal/algo/logistic_regression/test/batch_dpc.cpp b/cpp/oneapi/dal/algo/logistic_regression/test/batch_dpc.cpp new file mode 100644 index 00000000000..37880bc0197 --- /dev/null +++ b/cpp/oneapi/dal/algo/logistic_regression/test/batch_dpc.cpp @@ -0,0 +1,56 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "oneapi/dal/algo/logistic_regression/test/fixture.hpp" + +#include "oneapi/dal/test/engine/tables.hpp" +#include "oneapi/dal/test/engine/io.hpp" + +namespace oneapi::dal::logistic_regression::test { + +namespace te = dal::test::engine; +namespace de = dal::detail; +namespace la = te::linalg; + +template +class log_reg_batch_test : public log_reg_test> { +public: + using base_t = log_reg_test>; + using float_t = typename base_t::float_t; + using train_input_t = typename base_t::train_input_t; + using train_result_t = typename base_t::train_result_t; +}; + +TEMPLATE_LIST_TEST_M(log_reg_batch_test, "LR common flow", "[lr][batch]", lr_types) { + SKIP_IF(this->not_float64_friendly()); + SKIP_IF(this->get_policy().is_cpu()); + this->gen_input(true, 0.5); + + this->run_test(); +} + +TEMPLATE_LIST_TEST_M(log_reg_batch_test, + "LR common flow - no fit intercept", + "[lr][batch]", + lr_types) { + SKIP_IF(this->not_float64_friendly()); + SKIP_IF(this->get_policy().is_cpu()); + this->gen_input(false, 0.5); + + this->run_test(); +} + +} // namespace oneapi::dal::logistic_regression::test diff --git a/cpp/oneapi/dal/algo/logistic_regression/test/fixture.hpp b/cpp/oneapi/dal/algo/logistic_regression/test/fixture.hpp new file mode 100644 index 00000000000..9a2a2c3e7bf --- /dev/null +++ b/cpp/oneapi/dal/algo/logistic_regression/test/fixture.hpp @@ -0,0 +1,208 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include + +#include "oneapi/dal/algo/logistic_regression/common.hpp" +#include "oneapi/dal/algo/logistic_regression/train.hpp" +#include "oneapi/dal/algo/logistic_regression/infer.hpp" + +#include "oneapi/dal/table/homogen.hpp" +#include "oneapi/dal/table/row_accessor.hpp" +#include "oneapi/dal/table/detail/table_builder.hpp" + +#include "oneapi/dal/test/engine/fixtures.hpp" +#include "oneapi/dal/test/engine/math.hpp" + +namespace oneapi::dal::logistic_regression::test { + +namespace te = dal::test::engine; +namespace de = dal::detail; +namespace la = te::linalg; + +template +class log_reg_test : public te::crtp_algo_fixture { +public: + using float_t = std::tuple_element_t<0, TestType>; + using method_t = std::tuple_element_t<1, TestType>; + using task_t = std::tuple_element_t<2, TestType>; + + using train_input_t = train_input; + using train_result_t = train_result; + using test_input_t = infer_input; + using test_result_t = infer_result; + + te::table_id get_homogen_table_id() const { + return te::table_id::homogen(); + } + + Derived* get_impl() { + return static_cast(this); + } + + auto get_descriptor() const { + result_option_id resopts = result_options::coefficients; + if (this->fit_intercept_) + resopts = resopts | result_options::intercept; + return logistic_regression::descriptor(fit_intercept_, C_) + .set_result_options(resopts); + } + + void gen_dimensions(std::int64_t n = -1, std::int64_t p = -1) { + if (n == -1 || p == -1) { + this->n_ = GENERATE(100, 200, 1000, 10000, 50000); + this->p_ = GENERATE(10, 20, 30); + } + else { + this->n_ = n; + this->p_ = p; + } + } + + float_t predict_proba(float_t* ptr, float_t* params_ptr, float_t intercept) { + float_t val = 0; + for (std::int64_t j = 0; j < p_; ++j) { + val += ptr[j] * params_ptr[j]; + } + val += intercept; + return float_t(1) / (1 + std::exp(-val)); + } + + void gen_input(bool fit_intercept = true, double C = 1.0, std::int64_t seed = 2007) { + this->get_impl()->gen_dimensions(); + + this->fit_intercept_ = fit_intercept; + this->C_ = C; + + std::int64_t dim = fit_intercept_ ? p_ + 1 : p_; + + X_host_ = array::zeros(n_ * p_); + auto* x_ptr = X_host_.get_mutable_data(); + + y_host_ = array::zeros(n_); + auto* y_ptr = y_host_.get_mutable_data(); + + params_host_ = array::zeros(dim); + auto* params_ptr = params_host_.get_mutable_data(); + + std::mt19937 rnd(seed + n_ + p_); + std::uniform_real_distribution<> dis_data(-10.0, 10.0); + std::uniform_real_distribution<> dis_params(-3.0, 3.0); + + for (std::int64_t i = 0; i < n_; ++i) { + for (std::int64_t j = 0; j < p_; ++j) { + *(x_ptr + i * p_ + j) = dis_data(rnd); + } + } + + for (std::int64_t i = 0; i < dim; ++i) { + *(params_ptr + i) = dis_params(rnd); + } + + constexpr float_t half = 0.5; + for (std::int64_t i = 0; i < n_; ++i) { + float_t val = predict_proba(x_ptr + i * p_, + params_ptr + (std::int64_t)fit_intercept_, + fit_intercept_ ? *params_ptr : 0); + y_ptr[i] = bool(val < half); + } + } + + void run_test() { + std::int64_t train_size = n_ * 0.7; + std::int64_t test_size = n_ - train_size; + + table X_train = homogen_table::wrap(X_host_.get_mutable_data(), train_size, p_); + table X_test = homogen_table::wrap(X_host_.get_mutable_data() + train_size * p_, + test_size, + p_); + table y_train = + homogen_table::wrap(y_host_.get_mutable_data(), train_size, 1); + + const auto desc = this->get_descriptor(); + const auto train_res = this->train(desc, X_train, y_train); + table intercept; + array bias_host; + if (fit_intercept_) { + intercept = train_res.get_intercept(); + bias_host = row_accessor(intercept).pull({ 0, -1 }); + } + table coefs = train_res.get_coefficients(); + auto coefs_host = row_accessor(coefs).pull({ 0, -1 }); + + std::int64_t train_acc = 0; + std::int64_t test_acc = 0; + + const auto infer_res = this->infer(desc, X_test, train_res.get_model()); + + table resp_table = infer_res.get_responses(); + auto resp_host = row_accessor(resp_table).pull({ 0, -1 }); + + table prob_table = infer_res.get_probabilities(); + auto prob_host = row_accessor(prob_table).pull({ 0, -1 }); + + for (std::int64_t i = 0; i < n_; ++i) { + float_t val = predict_proba(X_host_.get_mutable_data() + i * p_, + coefs_host.get_mutable_data(), + fit_intercept_ ? *bias_host.get_mutable_data() : 0); + std::int32_t resp = 0; + if (val >= 0.5) { + resp = 1; + } + if (resp == *(y_host_.get_mutable_data() + i)) { + bool is_train = i < train_size; + train_acc += std::int64_t(is_train); + test_acc += std::int64_t(!is_train); + } + if (i >= train_size) { + REQUIRE(abs(val - *(prob_host.get_mutable_data() + i - train_size)) < 1e-5); + REQUIRE(*(resp_host.get_mutable_data() + i - train_size) == resp); + } + } + std::int64_t acc_algo = 0; + for (std::int64_t i = 0; i < test_size; ++i) { + if (*(resp_host.get_mutable_data() + i) == + *(y_host_.get_mutable_data() + train_size + i)) { + acc_algo++; + } + } + + float_t min_train_acc = 0.95; + float_t min_test_acc = n_ < 500 ? 0.7 : 0.85; + + REQUIRE(train_size * min_train_acc < train_acc); + REQUIRE(test_size * min_test_acc < test_acc); + REQUIRE(test_size * min_test_acc < acc_algo); + REQUIRE(test_acc == acc_algo); + } + +protected: + bool fit_intercept_ = true; + double C_ = 1.0; + std::int64_t n_ = 0; + std::int64_t p_ = 0; + array X_host_; + array params_host_; + array y_host_; + array resp_; +}; + +using lr_types = COMBINE_TYPES((double), + (logistic_regression::method::dense_batch), + (logistic_regression::task::classification)); + +} // namespace oneapi::dal::logistic_regression::test diff --git a/cpp/oneapi/dal/algo/logistic_regression/train.hpp b/cpp/oneapi/dal/algo/logistic_regression/train.hpp new file mode 100644 index 00000000000..c71f22657f6 --- /dev/null +++ b/cpp/oneapi/dal/algo/logistic_regression/train.hpp @@ -0,0 +1,32 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#pragma once + +#include "oneapi/dal/algo/logistic_regression/common.hpp" +#include "oneapi/dal/algo/logistic_regression/detail/train_ops.hpp" +#include "oneapi/dal/algo/logistic_regression/train_types.hpp" +#include "oneapi/dal/train.hpp" + +namespace oneapi::dal::detail { +namespace v1 { + +template +struct train_ops + : dal::logistic_regression::detail::train_ops {}; + +} // namespace v1 +} // namespace oneapi::dal::detail diff --git a/cpp/oneapi/dal/algo/logistic_regression/train_types.cpp b/cpp/oneapi/dal/algo/logistic_regression/train_types.cpp new file mode 100644 index 00000000000..a0f88b94ad9 --- /dev/null +++ b/cpp/oneapi/dal/algo/logistic_regression/train_types.cpp @@ -0,0 +1,202 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "oneapi/dal/algo/logistic_regression/train_types.hpp" +#include "oneapi/dal/detail/common.hpp" +#include "oneapi/dal/exceptions.hpp" + +namespace oneapi::dal::logistic_regression { + +namespace detail::v1 { + +template +class train_input_impl : public base { +public: + train_input_impl(const table& data, const table& responses = table{}) + : data(data), + responses(responses) {} + + table data; + table responses; +}; + +template +class train_result_impl : public base { +public: + table intercept; + table coefficients; + std::int64_t iter_cnt; + + result_option_id options; + + model trained_model; +}; + +template +struct train_parameters_impl : public base { + std::int64_t cpu_macro_block = 8'192l; + std::int64_t gpu_macro_block = 16'384l; +}; + +template +train_parameters::train_parameters() : impl_(new train_parameters_impl{}) {} + +template +std::int64_t train_parameters::get_cpu_macro_block() const { + return impl_->cpu_macro_block; +} + +template +void train_parameters::set_cpu_macro_block_impl(std::int64_t val) { + impl_->cpu_macro_block = val; +} + +template +std::int64_t train_parameters::get_gpu_macro_block() const { + return impl_->gpu_macro_block; +} + +template +void train_parameters::set_gpu_macro_block_impl(std::int64_t val) { + impl_->gpu_macro_block = val; +} + +template class ONEDAL_EXPORT train_parameters; + +} // namespace detail::v1 + +using detail::v1::train_input_impl; +using detail::v1::train_result_impl; +using detail::v1::train_parameters; + +namespace v1 { + +template +train_input::train_input(const table& data, const table& responses) + : impl_(new train_input_impl(data, responses)) {} + +template +const table& train_input::get_data() const { + return impl_->data; +} + +template +const table& train_input::get_responses() const { + return impl_->responses; +} + +template +void train_input::set_data_impl(const table& value) { + impl_->data = value; +} + +template +void train_input::set_responses_impl(const table& value) { + impl_->responses = value; +} + +template +train_result::train_result() : impl_(new train_result_impl{}) {} + +template +const model& train_result::get_model() const { + return impl_->trained_model; +} + +template +void train_result::set_model_impl(const model& value) { + impl_->trained_model = value; +} + +template +const table& train_result::get_intercept() const { + using msg = dal::detail::error_messages; + if (!get_result_options().test(result_options::intercept)) { + throw domain_error(msg::this_result_is_not_enabled_via_result_options()); + } + return impl_->intercept; +} + +template +void train_result::set_intercept_impl(const table& value) { + using msg = dal::detail::error_messages; + if (!get_result_options().test(result_options::intercept)) { + throw domain_error(msg::this_result_is_not_enabled_via_result_options()); + } + impl_->intercept = value; +} + +template +std::int64_t train_result::get_iterations_count() const { + using msg = dal::detail::error_messages; + if (!get_result_options().test(result_options::iterations_count)) { + throw domain_error(msg::this_result_is_not_enabled_via_result_options()); + } + return impl_->iter_cnt; +} + +template +void train_result::set_iterations_count_impl(std::int64_t value) { + using msg = dal::detail::error_messages; + if (!get_result_options().test(result_options::iterations_count)) { + throw domain_error(msg::this_result_is_not_enabled_via_result_options()); + } + impl_->iter_cnt = value; +} + +template +const table& train_result::get_coefficients() const { + using msg = dal::detail::error_messages; + if (!get_result_options().test(result_options::coefficients)) { + throw domain_error(msg::this_result_is_not_enabled_via_result_options()); + } + return impl_->coefficients; +} + +template +void train_result::set_coefficients_impl(const table& value) { + using msg = dal::detail::error_messages; + if (!get_result_options().test(result_options::coefficients)) { + throw domain_error(msg::this_result_is_not_enabled_via_result_options()); + } + impl_->coefficients = value; +} + +template +const table& train_result::get_packed_coefficients() const { + return impl_->trained_model.get_packed_coefficients(); +} + +template +void train_result::set_packed_coefficients_impl(const table& value) { + impl_->trained_model.set_packed_coefficients(value); +} + +template +const result_option_id& train_result::get_result_options() const { + return impl_->options; +} + +template +void train_result::set_result_options_impl(const result_option_id& value) { + impl_->options = value; +} + +template class ONEDAL_EXPORT train_result; +template class ONEDAL_EXPORT train_input; + +} // namespace v1 +} // namespace oneapi::dal::logistic_regression diff --git a/cpp/oneapi/dal/algo/logistic_regression/train_types.hpp b/cpp/oneapi/dal/algo/logistic_regression/train_types.hpp new file mode 100644 index 00000000000..1fc289466b7 --- /dev/null +++ b/cpp/oneapi/dal/algo/logistic_regression/train_types.hpp @@ -0,0 +1,190 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#pragma once + +#include "oneapi/dal/algo/logistic_regression/common.hpp" + +namespace oneapi::dal::logistic_regression { + +namespace detail { +namespace v1 { +template +class train_input_impl; + +template +class train_result_impl; + +template +struct train_parameters_impl; + +template +class train_parameters : public base { +public: + explicit train_parameters(); + train_parameters(train_parameters&&) = default; + train_parameters(const train_parameters&) = default; + + std::int64_t get_cpu_macro_block() const; + auto& set_cpu_macro_block(std::int64_t val) { + set_cpu_macro_block_impl(val); + return *this; + } + + std::int64_t get_gpu_macro_block() const; + auto& set_gpu_macro_block(std::int64_t val) { + set_gpu_macro_block_impl(val); + return *this; + } + +private: + void set_cpu_macro_block_impl(std::int64_t val); + void set_gpu_macro_block_impl(std::int64_t val); + dal::detail::pimpl> impl_; +}; + +} // namespace v1 + +using v1::train_parameters; +using v1::train_input_impl; +using v1::train_result_impl; + +} // namespace detail + +namespace v1 { + +/// @tparam Task Tag-type that specifies type of the problem to solve. Can +/// be :expr:`task::classification`. +template +class train_input : public base { + static_assert(detail::is_valid_task_v); + +public: + using task_t = Task; + + /// Creates a new instance of the class with the given :literal:`data` + /// and :literal:`responses` property values + train_input(const table& data, const table& responses); + + //train_input(const table& data); + + /// The training set X + /// @remark default = table{} + const table& get_data() const; + + auto& set_data(const table& data) { + set_data_impl(data); + return *this; + } + + /// Vector of responses y for the training set X + /// @remark default = table{} + const table& get_responses() const; + + auto& set_responses(const table& responses) { + set_data_impl(responses); + return *this; + } + +protected: + void set_data_impl(const table& data); + void set_responses_impl(const table& responses); + +private: + dal::detail::pimpl> impl_; +}; + +/// @tparam Task Tag-type that specifies type of the problem to solve. Can +/// be :expr:`task::classification`. +template +class train_result { + static_assert(detail::is_valid_task_v); + +public: + using task_t = Task; + + /// Creates a new instance of the class with the default property values. + train_result(); + + /// The trained Logistic Regression model + /// @remark default = model{} + const model& get_model() const; + + auto& set_model(const model& value) { + set_model_impl(value); + return *this; + } + + /// Table of Logistic regression intercept + const table& get_intercept() const; + + auto& set_intercept(const table& value) { + set_intercept_impl(value); + return *this; + } + + /// Table of Logistic regression coefficients + const table& get_coefficients() const; + + auto& set_coefficients(const table& value) { + set_coefficients_impl(value); + return *this; + } + + /// Actual number of optimizer iterations + std::int64_t get_iterations_count() const; + + auto& set_iterations_count(std::int64_t value) { + set_iterations_count_impl(value); + return *this; + } + + /// Table of Logistic regression coefficients with intercept + const table& get_packed_coefficients() const; + + auto& set_packed_coefficients(const table& value) { + set_packed_coefficients_impl(value); + return *this; + } + + /// Result options that indicates availability of the properties + const result_option_id& get_result_options() const; + + auto& set_result_options(const result_option_id& value) { + set_result_options_impl(value); + return *this; + } + +protected: + void set_model_impl(const model&); + + void set_intercept_impl(const table&); + void set_coefficients_impl(const table&); + void set_packed_coefficients_impl(const table&); + void set_iterations_count_impl(std::int64_t); + + void set_result_options_impl(const result_option_id&); + +private: + dal::detail::pimpl> impl_; +}; + +} // namespace v1 + +using v1::train_input; +using v1::train_result; + +} // namespace oneapi::dal::logistic_regression diff --git a/cpp/oneapi/dal/algo/newton_cg/BUILD b/cpp/oneapi/dal/algo/newton_cg/BUILD new file mode 100644 index 00000000000..5279414758c --- /dev/null +++ b/cpp/oneapi/dal/algo/newton_cg/BUILD @@ -0,0 +1,13 @@ +package(default_visibility = ["//visibility:public"]) +load("@onedal//dev/bazel:dal.bzl", + "dal_module", + "dal_test_suite", +) + +dal_module( + name = "newton_cg", + auto = True, + dal_deps = [ + "@onedal//cpp/oneapi/dal:core", + ] +) diff --git a/cpp/oneapi/dal/algo/newton_cg/common.cpp b/cpp/oneapi/dal/algo/newton_cg/common.cpp new file mode 100644 index 00000000000..248245fcbc1 --- /dev/null +++ b/cpp/oneapi/dal/algo/newton_cg/common.cpp @@ -0,0 +1,69 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "oneapi/dal/algo/newton_cg/common.hpp" +#include "oneapi/dal/exceptions.hpp" + +namespace oneapi::dal::newton_cg::detail { + +namespace v1 { + +template +class descriptor_impl : public base { +public: + explicit descriptor_impl(double tol = 1e-4, std::int64_t maxiter = 100) + : tol(tol), + maxiter(maxiter) {} + double tol; + std::int64_t maxiter; +}; + +template +descriptor_base::descriptor_base() : impl_(new descriptor_impl{}) {} + +template +double descriptor_base::get_tolerance() const { + return impl_->tol; +} + +template +std::int64_t descriptor_base::get_max_iteration() const { + return impl_->maxiter; +} + +template +void descriptor_base::set_tolerance_impl(double tol) { + using msg = dal::detail::error_messages; + if (tol < 0) { + throw domain_error(msg::conv_tol_lt_zero()); + } + impl_->tol = tol; +} + +template +void descriptor_base::set_max_iteration_impl(std::int64_t maxiter) { + using msg = dal::detail::error_messages; + if (maxiter < 0) { + throw domain_error(msg::max_iteration_count_lt_zero()); + } + impl_->maxiter = maxiter; +} + +template class ONEDAL_EXPORT descriptor_base; + +} // namespace v1 + +} // namespace oneapi::dal::newton_cg::detail diff --git a/cpp/oneapi/dal/algo/newton_cg/common.hpp b/cpp/oneapi/dal/algo/newton_cg/common.hpp new file mode 100644 index 00000000000..ec7f1156644 --- /dev/null +++ b/cpp/oneapi/dal/algo/newton_cg/common.hpp @@ -0,0 +1,162 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#pragma once + +#include "oneapi/dal/detail/common.hpp" +#include "oneapi/dal/detail/serialization.hpp" +#include "oneapi/dal/table/common.hpp" +#include "oneapi/dal/common.hpp" + +namespace oneapi::dal::newton_cg { + +namespace task { + +namespace v1 { + +struct compute {}; +using by_default = compute; + +} // namespace v1 + +using v1::compute; +using v1::by_default; + +} // namespace task + +namespace method { +namespace v1 { +struct dense {}; +using by_default = dense; + +} // namespace v1 + +using v1::dense; +using v1::by_default; +} // namespace method + +namespace detail { + +namespace v1 { + +struct descriptor_tag {}; + +template +class descriptor_impl; + +template +constexpr bool is_valid_float_v = dal::detail::is_one_of_v; + +template +constexpr bool is_valid_method_v = dal::detail::is_one_of_v; + +template +constexpr bool is_valid_task_v = dal::detail::is_one_of_v; + +template +class descriptor_base : public base { + static_assert(is_valid_task_v); + +public: + using tag_t = descriptor_tag; + using float_t = float; + using method_t = method::by_default; + using task_t = Task; + + descriptor_base(); + + double get_tolerance() const; + std::int64_t get_max_iteration() const; + +protected: + void set_tolerance_impl(double tol); + void set_max_iteration_impl(std::int64_t maxiter); + +private: + dal::detail::pimpl> impl_; +}; + +} // namespace v1 + +using v1::descriptor_tag; +using v1::descriptor_impl; +using v1::descriptor_base; + +using v1::is_valid_float_v; +using v1::is_valid_method_v; +using v1::is_valid_task_v; + +} // namespace detail + +namespace v1 { + +/// @tparam Float The floating-point type that the algorithm uses for +/// intermediate computations. Can be :expr:`float` or +/// :expr:`double`. +/// @tparam Method Tag-type that specifies an implementation of algorithm. Can +/// be :expr:`method::dense`. +/// @tparam Task Tag-type that specifies the type of the problem to solve. Can +/// be :expr:`task::compute`. +template +class descriptor : public detail::descriptor_base { + static_assert(detail::is_valid_float_v); + static_assert(detail::is_valid_method_v); + static_assert(detail::is_valid_task_v); + + using base_t = detail::descriptor_base; + +public: + using float_t = Float; + using method_t = Method; + using task_t = Task; + + /// Creates a new instance of the class with the given :literal:`l1_regularization_coefficient`, + /// :literal:`l2_regularization_coefficient` and :literal:`fit_intercept` property values. + explicit descriptor(double tol = 1e-4, std::int64_t maxiter = 100) { + set_tolerance(tol); + set_max_iteration(maxiter); + } + + /// The convergence tolerance + /// @invariant :expr:`tol >= 0.0` + double get_tolerance() const { + return base_t::get_tolerance(); + } + + auto& set_tolerance(double tol) { + base_t::set_tolerance_impl(tol); + return *this; + } + + /// The maximum iteration number + /// @invariant :expr:`maxiter >= 0` + std::int64_t get_max_iteration() const { + return base_t::get_max_iteration(); + } + + auto& set_max_iteration(std::int64_t maxiter) { + base_t::set_max_iteration_impl(maxiter); + return *this; + } +}; + +} // namespace v1 + +using v1::descriptor; + +} // namespace oneapi::dal::newton_cg diff --git a/cpp/oneapi/dal/backend/primitives/objective_function/logloss.hpp b/cpp/oneapi/dal/backend/primitives/objective_function/logloss.hpp index 43a17dd5684..8d99cba8da8 100644 --- a/cpp/oneapi/dal/backend/primitives/objective_function/logloss.hpp +++ b/cpp/oneapi/dal/backend/primitives/objective_function/logloss.hpp @@ -104,13 +104,13 @@ sycl::event compute_raw_hessian(sycl::queue& q, const event_vector& deps = {}); template -class LogLossHessianProduct : public BaseMatrixOperator { +class logloss_hessian_product : public base_matrix_operator { public: - LogLossHessianProduct(sycl::queue& q, - const table& data, - Float L2 = Float(0), - bool fit_intercept = true, - std::int64_t bsz = -1); + logloss_hessian_product(sycl::queue& q, + const table& data, + Float L2 = Float(0), + bool fit_intercept = true, + std::int64_t bsz = -1); sycl::event operator()(const ndview& vec, ndview& out, const event_vector& deps) final; @@ -130,23 +130,24 @@ class LogLossHessianProduct : public BaseMatrixOperator { bool fit_intercept_; ndarray raw_hessian_; ndarray buffer_; + ndarray tmp_gpu_; const std::int64_t n_; const std::int64_t p_; const std::int64_t bsz_; }; template -class LogLossFunction : public BaseFunction { +class logloss_function : public base_function { public: - LogLossFunction(sycl::queue queue, - const table& data, - ndview& labels, - Float L2 = 0.0, - bool fit_intercept = true, - std::int64_t bsz = -1); + logloss_function(sycl::queue queue, + const table& data, + const ndview& labels, + Float L2 = 0.0, + bool fit_intercept = true, + std::int64_t bsz = -1); Float get_value() final; ndview& get_gradient() final; - BaseMatrixOperator& get_hessian_product() final; + base_matrix_operator& get_hessian_product() final; event_vector update_x(const ndview& x, bool need_hessp = false, @@ -155,7 +156,7 @@ class LogLossFunction : public BaseFunction { private: sycl::queue q_; const table data_; - ndview labels_; + const ndview labels_; const std::int64_t n_; const std::int64_t p_; Float L2_; @@ -164,7 +165,7 @@ class LogLossFunction : public BaseFunction { ndarray probabilities_; ndarray gradient_; ndarray buffer_; - LogLossHessianProduct hessp_; + logloss_hessian_product hessp_; const std::int64_t dimension_; Float value_; }; diff --git a/cpp/oneapi/dal/backend/primitives/objective_function/logloss_dpc.cpp b/cpp/oneapi/dal/backend/primitives/objective_function/logloss_dpc.cpp index ef73aa0a107..d9c113fd312 100644 --- a/cpp/oneapi/dal/backend/primitives/objective_function/logloss_dpc.cpp +++ b/cpp/oneapi/dal/backend/primitives/objective_function/logloss_dpc.cpp @@ -45,7 +45,11 @@ sycl::event compute_probabilities(sycl::queue& q, Float w0 = fit_intercept ? parameters.get_slice(0, 1).at_device(q, 0l) : 0; // Poor perfomance ndview param_suf = fit_intercept ? parameters.get_slice(1, p + 1) : parameters; - auto event = gemv(q, data, param_suf, probabilities, Float(1), w0, { fill_event }); + sycl::event gemv_event; + { + gemv_event = gemv(q, data, param_suf, probabilities, Float(1), w0, { fill_event }); + gemv_event.wait_and_throw(); + } auto* const prob_ptr = probabilities.get_mutable_data(); const Float bottom = sizeof(Float) == 4 ? 1e-7 : 1e-15; @@ -53,7 +57,7 @@ sycl::event compute_probabilities(sycl::queue& q, // Log Loss is undefined for p = 0 and p = 1 so probabilities are clipped into [eps, 1 - eps] return q.submit([&](sycl::handler& cgh) { - cgh.depends_on(event); + cgh.depends_on(gemv_event); const auto range = make_range_1d(n); cgh.parallel_for(range, [=](sycl::id<1> idx) { prob_ptr[idx] = 1 / (1 + sycl::exp(-prob_ptr[idx])); @@ -179,8 +183,13 @@ sycl::event compute_logloss_with_der(sycl::queue& q, } auto out_der_suffix = fit_intercept ? out_derivative.get_slice(1, p + 1) : out_derivative; - - return gemv(q, data.t(), derivative_object, out_der_suffix, { loss_event, derw0_event }); + sycl::event gemv_event; + { + gemv_event = + gemv(q, data.t(), derivative_object, out_der_suffix, { loss_event, derw0_event }); + gemv_event.wait_and_throw(); + } + return gemv_event; } template @@ -248,8 +257,11 @@ sycl::event compute_derivative(sycl::queue& q, auto out_der_suffix = fit_intercept ? out_derivative.get_slice(1, p + 1) : out_derivative; - auto der_event = gemv(q, data.t(), derivative_object, out_der_suffix, { loss_event }); - + sycl::event der_event; + { + der_event = gemv(q, data.t(), derivative_object, out_der_suffix, { loss_event }); + der_event.wait_and_throw(); + } return der_event; } @@ -456,11 +468,11 @@ std::int64_t get_block_size(std::int64_t n, std::int64_t p) { } template -LogLossHessianProduct::LogLossHessianProduct(sycl::queue& q, - const table& data, - Float L2, - bool fit_intercept, - std::int64_t bsz) +logloss_hessian_product::logloss_hessian_product(sycl::queue& q, + const table& data, + Float L2, + bool fit_intercept, + std::int64_t bsz) : q_(q), data_(data), L2_(L2), @@ -470,52 +482,71 @@ LogLossHessianProduct::LogLossHessianProduct(sycl::queue& q, bsz_(bsz == -1 ? get_block_size(n_, p_) : bsz) { raw_hessian_ = ndarray::empty(q_, { n_ }, sycl::usm::alloc::device); buffer_ = ndarray::empty(q_, { n_ }, sycl::usm::alloc::device); + tmp_gpu_ = ndarray::empty(q_, { p_ + 1 }, sycl::usm::alloc::device); } template -ndview& LogLossHessianProduct::get_raw_hessian() { +ndview& logloss_hessian_product::get_raw_hessian() { return raw_hessian_; } template -sycl::event LogLossHessianProduct::compute_with_fit_intercept(const ndview& vec, - ndview& out, - const event_vector& deps) { - auto* const buffer_ptr = buffer_.get_mutable_data(); - const auto* const hess_ptr = raw_hessian_.get_data(); - auto* const out_ptr = out.get_mutable_data(); +sycl::event logloss_hessian_product::compute_with_fit_intercept(const ndview& vec, + ndview& out, + const event_vector& deps) { + auto* const tmp_ptr = tmp_gpu_.get_mutable_data(); ONEDAL_ASSERT(vec.get_dimension(0) == p_ + 1); ONEDAL_ASSERT(out.get_dimension(0) == p_ + 1); auto fill_buffer_event = fill(q_, buffer_, Float(1), deps); auto out_suf = out.get_slice(1, p_ + 1); + auto tmp_suf = tmp_gpu_.slice(1, p_); auto out_bias = out.get_slice(0, 1); auto vec_suf = vec.get_slice(1, p_ + 1); + ndview tmp_ndview = tmp_gpu_; sycl::event fill_out_event = fill(q_, out, Float(0), deps); Float v0 = vec.at_device(q_, 0, deps); - // TODO: Add batch matrix-vector multiplication - auto data_nd = table2ndarray(q_, data_, sycl::usm::alloc::device); - - sycl::event event_xv = gemv(q_, data_nd, vec_suf, buffer_, Float(1), v0, { fill_buffer_event }); - event_xv.wait_and_throw(); // Without this line gemv does not work correctly + const uniform_blocking blocking(n_, bsz_); - auto tmp_host = buffer_.to_host(q_); + row_accessor data_accessor(data_); + event_vector last_iter_deps = { fill_buffer_event, fill_out_event }; - sycl::event event_dxv = q_.submit([&](sycl::handler& cgh) { - cgh.depends_on({ event_xv, fill_out_event }); - const auto range = make_range_1d(n_); - auto sum_reduction = sycl::reduction(out_ptr, sycl::plus<>()); - cgh.parallel_for(range, sum_reduction, [=](sycl::id<1> idx, auto& sum_v0) { - buffer_ptr[idx] = buffer_ptr[idx] * hess_ptr[idx]; - sum_v0 += buffer_ptr[idx]; + for (std::int64_t b = 0; b < blocking.get_block_count(); ++b) { + const auto last = blocking.get_block_end_index(b); + const auto first = blocking.get_block_start_index(b); + const auto length = last - first; + auto x_rows = data_accessor.pull(q_, { first, last }, sycl::usm::alloc::device); + auto x_nd = pr::ndarray::wrap(x_rows, { length, p_ }); + auto buffer_batch = buffer_.slice(first, length); + sycl::event event_xv = gemv(q_, x_nd, vec_suf, buffer_batch, Float(1), v0, last_iter_deps); + event_xv.wait_and_throw(); // Without this line gemv does not work correctly + + auto* const buffer_ptr = buffer_batch.get_mutable_data(); + const auto* const hess_ptr = raw_hessian_.get_data() + first; + + auto fill_tmp_event = fill(q_, tmp_gpu_, Float(0), last_iter_deps); + + sycl::event event_dxv = q_.submit([&](sycl::handler& cgh) { + cgh.depends_on({ event_xv, fill_tmp_event }); + const auto range = make_range_1d(length); + auto sum_reduction = sycl::reduction(tmp_ptr, sycl::plus<>()); + cgh.parallel_for(range, sum_reduction, [=](sycl::id<1> idx, auto& sum_v0) { + buffer_ptr[idx] = buffer_ptr[idx] * hess_ptr[idx]; + sum_v0 += buffer_ptr[idx]; + }); }); - }); - sycl::event event_xtdxv = - gemv(q_, data_nd.t(), buffer_, out_suf, Float(1), Float(0), { event_dxv, fill_out_event }); - event_xtdxv.wait_and_throw(); // Without this line gemv does not work correctly + sycl::event event_xtdxv = + gemv(q_, x_nd.t(), buffer_batch, tmp_suf, Float(1), Float(0), { event_dxv }); + event_xtdxv.wait_and_throw(); // Without this line gemv does not work correctly + + sycl::event update_grad_e = + element_wise(q_, sycl::plus<>(), out, tmp_ndview, out, { event_xtdxv }); + + last_iter_deps = { update_grad_e }; + } const Float regularization_factor = L2_; @@ -524,34 +555,60 @@ sycl::event LogLossHessianProduct::compute_with_fit_intercept(const ndvie }; auto add_regularization_event = - element_wise(q_, kernel_regularization, out_suf, vec_suf, out_suf, { event_xtdxv }); + element_wise(q_, kernel_regularization, out_suf, vec_suf, out_suf, last_iter_deps); return add_regularization_event; } template -sycl::event LogLossHessianProduct::compute_without_fit_intercept(const ndview& vec, - ndview& out, - const event_vector& deps) { +sycl::event logloss_hessian_product::compute_without_fit_intercept( + const ndview& vec, + ndview& out, + const event_vector& deps) { ONEDAL_ASSERT(vec.get_dimension(0) == p_); ONEDAL_ASSERT(out.get_dimension(0) == p_); sycl::event fill_out_event = fill(q_, out, Float(0), deps); - // TODO: Add batch matrix-vector multiplication - auto data_nd = table2ndarray(q_, data_, sycl::usm::alloc::device); + const uniform_blocking blocking(n_, bsz_); - sycl::event event_xv = gemv(q_, data_nd, vec, buffer_, Float(1), Float(0), deps); - event_xv.wait_and_throw(); // Without this line gemv does not work correctly + ndview tmp_ndview = tmp_gpu_.slice(0, p_); - auto& buf_ndview = static_cast&>(buffer_); - auto& hess_ndview = static_cast&>(raw_hessian_); - constexpr sycl::multiplies kernel_mul{}; - auto event_dxv = - element_wise(q_, kernel_mul, buf_ndview, hess_ndview, buf_ndview, { event_xv }); + row_accessor data_accessor(data_); + event_vector last_iter_deps = { fill_out_event }; - sycl::event event_xtdxv = - gemv(q_, data_nd.t(), buffer_, out, Float(1), Float(0), { event_dxv, fill_out_event }); - event_xtdxv.wait_and_throw(); // Without this line gemv does not work correctly + for (std::int64_t b = 0; b < blocking.get_block_count(); ++b) { + const auto last = blocking.get_block_end_index(b); + const auto first = blocking.get_block_start_index(b); + const auto length = last - first; + ONEDAL_ASSERT(0l < length); + auto x_rows = data_accessor.pull(q_, { first, last }, sycl::usm::alloc::device); + auto x_nd = pr::ndarray::wrap(x_rows, { length, p_ }); + ndview buffer_batch = buffer_.slice(first, length); + ndview hess_batch = raw_hessian_.slice(first, length); + + sycl::event event_xv = + gemv(q_, x_nd, vec, buffer_batch, Float(1), Float(0), last_iter_deps); + event_xv.wait_and_throw(); // Without this line gemv does not work correctly + + constexpr sycl::multiplies kernel_mul{}; + auto event_dxv = + element_wise(q_, kernel_mul, buffer_batch, hess_batch, buffer_batch, { event_xv }); + + auto fill_tmp_event = fill(q_, tmp_ndview, Float(0), last_iter_deps); + + sycl::event event_xtdxv = gemv(q_, + x_nd.t(), + buffer_batch, + tmp_ndview, + Float(1), + Float(0), + { event_dxv, fill_tmp_event }); + event_xtdxv.wait_and_throw(); // Without this line gemv does not work correctly + + sycl::event update_grad_e = + element_wise(q_, sycl::plus<>(), out, tmp_ndview, out, { event_xtdxv }); + last_iter_deps = { update_grad_e }; + } const Float regularization_factor = L2_; @@ -560,15 +617,15 @@ sycl::event LogLossHessianProduct::compute_without_fit_intercept(const nd }; auto add_regularization_event = - element_wise(q_, kernel_regularization, out, vec, out, { event_xtdxv }); + element_wise(q_, kernel_regularization, out, vec, out, last_iter_deps); return add_regularization_event; } template -sycl::event LogLossHessianProduct::operator()(const ndview& vec, - ndview& out, - const event_vector& deps) { +sycl::event logloss_hessian_product::operator()(const ndview& vec, + ndview& out, + const event_vector& deps) { if (fit_intercept_) { return compute_with_fit_intercept(vec, out, deps); } @@ -578,12 +635,12 @@ sycl::event LogLossHessianProduct::operator()(const ndview& vec } template -LogLossFunction::LogLossFunction(sycl::queue q, - const table& data, - ndview& labels, - Float L2, - bool fit_intercept, - std::int64_t bsz) +logloss_function::logloss_function(sycl::queue q, + const table& data, + const ndview& labels, + Float L2, + bool fit_intercept, + std::int64_t bsz) : q_(q), data_(data), labels_(labels), @@ -601,9 +658,9 @@ LogLossFunction::LogLossFunction(sycl::queue q, } template -event_vector LogLossFunction::update_x(const ndview& x, - bool need_hessp, - const event_vector& deps) { +event_vector logloss_function::update_x(const ndview& x, + bool need_hessp, + const event_vector& deps) { using dal::backend::operator+; value_ = 0; auto fill_event = fill(q_, gradient_, Float(0), deps); @@ -621,6 +678,7 @@ event_vector LogLossFunction::update_x(const ndview& x, const auto first = blocking.get_block_start_index(b); const auto last = blocking.get_block_end_index(b); const std::int64_t cursize = last - first; + ONEDAL_ASSERT(0l < cursize); const auto data_rows = row_accessor(data_).pull(q_, { first, last }, sycl::usm::alloc::device); @@ -689,16 +747,16 @@ event_vector LogLossFunction::update_x(const ndview& x, } template -Float LogLossFunction::get_value() { +Float logloss_function::get_value() { return value_; } template -ndview& LogLossFunction::get_gradient() { +ndview& logloss_function::get_gradient() { return gradient_; } template -BaseMatrixOperator& LogLossFunction::get_hessian_product() { +base_matrix_operator& logloss_function::get_hessian_product() { return hessp_; } @@ -765,8 +823,8 @@ BaseMatrixOperator& LogLossFunction::get_hessian_product() { const ndview&, \ ndview&, \ const event_vector&); \ - template class LogLossHessianProduct; \ - template class LogLossFunction; + template class logloss_hessian_product; \ + template class logloss_function; INSTANTIATE(float); INSTANTIATE(double); diff --git a/cpp/oneapi/dal/backend/primitives/objective_function/test/logloss_dpc.cpp b/cpp/oneapi/dal/backend/primitives/objective_function/test/logloss_dpc.cpp index b983e109232..826490455cc 100644 --- a/cpp/oneapi/dal/backend/primitives/objective_function/test/logloss_dpc.cpp +++ b/cpp/oneapi/dal/backend/primitives/objective_function/test/logloss_dpc.cpp @@ -179,8 +179,8 @@ class logloss_test : public te::float_algo_fixture { fit_intercept, { logloss_event }); logloss_reg_event.wait_and_throw(); - const float_t val_logloss1 = out_logloss.to_host(this->get_queue(), {}).at(0); + const float_t val_logloss1 = out_logloss.to_host(this->get_queue(), {}).at(0); check_val(val_logloss1, logloss, rtol, atol); auto fill_event = fill(this->get_queue(), out_logloss, float_t(0), {}); @@ -206,8 +206,8 @@ class logloss_test : public te::float_algo_fixture { auto out_derivative_host = out_derivative.to_host(this->get_queue()); const float_t val_logloss2 = out_logloss.to_host(this->get_queue(), {}).at(0); - check_val(val_logloss2, logloss, rtol, atol); + auto [out_derivative2, out_der_e2] = ndarray::zeros(this->get_queue(), { dim }, sycl::usm::alloc::device); auto der_event = compute_derivative(this->get_queue(), @@ -256,7 +256,6 @@ class logloss_test : public te::float_algo_fixture { fit_intercept, rtol, atol); - test_formula_hessian(data_host, predictions_host, hessian_host, @@ -270,26 +269,24 @@ class logloss_test : public te::float_algo_fixture { if (batch_test) { bsz = GENERATE(4, 8, 16, 20, 37, 512); } - - // LogLossFunction has different regularization so we need to multiply it by 2 to allign with other implementations - auto functor = LogLossFunction(this->get_queue(), - data_, - labels_gpu, - L2 * 2, - fit_intercept, - bsz); + // logloss_function has different regularization so we need to multiply it by 2 to allign with other implementations + auto functor = logloss_function(this->get_queue(), + data_, + labels_gpu, + L2 * 2, + fit_intercept, + bsz); auto set_point_event = functor.update_x(params_gpu, true, {}); wait_or_pass(set_point_event).wait_and_throw(); check_val(logloss, functor.get_value(), rtol, atol); auto grad_func = functor.get_gradient(); auto grad_func_host = grad_func.to_host(this->get_queue()); - std::int64_t dim = fit_intercept ? p + 1 : p; for (std::int64_t i = 0; i < dim; ++i) { check_val(out_derivative_host.at(i), grad_func_host.at(i), rtol, atol); } - BaseMatrixOperator& hessp = functor.get_hessian_product(); + base_matrix_operator& hessp = functor.get_hessian_product(); test_hessian_product(hessian_host, hessp, fit_intercept, L2, rtol, atol); } } @@ -465,7 +462,7 @@ class logloss_test : public te::float_algo_fixture { } void test_hessian_product(const ndview& hessian_host, - BaseMatrixOperator& hessp, + base_matrix_operator& hessp, bool fit_intercept, double L2, const float_t rtol = 1e-3, diff --git a/cpp/oneapi/dal/backend/primitives/optimizers.hpp b/cpp/oneapi/dal/backend/primitives/optimizers.hpp index 570a6e41326..eac664683b4 100644 --- a/cpp/oneapi/dal/backend/primitives/optimizers.hpp +++ b/cpp/oneapi/dal/backend/primitives/optimizers.hpp @@ -16,6 +16,6 @@ #pragma once -#include "oneapi/dal/backend/primitives/newton_cg/cg_solver.hpp" -#include "oneapi/dal/backend/primitives/newton_cg/newton_cg.hpp" -#include "oneapi/dal/backend/primitives/newton_cg/line_search.hpp" +#include "oneapi/dal/backend/primitives/optimizers/cg_solver.hpp" +#include "oneapi/dal/backend/primitives/optimizers/newton_cg.hpp" +#include "oneapi/dal/backend/primitives/optimizers/line_search.hpp" diff --git a/cpp/oneapi/dal/backend/primitives/optimizers/BUILD b/cpp/oneapi/dal/backend/primitives/optimizers/BUILD index 3a0cf153916..81ab27df598 100644 --- a/cpp/oneapi/dal/backend/primitives/optimizers/BUILD +++ b/cpp/oneapi/dal/backend/primitives/optimizers/BUILD @@ -29,5 +29,6 @@ dal_test_suite( dal_deps = [ ":optimizers", "@onedal//cpp/oneapi/dal/backend/primitives:rng", + "@onedal//cpp/oneapi/dal/backend/primitives:objective_function", ], ) diff --git a/cpp/oneapi/dal/backend/primitives/optimizers/cg_solver.hpp b/cpp/oneapi/dal/backend/primitives/optimizers/cg_solver.hpp index d854fa092fb..d1681ac69fc 100644 --- a/cpp/oneapi/dal/backend/primitives/optimizers/cg_solver.hpp +++ b/cpp/oneapi/dal/backend/primitives/optimizers/cg_solver.hpp @@ -25,7 +25,7 @@ namespace oneapi::dal::backend::primitives { // https://nvlpubs.nist.gov/nistpubs/jres/049/jresv49n6p409_a1b.pdf template sycl::event cg_solve(sycl::queue& queue, - BaseMatrixOperator& mul_operator, + base_matrix_operator& mul_operator, const ndview& b, ndview& x, ndview& residual, diff --git a/cpp/oneapi/dal/backend/primitives/optimizers/cg_solver_dpc.cpp b/cpp/oneapi/dal/backend/primitives/optimizers/cg_solver_dpc.cpp index c9e996709e1..0eb1fb799cc 100644 --- a/cpp/oneapi/dal/backend/primitives/optimizers/cg_solver_dpc.cpp +++ b/cpp/oneapi/dal/backend/primitives/optimizers/cg_solver_dpc.cpp @@ -24,7 +24,7 @@ namespace oneapi::dal::backend::primitives { template sycl::event cg_solve(sycl::queue& queue, - BaseMatrixOperator& mul_operator, + base_matrix_operator& mul_operator, const ndview& b, ndview& x, ndview& residual, @@ -143,20 +143,21 @@ sycl::event cg_solve(sycl::queue& queue, conj_vector, { update_x_event, update_residual_event }); // p_i+1 = -r_i+1 + beta * p_i } + return compute_conj_event; } -#define INSTANTIATE(F) \ - template sycl::event cg_solve(sycl::queue&, \ - BaseMatrixOperator&, \ - const ndview&, \ - ndview&, \ - ndview&, \ - ndview&, \ - ndview&, \ - F, \ - F, \ - std::int64_t, \ +#define INSTANTIATE(F) \ + template sycl::event cg_solve(sycl::queue&, \ + base_matrix_operator&, \ + const ndview&, \ + ndview&, \ + ndview&, \ + ndview&, \ + ndview&, \ + F, \ + F, \ + std::int64_t, \ const event_vector&); INSTANTIATE(float); diff --git a/cpp/oneapi/dal/backend/primitives/optimizers/common.hpp b/cpp/oneapi/dal/backend/primitives/optimizers/common.hpp index 2164bb9c5d8..7b24282c805 100644 --- a/cpp/oneapi/dal/backend/primitives/optimizers/common.hpp +++ b/cpp/oneapi/dal/backend/primitives/optimizers/common.hpp @@ -17,8 +17,6 @@ #pragma once #include "oneapi/dal/backend/primitives/ndarray.hpp" -#include "oneapi/dal/backend/primitives/blas/gemv.hpp" -#include "oneapi/dal/backend/primitives/element_wise.hpp" namespace oneapi::dal::backend::primitives { @@ -38,18 +36,18 @@ sycl::event dot_product(sycl::queue& queue, const event_vector& deps = {}); template -class BaseMatrixOperator { +class base_matrix_operator { public: - virtual ~BaseMatrixOperator() {} + virtual ~base_matrix_operator() {} virtual sycl::event operator()(const ndview& vec, ndview& out, const event_vector& deps = {}) = 0; }; template -class LinearMatrixOperator : public BaseMatrixOperator { +class linear_matrix_operator : public base_matrix_operator { public: - LinearMatrixOperator(sycl::queue& q, const ndview& A); + linear_matrix_operator(sycl::queue& q, const ndview& A); sycl::event operator()(const ndview& vec, ndview& out, const event_vector& deps) final; @@ -60,12 +58,12 @@ class LinearMatrixOperator : public BaseMatrixOperator { }; template -class BaseFunction { +class base_function { public: - virtual ~BaseFunction() {} + virtual ~base_function() {} virtual Float get_value() = 0; virtual ndview& get_gradient() = 0; - virtual BaseMatrixOperator& get_hessian_product() = 0; + virtual base_matrix_operator& get_hessian_product() = 0; virtual event_vector update_x(const ndview& x, bool need_hessp = false, const event_vector& deps = {}) = 0; diff --git a/cpp/oneapi/dal/backend/primitives/optimizers/common_dpc.cpp b/cpp/oneapi/dal/backend/primitives/optimizers/common_dpc.cpp index 32010c1bae5..82dc9c05fd1 100644 --- a/cpp/oneapi/dal/backend/primitives/optimizers/common_dpc.cpp +++ b/cpp/oneapi/dal/backend/primitives/optimizers/common_dpc.cpp @@ -21,15 +21,15 @@ namespace oneapi::dal::backend::primitives { template -LinearMatrixOperator::LinearMatrixOperator(sycl::queue& q, const ndview& A) - : BaseMatrixOperator(), +linear_matrix_operator::linear_matrix_operator(sycl::queue& q, const ndview& A) + : base_matrix_operator(), q_(q), A_(A) {} template -sycl::event LinearMatrixOperator::operator()(const ndview& vec, - ndview& out, - const event_vector& deps) { +sycl::event linear_matrix_operator::operator()(const ndview& vec, + ndview& out, + const event_vector& deps) { ONEDAL_ASSERT(A_.get_dimension(1) == vec.get_dimension(0)); ONEDAL_ASSERT(out.get_dimension(0) == vec.get_dimension(0)); sycl::event fill_out_event = fill(q_, out, Float(0), deps); @@ -102,9 +102,9 @@ sycl::event l1_norm(sycl::queue& queue, F*, \ F*, \ const event_vector&); \ - template class BaseMatrixOperator; \ - template class LinearMatrixOperator; \ - template class BaseFunction; + template class base_matrix_operator; \ + template class linear_matrix_operator; \ + template class base_function; INSTANTIATE(float); INSTANTIATE(double); diff --git a/cpp/oneapi/dal/backend/primitives/optimizers/line_search.hpp b/cpp/oneapi/dal/backend/primitives/optimizers/line_search.hpp index 61a003bc314..43f3c47f121 100644 --- a/cpp/oneapi/dal/backend/primitives/optimizers/line_search.hpp +++ b/cpp/oneapi/dal/backend/primitives/optimizers/line_search.hpp @@ -39,7 +39,7 @@ namespace oneapi::dal::backend::primitives { /// @return template Float backtracking(sycl::queue queue, - BaseFunction& f, + base_function& f, const ndview& x, const ndview& direction, ndview& result, diff --git a/cpp/oneapi/dal/backend/primitives/optimizers/line_search_dpc.cpp b/cpp/oneapi/dal/backend/primitives/optimizers/line_search_dpc.cpp index 41bb9cd3d17..e501aaf9955 100644 --- a/cpp/oneapi/dal/backend/primitives/optimizers/line_search_dpc.cpp +++ b/cpp/oneapi/dal/backend/primitives/optimizers/line_search_dpc.cpp @@ -22,7 +22,7 @@ namespace oneapi::dal::backend::primitives { template Float backtracking(sycl::queue queue, - BaseFunction& f, + base_function& f, const ndview& x, const ndview& direction, ndview& result, @@ -40,7 +40,7 @@ Float backtracking(sycl::queue queue, Float df0 = 0; dot_product(queue, grad_f0, direction, result.get_mutable_data(), &df0, deps + precompute) .wait_and_throw(); - std::int32_t iter_num = 0; + std::int64_t iter_num = 0; Float cur_val = 0; while ((iter_num == 0 || cur_val > f0 + c1 * alpha * df0) && iter_num < 100) { if (iter_num > 0) { @@ -61,7 +61,7 @@ Float backtracking(sycl::queue queue, #define INSTANTIATE(F) \ template F backtracking(sycl::queue, \ - BaseFunction&, \ + base_function&, \ const ndview&, \ const ndview&, \ ndview&, \ diff --git a/cpp/oneapi/dal/backend/primitives/optimizers/newton_cg.hpp b/cpp/oneapi/dal/backend/primitives/optimizers/newton_cg.hpp index baf62e1879e..886943cfd49 100644 --- a/cpp/oneapi/dal/backend/primitives/optimizers/newton_cg.hpp +++ b/cpp/oneapi/dal/backend/primitives/optimizers/newton_cg.hpp @@ -25,11 +25,11 @@ namespace oneapi::dal::backend::primitives { // pp. 168 (also known as the truncated Newton method) // https://link.springer.com/book/10.1007/978-0-387-40065-5 template -sycl::event newton_cg(sycl::queue& queue, - BaseFunction& f, - ndview& x, - Float tol = 1.0e-5, - std::int64_t maxiter = 100l, - const event_vector& deps = {}); +std::pair newton_cg(sycl::queue& queue, + base_function& f, + ndview& x, + Float tol = 1.0e-5, + std::int64_t maxiter = 100l, + const event_vector& deps = {}); } // namespace oneapi::dal::backend::primitives diff --git a/cpp/oneapi/dal/backend/primitives/optimizers/newton_cg_dpc.cpp b/cpp/oneapi/dal/backend/primitives/optimizers/newton_cg_dpc.cpp index 4e4c94c6ca5..9aa368a8ce2 100644 --- a/cpp/oneapi/dal/backend/primitives/optimizers/newton_cg_dpc.cpp +++ b/cpp/oneapi/dal/backend/primitives/optimizers/newton_cg_dpc.cpp @@ -26,12 +26,12 @@ namespace oneapi::dal::backend::primitives { template -sycl::event newton_cg(sycl::queue& queue, - BaseFunction& f, - ndview& x, - Float tol, - std::int64_t maxiter, - const event_vector& deps) { +std::pair newton_cg(sycl::queue& queue, + base_function& f, + ndview& x, + Float tol, + std::int64_t maxiter, + const event_vector& deps) { std::int64_t n = x.get_dimension(0); const auto kernel_minus = [=](const Float val, Float) -> Float { @@ -51,25 +51,31 @@ sycl::event newton_cg(sycl::queue& queue, Float update_norm = tol + 1; - for (std::int64_t i = 0; i < maxiter; ++i) { - if (update_norm < tol) { - break; - } + std::int64_t cur_iter_id = 0; + + while (cur_iter_id < maxiter) { + cur_iter_id++; auto update_event_vec = f.update_x(x, true, last_iter_deps); auto gradient = f.get_gradient(); Float grad_norm = 0; l1_norm(queue, gradient, tmp_gpu, &grad_norm, update_event_vec).wait_and_throw(); + + if (grad_norm < tol) { + break; + } + Float tol_k = std::min(sqrt(grad_norm), 0.5); auto prepare_grad_event = element_wise(queue, kernel_minus, gradient, Float(0), gradient, update_event_vec); - auto copy_event = copy(queue, direction, gradient, { prepare_grad_event }); + // Initialize direction with 0 + auto init_dir_event = fill(queue, direction, Float(0), { prepare_grad_event }); Float desc = -1; std::int32_t iter_num = 0; - auto last_event = copy_event; + auto last_event = init_dir_event; while (desc < 0 && iter_num < 10) { if (iter_num > 0) { tol_k /= 10; @@ -94,8 +100,8 @@ sycl::event newton_cg(sycl::queue& queue, } if (desc < 0) { - // failed to find a descent direction with cg-solver after 10 atempts - return last_event; + // failed to find descent direction + return { last_event, cur_iter_id }; } Float alpha_opt = backtracking(queue, @@ -110,21 +116,22 @@ sycl::event newton_cg(sycl::queue& queue, update_norm = 0; dot_product(queue, direction, direction, tmp_gpu, &update_norm, { last_event }) .wait_and_throw(); + update_norm = sqrt(update_norm) * alpha_opt; // updated x is in buffer2 last = copy(queue, x, buffer2, {}); last_iter_deps = { last }; } - return last; + return { last, cur_iter_id }; } -#define INSTANTIATE(F) \ - template sycl::event newton_cg(sycl::queue&, \ - BaseFunction&, \ - ndview&, \ - F, \ - std::int64_t, \ - const event_vector&); +#define INSTANTIATE(F) \ + template std::pair newton_cg(sycl::queue&, \ + base_function&, \ + ndview&, \ + F, \ + std::int64_t, \ + const event_vector&); INSTANTIATE(float); INSTANTIATE(double); diff --git a/cpp/oneapi/dal/backend/primitives/optimizers/test/cg_solver_dpc.cpp b/cpp/oneapi/dal/backend/primitives/optimizers/test/cg_solver_dpc.cpp index ad04f020fa8..9967e007cc5 100644 --- a/cpp/oneapi/dal/backend/primitives/optimizers/test/cg_solver_dpc.cpp +++ b/cpp/oneapi/dal/backend/primitives/optimizers/test/cg_solver_dpc.cpp @@ -61,7 +61,7 @@ class cg_solver_test : public te::float_algo_fixture { auto A = A_host_.to_device(this->get_queue()); auto b = b_host_.to_device(this->get_queue()); - LinearMatrixOperator mul_operator(this->get_queue(), A); + linear_matrix_operator mul_operator(this->get_queue(), A); auto [x0, x0_init_event] = ndarray::zeros(this->get_queue(), { n_ }, sycl::usm::alloc::device); x0_init_event.wait_and_throw(); diff --git a/cpp/oneapi/dal/backend/primitives/optimizers/test/fixture.hpp b/cpp/oneapi/dal/backend/primitives/optimizers/test/fixture.hpp index 2d612b99303..a6b87b2dcc1 100644 --- a/cpp/oneapi/dal/backend/primitives/optimizers/test/fixture.hpp +++ b/cpp/oneapi/dal/backend/primitives/optimizers/test/fixture.hpp @@ -22,6 +22,8 @@ #include "oneapi/dal/test/engine/common.hpp" #include "oneapi/dal/test/engine/fixtures.hpp" #include "oneapi/dal/backend/primitives/rng/rng_engine.hpp" +#include "oneapi/dal/backend/primitives/blas/gemv.hpp" +#include "oneapi/dal/backend/primitives/element_wise.hpp" namespace oneapi::dal::backend::primitives::test { @@ -29,9 +31,9 @@ namespace oneapi::dal::backend::primitives::test { // df / dx = Ax - b // df / d^2x = A template -class QuadraticFunction : public BaseFunction { +class quadratic_function : public base_function { public: - QuadraticFunction(sycl::queue& q, const ndview& A, const ndview& b) + quadratic_function(sycl::queue& q, const ndview& A, const ndview& b) : q_(q), n_(A.get_dimension(0)), A_(A), @@ -50,7 +52,7 @@ class QuadraticFunction : public BaseFunction { return gradient_; } - BaseMatrixOperator& get_hessian_product() final { + base_matrix_operator& get_hessian_product() final { return hessp_; } @@ -89,7 +91,7 @@ class QuadraticFunction : public BaseFunction { Float value_; ndarray tmp_; ndarray gradient_; - LinearMatrixOperator hessp_; + linear_matrix_operator hessp_; }; template diff --git a/cpp/oneapi/dal/backend/primitives/optimizers/test/newton_cg_dpc.cpp b/cpp/oneapi/dal/backend/primitives/optimizers/test/newton_cg_dpc.cpp index 89501ce53ef..8c07455824a 100644 --- a/cpp/oneapi/dal/backend/primitives/optimizers/test/newton_cg_dpc.cpp +++ b/cpp/oneapi/dal/backend/primitives/optimizers/test/newton_cg_dpc.cpp @@ -25,6 +25,8 @@ #include "oneapi/dal/backend/primitives/rng/rng_engine.hpp" #include +#include "oneapi/dal/backend/primitives/objective_function/logloss.hpp" + namespace oneapi::dal::backend::primitives::test { namespace te = dal::test::engine; @@ -34,9 +36,99 @@ class newton_cg_test : public te::float_algo_fixture { public: using float_t = Param; + void gen_and_test_logistic_loss(std::int64_t n = -1, + std::int64_t p = -1, + bool fit_intercept = true) { + if (n == -1 || p == -1) { + n_ = GENERATE(1000, 10000, 20000); + p_ = GENERATE(3, 10, 20, 50); + } + else { + n_ = n; + p_ = p; + } + std::int64_t bsz = GENERATE(-1, 1024); + auto X_host = + ndarray::empty(this->get_queue(), { n_, p_ }, sycl::usm::alloc::host); + auto y_prob = + ndarray::empty(this->get_queue(), { n_ + 1 }, sycl::usm::alloc::host); + auto y_host = + ndarray::empty(this->get_queue(), { n_ + 1 }, sycl::usm::alloc::host); + auto params_host = + ndarray::empty(this->get_queue(), { p_ + 1 }, sycl::usm::alloc::host); + primitives::rng rn_gen; + primitives::engine eng(2007 + n); + rn_gen.uniform(n_ * p_, X_host.get_mutable_data(), eng.get_state(), -10.0, 10.0); + rn_gen.uniform(p_ + 1, params_host.get_mutable_data(), eng.get_state(), -5.0, 5.0); + for (std::int64_t i = 0; i < n_; ++i) { + float_t val = 0; + for (std::int64_t j = 0; j < p_; ++j) { + val += X_host.at(i, j) * params_host.at(j + 1); + } + val += params_host.at(0); + val = float_t(1) / (1 + std::exp(-val)); + y_prob.at(i) = val; + if (val < 0.5) { + y_host.at(i) = 0; + } + else { + y_host.at(i) = 1; + } + } + + int train_size = n_ * 0.7; + int test_size = n_ - train_size; + auto X_train = X_host.slice(0, train_size); + auto X_test = X_host.slice(train_size, test_size); + auto y_train = y_host.slice(0, train_size); + auto y_test = y_host.slice(train_size, test_size); + + auto y_gpu = y_train.to_device(this->get_queue()); + A_ = X_train.to_device(this->get_queue()); + table data = homogen_table::wrap(A_.get_mutable_data(), train_size, p_); + auto logloss_func = + logloss_function(this->get_queue(), data, y_gpu, 3.0, true, bsz); + auto [solution_, fill_e] = + ndarray::zeros(this->get_queue(), { p_ + 1 }, sycl::usm::alloc::device); + auto [opt_event, num_iter] = + newton_cg(this->get_queue(), logloss_func, solution_, float_t(1e-8), 100, { fill_e }); + opt_event.wait_and_throw(); + auto solution_host = solution_.to_host(this->get_queue()); + + double train_score = 0; + for (std::int64_t i = 0; i < train_size; ++i) { + float_t val = 0; + for (int j = 0; j < p_; ++j) { + val += X_train.at(i, j) * solution_host.at(j + 1); + } + val += solution_host.at(0); + val = float_t(1) / (1 + std::exp(-val)); + std::int32_t pred = val > 0.5 ? 1 : 0; + if (pred == y_train.at(i)) { + train_score += 1; + } + } + + double val_score = 0; + for (std::int64_t i = 0; i < test_size; ++i) { + float_t val = 0; + for (std::int64_t j = 0; j < p_; ++j) { + val += X_test.at(i, j) * solution_host.at(j + 1); + } + val += solution_host.at(0); + val = float_t(1) / (1 + std::exp(-val)); + std::int32_t pred = val > 0.5 ? 1 : 0; + if (pred == y_test.at(i)) { + val_score += 1; + } + } + REQUIRE(train_score >= 0.97 * train_size); + REQUIRE(val_score >= 0.96 * test_size); + } + void gen_and_test_quadratic_function(std::int64_t n = -1) { if (n == -1) { - n_ = GENERATE(5, 14, 25, 50, 100); + n_ = GENERATE(5, 14, 41, 100); } else { n_ = n; @@ -61,7 +153,7 @@ class newton_cg_test : public te::float_algo_fixture { A_ = A_host.to_device(this->get_queue()); b_ = b_host.to_device(this->get_queue()); - func_ = std::make_shared>(this->get_queue(), A_, b_); + func_ = std::make_shared>(this->get_queue(), A_, b_); auto x_host = ndarray::empty(this->get_queue(), { n_ }, sycl::usm::alloc::host); auto buffer = ndarray::empty(this->get_queue(), { n_ }, sycl::usm::alloc::host); @@ -103,7 +195,9 @@ class newton_cg_test : public te::float_algo_fixture { ndarray::zeros(this->get_queue(), { n_ }, sycl::usm::alloc::device); float_t conv_tol = sizeof(float_t) == 4 ? 1e-7 : 1e-14; - newton_cg(this->get_queue(), *func_, x, conv_tol, 100, { x_event }).wait_and_throw(); + auto [opt_event, num_iter] = + newton_cg(this->get_queue(), *func_, x, conv_tol, 100, { x_event }); + opt_event.wait_and_throw(); auto x_host = x.to_host(this->get_queue()); float_t tol = sizeof(float_t) == 4 ? 1e-4 : 1e-7; for (std::int64_t i = 0; i < n_; ++i) { @@ -113,7 +207,8 @@ class newton_cg_test : public te::float_algo_fixture { private: std::int64_t n_; - std::shared_ptr> func_; + std::int64_t p_; + std::shared_ptr> func_; ndarray solution_; ndarray A_; ndarray b_; @@ -138,4 +233,22 @@ TEMPLATE_TEST_M(newton_cg_test, this->test_newton_cg(); } +TEMPLATE_TEST_M(newton_cg_test, + "test newton-cg with logloss function - double", + "[newton-cg][gpu]", + double) { + SKIP_IF(this->not_float64_friendly()); + SKIP_IF(this->get_policy().is_cpu()); + this->gen_and_test_logistic_loss(); +} + +TEMPLATE_TEST_M(newton_cg_test, + "test newton-cg with logloss function - float", + "[newton-cg][gpu]", + float) { + SKIP_IF(this->not_float64_friendly()); + SKIP_IF(this->get_policy().is_cpu()); + this->gen_and_test_logistic_loss(); +} + } // namespace oneapi::dal::backend::primitives::test diff --git a/cpp/oneapi/dal/backend/serialization.hpp b/cpp/oneapi/dal/backend/serialization.hpp index 0880812a3f1..b3985ce402c 100644 --- a/cpp/oneapi/dal/backend/serialization.hpp +++ b/cpp/oneapi/dal/backend/serialization.hpp @@ -192,6 +192,9 @@ class serialization_ids { // Algorithms - KMeans ID(8010000000, kmeans_clustering_model_impl_id); + + // Algorithms - Logistic Regression + ID(9010000000, logistic_regression_model_impl_id); }; #undef ID diff --git a/cpp/oneapi/dal/detail/error_messages.cpp b/cpp/oneapi/dal/detail/error_messages.cpp index 74cada68dbf..1886cf2a3d4 100644 --- a/cpp/oneapi/dal/detail/error_messages.cpp +++ b/cpp/oneapi/dal/detail/error_messages.cpp @@ -148,6 +148,7 @@ MSG(archive_is_in_invalid_state, /* General algorithms */ MSG(accuracy_threshold_lt_zero, "Accuracy_threshold is lower than zero") MSG(class_count_leq_one, "Class count is lower than or equal to one") +MSG(conv_tol_lt_zero, "Convergence tolerance is less than zero") MSG(input_data_is_empty, "Input data is empty") MSG(input_data_rc_neq_input_responses_rc, "Input data row count is not equal to input responses row count") @@ -301,6 +302,16 @@ MSG(input_y_is_empty, "Input y is empty") MSG(intercept_result_option_requires_intercept_flag, "Intercept result option requires intercept flag") +/* Logistic Regression */ +MSG(class_count_neq_two, + "Only binary classification is supported so class count should be equal to 2") +MSG(inverse_regularization_leq_zero, "Inverse regularization factor should be a positive number") +MSG(l1_coef_neq_zero, + "Currently L1 regularization is not supported, so l1_coef should be equal to zero") +MSG(log_reg_dense_batch_method_is_not_implemented_for_cpu, + "LogisticRegression is not implemented for CPU") +MSG(unknown_optimizer, "Custom optimizers are not supported, use on of provided by the library") + /* Decision Forest */ MSG(bootstrap_is_incompatible_with_error_metric, "Values of bootstrap and error metric parameters provided " diff --git a/cpp/oneapi/dal/detail/error_messages.hpp b/cpp/oneapi/dal/detail/error_messages.hpp index 462e146e477..37c7b76b40e 100644 --- a/cpp/oneapi/dal/detail/error_messages.hpp +++ b/cpp/oneapi/dal/detail/error_messages.hpp @@ -162,6 +162,7 @@ class ONEDAL_EXPORT error_messages { /* General Algorithms */ MSG(accuracy_threshold_lt_zero); MSG(class_count_leq_one); + MSG(conv_tol_lt_zero); MSG(input_data_is_empty); MSG(input_data_rc_neq_input_responses_rc); MSG(input_data_rc_neq_input_weights_rc); @@ -240,6 +241,13 @@ class ONEDAL_EXPORT error_messages { /* Linear Regression */ MSG(intercept_result_option_requires_intercept_flag); + /* Logistic Regression */ + MSG(class_count_neq_two); + MSG(inverse_regularization_leq_zero); + MSG(l1_coef_neq_zero); + MSG(log_reg_dense_batch_method_is_not_implemented_for_cpu); + MSG(unknown_optimizer); + /* Louvain */ MSG(negative_resolution); MSG(input_initial_partition_table_rc_neq_vertex_count); diff --git a/examples/oneapi/dpc/BUILD b/examples/oneapi/dpc/BUILD index 754c4aecc62..f7f69bdd350 100644 --- a/examples/oneapi/dpc/BUILD +++ b/examples/oneapi/dpc/BUILD @@ -51,6 +51,7 @@ dal_algo_example_suite( "knn", "linear_kernel", "linear_regression", + "logistic_regression", "pca", "rbf_kernel", "svm", diff --git a/examples/oneapi/dpc/source/logistic_regression/logistic_regression_dense_batch.cpp b/examples/oneapi/dpc/source/logistic_regression/logistic_regression_dense_batch.cpp new file mode 100644 index 00000000000..8f254953b41 --- /dev/null +++ b/examples/oneapi/dpc/source/logistic_regression/logistic_regression_dense_batch.cpp @@ -0,0 +1,121 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef ONEDAL_DATA_PARALLEL +#define ONEDAL_DATA_PARALLEL +#endif + +#include "oneapi/dal/algo/logistic_regression.hpp" +#include "oneapi/dal/io/csv.hpp" +#include "oneapi/dal/exceptions.hpp" +#include "example_util/utils.hpp" +#include + +namespace dal = oneapi::dal; +namespace result_options = dal::logistic_regression::result_options; + +auto now = std::chrono::steady_clock::now(); + +float get_time_duration(std::chrono::time_point& a, + std::chrono::time_point& b) { + return (float)std::chrono::duration_cast(b - a).count() / 1000; +} + +void run(sycl::queue& q) { + const auto x_train_filename = get_data_path("df_binary_classification_train_data.csv"); + const auto y_train_filename = get_data_path("df_binary_classification_train_label.csv"); + const auto x_test_filename = get_data_path("df_binary_classification_test_data.csv"); + const auto y_test_filename = get_data_path("df_binary_classification_test_label.csv"); + + auto tm1 = std::chrono::steady_clock::now(); + + std::cout << "Loading dataset... "; + + const auto x_train = dal::read(dal::csv::data_source{ x_train_filename }); + const auto y_train = dal::read(dal::csv::data_source{ y_train_filename }); + const auto x_test = dal::read(dal::csv::data_source{ x_test_filename }); + const auto y_test = dal::read(dal::csv::data_source{ y_test_filename }); + + auto tm2 = std::chrono::steady_clock::now(); + std::cout << get_time_duration(tm1, tm2) << " s" << std::endl; + + std::cout << "Fitting model... "; + + using method_t = dal::logistic_regression::method::dense_batch; + using task_t = dal::logistic_regression::task::classification; + using optimizer_t = dal::newton_cg::descriptor<>; + + const auto optimizer_desc = dal::newton_cg::descriptor<>(1e-4, 10l); + + const auto log_reg_desc = + dal::logistic_regression::descriptor(true, + 0.5, + optimizer_desc) + .set_result_options(result_options::coefficients | result_options::intercept | + result_options::iterations_count); + + const auto train_result = dal::train(q, log_reg_desc, x_train, y_train); + + auto tm3 = std::chrono::steady_clock::now(); + std::cout << get_time_duration(tm2, tm3) << " s" << std::endl; + + std::cout << "Coefficients:\n" << train_result.get_coefficients() << std::endl; + std::cout << "Intercept:\n" << train_result.get_intercept() << std::endl; + std::cout << "Iterations count: " << train_result.get_iterations_count() << std::endl; + + const auto log_reg_model = train_result.get_model(); + + std::cout << "Inference... "; + + const auto test_result = dal::infer(q, log_reg_desc, x_test, log_reg_model); + + auto tm4 = std::chrono::steady_clock::now(); + std::cout << get_time_duration(tm3, tm4) << " s" << std::endl; + + std::cout << "Test results:\n" << test_result.get_responses() << std::endl; + std::cout << "True responses:\n" << y_test << std::endl; + + auto y_true_arr = oneapi::dal::row_accessor(y_test).pull(); + const auto gth_ptr = y_true_arr.get_data(); + + auto pred_arr = + oneapi::dal::row_accessor(test_result.get_responses()).pull(); + const auto pred_ptr = pred_arr.get_data(); + + std::int64_t acc = 0; + + for (std::int64_t i = 0; i < y_test.get_row_count(); ++i) { + if (pred_ptr[i] == gth_ptr[i]) { + acc += 1; + } + } + + std::cout << "Accuracy on test: " << double(acc) / y_test.get_row_count() << " (" << acc + << " out of " << y_test.get_row_count() << ")" << std::endl; +} + +int main(int argc, char const* argv[]) { + std::vector devices; + try_add_device(devices, &sycl::gpu_selector_v); + for (auto d : devices) { + std::cout << "Running on " << d.get_platform().get_info() + << ", " << d.get_info() << "\n" + << std::endl; + auto q = sycl::queue{ d }; + run(q); + } + return 0; +} diff --git a/makefile.lst b/makefile.lst index 8035320cdc2..de7afb1090c 100755 --- a/makefile.lst +++ b/makefile.lst @@ -233,8 +233,10 @@ ONEAPI.ALGOS := \ knn \ linear_kernel \ linear_regression \ + logistic_regression \ logloss_objective \ louvain \ + newton_cg \ minkowski_distance \ objective_function \ pca \