Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCL][ESIMD] Use LLVM IR for USM/SLM scatter #12628

Merged
merged 11 commits into from
Feb 8, 2024
38 changes: 38 additions & 0 deletions llvm/lib/SYCLLowerIR/ESIMD/LowerESIMD.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1004,6 +1004,37 @@ static void translateGatherLoad(CallInst &CI, bool IsSLM) {
CI.replaceAllUsesWith(LI);
}

static void translateScatterStore(CallInst &CI, bool IsSLM) {
IRBuilder<> Builder(&CI);
constexpr int AlignmentTemplateArgIdx = 2;
APInt Val = parseTemplateArg(CI, AlignmentTemplateArgIdx,
ESIMDIntrinDesc::GenXArgConversion::TO_I64);
Align AlignValue(Val.getZExtValue());

auto ValsOp = CI.getArgOperand(0);
auto OffsetsOp = CI.getArgOperand(1);
auto MaskOp = CI.getArgOperand(2);
auto DataType = ValsOp->getType();

// Convert the mask from <N x i16> to <N x i1>.
Value *Zero = ConstantInt::get(MaskOp->getType(), 0);
MaskOp = Builder.CreateICmp(ICmpInst::ICMP_NE, MaskOp, Zero);

// The address space may be 3-SLM, 1-global or private.
// At the moment of calling 'scatter()' operation the pointer passed to it
// is already 4-generic. Thus, simply use 4-generic for global and private
// and let GPU BE deduce the actual address space from the use-def graph.
unsigned AS = IsSLM ? 3 : 4;
auto ElemType = DataType->getScalarType();
auto NumElems = (cast<VectorType>(DataType))->getElementCount();
auto VPtrType = VectorType::get(PointerType::get(ElemType, AS), NumElems);
auto VPtrOp = Builder.CreateIntToPtr(OffsetsOp, VPtrType);

auto SI = Builder.CreateMaskedScatter(ValsOp, VPtrOp, AlignValue, MaskOp);
SI->setDebugLoc(CI.getDebugLoc());
CI.replaceAllUsesWith(SI);
}

// TODO Specify document behavior for slm_init and nbarrier_init when:
// 1) they are called not from kernels
// 2) there are multiple such calls reachable from a kernel
Expand Down Expand Up @@ -1987,6 +2018,13 @@ size_t SYCLLowerESIMDPass::runOnFunction(Function &F,
continue;
}

if (Name.starts_with("__esimd_scatter_st") ||
Name.starts_with("__esimd_slm_scatter_st")) {
translateScatterStore(*CI, Name.starts_with("__esimd_slm_scatter_st"));
ToErase.push_back(CI);
continue;
}

if (Name.starts_with("__esimd_nbarrier_init")) {
translateNbarrierInit(*CI);
ToErase.push_back(CI);
Expand Down
14 changes: 14 additions & 0 deletions sycl/include/sycl/ext/intel/esimd/detail/memory_intrin.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,20 @@ __ESIMD_INTRIN __ESIMD_DNS::vector_type_t<T, N> __esimd_slm_gather_ld(
__ESIMD_DNS::simd_mask_storage_t<N> pred,
__ESIMD_DNS::vector_type_t<T, N> pass_thru) __ESIMD_INTRIN_END;

// Scatter data to given global or private addresses.
template <typename T, int N, size_t Align>
__ESIMD_INTRIN void
__esimd_scatter_st(__ESIMD_DNS::vector_type_t<T, N> vals,
__ESIMD_DNS::vector_type_t<uint64_t, N> vptr,
__ESIMD_DNS::simd_mask_storage_t<N> pred) __ESIMD_INTRIN_END;

// Scatter data to given SLM addresses.
template <typename T, int N, size_t Align>
__ESIMD_INTRIN void __esimd_slm_scatter_st(
__ESIMD_DNS::vector_type_t<T, N> vals,
__ESIMD_DNS::vector_type_t<uint32_t, N> vptr,
__ESIMD_DNS::simd_mask_storage_t<N> pred) __ESIMD_INTRIN_END;

/// Surface-based gather.
/// Supported platforms: DG2, PVC
///
Expand Down
19 changes: 17 additions & 2 deletions sycl/include/sycl/ext/intel/esimd/memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -724,11 +724,20 @@ scatter(T *p, simd<OffsetT, N / VS> byte_offsets, simd<T, N> vals,

// Use LSC lowering if L1/L2 or VS > 1.
if constexpr (L1Hint != cache_hint::none || L2Hint != cache_hint::none ||
VS > 1 || !__ESIMD_DNS::isPowerOf2(N, 32)) {
VS > 1 ||
(!__ESIMD_DNS::isPowerOf2(N, 32) &&
!detail::isMaskedGatherScatterLLVMAvailable())) {
static_assert(VS == 1 || sizeof(T) >= 4,
"VS > 1 is supprted only for 4- and 8-byte elements");
return detail::scatter_impl<T, VS, detail::lsc_data_size::default_size,
L1Hint, L2Hint>(p, byte_offsets, vals, mask);
} else if constexpr (detail::isMaskedGatherScatterLLVMAvailable()) {
simd<uint64_t, N> Addrs(reinterpret_cast<uint64_t>(p));
Addrs = Addrs + convert<uint64_t>(byte_offsets);
using MsgT = detail::__raw_t<T>;
__esimd_scatter_st<MsgT, N, Alignment>(
sycl::bit_cast<__ESIMD_DNS::vector_type_t<MsgT, N>>(vals.data()),
Addrs.data(), mask.data());
} else {
using Tx = detail::__raw_t<T>;
simd<uint64_t, N> byte_offsets_i = convert<uint64_t>(byte_offsets);
Expand Down Expand Up @@ -4227,9 +4236,15 @@ slm_scatter(simd<uint32_t, N / VS> byte_offsets, simd<T, N> vals,
"slm_scatter() requires at least element-size alignment");

// Use LSC lowering if VS > 1.
if constexpr (VS > 1 || !(detail::isPowerOf2(N, 32) && sizeof(T) <= 4)) {
if constexpr (VS > 1 || (!(detail::isPowerOf2(N, 32) && sizeof(T) <= 4) &&
!detail::isMaskedGatherScatterLLVMAvailable())) {
__ESIMD_DNS::slm_scatter_impl<T, VS, detail::lsc_data_size::default_size>(
byte_offsets, vals, mask);
} else if constexpr (detail::isMaskedGatherScatterLLVMAvailable()) {
using MsgT = detail::__raw_t<T>;
__esimd_slm_scatter_st<MsgT, N, Alignment>(
sycl::bit_cast<__ESIMD_DNS::vector_type_t<MsgT, N>>(vals.data()),
byte_offsets.data(), mask.data());
} else {
detail::LocalAccessorMarker acc;
detail::scatter_impl<T, N>(acc, vals, byte_offsets, 0, mask);
Expand Down
6 changes: 3 additions & 3 deletions sycl/test-e2e/ESIMD/unified_memory_api/scatter_usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===------------------------------------------------------------------===//
// RUN: %{build} -fsycl-device-code-split=per_kernel -o %t.out
// RUN: %{build} -fsycl-device-code-split=per_kernel -D__ESIMD_GATHER_SCATTER_LLVM_IR -o %t.out
// RUN: %{run} %t.out

// The test verifies esimd::scatter() functions accepting USM pointer
// and optional compile-time esimd::properties.
// The scatter() calls in this test do not use cache-hint
// properties to not impose using DG2/PVC features.
// The scatter() calls in this test do not use cache-hint properties
// or VS > 1 (number of stores per offset) to not impose using PVC features.

#include "Inputs/scatter.hpp"

Expand Down
21 changes: 21 additions & 0 deletions sycl/test-e2e/ESIMD/unified_memory_api/scatter_usm_legacy.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
//==------- scatter_usm_legacy.cpp - DPC++ ESIMD on-device test -----------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// Use per-kernel compilation to have more information about failing cases.
// RUN: %{build} -fsycl-device-code-split=per_kernel -o %t.out
// RUN: %{run} %t.out

// The test verifies esimd::scatter() functions accepting USM pointer
// and optional compile-time esimd::properties.
// The scatter() calls in this test do not use cache-hint properties
// or VS > 1 (number of stores per offset) to not impose using PVC features.
//
// TODO: Remove this test when GPU driver issue with llvm.masked.scatter is
// resolved and ESIMD starts using llvm.masked.scatter by default.
// "-D__ESIMD_GATHER_SCATTER_LLVM_IR" is not used here.

#include "scatter_usm.cpp"
8 changes: 4 additions & 4 deletions sycl/test-e2e/ESIMD/unified_memory_api/slm_scatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===------------------------------------------------------------------===//
// RUN: %{build} -fsycl-device-code-split=per_kernel -o %t.out
// RUN: %{build} -fsycl-device-code-split=per_kernel -D__ESIMD_GATHER_SCATTER_LLVM_IR -o %t.out
// RUN: %{run} %t.out

// The test verifies esimd::slm_scatter() functions accepting
// optional compile-time esimd::properties.
// The scatter() calls in this test do not use DG2/PVC features.
// The test verifies esimd::slm_scatter() functions accepting optional
// compile-time esimd::properties. The slm_scatter() calls in this test do not
// use VS > 1 (number of stores per offset) to not impose using PVC features.

#include "Inputs/scatter.hpp"

Expand Down
20 changes: 20 additions & 0 deletions sycl/test-e2e/ESIMD/unified_memory_api/slm_scatter_legacy.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
//==------- slm_scatter_legacy.cpp - DPC++ ESIMD on-device test -----------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// Use per-kernel compilation to have more information about failing cases.
// RUN: %{build} -fsycl-device-code-split=per_kernel -o %t.out
// RUN: %{run} %t.out

// The test verifies esimd::slm_scatter() functions accepting optional
// compile-time esimd::properties. The slm_scatter() calls in this test do not
// use VS > 1 (number of stores per offset) to not impose using PVC features.
//
// TODO: Remove this test when GPU driver issue with llvm.masked.scatter is
// resolved and ESIMD starts using llvm.masked.scatter by default.
// "-D__ESIMD_GATHER_SCATTER_LLVM_IR" is not used here.

#include "slm_scatter.cpp"
25 changes: 20 additions & 5 deletions sycl/test/esimd/memory_properties.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1224,7 +1224,7 @@ test_gather_scatter(AccType &acc, LocalAccType &local_acc, float *ptrf,
acc_res = gather<float, 32>(local_acc, ioffset_n32, 0);
acc_res = gather<float, 32>(local_acc, ioffset_n32, 0, mask_n32);

// CHECK-COUNT-4: call void @llvm.genx.svm.scatter.v32i1.v32i64.v32f32(<32 x i1> {{[^)]+}}, i32 0, <32 x i64> {{[^)]+}}, <32 x float> {{[^)]+}})
// CHECK-COUNT-4: call void @llvm.masked.scatter.v32f32.v32p4(<32 x float> {{[^)]+}}, <32 x ptr addrspace(4)> {{[^)]+}}, i32 4, <32 x i1> {{[^)]+}})
sarnex marked this conversation as resolved.
Show resolved Hide resolved
scatter(ptrf, ioffset_n32, usm, mask_n32);

scatter(ptrf, ioffset_n32, usm);
Expand Down Expand Up @@ -1281,6 +1281,14 @@ test_gather_scatter(AccType &acc, LocalAccType &local_acc, float *ptrf,
scatter<float, 32, 2>(ptrf, ioffset_n16_view, usm_view, mask_n16);

scatter<float, 32, 2>(ptrf, ioffset_n16_view, usm_view);

simd<uint32_t, 10> ioffset_n10(byte_offset32, 8);
simd<float, 10> usm_n10;

// Check special case to verify that for cases when N is not power of 2 llvm
// intrinsic is used
// CHECK-COUNT-1: call void @llvm.masked.scatter.v10f32.v10p4(<10 x float> {{[^)]+}}, <10 x ptr addrspace(4)> {{[^)]+}}, i32 4, <10 x i1> {{[^)]+}})
scatter(ptrf, ioffset_n10, usm_n10);
}

// CHECK-LABEL: define {{.*}} @_Z23test_slm_gather_scatter{{.*}}
Expand Down Expand Up @@ -1381,26 +1389,26 @@ test_slm_gather_scatter(int byte_offset32) {
// 3) slm_scatter(...): same as (1), (2) above, but with VS > 1.

// 1) slm_scatter(offsets): offsets is simd or simd_view
// CHECK-COUNT-4: call void @llvm.genx.scatter.scaled.v32i1.v32i32.v32f32(<32 x i1> {{[^)]+}}, i32 2, i16 0, i32 {{[^)]+}}, i32 {{[^)]+}}, <32 x i32> {{[^)]+}}, <32 x float> {{[^)]+}})
// CHECK-COUNT-4: call void @llvm.masked.scatter.v32f32.v32p3(<32 x float> {{[^)]+}}, <32 x ptr addrspace(3)> {{[^)]+}}, i32 4, <32 x i1> {{[^)]+}})
slm_scatter<float>(ioffset_n32, slm);
slm_scatter<float, 32>(ioffset_n32_view, slm);
slm_scatter<float, 32>(ioffset_n32, slm_view);
slm_scatter<float, 32>(ioffset_n32_view, slm_view);

// CHECK-COUNT-4: call void @llvm.genx.scatter.scaled.v32i1.v32i32.v32f32(<32 x i1> {{[^)]+}}, i32 2, i16 0, i32 {{[^)]+}}, i32 {{[^)]+}}, <32 x i32> {{[^)]+}}, <32 x float> {{[^)]+}})
// CHECK-COUNT-4: call void @llvm.masked.scatter.v32f32.v32p3(<32 x float> {{[^)]+}}, <32 x ptr addrspace(3)> {{[^)]+}}, i32 8, <32 x i1> {{[^)]+}})
slm_scatter<float>(ioffset_n32, slm, props_align8);
slm_scatter<float, 32>(ioffset_n32_view, slm, props_align8);
slm_scatter<float, 32>(ioffset_n32, slm_view, props_align8);
slm_scatter<float, 32>(ioffset_n32_view, slm_view, props_align8);

// 2) slm_gather(offsets, mask): offsets is simd or simd_view
// CHECK-COUNT-4: call void @llvm.genx.scatter.scaled.v32i1.v32i32.v32f32(<32 x i1> {{[^)]+}}, i32 2, i16 0, i32 {{[^)]+}}, i32 {{[^)]+}}, <32 x i32> {{[^)]+}}, <32 x float> {{[^)]+}})
// CHECK-COUNT-4: call void @llvm.masked.scatter.v32f32.v32p3(<32 x float> {{[^)]+}}, <32 x ptr addrspace(3)> {{[^)]+}}, i32 4, <32 x i1> {{[^)]+}})
slm_scatter<float>(ioffset_n32, slm, mask_n32);
slm_scatter<float, 32>(ioffset_n32_view, slm, mask_n32);
slm_scatter<float, 32>(ioffset_n32, slm_view, mask_n32);
slm_scatter<float, 32>(ioffset_n32_view, slm_view, mask_n32);

// CHECK-COUNT-4: call void @llvm.genx.scatter.scaled.v32i1.v32i32.v32f32(<32 x i1> {{[^)]+}}, i32 2, i16 0, i32 {{[^)]+}}, i32 {{[^)]+}}, <32 x i32> {{[^)]+}}, <32 x float> {{[^)]+}})
// CHECK-COUNT-4: call void @llvm.masked.scatter.v32f32.v32p3(<32 x float> {{[^)]+}}, <32 x ptr addrspace(3)> {{[^)]+}}, i32 8, <32 x i1> {{[^)]+}})
v-klochkov marked this conversation as resolved.
Show resolved Hide resolved
slm_scatter<float>(ioffset_n32, slm, mask_n32, props_align8);
slm_scatter<float, 32>(ioffset_n32_view, slm, mask_n32, props_align8);
slm_scatter<float, 32>(ioffset_n32, slm_view, mask_n32, props_align8);
Expand Down Expand Up @@ -1429,4 +1437,11 @@ test_slm_gather_scatter(int byte_offset32) {
slm_scatter<float, 32, 2>(ioffset_n16_view, slm, mask_n16, props_align4);
slm_scatter<float, 32, 2>(ioffset_n16, slm_view, mask_n16, props_align4);
slm_scatter<float, 32, 2>(ioffset_n16_view, slm_view, mask_n16, props_align4);

simd<uint32_t, 10> ioffset_n10(byte_offset32, 8);
simd<float, 10> usm_n10;
// Check special case to verify that for cases when N is not power of 2 llvm
// intrinsic is used
// CHECK-COUNT-1: call void @llvm.masked.scatter.v10f32.v10p3(<10 x float> {{[^)]+}}, <10 x ptr addrspace(3)> {{[^)]+}}, i32 4, <10 x i1> {{[^)]+}})
slm_scatter(ioffset_n10, usm_n10);
}
Loading