Skip to content

Commit

Permalink
Fix online lom performance (#2541)
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexandr-Solovev committed Oct 11, 2023
1 parent f8644d6 commit 7664e40
Show file tree
Hide file tree
Showing 8 changed files with 432 additions and 370 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,35 @@ namespace bk = dal::backend;
using task_t = task::compute;
using descriptor_t = detail::descriptor_base<task_t>;

template <typename Float>
inline auto get_desc_to_compute(const descriptor_t& desc) {
const auto res_op = desc.get_result_options();
bool has_min_max = res_op.test(result_options::min) || res_op.test(result_options::max);
bool has_other_stat =
res_op.test(result_options::mean) || res_op.test(result_options::variance) ||
res_op.test(result_options::second_order_raw_moment) ||
res_op.test(result_options::variation) || res_op.test(result_options::standard_deviation);
bool has_sums = res_op.test(result_options::sum) ||
res_op.test(result_options::sum_squares_centered) || has_other_stat;
bool has_sums2 = res_op.test(result_options::sum_squares_centered) ||
res_op.test(result_options::sum_squares_centered) || has_other_stat;
auto local_desc =
basic_statistics::descriptor<Float, method::dense, basic_statistics::task::compute>();
if (has_min_max && has_sums && has_sums2) {
local_desc.set_result_options(result_options::min | result_options::max |
result_options::sum | result_options::sum_squares |
result_options::sum_squares_centered);
}
else if (!has_min_max || has_sums || has_sums2) {
local_desc.set_result_options(result_options::sum | result_options::sum_squares |
result_options::sum_squares_centered);
}
else if (has_min_max && !has_sums && !has_sums2) {
local_desc.set_result_options(result_options::min | result_options::max);
}
return local_desc;
}

inline auto get_daal_estimates_to_compute(const descriptor_t& desc) {
const auto res_op = desc.get_result_options();
const auto res_min_max = result_options::min | result_options::max;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,21 @@ static compute_result<Task> call_daal_kernel_finalize_compute(
const context_cpu& ctx,
const descriptor_t& desc,
const partial_compute_result<Task>& input) {
const auto result_ids = daal_lom::estimatesAll;
const auto result_ids = get_daal_estimates_to_compute(desc);
const auto daal_parameter = daal_lom::Parameter(result_ids);

auto column_count = input.get_partial_min().get_column_count();
std::int64_t column_count = 0;

const auto res_op = desc.get_result_options();

if (result_ids == daal_lom::estimatesMeanVariance || result_ids == daal_lom::estimatesAll) {
column_count = input.get_partial_sum().get_column_count();
}
if (result_ids == daal_lom::estimatesMinMax) {
column_count = input.get_partial_min().get_column_count();
}
ONEDAL_ASSERT(column_count > 0);
auto daal_partial = daal_lom::PartialResult();
auto daal_partial_obs = interop::copy_to_daal_homogen_table<Float>(input.get_partial_n_rows());
auto daal_partial_min = interop::copy_to_daal_homogen_table<Float>(input.get_partial_min());
auto daal_partial_max = interop::copy_to_daal_homogen_table<Float>(input.get_partial_max());
Expand All @@ -66,11 +76,10 @@ static compute_result<Task> call_daal_kernel_finalize_compute(

auto daal_means = interop::allocate_daal_homogen_table<Float>(1, column_count);
auto daal_rawt = interop::allocate_daal_homogen_table<Float>(1, column_count);

auto daal_variance = interop::allocate_daal_homogen_table<Float>(1, column_count);
auto daal_stdev = interop::allocate_daal_homogen_table<Float>(1, column_count);
auto daal_variation = interop::allocate_daal_homogen_table<Float>(1, column_count);
{
if (result_ids == daal_lom::estimatesMeanVariance || result_ids == daal_lom::estimatesAll) {
interop::status_to_exception(
interop::call_daal_kernel_finalize_compute<Float, daal_lom_online_kernel_t>(
ctx,
Expand All @@ -85,11 +94,8 @@ static compute_result<Task> call_daal_kernel_finalize_compute(
daal_variation.get(),
&daal_parameter));
}

compute_result<Task> res;
const auto res_op = desc.get_result_options();
res.set_result_options(desc.get_result_options());

if (res_op.test(result_options::min)) {
res.set_min(interop::convert_from_daal_homogen_table<Float>(daal_partial_min));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,22 +46,28 @@ using daal_lom_online_kernel_t =
daal_lom::internal::LowOrderMomentsOnlineKernel<Float, daal_lom::defaultDense, Cpu>;

template <typename Float, typename Task>
inline auto get_partial_result(daal_lom::PartialResult daal_partial_result) {
inline auto get_partial_result(daal_lom::PartialResult daal_partial_result,
const descriptor_t& desc) {
auto result = partial_compute_result();

const auto res_op = desc.get_result_options();
const auto result_ids = get_daal_estimates_to_compute(desc);
result.set_partial_n_rows(interop::convert_from_daal_homogen_table<Float>(
daal_partial_result.get(daal_lom::PartialResultId::nObservations)));
result.set_partial_min(interop::convert_from_daal_homogen_table<Float>(
daal_partial_result.get(daal_lom::PartialResultId::partialMinimum)));
result.set_partial_max(interop::convert_from_daal_homogen_table<Float>(
daal_partial_result.get(daal_lom::PartialResultId::partialMaximum)));
result.set_partial_sum(interop::convert_from_daal_homogen_table<Float>(
daal_partial_result.get(daal_lom::PartialResultId::partialSum)));
result.set_partial_sum_squares(interop::convert_from_daal_homogen_table<Float>(
daal_partial_result.get(daal_lom::PartialResultId::partialSumSquares)));
result.set_partial_sum_squares_centered(interop::convert_from_daal_homogen_table<Float>(
daal_partial_result.get(daal_lom::PartialResultId::partialSumSquaresCentered)));

if (result_ids == daal_lom::estimatesMinMax || res_op.test(result_options::min) ||
res_op.test(result_options::max)) {
result.set_partial_min(interop::convert_from_daal_homogen_table<Float>(
daal_partial_result.get(daal_lom::PartialResultId::partialMinimum)));
result.set_partial_max(interop::convert_from_daal_homogen_table<Float>(
daal_partial_result.get(daal_lom::PartialResultId::partialMaximum)));
}
if (result_ids == daal_lom::estimatesMeanVariance || result_ids == daal_lom::estimatesAll) {
result.set_partial_sum(interop::convert_from_daal_homogen_table<Float>(
daal_partial_result.get(daal_lom::PartialResultId::partialSum)));
result.set_partial_sum_squares(interop::convert_from_daal_homogen_table<Float>(
daal_partial_result.get(daal_lom::PartialResultId::partialSumSquares)));
result.set_partial_sum_squares_centered(interop::convert_from_daal_homogen_table<Float>(
daal_partial_result.get(daal_lom::PartialResultId::partialSumSquaresCentered)));
}
return result;
}

Expand All @@ -82,10 +88,12 @@ result_t call_daal_kernel_with_weights(const context_cpu& ctx,
auto daal_input = daal_lom::Input();
auto daal_partial = daal_lom::PartialResult();

const auto res_op = desc.get_result_options();

const auto input_ = input.get_prev();
row_accessor<const Float> data_accessor(data);
row_accessor<const Float> weights_accessor(weights);
const auto result_ids = daal_lom::estimatesAll;
const auto result_ids = get_daal_estimates_to_compute(desc);
const auto daal_parameter = daal_lom::Parameter(result_ids);
auto weights_arr = weights_accessor.pull();
auto gen_data_block = data_accessor.pull();
Expand All @@ -110,26 +118,32 @@ result_t call_daal_kernel_with_weights(const context_cpu& ctx,
}
const bool has_nobs_data = input_.get_partial_n_rows().has_data();
if (has_nobs_data) {
auto daal_partial_max =
interop::copy_to_daal_homogen_table<Float>(input_.get_partial_max());
auto daal_partial_min =
interop::copy_to_daal_homogen_table<Float>(input_.get_partial_min());
auto daal_partial_sums =
interop::copy_to_daal_homogen_table<Float>(input_.get_partial_sum());
auto daal_partial_sum_squares =
interop::copy_to_daal_homogen_table<Float>(input_.get_partial_sum_squares());
auto daal_partial_sum_squares_centered =
interop::copy_to_daal_homogen_table<Float>(input_.get_partial_sum_squares_centered());
auto daal_nobs = interop::copy_to_daal_homogen_table<Float>(input_.get_partial_n_rows());
daal_partial.set(daal_lom::PartialResultId::nObservations, daal_nobs);

daal_partial.set(daal_lom::PartialResultId::partialMaximum, daal_partial_max);
daal_partial.set(daal_lom::PartialResultId::partialMinimum, daal_partial_min);
daal_partial.set(daal_lom::PartialResultId::partialSum, daal_partial_sums);
daal_partial.set(daal_lom::PartialResultId::partialSumSquaresCentered,
daal_partial_sum_squares_centered);

daal_partial.set(daal_lom::PartialResultId::partialSumSquares, daal_partial_sum_squares);
if (result_ids == daal_lom::estimatesMinMax || res_op.test(result_options::min) ||
res_op.test(result_options::max)) {
auto daal_partial_max =
interop::copy_to_daal_homogen_table<Float>(input_.get_partial_max());
auto daal_partial_min =
interop::copy_to_daal_homogen_table<Float>(input_.get_partial_min());
daal_partial.set(daal_lom::PartialResultId::partialMaximum, daal_partial_max);
daal_partial.set(daal_lom::PartialResultId::partialMinimum, daal_partial_min);
}
if (result_ids == daal_lom::estimatesMeanVariance || result_ids == daal_lom::estimatesAll) {
auto daal_partial_sums =
interop::copy_to_daal_homogen_table<Float>(input_.get_partial_sum());
auto daal_partial_sum_squares =
interop::copy_to_daal_homogen_table<Float>(input_.get_partial_sum_squares());
auto daal_partial_sum_squares_centered = interop::copy_to_daal_homogen_table<Float>(
input_.get_partial_sum_squares_centered());

daal_partial.set(daal_lom::PartialResultId::partialSum, daal_partial_sums);
daal_partial.set(daal_lom::PartialResultId::partialSumSquaresCentered,
daal_partial_sum_squares_centered);

daal_partial.set(daal_lom::PartialResultId::partialSumSquares,
daal_partial_sum_squares);
}
{
interop::status_to_exception(
interop::call_daal_kernel<Float, daal_lom_online_kernel_t>(ctx,
Expand All @@ -138,7 +152,7 @@ result_t call_daal_kernel_with_weights(const context_cpu& ctx,
&daal_parameter,
is_online));
}
auto result = get_partial_result<Float, task_t>(daal_partial);
auto result = get_partial_result<Float, task_t>(daal_partial, desc);

return result;
}
Expand All @@ -151,7 +165,7 @@ result_t call_daal_kernel_with_weights(const context_cpu& ctx,
&daal_parameter,
is_online));
}
auto result = get_partial_result<Float, task_t>(daal_partial);
auto result = get_partial_result<Float, task_t>(daal_partial, desc);
return result;
}
}
Expand All @@ -169,47 +183,53 @@ result_t call_daal_kernel_without_weights(const context_cpu& ctx,

const auto input_ = input.get_prev();

const auto result_ids = daal_lom::estimatesAll;
const auto result_ids = get_daal_estimates_to_compute(desc);
const auto daal_parameter = daal_lom::Parameter(result_ids);

const auto daal_data = interop::convert_to_daal_table<Float>(data);

const auto res_op = desc.get_result_options();

daal_input.set(daal_lom::InputId::data, daal_data);
const bool has_nobs_data = input_.get_partial_n_rows().has_data();
{
alloc_result<Float>(daal_partial, &daal_input, &daal_parameter, result_ids);
initialize_result<Float>(daal_partial, &daal_input, &daal_parameter, result_ids);
}
if (has_nobs_data) {
auto daal_partial_max =
interop::copy_to_daal_homogen_table<Float>(input_.get_partial_max());
auto daal_partial_min =
interop::copy_to_daal_homogen_table<Float>(input_.get_partial_min());
auto daal_partial_sums =
interop::copy_to_daal_homogen_table<Float>(input_.get_partial_sum());
auto daal_partial_sum_squares =
interop::copy_to_daal_homogen_table<Float>(input_.get_partial_sum_squares());
auto daal_partial_sum_squares_centered =
interop::copy_to_daal_homogen_table<Float>(input_.get_partial_sum_squares_centered());
auto daal_nobs = interop::copy_to_daal_homogen_table<Float>(input_.get_partial_n_rows());

daal_partial.set(daal_lom::PartialResultId::nObservations, daal_nobs);

daal_partial.set(daal_lom::PartialResultId::partialMaximum, daal_partial_max);
daal_partial.set(daal_lom::PartialResultId::partialMinimum, daal_partial_min);
daal_partial.set(daal_lom::PartialResultId::partialSum, daal_partial_sums);
daal_partial.set(daal_lom::PartialResultId::partialSumSquaresCentered,
daal_partial_sum_squares_centered);

daal_partial.set(daal_lom::PartialResultId::partialSumSquares, daal_partial_sum_squares);

if (result_ids == daal_lom::estimatesMinMax || res_op.test(result_options::min) ||
res_op.test(result_options::max)) {
auto daal_partial_max =
interop::copy_to_daal_homogen_table<Float>(input_.get_partial_max());
auto daal_partial_min =
interop::copy_to_daal_homogen_table<Float>(input_.get_partial_min());
daal_partial.set(daal_lom::PartialResultId::partialMaximum, daal_partial_max);
daal_partial.set(daal_lom::PartialResultId::partialMinimum, daal_partial_min);
}
if (result_ids == daal_lom::estimatesMeanVariance || result_ids == daal_lom::estimatesAll) {
auto daal_partial_sums =
interop::copy_to_daal_homogen_table<Float>(input_.get_partial_sum());
auto daal_partial_sum_squares =
interop::copy_to_daal_homogen_table<Float>(input_.get_partial_sum_squares());
auto daal_partial_sum_squares_centered = interop::copy_to_daal_homogen_table<Float>(
input_.get_partial_sum_squares_centered());

daal_partial.set(daal_lom::PartialResultId::partialSum, daal_partial_sums);
daal_partial.set(daal_lom::PartialResultId::partialSumSquaresCentered,
daal_partial_sum_squares_centered);

daal_partial.set(daal_lom::PartialResultId::partialSumSquares,
daal_partial_sum_squares);
}
interop::status_to_exception(
interop::call_daal_kernel<Float, daal_lom_online_kernel_t>(ctx,
daal_data.get(),
&daal_partial,
&daal_parameter,
is_online));
auto result = get_partial_result<Float, task_t>(daal_partial);
auto result = get_partial_result<Float, task_t>(daal_partial, desc);
return result;
}
else {
Expand All @@ -221,7 +241,7 @@ result_t call_daal_kernel_without_weights(const context_cpu& ctx,
&daal_parameter,
is_online));
}
auto result = get_partial_result<Float, task_t>(daal_partial);
auto result = get_partial_result<Float, task_t>(daal_partial, desc);
return result;
}
}
Expand Down
Loading

0 comments on commit 7664e40

Please sign in to comment.