diff --git a/lib/SPIRV/SPIRVWriter.cpp b/lib/SPIRV/SPIRVWriter.cpp index 02d7e14b28..1feb79cc9f 100644 --- a/lib/SPIRV/SPIRVWriter.cpp +++ b/lib/SPIRV/SPIRVWriter.cpp @@ -5731,6 +5731,11 @@ LLVMToSPIRVBase::transBuiltinToInstWithoutDecoration(Op OC, CallInst *CI, return BM->addCompositeConstructInst(transType(CI->getType()), Operands, BB); } + case OpMatrixTimesScalar: { + return BM->addMatrixTimesScalarInst( + transType(CI->getType()), transValue(CI->getArgOperand(0), BB)->getId(), + transValue(CI->getArgOperand(1), BB)->getId(), BB); + } default: { if (isCvtOpCode(OC) && OC != OpGenericCastToPtrExplicit) { return BM->addUnaryInst(OC, transScavengedType(CI), diff --git a/test/extensions/KHR/SPV_KHR_cooperative_matrix/matrix_times_scalar.ll b/test/extensions/KHR/SPV_KHR_cooperative_matrix/matrix_times_scalar.ll new file mode 100644 index 0000000000..d1025d097e --- /dev/null +++ b/test/extensions/KHR/SPV_KHR_cooperative_matrix/matrix_times_scalar.ll @@ -0,0 +1,51 @@ +; RUN: llvm-as < %s -o %t.bc +; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_KHR_cooperative_matrix -o %t.spv +; TODO: Validation is disabled till the moment the tools in CI are updated +; R/UN: spirv-val %t.spv +; RUN: llvm-spirv %t.spv -to-text -o %t.spt +; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV + +; RUN: llvm-spirv -r %t.spv -o %t.bc +; RUN: llvm-dis < %t.bc | FileCheck %s --check-prefix=CHECK-LLVM + +; CHECK-SPIRV: TypeFloat [[#TypeFloat:]] 32 +; CHECK-SPIRV: TypeCooperativeMatrixKHR [[#MatrixType:]] + +; CHECK-SPIRV: CompositeConstruct [[#MatrixType]] [[#Matrix:]] [[#]] {{$}} +; CHECK-SPIRV: Load [[#TypeFloat]] [[#Scalar:]] +; CHECK-SPIRV: MatrixTimesScalar [[#MatrixType]] [[#]] [[#Matrix]] [[#Scalar]] + +; CHECK-LLVM: %[[#Matrix:]] = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructf(float 0.000000e+00) +; CHECK-LLVM: %[[#Scalar:]] = load float, ptr %scalar +; CHECK-LLVM: call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z25__spirv_MatrixTimesScalarPU3AS145__spirv_CooperativeMatrixKHR__float_3_12_12_3f(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %[[#Matrix]], float %[[#Scalar]]) + +target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128" +target triple = "spir64-unknown-unknown" + +; Function Attrs: mustprogress uwtable +define dso_local void @matrix_times_scalar(ptr %scalar) local_unnamed_addr #0 { +entry: + %0 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstruct(float 0.000000e+00) #4 + %1 = load float, ptr %scalar, align 4 + %call = call noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z25__spirv_MatrixTimesScalar(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0, float %1) + ret void +} + +declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstruct(float noundef) local_unnamed_addr #2 + +declare noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z25__spirv_MatrixTimesScalar(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) noundef, float noundef) local_unnamed_addr #2 + +attributes #0 = { mustprogress uwtable "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" } +attributes #1 = { mustprogress nocallback nofree nosync nounwind willreturn memory(argmem: readwrite) } +attributes #2 = { "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" } +attributes #3 = { nounwind } + +!llvm.module.flags = !{!0, !1, !2, !3, !4} +!llvm.ident = !{!5} + +!0 = !{i32 7, !"Dwarf Version", i32 4} +!1 = !{i32 1, !"wchar_size", i32 4} +!2 = !{i32 8, !"PIC Level", i32 2} +!3 = !{i32 7, !"PIE Level", i32 2} +!4 = !{i32 7, !"uwtable", i32 2} +!5 = !{!"clang version 16.0.0 (https://github.com/llvm/llvm-project.git 08d094a0e457360ad8b94b017d2dc277e697ca76)"} diff --git a/test/extensions/KHR/SPV_KHR_cooperative_matrix/matrix_times_scalar.spt b/test/extensions/KHR/SPV_KHR_cooperative_matrix/matrix_times_scalar.spt deleted file mode 100644 index 96d6c3922d..0000000000 --- a/test/extensions/KHR/SPV_KHR_cooperative_matrix/matrix_times_scalar.spt +++ /dev/null @@ -1,44 +0,0 @@ -; RUN: llvm-spirv %s -to-binary -o %t.spv -; Validation is disabled till the moment the tools in CI are updated -; R/UN: spirv-val %t.spv -; RUN: llvm-spirv -r %t.spv -o %t.bc -; RUN: llvm-dis < %t.bc | FileCheck %s --check-prefix=CHECK-LLVM - -; CHECK-LLVM: %[[#Matrix:]] = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructf(float 0.000000e+00) -; CHECK-LLVM: %[[#Scalar:]] = load float, ptr %rhs -; CHECK-LLVM: %3 = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z25__spirv_MatrixTimesScalarPU3AS145__spirv_CooperativeMatrixKHR__float_3_12_12_3f(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %[[#Matrix]], float %[[#Scalar]]) - -119734787 65536 458752 24 0 -2 Capability Addresses -2 Capability Linkage -2 Capability Kernel -2 Capability Float64 -2 Capability Matrix -2 Capability CooperativeMatrixKHR -8 Extension "SPV_KHR_cooperative_matrix" -3 MemoryModel 2 2 -8 EntryPoint 6 20 "matrix_times_scalar" -3 Source 3 102000 -3 Name 14 "rhs" - -4 TypeInt 19 32 0 -4 Constant 19 21 3 -4 Constant 19 22 12 -2 TypeVoid 5 -3 TypeFloat 6 32 -4 Constant 6 23 0 -7 TypeCooperativeMatrixKHR 8 6 21 22 22 21 -4 TypePointer 10 7 6 ; 10 : Pointer to Scalar -4 TypeFunction 11 5 10 - -5 Function 5 20 0 11 -3 FunctionParameter 10 14 ; rhs : Pointer to Scalar - -2 Label 15 -4 CompositeConstruct 8 16 23 -4 Load 6 17 14 - -5 MatrixTimesScalar 8 18 16 17 -1 Return - -1 FunctionEnd