Skip to content

Commit

Permalink
Fix AccessChain on cooperative matrix processing
Browse files Browse the repository at this point in the history
 Fixes an issue for sycl::half and sycl::bfloat16 types
  • Loading branch information
MrSidims authored and igcbot committed Aug 9, 2024
1 parent 537f3a1 commit 41bc660
Show file tree
Hide file tree
Showing 4 changed files with 378 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2013,15 +2013,19 @@ void JointMatrixFuncsResolutionPass::preprocessAccessChain(Function *F) {
// and more effecient to preprocess it and map on VectorInsertDynamic
// and VectorExtractDynamic here, before resolving and substituting matrix
// types
llvm::SmallPtrSet<llvm::Instruction *, 8> toErase;
SmallVector<Instruction*, 16> toErase;
SmallVector<std::pair<Instruction*, Instruction*>, 8> replaces;
for (auto It = F->use_begin(); It != F->use_end(); It++) {
auto user = It->getUser();
if (!isa<CallInst>(user))
continue;
auto *CI = cast<CallInst>(user);
IRBuilder<> builder(CI);
IGC_ASSERT_MESSAGE(CI->hasOneUse(),
"Unexpected matrix insert/extract format");
if (CI->getNumUses() == 0) {
toErase.push_back(CI);
continue;
}

Type *chainBaseTy =
IGCLLVM::getNonOpaquePtrEltTy(CI->getArgOperand(0)->getType());
IGC_ASSERT_MESSAGE(isMatrixType(chainBaseTy),
Expand All @@ -2030,53 +2034,86 @@ void JointMatrixFuncsResolutionPass::preprocessAccessChain(Function *F) {
Value *matrix = builder.CreateLoad(chainBaseTy, ptrToMatrix, "");
Value *index = CI->getArgOperand(1);

Instruction *memInst =
cast<Instruction>(IGCLLVM::getUniqueUndroppableUser(CI));
// Expected format is:
// %matrix_ptr = alloca %matrix_type
// (may be some casts)
// %element_ptr = __spirv_AccessChain(%matrix_ptr, %index)
// 1. For extract
// %element = load %element_ptr
// 2. For insert
// store %element %element_tr
IGC_ASSERT_MESSAGE(isa<LoadInst>(memInst) || isa<StoreInst>(memInst),
"Unexpected matrix insert/extract format");
// Get __spirv_AccessChain function's mangling postfix to reuse it for
// overloading of insert/extract
constexpr unsigned ACNameLenght = 23;
std::string funcPostfix = (F->getName().drop_front(ACNameLenght)).str();
if (isa<LoadInst>(memInst)) {
std::vector<Value *> Args = { matrix, index };
FunctionType *funcType = FunctionType::get(
memInst->getType(), { matrix->getType(), index->getType() },
false);
std::string funcName = std::string(SPIRVPrefix) +
std::string(JointMatrixSliceExtract) +
funcPostfix;
Instruction *extractCall = builder.CreateCall(
F->getParent()->getOrInsertFunction(funcName, funcType), Args);
memInst->replaceAllUsesWith(extractCall);
}
if (isa<StoreInst>(memInst)) {
Value *component = cast<StoreInst>(memInst)->getValueOperand();
std::vector<Value *> Args = { matrix, component, index };
FunctionType *funcType = FunctionType::get(
chainBaseTy, { matrix->getType(), component->getType(),
index->getType() }, false);
std::string funcName = std::string(SPIRVPrefix) +
std::string(JointMatrixSliceInsert) +
funcPostfix;
Instruction *extractCall =
builder.CreateCall(F->getParent()->getOrInsertFunction(
funcName, funcType), Args);
builder.CreateStore(extractCall, ptrToMatrix);
for (const auto &U : CI->users()) {
Instruction *memInst = dyn_cast<Instruction>(U);
if (!memInst)
continue;
// In case of sycl::half and sycl::bfloat16 storage will be accessed
// via zero GEP or pointer cast that we need to strip, for example:
// %call = call spir_func %structtype addrspace(4)* @__spirv_AccessChain(
// %spirv.CooperativeMatrixKHR._half_3_8_16_0 addrspace(1)* addrspace(4)* %matrix,
// i64 %idx)
// %gep = getelementptr inbounds %structtype, %structtype addrspace(4)* %call, i64 0, i32 0
// %cast = bitcast i16 addrspace(4)* %gep to half addrspace(4)*
// %extract = load half, half addrspace(4)* %cast
while (isa<GetElementPtrInst>(memInst) ||
isa<BitCastInst>(memInst) ||
isa<AddrSpaceCastInst>(memInst)) {
if (isa<GetElementPtrInst>(memInst))
IGC_ASSERT_MESSAGE(
cast<GetElementPtrInst>(memInst)->hasAllZeroIndices(),
"Unexpected matrix insert/extract format");
toErase.push_back(memInst);
memInst = cast<Instruction>(
IGCLLVM::getUniqueUndroppableUser(memInst));
}
// Expected format is:
// %matrix_ptr = alloca %matrix_type
// (may be some casts)
// %element_ptr = __spirv_AccessChain(%matrix_ptr, %index)
// 1. For extract
// %element = load %element_ptr
// 2. For insert
// store %element %element_ptr
IGC_ASSERT_MESSAGE(
isa<LoadInst>(memInst) || isa<StoreInst>(memInst),
"Unexpected matrix insert/extract format");
// Get __spirv_AccessChain function's mangling postfix to reuse it
// for overloading of insert/extract
constexpr unsigned ACNameLength = 23; // "_Z19__spirv_AccessChain"
std::string funcPostfix =
(F->getName().drop_front(ACNameLength)).str();
builder.SetInsertPoint(memInst);
if (isa<LoadInst>(memInst)) {
std::vector<Value *> Args = { matrix, index };
FunctionType *funcType = FunctionType::get(
memInst->getType(), { matrix->getType(), index->getType() },
false);
std::string funcName = std::string("_Z28") +
std::string(SPIRVPrefix) +
std::string(JointMatrixSliceExtract) +
funcPostfix;
Instruction *extractCall = builder.CreateCall(
F->getParent()->getOrInsertFunction(funcName, funcType),
Args);
replaces.push_back(std::make_pair(memInst, extractCall));
}
if (isa<StoreInst>(memInst)) {
Value *component = cast<StoreInst>(memInst)->getValueOperand();
std::vector<Value *> Args = { matrix, component, index };
FunctionType *funcType = FunctionType::get(
chainBaseTy, { matrix->getType(), component->getType(),
index->getType() }, false);
std::string funcName = std::string("_Z27") +
std::string(SPIRVPrefix) +
std::string(JointMatrixSliceInsert) +
funcPostfix;
Instruction *extractCall =
builder.CreateCall(F->getParent()->getOrInsertFunction(
funcName, funcType), Args);
builder.CreateStore(extractCall, ptrToMatrix);
}
toErase.push_back(memInst);
}
toErase.insert(memInst);
toErase.insert(CI);
toErase.push_back(CI);
}
for (Instruction *I : toErase) {
for (const auto &InstPair : replaces) {
InstPair.first->replaceAllUsesWith(InstPair.second);
}
for (Instruction *I : llvm::reverse(toErase)) {
I->dropAllReferences();
}
for (Instruction *I : toErase) {
I->eraseFromParent();
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
;=========================== begin_copyright_notice ============================
;
; Copyright (C) 2024 Intel Corporation
;
; SPDX-License-Identifier: MIT
;
; This software and the related documents are Intel copyrighted materials,
; and your use of them is governed by the express license under which they were
; provided to you ("License"). Unless the License provides otherwise,
; you may not use, modify, copy, publish, distribute, disclose or transmit this
; software or the related documents without Intel's prior written permission.
;
; This software and the related documents are provided as is, with no express or
; implied warranties, other than those that are expressly stated in the License.
;
;============================ end_copyright_notice =============================
; REQUIRES: llvm-14-plus

; RUN: igc_opt -platformpvc -igc-joint-matrix-resolution -S 2>&1 < %s | FileCheck %s
; ------------------------------------------------
; JointMatrixFuncsResolutionPass
; ------------------------------------------------
; Checks if unused __spirv_AccessChain function call is removed and doesn't
; cause trouble

; CHECK-NOT: call spir_func float addrspace(4)* @_Z19__spirv_AccessChain

target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v24:32:32-v32:32:32-v48:64:64-v64:64:64-v96:128:128-v128:128:128-v192:256:256-v256:256:256-v512:512:512-v1024:1024:1024-n8:16:32"
target triple = "spir64-unknown-unknown"

%spirv.CooperativeMatrixKHR._float_3_8_16_0 = type opaque
%"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix" = type { %spirv.CooperativeMatrixKHR._float_3_8_16_0 addrspace(1)* }

; Function Attrs: nounwind
define spir_kernel void @_ZTS5logicILm8ELm16EE() {
entry:
%0 = alloca %spirv.CooperativeMatrixKHR._float_3_8_16_0 addrspace(1)*
%1 = call spir_func %spirv.CooperativeMatrixKHR._float_3_8_16_0 addrspace(1)* @_Z26__spirv_CompositeConstructf(float 0.0)
store %spirv.CooperativeMatrixKHR._float_3_8_16_0 addrspace(1)* %1, %spirv.CooperativeMatrixKHR._float_3_8_16_0 addrspace(1)** %0
%call = call spir_func float addrspace(4)* @_Z19__spirv_AccessChainPU3AS4PU3AS143__spirv_CooperativeMatrixKHR__float_3_8_16_0l(%spirv.CooperativeMatrixKHR._float_3_8_16_0 addrspace(1)** %0, i64 4)
ret void
}

; Function Attrs: nounwind
declare spir_func %spirv.CooperativeMatrixKHR._float_3_8_16_0 addrspace(1)* @_Z26__spirv_CompositeConstructf(float %0)

; Function Attrs: nounwind
declare spir_func float addrspace(4)* @_Z19__spirv_AccessChainPU3AS4PU3AS143__spirv_CooperativeMatrixKHR__float_3_8_16_0l(%spirv.CooperativeMatrixKHR._float_3_8_16_0 addrspace(1)** %0, i64 %1)

!spirv.MemoryModel = !{!0}
!spirv.Source = !{!1}
!spirv.Generator = !{!2}
!igc.functions = !{!3}

!0 = !{i32 2, i32 2}
!1 = !{i32 4, i32 100000}
!2 = !{i16 6, i16 14}
!3 = !{void ()* @_ZTS5logicILm8ELm16EE, !4}
!4 = !{!5}
!5 = !{!"function_type", i32 0}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
;=========================== begin_copyright_notice ============================
;
; Copyright (C) 2024 Intel Corporation
;
; SPDX-License-Identifier: MIT
;
; This software and the related documents are Intel copyrighted materials,
; and your use of them is governed by the express license under which they were
; provided to you ("License"). Unless the License provides otherwise,
; you may not use, modify, copy, publish, distribute, disclose or transmit this
; software or the related documents without Intel's prior written permission.
;
; This software and the related documents are provided as is, with no express or
; implied warranties, other than those that are expressly stated in the License.
;
;============================ end_copyright_notice =============================
; REQUIRES: llvm-14-plus

; RUN: igc_opt -platformpvc -igc-joint-matrix-resolution -S 2>&1 < %s | FileCheck %s
; ------------------------------------------------
; JointMatrixFuncsResolutionPass
; ------------------------------------------------
; Checks for multiple uses of __spirv_AccessChain function call - load plus store
; it must result in extract and then insert an element to the matrix's slice

; CHECK: [[SLICE:%.*]] = load <8 x i16>, <8 x i16>* %{{.*}}, align 8
; CHECK: [[ELEMENT:%.*]] = extractelement <8 x i16> [[SLICE]], i64 4, !joint_matrix_apply
; CHECK: [[ADD:%.*]] = add i16 [[ELEMENT]], 1
; CHECK: [[INSERT:%.*]] = insertelement <8 x i16> [[SLICE]], i16 [[ADD]], i64 4
; CHECK: store <8 x i16> [[INSERT]], <8 x i16>* %{{.*}}, align 8

target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v24:32:32-v32:32:32-v48:64:64-v64:64:64-v96:128:128-v128:128:128-v192:256:256-v256:256:256-v512:512:512-v1024:1024:1024-n8:16:32"
target triple = "spir64-unknown-unknown"

%spirv.CooperativeMatrixKHR._short_3_8_16_0 = type opaque
%"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix" = type { %spirv.CooperativeMatrixKHR._short_3_8_16_0 addrspace(1)* }

; Function Attrs: nounwind
define spir_kernel void @_ZTS5logicILm8ELm16EE(i16 addrspace(1)* %arg) {
entry:
%0 = alloca %spirv.CooperativeMatrixKHR._short_3_8_16_0 addrspace(1)*
%1 = call spir_func %spirv.CooperativeMatrixKHR._short_3_8_16_0 addrspace(1)* @_Z86__spirv_CooperativeMatrixLoadKHR_RPU3AS143__spirv_CooperativeMatrixKHR__short_3_8_16_0PU3AS1slii(i16 addrspace(1)* %arg, i32 0, i64 64, i32 0)
store %spirv.CooperativeMatrixKHR._short_3_8_16_0 addrspace(1)* %1, %spirv.CooperativeMatrixKHR._short_3_8_16_0 addrspace(1)** %0
%ptr = call spir_func i16 addrspace(4)* @_Z19__spirv_AccessChainPU3AS4PU3AS143__spirv_CooperativeMatrixKHR._short_3_8_16_0l(%spirv.CooperativeMatrixKHR._short_3_8_16_0 addrspace(1)** %0, i64 4)
%extract = load i16, i16 addrspace(4)* %ptr
%add = add i16 %extract, 1
store i16 %add, i16 addrspace(4)* %ptr
ret void
}

; Function Attrs: nounwind
declare spir_func %spirv.CooperativeMatrixKHR._short_3_8_16_0 addrspace(1)* @_Z86__spirv_CooperativeMatrixLoadKHR_RPU3AS143__spirv_CooperativeMatrixKHR__short_3_8_16_0PU3AS1slii(i16 addrspace(1)* %0, i32 %1, i64 %2, i32 %3)

; Function Attrs: nounwind
declare spir_func i16 addrspace(4)* @_Z19__spirv_AccessChainPU3AS4PU3AS143__spirv_CooperativeMatrixKHR._short_3_8_16_0l(%spirv.CooperativeMatrixKHR._short_3_8_16_0 addrspace(1)** %0, i64 %1)

!spirv.MemoryModel = !{!0}
!spirv.Source = !{!1}
!spirv.Generator = !{!2}
!igc.functions = !{!3}

!0 = !{i32 2, i32 2}
!1 = !{i32 4, i32 100000}
!2 = !{i16 6, i16 14}
!3 = !{void (i16 addrspace(1)*)* @_ZTS5logicILm8ELm16EE, !4}
!4 = !{!5}
!5 = !{!"function_type", i32 0}
Loading

0 comments on commit 41bc660

Please sign in to comment.