Skip to content

Commit

Permalink
[core] Fix memtoreg bug with loop body (#1953)
Browse files Browse the repository at this point in the history
When the body of a loop defines a value that was live-in and
live-out of that loop, the region arguments have to be added for
all loop regions.

These arguments must be added to balance all the region arguments
even if the argument is technically not live-in to the region
because it is defined unconditionally therein.
  • Loading branch information
schweitzpgi authored Jul 17, 2024
1 parent da39e39 commit 42ee0e0
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 0 deletions.
7 changes: 7 additions & 0 deletions lib/Optimizer/Transforms/MemToReg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,11 @@ class RegionDataFlow {
return addLiveInToBlock(block, mr);
}

void maybeAddBalancedLiveInToBlock(Block *block, MemRef mr) {
if (liveOutSet.count(mr))
maybeAddLiveInToBlock(block, mr);
}

/// Record the memory reference \p mr as live-in to \p block. The live-in
/// value is specified as \p val. Consequently, \p val \em{must dominate} \p
/// block.
Expand Down Expand Up @@ -957,6 +962,8 @@ class MemToRegPass : public cudaq::opt::impl::MemToRegBase<MemToRegPass> {
auto *block = term->getBlock();
for (auto liveOut : bindings) {
if (dataFlow.hasBinding(block, liveOut)) {
if (!isFunctionBlock(block) && !usePromo && !onlyLinear)
dataFlow.maybeAddBalancedLiveInToBlock(block, liveOut);
auto oldVal = dataFlow.getBinding(block, liveOut);
addTerminatorArgument(term, target, oldVal);
} else if ((usePromo ||
Expand Down
102 changes: 102 additions & 0 deletions test/Quake/memtoreg-7.qke
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// ========================================================================== //
// Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. //
// All rights reserved. //
// //
// This source code and the accompanying materials are made available under //
// the terms of the Apache License 2.0 which accompanies this distribution. //
// ========================================================================== //

// RUN: cudaq-opt --memtoreg %s | FileCheck %s

func.func @__nvqpp__mlirgen__test() attributes {"cudaq-entrypoint", qubitMeasurementFeedback = true} {
%false = arith.constant false
%c1_i64 = arith.constant 1 : i64
%c0_i64 = arith.constant 0 : i64
%true = arith.constant true
%c2_i64 = arith.constant 2 : i64
%0 = quake.alloca !quake.veq<2>
%1 = cc.alloca i1
cc.store %true, %1 : !cc.ptr<i1>
%2 = cc.loop while ((%arg0 = %c0_i64) -> (i64)) {
%5 = arith.cmpi slt, %arg0, %c2_i64 : i64
cc.condition %5(%arg0 : i64)
} do {
^bb0(%arg0: i64):
%5 = quake.extract_ref %0[%arg0] : (!quake.veq<2>, i64) -> !quake.ref
%measOut = quake.mz %5 name "res" : (!quake.ref) -> !quake.measure
%6 = quake.discriminate %measOut : (!quake.measure) -> i1
cc.store %6, %1 : !cc.ptr<i1>
%7 = cc.load %1 : !cc.ptr<i1>
%8 = arith.cmpi eq, %7, %false : i1
cc.if(%8) {
%measOut_0 = quake.mz %0 name "inner_mz" : (!quake.veq<2>) -> !cc.stdvec<!quake.measure>
%9 = quake.discriminate %measOut_0 : (!cc.stdvec<!quake.measure>) -> !cc.stdvec<i1>
cc.scope {
%10 = cc.alloca !cc.stdvec<i1>
cc.store %9, %10 : !cc.ptr<!cc.stdvec<i1>>
}
}
cc.continue %arg0 : i64
} step {
^bb0(%arg0: i64):
%5 = arith.addi %arg0, %c1_i64 : i64
cc.continue %5 : i64
} {invariant}
%3 = cc.load %1 : !cc.ptr<i1>
%4 = arith.cmpi eq, %3, %true : i1
cc.if(%4) {
%measOut = quake.mz %0 name "outer_mz" : (!quake.veq<2>) -> !cc.stdvec<!quake.measure>
%5 = quake.discriminate %measOut : (!cc.stdvec<!quake.measure>) -> !cc.stdvec<i1>
cc.scope {
%6 = cc.alloca !cc.stdvec<i1>
cc.store %5, %6 : !cc.ptr<!cc.stdvec<i1>>
}
}
return
}

// CHECK-LABEL: func.func @__nvqpp__mlirgen__test() attributes {"cudaq-entrypoint", qubitMeasurementFeedback = true} {
// CHECK-DAG: %[[VAL_0:.*]] = arith.constant false
// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 1 : i64
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : i64
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant true
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 2 : i64
// CHECK-DAG: %[[VAL_5:.*]] = quake.alloca !quake.veq<2>
// CHECK-DAG: %[[VAL_6:.*]] = cc.undef i1
// CHECK: %[[VAL_7:.*]]:2 = cc.loop while ((%[[VAL_8:.*]] = %[[VAL_2]], %[[VAL_9:.*]] = %[[VAL_3]]) -> (i64, i1)) {
// CHECK: %[[VAL_10:.*]] = arith.cmpi slt, %[[VAL_8]], %[[VAL_4]] : i64
// CHECK: cc.condition %[[VAL_10]](%[[VAL_8]], %[[VAL_9]] : i64, i1)
// CHECK: } do {
// CHECK: ^bb0(%[[VAL_11:.*]]: i64, %[[VAL_12:.*]]: i1):
// CHECK: %[[VAL_13:.*]] = quake.extract_ref %[[VAL_5]]{{\[}}%[[VAL_11]]] : (!quake.veq<2>, i64) -> !quake.ref
// CHECK: %[[VAL_14:.*]] = quake.unwrap %[[VAL_13]] : (!quake.ref) -> !quake.wire
// CHECK: %[[VAL_15:.*]], %[[VAL_16:.*]] = quake.mz %[[VAL_14]] name "res" : (!quake.wire) -> (!quake.measure, !quake.wire)
// CHECK: quake.wrap %[[VAL_16]] to %[[VAL_13]] : !quake.wire, !quake.ref
// CHECK: %[[VAL_17:.*]] = quake.discriminate %[[VAL_15]] : (!quake.measure) -> i1
// CHECK: %[[VAL_18:.*]] = arith.cmpi eq, %[[VAL_17]], %[[VAL_0]] : i1
// CHECK: cc.if(%[[VAL_18]]) {
// CHECK: %[[VAL_19:.*]] = quake.mz %[[VAL_5]] name "inner_mz" : (!quake.veq<2>) -> !cc.stdvec<!quake.measure>
// CHECK: %[[VAL_20:.*]] = quake.discriminate %[[VAL_19]] : (!cc.stdvec<!quake.measure>) -> !cc.stdvec<i1>
// CHECK: cc.scope {
// CHECK: %[[VAL_21:.*]] = cc.undef !cc.stdvec<i1>
// CHECK: }
// CHECK: } else {
// CHECK: }
// CHECK: cc.continue %[[VAL_11]], %[[VAL_17]] : i64, i1
// CHECK: } step {
// CHECK: ^bb0(%[[VAL_22:.*]]: i64, %[[VAL_23:.*]]: i1):
// CHECK: %[[VAL_24:.*]] = arith.addi %[[VAL_22]], %[[VAL_1]] : i64
// CHECK: cc.continue %[[VAL_24]], %[[VAL_23]] : i64, i1
// CHECK: } {invariant}
// CHECK: %[[VAL_25:.*]] = arith.cmpi eq, %[[VAL_26:.*]]#1, %[[VAL_3]] : i1
// CHECK: cc.if(%[[VAL_25]]) {
// CHECK: %[[VAL_27:.*]] = quake.mz %[[VAL_5]] name "outer_mz" : (!quake.veq<2>) -> !cc.stdvec<!quake.measure>
// CHECK: %[[VAL_28:.*]] = quake.discriminate %[[VAL_27]] : (!cc.stdvec<!quake.measure>) -> !cc.stdvec<i1>
// CHECK: cc.scope {
// CHECK: %[[VAL_29:.*]] = cc.undef !cc.stdvec<i1>
// CHECK: }
// CHECK: } else {
// CHECK: }
// CHECK: return
// CHECK: }

0 comments on commit 42ee0e0

Please sign in to comment.