Skip to content

Commit

Permalink
include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge.h - …
Browse files Browse the repository at this point in the history
…introduce __parallel_merge_submitter_large for merge of biggest data sizes

Signed-off-by: Sergey Kopienko <sergey.kopienko@intel.com>
  • Loading branch information
SergeyKopienko committed Nov 7, 2024
1 parent a2e142d commit b33656a
Showing 1 changed file with 107 additions and 0 deletions.
107 changes: 107 additions & 0 deletions include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,10 +232,16 @@ __serial_merge(const _Rng1& __rng1, const _Rng2& __rng2, _Rng3& __rng3, _Index _
}
}

template <typename... _Name>
class _find_split_points_kernel_on_mid_diagonal;

// Please see the comment for __parallel_for_submitter for optional kernel name explanation
template <typename _IdType, typename _Name>
struct __parallel_merge_submitter;

template <typename _IdType, typename _Name>
struct __parallel_merge_submitter_large;

template <typename _IdType, typename... _Name>
struct __parallel_merge_submitter<_IdType, __internal::__optional_kernel_name<_Name...>>
{
Expand Down Expand Up @@ -269,6 +275,107 @@ struct __parallel_merge_submitter<_IdType, __internal::__optional_kernel_name<_N
}
};

template <typename _IdType, typename... _Name>
struct __parallel_merge_submitter_large<_IdType, __internal::__optional_kernel_name<_Name...>>
{
template <typename _ExecutionPolicy, typename _Range1, typename _Range2, typename _Range3, typename _Compare>
auto
operator()(_ExecutionPolicy&& __exec, _Range1&& __rng1, _Range2&& __rng2, _Range3&& __rng3, _Compare __comp) const
{
const _IdType __n1 = __rng1.size();
const _IdType __n2 = __rng2.size();
const _IdType __n = __n1 + __n2;

assert(__n1 > 0 || __n2 > 0);

_PRINT_INFO_IN_DEBUG_MODE(__exec);

using _CustomName = oneapi::dpl::__internal::__policy_kernel_name<_ExecutionPolicy>;

using _FindSplitPointsKernelOnMidDiagonal =
oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_generator<
_find_split_points_kernel_on_mid_diagonal, _CustomName, _Range1, _Range2, _IdType, _Compare>;

// Empirical number of values to process per work-item
const std::uint8_t __chunk = __exec.queue().get_device().is_cpu() ? 128 : 4;

const _IdType __steps = oneapi::dpl::__internal::__dpl_ceiling_div(__n, __chunk);
const _IdType __base_diag_count = 1'024 * 32;
const _IdType __base_diag_part = oneapi::dpl::__internal::__dpl_ceiling_div(__steps, __base_diag_count);

using _split_point_t = std::pair<_IdType, _IdType>;

using __result_and_scratch_storage_t = __result_and_scratch_storage<_ExecutionPolicy, _split_point_t>;
__result_and_scratch_storage_t __result_and_scratch{__exec, 0, __base_diag_count + 1};

sycl::event __event = __exec.queue().submit([&](sycl::handler& __cgh) {

oneapi::dpl::__ranges::__require_access(__cgh, __rng1, __rng2);
auto __scratch_acc = __result_and_scratch.template __get_scratch_acc<sycl::access_mode::write>(
__cgh, __dpl_sycl::__no_init{});

__cgh.parallel_for<_FindSplitPointsKernelOnMidDiagonal>(
sycl::range</*dim=*/1>(__base_diag_count + 1), [=](sycl::item</*dim=*/1> __item_id)
{
auto __global_idx = __item_id.get_linear_id();
auto __scratch_ptr = __result_and_scratch_storage_t::__get_usm_or_buffer_accessor_ptr(__scratch_acc);

if (__global_idx == 0)
{
__scratch_ptr[0] = std::make_pair((_IdType)0, (_IdType)0);
}
else if (__global_idx == __base_diag_count)
{
__scratch_ptr[__base_diag_count] = std::make_pair(__n1, __n2);
}
else
{
const _IdType __i_elem = __global_idx * __base_diag_part * __chunk;
__scratch_ptr[__global_idx] = __find_start_point(__rng1, __rng2, __i_elem, __n1, __n2, __comp);
}
});
});

__event = __exec.queue().submit([&](sycl::handler& __cgh) {

oneapi::dpl::__ranges::__require_access(__cgh, __rng1, __rng2, __rng3);
auto __scratch_acc = __result_and_scratch.template __get_scratch_acc<sycl::access_mode::read>(__cgh);

__cgh.depends_on(__event);

__cgh.parallel_for<_Name...>(sycl::range</*dim=*/1>(__steps), [=](sycl::item</*dim=*/1> __item_id) {
auto __global_idx = __item_id.get_linear_id();
const _IdType __i_elem = __global_idx * __chunk;

auto __scratch_ptr = __result_and_scratch_storage_t::__get_usm_or_buffer_accessor_ptr(__scratch_acc);
auto __scratch_idx = __global_idx / __base_diag_part;

_split_point_t __start;
if (__global_idx % __base_diag_part != 0)
{
// Check that we fit into size of scratch
assert(__scratch_idx + 1 < __base_diag_count + 1);

const _split_point_t __sp_left = __scratch_ptr[__scratch_idx];
const _split_point_t __sp_right = __scratch_ptr[__scratch_idx + 1];

__start = __find_start_point_in(__rng1, __sp_left.first, __sp_right.first,
__rng2, __sp_left.second, __sp_right.second,
__i_elem, __comp);
}
else
{
__start = __scratch_ptr[__scratch_idx];
}

__serial_merge(__rng1, __rng2, __rng3, __start.first, __start.second, __i_elem, __chunk, __n1, __n2,
__comp);
});
});
return __future(__event);
}
};

template <typename... _Name>
class __merge_kernel_name;

Expand Down

0 comments on commit b33656a

Please sign in to comment.