Skip to content

Commit

Permalink
[Bckport to 14] Add JointMatrixGetElementCoordINTEL instruction
Browse files Browse the repository at this point in the history
The instruction returns (Row, Column) coordinate of dynamically selected
element of a matrix

Updated version of the spec is here
intel/llvm#8175

Instruction correctness checks will be added later among non-backward
compatible changes.

Signed-off-by: Sidorov, Dmitry dmitry.sidorov@intel.com
Signed-off-by: Sidorov, Dmitry <dmitry.sidorov@intel.com>
  • Loading branch information
MrSidims committed Nov 20, 2023
1 parent 3b16190 commit 7b29f08
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 6 deletions.
2 changes: 2 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVEnum.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,8 @@ template <> inline void SPIRVMap<SPIRVCapabilityKind, SPIRVCapVec>::init() {
{CapabilitySubgroupAvcMotionEstimationINTEL});
ADD_VEC_INIT(CapabilitySubgroupAvcMotionEstimationChromaINTEL,
{CapabilitySubgroupAvcMotionEstimationIntraINTEL});
ADD_VEC_INIT(internal::CapabilityJointMatrixWIInstructionsINTEL,
{internal::CapabilityJointMatrixINTEL});
}

template <> inline void SPIRVMap<SPIRVExecutionModelKind, SPIRVCapVec>::init() {
Expand Down
15 changes: 15 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -3375,6 +3375,7 @@ _SPIRV_OP(JointMatrixMad, true, 7)
_SPIRV_OP(JointMatrixSUMad, true, 7)
_SPIRV_OP(JointMatrixUSMad, true, 7)
_SPIRV_OP(JointMatrixUUMad, true, 7)
// TODO: move to SPIRVJointMatrixINTELWorkItemInst
_SPIRV_OP(JointMatrixWorkItemLength, true, 4)
#undef _SPIRV_OP

Expand All @@ -3398,6 +3399,20 @@ _SPIRV_OP(CooperativeMatrixLengthKHR, true, 4, false)
_SPIRV_OP(CooperativeMatrixMulAddKHR, true, 6, true, 3)
#undef _SPIRV_OP

class SPIRVJointMatrixINTELWorkItemInst : public SPIRVJointMatrixINTELInstBase {
protected:
SPIRVCapVec getRequiredCapability() const override {
return getVec(internal::CapabilityJointMatrixWIInstructionsINTEL);
}
};

#define _SPIRV_OP(x, ...) \
typedef SPIRVInstTemplate<SPIRVJointMatrixINTELWorkItemInst, \
internal::Op##x##INTEL, __VA_ARGS__> \
SPIRV##x##INTEL;
_SPIRV_OP(JointMatrixGetElementCoord, true, 5)
#undef _SPIRV_OP

class SPIRVSplitBarrierINTELBase : public SPIRVInstTemplateBase {
protected:
SPIRVCapVec getRequiredCapability() const override {
Expand Down
2 changes: 2 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,8 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
add(internal::CapabilityTensorFloat32RoundingINTEL,
"TensorFloat32RoundingINTEL");
add(internal::CapabilityCacheControlsINTEL, "CacheControlsINTEL");
add(internal::CapabilityJointMatrixWIInstructionsINTEL,
"JointMatrixWIInstructionsINTEL");
}
SPIRV_DEF_NAMEMAP(Capability, SPIRVCapabilityNameMap)

Expand Down
2 changes: 2 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ _SPIRV_OP_INTERNAL(JointMatrixUSMadINTEL, internal::OpJointMatrixUSMadINTEL)
_SPIRV_OP_INTERNAL(JointMatrixUUMadINTEL, internal::OpJointMatrixUUMadINTEL)
_SPIRV_OP_INTERNAL(JointMatrixWorkItemLengthINTEL,
internal::OpJointMatrixWorkItemLengthINTEL)
_SPIRV_OP_INTERNAL(JointMatrixGetElementCoordINTEL,
internal::OpJointMatrixGetElementCoordINTEL)
_SPIRV_OP_INTERNAL(ComplexFMulINTEL, internal::ComplexFMulINTEL)
_SPIRV_OP_INTERNAL(ComplexFDivINTEL, internal::ComplexFDivINTEL)
_SPIRV_OP_INTERNAL(MaskedGatherINTEL, internal::OpMaskedGatherINTEL)
Expand Down
5 changes: 5 additions & 0 deletions lib/SPIRV/libSPIRV/spirv_internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ enum InternalOp {
IOpRoundFToTF32INTEL = 6426,
IOpMaskedGatherINTEL = 6428,
IOpMaskedScatterINTEL = 6429,
IOpJointMatrixGetElementCoordINTEL = 6440,
IOpPrev = OpMax - 2,
IOpForward
};
Expand Down Expand Up @@ -109,6 +110,7 @@ enum InternalCapability {
ICapabilityComplexFloatMulDivINTEL = 6414,
ICapabilityTensorFloat32RoundingINTEL = 6425,
ICapabilityMaskedGatherScatterINTEL = 6427,
ICapabilityJointMatrixWIInstructionsINTEL = 6435,
ICapabilityCacheControlsINTEL = 6441
};

Expand Down Expand Up @@ -155,6 +157,7 @@ enum class StoreCacheControlINTEL {

#define _SPIRV_OP(x, y) constexpr x x##y = static_cast<x>(I##x##y);
_SPIRV_OP(Capability, JointMatrixINTEL)
_SPIRV_OP(Capability, JointMatrixWIInstructionsINTEL)

Check warning on line 160 in lib/SPIRV/libSPIRV/spirv_internal.hpp

View workflow job for this annotation

GitHub Actions / clang-format & clang-tidy

unused variable 'CapabilityJointMatrixWIInstructionsINTEL' [clang-diagnostic-unused-const-variable]
_SPIRV_OP(Op, TypeJointMatrixINTEL)
_SPIRV_OP(Op, JointMatrixLoadINTEL)
_SPIRV_OP(Op, JointMatrixStoreINTEL)
Expand All @@ -163,6 +166,8 @@ _SPIRV_OP(Op, JointMatrixSUMadINTEL)
_SPIRV_OP(Op, JointMatrixUSMadINTEL)
_SPIRV_OP(Op, JointMatrixUUMadINTEL)
_SPIRV_OP(Op, JointMatrixWorkItemLengthINTEL)
_SPIRV_OP(Op, JointMatrixGetElementCoordINTEL)

Check warning on line 169 in lib/SPIRV/libSPIRV/spirv_internal.hpp

View workflow job for this annotation

GitHub Actions / clang-format & clang-tidy

unused variable 'OpJointMatrixGetElementCoordINTEL' [clang-diagnostic-unused-const-variable]

_SPIRV_OP(Capability, HWThreadQueryINTEL)
_SPIRV_OP(BuiltIn, SubDeviceIDINTEL)
_SPIRV_OP(BuiltIn, GlobalHWThreadIDINTEL)
Expand Down
21 changes: 15 additions & 6 deletions test/transcoding/SPV_INTEL_joint_matrix/joint_matrix_element.ll
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,26 @@
; RUN: llvm-spirv -r %t.spv -o %t.rev.bc
; RUN: llvm-dis %t.rev.bc -o - | FileCheck %s --check-prefix=CHECK-LLVM

; CHECK-SPIRV: Capability JointMatrixINTEL
; CHECK-SPIRV: Extension "SPV_INTEL_joint_matrix"
; CHECK-SPIRV: TypeInt [[#TypeInt:]] 64
; CHECK-SPIRV: TypeFloat [[#TypeFloat:]] 32
; CHECK-SPIRV: TypeJointMatrixINTEL [[#TypeMatrix:]] [[#TypeFloat]] [[#]] [[#]] [[#]] [[#]]
; CHECK-SPIRV-DAG: Capability JointMatrixINTEL
; CHECK-SPIRV-DAG: Capability JointMatrixWIInstructionsINTEL
; CHECK-SPIRV-DAG: Extension "SPV_INTEL_joint_matrix"
; CHECK-SPIRV-DAG: TypeInt [[#TypeInt32:]] 32
; CHECK-SPIRV-DAG: TypeInt [[#TypeInt64:]] 64
; CHECK-SPIRV-DAG: TypeFloat [[#TypeFloat:]] 32
; CHECK-SPIRV-DAG: TypeJointMatrixINTEL [[#TypeMatrix:]] [[#TypeFloat]] [[#]] [[#]] [[#]] [[#]]
; CHECK-SPIRV-DAG: TypeVector [[#TypeVec:]] [[#TypeInt32]] 2
; CHECK-SPIRV: Phi [[#TypeMatrix]] [[#Matrix:]]
; CHECK-SPIRV: JointMatrixWorkItemLengthINTEL [[#TypeInt]] [[#]] [[#Matrix]]
; CHECK-SPIRV: JointMatrixWorkItemLengthINTEL [[#TypeInt64]] [[#]] [[#Matrix]]
; CHECK-SPIRV: VectorExtractDynamic [[#TypeFloat]] [[#]] [[#Matrix]] [[#Index:]]
; CHECK-SPIRV: FMul [[#TypeFloat]] [[#NewVal:]] [[#]] [[#]]
; CHECK-SPIRV: VectorInsertDynamic [[#TypeMatrix]] [[#]] [[#Matrix]] [[#NewVal]] [[#Index]]
; CHECK-SPIRV: JointMatrixGetElementCoordINTEL [[#TypeVec]] [[#]] [[#Matrix]] [[#Index]]

; CHECK-LLVM: [[Length:%.*]] = call spir_func i64 @_Z38__spirv_JointMatrixWorkItemLengthINTELPU3AS141__spirv_JointMatrixINTEL__float_16_16_0_3(%spirv.JointMatrixINTEL._float_16_16_0_3 addrspace(1)* [[Matrix:%.*]])
; CHECK-LLVM: [[Elem:%.*]] = call spir_func float @_Z28__spirv_VectorExtractDynamicPU3AS141__spirv_JointMatrixINTEL__float_16_16_0_3l(%spirv.JointMatrixINTEL._float_16_16_0_3 addrspace(1)* [[Matrix]], i64 [[Index:%.*]])
; CHECK-LLVM: [[NewVal:%.*]] = fmul float [[Elem]], 5.000000e+00
; CHECK-LLVM: {{%.*}} = call spir_func %spirv.JointMatrixINTEL._float_16_16_0_3 addrspace(1)* @_Z27__spirv_VectorInsertDynamicPU3AS141__spirv_JointMatrixINTEL__float_16_16_0_3fl(%spirv.JointMatrixINTEL._float_16_16_0_3 addrspace(1)* [[Matrix]], float [[NewVal]], i64 [[Index]])
; CHECK-LLVM: {{%.*}} = call spir_func <2 x i32> @_Z39__spirv_JointMatrixGetElementCoordINTELPU3AS141__spirv_JointMatrixINTEL__float_16_16_0_3l(%spirv.JointMatrixINTEL._float_16_16_0_3 addrspace(1)* [[Matrix]], i64 [[Index]])

source_filename = "/work/tmp/matrix-slice.cpp"
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
Expand Down Expand Up @@ -69,6 +74,7 @@ for.body.i: ; preds = %for.cond.i
%call.i.i = tail call spir_func float @_Z28__spirv_VectorExtractDynamicIfLm16ELm16ELN5__spv12MatrixLayoutE0ELNS0_5Scope4FlagE3EmET_PNS0_24__spirv_JointMatrixINTELIS4_XT0_EXT1_EXT2_EXT3_EEET4_(%spirv.JointMatrixINTEL._float_16_16_0_3 addrspace(4)* %A.sroa.0.0.i, i64 %conv.i) #2
%mul.i.i = fmul float %call.i.i, 5.000000e+00
%call5.i.i = tail call spir_func %spirv.JointMatrixINTEL._float_16_16_0_3 addrspace(4)* @_Z27__spirv_VectorInsertDynamicIfLm16ELm16ELN5__spv12MatrixLayoutE0ELNS0_5Scope4FlagE3EmEPNS0_24__spirv_JointMatrixINTELIT_XT0_EXT1_EXT2_EXT3_EEES7_T4_S5_(%spirv.JointMatrixINTEL._float_16_16_0_3 addrspace(4)* %A.sroa.0.0.i, float %mul.i.i, i64 %conv.i) #2
%call6 = tail call spir_func <2 x i32> @_Z39__spirv_JointMatrixGetElementCoordINTELIaLm8ELm32ELN5__spv9MatrixUseE0ELNS0_12MatrixLayoutE0ELNS0_5Scope4FlagE3EEDv2_jPNS0_24__spirv_JointMatrixINTELIT_XT0_EXT1_EXT3_EXT4_EXT2_EEEm(%spirv.JointMatrixINTEL._float_16_16_0_3 addrspace(4)* %A.sroa.0.0.i, i64 %conv.i) #2
%inc.i = add nuw nsw i32 %i.0.i, 1
br label %for.cond.i, !llvm.loop !7

Expand All @@ -92,6 +98,9 @@ declare dso_local spir_func %spirv.JointMatrixINTEL._float_16_16_0_3 addrspace(4
; Function Attrs: convergent
declare dso_local spir_func void @_Z29__spirv_JointMatrixStoreINTELIfLm16ELm16ELN5__spv12MatrixLayoutE0ELNS0_5Scope4FlagE3EEvPT_PNS0_24__spirv_JointMatrixINTELIS4_XT0_EXT1_EXT2_EXT3_EEEmS1_S3_i(float addrspace(4)*, %spirv.JointMatrixINTEL._float_16_16_0_3 addrspace(4)*, i64, i32, i32, i32) local_unnamed_addr #1

; Function Attrs: convergent
declare dso_local spir_func <2 x i32> @_Z39__spirv_JointMatrixGetElementCoordINTELIaLm8ELm32ELN5__spv9MatrixUseE0ELNS0_12MatrixLayoutE0ELNS0_5Scope4FlagE3EEDv2_jPNS0_24__spirv_JointMatrixINTELIT_XT0_EXT1_EXT3_EXT4_EXT2_EEEm(%spirv.JointMatrixINTEL._float_16_16_0_3 addrspace(4)*, i64) #2

attributes #0 = { convergent norecurse "frame-pointer"="all" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "sycl-module-id"="/work/tmp/matrix-slice.cpp" "uniform-work-group-size"="true" }
attributes #1 = { convergent "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
attributes #2 = { convergent }
Expand Down

0 comments on commit 7b29f08

Please sign in to comment.