Skip to content

Commit

Permalink
Address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
fineg74 committed Feb 6, 2024
1 parent 73baeac commit bb2f878
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 32 deletions.
42 changes: 18 additions & 24 deletions sycl/include/sycl/ext/intel/esimd/memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2822,10 +2822,9 @@ gather_impl(AccessorT acc, simd<OffsetT, N / VS> byte_offsets,
/// @return is a vector of type T and size N * NElts.
///
template <typename T, int NElts, lsc_data_size DS, int N>
__ESIMD_API __ESIMD_NS::simd<T, N * NElts>
slm_gather_impl(__ESIMD_NS::simd<uint32_t, N> offsets,
__ESIMD_NS::simd_mask<N> pred,
__ESIMD_NS::simd<T, N * NElts> pass_thru) {
__ESIMD_API simd<T, N * NElts> slm_gather_impl(simd<uint32_t, N> offsets,
simd_mask<N> pred,
simd<T, N * NElts> pass_thru) {
check_lsc_vector_size<NElts>();
check_lsc_data_size<T, DS>();
constexpr uint16_t AddressScale = 1;
Expand All @@ -2834,9 +2833,8 @@ slm_gather_impl(__ESIMD_NS::simd<uint32_t, N> offsets,
constexpr lsc_vector_size LSCVS = to_lsc_vector_size<NElts>();
constexpr lsc_data_order Transposed = lsc_data_order::nontranspose;
using MsgT = typename lsc_expand_type<T>::type;
__ESIMD_NS::simd<MsgT, N * NElts> PassThruExpanded =
lsc_format_input<MsgT>(pass_thru);
__ESIMD_NS::simd<MsgT, N * NElts> Result =
simd<MsgT, N * NElts> PassThruExpanded = lsc_format_input<MsgT>(pass_thru);
simd<MsgT, N * NElts> Result =
__esimd_lsc_load_merge_slm<MsgT, cache_hint::none, cache_hint::none,
AddressScale, ImmOffset, EDS, LSCVS,
Transposed, N>(pred.data(), offsets.data(),
Expand All @@ -2859,21 +2857,17 @@ slm_gather_impl(__ESIMD_NS::simd<uint32_t, N> offsets,
/// @param pred is predicates.
///
template <typename T, int NElts, lsc_data_size DS, int N>
__ESIMD_API void slm_scatter_impl(__ESIMD_NS::simd<uint32_t, N> offsets,
__ESIMD_NS::simd<T, N * NElts> vals,
__ESIMD_NS::simd_mask<N> pred) {
detail::check_lsc_vector_size<NElts>();
detail::check_lsc_data_size<T, DS>();
__ESIMD_API void slm_scatter_impl(simd<uint32_t, N> offsets,
simd<T, N * NElts> vals, simd_mask<N> pred) {
check_lsc_vector_size<NElts>();
check_lsc_data_size<T, DS>();
constexpr uint16_t AddressScale = 1;
constexpr int ImmOffset = 0;
constexpr lsc_data_size EDS =
detail::expand_data_size(detail::finalize_data_size<T, DS>());
constexpr detail::lsc_vector_size LSCVS = detail::to_lsc_vector_size<NElts>();
constexpr detail::lsc_data_order Transposed =
detail::lsc_data_order::nontranspose;
using MsgT = typename detail::lsc_expand_type<T>::type;
using CstT = __ESIMD_DNS::uint_type_t<sizeof(T)>;
__ESIMD_NS::simd<MsgT, N * NElts> Tmp = vals.template bit_cast_view<CstT>();
constexpr lsc_data_size EDS = expand_data_size(finalize_data_size<T, DS>());
constexpr lsc_vector_size LSCVS = to_lsc_vector_size<NElts>();
constexpr lsc_data_order Transposed = lsc_data_order::nontranspose;
using MsgT = typename lsc_expand_type<T>::type;
simd<MsgT, N * NElts> Tmp = lsc_format_input<MsgT, T>(vals);
__esimd_lsc_store_slm<MsgT, cache_hint::none, cache_hint::none, AddressScale,
ImmOffset, EDS, LSCVS, Transposed, N>(
pred.data(), offsets.data(), Tmp.data());
Expand Down Expand Up @@ -4181,8 +4175,8 @@ template <typename T> __ESIMD_API T slm_scalar_load(uint32_t offset) {
/// template <typename T, int N, int VS = 1,
/// typename PropertyListT = empty_properties_t>
/// void slm_scatter(simd<uint32_t, N / VS> byte_offsets,
/// simd<T, N> vals, simd_mask<N / VS> mask,
/// PropertyListT props = {}); // (slm-sc-1)
/// simd<T, N> vals, simd_mask<N / VS> mask,
/// PropertyListT props = {}); // (slm-sc-1)
/// void slm_scatter(simd<uint32_t, N / VS> byte_offsets,
/// simd<T, N> vals, PropertyListT props = {}); // (slm-sc-2)
///
Expand Down Expand Up @@ -4252,7 +4246,7 @@ slm_scatter(simd<uint32_t, N / VS> byte_offsets, simd<T, N> vals,
/// @tparam N Number of elements to read.
/// @tparam VS Vector size. It can also be read as the number of reads per each
/// address. The parameter 'N' must be divisible by 'VS'. (VS > 1) is supported
/// only on DG2 and PVC.
/// only on DG2 and PVC and only for 4- and 8-byte element vectors..
/// @param byte_offsets the vector of 32-bit offsets in bytes.
/// For each i, (byte_offsets[i]) must be element size aligned.
/// @param vals The vector of values to store.
Expand Down Expand Up @@ -4283,7 +4277,7 @@ slm_scatter(simd<uint32_t, N / VS> byte_offsets, simd<T, N> vals,
/// @tparam N Number of elements to read.
/// @tparam VS Vector size. It can also be read as the number of reads per each
/// address. The parameter 'N' must be divisible by 'VS'. (VS > 1) is supported
/// only on DG2 and PVC.
/// only on DG2 and PVC and only for 4- and 8-byte element vectors..
/// @param byte_offsets the vector of 32-bit offsets in bytes.
/// For each i, (byte_offsets[i]) must be element size aligned.
/// If the alignment property is not passed, then it is assumed that each
Expand Down
33 changes: 25 additions & 8 deletions sycl/test/esimd/memory_properties.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1302,6 +1302,7 @@ test_slm_gather_scatter(int byte_offset32) {
simd<float, 32> slm;
simd<float, 32> pass_thru;
auto pass_thru_view = pass_thru.select<32, 1>();
auto slm_view = slm.select<32, 1>();

// Test SLM gather using this plan:
// 1) slm_gather(offsets): offsets is simd or simd_view
Expand Down Expand Up @@ -1375,41 +1376,57 @@ test_slm_gather_scatter(int byte_offset32) {
props_align4);

// Test SLM scatter using this plan:
// 1) slm_scatter(offsets): offsets is simd or simd_view
// 2) slm_scatter(offsets, mask): offsets is simd or simd_view
// 4) slm_scatter(...): same as (1), (2) above, but with VS > 1.
// 1) slm_scatter(offsets, vals): offsets/vals is simd or simd_view
// 2) slm_scatter(offsets, vals, mask): offsets/vals is simd or simd_view
// 3) slm_scatter(...): same as (1), (2) above, but with VS > 1.

// 1) slm_scatter(offsets): offsets is simd or simd_view
// CHECK-COUNT-2: call void @llvm.genx.scatter.scaled.v32i1.v32i32.v32f32(<32 x i1> {{[^)]+}}, i32 2, i16 0, i32 {{[^)]+}}, i32 {{[^)]+}}, <32 x i32> {{[^)]+}}, <32 x float> {{[^)]+}})
// CHECK-COUNT-4: call void @llvm.genx.scatter.scaled.v32i1.v32i32.v32f32(<32 x i1> {{[^)]+}}, i32 2, i16 0, i32 {{[^)]+}}, i32 {{[^)]+}}, <32 x i32> {{[^)]+}}, <32 x float> {{[^)]+}})
slm_scatter<float>(ioffset_n32, slm);
slm_scatter<float, 32>(ioffset_n32_view, slm);
slm_scatter<float, 32>(ioffset_n32, slm_view);
slm_scatter<float, 32>(ioffset_n32_view, slm_view);

// CHECK-COUNT-2: call void @llvm.genx.scatter.scaled.v32i1.v32i32.v32f32(<32 x i1> {{[^)]+}}, i32 2, i16 0, i32 {{[^)]+}}, i32 {{[^)]+}}, <32 x i32> {{[^)]+}}, <32 x float> {{[^)]+}})
// CHECK-COUNT-4: call void @llvm.genx.scatter.scaled.v32i1.v32i32.v32f32(<32 x i1> {{[^)]+}}, i32 2, i16 0, i32 {{[^)]+}}, i32 {{[^)]+}}, <32 x i32> {{[^)]+}}, <32 x float> {{[^)]+}})
slm_scatter<float>(ioffset_n32, slm, props_align8);
slm_scatter<float, 32>(ioffset_n32_view, slm, props_align8);
slm_scatter<float, 32>(ioffset_n32, slm_view, props_align8);
slm_scatter<float, 32>(ioffset_n32_view, slm_view, props_align8);

// 2) slm_gather(offsets, mask): offsets is simd or simd_view
// CHECK-COUNT-2: call void @llvm.genx.scatter.scaled.v32i1.v32i32.v32f32(<32 x i1> {{[^)]+}}, i32 2, i16 0, i32 {{[^)]+}}, i32 {{[^)]+}}, <32 x i32> {{[^)]+}}, <32 x float> {{[^)]+}})
// CHECK-COUNT-4: call void @llvm.genx.scatter.scaled.v32i1.v32i32.v32f32(<32 x i1> {{[^)]+}}, i32 2, i16 0, i32 {{[^)]+}}, i32 {{[^)]+}}, <32 x i32> {{[^)]+}}, <32 x float> {{[^)]+}})
slm_scatter<float>(ioffset_n32, slm, mask_n32);
slm_scatter<float, 32>(ioffset_n32_view, slm, mask_n32);
slm_scatter<float, 32>(ioffset_n32, slm_view, mask_n32);
slm_scatter<float, 32>(ioffset_n32_view, slm_view, mask_n32);

// CHECK-COUNT-2: call void @llvm.genx.scatter.scaled.v32i1.v32i32.v32f32(<32 x i1> {{[^)]+}}, i32 2, i16 0, i32 {{[^)]+}}, i32 {{[^)]+}}, <32 x i32> {{[^)]+}}, <32 x float> {{[^)]+}})
// CHECK-COUNT-4: call void @llvm.genx.scatter.scaled.v32i1.v32i32.v32f32(<32 x i1> {{[^)]+}}, i32 2, i16 0, i32 {{[^)]+}}, i32 {{[^)]+}}, <32 x i32> {{[^)]+}}, <32 x float> {{[^)]+}})
slm_scatter<float>(ioffset_n32, slm, mask_n32, props_align8);
slm_scatter<float, 32>(ioffset_n32_view, slm, mask_n32, props_align8);
slm_scatter<float, 32>(ioffset_n32, slm_view, mask_n32, props_align8);
slm_scatter<float, 32>(ioffset_n32_view, slm_view, mask_n32, props_align8);

// 4) slm_gather(...): same as (1), (2), above, but with VS > 1.
// CHECK-COUNT-8: call void @llvm.genx.lsc.store.slm.v16i1.v16i32.v32i32(<16 x i1> {{[^)]+}}, i8 4, i8 0, i8 0, i16 1, i32 0, i8 3, i8 2, i8 1, i8 0, <16 x i32> {{[^)]+}}, <32 x i32>{{[^)]+}}, i32 0)
// CHECK-COUNT-16: call void @llvm.genx.lsc.store.slm.v16i1.v16i32.v32i32(<16 x i1> {{[^)]+}}, i8 4, i8 0, i8 0, i16 1, i32 0, i8 3, i8 2, i8 1, i8 0, <16 x i32> {{[^)]+}}, <32 x i32>{{[^)]+}}, i32 0)
// 4a) check VS > 1. no 'mask' operand first.
slm_scatter<float, 32, 2>(ioffset_n16, slm);
slm_scatter<float, 32, 2>(ioffset_n16_view, slm);
slm_scatter<float, 32, 2>(ioffset_n16, slm_view);
slm_scatter<float, 32, 2>(ioffset_n16_view, slm_view);

slm_scatter<float, 32, 2>(ioffset_n16, slm, props_align4);
slm_scatter<float, 32, 2>(ioffset_n16_view, slm, props_align4);
slm_scatter<float, 32, 2>(ioffset_n16, slm_view, props_align4);
slm_scatter<float, 32, 2>(ioffset_n16_view, slm_view, props_align4);

// 4b) check VS > 1. Pass the 'mask' operand this time.
slm_scatter<float, 32, 2>(ioffset_n16, slm, mask_n16);
slm_scatter<float, 32, 2>(ioffset_n16_view, slm, mask_n16);
slm_scatter<float, 32, 2>(ioffset_n16, slm_view, mask_n16);
slm_scatter<float, 32, 2>(ioffset_n16_view, slm_view, mask_n16);

slm_scatter<float, 32, 2>(ioffset_n16, slm, mask_n16, props_align4);
slm_scatter<float, 32, 2>(ioffset_n16_view, slm, mask_n16, props_align4);
slm_scatter<float, 32, 2>(ioffset_n16, slm_view, mask_n16, props_align4);
slm_scatter<float, 32, 2>(ioffset_n16_view, slm_view, mask_n16, props_align4);
}

0 comments on commit bb2f878

Please sign in to comment.