Skip to content

Commit

Permalink
fixes for bazel
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexandr-Solovev committed Jul 26, 2024
1 parent 72c9671 commit 9bc24dc
Show file tree
Hide file tree
Showing 12 changed files with 313 additions and 131 deletions.
5 changes: 3 additions & 2 deletions cpp/daal/src/algorithms/engines/engine_batch_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ class BatchBaseImpl
public:
BatchBaseImpl(size_t seed) : _seed(seed) {}
size_t getSeed() const { return _seed; }
virtual void * getState() = 0;
virtual int getStateSize() const = 0;
virtual void * getState() = 0;
virtual int skipAheadoneDAL(size_t skip) = 0;
virtual int getStateSize() const = 0;
virtual ~BatchBaseImpl() {}
virtual bool hasSupport(ParallelizationTechnique technique) const = 0;

Expand Down
6 changes: 5 additions & 1 deletion cpp/daal/src/algorithms/engines/mcg59/mcg59_batch_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,11 @@ class BatchImpl : public algorithms::engines::mcg59::interface1::Batch<algorithm
{
return new BatchImpl<cpu, algorithmFPType, method>(*this);
}

int skipAheadoneDAL(size_t skip) DAAL_C11_OVERRIDE
{
skipAheadImpl(skip);
return 0;
}
bool hasSupport(engines::internal::ParallelizationTechnique technique) const DAAL_C11_OVERRIDE
{
switch (technique)
Expand Down
6 changes: 5 additions & 1 deletion cpp/daal/src/algorithms/engines/mt19937/mt19937_batch_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,11 @@ class BatchImpl : public algorithms::engines::mt19937::interface1::Batch<algorit
{
return new BatchImpl<cpu, algorithmFPType, method>(*this);
}

int skipAheadoneDAL(size_t skip) DAAL_C11_OVERRIDE
{
skipAheadImpl(skip);
return 0;
}
bool hasSupport(engines::internal::ParallelizationTechnique technique) const DAAL_C11_OVERRIDE
{
switch (technique)
Expand Down
6 changes: 5 additions & 1 deletion cpp/daal/src/algorithms/engines/mt2203/mt2203_batch_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,11 @@ class BatchImpl : public algorithms::engines::mt2203::interface1::Batch<algorith
size_t getNumberOfStreamsImpl() const DAAL_C11_OVERRIDE { return _streams.size(); }

size_t getMaxNumberOfStreamsImpl() const DAAL_C11_OVERRIDE { return 6024; }

int skipAheadoneDAL(size_t skip) DAAL_C11_OVERRIDE
{
skipAheadImpl(skip);
return 0;
}
bool hasSupport(engines::internal::ParallelizationTechnique technique) const DAAL_C11_OVERRIDE
{
switch (technique)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,24 @@ class train_kernel_hist_impl {
Index node_count,
const bk::event_vector& deps = {});

sycl::event compute_initial_imp_for_node_list_regression(
const train_context_t& ctx,
const pr::ndarray<Index, 1>& node_list,
const pr::ndarray<Float, 1>& local_sum_hist,
const pr::ndarray<Float, 1>& local_sum2cent_hist,
imp_data_t& imp_data_list,
Index node_count,
const bk::event_vector& deps = {});

sycl::event compute_local_sum_histogram(const train_context_t& ctx,
const pr::ndarray<Float, 1>& response,
const pr::ndarray<Index, 1>& tree_order,
const pr::ndarray<Index, 1>& node_list,
pr::ndarray<Float, 1>& local_sum_hist,
pr::ndarray<Float, 1>& local_sum2cent_hist,
Index node_count,
const bk::event_vector& deps = {});

/// Computes initial histograms for each node to compute impurity.
///
/// @param[in] ctx a training context structure for a GPU backend
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1009,6 +1009,136 @@ Float* local_buf_ptr = local_buf.get_pointer().get();
return event;
}

template <typename Float, typename Bin, typename Index, typename Task>
sycl::event train_kernel_hist_impl<Float, Bin, Index, Task>::compute_local_sum_histogram(
const train_context_t& ctx,
const pr::ndarray<Float, 1>& response,
const pr::ndarray<Index, 1>& tree_order,
const pr::ndarray<Index, 1>& node_list,
pr::ndarray<Float, 1>& local_sum_hist,
pr::ndarray<Float, 1>& local_sum2cent_hist,
Index node_count,
const bk::event_vector& deps) {
ONEDAL_ASSERT(response.get_count() == ctx.row_count_);
ONEDAL_ASSERT(tree_order.get_count() == ctx.tree_in_block_ * ctx.selected_row_total_count_);
ONEDAL_ASSERT(node_list.get_count() == node_count * impl_const_t::node_prop_count_);
ONEDAL_ASSERT(local_sum_hist.get_count() == node_count);
ONEDAL_ASSERT(local_sum2cent_hist.get_count() == node_count);

auto fill_event1 = local_sum_hist.fill(queue_, 0, deps);
auto fill_event2 = local_sum2cent_hist.fill(queue_, 0, deps);

fill_event1.wait_and_throw();
fill_event2.wait_and_throw();

const Float* response_ptr = response.get_data();
const Index* tree_order_ptr = tree_order.get_data();
const Index* node_list_ptr = node_list.get_data();
Float* local_sum_hist_ptr = local_sum_hist.get_mutable_data();
Float* local_sum2cent_hist_ptr = local_sum2cent_hist.get_mutable_data();

const Index node_prop_count = impl_const_t::node_prop_count_;

auto local_size = ctx.preferable_group_size_;
const sycl::nd_range<2> nd_range =
bk::make_multiple_nd_range_2d({ local_size, node_count }, { local_size, 1 });

auto event = queue_.submit([&](sycl::handler& cgh) {
cgh.depends_on(deps);
local_accessor_rw_t<Float> local_sum_buf(local_size, cgh);
local_accessor_rw_t<Float> local_sum2cent_buf(local_size, cgh);
cgh.parallel_for(nd_range, [=](sycl::nd_item<2> item) {
const Index node_id = item.get_global_id()[1];
const Index local_id = item.get_local_id()[0];
const Index local_size = item.get_local_range()[0];

const Index* node_ptr = node_list_ptr + node_id * node_prop_count;

const Index row_offset = node_ptr[impl_const_t::ind_ofs];
const Index row_count = node_ptr[impl_const_t::ind_lrc];

const Index* node_tree_order_ptr = &tree_order_ptr[row_offset];
#if __SYCL_COMPILER_VERSION >= 20230828
Float* local_sum_buf_ptr =
local_sum_buf.template get_multi_ptr<sycl::access::decorated::yes>().get_raw();
Float* local_sum2cent_buf_ptr =
local_sum2cent_buf.template get_multi_ptr<sycl::access::decorated::yes>().get_raw();
#else
Float* local_sum_buf_ptr = local_sum_buf.get_pointer().get();
Float* local_sum2cent_buf_ptr = local_sum2cent_buf.get_pointer().get();
#endif
Float local_sum = Float(0);
Float local_sum2cent = Float(0);
for (Index i = local_id; i < row_count; i += local_size) {
Float value = response_ptr[node_tree_order_ptr[i]];
local_sum += value;
local_sum2cent += value * value;
}

local_sum_buf_ptr[local_id] = local_sum;
local_sum2cent_buf_ptr[local_id] = local_sum2cent;

for (Index offset = local_size / 2; offset > 0; offset >>= 1) {
item.barrier(sycl::access::fence_space::local_space);
if (local_id < offset) {
local_sum_buf_ptr[local_id] += local_sum_buf_ptr[local_id + offset];
local_sum2cent_buf_ptr[local_id] += local_sum2cent_buf_ptr[local_id + offset];
}
}

if (local_id == 0) {
local_sum_hist_ptr[node_id] = local_sum_buf_ptr[local_id];
local_sum2cent_hist_ptr[node_id] = local_sum2cent_buf_ptr[local_id];
}
});
});

event.wait_and_throw();
return event;
}

template <typename Float, typename Bin, typename Index, typename Task>
sycl::event
train_kernel_hist_impl<Float, Bin, Index, Task>::compute_initial_imp_for_node_list_regression(
const train_context_t& ctx,
const pr::ndarray<Index, 1>& node_list,
const pr::ndarray<Float, 1>& local_sum_hist,
const pr::ndarray<Float, 1>& local_sum2cent_hist,
imp_data_t& imp_data_list,
Index node_count,
const bk::event_vector& deps) {
ONEDAL_ASSERT(node_list.get_count() == node_count * impl_const_t::node_prop_count_);
ONEDAL_ASSERT(local_sum_hist.get_count() == node_count);
ONEDAL_ASSERT(local_sum2cent_hist.get_count() == node_count);
ONEDAL_ASSERT(imp_data_list.imp_list_.get_count() ==
node_count * impl_const_t::node_imp_prop_count_);

const Index* node_list_ptr = node_list.get_data();
const Float* local_sum_hist_ptr = local_sum_hist.get_data();
const Float* local_sum2cent_hist_ptr = local_sum2cent_hist.get_data();
Float* imp_list_ptr = imp_data_list.imp_list_.get_mutable_data();

const sycl::range<1> range{ de::integral_cast<std::size_t>(node_count) };

auto last_event = queue_.submit([&](sycl::handler& cgh) {
cgh.depends_on(deps);
cgh.parallel_for(range, [=](sycl::id<1> node_idx) {
// set mean
imp_list_ptr[node_idx * impl_const_t::node_imp_prop_count_ + 0] =
local_sum_hist_ptr[node_idx] /
node_list_ptr[node_idx * impl_const_t::node_prop_count_ + impl_const_t::ind_grc];
// set sum2cent
imp_list_ptr[node_idx * impl_const_t::node_imp_prop_count_ + 1] =
local_sum2cent_hist_ptr[node_idx] -
(local_sum_hist_ptr[node_idx] * local_sum_hist_ptr[node_idx]) /
node_list_ptr[node_idx * impl_const_t::node_prop_count_ +
impl_const_t::ind_grc];
});
});

return last_event;
}

template <typename Float, typename Bin, typename Index, typename Task>
sycl::event train_kernel_hist_impl<Float, Bin, Index, Task>::compute_initial_sum2cent_local(
const train_context_t& ctx,
Expand Down Expand Up @@ -1150,8 +1280,8 @@ sycl::event train_kernel_hist_impl<Float, Bin, Index, Task>::compute_initial_his

sycl::event last_event;

if (ctx.distr_mode_) {
if constexpr (std::is_same_v<Task, task::classification>) {
if constexpr (std::is_same_v<Task, task::classification>) {
if (ctx.distr_mode_) {
last_event = compute_initial_histogram_local(ctx,
response,
tree_order,
Expand All @@ -1171,51 +1301,68 @@ sycl::event train_kernel_hist_impl<Float, Bin, Index, Task>::compute_initial_his
{ last_event });
}
else {
auto sum_list = pr::ndarray<Float, 1>::empty(queue_, { node_count });
auto sum2cent_list = pr::ndarray<Float, 1>::empty(queue_, { node_count });
last_event = compute_initial_sum_local(ctx,
response,
tree_order,
node_list,
sum_list,
node_count,
deps);
{
ONEDAL_PROFILER_TASK(sum_list, queue_);
comm_.allreduce(sum_list.flatten(queue_, { last_event })).wait();
}
last_event = compute_initial_sum2cent_local(ctx,
response,
tree_order,
node_list,
sum_list,
sum2cent_list,
node_count,
{ last_event });
{
ONEDAL_PROFILER_TASK(allreduce_sum2cent_list, queue_);
comm_.allreduce(sum2cent_list.flatten(queue_, { last_event })).wait();
}
last_event = fin_initial_imp(ctx,
node_list,
sum_list,
sum2cent_list,
imp_data_list,
node_count,
{ last_event });
last_event = compute_initial_histogram_local(ctx,
response,
tree_order,
node_list,
imp_data_list,
node_count,
deps);
last_event.wait_and_throw();
}
}
else {
last_event = compute_initial_histogram_local(ctx,
response,
tree_order,
node_list,
imp_data_list,
node_count,
deps);
auto local_sum_hist = pr::ndarray<Float, 1>::empty(queue_, { node_count });
auto local_sum2cent_hist = pr::ndarray<Float, 1>::empty(queue_, { node_count });

last_event = compute_local_sum_histogram(ctx,
response,
tree_order,
node_list,
local_sum_hist,
local_sum2cent_hist,
node_count,
deps);
{
ONEDAL_PROFILER_TASK(allreduce_sum_hist, queue_);
comm_.allreduce(local_sum_hist.flatten(queue_, { last_event })).wait();
}
{
ONEDAL_PROFILER_TASK(allreduce_sum2cent_hist, queue_);
comm_.allreduce(local_sum2cent_hist.flatten(queue_, { last_event })).wait();
}

auto host_arr_1 = local_sum_hist.to_host(queue_);
auto host_arr_2 = local_sum2cent_hist.to_host(queue_);
auto host_arr_1_ptr = host_arr_1.get_data();
auto host_arr_2_ptr = host_arr_2.get_data();
std::cout << "1st array output" << std::endl;
for (std::int64_t i = 0; i < node_count; i++) {
std::cout << host_arr_1_ptr[i] << " ";
}
std::cout << std::endl;
std::cout << "2nd array output" << std::endl;
for (std::int64_t i = 0; i < node_count; i++) {
std::cout << host_arr_2_ptr[i] << " ";
}
std::cout << std::endl;
last_event = compute_initial_imp_for_node_list_regression(ctx,
node_list,
local_sum_hist,
local_sum2cent_hist,
imp_data_list,
node_count,
{ last_event });
last_event.wait_and_throw();
}
// last_event = compute_initial_histogram_local(ctx,
// response,
// tree_order,
// node_list,
// imp_data_list,
// node_count,
// deps);
// last_event.wait_and_throw();

return last_event;
}
Expand Down
8 changes: 6 additions & 2 deletions cpp/oneapi/dal/backend/primitives/rng/rng.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

#include "oneapi/dal/backend/primitives/ndarray.hpp"

#include <daal/include/algorithms/engines/mt2203/mt2203.h>
#include <daal/include/algorithms/engines/mcg59/mcg59.h>

#include "oneapi/dal/backend/primitives/rng/utils.hpp"

Expand Down Expand Up @@ -82,7 +82,7 @@ class rng {
class engine {
public:
explicit engine(std::int64_t seed = 777)
: engine_(daal::algorithms::engines::mt2203::Batch<>::create(seed)) {
: engine_(daal::algorithms::engines::mcg59::Batch<>::create(seed)) {
impl_ = dynamic_cast<daal::algorithms::engines::internal::BatchBaseImpl*>(engine_.get());
if (!impl_) {
throw domain_error(dal::detail::error_messages::rng_engine_is_not_supported());
Expand Down Expand Up @@ -112,6 +112,10 @@ class engine {
return impl_->getState();
}

int skip_ahead(size_t nSkip) {
return impl_->skipAheadoneDAL(nSkip);
}

private:
daal::algorithms::engines::EnginePtr engine_;
daal::algorithms::engines::internal::BatchBaseImpl* impl_;
Expand Down
2 changes: 1 addition & 1 deletion cpp/oneapi/dal/backend/primitives/rng/rng_dpc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ void rng<Type, Size>::uniform(sycl::queue& queue,

// auto d = sycl::device(sycl::cpu_selector_v);
// sycl::queue cpu_queue(d);
auto engine = oneapi::mkl::rng::load_state<oneapi::mkl::rng::mrg32k3a>(queue, state);
auto engine = oneapi::mkl::rng::load_state<oneapi::mkl::rng::mcg59>(queue, state);

oneapi::mkl::rng::uniform<Type> distr(a, b);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class engine_collection {
public:
explicit engine_collection(Size count, std::int64_t seed = 777)
: count_(count),
engine_(daal::algorithms::engines::mt2203::Batch<>::create(seed)),
engine_(daal::algorithms::engines::mcg59::Batch<>::create(seed)),
params_(count),
technique_(daal::algorithms::engines::internal::family),
daal_engine_list_(count) {}
Expand Down
Loading

0 comments on commit 9bc24dc

Please sign in to comment.