Skip to content

Commit

Permalink
[DAG][RISCV] Use vp_reduce_fadd/fmul when widening types for FP reduc…
Browse files Browse the repository at this point in the history
…tions (llvm#105840)

This is a follow up to llvm#105455 which updates the VPIntrinsic mappings
for the fadd and fmul cases, and supports both ordered and unordered
reductions. This allows the use a single wider operation with a
restricted EVL instead of padding the vector with the neutral element.

This has all the same tradeoffs as the previous patch.
  • Loading branch information
preames committed Aug 24, 2024
1 parent 31b4bf9 commit 2cb25d5
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 86 deletions.
12 changes: 7 additions & 5 deletions llvm/include/llvm/IR/VPIntrinsics.def
Original file line number Diff line number Diff line change
Expand Up @@ -722,27 +722,29 @@ HELPER_REGISTER_REDUCTION_VP(vp_reduce_fminimum, VP_REDUCE_FMINIMUM,
#error \
"The internal helper macro HELPER_REGISTER_REDUCTION_SEQ_VP is already defined!"
#endif
#define HELPER_REGISTER_REDUCTION_SEQ_VP(VPID, VPSD, SEQ_VPSD, INTRIN) \
#define HELPER_REGISTER_REDUCTION_SEQ_VP(VPID, VPSD, SEQ_VPSD, SDOPC, SEQ_SDOPC, INTRIN) \
BEGIN_REGISTER_VP_INTRINSIC(VPID, 2, 3) \
BEGIN_REGISTER_VP_SDNODE(VPSD, 1, VPID, 2, 3) \
VP_PROPERTY_REDUCTION(0, 1) \
VP_PROPERTY_FUNCTIONAL_SDOPC(SDOPC) \
END_REGISTER_VP_SDNODE(VPSD) \
BEGIN_REGISTER_VP_SDNODE(SEQ_VPSD, 1, VPID, 2, 3) \
HELPER_MAP_VPID_TO_VPSD(VPID, SEQ_VPSD) \
VP_PROPERTY_FUNCTIONAL_SDOPC(SEQ_SDOPC) \
VP_PROPERTY_REDUCTION(0, 1) \
END_REGISTER_VP_SDNODE(SEQ_VPSD) \
VP_PROPERTY_FUNCTIONAL_INTRINSIC(INTRIN) \
END_REGISTER_VP_INTRINSIC(VPID)

// llvm.vp.reduce.fadd(start,x,mask,vlen)
HELPER_REGISTER_REDUCTION_SEQ_VP(vp_reduce_fadd, VP_REDUCE_FADD,
VP_REDUCE_SEQ_FADD,
vector_reduce_fadd)
VP_REDUCE_SEQ_FADD, VECREDUCE_FADD,
VECREDUCE_SEQ_FADD, vector_reduce_fadd)

// llvm.vp.reduce.fmul(start,x,mask,vlen)
HELPER_REGISTER_REDUCTION_SEQ_VP(vp_reduce_fmul, VP_REDUCE_FMUL,
VP_REDUCE_SEQ_FMUL,
vector_reduce_fmul)
VP_REDUCE_SEQ_FMUL, VECREDUCE_FMUL,
VECREDUCE_SEQ_FMUL, vector_reduce_fmul)

#undef HELPER_REGISTER_REDUCTION_SEQ_VP

Expand Down
20 changes: 16 additions & 4 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7311,8 +7311,6 @@ SDValue DAGTypeLegalizer::WidenVecOp_VECREDUCE(SDNode *N) {
// Generate a vp.reduce_op if it is custom/legal for the target. This avoids
// needing to pad the source vector, because the inactive lanes can simply be
// disabled and not contribute to the result.
// TODO: VECREDUCE_FADD, VECREDUCE_FMUL aren't currently mapped correctly,
// and thus don't take this path.
if (auto VPOpcode = ISD::getVPForBaseOpcode(Opc);
VPOpcode && TLI.isOperationLegalOrCustom(*VPOpcode, WideVT)) {
SDValue Start = NeutralElem;
Expand Down Expand Up @@ -7351,6 +7349,7 @@ SDValue DAGTypeLegalizer::WidenVecOp_VECREDUCE_SEQ(SDNode *N) {
SDValue VecOp = N->getOperand(1);
SDValue Op = GetWidenedVector(VecOp);

EVT VT = N->getValueType(0);
EVT OrigVT = VecOp.getValueType();
EVT WideVT = Op.getValueType();
EVT ElemVT = OrigVT.getVectorElementType();
Expand All @@ -7364,6 +7363,19 @@ SDValue DAGTypeLegalizer::WidenVecOp_VECREDUCE_SEQ(SDNode *N) {
unsigned OrigElts = OrigVT.getVectorMinNumElements();
unsigned WideElts = WideVT.getVectorMinNumElements();

// Generate a vp.reduce_op if it is custom/legal for the target. This avoids
// needing to pad the source vector, because the inactive lanes can simply be
// disabled and not contribute to the result.
if (auto VPOpcode = ISD::getVPForBaseOpcode(Opc);
VPOpcode && TLI.isOperationLegalOrCustom(*VPOpcode, WideVT)) {
EVT WideMaskVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1,
WideVT.getVectorElementCount());
SDValue Mask = DAG.getAllOnesConstant(dl, WideMaskVT);
SDValue EVL = DAG.getElementCount(dl, TLI.getVPExplicitVectorLengthTy(),
OrigVT.getVectorElementCount());
return DAG.getNode(*VPOpcode, dl, VT, {AccOp, Op, Mask, EVL}, Flags);
}

if (WideVT.isScalableVector()) {
unsigned GCD = std::gcd(OrigElts, WideElts);
EVT SplatVT = EVT::getVectorVT(*DAG.getContext(), ElemVT,
Expand All @@ -7372,14 +7384,14 @@ SDValue DAGTypeLegalizer::WidenVecOp_VECREDUCE_SEQ(SDNode *N) {
for (unsigned Idx = OrigElts; Idx < WideElts; Idx = Idx + GCD)
Op = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, WideVT, Op, SplatNeutral,
DAG.getVectorIdxConstant(Idx, dl));
return DAG.getNode(Opc, dl, N->getValueType(0), AccOp, Op, Flags);
return DAG.getNode(Opc, dl, VT, AccOp, Op, Flags);
}

for (unsigned Idx = OrigElts; Idx < WideElts; Idx++)
Op = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, WideVT, Op, NeutralElem,
DAG.getVectorIdxConstant(Idx, dl));

return DAG.getNode(Opc, dl, N->getValueType(0), AccOp, Op, Flags);
return DAG.getNode(Opc, dl, VT, AccOp, Op, Flags);
}

SDValue DAGTypeLegalizer::WidenVecOp_VP_REDUCE(SDNode *N) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -791,12 +791,7 @@ define float @reduce_fadd_16xi32_prefix5(ptr %p) {
; CHECK-NEXT: vle32.v v8, (a0)
; CHECK-NEXT: lui a0, 524288
; CHECK-NEXT: vmv.s.x v10, a0
; CHECK-NEXT: vsetivli zero, 6, e32, m2, tu, ma
; CHECK-NEXT: vslideup.vi v8, v10, 5
; CHECK-NEXT: vsetivli zero, 7, e32, m2, tu, ma
; CHECK-NEXT: vslideup.vi v8, v10, 6
; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
; CHECK-NEXT: vslideup.vi v8, v10, 7
; CHECK-NEXT: vsetivli zero, 5, e32, m2, ta, ma
; CHECK-NEXT: vfredusum.vs v8, v8, v10
; CHECK-NEXT: vfmv.f.s fa0, v8
; CHECK-NEXT: ret
Expand Down Expand Up @@ -880,7 +875,7 @@ define float @reduce_fadd_4xi32_non_associative(ptr %p) {
; CHECK-NEXT: vfmv.f.s fa5, v9
; CHECK-NEXT: lui a0, 524288
; CHECK-NEXT: vmv.s.x v9, a0
; CHECK-NEXT: vslideup.vi v8, v9, 3
; CHECK-NEXT: vsetivli zero, 3, e32, m1, ta, ma
; CHECK-NEXT: vfredusum.vs v8, v8, v9
; CHECK-NEXT: vfmv.f.s fa4, v8
; CHECK-NEXT: fadd.s fa0, fa4, fa5
Expand Down
12 changes: 0 additions & 12 deletions llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-fp.ll
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,6 @@ define half @vreduce_fadd_v7f16(ptr %x, half %s) {
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 7, e16, m1, ta, ma
; CHECK-NEXT: vle16.v v8, (a0)
; CHECK-NEXT: lui a0, 1048568
; CHECK-NEXT: vmv.s.x v9, a0
; CHECK-NEXT: vsetivli zero, 8, e16, m1, ta, ma
; CHECK-NEXT: vslideup.vi v8, v9, 7
; CHECK-NEXT: vfmv.s.f v9, fa0
; CHECK-NEXT: vfredusum.vs v8, v8, v9
; CHECK-NEXT: vfmv.f.s fa0, v8
Expand Down Expand Up @@ -470,10 +466,6 @@ define float @vreduce_fadd_v7f32(ptr %x, float %s) {
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 7, e32, m2, ta, ma
; CHECK-NEXT: vle32.v v8, (a0)
; CHECK-NEXT: lui a0, 524288
; CHECK-NEXT: vmv.s.x v10, a0
; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
; CHECK-NEXT: vslideup.vi v8, v10, 7
; CHECK-NEXT: vfmv.s.f v10, fa0
; CHECK-NEXT: vfredusum.vs v8, v8, v10
; CHECK-NEXT: vfmv.f.s fa0, v8
Expand All @@ -488,10 +480,6 @@ define float @vreduce_ord_fadd_v7f32(ptr %x, float %s) {
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 7, e32, m2, ta, ma
; CHECK-NEXT: vle32.v v8, (a0)
; CHECK-NEXT: lui a0, 524288
; CHECK-NEXT: vmv.s.x v10, a0
; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
; CHECK-NEXT: vslideup.vi v8, v10, 7
; CHECK-NEXT: vfmv.s.f v10, fa0
; CHECK-NEXT: vfredosum.vs v8, v8, v10
; CHECK-NEXT: vfmv.f.s fa0, v8
Expand Down
100 changes: 42 additions & 58 deletions llvm/test/CodeGen/RISCV/rvv/vreductions-fp-sdnode.ll
Original file line number Diff line number Diff line change
Expand Up @@ -889,17 +889,12 @@ define half @vreduce_ord_fadd_nxv3f16(<vscale x 3 x half> %v, half %s) {
; CHECK-NEXT: csrr a0, vlenb
; CHECK-NEXT: srli a0, a0, 3
; CHECK-NEXT: slli a1, a0, 1
; CHECK-NEXT: add a1, a1, a0
; CHECK-NEXT: add a0, a1, a0
; CHECK-NEXT: lui a2, 1048568
; CHECK-NEXT: vsetvli a3, zero, e16, m1, ta, ma
; CHECK-NEXT: vmv.v.x v9, a2
; CHECK-NEXT: vsetvli zero, a0, e16, m1, ta, ma
; CHECK-NEXT: vslideup.vx v8, v9, a1
; CHECK-NEXT: vsetvli a0, zero, e16, m1, ta, ma
; CHECK-NEXT: vsetivli zero, 1, e16, m1, ta, ma
; CHECK-NEXT: vfmv.s.f v9, fa0
; CHECK-NEXT: vfredosum.vs v8, v8, v9
; CHECK-NEXT: vfmv.f.s fa0, v8
; CHECK-NEXT: vsetvli zero, a0, e16, m1, ta, ma
; CHECK-NEXT: vfredosum.vs v9, v8, v9
; CHECK-NEXT: vfmv.f.s fa0, v9
; CHECK-NEXT: ret
%red = call half @llvm.vector.reduce.fadd.nxv3f16(half %s, <vscale x 3 x half> %v)
ret half %red
Expand All @@ -910,18 +905,15 @@ declare half @llvm.vector.reduce.fadd.nxv6f16(half, <vscale x 6 x half>)
define half @vreduce_ord_fadd_nxv6f16(<vscale x 6 x half> %v, half %s) {
; CHECK-LABEL: vreduce_ord_fadd_nxv6f16:
; CHECK: # %bb.0:
; CHECK-NEXT: lui a0, 1048568
; CHECK-NEXT: vsetvli a1, zero, e16, m1, ta, ma
; CHECK-NEXT: vmv.v.x v10, a0
; CHECK-NEXT: csrr a0, vlenb
; CHECK-NEXT: srli a0, a0, 2
; CHECK-NEXT: add a1, a0, a0
; CHECK-NEXT: vsetvli zero, a1, e16, m1, ta, ma
; CHECK-NEXT: vslideup.vx v9, v10, a0
; CHECK-NEXT: vsetvli a0, zero, e16, m2, ta, ma
; CHECK-NEXT: srli a1, a0, 3
; CHECK-NEXT: slli a1, a1, 1
; CHECK-NEXT: sub a0, a0, a1
; CHECK-NEXT: vsetivli zero, 1, e16, m1, ta, ma
; CHECK-NEXT: vfmv.s.f v10, fa0
; CHECK-NEXT: vfredosum.vs v8, v8, v10
; CHECK-NEXT: vfmv.f.s fa0, v8
; CHECK-NEXT: vsetvli zero, a0, e16, m2, ta, ma
; CHECK-NEXT: vfredosum.vs v10, v8, v10
; CHECK-NEXT: vfmv.f.s fa0, v10
; CHECK-NEXT: ret
%red = call half @llvm.vector.reduce.fadd.nxv6f16(half %s, <vscale x 6 x half> %v)
ret half %red
Expand All @@ -932,22 +924,15 @@ declare half @llvm.vector.reduce.fadd.nxv10f16(half, <vscale x 10 x half>)
define half @vreduce_ord_fadd_nxv10f16(<vscale x 10 x half> %v, half %s) {
; CHECK-LABEL: vreduce_ord_fadd_nxv10f16:
; CHECK: # %bb.0:
; CHECK-NEXT: lui a0, 1048568
; CHECK-NEXT: vsetvli a1, zero, e16, m1, ta, ma
; CHECK-NEXT: vmv.v.x v12, a0
; CHECK-NEXT: csrr a0, vlenb
; CHECK-NEXT: srli a0, a0, 2
; CHECK-NEXT: add a1, a0, a0
; CHECK-NEXT: vsetvli zero, a1, e16, m1, ta, ma
; CHECK-NEXT: vslideup.vx v10, v12, a0
; CHECK-NEXT: vsetvli zero, a0, e16, m1, tu, ma
; CHECK-NEXT: vmv.v.v v11, v12
; CHECK-NEXT: vsetvli zero, a1, e16, m1, ta, ma
; CHECK-NEXT: vslideup.vx v11, v12, a0
; CHECK-NEXT: vsetvli a0, zero, e16, m4, ta, ma
; CHECK-NEXT: srli a0, a0, 3
; CHECK-NEXT: li a1, 10
; CHECK-NEXT: mul a0, a0, a1
; CHECK-NEXT: vsetivli zero, 1, e16, m1, ta, ma
; CHECK-NEXT: vfmv.s.f v12, fa0
; CHECK-NEXT: vfredosum.vs v8, v8, v12
; CHECK-NEXT: vfmv.f.s fa0, v8
; CHECK-NEXT: vsetvli zero, a0, e16, m4, ta, ma
; CHECK-NEXT: vfredosum.vs v12, v8, v12
; CHECK-NEXT: vfmv.f.s fa0, v12
; CHECK-NEXT: ret
%red = call half @llvm.vector.reduce.fadd.nxv10f16(half %s, <vscale x 10 x half> %v)
ret half %red
Expand All @@ -958,13 +943,16 @@ declare half @llvm.vector.reduce.fadd.nxv12f16(half, <vscale x 12 x half>)
define half @vreduce_ord_fadd_nxv12f16(<vscale x 12 x half> %v, half %s) {
; CHECK-LABEL: vreduce_ord_fadd_nxv12f16:
; CHECK: # %bb.0:
; CHECK-NEXT: lui a0, 1048568
; CHECK-NEXT: vsetvli a1, zero, e16, m1, ta, ma
; CHECK-NEXT: vmv.v.x v11, a0
; CHECK-NEXT: csrr a0, vlenb
; CHECK-NEXT: srli a0, a0, 3
; CHECK-NEXT: slli a1, a0, 2
; CHECK-NEXT: slli a0, a0, 4
; CHECK-NEXT: sub a0, a0, a1
; CHECK-NEXT: vsetivli zero, 1, e16, m1, ta, ma
; CHECK-NEXT: vfmv.s.f v12, fa0
; CHECK-NEXT: vsetvli a0, zero, e16, m4, ta, ma
; CHECK-NEXT: vfredosum.vs v8, v8, v12
; CHECK-NEXT: vfmv.f.s fa0, v8
; CHECK-NEXT: vsetvli zero, a0, e16, m4, ta, ma
; CHECK-NEXT: vfredosum.vs v12, v8, v12
; CHECK-NEXT: vfmv.f.s fa0, v12
; CHECK-NEXT: ret
%red = call half @llvm.vector.reduce.fadd.nxv12f16(half %s, <vscale x 12 x half> %v)
ret half %red
Expand All @@ -977,17 +965,14 @@ define half @vreduce_fadd_nxv3f16(<vscale x 3 x half> %v, half %s) {
; CHECK-NEXT: csrr a0, vlenb
; CHECK-NEXT: srli a0, a0, 3
; CHECK-NEXT: slli a1, a0, 1
; CHECK-NEXT: add a1, a1, a0
; CHECK-NEXT: add a0, a1, a0
; CHECK-NEXT: lui a2, 1048568
; CHECK-NEXT: vsetvli a3, zero, e16, m1, ta, ma
; CHECK-NEXT: vmv.v.x v9, a2
; CHECK-NEXT: vsetvli zero, a0, e16, m1, ta, ma
; CHECK-NEXT: vslideup.vx v8, v9, a1
; CHECK-NEXT: vsetvli a0, zero, e16, m1, ta, ma
; CHECK-NEXT: vsetivli zero, 1, e16, m1, ta, ma
; CHECK-NEXT: vfmv.s.f v9, fa0
; CHECK-NEXT: vfredusum.vs v8, v8, v9
; CHECK-NEXT: vfmv.f.s fa0, v8
; CHECK-NEXT: lui a1, 1048568
; CHECK-NEXT: vmv.s.x v10, a1
; CHECK-NEXT: vsetvli zero, a0, e16, m1, ta, ma
; CHECK-NEXT: vfredusum.vs v10, v8, v9
; CHECK-NEXT: vfmv.f.s fa0, v10
; CHECK-NEXT: ret
%red = call reassoc half @llvm.vector.reduce.fadd.nxv3f16(half %s, <vscale x 3 x half> %v)
ret half %red
Expand All @@ -996,18 +981,17 @@ define half @vreduce_fadd_nxv3f16(<vscale x 3 x half> %v, half %s) {
define half @vreduce_fadd_nxv6f16(<vscale x 6 x half> %v, half %s) {
; CHECK-LABEL: vreduce_fadd_nxv6f16:
; CHECK: # %bb.0:
; CHECK-NEXT: lui a0, 1048568
; CHECK-NEXT: vsetvli a1, zero, e16, m1, ta, ma
; CHECK-NEXT: vmv.v.x v10, a0
; CHECK-NEXT: csrr a0, vlenb
; CHECK-NEXT: srli a0, a0, 2
; CHECK-NEXT: add a1, a0, a0
; CHECK-NEXT: vsetvli zero, a1, e16, m1, ta, ma
; CHECK-NEXT: vslideup.vx v9, v10, a0
; CHECK-NEXT: vsetvli a0, zero, e16, m2, ta, ma
; CHECK-NEXT: srli a1, a0, 3
; CHECK-NEXT: slli a1, a1, 1
; CHECK-NEXT: sub a0, a0, a1
; CHECK-NEXT: vsetivli zero, 1, e16, m1, ta, ma
; CHECK-NEXT: vfmv.s.f v10, fa0
; CHECK-NEXT: vfredusum.vs v8, v8, v10
; CHECK-NEXT: vfmv.f.s fa0, v8
; CHECK-NEXT: lui a1, 1048568
; CHECK-NEXT: vmv.s.x v11, a1
; CHECK-NEXT: vsetvli zero, a0, e16, m2, ta, ma
; CHECK-NEXT: vfredusum.vs v11, v8, v10
; CHECK-NEXT: vfmv.f.s fa0, v11
; CHECK-NEXT: ret
%red = call reassoc half @llvm.vector.reduce.fadd.nxv6f16(half %s, <vscale x 6 x half> %v)
ret half %red
Expand Down

0 comments on commit 2cb25d5

Please sign in to comment.