Skip to content

Commit

Permalink
Fix argument conversion issues (#2211)
Browse files Browse the repository at this point in the history
* temp

* temp

* temp

* Remove printing

* Fixed synthesis issues and added tests

* Remove changes for remote sim

* Address CR comments
  • Loading branch information
annagrin authored Sep 17, 2024
1 parent 176f1e7 commit e0faa09
Show file tree
Hide file tree
Showing 4 changed files with 453 additions and 21 deletions.
14 changes: 11 additions & 3 deletions runtime/common/ArgumentConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "cudaq/Optimizer/Builder/Intrinsics.h"
#include "cudaq/Optimizer/Builder/Runtime.h"
#include "cudaq/Todo.h"
#include "cudaq/qis/pauli_word.h"
#include "llvm/ADT/TypeSwitch.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
Expand Down Expand Up @@ -199,7 +200,7 @@ Value dispatchSubtype(OpBuilder &builder, Type ty, void *p, ModuleOp substMod,
return {};
})
.Case([&](cudaq::cc::CharspanType strTy) {
return genConstant(builder, *static_cast<const std::string *>(p),
return genConstant(builder, static_cast<cudaq::pauli_word *>(p)->str(),
substMod);
})
.Case([&](cudaq::cc::StdvecType ty) {
Expand All @@ -224,6 +225,11 @@ Value genConstant(OpBuilder &builder, cudaq::cc::StdvecType vecTy, void *p,
auto eleTy = vecTy.getElementType();
auto elePtrTy = cudaq::cc::PointerType::get(eleTy);
auto eleSize = cudaq::opt::getDataSize(layout, eleTy);
if (isa<cudaq::cc::CharspanType>(eleTy)) {
// char span type (i.e. pauli word) is a `vector<char>`
eleSize = sizeof(VectorType);
}

assert(eleSize && "element must have a size");
auto loc = builder.getUnknownLoc();
std::int32_t vecSize = delta / eleSize;
Expand Down Expand Up @@ -361,7 +367,7 @@ void cudaq::opt::ArgumentConverter::gen(const std::vector<void *> &arguments) {
return {};
})
.Case([&](cc::CharspanType strTy) {
return buildSubst(*static_cast<const std::string *>(argPtr),
return buildSubst(static_cast<cudaq::pauli_word *>(argPtr)->str(),
substModule);
})
.Case([&](cc::PointerType ptrTy) -> cc::ArgumentSubstitutionOp {
Expand Down Expand Up @@ -406,8 +412,10 @@ void cudaq::opt::ArgumentConverter::gen_drop_front(
if (numDrop >= arguments.size())
return;
std::vector<void *> partialArgs;
int drop = numDrop;
for (void *arg : arguments) {
if (numDrop--) {
if (drop > 0) {
drop--;
partialArgs.push_back(nullptr);
continue;
}
Expand Down
234 changes: 219 additions & 15 deletions runtime/test/test_argument_conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
#include "common/ArgumentConversion.h"
#include "cudaq/Optimizer/Dialect/CC/CCDialect.h"
#include "cudaq/Optimizer/Dialect/Quake/QuakeDialect.h"
#include "cudaq/qis/pauli_word.h"
#include "mlir/InitAllDialects.h"
#include "mlir/Parser/Parser.h"
#include <numeric>

void doSimpleTest(mlir::MLIRContext *ctx, const std::string &typeName,
std::vector<void *> args) {
Expand All @@ -29,14 +31,58 @@ func.func @__nvqpp__mlirgen__testy(%0: )#" +
typeName + R"#() -> ()
return
})#";
// Create the Module
auto mod = mlir::parseSourceString<mlir::ModuleOp>(code, ctx);
llvm::outs() << "Source module:\n" << *mod << '\n';
cudaq::opt::ArgumentConverter ab{"testy", *mod};
// Create the argument conversions
ab.gen(args);
// Dump the conversions
llvm::outs() << "========================================\n"
"Substitution module:\n"
<< ab.getSubstitutionModule() << '\n';
}

void doTest(mlir::MLIRContext *ctx, std::vector<std::string> &typeNames,
std::vector<void *> args, std::size_t startingArgIdx = 0) {

std::string code;
llvm::raw_string_ostream ss(code);

// Create code
std::vector<int> indices(args.size());
std::iota(indices.begin(), indices.end(), 0);
auto argPairs = llvm::zip_equal(indices, typeNames);

ss << "func.func private @callee(";
llvm::interleaveComma(argPairs, ss, [&](auto p) {
ss << "%" << std::get<0>(p) << ": " << std::get<1>(p);
});
ss << ")\n";

ss << "func.func @__nvqpp__mlirgen__testy(";
llvm::interleaveComma(argPairs, ss, [&](auto p) {
ss << "%" << std::get<0>(p) << ": " << std::get<1>(p);
});
ss << ") {";

ss << " call @callee(";
llvm::interleaveComma(indices, ss, [&](auto p) { ss << "%" << p; });

ss << "): (";
llvm::interleaveComma(typeNames, ss, [&](auto t) { ss << t; });
ss << ") -> ()\n";

ss << " return\n";
ss << "}\n";

// Create the Module
auto mod = mlir::parseSourceString<mlir::ModuleOp>(code, ctx);
llvm::outs() << "Source module:\n" << *mod << '\n';
cudaq::opt::ArgumentConverter ab{"testy", *mod};

// Create the argument conversions
ab.gen(args);
ab.gen_drop_front(args, startingArgIdx);

// Dump the conversions
llvm::outs() << "========================================\n"
Expand Down Expand Up @@ -146,7 +192,7 @@ void test_scalars(mlir::MLIRContext *ctx) {
// clang-format on

{
std::string x = "Hi, there!";
cudaq::pauli_word x{"XYZ"};
std::vector<void *> v = {static_cast<void *>(&x)};
doSimpleTest(ctx, "!cc.charspan", v);
}
Expand All @@ -156,12 +202,12 @@ void test_scalars(mlir::MLIRContext *ctx) {
// CHECK: Substitution module:

// CHECK-LABEL: cc.arg_subst[0] {
// CHECK: %[[VAL_0:.*]] = cc.address_of @cstr.48692C2074686572652100 : !cc.ptr<!llvm.array<11 x i8>>
// CHECK: %[[VAL_1:.*]] = cc.cast %[[VAL_0]] : (!cc.ptr<!llvm.array<11 x i8>>) -> !cc.ptr<i8>
// CHECK: %[[VAL_2:.*]] = arith.constant 10 : i64
// CHECK: %[[VAL_0:.*]] = cc.address_of @cstr.58595A00 : !cc.ptr<!llvm.array<4 x i8>>
// CHECK: %[[VAL_1:.*]] = cc.cast %[[VAL_0]] : (!cc.ptr<!llvm.array<4 x i8>>) -> !cc.ptr<i8>
// CHECK: %[[VAL_2:.*]] = arith.constant 3 : i64
// CHECK: %[[VAL_3:.*]] = cc.stdvec_init %[[VAL_1]], %[[VAL_2]] : (!cc.ptr<i8>, i64) -> !cc.charspan
// CHECK: }
// CHECK: llvm.mlir.global private constant @cstr.48692C2074686572652100("Hi, there!\00") {addr_space = 0 : i32}
// CHECK-DAG: llvm.mlir.global private constant @cstr.58595A00("XYZ\00") {addr_space = 0 : i32}
// clang-format on
}

Expand Down Expand Up @@ -194,6 +240,34 @@ void test_vectors(mlir::MLIRContext *ctx) {
// CHECK: %[[VAL_10:.*]] = cc.stdvec_init %[[VAL_0]], %[[VAL_9]] : (!cc.ptr<!cc.array<i32 x 4>>, i64) -> !cc.stdvec<i32>
// CHECK: }
// clang-format on

{
std::vector<cudaq::pauli_word> x = {cudaq::pauli_word{"XX"},
cudaq::pauli_word{"XY"}};
std::vector<void *> v = {static_cast<void *>(&x)};
doSimpleTest(ctx, "!cc.stdvec<!cc.charspan>", v);
}
// clang-format off
// CHECK-LABEL: cc.arg_subst[0] {
// CHECK: %[[VAL_0:.*]] = cc.alloca !cc.array<!cc.charspan x 2>
// CHECK: %[[VAL_1:.*]] = cc.address_of @cstr.585800 : !cc.ptr<!llvm.array<3 x i8>>
// CHECK: %[[VAL_2:.*]] = cc.cast %[[VAL_1]] : (!cc.ptr<!llvm.array<3 x i8>>) -> !cc.ptr<i8>
// CHECK: %[[VAL_3:.*]] = arith.constant 2 : i64
// CHECK: %[[VAL_4:.*]] = cc.stdvec_init %[[VAL_2]], %[[VAL_3]] : (!cc.ptr<i8>, i64) -> !cc.charspan
// CHECK: %[[VAL_5:.*]] = cc.compute_ptr %[[VAL_0]][0] : (!cc.ptr<!cc.array<!cc.charspan x 2>>) -> !cc.ptr<!cc.charspan>
// CHECK: cc.store %[[VAL_4]], %[[VAL_5]] : !cc.ptr<!cc.charspan>
// CHECK: %[[VAL_6:.*]] = cc.address_of @cstr.585900 : !cc.ptr<!llvm.array<3 x i8>>
// CHECK: %[[VAL_7:.*]] = cc.cast %[[VAL_6]] : (!cc.ptr<!llvm.array<3 x i8>>) -> !cc.ptr<i8>
// CHECK: %[[VAL_8:.*]] = arith.constant 2 : i64
// CHECK: %[[VAL_9:.*]] = cc.stdvec_init %[[VAL_7]], %[[VAL_8]] : (!cc.ptr<i8>, i64) -> !cc.charspan
// CHECK: %[[VAL_10:.*]] = cc.compute_ptr %[[VAL_0]][1] : (!cc.ptr<!cc.array<!cc.charspan x 2>>) -> !cc.ptr<!cc.charspan>
// CHECK: cc.store %[[VAL_9:.*]], %[[VAL_10:.*]] : !cc.ptr<!cc.charspan>
// CHECK: %[[VAL_11:.*]] = arith.constant 2 : i64
// CHECK: %[[VAL_12:.*]] = cc.stdvec_init %[[VAL_0]], %[[VAL_11]] : (!cc.ptr<!cc.array<!cc.charspan x 2>>, i64) -> !cc.stdvec<!cc.charspan>
// CHECK: }
// CHECK-DAG: llvm.mlir.global private constant @cstr.585800("XX\00") {addr_space = 0 : i32}
// CHECK-DAG: llvm.mlir.global private constant @cstr.585900("XY\00") {addr_space = 0 : i32}
// clang-format on
}

void test_aggregates(mlir::MLIRContext *ctx) {
Expand Down Expand Up @@ -304,18 +378,147 @@ void test_state(mlir::MLIRContext *ctx) {
// CHECK: func.func private @callee(!cc.ptr<!cc.state>)
// CHECK: Substitution module:

// CHECK-LABEL: cc.arg_subst[0] {
// CHECK: %[[VAL_0:.*]] = cc.address_of @[[VAL_GC:.*]] : !cc.ptr<!cc.array<complex<f64> x 8>>
// CHECK: %[[VAL_1:.*]] = cc.load %[[VAL_0]] : !cc.ptr<!cc.array<complex<f64> x 8>>
// CHECK: %[[VAL_2:.*]] = arith.constant 8 : i64
// CHECK: %[[VAL_3:.*]] = cc.alloca !cc.array<complex<f64> x 8>
// CHECK: cc.store %[[VAL_1]], %[[VAL_3]] : !cc.ptr<!cc.array<complex<f64> x 8>>
// CHECK: %[[VAL_4:.*]] = cc.cast %[[VAL_3]] : (!cc.ptr<!cc.array<complex<f64> x 8>>) -> !cc.ptr<i8>
// CHECK: %[[VAL_5:.*]] = func.call @__nvqpp_cudaq_state_createFromData_fp64(%[[VAL_4]], %[[VAL_2]]) : (!cc.ptr<i8>, i64) -> !cc.ptr<!cc.state>
// CHECK: %[[VAL_6:.*]] = cc.cast %[[VAL_5]] : (!cc.ptr<!cc.state>) -> !cc.ptr<!cc.state>
// CHECK-LABEL: cc.arg_subst[0] {
// CHECK: %[[VAL_0:.*]] = cc.address_of @[[VAL_GC:.*]] : !cc.ptr<!cc.array<complex<f64> x 8>>
// CHECK: %[[VAL_1:.*]] = cc.load %[[VAL_0]] : !cc.ptr<!cc.array<complex<f64> x 8>>
// CHECK: %[[VAL_2:.*]] = arith.constant 8 : i64
// CHECK: %[[VAL_3:.*]] = cc.alloca !cc.array<complex<f64> x 8>
// CHECK: cc.store %[[VAL_1]], %[[VAL_3]] : !cc.ptr<!cc.array<complex<f64> x 8>>
// CHECK: %[[VAL_4:.*]] = cc.cast %[[VAL_3]] : (!cc.ptr<!cc.array<complex<f64> x 8>>) -> !cc.ptr<i8>
// CHECK: %[[VAL_5:.*]] = func.call @__nvqpp_cudaq_state_createFromData_fp64(%[[VAL_4]], %[[VAL_2]]) : (!cc.ptr<i8>, i64) -> !cc.ptr<!cc.state>
// CHECK: %[[VAL_6:.*]] = cc.cast %[[VAL_5]] : (!cc.ptr<!cc.state>) -> !cc.ptr<!cc.state>
// CHECK: }
// CHECK-DAG: cc.global constant @[[VAL_GC]] (dense<[(0.70710678118654757,0.000000e+00), (0.70710678118654757,0.000000e+00), (0.000000e+00,0.000000e+00), (0.000000e+00,0.000000e+00), (0.000000e+00,0.000000e+00), (0.000000e+00,0.000000e+00), (0.000000e+00,0.000000e+00), (0.000000e+00,0.000000e+00)]> : tensor<8xcomplex<f64>>) : !cc.array<complex<f64> x 8>
// CHECK-DAG: func.func private @__nvqpp_cudaq_state_createFromData_fp64(!cc.ptr<i8>, i64) -> !cc.ptr<!cc.state>
// clang-format on
}

void test_combinations(mlir::MLIRContext *ctx) {
{
bool x = true;
std::vector<void *> v = {static_cast<void *>(&x)};
std::vector<std::string> t = {"i1"};
doTest(ctx, t, v);
}
// clang-format off
// CHECK-LABEL: Source module:
// CHECK: func.func private @callee(i1)
// CHECK: Substitution module:

// CHECK-LABEL: cc.arg_subst[0] {
// CHECK: %[[VAL_0:.*]] = arith.constant true
// CHECK: }
// clang-format on

{
bool x = true;
bool y = false;
std::vector<void *> v = {static_cast<void *>(&x), static_cast<void *>(&y)};
std::vector<std::string> t = {"i1", "i1"};
doTest(ctx, t, v);
}
// clang-format off
// CHECK: Source module:
// CHECK: func.func private @callee(i1, i1)
// CHECK: Substitution module:

// CHECK-LABEL: cc.arg_subst[0] {
// CHECK: %[[VAL_0:.*]] = arith.constant true
// CHECK: }
// CHECK-LABEL: cc.arg_subst[1] {
// CHECK: %[[VAL_1:.*]] = arith.constant false
// CHECK: }
// clang-format on

{
bool x = true;
std::int32_t y = 42;
std::vector<void *> v = {static_cast<void *>(&x), static_cast<void *>(&y)};
std::vector<std::string> t = {"i1", "i32"};
doTest(ctx, t, v, 1);
}

// clang-format off
// CHECK: Source module:
// CHECK: func.func private @callee(i1, i32)
// CHECK: Substitution module:

// CHECK-LABEL: cc.arg_subst[1] {
// CHECK: %[[VAL_0:.*]] = arith.constant 42 : i32
// CHECK: }
// clang-format on

{
std::vector<std::complex<double>> data{M_SQRT1_2, M_SQRT1_2, 0., 0.,
0., 0., 0., 0.};

std::vector<double> x = {0.5, 0.6};
cudaq::state y{new FakeSimulationState(data.size(), data.data())};
std::vector<cudaq::pauli_word> z = {
cudaq::pauli_word{"XX"},
cudaq::pauli_word{"XY"},
};

std::vector<void *> v = {static_cast<void *>(&x), static_cast<void *>(&y),
static_cast<void *>(&z)};
std::vector<std::string> t = {"!cc.stdvec<f32>", "!cc.ptr<!cc.state>",
"!cc.stdvec<!cc.charspan>"};
doTest(ctx, t, v);
}

// clang-format off
// CHECK: Source module:
// CHECK: func.func private @callee(!cc.stdvec<f32>, !cc.ptr<!cc.state>, !cc.stdvec<!cc.charspan>)
// CHECK: Substitution module:

// CHECK-LABEL: cc.arg_subst[0] {
// CHECK: %[[VAL_0:.*]] = cc.alloca !cc.array<f32 x 4>
// CHECK: %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[VAL_2:.*]] = cc.compute_ptr %[[VAL_0]][0] : (!cc.ptr<!cc.array<f32 x 4>>) -> !cc.ptr<f32>
// CHECK: cc.store %[[VAL_1]], %[[VAL_2]] : !cc.ptr<f32>
// CHECK: %[[VAL_3:.*]] = arith.constant 1.750000e+00 : f32
// CHECK: %[[VAL_4:.*]] = cc.compute_ptr %[[VAL_0]][1] : (!cc.ptr<!cc.array<f32 x 4>>) -> !cc.ptr<f32>
// CHECK: cc.store %[[VAL_3]], %[[VAL_4]] : !cc.ptr<f32>
// CHECK: %[[VAL_5:.*]] = arith.constant 4.17232506E-8 : f32
// CHECK: %[[VAL_6:.*]] = cc.compute_ptr %[[VAL_0]][2] : (!cc.ptr<!cc.array<f32 x 4>>) -> !cc.ptr<f32>
// CHECK: cc.store %[[VAL_5]], %[[VAL_6]] : !cc.ptr<f32>
// CHECK: %[[VAL_7:.*]] = arith.constant 1.775000e+00 : f32
// CHECK: %[[VAL_8:.*]] = cc.compute_ptr %[[VAL_0]][3] : (!cc.ptr<!cc.array<f32 x 4>>) -> !cc.ptr<f32>
// CHECK: cc.store %[[VAL_7]], %[[VAL_8]] : !cc.ptr<f32>
// CHECK: %[[VAL_9:.*]] = arith.constant 4 : i64
// CHECK: %[[VAL_10:.*]] = cc.stdvec_init %[[VAL_0]], %[[VAL_9]] : (!cc.ptr<!cc.array<f32 x 4>>, i64) -> !cc.stdvec<f32>
// CHECK: }
// CHECK-LABEL: cc.arg_subst[1] {
// CHECK: %[[VAL_0:.*]] = cc.address_of @[[VAL_GC:.*]] : !cc.ptr<!cc.array<complex<f64> x 8>>
// CHECK: %[[VAL_1:.*]] = cc.load %[[VAL_0]] : !cc.ptr<!cc.array<complex<f64> x 8>>
// CHECK: %[[VAL_2:.*]] = arith.constant 8 : i64
// CHECK: %[[VAL_3:.*]] = cc.alloca !cc.array<complex<f64> x 8>
// CHECK: cc.store %[[VAL_1]], %[[VAL_3]] : !cc.ptr<!cc.array<complex<f64> x 8>>
// CHECK: %[[VAL_4:.*]] = cc.cast %[[VAL_3]] : (!cc.ptr<!cc.array<complex<f64> x 8>>) -> !cc.ptr<i8>
// CHECK: %[[VAL_5:.*]] = func.call @__nvqpp_cudaq_state_createFromData_fp64(%[[VAL_4]], %[[VAL_2]]) : (!cc.ptr<i8>, i64) -> !cc.ptr<!cc.state>
// CHECK: %[[VAL_6:.*]] = cc.cast %[[VAL_5]] : (!cc.ptr<!cc.state>) -> !cc.ptr<!cc.state>
// CHECK: }
// CHECK-DAG: cc.global constant @[[VAL_GC]] (dense<[(0.70710678118654757,0.000000e+00), (0.70710678118654757,0.000000e+00), (0.000000e+00,0.000000e+00), (0.000000e+00,0.000000e+00), (0.000000e+00,0.000000e+00), (0.000000e+00,0.000000e+00), (0.000000e+00,0.000000e+00), (0.000000e+00,0.000000e+00)]> : tensor<8xcomplex<f64>>) : !cc.array<complex<f64> x 8>
// CHECK-DAG: func.func private @__nvqpp_cudaq_state_createFromData_fp64(!cc.ptr<i8>, i64) -> !cc.ptr<!cc.state>
// CHECK-LABEL: cc.arg_subst[2] {
// CHECK: %[[VAL_0:.*]] = cc.alloca !cc.array<!cc.charspan x 2>
// CHECK: %[[VAL_1:.*]] = cc.address_of @cstr.585800 : !cc.ptr<!llvm.array<3 x i8>>
// CHECK: %[[VAL_2:.*]] = cc.cast %[[VAL_1]] : (!cc.ptr<!llvm.array<3 x i8>>) -> !cc.ptr<i8>
// CHECK: %[[VAL_3:.*]] = arith.constant 2 : i64
// CHECK: %[[VAL_4:.*]] = cc.stdvec_init %[[VAL_2]], %[[VAL_3]] : (!cc.ptr<i8>, i64) -> !cc.charspan
// CHECK: %[[VAL_5:.*]] = cc.compute_ptr %[[VAL_0]][0] : (!cc.ptr<!cc.array<!cc.charspan x 2>>) -> !cc.ptr<!cc.charspan>
// CHECK: cc.store %[[VAL_4]], %[[VAL_5]] : !cc.ptr<!cc.charspan>
// CHECK: %[[VAL_6:.*]] = cc.address_of @cstr.585900 : !cc.ptr<!llvm.array<3 x i8>>
// CHECK: %[[VAL_7:.*]] = cc.cast %[[VAL_6]] : (!cc.ptr<!llvm.array<3 x i8>>) -> !cc.ptr<i8>
// CHECK: %[[VAL_8:.*]] = arith.constant 2 : i64
// CHECK: %[[VAL_9:.*]] = cc.stdvec_init %[[VAL_7]], %[[VAL_8]] : (!cc.ptr<i8>, i64) -> !cc.charspan
// CHECK: %[[VAL_10:.*]] = cc.compute_ptr %[[VAL_0]][1] : (!cc.ptr<!cc.array<!cc.charspan x 2>>) -> !cc.ptr<!cc.charspan>
// CHECK: cc.store %[[VAL_9]], %[[VAL_10]] : !cc.ptr<!cc.charspan>
// CHECK: %[[VAL_11:.*]] = arith.constant 2 : i64
// CHECK: %[[VAL_12:.*]] = cc.stdvec_init %[[VAL_0]], %[[VAL_11]] : (!cc.ptr<!cc.array<!cc.charspan x 2>>, i64) -> !cc.stdvec<!cc.charspan>
// CHECK: }
// CHECK-DAG: llvm.mlir.global private constant @cstr.585800("XX\00") {addr_space = 0 : i32}
// CHECK-DAG: llvm.mlir.global private constant @cstr.585900("XY\00") {addr_space = 0 : i32}
// clang-format on
}

Expand All @@ -330,5 +533,6 @@ int main() {
test_aggregates(&context);
test_recursive(&context);
test_state(&context);
test_combinations(&context);
return 0;
}
Loading

0 comments on commit e0faa09

Please sign in to comment.