diff --git a/lib/Optimizer/Transforms/MemToReg.cpp b/lib/Optimizer/Transforms/MemToReg.cpp index a1d9b13532..ebdba833bf 100644 --- a/lib/Optimizer/Transforms/MemToReg.cpp +++ b/lib/Optimizer/Transforms/MemToReg.cpp @@ -331,6 +331,11 @@ class RegionDataFlow { return addLiveInToBlock(block, mr); } + void maybeAddBalancedLiveInToBlock(Block *block, MemRef mr) { + if (liveOutSet.count(mr)) + maybeAddLiveInToBlock(block, mr); + } + /// Record the memory reference \p mr as live-in to \p block. The live-in /// value is specified as \p val. Consequently, \p val \em{must dominate} \p /// block. @@ -957,6 +962,8 @@ class MemToRegPass : public cudaq::opt::impl::MemToRegBase { auto *block = term->getBlock(); for (auto liveOut : bindings) { if (dataFlow.hasBinding(block, liveOut)) { + if (!isFunctionBlock(block) && !usePromo && !onlyLinear) + dataFlow.maybeAddBalancedLiveInToBlock(block, liveOut); auto oldVal = dataFlow.getBinding(block, liveOut); addTerminatorArgument(term, target, oldVal); } else if ((usePromo || diff --git a/test/Quake/memtoreg-7.qke b/test/Quake/memtoreg-7.qke new file mode 100644 index 0000000000..2d0245d684 --- /dev/null +++ b/test/Quake/memtoreg-7.qke @@ -0,0 +1,102 @@ +// ========================================================================== // +// Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. // +// All rights reserved. // +// // +// This source code and the accompanying materials are made available under // +// the terms of the Apache License 2.0 which accompanies this distribution. // +// ========================================================================== // + +// RUN: cudaq-opt --memtoreg %s | FileCheck %s + +func.func @__nvqpp__mlirgen__test() attributes {"cudaq-entrypoint", qubitMeasurementFeedback = true} { + %false = arith.constant false + %c1_i64 = arith.constant 1 : i64 + %c0_i64 = arith.constant 0 : i64 + %true = arith.constant true + %c2_i64 = arith.constant 2 : i64 + %0 = quake.alloca !quake.veq<2> + %1 = cc.alloca i1 + cc.store %true, %1 : !cc.ptr + %2 = cc.loop while ((%arg0 = %c0_i64) -> (i64)) { + %5 = arith.cmpi slt, %arg0, %c2_i64 : i64 + cc.condition %5(%arg0 : i64) + } do { + ^bb0(%arg0: i64): + %5 = quake.extract_ref %0[%arg0] : (!quake.veq<2>, i64) -> !quake.ref + %measOut = quake.mz %5 name "res" : (!quake.ref) -> !quake.measure + %6 = quake.discriminate %measOut : (!quake.measure) -> i1 + cc.store %6, %1 : !cc.ptr + %7 = cc.load %1 : !cc.ptr + %8 = arith.cmpi eq, %7, %false : i1 + cc.if(%8) { + %measOut_0 = quake.mz %0 name "inner_mz" : (!quake.veq<2>) -> !cc.stdvec + %9 = quake.discriminate %measOut_0 : (!cc.stdvec) -> !cc.stdvec + cc.scope { + %10 = cc.alloca !cc.stdvec + cc.store %9, %10 : !cc.ptr> + } + } + cc.continue %arg0 : i64 + } step { + ^bb0(%arg0: i64): + %5 = arith.addi %arg0, %c1_i64 : i64 + cc.continue %5 : i64 + } {invariant} + %3 = cc.load %1 : !cc.ptr + %4 = arith.cmpi eq, %3, %true : i1 + cc.if(%4) { + %measOut = quake.mz %0 name "outer_mz" : (!quake.veq<2>) -> !cc.stdvec + %5 = quake.discriminate %measOut : (!cc.stdvec) -> !cc.stdvec + cc.scope { + %6 = cc.alloca !cc.stdvec + cc.store %5, %6 : !cc.ptr> + } + } + return +} + +// CHECK-LABEL: func.func @__nvqpp__mlirgen__test() attributes {"cudaq-entrypoint", qubitMeasurementFeedback = true} { +// CHECK-DAG: %[[VAL_0:.*]] = arith.constant false +// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 1 : i64 +// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : i64 +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant true +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 2 : i64 +// CHECK-DAG: %[[VAL_5:.*]] = quake.alloca !quake.veq<2> +// CHECK-DAG: %[[VAL_6:.*]] = cc.undef i1 +// CHECK: %[[VAL_7:.*]]:2 = cc.loop while ((%[[VAL_8:.*]] = %[[VAL_2]], %[[VAL_9:.*]] = %[[VAL_3]]) -> (i64, i1)) { +// CHECK: %[[VAL_10:.*]] = arith.cmpi slt, %[[VAL_8]], %[[VAL_4]] : i64 +// CHECK: cc.condition %[[VAL_10]](%[[VAL_8]], %[[VAL_9]] : i64, i1) +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_11:.*]]: i64, %[[VAL_12:.*]]: i1): +// CHECK: %[[VAL_13:.*]] = quake.extract_ref %[[VAL_5]]{{\[}}%[[VAL_11]]] : (!quake.veq<2>, i64) -> !quake.ref +// CHECK: %[[VAL_14:.*]] = quake.unwrap %[[VAL_13]] : (!quake.ref) -> !quake.wire +// CHECK: %[[VAL_15:.*]], %[[VAL_16:.*]] = quake.mz %[[VAL_14]] name "res" : (!quake.wire) -> (!quake.measure, !quake.wire) +// CHECK: quake.wrap %[[VAL_16]] to %[[VAL_13]] : !quake.wire, !quake.ref +// CHECK: %[[VAL_17:.*]] = quake.discriminate %[[VAL_15]] : (!quake.measure) -> i1 +// CHECK: %[[VAL_18:.*]] = arith.cmpi eq, %[[VAL_17]], %[[VAL_0]] : i1 +// CHECK: cc.if(%[[VAL_18]]) { +// CHECK: %[[VAL_19:.*]] = quake.mz %[[VAL_5]] name "inner_mz" : (!quake.veq<2>) -> !cc.stdvec +// CHECK: %[[VAL_20:.*]] = quake.discriminate %[[VAL_19]] : (!cc.stdvec) -> !cc.stdvec +// CHECK: cc.scope { +// CHECK: %[[VAL_21:.*]] = cc.undef !cc.stdvec +// CHECK: } +// CHECK: } else { +// CHECK: } +// CHECK: cc.continue %[[VAL_11]], %[[VAL_17]] : i64, i1 +// CHECK: } step { +// CHECK: ^bb0(%[[VAL_22:.*]]: i64, %[[VAL_23:.*]]: i1): +// CHECK: %[[VAL_24:.*]] = arith.addi %[[VAL_22]], %[[VAL_1]] : i64 +// CHECK: cc.continue %[[VAL_24]], %[[VAL_23]] : i64, i1 +// CHECK: } {invariant} +// CHECK: %[[VAL_25:.*]] = arith.cmpi eq, %[[VAL_26:.*]]#1, %[[VAL_3]] : i1 +// CHECK: cc.if(%[[VAL_25]]) { +// CHECK: %[[VAL_27:.*]] = quake.mz %[[VAL_5]] name "outer_mz" : (!quake.veq<2>) -> !cc.stdvec +// CHECK: %[[VAL_28:.*]] = quake.discriminate %[[VAL_27]] : (!cc.stdvec) -> !cc.stdvec +// CHECK: cc.scope { +// CHECK: %[[VAL_29:.*]] = cc.undef !cc.stdvec +// CHECK: } +// CHECK: } else { +// CHECK: } +// CHECK: return +// CHECK: } +