Skip to content

Commit

Permalink
[SYCL][ESIMD] Use LLVM IR for USM/SLM scatter (#12628)
Browse files Browse the repository at this point in the history
  • Loading branch information
fineg74 committed Feb 8, 2024
1 parent 573e28b commit 7ee7e90
Show file tree
Hide file tree
Showing 8 changed files with 137 additions and 14 deletions.
38 changes: 38 additions & 0 deletions llvm/lib/SYCLLowerIR/ESIMD/LowerESIMD.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1004,6 +1004,37 @@ static void translateGatherLoad(CallInst &CI, bool IsSLM) {
CI.replaceAllUsesWith(LI);
}

static void translateScatterStore(CallInst &CI, bool IsSLM) {
IRBuilder<> Builder(&CI);
constexpr int AlignmentTemplateArgIdx = 2;
APInt Val = parseTemplateArg(CI, AlignmentTemplateArgIdx,
ESIMDIntrinDesc::GenXArgConversion::TO_I64);
Align AlignValue(Val.getZExtValue());

auto ValsOp = CI.getArgOperand(0);
auto OffsetsOp = CI.getArgOperand(1);
auto MaskOp = CI.getArgOperand(2);
auto DataType = ValsOp->getType();

// Convert the mask from <N x i16> to <N x i1>.
Value *Zero = ConstantInt::get(MaskOp->getType(), 0);
MaskOp = Builder.CreateICmp(ICmpInst::ICMP_NE, MaskOp, Zero);

// The address space may be 3-SLM, 1-global or private.
// At the moment of calling 'scatter()' operation the pointer passed to it
// is already 4-generic. Thus, simply use 4-generic for global and private
// and let GPU BE deduce the actual address space from the use-def graph.
unsigned AS = IsSLM ? 3 : 4;
auto ElemType = DataType->getScalarType();
auto NumElems = (cast<VectorType>(DataType))->getElementCount();
auto VPtrType = VectorType::get(PointerType::get(ElemType, AS), NumElems);
auto VPtrOp = Builder.CreateIntToPtr(OffsetsOp, VPtrType);

auto SI = Builder.CreateMaskedScatter(ValsOp, VPtrOp, AlignValue, MaskOp);
SI->setDebugLoc(CI.getDebugLoc());
CI.replaceAllUsesWith(SI);
}

// TODO Specify document behavior for slm_init and nbarrier_init when:
// 1) they are called not from kernels
// 2) there are multiple such calls reachable from a kernel
Expand Down Expand Up @@ -1987,6 +2018,13 @@ size_t SYCLLowerESIMDPass::runOnFunction(Function &F,
continue;
}

if (Name.starts_with("__esimd_scatter_st") ||
Name.starts_with("__esimd_slm_scatter_st")) {
translateScatterStore(*CI, Name.starts_with("__esimd_slm_scatter_st"));
ToErase.push_back(CI);
continue;
}

if (Name.starts_with("__esimd_nbarrier_init")) {
translateNbarrierInit(*CI);
ToErase.push_back(CI);
Expand Down
14 changes: 14 additions & 0 deletions sycl/include/sycl/ext/intel/esimd/detail/memory_intrin.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,20 @@ __ESIMD_INTRIN __ESIMD_DNS::vector_type_t<T, N> __esimd_slm_gather_ld(
__ESIMD_DNS::simd_mask_storage_t<N> pred,
__ESIMD_DNS::vector_type_t<T, N> pass_thru) __ESIMD_INTRIN_END;

// Scatter data to given global or private addresses.
template <typename T, int N, size_t Align>
__ESIMD_INTRIN void
__esimd_scatter_st(__ESIMD_DNS::vector_type_t<T, N> vals,
__ESIMD_DNS::vector_type_t<uint64_t, N> vptr,
__ESIMD_DNS::simd_mask_storage_t<N> pred) __ESIMD_INTRIN_END;

// Scatter data to given SLM addresses.
template <typename T, int N, size_t Align>
__ESIMD_INTRIN void __esimd_slm_scatter_st(
__ESIMD_DNS::vector_type_t<T, N> vals,
__ESIMD_DNS::vector_type_t<uint32_t, N> vptr,
__ESIMD_DNS::simd_mask_storage_t<N> pred) __ESIMD_INTRIN_END;

/// Surface-based gather.
/// Supported platforms: DG2, PVC
///
Expand Down
19 changes: 17 additions & 2 deletions sycl/include/sycl/ext/intel/esimd/memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -724,11 +724,20 @@ scatter(T *p, simd<OffsetT, N / VS> byte_offsets, simd<T, N> vals,

// Use LSC lowering if L1/L2 or VS > 1.
if constexpr (L1Hint != cache_hint::none || L2Hint != cache_hint::none ||
VS > 1 || !__ESIMD_DNS::isPowerOf2(N, 32)) {
VS > 1 ||
(!__ESIMD_DNS::isPowerOf2(N, 32) &&
!detail::isMaskedGatherScatterLLVMAvailable())) {
static_assert(VS == 1 || sizeof(T) >= 4,
"VS > 1 is supprted only for 4- and 8-byte elements");
return detail::scatter_impl<T, VS, detail::lsc_data_size::default_size,
L1Hint, L2Hint>(p, byte_offsets, vals, mask);
} else if constexpr (detail::isMaskedGatherScatterLLVMAvailable()) {
simd<uint64_t, N> Addrs(reinterpret_cast<uint64_t>(p));
Addrs = Addrs + convert<uint64_t>(byte_offsets);
using MsgT = detail::__raw_t<T>;
__esimd_scatter_st<MsgT, N, Alignment>(
sycl::bit_cast<__ESIMD_DNS::vector_type_t<MsgT, N>>(vals.data()),
Addrs.data(), mask.data());
} else {
using Tx = detail::__raw_t<T>;
simd<uint64_t, N> byte_offsets_i = convert<uint64_t>(byte_offsets);
Expand Down Expand Up @@ -4227,9 +4236,15 @@ slm_scatter(simd<uint32_t, N / VS> byte_offsets, simd<T, N> vals,
"slm_scatter() requires at least element-size alignment");

// Use LSC lowering if VS > 1.
if constexpr (VS > 1 || !(detail::isPowerOf2(N, 32) && sizeof(T) <= 4)) {
if constexpr (VS > 1 || (!(detail::isPowerOf2(N, 32) && sizeof(T) <= 4) &&
!detail::isMaskedGatherScatterLLVMAvailable())) {
__ESIMD_DNS::slm_scatter_impl<T, VS, detail::lsc_data_size::default_size>(
byte_offsets, vals, mask);
} else if constexpr (detail::isMaskedGatherScatterLLVMAvailable()) {
using MsgT = detail::__raw_t<T>;
__esimd_slm_scatter_st<MsgT, N, Alignment>(
sycl::bit_cast<__ESIMD_DNS::vector_type_t<MsgT, N>>(vals.data()),
byte_offsets.data(), mask.data());
} else {
detail::LocalAccessorMarker acc;
detail::scatter_impl<T, N>(acc, vals, byte_offsets, 0, mask);
Expand Down
6 changes: 3 additions & 3 deletions sycl/test-e2e/ESIMD/unified_memory_api/scatter_usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===------------------------------------------------------------------===//
// RUN: %{build} -fsycl-device-code-split=per_kernel -o %t.out
// RUN: %{build} -fsycl-device-code-split=per_kernel -D__ESIMD_GATHER_SCATTER_LLVM_IR -o %t.out
// RUN: %{run} %t.out

// The test verifies esimd::scatter() functions accepting USM pointer
// and optional compile-time esimd::properties.
// The scatter() calls in this test do not use cache-hint
// properties to not impose using DG2/PVC features.
// The scatter() calls in this test do not use cache-hint properties
// or VS > 1 (number of stores per offset) to not impose using PVC features.

#include "Inputs/scatter.hpp"

Expand Down
21 changes: 21 additions & 0 deletions sycl/test-e2e/ESIMD/unified_memory_api/scatter_usm_legacy.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
//==------- scatter_usm_legacy.cpp - DPC++ ESIMD on-device test -----------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// Use per-kernel compilation to have more information about failing cases.
// RUN: %{build} -fsycl-device-code-split=per_kernel -o %t.out
// RUN: %{run} %t.out

// The test verifies esimd::scatter() functions accepting USM pointer
// and optional compile-time esimd::properties.
// The scatter() calls in this test do not use cache-hint properties
// or VS > 1 (number of stores per offset) to not impose using PVC features.
//
// TODO: Remove this test when GPU driver issue with llvm.masked.scatter is
// resolved and ESIMD starts using llvm.masked.scatter by default.
// "-D__ESIMD_GATHER_SCATTER_LLVM_IR" is not used here.

#include "scatter_usm.cpp"
8 changes: 4 additions & 4 deletions sycl/test-e2e/ESIMD/unified_memory_api/slm_scatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===------------------------------------------------------------------===//
// RUN: %{build} -fsycl-device-code-split=per_kernel -o %t.out
// RUN: %{build} -fsycl-device-code-split=per_kernel -D__ESIMD_GATHER_SCATTER_LLVM_IR -o %t.out
// RUN: %{run} %t.out

// The test verifies esimd::slm_scatter() functions accepting
// optional compile-time esimd::properties.
// The scatter() calls in this test do not use DG2/PVC features.
// The test verifies esimd::slm_scatter() functions accepting optional
// compile-time esimd::properties. The slm_scatter() calls in this test do not
// use VS > 1 (number of stores per offset) to not impose using PVC features.

#include "Inputs/scatter.hpp"

Expand Down
20 changes: 20 additions & 0 deletions sycl/test-e2e/ESIMD/unified_memory_api/slm_scatter_legacy.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
//==------- slm_scatter_legacy.cpp - DPC++ ESIMD on-device test -----------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// Use per-kernel compilation to have more information about failing cases.
// RUN: %{build} -fsycl-device-code-split=per_kernel -o %t.out
// RUN: %{run} %t.out

// The test verifies esimd::slm_scatter() functions accepting optional
// compile-time esimd::properties. The slm_scatter() calls in this test do not
// use VS > 1 (number of stores per offset) to not impose using PVC features.
//
// TODO: Remove this test when GPU driver issue with llvm.masked.scatter is
// resolved and ESIMD starts using llvm.masked.scatter by default.
// "-D__ESIMD_GATHER_SCATTER_LLVM_IR" is not used here.

#include "slm_scatter.cpp"
25 changes: 20 additions & 5 deletions sycl/test/esimd/memory_properties.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1224,7 +1224,7 @@ test_gather_scatter(AccType &acc, LocalAccType &local_acc, float *ptrf,
acc_res = gather<float, 32>(local_acc, ioffset_n32, 0);
acc_res = gather<float, 32>(local_acc, ioffset_n32, 0, mask_n32);

// CHECK-COUNT-4: call void @llvm.genx.svm.scatter.v32i1.v32i64.v32f32(<32 x i1> {{[^)]+}}, i32 0, <32 x i64> {{[^)]+}}, <32 x float> {{[^)]+}})
// CHECK-COUNT-4: call void @llvm.masked.scatter.v32f32.v32p4(<32 x float> {{[^)]+}}, <32 x ptr addrspace(4)> {{[^)]+}}, i32 4, <32 x i1> {{[^)]+}})
scatter(ptrf, ioffset_n32, usm, mask_n32);

scatter(ptrf, ioffset_n32, usm);
Expand Down Expand Up @@ -1281,6 +1281,14 @@ test_gather_scatter(AccType &acc, LocalAccType &local_acc, float *ptrf,
scatter<float, 32, 2>(ptrf, ioffset_n16_view, usm_view, mask_n16);

scatter<float, 32, 2>(ptrf, ioffset_n16_view, usm_view);

simd<uint32_t, 10> ioffset_n10(byte_offset32, 8);
simd<float, 10> usm_n10;

// Check special case to verify that for cases when N is not power of 2 llvm
// intrinsic is used
// CHECK-COUNT-1: call void @llvm.masked.scatter.v10f32.v10p4(<10 x float> {{[^)]+}}, <10 x ptr addrspace(4)> {{[^)]+}}, i32 4, <10 x i1> {{[^)]+}})
scatter(ptrf, ioffset_n10, usm_n10);
}

// CHECK-LABEL: define {{.*}} @_Z23test_slm_gather_scatter{{.*}}
Expand Down Expand Up @@ -1381,26 +1389,26 @@ test_slm_gather_scatter(int byte_offset32) {
// 3) slm_scatter(...): same as (1), (2) above, but with VS > 1.

// 1) slm_scatter(offsets): offsets is simd or simd_view
// CHECK-COUNT-4: call void @llvm.genx.scatter.scaled.v32i1.v32i32.v32f32(<32 x i1> {{[^)]+}}, i32 2, i16 0, i32 {{[^)]+}}, i32 {{[^)]+}}, <32 x i32> {{[^)]+}}, <32 x float> {{[^)]+}})
// CHECK-COUNT-4: call void @llvm.masked.scatter.v32f32.v32p3(<32 x float> {{[^)]+}}, <32 x ptr addrspace(3)> {{[^)]+}}, i32 4, <32 x i1> {{[^)]+}})
slm_scatter<float>(ioffset_n32, slm);
slm_scatter<float, 32>(ioffset_n32_view, slm);
slm_scatter<float, 32>(ioffset_n32, slm_view);
slm_scatter<float, 32>(ioffset_n32_view, slm_view);

// CHECK-COUNT-4: call void @llvm.genx.scatter.scaled.v32i1.v32i32.v32f32(<32 x i1> {{[^)]+}}, i32 2, i16 0, i32 {{[^)]+}}, i32 {{[^)]+}}, <32 x i32> {{[^)]+}}, <32 x float> {{[^)]+}})
// CHECK-COUNT-4: call void @llvm.masked.scatter.v32f32.v32p3(<32 x float> {{[^)]+}}, <32 x ptr addrspace(3)> {{[^)]+}}, i32 8, <32 x i1> {{[^)]+}})
slm_scatter<float>(ioffset_n32, slm, props_align8);
slm_scatter<float, 32>(ioffset_n32_view, slm, props_align8);
slm_scatter<float, 32>(ioffset_n32, slm_view, props_align8);
slm_scatter<float, 32>(ioffset_n32_view, slm_view, props_align8);

// 2) slm_gather(offsets, mask): offsets is simd or simd_view
// CHECK-COUNT-4: call void @llvm.genx.scatter.scaled.v32i1.v32i32.v32f32(<32 x i1> {{[^)]+}}, i32 2, i16 0, i32 {{[^)]+}}, i32 {{[^)]+}}, <32 x i32> {{[^)]+}}, <32 x float> {{[^)]+}})
// CHECK-COUNT-4: call void @llvm.masked.scatter.v32f32.v32p3(<32 x float> {{[^)]+}}, <32 x ptr addrspace(3)> {{[^)]+}}, i32 4, <32 x i1> {{[^)]+}})
slm_scatter<float>(ioffset_n32, slm, mask_n32);
slm_scatter<float, 32>(ioffset_n32_view, slm, mask_n32);
slm_scatter<float, 32>(ioffset_n32, slm_view, mask_n32);
slm_scatter<float, 32>(ioffset_n32_view, slm_view, mask_n32);

// CHECK-COUNT-4: call void @llvm.genx.scatter.scaled.v32i1.v32i32.v32f32(<32 x i1> {{[^)]+}}, i32 2, i16 0, i32 {{[^)]+}}, i32 {{[^)]+}}, <32 x i32> {{[^)]+}}, <32 x float> {{[^)]+}})
// CHECK-COUNT-4: call void @llvm.masked.scatter.v32f32.v32p3(<32 x float> {{[^)]+}}, <32 x ptr addrspace(3)> {{[^)]+}}, i32 8, <32 x i1> {{[^)]+}})
slm_scatter<float>(ioffset_n32, slm, mask_n32, props_align8);
slm_scatter<float, 32>(ioffset_n32_view, slm, mask_n32, props_align8);
slm_scatter<float, 32>(ioffset_n32, slm_view, mask_n32, props_align8);
Expand Down Expand Up @@ -1429,4 +1437,11 @@ test_slm_gather_scatter(int byte_offset32) {
slm_scatter<float, 32, 2>(ioffset_n16_view, slm, mask_n16, props_align4);
slm_scatter<float, 32, 2>(ioffset_n16, slm_view, mask_n16, props_align4);
slm_scatter<float, 32, 2>(ioffset_n16_view, slm_view, mask_n16, props_align4);

simd<uint32_t, 10> ioffset_n10(byte_offset32, 8);
simd<float, 10> usm_n10;
// Check special case to verify that for cases when N is not power of 2 llvm
// intrinsic is used
// CHECK-COUNT-1: call void @llvm.masked.scatter.v10f32.v10p3(<10 x float> {{[^)]+}}, <10 x ptr addrspace(3)> {{[^)]+}}, i32 4, <10 x i1> {{[^)]+}})
slm_scatter(ioffset_n10, usm_n10);
}

0 comments on commit 7ee7e90

Please sign in to comment.