Skip to content

Commit

Permalink
Fix sub-group load/store extension disabling logic
Browse files Browse the repository at this point in the history
  • Loading branch information
dmitriy-sobolev committed Dec 23, 2024
1 parent c2a226b commit 0c0eae5
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
10 changes: 8 additions & 2 deletions include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ struct __parallel_transform_scan_static_single_group_submitter<_Inclusive, _Elem
const ::std::uint16_t __subgroup_id = __subgroup.get_group_id();
const ::std::uint16_t __subgroup_size = __subgroup.get_local_linear_range();

#if _ONEDPL_SYCL_SUB_GROUP_LOAD_STORE_PRESENT
#if _ONEDPL_LIBSYCL_SUB_GROUP_LOAD_STORE_PRESENT
constexpr bool __can_use_subgroup_load_store =
_IsFullGroup && oneapi::dpl::__internal::__range_has_raw_ptr_iterator_v<::std::decay_t<_InRng>>;
#else
Expand All @@ -502,13 +502,15 @@ struct __parallel_transform_scan_static_single_group_submitter<_Inclusive, _Elem
auto __lacc_ptr = __dpl_sycl::__get_accessor_ptr(__lacc);
if constexpr (__can_use_subgroup_load_store)
{
#if _ONEDPL_LIBSYCL_SUB_GROUP_LOAD_STORE_PRESENT
_ONEDPL_PRAGMA_UNROLL
for (::std::uint16_t __i = 0; __i < _ElemsPerItem; ++__i)
{
auto __idx = __i * _WGSize + __subgroup_id * __subgroup_size;
auto __val = __unary_op(__subgroup.load(__in_rng.begin() + __idx));
__subgroup.store(__lacc_ptr + __idx, __val);
}
#endif
}
else
{
Expand All @@ -523,13 +525,15 @@ struct __parallel_transform_scan_static_single_group_submitter<_Inclusive, _Elem

if constexpr (__can_use_subgroup_load_store)
{
#if _ONEDPL_LIBSYCL_SUB_GROUP_LOAD_STORE_PRESENT
_ONEDPL_PRAGMA_UNROLL
for (::std::uint16_t __i = 0; __i < _ElemsPerItem; ++__i)
{
auto __idx = __i * _WGSize + __subgroup_id * __subgroup_size;
auto __val = __subgroup.load(__lacc_ptr + __idx);
__subgroup.store(__out_rng.begin() + __idx, __val);
}
#endif
}
else
{
Expand Down Expand Up @@ -598,7 +602,7 @@ struct __parallel_copy_if_static_single_group_submitter<_Size, _ElemsPerItem, _W
const ::std::uint16_t __subgroup_id = __subgroup.get_group_id();
const ::std::uint16_t __subgroup_size = __subgroup.get_local_linear_range();

#if _ONEDPL_SYCL_SUB_GROUP_LOAD_STORE_PRESENT
#if _ONEDPL_LIBSYCL_SUB_GROUP_LOAD_STORE_PRESENT
constexpr bool __can_use_subgroup_load_store =
_IsFullGroup && oneapi::dpl::__internal::__range_has_raw_ptr_iterator_v<::std::decay_t<_InRng>>;
#else
Expand All @@ -607,13 +611,15 @@ struct __parallel_copy_if_static_single_group_submitter<_Size, _ElemsPerItem, _W
auto __lacc_ptr = __dpl_sycl::__get_accessor_ptr(__lacc);
if constexpr (__can_use_subgroup_load_store)
{
#if _ONEDPL_LIBSYCL_SUB_GROUP_LOAD_STORE_PRESENT
_ONEDPL_PRAGMA_UNROLL
for (::std::uint16_t __i = 0; __i < _ElemsPerItem; ++__i)
{
auto __idx = __i * _WGSize + __subgroup_id * __subgroup_size;
uint16_t __val = __unary_op(__subgroup.load(__in_rng.begin() + __idx));
__subgroup.store(__lacc_ptr + __idx, __val);
}
#endif
}
else
{
Expand Down
6 changes: 3 additions & 3 deletions include/oneapi/dpl/pstl/hetero/dpcpp/sycl_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,12 @@
#define _ONEDPL_LIBSYCL_KNOWN_IDENTITY_PRESENT (_ONEDPL_LIBSYCL_VERSION == 50200)
#define _ONEDPL_LIBSYCL_SUB_GROUP_MASK_PRESENT \
(SYCL_EXT_ONEAPI_SUB_GROUP_MASK >= 1 && _ONEDPL_LIBSYCL_VERSION >= 50700)
// TODO: consider replacing with SYCL_EXT_ONEAPI_GROUP_LOAD_STORE extension due to the deprecation with DPC++ 2025.1
// or using a unified approach for loading and storing across the patterns
#define _ONEDPL_LIBSYCL_SUB_GROUP_LOAD_STORE_PRESENT 0
#define _ONEDPL_SYCL_DEVICE_COPYABLE_SPECIALIZATION_BROKEN (_ONEDPL_LIBSYCL_VERSION_LESS_THAN(70100))
// TODO: determine which compiler configurations provide subgroup load/store
#define _ONEDPL_SYCL_SUB_GROUP_LOAD_STORE_PRESENT false
// Macro to check if we are compiling for SPIR-V devices. This macro must only be used within
// SYCL kernels for determining SPIR-V compilation. Using this macro on the host may lead to incorrect behavior.
#ifndef _ONEDPL_DETECT_SPIRV_COMPILATION // Check if overridden for testing
Expand Down

0 comments on commit 0c0eae5

Please sign in to comment.