From 7b29f08ec2ff3de75d73f7f9fae351c4a676ad93 Mon Sep 17 00:00:00 2001 From: Dmitry Sidorov Date: Tue, 21 Feb 2023 12:24:02 +0100 Subject: [PATCH] [Bckport to 14] Add JointMatrixGetElementCoordINTEL instruction The instruction returns (Row, Column) coordinate of dynamically selected element of a matrix Updated version of the spec is here intel/llvm#8175 Instruction correctness checks will be added later among non-backward compatible changes. Signed-off-by: Sidorov, Dmitry dmitry.sidorov@intel.com Signed-off-by: Sidorov, Dmitry --- lib/SPIRV/libSPIRV/SPIRVEnum.h | 2 ++ lib/SPIRV/libSPIRV/SPIRVInstruction.h | 15 +++++++++++++ lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h | 2 ++ lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h | 2 ++ lib/SPIRV/libSPIRV/spirv_internal.hpp | 5 +++++ .../joint_matrix_element.ll | 21 +++++++++++++------ 6 files changed, 41 insertions(+), 6 deletions(-) diff --git a/lib/SPIRV/libSPIRV/SPIRVEnum.h b/lib/SPIRV/libSPIRV/SPIRVEnum.h index 1bd5cc7e40..43dcdc6e82 100644 --- a/lib/SPIRV/libSPIRV/SPIRVEnum.h +++ b/lib/SPIRV/libSPIRV/SPIRVEnum.h @@ -205,6 +205,8 @@ template <> inline void SPIRVMap::init() { {CapabilitySubgroupAvcMotionEstimationINTEL}); ADD_VEC_INIT(CapabilitySubgroupAvcMotionEstimationChromaINTEL, {CapabilitySubgroupAvcMotionEstimationIntraINTEL}); + ADD_VEC_INIT(internal::CapabilityJointMatrixWIInstructionsINTEL, + {internal::CapabilityJointMatrixINTEL}); } template <> inline void SPIRVMap::init() { diff --git a/lib/SPIRV/libSPIRV/SPIRVInstruction.h b/lib/SPIRV/libSPIRV/SPIRVInstruction.h index 7f33b44f2d..cde72f383e 100644 --- a/lib/SPIRV/libSPIRV/SPIRVInstruction.h +++ b/lib/SPIRV/libSPIRV/SPIRVInstruction.h @@ -3375,6 +3375,7 @@ _SPIRV_OP(JointMatrixMad, true, 7) _SPIRV_OP(JointMatrixSUMad, true, 7) _SPIRV_OP(JointMatrixUSMad, true, 7) _SPIRV_OP(JointMatrixUUMad, true, 7) +// TODO: move to SPIRVJointMatrixINTELWorkItemInst _SPIRV_OP(JointMatrixWorkItemLength, true, 4) #undef _SPIRV_OP @@ -3398,6 +3399,20 @@ _SPIRV_OP(CooperativeMatrixLengthKHR, true, 4, false) _SPIRV_OP(CooperativeMatrixMulAddKHR, true, 6, true, 3) #undef _SPIRV_OP +class SPIRVJointMatrixINTELWorkItemInst : public SPIRVJointMatrixINTELInstBase { +protected: + SPIRVCapVec getRequiredCapability() const override { + return getVec(internal::CapabilityJointMatrixWIInstructionsINTEL); + } +}; + +#define _SPIRV_OP(x, ...) \ + typedef SPIRVInstTemplate \ + SPIRV##x##INTEL; +_SPIRV_OP(JointMatrixGetElementCoord, true, 5) +#undef _SPIRV_OP + class SPIRVSplitBarrierINTELBase : public SPIRVInstTemplateBase { protected: SPIRVCapVec getRequiredCapability() const override { diff --git a/lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h b/lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h index a55869ecc8..ab5865a7e4 100644 --- a/lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h +++ b/lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h @@ -622,6 +622,8 @@ template <> inline void SPIRVMap::init() { add(internal::CapabilityTensorFloat32RoundingINTEL, "TensorFloat32RoundingINTEL"); add(internal::CapabilityCacheControlsINTEL, "CacheControlsINTEL"); + add(internal::CapabilityJointMatrixWIInstructionsINTEL, + "JointMatrixWIInstructionsINTEL"); } SPIRV_DEF_NAMEMAP(Capability, SPIRVCapabilityNameMap) diff --git a/lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h b/lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h index 2682f86937..b4d4b58d63 100644 --- a/lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h +++ b/lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h @@ -14,6 +14,8 @@ _SPIRV_OP_INTERNAL(JointMatrixUSMadINTEL, internal::OpJointMatrixUSMadINTEL) _SPIRV_OP_INTERNAL(JointMatrixUUMadINTEL, internal::OpJointMatrixUUMadINTEL) _SPIRV_OP_INTERNAL(JointMatrixWorkItemLengthINTEL, internal::OpJointMatrixWorkItemLengthINTEL) +_SPIRV_OP_INTERNAL(JointMatrixGetElementCoordINTEL, + internal::OpJointMatrixGetElementCoordINTEL) _SPIRV_OP_INTERNAL(ComplexFMulINTEL, internal::ComplexFMulINTEL) _SPIRV_OP_INTERNAL(ComplexFDivINTEL, internal::ComplexFDivINTEL) _SPIRV_OP_INTERNAL(MaskedGatherINTEL, internal::OpMaskedGatherINTEL) diff --git a/lib/SPIRV/libSPIRV/spirv_internal.hpp b/lib/SPIRV/libSPIRV/spirv_internal.hpp index e0acc6bb8d..0a90b5922a 100644 --- a/lib/SPIRV/libSPIRV/spirv_internal.hpp +++ b/lib/SPIRV/libSPIRV/spirv_internal.hpp @@ -75,6 +75,7 @@ enum InternalOp { IOpRoundFToTF32INTEL = 6426, IOpMaskedGatherINTEL = 6428, IOpMaskedScatterINTEL = 6429, + IOpJointMatrixGetElementCoordINTEL = 6440, IOpPrev = OpMax - 2, IOpForward }; @@ -109,6 +110,7 @@ enum InternalCapability { ICapabilityComplexFloatMulDivINTEL = 6414, ICapabilityTensorFloat32RoundingINTEL = 6425, ICapabilityMaskedGatherScatterINTEL = 6427, + ICapabilityJointMatrixWIInstructionsINTEL = 6435, ICapabilityCacheControlsINTEL = 6441 }; @@ -155,6 +157,7 @@ enum class StoreCacheControlINTEL { #define _SPIRV_OP(x, y) constexpr x x##y = static_cast(I##x##y); _SPIRV_OP(Capability, JointMatrixINTEL) +_SPIRV_OP(Capability, JointMatrixWIInstructionsINTEL) _SPIRV_OP(Op, TypeJointMatrixINTEL) _SPIRV_OP(Op, JointMatrixLoadINTEL) _SPIRV_OP(Op, JointMatrixStoreINTEL) @@ -163,6 +166,8 @@ _SPIRV_OP(Op, JointMatrixSUMadINTEL) _SPIRV_OP(Op, JointMatrixUSMadINTEL) _SPIRV_OP(Op, JointMatrixUUMadINTEL) _SPIRV_OP(Op, JointMatrixWorkItemLengthINTEL) +_SPIRV_OP(Op, JointMatrixGetElementCoordINTEL) + _SPIRV_OP(Capability, HWThreadQueryINTEL) _SPIRV_OP(BuiltIn, SubDeviceIDINTEL) _SPIRV_OP(BuiltIn, GlobalHWThreadIDINTEL) diff --git a/test/transcoding/SPV_INTEL_joint_matrix/joint_matrix_element.ll b/test/transcoding/SPV_INTEL_joint_matrix/joint_matrix_element.ll index f6519cfd91..f0262979e5 100644 --- a/test/transcoding/SPV_INTEL_joint_matrix/joint_matrix_element.ll +++ b/test/transcoding/SPV_INTEL_joint_matrix/joint_matrix_element.ll @@ -5,21 +5,26 @@ ; RUN: llvm-spirv -r %t.spv -o %t.rev.bc ; RUN: llvm-dis %t.rev.bc -o - | FileCheck %s --check-prefix=CHECK-LLVM -; CHECK-SPIRV: Capability JointMatrixINTEL -; CHECK-SPIRV: Extension "SPV_INTEL_joint_matrix" -; CHECK-SPIRV: TypeInt [[#TypeInt:]] 64 -; CHECK-SPIRV: TypeFloat [[#TypeFloat:]] 32 -; CHECK-SPIRV: TypeJointMatrixINTEL [[#TypeMatrix:]] [[#TypeFloat]] [[#]] [[#]] [[#]] [[#]] +; CHECK-SPIRV-DAG: Capability JointMatrixINTEL +; CHECK-SPIRV-DAG: Capability JointMatrixWIInstructionsINTEL +; CHECK-SPIRV-DAG: Extension "SPV_INTEL_joint_matrix" +; CHECK-SPIRV-DAG: TypeInt [[#TypeInt32:]] 32 +; CHECK-SPIRV-DAG: TypeInt [[#TypeInt64:]] 64 +; CHECK-SPIRV-DAG: TypeFloat [[#TypeFloat:]] 32 +; CHECK-SPIRV-DAG: TypeJointMatrixINTEL [[#TypeMatrix:]] [[#TypeFloat]] [[#]] [[#]] [[#]] [[#]] +; CHECK-SPIRV-DAG: TypeVector [[#TypeVec:]] [[#TypeInt32]] 2 ; CHECK-SPIRV: Phi [[#TypeMatrix]] [[#Matrix:]] -; CHECK-SPIRV: JointMatrixWorkItemLengthINTEL [[#TypeInt]] [[#]] [[#Matrix]] +; CHECK-SPIRV: JointMatrixWorkItemLengthINTEL [[#TypeInt64]] [[#]] [[#Matrix]] ; CHECK-SPIRV: VectorExtractDynamic [[#TypeFloat]] [[#]] [[#Matrix]] [[#Index:]] ; CHECK-SPIRV: FMul [[#TypeFloat]] [[#NewVal:]] [[#]] [[#]] ; CHECK-SPIRV: VectorInsertDynamic [[#TypeMatrix]] [[#]] [[#Matrix]] [[#NewVal]] [[#Index]] +; CHECK-SPIRV: JointMatrixGetElementCoordINTEL [[#TypeVec]] [[#]] [[#Matrix]] [[#Index]] ; CHECK-LLVM: [[Length:%.*]] = call spir_func i64 @_Z38__spirv_JointMatrixWorkItemLengthINTELPU3AS141__spirv_JointMatrixINTEL__float_16_16_0_3(%spirv.JointMatrixINTEL._float_16_16_0_3 addrspace(1)* [[Matrix:%.*]]) ; CHECK-LLVM: [[Elem:%.*]] = call spir_func float @_Z28__spirv_VectorExtractDynamicPU3AS141__spirv_JointMatrixINTEL__float_16_16_0_3l(%spirv.JointMatrixINTEL._float_16_16_0_3 addrspace(1)* [[Matrix]], i64 [[Index:%.*]]) ; CHECK-LLVM: [[NewVal:%.*]] = fmul float [[Elem]], 5.000000e+00 ; CHECK-LLVM: {{%.*}} = call spir_func %spirv.JointMatrixINTEL._float_16_16_0_3 addrspace(1)* @_Z27__spirv_VectorInsertDynamicPU3AS141__spirv_JointMatrixINTEL__float_16_16_0_3fl(%spirv.JointMatrixINTEL._float_16_16_0_3 addrspace(1)* [[Matrix]], float [[NewVal]], i64 [[Index]]) +; CHECK-LLVM: {{%.*}} = call spir_func <2 x i32> @_Z39__spirv_JointMatrixGetElementCoordINTELPU3AS141__spirv_JointMatrixINTEL__float_16_16_0_3l(%spirv.JointMatrixINTEL._float_16_16_0_3 addrspace(1)* [[Matrix]], i64 [[Index]]) source_filename = "/work/tmp/matrix-slice.cpp" target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64" @@ -69,6 +74,7 @@ for.body.i: ; preds = %for.cond.i %call.i.i = tail call spir_func float @_Z28__spirv_VectorExtractDynamicIfLm16ELm16ELN5__spv12MatrixLayoutE0ELNS0_5Scope4FlagE3EmET_PNS0_24__spirv_JointMatrixINTELIS4_XT0_EXT1_EXT2_EXT3_EEET4_(%spirv.JointMatrixINTEL._float_16_16_0_3 addrspace(4)* %A.sroa.0.0.i, i64 %conv.i) #2 %mul.i.i = fmul float %call.i.i, 5.000000e+00 %call5.i.i = tail call spir_func %spirv.JointMatrixINTEL._float_16_16_0_3 addrspace(4)* @_Z27__spirv_VectorInsertDynamicIfLm16ELm16ELN5__spv12MatrixLayoutE0ELNS0_5Scope4FlagE3EmEPNS0_24__spirv_JointMatrixINTELIT_XT0_EXT1_EXT2_EXT3_EEES7_T4_S5_(%spirv.JointMatrixINTEL._float_16_16_0_3 addrspace(4)* %A.sroa.0.0.i, float %mul.i.i, i64 %conv.i) #2 + %call6 = tail call spir_func <2 x i32> @_Z39__spirv_JointMatrixGetElementCoordINTELIaLm8ELm32ELN5__spv9MatrixUseE0ELNS0_12MatrixLayoutE0ELNS0_5Scope4FlagE3EEDv2_jPNS0_24__spirv_JointMatrixINTELIT_XT0_EXT1_EXT3_EXT4_EXT2_EEEm(%spirv.JointMatrixINTEL._float_16_16_0_3 addrspace(4)* %A.sroa.0.0.i, i64 %conv.i) #2 %inc.i = add nuw nsw i32 %i.0.i, 1 br label %for.cond.i, !llvm.loop !7 @@ -92,6 +98,9 @@ declare dso_local spir_func %spirv.JointMatrixINTEL._float_16_16_0_3 addrspace(4 ; Function Attrs: convergent declare dso_local spir_func void @_Z29__spirv_JointMatrixStoreINTELIfLm16ELm16ELN5__spv12MatrixLayoutE0ELNS0_5Scope4FlagE3EEvPT_PNS0_24__spirv_JointMatrixINTELIS4_XT0_EXT1_EXT2_EXT3_EEEmS1_S3_i(float addrspace(4)*, %spirv.JointMatrixINTEL._float_16_16_0_3 addrspace(4)*, i64, i32, i32, i32) local_unnamed_addr #1 +; Function Attrs: convergent +declare dso_local spir_func <2 x i32> @_Z39__spirv_JointMatrixGetElementCoordINTELIaLm8ELm32ELN5__spv9MatrixUseE0ELNS0_12MatrixLayoutE0ELNS0_5Scope4FlagE3EEDv2_jPNS0_24__spirv_JointMatrixINTELIT_XT0_EXT1_EXT3_EXT4_EXT2_EEEm(%spirv.JointMatrixINTEL._float_16_16_0_3 addrspace(4)*, i64) #2 + attributes #0 = { convergent norecurse "frame-pointer"="all" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "sycl-module-id"="/work/tmp/matrix-slice.cpp" "uniform-work-group-size"="true" } attributes #1 = { convergent "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" } attributes #2 = { convergent }