Skip to content

Commit

Permalink
[SYCL][ESIMD] Add more stringent compile time checks to local_accesso…
Browse files Browse the repository at this point in the history
…r version of block_load/block_store, gather/scatter API (intel#11653)
  • Loading branch information
fineg74 authored Oct 26, 2023
1 parent 331e513 commit 8d7396d
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 24 deletions.
49 changes: 27 additions & 22 deletions sycl/include/sycl/ext/intel/esimd/memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3057,11 +3057,12 @@ __ESIMD_API void media_block_store(AccessorTy acc, unsigned x, unsigned y,
///
template <typename Tx, int N, typename AccessorTy,
typename Flags = overaligned_tag<detail::OperandSize::OWORD>>
__ESIMD_API std::enable_if_t<
sycl::detail::acc_properties::is_local_accessor_v<AccessorTy> &&
is_simd_flag_type_v<Flags>,
simd<Tx, N>>
block_load(AccessorTy acc, uint32_t offset, Flags = {}) {
__ESIMD_API
std::enable_if_t<detail::is_local_accessor_with_v<
AccessorTy, detail::accessor_mode_cap::can_read> &&
is_simd_flag_type_v<Flags>,
simd<Tx, N>>
block_load(AccessorTy acc, uint32_t offset, Flags = {}) {
return slm_block_load<Tx, N, Flags>(offset +
__ESIMD_DNS::localAccessorToOffset(acc));
}
Expand All @@ -3085,10 +3086,11 @@ block_load(AccessorTy acc, uint32_t offset, Flags = {}) {
///
template <typename Tx, int N, typename AccessorTy,
typename Flags = overaligned_tag<detail::OperandSize::OWORD>>
__ESIMD_API std::enable_if_t<
sycl::detail::acc_properties::is_local_accessor_v<AccessorTy> &&
is_simd_flag_type_v<Flags>>
block_store(AccessorTy acc, uint32_t offset, simd<Tx, N> vals, Flags = {}) {
__ESIMD_API
std::enable_if_t<detail::is_local_accessor_with_v<
AccessorTy, detail::accessor_mode_cap::can_write> &&
is_simd_flag_type_v<Flags>>
block_store(AccessorTy acc, uint32_t offset, simd<Tx, N> vals, Flags = {}) {
slm_block_store<Tx, N, Flags>(
offset + __ESIMD_DNS::localAccessorToOffset(acc), vals);
}
Expand All @@ -3111,10 +3113,12 @@ block_store(AccessorTy acc, uint32_t offset, simd<Tx, N> vals, Flags = {}) {
/// undefined.
///
template <typename T, int N, typename AccessorTy>
__ESIMD_API std::enable_if_t<
sycl::detail::acc_properties::is_local_accessor_v<AccessorTy>, simd<T, N>>
gather(AccessorTy acc, simd<uint32_t, N> offsets, uint32_t glob_offset = 0,
simd_mask<N> mask = 1) {
__ESIMD_API
std::enable_if_t<detail::is_local_accessor_with_v<
AccessorTy, detail::accessor_mode_cap::can_read>,
simd<T, N>>
gather(AccessorTy acc, simd<uint32_t, N> offsets, uint32_t glob_offset = 0,
simd_mask<N> mask = 1) {
return slm_gather<T, N>(
offsets + glob_offset + __ESIMD_DNS::localAccessorToOffset(acc), mask);
}
Expand All @@ -3138,8 +3142,8 @@ gather(AccessorTy acc, simd<uint32_t, N> offsets, uint32_t glob_offset = 0,
///
///
template <typename T, int N, typename AccessorTy>
__ESIMD_API std::enable_if_t<
sycl::detail::acc_properties::is_local_accessor_v<AccessorTy>>
__ESIMD_API std::enable_if_t<detail::is_local_accessor_with_v<
AccessorTy, detail::accessor_mode_cap::can_write>>
scatter(AccessorTy acc, simd<uint32_t, N> offsets, simd<T, N> vals,
uint32_t glob_offset = 0, simd_mask<N> mask = 1) {
slm_scatter<T, N>(offsets + glob_offset +
Expand Down Expand Up @@ -3174,11 +3178,12 @@ scatter(AccessorTy acc, simd<uint32_t, N> offsets, simd<T, N> vals,
template <rgba_channel_mask RGBAMask = rgba_channel_mask::ABGR,
typename AccessorT, int N,
typename T = typename AccessorT::value_type>
__ESIMD_API std::enable_if_t<
sycl::detail::acc_properties::is_local_accessor_v<AccessorT>,
simd<T, N * get_num_channels_enabled(RGBAMask)>>
gather_rgba(AccessorT acc, simd<uint32_t, N> offsets,
uint32_t global_offset = 0, simd_mask<N> mask = 1) {
__ESIMD_API
std::enable_if_t<detail::is_local_accessor_with_v<
AccessorT, detail::accessor_mode_cap::can_read>,
simd<T, N * get_num_channels_enabled(RGBAMask)>>
gather_rgba(AccessorT acc, simd<uint32_t, N> offsets,
uint32_t global_offset = 0, simd_mask<N> mask = 1) {
return slm_gather_rgba<T, N, RGBAMask>(
offsets + global_offset + __ESIMD_DNS::localAccessorToOffset(acc), mask);
}
Expand All @@ -3202,8 +3207,8 @@ gather_rgba(AccessorT acc, simd<uint32_t, N> offsets,
template <rgba_channel_mask RGBAMask = rgba_channel_mask::ABGR,
typename AccessorT, int N,
typename T = typename AccessorT::value_type>
__ESIMD_API std::enable_if_t<
sycl::detail::acc_properties::is_local_accessor_v<AccessorT>>
__ESIMD_API std::enable_if_t<detail::is_local_accessor_with_v<
AccessorT, detail::accessor_mode_cap::can_write>>
scatter_rgba(AccessorT acc, simd<uint32_t, N> offsets,
simd<T, N * get_num_channels_enabled(RGBAMask)> vals,
uint32_t global_offset = 0, simd_mask<N> mask = 1) {
Expand Down
13 changes: 11 additions & 2 deletions sycl/test/esimd/block_load_store.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ SYCL_EXTERNAL void kernel2(int *ptr) SYCL_ESIMD_FUNCTION {

// Incompatible mode (write).
SYCL_EXTERNAL void
kernel3(accessor<int, 1, access::mode::write, access::target::device> &buf)
kernel4(accessor<int, 1, access::mode::write, access::target::device> &buf)
SYCL_ESIMD_FUNCTION {
simd<int, 32> v;
// CHECK: block_load_store.cpp:38{{.*}}error: no matching function
Expand All @@ -40,10 +40,19 @@ kernel3(accessor<int, 1, access::mode::write, access::target::device> &buf)

// Incompatible mode (read).
SYCL_EXTERNAL void
kernel4(accessor<int, 1, access::mode::read, access::target::device> &buf)
kernel5(accessor<int, 1, access::mode::read, access::target::device> &buf)
SYCL_ESIMD_FUNCTION {
simd<int, 32> v(0, 1);
// CHECK: block_load_store.cpp:48{{.*}}error: no matching function
// function for call to 'block_store'
block_store<int, 32>(buf, 0, v);
}

// Incompatible mode (read).
SYCL_EXTERNAL void
kernel6(local_accessor<const int, 1> &buf) SYCL_ESIMD_FUNCTION {
simd<int, 32> v(0, 1);
// CHECK: block_load_store.cpp:57{{.*}}error: no matching function
// function for call to 'block_store'
block_store<int, 32>(buf, 0, v);
}
20 changes: 20 additions & 0 deletions sycl/test/esimd/gather_scatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,23 @@ kernel5(accessor<int, 1, access::mode::read, access::target::device> &buf)
// function for call to 'scatter'
scatter_rgba(buf, offset, v);
}

// Incompatible mode (read).
SYCL_EXTERNAL void
kernel6(local_accessor<const int, 1> &buf) SYCL_ESIMD_FUNCTION {
simd<int, 32> v(0, 1);
simd<uint32_t, 32> offset(0, 1);
// CHECK: gather_scatter.cpp:115{{.*}}error: no matching function
// function for call to 'scatter'
scatter<int, 32>(buf, offset, v);
}

// Incompatible mode (read).
SYCL_EXTERNAL void
kernel7(local_accessor<const int, 1> &buf) SYCL_ESIMD_FUNCTION {
simd<int, 32 * 4> v(0, 1);
simd<uint32_t, 32> offset(0, sizeof(int) * 4);
// CHECK: gather_scatter.cpp:125{{.*}}error: no matching function
// function for call to 'scatter'
scatter_rgba(buf, offset, v);
}

0 comments on commit 8d7396d

Please sign in to comment.