Skip to content

Commit

Permalink
Allow cooperative matrix types for OpMatrixTimesScalar
Browse files Browse the repository at this point in the history
Forward translation to `OpMatrixTimesScalar` is not implemented even for
`OpTypeMatrix`, so this patch introduces only reverse translation.

Spec:
https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/KHR/SPV_KHR_cooperative_matrix.asciidoc
  • Loading branch information
vmaksimo committed Aug 14, 2023
1 parent 2bb5fed commit dacf02a
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 0 deletions.
6 changes: 6 additions & 0 deletions lib/SPIRV/SPIRVReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1961,6 +1961,12 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
IRBuilder<> Builder(BB);
auto *Scalar = transValue(MTS->getScalar(), F, BB);
auto *Matrix = transValue(MTS->getMatrix(), F, BB);

if (MTS->getMatrix()->getType()->isTypeCooperativeMatrixKHR()) {
return mapValue(BV, transSPIRVBuiltinFromInst(
static_cast<SPIRVInstruction *>(BV), BB));
}

uint64_t ColNum = Matrix->getType()->getArrayNumElements();
auto *ColType = cast<ArrayType>(Matrix->getType())->getElementType();
auto VecSize = cast<FixedVectorType>(ColType)->getNumElements();
Expand Down
3 changes: 3 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ SPIRVType *SPIRVType::getScalarType() const {
return getVectorComponentType();
case OpTypeMatrix:
return getMatrixColumnType()->getVectorComponentType();
case OpTypeCooperativeMatrixKHR:
return static_cast<const SPIRVTypeCooperativeMatrixKHR *>(this)
->getCompType();
case OpTypeInt:
case OpTypeFloat:
case OpTypeBool:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
; RUN: llvm-spirv %s -to-binary -o %t.spv
; RUN: spirv-val %t.spv
; RUN: llvm-spirv -r %t.spv -o %t.bc
; RUN: llvm-dis < %t.bc | FileCheck %s --check-prefix=CHECK-LLVM

; CHECK-LLVM: %[[#Matrix:]] = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructf(float 0.000000e+00)
; CHECK-LLVM: %[[#Scalar:]] = load float, ptr %rhs
; CHECK-LLVM: %3 = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z25__spirv_MatrixTimesScalarPU3AS145__spirv_CooperativeMatrixKHR__float_3_12_12_3f(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %[[#Matrix]], float %[[#Scalar]])

119734787 65536 458752 24 0
2 Capability Addresses
2 Capability Linkage
2 Capability Kernel
2 Capability Float64
2 Capability Matrix
2 Capability CooperativeMatrixKHR
8 Extension "SPV_KHR_cooperative_matrix"
3 MemoryModel 2 2
8 EntryPoint 6 20 "matrix_times_scalar"
3 Source 3 102000
3 Name 14 "rhs"

4 TypeInt 19 32 0
4 Constant 19 21 3
4 Constant 19 22 12
2 TypeVoid 5
3 TypeFloat 6 32
4 Constant 6 23 0
7 TypeCooperativeMatrixKHR 8 6 21 22 22 21
4 TypePointer 10 7 6 ; 10 : Pointer to Scalar
4 TypeFunction 11 5 10

5 Function 5 20 0 11
3 FunctionParameter 10 14 ; rhs : Pointer to Scalar

2 Label 15
4 CompositeConstruct 8 16 23
4 Load 6 17 14

5 MatrixTimesScalar 8 18 16 17
1 Return

1 FunctionEnd

0 comments on commit dacf02a

Please sign in to comment.