diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 5b1a246a19c8e8..47617642a5e9ae 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -13709,6 +13709,57 @@ static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N, return InputRootReplacement; } +// Fold (vwadd.wv y, (vmerge cond, x, 0)) -> vwadd.wv y, x, y, cond +// y will be the Passthru and cond will be the Mask. +static SDValue combineVWADDWSelect(SDNode *N, SelectionDAG &DAG) { + unsigned Opc = N->getOpcode(); + assert(Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWADDU_W_VL); + + SDValue Y = N->getOperand(0); + SDValue MergeOp = N->getOperand(1); + if (MergeOp.getOpcode() != RISCVISD::VMERGE_VL) + return SDValue(); + SDValue X = MergeOp->getOperand(1); + + if (!MergeOp.hasOneUse()) + return SDValue(); + + // Passthru should be undef + SDValue Passthru = N->getOperand(2); + if (!Passthru.isUndef()) + return SDValue(); + + // Mask should be all ones + SDValue Mask = N->getOperand(3); + if (Mask.getOpcode() != RISCVISD::VMSET_VL) + return SDValue(); + + // False value of MergeOp should be all zeros + SDValue Z = MergeOp->getOperand(2); + if (Z.getOpcode() != ISD::INSERT_SUBVECTOR) + return SDValue(); + if (!ISD::isBuildVectorAllZeros(Z.getOperand(1).getNode())) + return SDValue(); + if (!isNullOrNullSplat(Z.getOperand(0)) && !Z.getOperand(0).isUndef()) + return SDValue(); + + return DAG.getNode(Opc, SDLoc(N), N->getValueType(0), + {Y, X, Y, MergeOp->getOperand(0), N->getOperand(4)}, + N->getFlags()); +} + +static SDValue performVWADDW_VLCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI, + const RISCVSubtarget &Subtarget) { + unsigned Opc = N->getOpcode(); + assert(Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWADDU_W_VL); + + if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget)) + return V; + + return combineVWADDWSelect(N, DCI.DAG); +} + // Helper function for performMemPairCombine. // Try to combine the memory loads/stores LSNode1 and LSNode2 // into a single memory pair operation. @@ -15777,9 +15828,10 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget)) return V; return combineToVWMACC(N, DAG, Subtarget); - case RISCVISD::SUB_VL: case RISCVISD::VWADD_W_VL: case RISCVISD::VWADDU_W_VL: + return performVWADDW_VLCombine(N, DCI, Subtarget); + case RISCVISD::SUB_VL: case RISCVISD::VWSUB_W_VL: case RISCVISD::VWSUBU_W_VL: case RISCVISD::MUL_VL: diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td index 6e7be2647e8f83..cc44092700c66e 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td @@ -691,6 +691,27 @@ multiclass VPatTiedBinaryNoMaskVL_V; } +class VPatTiedBinaryMaskVL_V : + Pat<(result_type (vop + (result_type result_reg_class:$rs1), + (op2_type op2_reg_class:$rs2), + (result_type result_reg_class:$rs1), + (mask_type V0), + VLOpFrag)), + (!cast(instruction_name#"_"#suffix#"_"# vlmul.MX#"_MASK_TIED") + result_reg_class:$rs1, + op2_reg_class:$rs2, + (mask_type V0), GPR:$vl, sew, TU_MU)>; + multiclass VPatTiedBinaryNoMaskVL_V_RM; + def : VPatTiedBinaryMaskVL_V; def : VPatBinaryVL_V @vwadd_wv_mask_v8i32(<8 x i32> %x, <8 x i64> %y) { +; CHECK-LABEL: vwadd_wv_mask_v8i32: +; CHECK: # %bb.0: +; CHECK-NEXT: li a0, 42 +; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma +; CHECK-NEXT: vmslt.vx v0, v8, a0 +; CHECK-NEXT: vsetvli zero, zero, e32, m2, tu, mu +; CHECK-NEXT: vwadd.wv v12, v12, v8, v0.t +; CHECK-NEXT: vmv4r.v v8, v12 +; CHECK-NEXT: ret + %mask = icmp slt <8 x i32> %x, + %a = select <8 x i1> %mask, <8 x i32> %x, <8 x i32> zeroinitializer + %sa = sext <8 x i32> %a to <8 x i64> + %ret = add <8 x i64> %sa, %y + ret <8 x i64> %ret +} + +define <8 x i64> @vwaddu_wv_mask_v8i32(<8 x i32> %x, <8 x i64> %y) { +; CHECK-LABEL: vwaddu_wv_mask_v8i32: +; CHECK: # %bb.0: +; CHECK-NEXT: li a0, 42 +; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma +; CHECK-NEXT: vmslt.vx v0, v8, a0 +; CHECK-NEXT: vsetvli zero, zero, e32, m2, tu, mu +; CHECK-NEXT: vwaddu.wv v12, v12, v8, v0.t +; CHECK-NEXT: vmv4r.v v8, v12 +; CHECK-NEXT: ret + %mask = icmp slt <8 x i32> %x, + %a = select <8 x i1> %mask, <8 x i32> %x, <8 x i32> zeroinitializer + %sa = zext <8 x i32> %a to <8 x i64> + %ret = add <8 x i64> %sa, %y + ret <8 x i64> %ret +} + +define <8 x i64> @vwaddu_vv_mask_v8i32(<8 x i32> %x, <8 x i32> %y) { +; CHECK-LABEL: vwaddu_vv_mask_v8i32: +; CHECK: # %bb.0: +; CHECK-NEXT: li a0, 42 +; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma +; CHECK-NEXT: vmslt.vx v0, v8, a0 +; CHECK-NEXT: vmv.v.i v12, 0 +; CHECK-NEXT: vmerge.vvm v8, v12, v8, v0 +; CHECK-NEXT: vwaddu.vv v12, v8, v10 +; CHECK-NEXT: vmv4r.v v8, v12 +; CHECK-NEXT: ret + %mask = icmp slt <8 x i32> %x, + %a = select <8 x i1> %mask, <8 x i32> %x, <8 x i32> zeroinitializer + %sa = zext <8 x i32> %a to <8 x i64> + %sy = zext <8 x i32> %y to <8 x i64> + %ret = add <8 x i64> %sa, %sy + ret <8 x i64> %ret +} + +define <8 x i64> @vwadd_wv_mask_v8i32_commutative(<8 x i32> %x, <8 x i64> %y) { +; CHECK-LABEL: vwadd_wv_mask_v8i32_commutative: +; CHECK: # %bb.0: +; CHECK-NEXT: li a0, 42 +; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma +; CHECK-NEXT: vmslt.vx v0, v8, a0 +; CHECK-NEXT: vsetvli zero, zero, e32, m2, tu, mu +; CHECK-NEXT: vwadd.wv v12, v12, v8, v0.t +; CHECK-NEXT: vmv4r.v v8, v12 +; CHECK-NEXT: ret + %mask = icmp slt <8 x i32> %x, + %a = select <8 x i1> %mask, <8 x i32> %x, <8 x i32> zeroinitializer + %sa = sext <8 x i32> %a to <8 x i64> + %ret = add <8 x i64> %y, %sa + ret <8 x i64> %ret +} + +define <8 x i64> @vwadd_wv_mask_v8i32_nonzero(<8 x i32> %x, <8 x i64> %y) { +; CHECK-LABEL: vwadd_wv_mask_v8i32_nonzero: +; CHECK: # %bb.0: +; CHECK-NEXT: li a0, 42 +; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma +; CHECK-NEXT: vmslt.vx v0, v8, a0 +; CHECK-NEXT: vmv.v.i v10, 1 +; CHECK-NEXT: vmerge.vvm v16, v10, v8, v0 +; CHECK-NEXT: vwadd.wv v8, v12, v16 +; CHECK-NEXT: ret + %mask = icmp slt <8 x i32> %x, + %a = select <8 x i1> %mask, <8 x i32> %x, <8 x i32> + %sa = sext <8 x i32> %a to <8 x i64> + %ret = add <8 x i64> %y, %sa + ret <8 x i64> %ret +}