Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pauli word] Rework the implementation from front to back. #2338

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
3 changes: 3 additions & 0 deletions include/cudaq/Optimizer/Builder/Factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,9 @@ bool hasSRet(mlir::func::FuncOp funcOp);
mlir::FunctionType toHostSideFuncType(mlir::FunctionType funcTy,
bool addThisPtr, mlir::ModuleOp module);

/// Convert device type, \p ty, to host side type.
mlir::Type convertToHostSideType(mlir::Type ty);

// Return `true` if the given type corresponds to a standard vector type
// according to our convention.
// The convention is a `ptr<struct<ptr<T>, ptr<T>, ptr<T>>>`.
Expand Down
5 changes: 5 additions & 0 deletions include/cudaq/Optimizer/Builder/Intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,16 @@ static constexpr const char getCudaqSizeFromTriple[] =
// typically specialized to be bit packed).
static constexpr const char stdvecBoolCtorFromInitList[] =
"__nvqpp_initializer_list_to_vector_bool";

// Convert a (likely packed) std::vector<bool> into a sequence of bytes, each
// holding a boolean value.
static constexpr const char stdvecBoolUnpackToInitList[] =
"__nvqpp_vector_bool_to_initializer_list";

// Free any temporary buffers used to hold std::vector<bool> data.
static constexpr const char stdvecBoolFreeTemporaryLists[] =
"__nvqpp_vector_bool_free_temporary_initlists";

// The internal data of the cudaq::state object must be `2**n` in length. This
// function returns the value `n`.
static constexpr const char getNumQubitsFromCudaqState[] =
Expand Down
5 changes: 5 additions & 0 deletions include/cudaq/Optimizer/Dialect/CC/CCTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,12 @@ def cc_StructType : CCType<"Struct", "struct",
];

let extraClassDeclaration = [{
// O(1)
bool isEmpty() const { return getMembers().empty(); }

// O(n)
std::size_t getNumMembers() const { return getMembers().size(); }

Type getMember(unsigned position) { return getMembers()[position]; }
}];
}
Expand Down
30 changes: 20 additions & 10 deletions lib/Optimizer/Builder/Factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,22 @@ cc::StructType factory::stlVectorType(Type eleTy) {
return cc::StructType::get(ctx, ArrayRef<Type>{ptrTy, ptrTy, ptrTy});
}

// Note that this is the raw host type, where std::vector<bool> is distinct.
// When converting to the device side, the distinction is deliberately removed
// making std::vector<bool> the same format as std::vector<char>.
static cc::StructType stlHostVectorType(Type eleTy) {
MLIRContext *ctx = eleTy.getContext();
if (eleTy != IntegerType::get(ctx, 1)) {
// std::vector<T> where T != bool.
return factory::stlVectorType(eleTy);
}
// std::vector<bool> is a different type than std::vector<T>.
auto ptrTy = cc::PointerType::get(eleTy);
auto i8Ty = IntegerType::get(ctx, 8);
auto padout = cc::ArrayType::get(ctx, i8Ty, 32);
return cc::StructType::get(ctx, ArrayRef<Type>{ptrTy, padout});
}

// FIXME: Give these front-end names so we can disambiguate more types.
cc::StructType factory::getDynamicBufferType(MLIRContext *ctx) {
auto ptrTy = cc::PointerType::get(IntegerType::get(ctx, 8));
Expand All @@ -342,19 +358,13 @@ Type factory::getSRetElementType(FunctionType funcTy) {
return funcTy.getResult(0);
}

static Type convertToHostSideType(Type ty) {
Type factory::convertToHostSideType(Type ty) {
if (auto memrefTy = dyn_cast<cc::StdvecType>(ty))
return convertToHostSideType(
factory::stlVectorType(memrefTy.getElementType()));
return stlHostVectorType(convertToHostSideType(memrefTy.getElementType()));
if (isa<cc::IndirectCallableType>(ty))
return cc::PointerType::get(IntegerType::get(ty.getContext(), 8));
if (auto memrefTy = dyn_cast<cc::CharspanType>(ty)) {
// `pauli_word` is an object with a std::vector in the header files at
// present. This data type *must* be updated if it becomes a std::string
// once again.
return convertToHostSideType(
factory::stlVectorType(IntegerType::get(ty.getContext(), 8)));
}
if (isa<cc::CharspanType>(ty))
return factory::stlStringType(ty.getContext());
auto *ctx = ty.getContext();
if (auto structTy = dyn_cast<cc::StructType>(ty)) {
SmallVector<Type> newMembers;
Expand Down
8 changes: 7 additions & 1 deletion lib/Optimizer/Builder/Intrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,11 +307,17 @@ static constexpr IntrinsicCode intrinsicTable[] = {
return %0 : !cc.ptr<i8>
})#"},

// __nvqpp_vector_bool_free_temporary_lists
{cudaq::stdvecBoolFreeTemporaryLists,
{},
R"#(
func.func private @__nvqpp_vector_bool_free_temporary_initlists(!cc.ptr<i8>) -> ())#"},

// __nvqpp_vector_bool_to_initializer_list
{cudaq::stdvecBoolUnpackToInitList,
{},
R"#(
func.func private @__nvqpp_vector_bool_to_initializer_list(!cc.ptr<!cc.struct<{!cc.ptr<i1>, !cc.ptr<i1>, !cc.ptr<i1>}>>, !cc.ptr<!cc.struct<{!cc.ptr<i1>, !cc.ptr<i1>, !cc.ptr<i1>}>>) -> ())#"},
func.func private @__nvqpp_vector_bool_to_initializer_list(!cc.ptr<!cc.struct<{!cc.ptr<i1>, !cc.ptr<i1>, !cc.ptr<i1>}>>, !cc.ptr<!cc.struct<{!cc.ptr<i1>, !cc.array<i8 x 32>}>>, !cc.ptr<!cc.ptr<i8>>) -> ())#"},

{"__nvqpp_zeroDynamicResult", {}, R"#(
func.func private @__nvqpp_zeroDynamicResult() -> !cc.struct<{!cc.ptr<i8>, i64}> {
Expand Down
2 changes: 1 addition & 1 deletion lib/Optimizer/Dialect/CC/CCTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ Type cc::SpanLikeType::getElementType() const {
}

bool isDynamicType(Type ty) {
if (isa<StdvecType>(ty))
if (isa<SpanLikeType>(ty))
return true;
if (auto strTy = dyn_cast<StructType>(ty)) {
for (auto memTy : strTy.getMembers())
Expand Down
5 changes: 4 additions & 1 deletion lib/Optimizer/Transforms/DecompositionPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,14 +362,17 @@ struct ExpPauliDecomposition : public OpRewritePattern<quake::ExpPauliOp> {
auto strAttr = cast<mlir::StringAttr>(attr.value());
optPauliWordStr = strAttr.getValue();
}
} else if (auto lit = addrOp.getDefiningOp<
cudaq::cc::CreateStringLiteralOp>()) {
optPauliWordStr = lit.getStringLiteral();
}
}
}
}

// Assert that we have a constant known pauli word
if (!optPauliWordStr.has_value())
return failure();
return expPauliOp.emitOpError("cannot determine pauli word string");

auto pauliWordStr = optPauliWordStr.value();

Expand Down
Loading
Loading