From 0ef8e71874e128560fdc77b6234d1bef3e18d3bd Mon Sep 17 00:00:00 2001 From: Craig Topper Date: Mon, 26 Aug 2024 21:48:32 -0700 Subject: [PATCH] [RISCV] Custom legalize vXbf16 BUILD_VECTOR without Zfbfmin. By default, type legalization will try to promote the build_vector, but that generic type legalizer doesn't support that. Bitcast to vXi16 instead. Same as what we do for vXf16 without Zfhmin. Fixes #100846. --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 16 ++- .../RISCV/rvv/fixed-vectors-fp-splat-bf16.ll | 111 ++++++++++++++++++ 2 files changed, 122 insertions(+), 5 deletions(-) create mode 100644 llvm/test/CodeGen/RISCV/rvv/fixed-vectors-fp-splat-bf16.ll diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 9b96e32c5ab394..790107b772fcb3 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -1285,8 +1285,14 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, if (VT.getVectorElementType() == MVT::bf16) { setOperationAction({ISD::VP_FP_ROUND, ISD::VP_FP_EXTEND}, VT, Custom); - // FIXME: We should prefer BUILD_VECTOR over SPLAT_VECTOR. - setOperationAction(ISD::SPLAT_VECTOR, VT, Custom); + if (Subtarget.hasStdExtZfbfmin()) { + // FIXME: We should prefer BUILD_VECTOR over SPLAT_VECTOR. + setOperationAction(ISD::SPLAT_VECTOR, VT, Custom); + } else { + // We need to custom legalize bf16 build vectors if Zfbfmin isn't + // available. + setOperationAction(ISD::BUILD_VECTOR, MVT::bf16, Custom); + } setOperationAction( {ISD::VP_MERGE, ISD::VP_SELECT, ISD::VSELECT, ISD::SELECT}, VT, Custom); @@ -3935,9 +3941,9 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG, MVT VT = Op.getSimpleValueType(); assert(VT.isFixedLengthVector() && "Unexpected vector!"); - // If we don't have scalar f16, we need to bitcast to an i16 vector. - if (VT.getVectorElementType() == MVT::f16 && - !Subtarget.hasStdExtZfhmin()) + // If we don't have scalar f16/bf16, we need to bitcast to an i16 vector. + if ((VT.getVectorElementType() == MVT::f16 && !Subtarget.hasStdExtZfhmin()) || + (VT.getVectorElementType() == MVT::bf16 && !Subtarget.hasStdExtZfbfmin())) return lowerBUILD_VECTORvXf16(Op, DAG); if (ISD::isBuildVectorOfConstantSDNodes(Op.getNode()) || diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-fp-splat-bf16.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-fp-splat-bf16.ll new file mode 100644 index 00000000000000..b1250f4804549a --- /dev/null +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-fp-splat-bf16.ll @@ -0,0 +1,111 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=riscv32 -target-abi=ilp32d -mattr=+v,+zfbfmin,+zvfbfmin -verify-machineinstrs < %s | FileCheck %s --check-prefixes=ZFBFMIN-ZVFBFMIN +; RUN: llc -mtriple=riscv32 -target-abi=ilp32d -mattr=+v,+zvfbfmin -verify-machineinstrs < %s | FileCheck %s --check-prefixes=ZVFBFMIN +; RUN: llc -mtriple=riscv64 -target-abi=lp64d -mattr=+v,+zfbfmin,+zvfbfmin -verify-machineinstrs < %s | FileCheck %s --check-prefixes=ZFBFMIN-ZVFBFMIN +; RUN: llc -mtriple=riscv64 -target-abi=lp64d -mattr=+v,+zvfbfmin -verify-machineinstrs < %s | FileCheck %s --check-prefixes=ZVFBFMIN + +define <8 x bfloat> @splat_v8bf16(ptr %x, bfloat %y) { +; ZFBFMIN-ZVFBFMIN-LABEL: splat_v8bf16: +; ZFBFMIN-ZVFBFMIN: # %bb.0: +; ZFBFMIN-ZVFBFMIN-NEXT: fcvt.s.bf16 fa5, fa0 +; ZFBFMIN-ZVFBFMIN-NEXT: vsetvli a0, zero, e32, m2, ta, ma +; ZFBFMIN-ZVFBFMIN-NEXT: vfmv.v.f v10, fa5 +; ZFBFMIN-ZVFBFMIN-NEXT: vsetvli zero, zero, e16, m1, ta, ma +; ZFBFMIN-ZVFBFMIN-NEXT: vfncvtbf16.f.f.w v8, v10 +; ZFBFMIN-ZVFBFMIN-NEXT: ret +; +; ZVFBFMIN-LABEL: splat_v8bf16: +; ZVFBFMIN: # %bb.0: +; ZVFBFMIN-NEXT: fmv.x.w a0, fa0 +; ZVFBFMIN-NEXT: vsetivli zero, 8, e16, m1, ta, ma +; ZVFBFMIN-NEXT: vmv.v.x v8, a0 +; ZVFBFMIN-NEXT: ret + %a = insertelement <8 x bfloat> poison, bfloat %y, i32 0 + %b = shufflevector <8 x bfloat> %a, <8 x bfloat> poison, <8 x i32> zeroinitializer + ret <8 x bfloat> %b +} + +define <16 x bfloat> @splat_16bf16(ptr %x, bfloat %y) { +; ZFBFMIN-ZVFBFMIN-LABEL: splat_16bf16: +; ZFBFMIN-ZVFBFMIN: # %bb.0: +; ZFBFMIN-ZVFBFMIN-NEXT: fcvt.s.bf16 fa5, fa0 +; ZFBFMIN-ZVFBFMIN-NEXT: vsetvli a0, zero, e32, m4, ta, ma +; ZFBFMIN-ZVFBFMIN-NEXT: vfmv.v.f v12, fa5 +; ZFBFMIN-ZVFBFMIN-NEXT: vsetvli zero, zero, e16, m2, ta, ma +; ZFBFMIN-ZVFBFMIN-NEXT: vfncvtbf16.f.f.w v8, v12 +; ZFBFMIN-ZVFBFMIN-NEXT: ret +; +; ZVFBFMIN-LABEL: splat_16bf16: +; ZVFBFMIN: # %bb.0: +; ZVFBFMIN-NEXT: fmv.x.w a0, fa0 +; ZVFBFMIN-NEXT: vsetivli zero, 16, e16, m2, ta, ma +; ZVFBFMIN-NEXT: vmv.v.x v8, a0 +; ZVFBFMIN-NEXT: ret + %a = insertelement <16 x bfloat> poison, bfloat %y, i32 0 + %b = shufflevector <16 x bfloat> %a, <16 x bfloat> poison, <16 x i32> zeroinitializer + ret <16 x bfloat> %b +} + +define <8 x bfloat> @splat_zero_v8bf16(ptr %x) { +; ZFBFMIN-ZVFBFMIN-LABEL: splat_zero_v8bf16: +; ZFBFMIN-ZVFBFMIN: # %bb.0: +; ZFBFMIN-ZVFBFMIN-NEXT: vsetvli a0, zero, e16, m1, ta, ma +; ZFBFMIN-ZVFBFMIN-NEXT: vmv.v.i v8, 0 +; ZFBFMIN-ZVFBFMIN-NEXT: ret +; +; ZVFBFMIN-LABEL: splat_zero_v8bf16: +; ZVFBFMIN: # %bb.0: +; ZVFBFMIN-NEXT: vsetivli zero, 8, e16, m1, ta, ma +; ZVFBFMIN-NEXT: vmv.v.i v8, 0 +; ZVFBFMIN-NEXT: ret + ret <8 x bfloat> splat (bfloat 0.0) +} + +define <16 x bfloat> @splat_zero_16bf16(ptr %x) { +; ZFBFMIN-ZVFBFMIN-LABEL: splat_zero_16bf16: +; ZFBFMIN-ZVFBFMIN: # %bb.0: +; ZFBFMIN-ZVFBFMIN-NEXT: vsetvli a0, zero, e16, m2, ta, ma +; ZFBFMIN-ZVFBFMIN-NEXT: vmv.v.i v8, 0 +; ZFBFMIN-ZVFBFMIN-NEXT: ret +; +; ZVFBFMIN-LABEL: splat_zero_16bf16: +; ZVFBFMIN: # %bb.0: +; ZVFBFMIN-NEXT: vsetivli zero, 16, e16, m2, ta, ma +; ZVFBFMIN-NEXT: vmv.v.i v8, 0 +; ZVFBFMIN-NEXT: ret + ret <16 x bfloat> splat (bfloat 0.0) +} + +define <8 x bfloat> @splat_negzero_v8bf16(ptr %x) { +; ZFBFMIN-ZVFBFMIN-LABEL: splat_negzero_v8bf16: +; ZFBFMIN-ZVFBFMIN: # %bb.0: +; ZFBFMIN-ZVFBFMIN-NEXT: lui a0, 1048568 +; ZFBFMIN-ZVFBFMIN-NEXT: vsetvli a1, zero, e16, m1, ta, ma +; ZFBFMIN-ZVFBFMIN-NEXT: vmv.v.x v8, a0 +; ZFBFMIN-ZVFBFMIN-NEXT: ret +; +; ZVFBFMIN-LABEL: splat_negzero_v8bf16: +; ZVFBFMIN: # %bb.0: +; ZVFBFMIN-NEXT: lui a0, 1048568 +; ZVFBFMIN-NEXT: vsetivli zero, 8, e16, m1, ta, ma +; ZVFBFMIN-NEXT: vmv.v.x v8, a0 +; ZVFBFMIN-NEXT: ret + ret <8 x bfloat> splat (bfloat -0.0) +} + +define <16 x bfloat> @splat_negzero_16bf16(ptr %x) { +; ZFBFMIN-ZVFBFMIN-LABEL: splat_negzero_16bf16: +; ZFBFMIN-ZVFBFMIN: # %bb.0: +; ZFBFMIN-ZVFBFMIN-NEXT: lui a0, 1048568 +; ZFBFMIN-ZVFBFMIN-NEXT: vsetvli a1, zero, e16, m2, ta, ma +; ZFBFMIN-ZVFBFMIN-NEXT: vmv.v.x v8, a0 +; ZFBFMIN-ZVFBFMIN-NEXT: ret +; +; ZVFBFMIN-LABEL: splat_negzero_16bf16: +; ZVFBFMIN: # %bb.0: +; ZVFBFMIN-NEXT: lui a0, 1048568 +; ZVFBFMIN-NEXT: vsetivli zero, 16, e16, m2, ta, ma +; ZVFBFMIN-NEXT: vmv.v.x v8, a0 +; ZVFBFMIN-NEXT: ret + ret <16 x bfloat> splat (bfloat -0.0) +}