diff --git a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge_sort.h b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge_sort.h index 2e05f380522..a8f9deb527c 100644 --- a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge_sort.h +++ b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge_sort.h @@ -240,8 +240,6 @@ struct __merge_sort_global_submitter<_IndexT, using _merge_split_point_t = _split_point_t<_IndexT>; - static constexpr std::size_t __starting_size_limit_for_large_submitter = 4 * 1'048'576; // 4 MB - struct nd_range_params { std::size_t base_diag_count = 0; @@ -567,6 +565,8 @@ struct __merge_sort_global_submitter<_IndexT, bool __data_in_temp = false; + using __value_type = oneapi::dpl::__internal::__value_t<_Range>; + // Calculate nd-range params const nd_range_params __nd_range_params = eval_nd_range_params(__exec, __n, __n_sorted); @@ -583,7 +583,15 @@ struct __merge_sort_global_submitter<_IndexT, for (std::int64_t __i = 0; __i < __n_iter; ++__i) { - if (2 * __n_sorted >= __starting_size_limit_for_large_submitter) + if (2 * __n_sorted < __get_starting_size_limit_for_large_submitter<__value_type>()) + { + // Process parallel merge + __event_chain = run_parallel_merge(__event_chain, + __n_sorted, __data_in_temp, + __exec, __rng, __temp_buf, __comp, + __nd_range_params); + } + else { // Create storage for save split-points on each base diagonal // - for current iteration @@ -604,14 +612,6 @@ struct __merge_sort_global_submitter<_IndexT, __nd_range_params, *__p_base_diagonals_sp_storage); } - else - { - // Process parallel merge - __event_chain = run_parallel_merge(__event_chain, - __n_sorted, __data_in_temp, - __exec, __rng, __temp_buf, __comp, - __nd_range_params); - } __n_sorted *= 2; __data_in_temp = !__data_in_temp;