From 3e3780ef6ab5902cd1763e28bb143e47091bd23a Mon Sep 17 00:00:00 2001 From: Paul Walker Date: Tue, 24 Sep 2024 13:15:26 +0100 Subject: [PATCH] [LLVM][CodeGen][SVE] Implement nxvf32 fpround to nxvbf16. (#107420) --- .../Target/AArch64/AArch64ISelLowering.cpp | 50 ++++++- .../lib/Target/AArch64/AArch64SVEInstrInfo.td | 2 +- llvm/lib/Target/AArch64/SVEInstrFormats.td | 6 +- .../test/CodeGen/AArch64/sve-bf16-converts.ll | 129 +++++++++++++++++- 4 files changed, 180 insertions(+), 7 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index b11ac81069f660..4166d9bd22bc01 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1664,6 +1664,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::BITCAST, VT, Custom); setOperationAction(ISD::CONCAT_VECTORS, VT, Custom); setOperationAction(ISD::FP_EXTEND, VT, Custom); + setOperationAction(ISD::FP_ROUND, VT, Custom); setOperationAction(ISD::MLOAD, VT, Custom); setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom); setOperationAction(ISD::SPLAT_VECTOR, VT, Legal); @@ -4334,14 +4335,57 @@ SDValue AArch64TargetLowering::LowerFP_EXTEND(SDValue Op, SDValue AArch64TargetLowering::LowerFP_ROUND(SDValue Op, SelectionDAG &DAG) const { EVT VT = Op.getValueType(); - if (VT.isScalableVector()) - return LowerToPredicatedOp(Op, DAG, AArch64ISD::FP_ROUND_MERGE_PASSTHRU); - bool IsStrict = Op->isStrictFPOpcode(); SDValue SrcVal = Op.getOperand(IsStrict ? 1 : 0); EVT SrcVT = SrcVal.getValueType(); bool Trunc = Op.getConstantOperandVal(IsStrict ? 2 : 1) == 1; + if (VT.isScalableVector()) { + if (VT.getScalarType() != MVT::bf16) + return LowerToPredicatedOp(Op, DAG, AArch64ISD::FP_ROUND_MERGE_PASSTHRU); + + SDLoc DL(Op); + constexpr EVT I32 = MVT::nxv4i32; + auto ImmV = [&](int I) -> SDValue { return DAG.getConstant(I, DL, I32); }; + + SDValue NaN; + SDValue Narrow; + + if (SrcVT == MVT::nxv2f32 || SrcVT == MVT::nxv4f32) { + if (Subtarget->hasBF16()) + return LowerToPredicatedOp(Op, DAG, + AArch64ISD::FP_ROUND_MERGE_PASSTHRU); + + Narrow = getSVESafeBitCast(I32, SrcVal, DAG); + + // Set the quiet bit. + if (!DAG.isKnownNeverSNaN(SrcVal)) + NaN = DAG.getNode(ISD::OR, DL, I32, Narrow, ImmV(0x400000)); + } else + return SDValue(); + + if (!Trunc) { + SDValue Lsb = DAG.getNode(ISD::SRL, DL, I32, Narrow, ImmV(16)); + Lsb = DAG.getNode(ISD::AND, DL, I32, Lsb, ImmV(1)); + SDValue RoundingBias = DAG.getNode(ISD::ADD, DL, I32, Lsb, ImmV(0x7fff)); + Narrow = DAG.getNode(ISD::ADD, DL, I32, Narrow, RoundingBias); + } + + // Don't round if we had a NaN, we don't want to turn 0x7fffffff into + // 0x80000000. + if (NaN) { + EVT I1 = I32.changeElementType(MVT::i1); + EVT CondVT = VT.changeElementType(MVT::i1); + SDValue IsNaN = DAG.getSetCC(DL, CondVT, SrcVal, SrcVal, ISD::SETUO); + IsNaN = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, I1, IsNaN); + Narrow = DAG.getSelect(DL, I32, IsNaN, NaN, Narrow); + } + + // Now that we have rounded, shift the bits into position. + Narrow = DAG.getNode(ISD::SRL, DL, I32, Narrow, ImmV(16)); + return getSVESafeBitCast(VT, Narrow, DAG); + } + if (useSVEForFixedLengthVectorVT(SrcVT, !Subtarget->isNeonAvailable())) return LowerFixedLengthFPRoundToSVE(Op, DAG); diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td index 1f3d63a216c6dd..7240f6a22a87bd 100644 --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -2425,7 +2425,7 @@ let Predicates = [HasBF16, HasSVEorSME] in { defm BFMLALT_ZZZ : sve2_fp_mla_long<0b101, "bfmlalt", nxv4f32, nxv8bf16, int_aarch64_sve_bfmlalt>; defm BFMLALB_ZZZI : sve2_fp_mla_long_by_indexed_elem<0b100, "bfmlalb", nxv4f32, nxv8bf16, int_aarch64_sve_bfmlalb_lane_v2>; defm BFMLALT_ZZZI : sve2_fp_mla_long_by_indexed_elem<0b101, "bfmlalt", nxv4f32, nxv8bf16, int_aarch64_sve_bfmlalt_lane_v2>; - defm BFCVT_ZPmZ : sve_bfloat_convert<0b1, "bfcvt", int_aarch64_sve_fcvt_bf16f32>; + defm BFCVT_ZPmZ : sve_bfloat_convert<0b1, "bfcvt", int_aarch64_sve_fcvt_bf16f32, AArch64fcvtr_mt>; defm BFCVTNT_ZPmZ : sve_bfloat_convert<0b0, "bfcvtnt", int_aarch64_sve_fcvtnt_bf16f32>; } // End HasBF16, HasSVEorSME diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td index 8119198a48aa59..0bfac6465a1f30 100644 --- a/llvm/lib/Target/AArch64/SVEInstrFormats.td +++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td @@ -8807,9 +8807,13 @@ class sve_bfloat_convert let mayRaiseFPException = 1; } -multiclass sve_bfloat_convert { +multiclass sve_bfloat_convert { def NAME : sve_bfloat_convert; + def : SVE_3_Op_Pat(NAME)>; + def : SVE_1_Op_Passthru_Round_Pat(NAME)>; + def : SVE_1_Op_Passthru_Round_Pat(NAME)>; } //===----------------------------------------------------------------------===// diff --git a/llvm/test/CodeGen/AArch64/sve-bf16-converts.ll b/llvm/test/CodeGen/AArch64/sve-bf16-converts.ll index d72f92c1dac1ff..d63f7e6f3242e0 100644 --- a/llvm/test/CodeGen/AArch64/sve-bf16-converts.ll +++ b/llvm/test/CodeGen/AArch64/sve-bf16-converts.ll @@ -1,9 +1,15 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 -; RUN: llc -mattr=+sve < %s | FileCheck %s -; RUN: llc -mattr=+sme -force-streaming < %s | FileCheck %s +; RUN: llc -mattr=+sve < %s | FileCheck %s --check-prefixes=CHECK,NOBF16 +; RUN: llc -mattr=+sve --enable-no-nans-fp-math < %s | FileCheck %s --check-prefixes=CHECK,NOBF16NNAN +; RUN: llc -mattr=+sve,+bf16 < %s | FileCheck %s --check-prefixes=CHECK,BF16 +; RUN: llc -mattr=+sme -force-streaming < %s | FileCheck %s --check-prefixes=CHECK,BF16 target triple = "aarch64-unknown-linux-gnu" +; NOTE: "fptrunc <# x double> to <# x bfloat>" is not supported because SVE +; lacks a down convert that rounds to odd. Such IR will trigger the usual +; failure (crash) when attempting to unroll a scalable vector. + define @fpext_nxv2bf16_to_nxv2f32( %a) { ; CHECK-LABEL: fpext_nxv2bf16_to_nxv2f32: ; CHECK: // %bb.0: @@ -87,3 +93,122 @@ define @fpext_nxv8bf16_to_nxv8f64( %a %res = fpext %a to ret %res } + +define @fptrunc_nxv2f32_to_nxv2bf16( %a) { +; NOBF16-LABEL: fptrunc_nxv2f32_to_nxv2bf16: +; NOBF16: // %bb.0: +; NOBF16-NEXT: mov z1.s, #32767 // =0x7fff +; NOBF16-NEXT: lsr z2.s, z0.s, #16 +; NOBF16-NEXT: ptrue p0.d +; NOBF16-NEXT: fcmuo p0.s, p0/z, z0.s, z0.s +; NOBF16-NEXT: and z2.s, z2.s, #0x1 +; NOBF16-NEXT: add z1.s, z0.s, z1.s +; NOBF16-NEXT: orr z0.s, z0.s, #0x400000 +; NOBF16-NEXT: add z1.s, z2.s, z1.s +; NOBF16-NEXT: sel z0.s, p0, z0.s, z1.s +; NOBF16-NEXT: lsr z0.s, z0.s, #16 +; NOBF16-NEXT: ret +; +; NOBF16NNAN-LABEL: fptrunc_nxv2f32_to_nxv2bf16: +; NOBF16NNAN: // %bb.0: +; NOBF16NNAN-NEXT: mov z1.s, #32767 // =0x7fff +; NOBF16NNAN-NEXT: lsr z2.s, z0.s, #16 +; NOBF16NNAN-NEXT: and z2.s, z2.s, #0x1 +; NOBF16NNAN-NEXT: add z0.s, z0.s, z1.s +; NOBF16NNAN-NEXT: add z0.s, z2.s, z0.s +; NOBF16NNAN-NEXT: lsr z0.s, z0.s, #16 +; NOBF16NNAN-NEXT: ret +; +; BF16-LABEL: fptrunc_nxv2f32_to_nxv2bf16: +; BF16: // %bb.0: +; BF16-NEXT: ptrue p0.d +; BF16-NEXT: bfcvt z0.h, p0/m, z0.s +; BF16-NEXT: ret + %res = fptrunc %a to + ret %res +} + +define @fptrunc_nxv4f32_to_nxv4bf16( %a) { +; NOBF16-LABEL: fptrunc_nxv4f32_to_nxv4bf16: +; NOBF16: // %bb.0: +; NOBF16-NEXT: mov z1.s, #32767 // =0x7fff +; NOBF16-NEXT: lsr z2.s, z0.s, #16 +; NOBF16-NEXT: ptrue p0.s +; NOBF16-NEXT: fcmuo p0.s, p0/z, z0.s, z0.s +; NOBF16-NEXT: and z2.s, z2.s, #0x1 +; NOBF16-NEXT: add z1.s, z0.s, z1.s +; NOBF16-NEXT: orr z0.s, z0.s, #0x400000 +; NOBF16-NEXT: add z1.s, z2.s, z1.s +; NOBF16-NEXT: sel z0.s, p0, z0.s, z1.s +; NOBF16-NEXT: lsr z0.s, z0.s, #16 +; NOBF16-NEXT: ret +; +; NOBF16NNAN-LABEL: fptrunc_nxv4f32_to_nxv4bf16: +; NOBF16NNAN: // %bb.0: +; NOBF16NNAN-NEXT: mov z1.s, #32767 // =0x7fff +; NOBF16NNAN-NEXT: lsr z2.s, z0.s, #16 +; NOBF16NNAN-NEXT: and z2.s, z2.s, #0x1 +; NOBF16NNAN-NEXT: add z0.s, z0.s, z1.s +; NOBF16NNAN-NEXT: add z0.s, z2.s, z0.s +; NOBF16NNAN-NEXT: lsr z0.s, z0.s, #16 +; NOBF16NNAN-NEXT: ret +; +; BF16-LABEL: fptrunc_nxv4f32_to_nxv4bf16: +; BF16: // %bb.0: +; BF16-NEXT: ptrue p0.s +; BF16-NEXT: bfcvt z0.h, p0/m, z0.s +; BF16-NEXT: ret + %res = fptrunc %a to + ret %res +} + +define @fptrunc_nxv8f32_to_nxv8bf16( %a) { +; NOBF16-LABEL: fptrunc_nxv8f32_to_nxv8bf16: +; NOBF16: // %bb.0: +; NOBF16-NEXT: mov z2.s, #32767 // =0x7fff +; NOBF16-NEXT: lsr z3.s, z1.s, #16 +; NOBF16-NEXT: lsr z4.s, z0.s, #16 +; NOBF16-NEXT: ptrue p0.s +; NOBF16-NEXT: and z3.s, z3.s, #0x1 +; NOBF16-NEXT: and z4.s, z4.s, #0x1 +; NOBF16-NEXT: fcmuo p1.s, p0/z, z1.s, z1.s +; NOBF16-NEXT: add z5.s, z1.s, z2.s +; NOBF16-NEXT: add z2.s, z0.s, z2.s +; NOBF16-NEXT: fcmuo p0.s, p0/z, z0.s, z0.s +; NOBF16-NEXT: orr z1.s, z1.s, #0x400000 +; NOBF16-NEXT: orr z0.s, z0.s, #0x400000 +; NOBF16-NEXT: add z3.s, z3.s, z5.s +; NOBF16-NEXT: add z2.s, z4.s, z2.s +; NOBF16-NEXT: sel z1.s, p1, z1.s, z3.s +; NOBF16-NEXT: sel z0.s, p0, z0.s, z2.s +; NOBF16-NEXT: lsr z1.s, z1.s, #16 +; NOBF16-NEXT: lsr z0.s, z0.s, #16 +; NOBF16-NEXT: uzp1 z0.h, z0.h, z1.h +; NOBF16-NEXT: ret +; +; NOBF16NNAN-LABEL: fptrunc_nxv8f32_to_nxv8bf16: +; NOBF16NNAN: // %bb.0: +; NOBF16NNAN-NEXT: mov z2.s, #32767 // =0x7fff +; NOBF16NNAN-NEXT: lsr z3.s, z1.s, #16 +; NOBF16NNAN-NEXT: lsr z4.s, z0.s, #16 +; NOBF16NNAN-NEXT: and z3.s, z3.s, #0x1 +; NOBF16NNAN-NEXT: and z4.s, z4.s, #0x1 +; NOBF16NNAN-NEXT: add z1.s, z1.s, z2.s +; NOBF16NNAN-NEXT: add z0.s, z0.s, z2.s +; NOBF16NNAN-NEXT: add z1.s, z3.s, z1.s +; NOBF16NNAN-NEXT: add z0.s, z4.s, z0.s +; NOBF16NNAN-NEXT: lsr z1.s, z1.s, #16 +; NOBF16NNAN-NEXT: lsr z0.s, z0.s, #16 +; NOBF16NNAN-NEXT: uzp1 z0.h, z0.h, z1.h +; NOBF16NNAN-NEXT: ret +; +; BF16-LABEL: fptrunc_nxv8f32_to_nxv8bf16: +; BF16: // %bb.0: +; BF16-NEXT: ptrue p0.s +; BF16-NEXT: bfcvt z1.h, p0/m, z1.s +; BF16-NEXT: bfcvt z0.h, p0/m, z0.s +; BF16-NEXT: uzp1 z0.h, z0.h, z1.h +; BF16-NEXT: ret + %res = fptrunc %a to + ret %res +}