From 1e33c8bfd2ba079639f068fcdcd6f2c84d29d9c9 Mon Sep 17 00:00:00 2001 From: "Levytskyy, Vyacheslav" Date: Tue, 14 Nov 2023 09:39:22 -0800 Subject: [PATCH 1/6] adding OpCooperativeMatrixApplyFunctionINTEL operation and CooperativeMatrixInvocationInstructionsINTEL capability --- lib/SPIRV/libSPIRV/SPIRVInstruction.h | 20 ++++++++++++++++++++ lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h | 2 ++ lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h | 2 ++ lib/SPIRV/libSPIRV/spirv_internal.hpp | 5 +++++ 4 files changed, 29 insertions(+) diff --git a/lib/SPIRV/libSPIRV/SPIRVInstruction.h b/lib/SPIRV/libSPIRV/SPIRVInstruction.h index 5e2bdde59f..163be06842 100644 --- a/lib/SPIRV/libSPIRV/SPIRVInstruction.h +++ b/lib/SPIRV/libSPIRV/SPIRVInstruction.h @@ -3420,6 +3420,26 @@ class SPIRVCooperativeMatrixPrefetchINTELInstBase _SPIRV_OP(CooperativeMatrixPrefetch, false, 8, true, 5) #undef _SPIRV_OP +class SPIRVCooperativeMatrixInvocationInstructionsINTELInstBase + : public SPIRVInstTemplateBase { +protected: + std::optional getRequiredExtension() const override { + return ExtensionID::SPV_INTEL_joint_matrix; + } + SPIRVCapVec getRequiredCapability() const override { + return getVec( + internal::CapabilityCooperativeMatrixInvocationInstructionsINTEL); + } +}; + +#define _SPIRV_OP(x, ...) \ + typedef SPIRVInstTemplate< \ + SPIRVCooperativeMatrixInvocationInstructionsINTELInstBase, \ + internal::Op##x##INTEL, __VA_ARGS__> \ + SPIRV##x##INTEL; +_SPIRV_OP(CooperativeMatrixApplyFunction, true, 4, true) +#undef _SPIRV_OP + class SPIRVCooperativeMatrixKHRInstBase : public SPIRVInstTemplateBase { protected: std::optional getRequiredExtension() const override { diff --git a/lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h b/lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h index 41d34bb04c..17ffafdd53 100644 --- a/lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h +++ b/lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h @@ -657,6 +657,8 @@ template <> inline void SPIRVMap::init() { add(internal::CapabilityCacheControlsINTEL, "CacheControlsINTEL"); add(internal::CapabilityCooperativeMatrixPrefetchINTEL, "CooperativeMatrixPrefetchINTEL"); + add(internal::CapabilityCooperativeMatrixInvocationInstructionsINTEL, + "CooperativeMatrixInvocationInstructionsINTEL"); } SPIRV_DEF_NAMEMAP(Capability, SPIRVCapabilityNameMap) diff --git a/lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h b/lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h index 64a27e7e2c..5ce0a7118c 100644 --- a/lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h +++ b/lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h @@ -18,6 +18,8 @@ _SPIRV_OP_INTERNAL(JointMatrixGetElementCoordINTEL, internal::OpJointMatrixGetElementCoordINTEL) _SPIRV_OP_INTERNAL(CooperativeMatrixPrefetchINTEL, internal::OpCooperativeMatrixPrefetchINTEL) +_SPIRV_OP_INTERNAL(CooperativeMatrixApplyFunctionINTEL, + internal::OpCooperativeMatrixApplyFunctionINTEL) _SPIRV_OP_INTERNAL(ComplexFMulINTEL, internal::ComplexFMulINTEL) _SPIRV_OP_INTERNAL(ComplexFDivINTEL, internal::ComplexFDivINTEL) _SPIRV_OP_INTERNAL(MaskedGatherINTEL, internal::OpMaskedGatherINTEL) diff --git a/lib/SPIRV/libSPIRV/spirv_internal.hpp b/lib/SPIRV/libSPIRV/spirv_internal.hpp index 34d7ae2cc9..780ebf44ff 100644 --- a/lib/SPIRV/libSPIRV/spirv_internal.hpp +++ b/lib/SPIRV/libSPIRV/spirv_internal.hpp @@ -78,6 +78,7 @@ enum InternalOp { IOpMaskedScatterINTEL = 6429, IOpJointMatrixGetElementCoordINTEL = 6440, IOpCooperativeMatrixPrefetchINTEL = 6449, + IOpCooperativeMatrixApplyFunctionINTEL = 6448, IOpPrev = OpMax - 2, IOpForward }; @@ -104,6 +105,7 @@ enum InternalCapability { ICapabilityTensorFloat32RoundingINTEL = 6425, ICapabilityMaskedGatherScatterINTEL = 6427, ICapabilityJointMatrixWIInstructionsINTEL = 6435, + ICapabilityCooperativeMatrixInvocationInstructionsINTEL = 6435, ICapabilityJointMatrixTF32ComponentTypeINTEL = 6436, ICapabilityJointMatrixBF16ComponentTypeINTEL = 6437, ICapabilityJointMatrixPackedInt2ComponentTypeINTEL = 6438, @@ -178,6 +180,9 @@ _SPIRV_OP(Op, JointMatrixGetElementCoordINTEL) _SPIRV_OP(Capability, CooperativeMatrixPrefetchINTEL) _SPIRV_OP(Op, CooperativeMatrixPrefetchINTEL) +_SPIRV_OP(Capability, CooperativeMatrixInvocationInstructionsINTEL) +_SPIRV_OP(Op, CooperativeMatrixApplyFunctionINTEL) + _SPIRV_OP(Capability, HWThreadQueryINTEL) _SPIRV_OP(BuiltIn, SubDeviceIDINTEL) _SPIRV_OP(BuiltIn, GlobalHWThreadIDINTEL) From d039fddeee00a276d30a6592578530ec5d0c4a04 Mon Sep 17 00:00:00 2001 From: "Sidorov, Dmitry" Date: Tue, 14 Nov 2023 17:56:18 -0800 Subject: [PATCH 2/6] Add test TODO update spec Signed-off-by: Sidorov, Dmitry --- lib/SPIRV/libSPIRV/SPIRVInstruction.h | 2 +- .../cooperative_matrix_apply.ll | 163 ++++++++++++++++++ 2 files changed, 164 insertions(+), 1 deletion(-) create mode 100644 test/extensions/INTEL/SPV_INTEL_joint_matrix/cooperative_matrix_apply.ll diff --git a/lib/SPIRV/libSPIRV/SPIRVInstruction.h b/lib/SPIRV/libSPIRV/SPIRVInstruction.h index 163be06842..66be3e72d7 100644 --- a/lib/SPIRV/libSPIRV/SPIRVInstruction.h +++ b/lib/SPIRV/libSPIRV/SPIRVInstruction.h @@ -3437,7 +3437,7 @@ class SPIRVCooperativeMatrixInvocationInstructionsINTELInstBase SPIRVCooperativeMatrixInvocationInstructionsINTELInstBase, \ internal::Op##x##INTEL, __VA_ARGS__> \ SPIRV##x##INTEL; -_SPIRV_OP(CooperativeMatrixApplyFunction, true, 4, true) +_SPIRV_OP(CooperativeMatrixApplyFunction, true, 5) #undef _SPIRV_OP class SPIRVCooperativeMatrixKHRInstBase : public SPIRVInstTemplateBase { diff --git a/test/extensions/INTEL/SPV_INTEL_joint_matrix/cooperative_matrix_apply.ll b/test/extensions/INTEL/SPV_INTEL_joint_matrix/cooperative_matrix_apply.ll new file mode 100644 index 0000000000..2c7153cb61 --- /dev/null +++ b/test/extensions/INTEL/SPV_INTEL_joint_matrix/cooperative_matrix_apply.ll @@ -0,0 +1,163 @@ +;; compiled from joint_matrix_apply_bf16.cpp from intel/llvm with some modifications + +; RUN: llvm-as < %s -o %t.bc +; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_KHR_cooperative_matrix,+SPV_INTEL_joint_matrix -o %t.spv +; RUN: llvm-spirv %t.spv -to-text -o %t.spt +; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV + +; RUN: llvm-spirv -r %t.spv -o %t.rev.bc +; RUN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefix=CHECK-LLVM + +; CHECK-SPIRV-DAG: Capability CooperativeMatrixKHR +; CHECK-SPIRV-DAG: Capability CooperativeMatrixInvocationInstructionsINTEL +; CHECK-SPIRV-DAG: Extension "SPV_INTEL_joint_matrix" +; CHECK-SPIRV-DAG: Extension "SPV_KHR_cooperative_matrix" +; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#MatTy:]] +; CHECK-SPIRV: CompositeConstruct [[#MatTy]] [[#Mat:]] +; CHECK-SPIRV: PtrCastToGeneric [[#]] [[#Ptr:]] [[#]] +; CHECK-SPIRV: CooperativeMatrixApplyFunctionINTEL [[#MatTy]] [[#Apply:]] [[#Ptr]] [[#Mat]] +; CHECK-SPIRV: CooperativeMatrixStoreKHR [[#]] [[#Apply]] + +; CHECK-LLVM: %[[Mat:[%0-9a-z.]+]] = call spir_func target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 3) @"_Z26__spirv_CompositeConstructP38class.sycl::_V1::ext::oneapi::bfloat16" +; CHECK-LLVM: %[[Apply:[%0-9a-z.]+]] = call spir_func target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 3) @"_Z43__spirv_CooperativeMatrixApplyFunctionINTELPU3AS477class.sycl::_V1::ext::oneapi::experimental::matrix::helper::reference_wrapperPU3AS144__spirv_CooperativeMatrixKHR__short_8_16_0_3"(ptr addrspace(4) %ref.tmp.ascast.i21, target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 3) %[[Mat]]) +; CHECK-LLVM: call spir_func void @"_Z33__spirv_CooperativeMatrixStoreKHRPU3AS138class.sycl::_V1::ext::oneapi::bfloat16PU3AS144__spirv_CooperativeMatrixKHR__short_8_16_0_3liii"(ptr addrspace(1) %{{.*}}, target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 3) %[[Apply]], i64 32, i32 0, i32 3, i32 0) + + +; ModuleID = 'matrix_apply.bc' +source_filename = "../llvm/sycl/test-e2e/Matrix/joint_matrix_apply_bf16.cpp" +target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64" +target triple = "spir64-unknown-unknown" + +%"class.sycl::_V1::range" = type { %"class.sycl::_V1::detail::array" } +%"class.sycl::_V1::detail::array" = type { [2 x i64] } +%"class.sycl::_V1::id" = type { %"class.sycl::_V1::detail::array" } +%"class.sycl::_V1::ext::oneapi::experimental::matrix::helper::reference_wrapper" = type { ptr addrspace(4) } +%"class.sycl::_V1::ext::oneapi::bfloat16" = type { i16 } +%class.anon.0 = type <{ %"class.sycl::_V1::accessor", %class.anon, [7 x i8] }> +%"class.sycl::_V1::accessor" = type { %"class.sycl::_V1::detail::AccessorImplDevice", %union.anon } +%"class.sycl::_V1::detail::AccessorImplDevice" = type { %"class.sycl::_V1::id", %"class.sycl::_V1::range", %"class.sycl::_V1::range" } +%union.anon = type { ptr addrspace(1) } +%class.anon = type { i8 } + +$_ZTSZZ17matrix_verify_addIN4sycl3_V13ext6oneapi8bfloat16ELm16ELm32EZ4mainEUlRS4_E_EvNS1_5queueER10big_matrixIT_XT0_EXT1_EERNS1_8nd_rangeILi2EEEfOT2_ENKUlRNS1_7handlerEE_clESI_EUlNS1_7nd_itemILi2EEEE_ = comdat any + +@__spirv_BuiltInGlobalInvocationId = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32 +@__spirv_BuiltInLocalInvocationId = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32 + +; Function Attrs: convergent norecurse nounwind +define weak_odr dso_local spir_kernel void @_ZTSZZ17matrix_verify_addIN4sycl3_V13ext6oneapi8bfloat16ELm16ELm32EZ4mainEUlRS4_E_EvNS1_5queueER10big_matrixIT_XT0_EXT1_EERNS1_8nd_rangeILi2EEEfOT2_ENKUlRNS1_7handlerEE_clESI_EUlNS1_7nd_itemILi2EEEE_(ptr addrspace(1) noundef align 2 %_arg_accA, ptr noundef byval(%"class.sycl::_V1::range") align 8 %_arg_accA1, ptr noundef byval(%"class.sycl::_V1::range") align 8 %_arg_accA2, ptr noundef byval(%"class.sycl::_V1::id") align 8 %_arg_accA3) local_unnamed_addr { +entry: + call spir_func void @__itt_offload_wi_start_wrapper() + %ref.tmp.i20 = alloca %"class.sycl::_V1::ext::oneapi::experimental::matrix::helper::reference_wrapper", align 8 + %agg.tmp.i17 = alloca %"class.sycl::_V1::ext::oneapi::bfloat16", align 2 + %ref.tmp6.i = alloca float, align 4 + %__SYCLKernel = alloca %class.anon.0, align 8 + %__SYCLKernel.ascast = addrspacecast ptr %__SYCLKernel to ptr addrspace(4) + call void @llvm.lifetime.start.p0(i64 64, ptr nonnull %__SYCLKernel) + %agg.tmp.sroa.0.sroa.0.0.copyload = load i64, ptr %_arg_accA1, align 8 + %agg.tmp.sroa.0.sroa.2.0._arg_accA1.ascast.sroa_idx = getelementptr inbounds i8, ptr %_arg_accA1, i64 8 + %agg.tmp.sroa.0.sroa.2.0.copyload = load i64, ptr %agg.tmp.sroa.0.sroa.2.0._arg_accA1.ascast.sroa_idx, align 8 + %agg.tmp5.sroa.0.sroa.0.0.copyload = load i64, ptr %_arg_accA2, align 8 + %agg.tmp5.sroa.0.sroa.2.0._arg_accA2.ascast.sroa_idx = getelementptr inbounds i8, ptr %_arg_accA2, i64 8 + %agg.tmp5.sroa.0.sroa.2.0.copyload = load i64, ptr %agg.tmp5.sroa.0.sroa.2.0._arg_accA2.ascast.sroa_idx, align 8 + %agg.tmp6.sroa.0.sroa.0.0.copyload = load i64, ptr %_arg_accA3, align 8 + %agg.tmp6.sroa.0.sroa.2.0._arg_accA3.ascast.sroa_idx = getelementptr inbounds i8, ptr %_arg_accA3, i64 8 + %agg.tmp6.sroa.0.sroa.2.0.copyload = load i64, ptr %agg.tmp6.sroa.0.sroa.2.0._arg_accA3.ascast.sroa_idx, align 8 + %0 = getelementptr inbounds %"class.sycl::_V1::accessor", ptr %__SYCLKernel, i64 0, i32 1 + store i64 %agg.tmp6.sroa.0.sroa.0.0.copyload, ptr %__SYCLKernel, align 8 + %AccessRange.i.i.i.i.i = getelementptr inbounds %"class.sycl::_V1::detail::AccessorImplDevice", ptr %__SYCLKernel, i64 0, i32 1 + store i64 %agg.tmp.sroa.0.sroa.0.0.copyload, ptr %AccessRange.i.i.i.i.i, align 8 + %MemRange.i.i.i.i.i = getelementptr inbounds %"class.sycl::_V1::detail::AccessorImplDevice", ptr %__SYCLKernel, i64 0, i32 2 + store i64 %agg.tmp5.sroa.0.sroa.0.0.copyload, ptr %MemRange.i.i.i.i.i, align 8 + %arrayidx.i21.i.i.i.i = getelementptr inbounds [2 x i64], ptr %__SYCLKernel, i64 0, i64 1 + store i64 %agg.tmp6.sroa.0.sroa.2.0.copyload, ptr %arrayidx.i21.i.i.i.i, align 8 + %arrayidx.i25.i.i.i.i = getelementptr inbounds %"class.sycl::_V1::detail::AccessorImplDevice", ptr %__SYCLKernel, i64 0, i32 1, i32 0, i32 0, i64 1 + store i64 %agg.tmp.sroa.0.sroa.2.0.copyload, ptr %arrayidx.i25.i.i.i.i, align 8 + %arrayidx.i29.i.i.i.i = getelementptr inbounds %"class.sycl::_V1::detail::AccessorImplDevice", ptr %__SYCLKernel, i64 0, i32 2, i32 0, i32 0, i64 1 + store i64 %agg.tmp5.sroa.0.sroa.2.0.copyload, ptr %arrayidx.i29.i.i.i.i, align 8 + %mul.i6.i.i.i.i = mul i64 %agg.tmp6.sroa.0.sroa.0.0.copyload, %agg.tmp5.sroa.0.sroa.2.0.copyload + %1 = getelementptr %"class.sycl::_V1::ext::oneapi::bfloat16", ptr addrspace(1) %_arg_accA, i64 %mul.i6.i.i.i.i + %add.ptr.i = getelementptr %"class.sycl::_V1::ext::oneapi::bfloat16", ptr addrspace(1) %1, i64 %agg.tmp6.sroa.0.sroa.2.0.copyload + store ptr addrspace(1) %add.ptr.i, ptr %0, align 8 + %2 = load i64, ptr addrspace(1) getelementptr inbounds (i8, ptr addrspace(1) @__spirv_BuiltInGlobalInvocationId, i64 8), align 8 + %3 = load i64, ptr addrspace(1) @__spirv_BuiltInGlobalInvocationId, align 32 + %4 = load i64, ptr addrspace(1) getelementptr inbounds (i8, ptr addrspace(1) @__spirv_BuiltInLocalInvocationId, i64 8), align 8 + %5 = load i64, ptr addrspace(1) @__spirv_BuiltInLocalInvocationId, align 32 + %ref.tmp6.ascast.i = addrspacecast ptr %ref.tmp6.i to ptr addrspace(4) + %cmp.i11 = icmp ult i64 %2, 2147483648 + tail call void @llvm.assume(i1 %cmp.i11) + %cmp.i = icmp ult i64 %3, 2147483648 + tail call void @llvm.assume(i1 %cmp.i) + %cmp.i15 = icmp ult i64 %4, 2147483648 + tail call void @llvm.assume(i1 %cmp.i15) + %sub.i = sub nsw i64 %2, %4 + %cmp.i12 = icmp ult i64 %5, 2147483648 + tail call void @llvm.assume(i1 %cmp.i12) + %sub5.i = sub nsw i64 %3, %5 + call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %ref.tmp6.i) + store float 5.000000e+00, ptr %ref.tmp6.i, align 4 + %call.i.i = call spir_func noundef zeroext i16 @__devicelib_ConvertFToBF16INTEL(ptr addrspace(4) noundef align 4 dereferenceable(4) %ref.tmp6.ascast.i) + call void @llvm.lifetime.start.p0(i64 2, ptr nonnull %agg.tmp.i17) + store i16 %call.i.i, ptr %agg.tmp.i17, align 2 + %call.i18 = call spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 3) @_Z26__spirv_CompositeConstruct(ptr noundef nonnull byval(%"class.sycl::_V1::ext::oneapi::bfloat16") align 2 %agg.tmp.i17) + call void @llvm.lifetime.end.p0(i64 2, ptr nonnull %agg.tmp.i17) + call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %ref.tmp6.i) + %lambda.i = getelementptr inbounds %class.anon.0, ptr addrspace(4) %__SYCLKernel.ascast, i64 0, i32 1 + %ref.tmp.ascast.i21 = addrspacecast ptr %ref.tmp.i20 to ptr addrspace(4) + call void @llvm.lifetime.start.p0(i64 8, ptr nonnull %ref.tmp.i20) + store ptr addrspace(4) %lambda.i, ptr %ref.tmp.i20, align 8 + %call.i22 = call spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 3) @_Z43__spirv_CooperativeMatrixApplyFunctionINTEL(ptr addrspace(4) noundef align 8 dereferenceable(8) %ref.tmp.ascast.i21, target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 3) noundef %call.i18) + call void @llvm.lifetime.end.p0(i64 8, ptr nonnull %ref.tmp.i20) + %6 = load ptr addrspace(1), ptr %0, align 8 + %7 = load i64, ptr %__SYCLKernel, align 8 + %8 = load i64, ptr %arrayidx.i29.i.i.i.i, align 8 + %mul.i6.i.i.i.i.i = mul i64 %7, %8 + %9 = load i64, ptr %arrayidx.i21.i.i.i.i, align 8 + %add.i7.i.i.i.i.i = add i64 %mul.i6.i.i.i.i.i, %9 + %idx.neg.i.i = sub i64 0, %add.i7.i.i.i.i.i + %add.ptr.i.i = getelementptr inbounds %"class.sycl::_V1::ext::oneapi::bfloat16", ptr addrspace(1) %6, i64 %idx.neg.i.i + %mul12.i = shl nsw i64 %sub.i, 8 + %add.ptr.i43 = getelementptr inbounds %"class.sycl::_V1::ext::oneapi::bfloat16", ptr addrspace(1) %add.ptr.i.i, i64 %mul12.i + %div14.i = and i64 %sub5.i, -16 + %add.ptr.i44 = getelementptr inbounds %"class.sycl::_V1::ext::oneapi::bfloat16", ptr addrspace(1) %add.ptr.i43, i64 %div14.i + call spir_func void @_Z33__spirv_CooperativeMatrixStoreKHRPU3AS4iPU3AS144__spirv_CooperativeMatrixKHR__uint_3_12_12_3ili(ptr addrspace(1) noundef %add.ptr.i44, target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 3) noundef %call.i22, i64 noundef 32, i32 noundef 0, i32 noundef 3, i32 noundef 0) + call void @llvm.lifetime.end.p0(i64 64, ptr nonnull %__SYCLKernel) + call spir_func void @__itt_offload_wi_finish_wrapper() + ret void +} + +; Function Attrs: nocallback nofree nosync nounwind willreturn memory(argmem: readwrite) +declare void @llvm.lifetime.start.p0(i64 immarg, ptr nocapture) + +; Function Attrs: nocallback nofree nosync nounwind willreturn memory(argmem: readwrite) +declare void @llvm.lifetime.end.p0(i64 immarg, ptr nocapture) + +; Function Attrs: nocallback nofree nosync nounwind willreturn memory(inaccessiblemem: write) +declare void @llvm.assume(i1 noundef) + +; Function Attrs: convergent nounwind +declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 3) @_Z26__spirv_CompositeConstruct(ptr noundef byval(%"class.sycl::_V1::ext::oneapi::bfloat16") align 2) local_unnamed_addr + +; Function Attrs: convergent nounwind +declare dso_local spir_func zeroext i16 @__devicelib_ConvertFToBF16INTEL(ptr addrspace(4) noundef align 4 dereferenceable(4)) local_unnamed_addr + +; Function Attrs: convergent nounwind +declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 3) @_Z43__spirv_CooperativeMatrixApplyFunctionINTEL(ptr addrspace(4) noundef align 8 dereferenceable(8), target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 3) noundef) local_unnamed_addr + +; Function Attrs: convergent nounwind +declare dso_local spir_func void @_Z33__spirv_CooperativeMatrixStoreKHRPU3AS4iPU3AS144__spirv_CooperativeMatrixKHR__uint_3_12_12_3ili(ptr addrspace(1) noundef, target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 3) noundef, i64 noundef, i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr + +declare spir_func void @__itt_offload_wi_start_wrapper() + +declare spir_func void @__itt_offload_wi_finish_wrapper() + +!llvm.module.flags = !{!0, !1} +!opencl.spir.version = !{!2} +!spirv.Source = !{!3} +!llvm.ident = !{!4} + +!0 = !{i32 1, !"wchar_size", i32 4} +!1 = !{i32 7, !"frame-pointer", i32 2} +!2 = !{i32 1, i32 2} +!3 = !{i32 4, i32 100000} +!4 = !{!"clang version 18.0.0 (https://github.com/intel/llvm.git)"} From 7682dd455470772fc979d835a916a154aaf6db65 Mon Sep 17 00:00:00 2001 From: "Sidorov, Dmitry" Date: Wed, 15 Nov 2023 07:22:23 -0800 Subject: [PATCH 3/6] Remove some unnecessary things Signed-off-by: Sidorov, Dmitry --- .../cooperative_matrix_apply.ll | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/test/extensions/INTEL/SPV_INTEL_joint_matrix/cooperative_matrix_apply.ll b/test/extensions/INTEL/SPV_INTEL_joint_matrix/cooperative_matrix_apply.ll index 2c7153cb61..616b813fd0 100644 --- a/test/extensions/INTEL/SPV_INTEL_joint_matrix/cooperative_matrix_apply.ll +++ b/test/extensions/INTEL/SPV_INTEL_joint_matrix/cooperative_matrix_apply.ll @@ -22,7 +22,6 @@ ; CHECK-LLVM: %[[Apply:[%0-9a-z.]+]] = call spir_func target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 3) @"_Z43__spirv_CooperativeMatrixApplyFunctionINTELPU3AS477class.sycl::_V1::ext::oneapi::experimental::matrix::helper::reference_wrapperPU3AS144__spirv_CooperativeMatrixKHR__short_8_16_0_3"(ptr addrspace(4) %ref.tmp.ascast.i21, target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 3) %[[Mat]]) ; CHECK-LLVM: call spir_func void @"_Z33__spirv_CooperativeMatrixStoreKHRPU3AS138class.sycl::_V1::ext::oneapi::bfloat16PU3AS144__spirv_CooperativeMatrixKHR__short_8_16_0_3liii"(ptr addrspace(1) %{{.*}}, target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 3) %[[Apply]], i64 32, i32 0, i32 3, i32 0) - ; ModuleID = 'matrix_apply.bc' source_filename = "../llvm/sycl/test-e2e/Matrix/joint_matrix_apply_bf16.cpp" target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64" @@ -47,7 +46,6 @@ $_ZTSZZ17matrix_verify_addIN4sycl3_V13ext6oneapi8bfloat16ELm16ELm32EZ4mainEUlRS4 ; Function Attrs: convergent norecurse nounwind define weak_odr dso_local spir_kernel void @_ZTSZZ17matrix_verify_addIN4sycl3_V13ext6oneapi8bfloat16ELm16ELm32EZ4mainEUlRS4_E_EvNS1_5queueER10big_matrixIT_XT0_EXT1_EERNS1_8nd_rangeILi2EEEfOT2_ENKUlRNS1_7handlerEE_clESI_EUlNS1_7nd_itemILi2EEEE_(ptr addrspace(1) noundef align 2 %_arg_accA, ptr noundef byval(%"class.sycl::_V1::range") align 8 %_arg_accA1, ptr noundef byval(%"class.sycl::_V1::range") align 8 %_arg_accA2, ptr noundef byval(%"class.sycl::_V1::id") align 8 %_arg_accA3) local_unnamed_addr { entry: - call spir_func void @__itt_offload_wi_start_wrapper() %ref.tmp.i20 = alloca %"class.sycl::_V1::ext::oneapi::experimental::matrix::helper::reference_wrapper", align 8 %agg.tmp.i17 = alloca %"class.sycl::_V1::ext::oneapi::bfloat16", align 2 %ref.tmp6.i = alloca float, align 4 @@ -85,14 +83,10 @@ entry: %5 = load i64, ptr addrspace(1) @__spirv_BuiltInLocalInvocationId, align 32 %ref.tmp6.ascast.i = addrspacecast ptr %ref.tmp6.i to ptr addrspace(4) %cmp.i11 = icmp ult i64 %2, 2147483648 - tail call void @llvm.assume(i1 %cmp.i11) %cmp.i = icmp ult i64 %3, 2147483648 - tail call void @llvm.assume(i1 %cmp.i) %cmp.i15 = icmp ult i64 %4, 2147483648 - tail call void @llvm.assume(i1 %cmp.i15) %sub.i = sub nsw i64 %2, %4 %cmp.i12 = icmp ult i64 %5, 2147483648 - tail call void @llvm.assume(i1 %cmp.i12) %sub5.i = sub nsw i64 %3, %5 call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %ref.tmp6.i) store float 5.000000e+00, ptr %ref.tmp6.i, align 4 @@ -122,7 +116,6 @@ entry: %add.ptr.i44 = getelementptr inbounds %"class.sycl::_V1::ext::oneapi::bfloat16", ptr addrspace(1) %add.ptr.i43, i64 %div14.i call spir_func void @_Z33__spirv_CooperativeMatrixStoreKHRPU3AS4iPU3AS144__spirv_CooperativeMatrixKHR__uint_3_12_12_3ili(ptr addrspace(1) noundef %add.ptr.i44, target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 3) noundef %call.i22, i64 noundef 32, i32 noundef 0, i32 noundef 3, i32 noundef 0) call void @llvm.lifetime.end.p0(i64 64, ptr nonnull %__SYCLKernel) - call spir_func void @__itt_offload_wi_finish_wrapper() ret void } @@ -132,9 +125,6 @@ declare void @llvm.lifetime.start.p0(i64 immarg, ptr nocapture) ; Function Attrs: nocallback nofree nosync nounwind willreturn memory(argmem: readwrite) declare void @llvm.lifetime.end.p0(i64 immarg, ptr nocapture) -; Function Attrs: nocallback nofree nosync nounwind willreturn memory(inaccessiblemem: write) -declare void @llvm.assume(i1 noundef) - ; Function Attrs: convergent nounwind declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 3) @_Z26__spirv_CompositeConstruct(ptr noundef byval(%"class.sycl::_V1::ext::oneapi::bfloat16") align 2) local_unnamed_addr @@ -147,10 +137,6 @@ declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 8, ; Function Attrs: convergent nounwind declare dso_local spir_func void @_Z33__spirv_CooperativeMatrixStoreKHRPU3AS4iPU3AS144__spirv_CooperativeMatrixKHR__uint_3_12_12_3ili(ptr addrspace(1) noundef, target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 3) noundef, i64 noundef, i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr -declare spir_func void @__itt_offload_wi_start_wrapper() - -declare spir_func void @__itt_offload_wi_finish_wrapper() - !llvm.module.flags = !{!0, !1} !opencl.spir.version = !{!2} !spirv.Source = !{!3} From 49838443b44a24283c2b365be762729fa4390467 Mon Sep 17 00:00:00 2001 From: "Levytskyy, Vyacheslav" Date: Thu, 16 Nov 2023 10:16:56 -0800 Subject: [PATCH 4/6] implement requirement: implicitly declares CooperativeMatrixKHR --- lib/SPIRV/libSPIRV/SPIRVEnum.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lib/SPIRV/libSPIRV/SPIRVEnum.h b/lib/SPIRV/libSPIRV/SPIRVEnum.h index ccdc21a2a3..3b501fedbd 100644 --- a/lib/SPIRV/libSPIRV/SPIRVEnum.h +++ b/lib/SPIRV/libSPIRV/SPIRVEnum.h @@ -217,6 +217,8 @@ template <> inline void SPIRVMap::init() { {internal::CapabilityJointMatrixINTEL}); ADD_VEC_INIT(internal::CapabilityCooperativeMatrixPrefetchINTEL, {CapabilityCooperativeMatrixKHR}); + ADD_VEC_INIT(internal::CapabilityCooperativeMatrixInvocationInstructionsINTEL, + {CapabilityCooperativeMatrixKHR}); } template <> inline void SPIRVMap::init() { From e7ad51a811c2a1667cf4dddb7d6767624c3c1b00 Mon Sep 17 00:00:00 2001 From: "Levytskyy, Vyacheslav" Date: Fri, 17 Nov 2023 04:07:48 -0800 Subject: [PATCH 5/6] change tokens order --- lib/SPIRV/libSPIRV/spirv_internal.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/SPIRV/libSPIRV/spirv_internal.hpp b/lib/SPIRV/libSPIRV/spirv_internal.hpp index 780ebf44ff..2805a76115 100644 --- a/lib/SPIRV/libSPIRV/spirv_internal.hpp +++ b/lib/SPIRV/libSPIRV/spirv_internal.hpp @@ -77,8 +77,8 @@ enum InternalOp { IOpMaskedGatherINTEL = 6428, IOpMaskedScatterINTEL = 6429, IOpJointMatrixGetElementCoordINTEL = 6440, - IOpCooperativeMatrixPrefetchINTEL = 6449, IOpCooperativeMatrixApplyFunctionINTEL = 6448, + IOpCooperativeMatrixPrefetchINTEL = 6449, IOpPrev = OpMax - 2, IOpForward }; From a76e816b026394d1bd9f33b024149cb715a13248 Mon Sep 17 00:00:00 2001 From: "Levytskyy, Vyacheslav" Date: Fri, 17 Nov 2023 08:18:55 -0800 Subject: [PATCH 6/6] fix a wrong value of the Use parameter to match documented accepted range of CooperativeMatrixUse --- .../cooperative_matrix_apply.ll | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/test/extensions/INTEL/SPV_INTEL_joint_matrix/cooperative_matrix_apply.ll b/test/extensions/INTEL/SPV_INTEL_joint_matrix/cooperative_matrix_apply.ll index 616b813fd0..f85a5f0cc8 100644 --- a/test/extensions/INTEL/SPV_INTEL_joint_matrix/cooperative_matrix_apply.ll +++ b/test/extensions/INTEL/SPV_INTEL_joint_matrix/cooperative_matrix_apply.ll @@ -18,9 +18,9 @@ ; CHECK-SPIRV: CooperativeMatrixApplyFunctionINTEL [[#MatTy]] [[#Apply:]] [[#Ptr]] [[#Mat]] ; CHECK-SPIRV: CooperativeMatrixStoreKHR [[#]] [[#Apply]] -; CHECK-LLVM: %[[Mat:[%0-9a-z.]+]] = call spir_func target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 3) @"_Z26__spirv_CompositeConstructP38class.sycl::_V1::ext::oneapi::bfloat16" -; CHECK-LLVM: %[[Apply:[%0-9a-z.]+]] = call spir_func target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 3) @"_Z43__spirv_CooperativeMatrixApplyFunctionINTELPU3AS477class.sycl::_V1::ext::oneapi::experimental::matrix::helper::reference_wrapperPU3AS144__spirv_CooperativeMatrixKHR__short_8_16_0_3"(ptr addrspace(4) %ref.tmp.ascast.i21, target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 3) %[[Mat]]) -; CHECK-LLVM: call spir_func void @"_Z33__spirv_CooperativeMatrixStoreKHRPU3AS138class.sycl::_V1::ext::oneapi::bfloat16PU3AS144__spirv_CooperativeMatrixKHR__short_8_16_0_3liii"(ptr addrspace(1) %{{.*}}, target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 3) %[[Apply]], i64 32, i32 0, i32 3, i32 0) +; CHECK-LLVM: %[[Mat:[%0-9a-z.]+]] = call spir_func target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 0) @"_Z26__spirv_CompositeConstructP38class.sycl::_V1::ext::oneapi::bfloat16" +; CHECK-LLVM: %[[Apply:[%0-9a-z.]+]] = call spir_func target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 0) @"_Z43__spirv_CooperativeMatrixApplyFunctionINTELPU3AS477class.sycl::_V1::ext::oneapi::experimental::matrix::helper::reference_wrapperPU3AS144__spirv_CooperativeMatrixKHR__short_8_16_0_0"(ptr addrspace(4) %ref.tmp.ascast.i21, target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 0) %[[Mat]]) +; CHECK-LLVM: call spir_func void @"_Z33__spirv_CooperativeMatrixStoreKHRPU3AS138class.sycl::_V1::ext::oneapi::bfloat16PU3AS144__spirv_CooperativeMatrixKHR__short_8_16_0_0liii"(ptr addrspace(1) %{{.*}}, target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 0) %[[Apply]], i64 32, i32 0, i32 3, i32 0) ; ModuleID = 'matrix_apply.bc' source_filename = "../llvm/sycl/test-e2e/Matrix/joint_matrix_apply_bf16.cpp" @@ -93,14 +93,14 @@ entry: %call.i.i = call spir_func noundef zeroext i16 @__devicelib_ConvertFToBF16INTEL(ptr addrspace(4) noundef align 4 dereferenceable(4) %ref.tmp6.ascast.i) call void @llvm.lifetime.start.p0(i64 2, ptr nonnull %agg.tmp.i17) store i16 %call.i.i, ptr %agg.tmp.i17, align 2 - %call.i18 = call spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 3) @_Z26__spirv_CompositeConstruct(ptr noundef nonnull byval(%"class.sycl::_V1::ext::oneapi::bfloat16") align 2 %agg.tmp.i17) + %call.i18 = call spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 0) @_Z26__spirv_CompositeConstruct(ptr noundef nonnull byval(%"class.sycl::_V1::ext::oneapi::bfloat16") align 2 %agg.tmp.i17) call void @llvm.lifetime.end.p0(i64 2, ptr nonnull %agg.tmp.i17) call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %ref.tmp6.i) %lambda.i = getelementptr inbounds %class.anon.0, ptr addrspace(4) %__SYCLKernel.ascast, i64 0, i32 1 %ref.tmp.ascast.i21 = addrspacecast ptr %ref.tmp.i20 to ptr addrspace(4) call void @llvm.lifetime.start.p0(i64 8, ptr nonnull %ref.tmp.i20) store ptr addrspace(4) %lambda.i, ptr %ref.tmp.i20, align 8 - %call.i22 = call spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 3) @_Z43__spirv_CooperativeMatrixApplyFunctionINTEL(ptr addrspace(4) noundef align 8 dereferenceable(8) %ref.tmp.ascast.i21, target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 3) noundef %call.i18) + %call.i22 = call spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 0) @_Z43__spirv_CooperativeMatrixApplyFunctionINTEL(ptr addrspace(4) noundef align 8 dereferenceable(8) %ref.tmp.ascast.i21, target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 0) noundef %call.i18) call void @llvm.lifetime.end.p0(i64 8, ptr nonnull %ref.tmp.i20) %6 = load ptr addrspace(1), ptr %0, align 8 %7 = load i64, ptr %__SYCLKernel, align 8 @@ -114,7 +114,7 @@ entry: %add.ptr.i43 = getelementptr inbounds %"class.sycl::_V1::ext::oneapi::bfloat16", ptr addrspace(1) %add.ptr.i.i, i64 %mul12.i %div14.i = and i64 %sub5.i, -16 %add.ptr.i44 = getelementptr inbounds %"class.sycl::_V1::ext::oneapi::bfloat16", ptr addrspace(1) %add.ptr.i43, i64 %div14.i - call spir_func void @_Z33__spirv_CooperativeMatrixStoreKHRPU3AS4iPU3AS144__spirv_CooperativeMatrixKHR__uint_3_12_12_3ili(ptr addrspace(1) noundef %add.ptr.i44, target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 3) noundef %call.i22, i64 noundef 32, i32 noundef 0, i32 noundef 3, i32 noundef 0) + call spir_func void @_Z33__spirv_CooperativeMatrixStoreKHRPU3AS4iPU3AS144__spirv_CooperativeMatrixKHR__uint_3_12_12_3ili(ptr addrspace(1) noundef %add.ptr.i44, target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 0) noundef %call.i22, i64 noundef 32, i32 noundef 0, i32 noundef 3, i32 noundef 0) call void @llvm.lifetime.end.p0(i64 64, ptr nonnull %__SYCLKernel) ret void } @@ -126,16 +126,16 @@ declare void @llvm.lifetime.start.p0(i64 immarg, ptr nocapture) declare void @llvm.lifetime.end.p0(i64 immarg, ptr nocapture) ; Function Attrs: convergent nounwind -declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 3) @_Z26__spirv_CompositeConstruct(ptr noundef byval(%"class.sycl::_V1::ext::oneapi::bfloat16") align 2) local_unnamed_addr +declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 0) @_Z26__spirv_CompositeConstruct(ptr noundef byval(%"class.sycl::_V1::ext::oneapi::bfloat16") align 2) local_unnamed_addr ; Function Attrs: convergent nounwind declare dso_local spir_func zeroext i16 @__devicelib_ConvertFToBF16INTEL(ptr addrspace(4) noundef align 4 dereferenceable(4)) local_unnamed_addr ; Function Attrs: convergent nounwind -declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 3) @_Z43__spirv_CooperativeMatrixApplyFunctionINTEL(ptr addrspace(4) noundef align 8 dereferenceable(8), target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 3) noundef) local_unnamed_addr +declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 0) @_Z43__spirv_CooperativeMatrixApplyFunctionINTEL(ptr addrspace(4) noundef align 8 dereferenceable(8), target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 0) noundef) local_unnamed_addr ; Function Attrs: convergent nounwind -declare dso_local spir_func void @_Z33__spirv_CooperativeMatrixStoreKHRPU3AS4iPU3AS144__spirv_CooperativeMatrixKHR__uint_3_12_12_3ili(ptr addrspace(1) noundef, target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 3) noundef, i64 noundef, i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr +declare dso_local spir_func void @_Z33__spirv_CooperativeMatrixStoreKHRPU3AS4iPU3AS144__spirv_CooperativeMatrixKHR__uint_3_12_12_3ili(ptr addrspace(1) noundef, target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 0) noundef, i64 noundef, i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr !llvm.module.flags = !{!0, !1} !opencl.spir.version = !{!2}