Skip to content

Commit

Permalink
[AArch64] Add more complete support for BF16
Browse files Browse the repository at this point in the history
We can use a small amount of integer arithmetic to round FP32 to BF16
and extend BF16 to FP32.

While a number of operations still require promotion, this can be
reduced for some rather simple operations like abs, copysign, fneg but
these can be done in a follow-up.

A few neat optimizations are implemented:
- round-inexact-to-odd is used for F64 to BF16 rounding.
- quieting signaling NaNs for f32 -> bf16 tries to detect if a prior
  operation makes it unnecessary.
  • Loading branch information
majnemer committed Mar 3, 2024
1 parent 2435dcd commit 3dd6750
Show file tree
Hide file tree
Showing 13 changed files with 842 additions and 197 deletions.
5 changes: 3 additions & 2 deletions llvm/include/llvm/CodeGen/TargetLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -1573,13 +1573,14 @@ class TargetLoweringBase {
assert((VT.isInteger() || VT.isFloatingPoint()) &&
"Cannot autopromote this type, add it with AddPromotedToType.");

uint64_t VTBits = VT.getScalarSizeInBits();
MVT NVT = VT;
do {
NVT = (MVT::SimpleValueType)(NVT.SimpleTy+1);
assert(NVT.isInteger() == VT.isInteger() && NVT != MVT::isVoid &&
"Didn't find type to promote to!");
} while (!isTypeLegal(NVT) ||
getOperationAction(Op, NVT) == Promote);
} while (VTBits >= NVT.getScalarSizeInBits() || !isTypeLegal(NVT) ||
getOperationAction(Op, NVT) == Promote);
return NVT;
}

Expand Down
7 changes: 7 additions & 0 deletions llvm/include/llvm/CodeGen/ValueTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,13 @@ namespace llvm {
return changeExtendedVectorElementType(EltVT);
}

/// Return a VT for a type whose attributes match ourselves with the
/// exception of the element type that is chosen by the caller.
EVT changeElementType(EVT EltVT) const {
EltVT = EltVT.getScalarType();
return isVector() ? changeVectorElementType(EltVT) : EltVT;
}

/// Return the type converted to an equivalently sized integer or vector
/// with integer element type. Similar to changeVectorElementTypeToInteger,
/// but also handles scalars.
Expand Down
290 changes: 201 additions & 89 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,9 @@ enum NodeType : unsigned {
FCMLEz,
FCMLTz,

// Round wide FP to narrow FP with inexact results to odd.
FCVTXN,

// Vector across-lanes addition
// Only the lower result lane is defined.
SADDV,
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/AArch64/AArch64InstrFormats.td
Original file line number Diff line number Diff line change
Expand Up @@ -7547,7 +7547,7 @@ class BaseSIMDCmpTwoScalar<bit U, bits<2> size, bits<2> size2, bits<5> opcode,
let mayRaiseFPException = 1, Uses = [FPCR] in
class SIMDInexactCvtTwoScalar<bits<5> opcode, string asm>
: I<(outs FPR32:$Rd), (ins FPR64:$Rn), asm, "\t$Rd, $Rn", "",
[(set (f32 FPR32:$Rd), (int_aarch64_sisd_fcvtxn (f64 FPR64:$Rn)))]>,
[(set (f32 FPR32:$Rd), (AArch64fcvtxn (f64 FPR64:$Rn)))]>,
Sched<[WriteVd]> {
bits<5> Rd;
bits<5> Rn;
Expand Down
37 changes: 37 additions & 0 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,11 @@ def AArch64fcmgtz: SDNode<"AArch64ISD::FCMGTz", SDT_AArch64fcmpz>;
def AArch64fcmlez: SDNode<"AArch64ISD::FCMLEz", SDT_AArch64fcmpz>;
def AArch64fcmltz: SDNode<"AArch64ISD::FCMLTz", SDT_AArch64fcmpz>;

def AArch64fcvtxn_n: SDNode<"AArch64ISD::FCVTXN", SDTFPRoundOp>;
def AArch64fcvtxn: PatFrags<(ops node:$Rn),
[(f32 (int_aarch64_sisd_fcvtxn (f64 node:$Rn))),
(f32 (AArch64fcvtxn_n (f64 node:$Rn)))]>;

def AArch64bici: SDNode<"AArch64ISD::BICi", SDT_AArch64vecimm>;
def AArch64orri: SDNode<"AArch64ISD::ORRi", SDT_AArch64vecimm>;

Expand Down Expand Up @@ -1276,6 +1281,9 @@ def BFMLALTIdx : SIMDBF16MLALIndex<1, "bfmlalt", int_aarch64_neon_bfmlalt>;
def BFCVTN : SIMD_BFCVTN;
def BFCVTN2 : SIMD_BFCVTN2;

def : Pat<(v4bf16 (any_fpround (v4f32 V128:$Rn))),
(EXTRACT_SUBREG (BFCVTN V128:$Rn), dsub)>;

// Vector-scalar BFDOT:
// The second source operand of the 64-bit variant of BF16DOTlane is a 128-bit
// register (the instruction uses a single 32-bit lane from it), so the pattern
Expand All @@ -1296,6 +1304,8 @@ def : Pat<(v2f32 (int_aarch64_neon_bfdot

let Predicates = [HasNEONorSME, HasBF16] in {
def BFCVT : BF16ToSinglePrecision<"bfcvt">;
// Round FP32 to BF16.
def : Pat<(bf16 (any_fpround (f32 FPR32:$Rn))), (BFCVT $Rn)>;
}

// ARMv8.6A AArch64 matrix multiplication
Expand Down Expand Up @@ -4648,6 +4658,22 @@ let Predicates = [HasFullFP16] in {
//===----------------------------------------------------------------------===//

defm FCVT : FPConversion<"fcvt">;
// Helper to get bf16 into fp32.
def cvt_bf16_to_fp32 :
OutPatFrag<(ops node:$Rn),
(f32 (COPY_TO_REGCLASS
(i32 (UBFMWri
(i32 (COPY_TO_REGCLASS (INSERT_SUBREG (f32 (IMPLICIT_DEF)),
node:$Rn, hsub), GPR32)),
(i64 (i32shift_a (i64 16))),
(i64 (i32shift_b (i64 16))))),
FPR32))>;
// Pattern for bf16 -> fp32.
def : Pat<(f32 (any_fpextend (bf16 FPR16:$Rn))),
(cvt_bf16_to_fp32 FPR16:$Rn)>;
// Pattern for bf16 -> fp64.
def : Pat<(f64 (any_fpextend (bf16 FPR16:$Rn))),
(FCVTDSr (f32 (cvt_bf16_to_fp32 FPR16:$Rn)))>;

//===----------------------------------------------------------------------===//
// Floating point single operand instructions.
Expand Down Expand Up @@ -5002,6 +5028,9 @@ defm FCVTNU : SIMDTwoVectorFPToInt<1,0,0b11010, "fcvtnu",int_aarch64_neon_fcvtnu
defm FCVTN : SIMDFPNarrowTwoVector<0, 0, 0b10110, "fcvtn">;
def : Pat<(v4i16 (int_aarch64_neon_vcvtfp2hf (v4f32 V128:$Rn))),
(FCVTNv4i16 V128:$Rn)>;
//def : Pat<(concat_vectors V64:$Rd,
// (v4bf16 (any_fpround (v4f32 V128:$Rn)))),
// (FCVTNv8bf16 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Rd, dsub), V128:$Rn)>;
def : Pat<(concat_vectors V64:$Rd,
(v4i16 (int_aarch64_neon_vcvtfp2hf (v4f32 V128:$Rn)))),
(FCVTNv8i16 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Rd, dsub), V128:$Rn)>;
Expand Down Expand Up @@ -5686,6 +5715,11 @@ defm USQADD : SIMDTwoScalarBHSDTied< 1, 0b00011, "usqadd",
def : Pat<(v1i64 (AArch64vashr (v1i64 V64:$Rn), (i32 63))),
(CMLTv1i64rz V64:$Rn)>;

// Round FP64 to BF16.
let Predicates = [HasNEONorSME, HasBF16] in
def : Pat<(bf16 (any_fpround (f64 FPR64:$Rn))),
(BFCVT (FCVTXNv1i64 $Rn))>;

def : Pat<(v1i64 (int_aarch64_neon_fcvtas (v1f64 FPR64:$Rn))),
(FCVTASv1i64 FPR64:$Rn)>;
def : Pat<(v1i64 (int_aarch64_neon_fcvtau (v1f64 FPR64:$Rn))),
Expand Down Expand Up @@ -7698,6 +7732,9 @@ def : Pat<(v4i32 (anyext (v4i16 V64:$Rn))), (USHLLv4i16_shift V64:$Rn, (i32 0))>
def : Pat<(v2i64 (sext (v2i32 V64:$Rn))), (SSHLLv2i32_shift V64:$Rn, (i32 0))>;
def : Pat<(v2i64 (zext (v2i32 V64:$Rn))), (USHLLv2i32_shift V64:$Rn, (i32 0))>;
def : Pat<(v2i64 (anyext (v2i32 V64:$Rn))), (USHLLv2i32_shift V64:$Rn, (i32 0))>;
// Vector bf16 -> fp32 is implemented morally as a zext + shift.
def : Pat<(v4f32 (any_fpextend (v4bf16 V64:$Rn))),
(USHLLv4i16_shift V64:$Rn, (i32 16))>;
// Also match an extend from the upper half of a 128 bit source register.
def : Pat<(v8i16 (anyext (v8i8 (extract_high_v16i8 (v16i8 V128:$Rn)) ))),
(USHLLv16i8_shift V128:$Rn, (i32 0))>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1022,13 +1022,17 @@ void applyLowerVectorFCMP(MachineInstr &MI, MachineRegisterInfo &MRI,

bool Invert = false;
AArch64CC::CondCode CC, CC2 = AArch64CC::AL;
if (Pred == CmpInst::Predicate::FCMP_ORD && IsZero) {
if ((Pred == CmpInst::Predicate::FCMP_ORD ||
Pred == CmpInst::Predicate::FCMP_UNO) &&
IsZero) {
// The special case "fcmp ord %a, 0" is the canonical check that LHS isn't
// NaN, so equivalent to a == a and doesn't need the two comparisons an
// "ord" normally would.
// Similarly, "fcmp uno %a, 0" is the canonical check that LHS is NaN and is
// thus equivalent to a != a.
RHS = LHS;
IsZero = false;
CC = AArch64CC::EQ;
CC = Pred == CmpInst::Predicate::FCMP_ORD ? AArch64CC::EQ : AArch64CC::NE;
} else
changeVectorFCMPPredToAArch64CC(Pred, CC, CC2, Invert);

Expand Down
8 changes: 4 additions & 4 deletions llvm/test/Analysis/CostModel/AArch64/reduce-fadd.ll
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ define void @strict_fp_reductions() {
; CHECK-NEXT: Cost Model: Found an estimated cost of 28 for instruction: %fadd_v8f32 = call float @llvm.vector.reduce.fadd.v8f32(float 0.000000e+00, <8 x float> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 6 for instruction: %fadd_v2f64 = call double @llvm.vector.reduce.fadd.v2f64(double 0.000000e+00, <2 x double> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 12 for instruction: %fadd_v4f64 = call double @llvm.vector.reduce.fadd.v4f64(double 0.000000e+00, <4 x double> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 12 for instruction: %fadd_v4f8 = call bfloat @llvm.vector.reduce.fadd.v4bf16(bfloat 0xR0000, <4 x bfloat> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 18 for instruction: %fadd_v4f8 = call bfloat @llvm.vector.reduce.fadd.v4bf16(bfloat 0xR0000, <4 x bfloat> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 20 for instruction: %fadd_v4f128 = call fp128 @llvm.vector.reduce.fadd.v4f128(fp128 undef, <4 x fp128> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 0 for instruction: ret void
;
Expand All @@ -24,7 +24,7 @@ define void @strict_fp_reductions() {
; FP16-NEXT: Cost Model: Found an estimated cost of 28 for instruction: %fadd_v8f32 = call float @llvm.vector.reduce.fadd.v8f32(float 0.000000e+00, <8 x float> undef)
; FP16-NEXT: Cost Model: Found an estimated cost of 6 for instruction: %fadd_v2f64 = call double @llvm.vector.reduce.fadd.v2f64(double 0.000000e+00, <2 x double> undef)
; FP16-NEXT: Cost Model: Found an estimated cost of 12 for instruction: %fadd_v4f64 = call double @llvm.vector.reduce.fadd.v4f64(double 0.000000e+00, <4 x double> undef)
; FP16-NEXT: Cost Model: Found an estimated cost of 12 for instruction: %fadd_v4f8 = call bfloat @llvm.vector.reduce.fadd.v4bf16(bfloat 0xR0000, <4 x bfloat> undef)
; FP16-NEXT: Cost Model: Found an estimated cost of 18 for instruction: %fadd_v4f8 = call bfloat @llvm.vector.reduce.fadd.v4bf16(bfloat 0xR0000, <4 x bfloat> undef)
; FP16-NEXT: Cost Model: Found an estimated cost of 20 for instruction: %fadd_v4f128 = call fp128 @llvm.vector.reduce.fadd.v4f128(fp128 undef, <4 x fp128> undef)
; FP16-NEXT: Cost Model: Found an estimated cost of 0 for instruction: ret void
;
Expand Down Expand Up @@ -72,7 +72,7 @@ define void @fast_fp_reductions() {
; CHECK-NEXT: Cost Model: Found an estimated cost of 5 for instruction: %fadd_v4f64_reassoc = call reassoc double @llvm.vector.reduce.fadd.v4f64(double 0.000000e+00, <4 x double> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 9 for instruction: %fadd_v7f64 = call fast double @llvm.vector.reduce.fadd.v7f64(double 0.000000e+00, <7 x double> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 15 for instruction: %fadd_v9f64_reassoc = call reassoc double @llvm.vector.reduce.fadd.v9f64(double 0.000000e+00, <9 x double> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 6 for instruction: %fadd_v4f8 = call reassoc bfloat @llvm.vector.reduce.fadd.v4bf16(bfloat 0xR8000, <4 x bfloat> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 10 for instruction: %fadd_v4f8 = call reassoc bfloat @llvm.vector.reduce.fadd.v4bf16(bfloat 0xR8000, <4 x bfloat> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 12 for instruction: %fadd_v4f128 = call reassoc fp128 @llvm.vector.reduce.fadd.v4f128(fp128 undef, <4 x fp128> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 0 for instruction: ret void
;
Expand All @@ -95,7 +95,7 @@ define void @fast_fp_reductions() {
; FP16-NEXT: Cost Model: Found an estimated cost of 5 for instruction: %fadd_v4f64_reassoc = call reassoc double @llvm.vector.reduce.fadd.v4f64(double 0.000000e+00, <4 x double> undef)
; FP16-NEXT: Cost Model: Found an estimated cost of 9 for instruction: %fadd_v7f64 = call fast double @llvm.vector.reduce.fadd.v7f64(double 0.000000e+00, <7 x double> undef)
; FP16-NEXT: Cost Model: Found an estimated cost of 15 for instruction: %fadd_v9f64_reassoc = call reassoc double @llvm.vector.reduce.fadd.v9f64(double 0.000000e+00, <9 x double> undef)
; FP16-NEXT: Cost Model: Found an estimated cost of 6 for instruction: %fadd_v4f8 = call reassoc bfloat @llvm.vector.reduce.fadd.v4bf16(bfloat 0xR8000, <4 x bfloat> undef)
; FP16-NEXT: Cost Model: Found an estimated cost of 10 for instruction: %fadd_v4f8 = call reassoc bfloat @llvm.vector.reduce.fadd.v4bf16(bfloat 0xR8000, <4 x bfloat> undef)
; FP16-NEXT: Cost Model: Found an estimated cost of 12 for instruction: %fadd_v4f128 = call reassoc fp128 @llvm.vector.reduce.fadd.v4f128(fp128 undef, <4 x fp128> undef)
; FP16-NEXT: Cost Model: Found an estimated cost of 0 for instruction: ret void
;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,18 +321,15 @@ body: |
bb.0:
liveins: $q0, $q1
; Should be inverted. Needs two compares.
; CHECK-LABEL: name: uno_zero
; CHECK: liveins: $q0, $q1
; CHECK-NEXT: {{ $}}
; CHECK-NEXT: %lhs:_(<2 x s64>) = COPY $q0
; CHECK-NEXT: [[FCMGEZ:%[0-9]+]]:_(<2 x s64>) = G_FCMGEZ %lhs
; CHECK-NEXT: [[FCMLTZ:%[0-9]+]]:_(<2 x s64>) = G_FCMLTZ %lhs
; CHECK-NEXT: [[OR:%[0-9]+]]:_(<2 x s64>) = G_OR [[FCMLTZ]], [[FCMGEZ]]
; CHECK-NEXT: [[FCMEQ:%[0-9]+]]:_(<2 x s64>) = G_FCMEQ %lhs, %lhs(<2 x s64>)
; CHECK-NEXT: [[C:%[0-9]+]]:_(s64) = G_CONSTANT i64 -1
; CHECK-NEXT: [[BUILD_VECTOR:%[0-9]+]]:_(<2 x s64>) = G_BUILD_VECTOR [[C]](s64), [[C]](s64)
; CHECK-NEXT: [[XOR:%[0-9]+]]:_(<2 x s64>) = G_XOR [[OR]], [[BUILD_VECTOR]]
; CHECK-NEXT: [[XOR:%[0-9]+]]:_(<2 x s64>) = G_XOR [[FCMEQ]], [[BUILD_VECTOR]]
; CHECK-NEXT: $q0 = COPY [[XOR]](<2 x s64>)
; CHECK-NEXT: RET_ReallyLR implicit $q0
%lhs:_(<2 x s64>) = COPY $q0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,7 @@ entry:
define <8 x bfloat> @insertzero_v4bf16(<4 x bfloat> %a) {
; CHECK-LABEL: insertzero_v4bf16:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: movi d4, #0000000000000000
; CHECK-NEXT: movi d5, #0000000000000000
; CHECK-NEXT: movi d6, #0000000000000000
; CHECK-NEXT: movi d7, #0000000000000000
; CHECK-NEXT: fmov d0, d0
; CHECK-NEXT: ret
entry:
%shuffle.i = shufflevector <4 x bfloat> %a, <4 x bfloat> zeroinitializer, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
Expand Down
Loading

0 comments on commit 3dd6750

Please sign in to comment.