Skip to content

Commit

Permalink
[LLVM][CodeGen][SVE] Implement nxvf32 fpround to nxvbf16. (llvm#107420)
Browse files Browse the repository at this point in the history
  • Loading branch information
paulwalker-arm authored Sep 24, 2024
1 parent c1826ae commit 3e3780e
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 7 deletions.
50 changes: 47 additions & 3 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1664,6 +1664,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::BITCAST, VT, Custom);
setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
setOperationAction(ISD::FP_EXTEND, VT, Custom);
setOperationAction(ISD::FP_ROUND, VT, Custom);
setOperationAction(ISD::MLOAD, VT, Custom);
setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
Expand Down Expand Up @@ -4334,14 +4335,57 @@ SDValue AArch64TargetLowering::LowerFP_EXTEND(SDValue Op,
SDValue AArch64TargetLowering::LowerFP_ROUND(SDValue Op,
SelectionDAG &DAG) const {
EVT VT = Op.getValueType();
if (VT.isScalableVector())
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FP_ROUND_MERGE_PASSTHRU);

bool IsStrict = Op->isStrictFPOpcode();
SDValue SrcVal = Op.getOperand(IsStrict ? 1 : 0);
EVT SrcVT = SrcVal.getValueType();
bool Trunc = Op.getConstantOperandVal(IsStrict ? 2 : 1) == 1;

if (VT.isScalableVector()) {
if (VT.getScalarType() != MVT::bf16)
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FP_ROUND_MERGE_PASSTHRU);

SDLoc DL(Op);
constexpr EVT I32 = MVT::nxv4i32;
auto ImmV = [&](int I) -> SDValue { return DAG.getConstant(I, DL, I32); };

SDValue NaN;
SDValue Narrow;

if (SrcVT == MVT::nxv2f32 || SrcVT == MVT::nxv4f32) {
if (Subtarget->hasBF16())
return LowerToPredicatedOp(Op, DAG,
AArch64ISD::FP_ROUND_MERGE_PASSTHRU);

Narrow = getSVESafeBitCast(I32, SrcVal, DAG);

// Set the quiet bit.
if (!DAG.isKnownNeverSNaN(SrcVal))
NaN = DAG.getNode(ISD::OR, DL, I32, Narrow, ImmV(0x400000));
} else
return SDValue();

if (!Trunc) {
SDValue Lsb = DAG.getNode(ISD::SRL, DL, I32, Narrow, ImmV(16));
Lsb = DAG.getNode(ISD::AND, DL, I32, Lsb, ImmV(1));
SDValue RoundingBias = DAG.getNode(ISD::ADD, DL, I32, Lsb, ImmV(0x7fff));
Narrow = DAG.getNode(ISD::ADD, DL, I32, Narrow, RoundingBias);
}

// Don't round if we had a NaN, we don't want to turn 0x7fffffff into
// 0x80000000.
if (NaN) {
EVT I1 = I32.changeElementType(MVT::i1);
EVT CondVT = VT.changeElementType(MVT::i1);
SDValue IsNaN = DAG.getSetCC(DL, CondVT, SrcVal, SrcVal, ISD::SETUO);
IsNaN = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, I1, IsNaN);
Narrow = DAG.getSelect(DL, I32, IsNaN, NaN, Narrow);
}

// Now that we have rounded, shift the bits into position.
Narrow = DAG.getNode(ISD::SRL, DL, I32, Narrow, ImmV(16));
return getSVESafeBitCast(VT, Narrow, DAG);
}

if (useSVEForFixedLengthVectorVT(SrcVT, !Subtarget->isNeonAvailable()))
return LowerFixedLengthFPRoundToSVE(Op, DAG);

Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -2425,7 +2425,7 @@ let Predicates = [HasBF16, HasSVEorSME] in {
defm BFMLALT_ZZZ : sve2_fp_mla_long<0b101, "bfmlalt", nxv4f32, nxv8bf16, int_aarch64_sve_bfmlalt>;
defm BFMLALB_ZZZI : sve2_fp_mla_long_by_indexed_elem<0b100, "bfmlalb", nxv4f32, nxv8bf16, int_aarch64_sve_bfmlalb_lane_v2>;
defm BFMLALT_ZZZI : sve2_fp_mla_long_by_indexed_elem<0b101, "bfmlalt", nxv4f32, nxv8bf16, int_aarch64_sve_bfmlalt_lane_v2>;
defm BFCVT_ZPmZ : sve_bfloat_convert<0b1, "bfcvt", int_aarch64_sve_fcvt_bf16f32>;
defm BFCVT_ZPmZ : sve_bfloat_convert<0b1, "bfcvt", int_aarch64_sve_fcvt_bf16f32, AArch64fcvtr_mt>;
defm BFCVTNT_ZPmZ : sve_bfloat_convert<0b0, "bfcvtnt", int_aarch64_sve_fcvtnt_bf16f32>;
} // End HasBF16, HasSVEorSME

Expand Down
6 changes: 5 additions & 1 deletion llvm/lib/Target/AArch64/SVEInstrFormats.td
Original file line number Diff line number Diff line change
Expand Up @@ -8807,9 +8807,13 @@ class sve_bfloat_convert<bit N, string asm>
let mayRaiseFPException = 1;
}

multiclass sve_bfloat_convert<bit N, string asm, SDPatternOperator op> {
multiclass sve_bfloat_convert<bit N, string asm, SDPatternOperator op,
SDPatternOperator ir_op = null_frag> {
def NAME : sve_bfloat_convert<N, asm>;

def : SVE_3_Op_Pat<nxv8bf16, op, nxv8bf16, nxv8i1, nxv4f32, !cast<Instruction>(NAME)>;
def : SVE_1_Op_Passthru_Round_Pat<nxv4bf16, ir_op, nxv4i1, nxv4f32, !cast<Instruction>(NAME)>;
def : SVE_1_Op_Passthru_Round_Pat<nxv2bf16, ir_op, nxv2i1, nxv2f32, !cast<Instruction>(NAME)>;
}

//===----------------------------------------------------------------------===//
Expand Down
129 changes: 127 additions & 2 deletions llvm/test/CodeGen/AArch64/sve-bf16-converts.ll
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc -mattr=+sve < %s | FileCheck %s
; RUN: llc -mattr=+sme -force-streaming < %s | FileCheck %s
; RUN: llc -mattr=+sve < %s | FileCheck %s --check-prefixes=CHECK,NOBF16
; RUN: llc -mattr=+sve --enable-no-nans-fp-math < %s | FileCheck %s --check-prefixes=CHECK,NOBF16NNAN
; RUN: llc -mattr=+sve,+bf16 < %s | FileCheck %s --check-prefixes=CHECK,BF16
; RUN: llc -mattr=+sme -force-streaming < %s | FileCheck %s --check-prefixes=CHECK,BF16

target triple = "aarch64-unknown-linux-gnu"

; NOTE: "fptrunc <# x double> to <# x bfloat>" is not supported because SVE
; lacks a down convert that rounds to odd. Such IR will trigger the usual
; failure (crash) when attempting to unroll a scalable vector.

define <vscale x 2 x float> @fpext_nxv2bf16_to_nxv2f32(<vscale x 2 x bfloat> %a) {
; CHECK-LABEL: fpext_nxv2bf16_to_nxv2f32:
; CHECK: // %bb.0:
Expand Down Expand Up @@ -87,3 +93,122 @@ define <vscale x 8 x double> @fpext_nxv8bf16_to_nxv8f64(<vscale x 8 x bfloat> %a
%res = fpext <vscale x 8 x bfloat> %a to <vscale x 8 x double>
ret <vscale x 8 x double> %res
}

define <vscale x 2 x bfloat> @fptrunc_nxv2f32_to_nxv2bf16(<vscale x 2 x float> %a) {
; NOBF16-LABEL: fptrunc_nxv2f32_to_nxv2bf16:
; NOBF16: // %bb.0:
; NOBF16-NEXT: mov z1.s, #32767 // =0x7fff
; NOBF16-NEXT: lsr z2.s, z0.s, #16
; NOBF16-NEXT: ptrue p0.d
; NOBF16-NEXT: fcmuo p0.s, p0/z, z0.s, z0.s
; NOBF16-NEXT: and z2.s, z2.s, #0x1
; NOBF16-NEXT: add z1.s, z0.s, z1.s
; NOBF16-NEXT: orr z0.s, z0.s, #0x400000
; NOBF16-NEXT: add z1.s, z2.s, z1.s
; NOBF16-NEXT: sel z0.s, p0, z0.s, z1.s
; NOBF16-NEXT: lsr z0.s, z0.s, #16
; NOBF16-NEXT: ret
;
; NOBF16NNAN-LABEL: fptrunc_nxv2f32_to_nxv2bf16:
; NOBF16NNAN: // %bb.0:
; NOBF16NNAN-NEXT: mov z1.s, #32767 // =0x7fff
; NOBF16NNAN-NEXT: lsr z2.s, z0.s, #16
; NOBF16NNAN-NEXT: and z2.s, z2.s, #0x1
; NOBF16NNAN-NEXT: add z0.s, z0.s, z1.s
; NOBF16NNAN-NEXT: add z0.s, z2.s, z0.s
; NOBF16NNAN-NEXT: lsr z0.s, z0.s, #16
; NOBF16NNAN-NEXT: ret
;
; BF16-LABEL: fptrunc_nxv2f32_to_nxv2bf16:
; BF16: // %bb.0:
; BF16-NEXT: ptrue p0.d
; BF16-NEXT: bfcvt z0.h, p0/m, z0.s
; BF16-NEXT: ret
%res = fptrunc <vscale x 2 x float> %a to <vscale x 2 x bfloat>
ret <vscale x 2 x bfloat> %res
}

define <vscale x 4 x bfloat> @fptrunc_nxv4f32_to_nxv4bf16(<vscale x 4 x float> %a) {
; NOBF16-LABEL: fptrunc_nxv4f32_to_nxv4bf16:
; NOBF16: // %bb.0:
; NOBF16-NEXT: mov z1.s, #32767 // =0x7fff
; NOBF16-NEXT: lsr z2.s, z0.s, #16
; NOBF16-NEXT: ptrue p0.s
; NOBF16-NEXT: fcmuo p0.s, p0/z, z0.s, z0.s
; NOBF16-NEXT: and z2.s, z2.s, #0x1
; NOBF16-NEXT: add z1.s, z0.s, z1.s
; NOBF16-NEXT: orr z0.s, z0.s, #0x400000
; NOBF16-NEXT: add z1.s, z2.s, z1.s
; NOBF16-NEXT: sel z0.s, p0, z0.s, z1.s
; NOBF16-NEXT: lsr z0.s, z0.s, #16
; NOBF16-NEXT: ret
;
; NOBF16NNAN-LABEL: fptrunc_nxv4f32_to_nxv4bf16:
; NOBF16NNAN: // %bb.0:
; NOBF16NNAN-NEXT: mov z1.s, #32767 // =0x7fff
; NOBF16NNAN-NEXT: lsr z2.s, z0.s, #16
; NOBF16NNAN-NEXT: and z2.s, z2.s, #0x1
; NOBF16NNAN-NEXT: add z0.s, z0.s, z1.s
; NOBF16NNAN-NEXT: add z0.s, z2.s, z0.s
; NOBF16NNAN-NEXT: lsr z0.s, z0.s, #16
; NOBF16NNAN-NEXT: ret
;
; BF16-LABEL: fptrunc_nxv4f32_to_nxv4bf16:
; BF16: // %bb.0:
; BF16-NEXT: ptrue p0.s
; BF16-NEXT: bfcvt z0.h, p0/m, z0.s
; BF16-NEXT: ret
%res = fptrunc <vscale x 4 x float> %a to <vscale x 4 x bfloat>
ret <vscale x 4 x bfloat> %res
}

define <vscale x 8 x bfloat> @fptrunc_nxv8f32_to_nxv8bf16(<vscale x 8 x float> %a) {
; NOBF16-LABEL: fptrunc_nxv8f32_to_nxv8bf16:
; NOBF16: // %bb.0:
; NOBF16-NEXT: mov z2.s, #32767 // =0x7fff
; NOBF16-NEXT: lsr z3.s, z1.s, #16
; NOBF16-NEXT: lsr z4.s, z0.s, #16
; NOBF16-NEXT: ptrue p0.s
; NOBF16-NEXT: and z3.s, z3.s, #0x1
; NOBF16-NEXT: and z4.s, z4.s, #0x1
; NOBF16-NEXT: fcmuo p1.s, p0/z, z1.s, z1.s
; NOBF16-NEXT: add z5.s, z1.s, z2.s
; NOBF16-NEXT: add z2.s, z0.s, z2.s
; NOBF16-NEXT: fcmuo p0.s, p0/z, z0.s, z0.s
; NOBF16-NEXT: orr z1.s, z1.s, #0x400000
; NOBF16-NEXT: orr z0.s, z0.s, #0x400000
; NOBF16-NEXT: add z3.s, z3.s, z5.s
; NOBF16-NEXT: add z2.s, z4.s, z2.s
; NOBF16-NEXT: sel z1.s, p1, z1.s, z3.s
; NOBF16-NEXT: sel z0.s, p0, z0.s, z2.s
; NOBF16-NEXT: lsr z1.s, z1.s, #16
; NOBF16-NEXT: lsr z0.s, z0.s, #16
; NOBF16-NEXT: uzp1 z0.h, z0.h, z1.h
; NOBF16-NEXT: ret
;
; NOBF16NNAN-LABEL: fptrunc_nxv8f32_to_nxv8bf16:
; NOBF16NNAN: // %bb.0:
; NOBF16NNAN-NEXT: mov z2.s, #32767 // =0x7fff
; NOBF16NNAN-NEXT: lsr z3.s, z1.s, #16
; NOBF16NNAN-NEXT: lsr z4.s, z0.s, #16
; NOBF16NNAN-NEXT: and z3.s, z3.s, #0x1
; NOBF16NNAN-NEXT: and z4.s, z4.s, #0x1
; NOBF16NNAN-NEXT: add z1.s, z1.s, z2.s
; NOBF16NNAN-NEXT: add z0.s, z0.s, z2.s
; NOBF16NNAN-NEXT: add z1.s, z3.s, z1.s
; NOBF16NNAN-NEXT: add z0.s, z4.s, z0.s
; NOBF16NNAN-NEXT: lsr z1.s, z1.s, #16
; NOBF16NNAN-NEXT: lsr z0.s, z0.s, #16
; NOBF16NNAN-NEXT: uzp1 z0.h, z0.h, z1.h
; NOBF16NNAN-NEXT: ret
;
; BF16-LABEL: fptrunc_nxv8f32_to_nxv8bf16:
; BF16: // %bb.0:
; BF16-NEXT: ptrue p0.s
; BF16-NEXT: bfcvt z1.h, p0/m, z1.s
; BF16-NEXT: bfcvt z0.h, p0/m, z0.s
; BF16-NEXT: uzp1 z0.h, z0.h, z1.h
; BF16-NEXT: ret
%res = fptrunc <vscale x 8 x float> %a to <vscale x 8 x bfloat>
ret <vscale x 8 x bfloat> %res
}

0 comments on commit 3e3780e

Please sign in to comment.