From 1010efc8bd9e367597b2c677a0b0df0c14e7b051 Mon Sep 17 00:00:00 2001 From: Dmitry Sidorov Date: Thu, 16 Nov 2023 16:45:41 +0100 Subject: [PATCH] Enable BFloat16 and TensorFloat32 conversions for cooperative matrices (#2213) Previously added scalar/vector ConvertFToBF16INTEL, ConvertBF16ToFINTEL and RoundFToTF32INTEL conversions are now enabled for cooperative matrix type under SPV_INTEL_joint_matrix extension following the spec: https://github.com/intel/llvm/blob/sycl/sycl/doc/design/spirv-extensions/SPV_INTEL_joint_matrix.asciidoc Note, joint matrices are not allowed as input/output for these conversions as it is being deprecated. Signed-off-by: Sidorov, Dmitry --- lib/SPIRV/libSPIRV/SPIRVInstruction.h | 53 ++++++++++++- .../bf16_conversion_instructions.ll | 79 +++++++++++++++++++ .../tf32_conversion_instructions.ll | 53 +++++++++++++ 3 files changed, 183 insertions(+), 2 deletions(-) create mode 100644 test/extensions/INTEL/SPV_INTEL_joint_matrix/bf16_conversion_instructions.ll create mode 100644 test/extensions/INTEL/SPV_INTEL_joint_matrix/tf32_conversion_instructions.ll diff --git a/lib/SPIRV/libSPIRV/SPIRVInstruction.h b/lib/SPIRV/libSPIRV/SPIRVInstruction.h index 5e2bdde59f..fcf08720ba 100644 --- a/lib/SPIRV/libSPIRV/SPIRVInstruction.h +++ b/lib/SPIRV/libSPIRV/SPIRVInstruction.h @@ -3294,10 +3294,17 @@ template class SPIRVBfloat16ConversionINTELInstBase : public SPIRVUnaryInst { protected: SPIRVCapVec getRequiredCapability() const override { + SPIRVType *ResCompTy = this->getType(); + if (ResCompTy->isTypeCooperativeMatrixKHR()) + return getVec(internal::CapabilityBfloat16ConversionINTEL, + internal::CapabilityJointMatrixBF16ComponentTypeINTEL); return getVec(internal::CapabilityBfloat16ConversionINTEL); } std::optional getRequiredExtension() const override { + SPIRVType *ResCompTy = this->getType(); + if (ResCompTy->isTypeCooperativeMatrixKHR()) + this->getModule()->addExtension(ExtensionID::SPV_INTEL_joint_matrix); return ExtensionID::SPV_INTEL_bfloat16_conversion; } @@ -3326,8 +3333,25 @@ class SPIRVBfloat16ConversionINTELInstBase : public SPIRVUnaryInst { } auto InstName = OpCodeNameMap::map(OC); - SPIRVErrorLog &SPVErrLog = this->getModule()->getErrorLog(); + auto *Module = this->getModule(); + SPIRVErrorLog &SPVErrLog = Module->getErrorLog(); + // Cooperative matrix type is allowed as input/output of the instruction + // if SPV_INTEL_joint_matrix is enabled + if (ResCompTy->isTypeCooperativeMatrixKHR()) { + SPVErrLog.checkError( + Module->isAllowedToUseExtension(ExtensionID::SPV_INTEL_joint_matrix), + SPIRVEC_InvalidInstruction, + InstName + "\nCan be used with " + "cooperative matrices only when SPV_INTEL_joint_matrix is " + "enabled\n"); + assert(InCompTy->isTypeCooperativeMatrixKHR() && + "Input must also be a cooperative matrix"); + ResCompTy = static_cast(ResCompTy) + ->getCompType(); + InCompTy = + static_cast(InCompTy)->getCompType(); + } if (OC == internal::OpConvertFToBF16INTEL) { SPVErrLog.checkError( ResCompTy->isTypeInt(16), SPIRVEC_InvalidInstruction, @@ -3679,10 +3703,17 @@ template class SPIRVTensorFloat32RoundingINTELInstBase : public SPIRVUnaryInst { protected: SPIRVCapVec getRequiredCapability() const override { + SPIRVType *ResCompTy = this->getType(); + if (ResCompTy->isTypeCooperativeMatrixKHR()) + return getVec(internal::CapabilityTensorFloat32RoundingINTEL, + internal::CapabilityJointMatrixTF32ComponentTypeINTEL); return getVec(internal::CapabilityTensorFloat32RoundingINTEL); } std::optional getRequiredExtension() const override { + SPIRVType *ResCompTy = this->getType(); + if (ResCompTy->isTypeCooperativeMatrixKHR()) + this->getModule()->addExtension(ExtensionID::SPV_INTEL_joint_matrix); return ExtensionID::SPV_INTEL_tensor_float32_conversion; } @@ -3711,7 +3742,25 @@ class SPIRVTensorFloat32RoundingINTELInstBase : public SPIRVUnaryInst { } auto InstName = OpCodeNameMap::map(OC); - SPIRVErrorLog &SPVErrLog = this->getModule()->getErrorLog(); + auto *Module = this->getModule(); + SPIRVErrorLog &SPVErrLog = Module->getErrorLog(); + + // Cooperative matrix type is allowed as input/output of the instruction + // if SPV_INTEL_joint_matrix is enabled + if (ResCompTy->isTypeCooperativeMatrixKHR()) { + SPVErrLog.checkError( + Module->isAllowedToUseExtension(ExtensionID::SPV_INTEL_joint_matrix), + SPIRVEC_InvalidInstruction, + InstName + "\nCan be used with " + "cooperative matrices only when SPV_INTEL_joint_matrix is " + "enabled\n"); + assert(InCompTy->isTypeCooperativeMatrixKHR() && + "Input must also be a cooperative matrix"); + ResCompTy = static_cast(ResCompTy) + ->getCompType(); + InCompTy = + static_cast(InCompTy)->getCompType(); + } SPVErrLog.checkError( ResCompTy->isTypeFloat(32), SPIRVEC_InvalidInstruction, diff --git a/test/extensions/INTEL/SPV_INTEL_joint_matrix/bf16_conversion_instructions.ll b/test/extensions/INTEL/SPV_INTEL_joint_matrix/bf16_conversion_instructions.ll new file mode 100644 index 0000000000..eb1d1afe51 --- /dev/null +++ b/test/extensions/INTEL/SPV_INTEL_joint_matrix/bf16_conversion_instructions.ll @@ -0,0 +1,79 @@ +; RUN: llvm-as < %s -o %t.bc +; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_KHR_cooperative_matrix,+SPV_INTEL_joint_matrix,+SPV_INTEL_bfloat16_conversion -o %t.spv +; RUN: llvm-spirv %t.spv -to-text -o %t.spt +; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV + +; RUN: llvm-spirv -r %t.spv -o %t.rev.bc +; RUN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefix=CHECK-OCL-IR + +; RUN: llvm-spirv -r %t.spv -o %t.rev.bc --spirv-target-env=SPV-IR +; RUN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefix=CHECK-SPV-IR + +; RUN: not llvm-spirv %t.bc --spirv-ext=+SPV_KHR_cooperative_matrix,+SPV_INTEL_bfloat16_conversion 2>&1 \ +; RUN: | FileCheck %s --check-prefix=CHECK-ERROR + +; CHECK-ERROR: InvalidInstruction: Can't translate llvm instruction: +; CHECK-ERROR-NEXT: ConvertFToBF16INTEL +; CHECK-ERROR-NEXT: Can be used with cooperative matrices only when SPV_INTEL_joint_matrix is enabled + +; CHECK-SPIRV-DAG: Capability CooperativeMatrixKHR +; CHECK-SPIRV-DAG: Capability Bfloat16ConversionINTEL +; CHECK-SPIRV-DAG: Capability JointMatrixBF16ComponentTypeINTEL +; CHECK-SPIRV-DAG: Extension "SPV_INTEL_bfloat16_conversion" +; CHECK-SPIRV-DAG: Extension "SPV_KHR_cooperative_matrix" +; CHECK-SPIRV-DAG: Extension "SPV_INTEL_joint_matrix" +; CHECK-SPIRV-DAG: TypeInt [[#ShortTy:]] 16 0 +; CHECK-SPIRV-DAG: TypeFloat [[#FP32Ty:]] 32 +; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#FP32MatTy:]] [[#FP32Ty]] +; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#ShortMatTy:]] [[#ShortTy]] +; CHECK-SPIRV: CompositeConstruct [[#FP32MatTy]] [[#FP32Mat:]] +; CHECK-SPIRV: ConvertFToBF16INTEL [[#ShortMatTy]] [[#]] [[#FP32Mat]] +; CHECK-SPIRV: CompositeConstruct [[#ShortMatTy]] [[#ShortMat:]] +; CHECK-SPIRV: ConvertBF16ToFINTEL [[#FP32MatTy]] [[#]] [[#ShortMat]] + +; CHECK-OCL-IR: %[[#FP32Matrix:]] = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructf(float 0.000000e+00) +; CHECK-OCL-IR: call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) @_Z32intel_convert_bfloat16_as_ushortPU3AS145__spirv_CooperativeMatrixKHR__float_3_12_12_3(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %[[#FP32Matrix]]) +; CHECK-OCL-IR: %[[#ShortMatrix:]] = call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructs(i16 0) +; CHECK-OCL-IR: call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z31intel_convert_as_bfloat16_floatPU3AS145__spirv_CooperativeMatrixKHR__short_3_12_12_3(target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) %[[#ShortMatrix]]) + + +; CHECK-SPV-IR: %[[#FP32Matrix:]] = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructf(float 0.000000e+00) +; CHECK-SPV-IR: call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) @_Z27__spirv_ConvertFToBF16INTELPU3AS145__spirv_CooperativeMatrixKHR__float_3_12_12_3(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %[[#FP32Matrix]]) +; CHECK-SPV-IR: %[[#ShortMatrix:]] = call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructs(i16 0) +; CHECK-SPV-IR: call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z27__spirv_ConvertBF16ToFINTELPU3AS145__spirv_CooperativeMatrixKHR__short_3_12_12_3(target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) %[[#ShortMatrix]]) + + +target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128" +target triple = "spir64-unknown-unknown" + +define void @convert_f_to_bf() { +entry: + %0 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float 0.000000e+00) + %call = call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) @_Z27__spirv_ConvertFToBF16INTEL(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0) + ret void +} + +define void @convert_bf_to_f() { +entry: + %0 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt16(i16 0) + %call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z27__spirv_ConvertBF16ToFINTEL(target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) %0) + ret void +} + +declare spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float noundef) + +declare spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt16(i16 noundef) + +declare spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) @_Z27__spirv_ConvertFToBF16INTEL(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) noundef) + +declare spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z27__spirv_ConvertBF16ToFINTEL(target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) noundef) + +!llvm.module.flags = !{!0, !1, !2, !3, !4} +!llvm.ident = !{!5} + +!0 = !{i32 7, !"Dwarf Version", i32 4} +!1 = !{i32 1, !"wchar_size", i32 4} +!2 = !{i32 8, !"PIC Level", i32 2} +!3 = !{i32 7, !"PIE Level", i32 2} +!4 = !{i32 7, !"uwtable", i32 2} +!5 = !{!"clang version 17.0.0"} diff --git a/test/extensions/INTEL/SPV_INTEL_joint_matrix/tf32_conversion_instructions.ll b/test/extensions/INTEL/SPV_INTEL_joint_matrix/tf32_conversion_instructions.ll new file mode 100644 index 0000000000..6392c94138 --- /dev/null +++ b/test/extensions/INTEL/SPV_INTEL_joint_matrix/tf32_conversion_instructions.ll @@ -0,0 +1,53 @@ +; RUN: llvm-as < %s -o %t.bc +; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_KHR_cooperative_matrix,+SPV_INTEL_joint_matrix,+SPV_INTEL_tensor_float32_conversion -o %t.spv +; RUN: llvm-spirv %t.spv -to-text -o %t.spt +; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV + +; RUN: llvm-spirv -r %t.spv -o %t.rev.bc +; RUN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefix=CHECK-LLVM + +; RUN: not llvm-spirv %t.bc --spirv-ext=+SPV_KHR_cooperative_matrix,+SPV_INTEL_tensor_float32_conversion 2>&1 \ +; RUN: | FileCheck %s --check-prefix=CHECK-ERROR + +; CHECK-ERROR: InvalidInstruction: Can't translate llvm instruction: +; CHECK-ERROR-NEXT: RoundFToTF32INTEL +; CHECK-ERROR-NEXT: Can be used with cooperative matrices only when SPV_INTEL_joint_matrix is enabled + +; CHECK-SPIRV-DAG: Capability CooperativeMatrixKHR +; CHECK-SPIRV-DAG: Capability TensorFloat32RoundingINTEL +; CHECK-SPIRV-DAG: Capability JointMatrixTF32ComponentTypeINTEL +; CHECK-SPIRV-DAG: Extension "SPV_INTEL_tensor_float32_conversion" +; CHECK-SPIRV-DAG: Extension "SPV_KHR_cooperative_matrix" +; CHECK-SPIRV-DAG: Extension "SPV_INTEL_joint_matrix" +; CHECK-SPIRV-DAG: TypeFloat [[#FP32Ty:]] 32 +; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#FP32MatTy:]] [[#FP32Ty]] +; CHECK-SPIRV: CompositeConstruct [[#FP32MatTy]] [[#FP32Mat:]] +; CHECK-SPIRV: RoundFToTF32INTEL [[#FP32MatTy]] [[#]] [[#FP32Mat]] + +; CHECK-LLVM: %[[#Mat:]] = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructf(float 0.000000e+00) +; CHECK-LLVM: call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z25__spirv_RoundFToTF32INTELPU3AS145__spirv_CooperativeMatrixKHR__float_3_12_12_3(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %[[#Mat]]) + + +target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128" +target triple = "spir64-unknown-unknown" + +define void @convert_f_to_tf() { +entry: + %0 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float 0.000000e+00) + %call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z25__spirv_RoundFToTF32INTEL(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0) + ret void +} + +declare spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float noundef) + +declare spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z25__spirv_RoundFToTF32INTEL(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) noundef) + +!llvm.module.flags = !{!0, !1, !2, !3, !4} +!llvm.ident = !{!5} + +!0 = !{i32 7, !"Dwarf Version", i32 4} +!1 = !{i32 1, !"wchar_size", i32 4} +!2 = !{i32 8, !"PIC Level", i32 2} +!3 = !{i32 7, !"PIE Level", i32 2} +!4 = !{i32 7, !"uwtable", i32 2} +!5 = !{!"clang version 17.0.0"}