Skip to content

Commit

Permalink
[RISCV][ISel] Remove redundant vmerge for the vwadd. (llvm#78403)
Browse files Browse the repository at this point in the history
This patch is aiming at resolving the below missed-optimization case. 

### Code
```
define <8 x i64> @vwadd_mask_v8i32(<8 x i32> %x, <8 x i64> %y) {
    %mask = icmp slt <8 x i32> %x, <i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42>
    %a = select <8 x i1> %mask, <8 x i32> %x, <8 x i32> zeroinitializer
    %sa = sext <8 x i32> %a to <8 x i64>
    %ret = add <8 x i64> %sa, %y
    ret <8 x i64> %ret
}
```

### Before this patch
[Compiler Explorer](https://godbolt.org/z/cd1bKTrx6)
```
vwadd_mask_v8i32:
        li      a0, 42
        vsetivli        zero, 8, e32, m2, ta, ma
        vmslt.vx        v0, v8, a0
        vmv.v.i v10, 0
        vmerge.vvm      v16, v10, v8, v0
        vwadd.wv        v8, v12, v16
        ret
```

### After this patch
```
vwadd_mask_v8i32:
        li a0, 42
        vsetivli zero, 8, e32, m2, ta, ma
        vmslt.vx v0, v8, a0
        vsetvli zero, zero, e32, m2, tu, mu
        vwadd.wv v12, v12, v8, v0.t
        vmv4r.v v8, v12
        ret
```
This pattern could be found in a reduction with a widening destination

Specifically, we first do a fold like `(vwadd.wv y, (vmerge cond, x, 0))
-> (vwadd.wv y, x, y, cond)`, then do pattern matching on it.
  • Loading branch information
sun-jacobi authored Jan 27, 2024
1 parent 608d602 commit 3855757
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 1 deletion.
54 changes: 53 additions & 1 deletion llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13709,6 +13709,57 @@ static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N,
return InputRootReplacement;
}

// Fold (vwadd.wv y, (vmerge cond, x, 0)) -> vwadd.wv y, x, y, cond
// y will be the Passthru and cond will be the Mask.
static SDValue combineVWADDWSelect(SDNode *N, SelectionDAG &DAG) {
unsigned Opc = N->getOpcode();
assert(Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWADDU_W_VL);

SDValue Y = N->getOperand(0);
SDValue MergeOp = N->getOperand(1);
if (MergeOp.getOpcode() != RISCVISD::VMERGE_VL)
return SDValue();
SDValue X = MergeOp->getOperand(1);

if (!MergeOp.hasOneUse())
return SDValue();

// Passthru should be undef
SDValue Passthru = N->getOperand(2);
if (!Passthru.isUndef())
return SDValue();

// Mask should be all ones
SDValue Mask = N->getOperand(3);
if (Mask.getOpcode() != RISCVISD::VMSET_VL)
return SDValue();

// False value of MergeOp should be all zeros
SDValue Z = MergeOp->getOperand(2);
if (Z.getOpcode() != ISD::INSERT_SUBVECTOR)
return SDValue();
if (!ISD::isBuildVectorAllZeros(Z.getOperand(1).getNode()))
return SDValue();
if (!isNullOrNullSplat(Z.getOperand(0)) && !Z.getOperand(0).isUndef())
return SDValue();

return DAG.getNode(Opc, SDLoc(N), N->getValueType(0),
{Y, X, Y, MergeOp->getOperand(0), N->getOperand(4)},
N->getFlags());
}

static SDValue performVWADDW_VLCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
const RISCVSubtarget &Subtarget) {
unsigned Opc = N->getOpcode();
assert(Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWADDU_W_VL);

if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
return V;

return combineVWADDWSelect(N, DCI.DAG);
}

// Helper function for performMemPairCombine.
// Try to combine the memory loads/stores LSNode1 and LSNode2
// into a single memory pair operation.
Expand Down Expand Up @@ -15777,9 +15828,10 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
return V;
return combineToVWMACC(N, DAG, Subtarget);
case RISCVISD::SUB_VL:
case RISCVISD::VWADD_W_VL:
case RISCVISD::VWADDU_W_VL:
return performVWADDW_VLCombine(N, DCI, Subtarget);
case RISCVISD::SUB_VL:
case RISCVISD::VWSUB_W_VL:
case RISCVISD::VWSUBU_W_VL:
case RISCVISD::MUL_VL:
Expand Down
25 changes: 25 additions & 0 deletions llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,27 @@ multiclass VPatTiedBinaryNoMaskVL_V<SDNode vop,
GPR:$vl, sew, TU_MU)>;
}

class VPatTiedBinaryMaskVL_V<SDNode vop,
string instruction_name,
string suffix,
ValueType result_type,
ValueType op2_type,
ValueType mask_type,
int sew,
LMULInfo vlmul,
VReg result_reg_class,
VReg op2_reg_class> :
Pat<(result_type (vop
(result_type result_reg_class:$rs1),
(op2_type op2_reg_class:$rs2),
(result_type result_reg_class:$rs1),
(mask_type V0),
VLOpFrag)),
(!cast<Instruction>(instruction_name#"_"#suffix#"_"# vlmul.MX#"_MASK_TIED")
result_reg_class:$rs1,
op2_reg_class:$rs2,
(mask_type V0), GPR:$vl, sew, TU_MU)>;

multiclass VPatTiedBinaryNoMaskVL_V_RM<SDNode vop,
string instruction_name,
string suffix,
Expand Down Expand Up @@ -819,6 +840,10 @@ multiclass VPatBinaryWVL_VV_VX_WV_WX<SDPatternOperator vop, SDNode vop_w,
defm : VPatTiedBinaryNoMaskVL_V<vop_w, instruction_name, "WV",
wti.Vector, vti.Vector, vti.Log2SEW,
vti.LMul, wti.RegClass, vti.RegClass>;
def : VPatTiedBinaryMaskVL_V<vop_w, instruction_name, "WV",
wti.Vector, vti.Vector, wti.Mask,
vti.Log2SEW, vti.LMul, wti.RegClass,
vti.RegClass>;
def : VPatBinaryVL_V<vop_w, instruction_name, "WV",
wti.Vector, wti.Vector, vti.Vector, vti.Mask,
vti.Log2SEW, vti.LMul, wti.RegClass, wti.RegClass,
Expand Down
90 changes: 90 additions & 0 deletions llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwadd-mask.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK
; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK

define <8 x i64> @vwadd_wv_mask_v8i32(<8 x i32> %x, <8 x i64> %y) {
; CHECK-LABEL: vwadd_wv_mask_v8i32:
; CHECK: # %bb.0:
; CHECK-NEXT: li a0, 42
; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
; CHECK-NEXT: vmslt.vx v0, v8, a0
; CHECK-NEXT: vsetvli zero, zero, e32, m2, tu, mu
; CHECK-NEXT: vwadd.wv v12, v12, v8, v0.t
; CHECK-NEXT: vmv4r.v v8, v12
; CHECK-NEXT: ret
%mask = icmp slt <8 x i32> %x, <i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42>
%a = select <8 x i1> %mask, <8 x i32> %x, <8 x i32> zeroinitializer
%sa = sext <8 x i32> %a to <8 x i64>
%ret = add <8 x i64> %sa, %y
ret <8 x i64> %ret
}

define <8 x i64> @vwaddu_wv_mask_v8i32(<8 x i32> %x, <8 x i64> %y) {
; CHECK-LABEL: vwaddu_wv_mask_v8i32:
; CHECK: # %bb.0:
; CHECK-NEXT: li a0, 42
; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
; CHECK-NEXT: vmslt.vx v0, v8, a0
; CHECK-NEXT: vsetvli zero, zero, e32, m2, tu, mu
; CHECK-NEXT: vwaddu.wv v12, v12, v8, v0.t
; CHECK-NEXT: vmv4r.v v8, v12
; CHECK-NEXT: ret
%mask = icmp slt <8 x i32> %x, <i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42>
%a = select <8 x i1> %mask, <8 x i32> %x, <8 x i32> zeroinitializer
%sa = zext <8 x i32> %a to <8 x i64>
%ret = add <8 x i64> %sa, %y
ret <8 x i64> %ret
}

define <8 x i64> @vwaddu_vv_mask_v8i32(<8 x i32> %x, <8 x i32> %y) {
; CHECK-LABEL: vwaddu_vv_mask_v8i32:
; CHECK: # %bb.0:
; CHECK-NEXT: li a0, 42
; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
; CHECK-NEXT: vmslt.vx v0, v8, a0
; CHECK-NEXT: vmv.v.i v12, 0
; CHECK-NEXT: vmerge.vvm v8, v12, v8, v0
; CHECK-NEXT: vwaddu.vv v12, v8, v10
; CHECK-NEXT: vmv4r.v v8, v12
; CHECK-NEXT: ret
%mask = icmp slt <8 x i32> %x, <i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42>
%a = select <8 x i1> %mask, <8 x i32> %x, <8 x i32> zeroinitializer
%sa = zext <8 x i32> %a to <8 x i64>
%sy = zext <8 x i32> %y to <8 x i64>
%ret = add <8 x i64> %sa, %sy
ret <8 x i64> %ret
}

define <8 x i64> @vwadd_wv_mask_v8i32_commutative(<8 x i32> %x, <8 x i64> %y) {
; CHECK-LABEL: vwadd_wv_mask_v8i32_commutative:
; CHECK: # %bb.0:
; CHECK-NEXT: li a0, 42
; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
; CHECK-NEXT: vmslt.vx v0, v8, a0
; CHECK-NEXT: vsetvli zero, zero, e32, m2, tu, mu
; CHECK-NEXT: vwadd.wv v12, v12, v8, v0.t
; CHECK-NEXT: vmv4r.v v8, v12
; CHECK-NEXT: ret
%mask = icmp slt <8 x i32> %x, <i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42>
%a = select <8 x i1> %mask, <8 x i32> %x, <8 x i32> zeroinitializer
%sa = sext <8 x i32> %a to <8 x i64>
%ret = add <8 x i64> %y, %sa
ret <8 x i64> %ret
}

define <8 x i64> @vwadd_wv_mask_v8i32_nonzero(<8 x i32> %x, <8 x i64> %y) {
; CHECK-LABEL: vwadd_wv_mask_v8i32_nonzero:
; CHECK: # %bb.0:
; CHECK-NEXT: li a0, 42
; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
; CHECK-NEXT: vmslt.vx v0, v8, a0
; CHECK-NEXT: vmv.v.i v10, 1
; CHECK-NEXT: vmerge.vvm v16, v10, v8, v0
; CHECK-NEXT: vwadd.wv v8, v12, v16
; CHECK-NEXT: ret
%mask = icmp slt <8 x i32> %x, <i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42>
%a = select <8 x i1> %mask, <8 x i32> %x, <8 x i32> <i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1>
%sa = sext <8 x i32> %a to <8 x i64>
%ret = add <8 x i64> %y, %sa
ret <8 x i64> %ret
}

0 comments on commit 3855757

Please sign in to comment.