Skip to content

Commit

Permalink
Remove QBdtQEngineNode (#995)
Browse files Browse the repository at this point in the history
* Cut QBdtQEngineNode

* Revert "Optimize QBdt::MCInvert()"

This reverts commit 5e87c00.

* Remove bdtQubitCount
  • Loading branch information
WrathfulSpatula authored Aug 6, 2023
1 parent 0e4d54e commit 85cc9bb
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 863 deletions.
1 change: 0 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,6 @@ install (FILES
include/qbdt.hpp
include/qbdt_node.hpp
include/qbdt_node_interface.hpp
include/qbdt_qengine_node.hpp
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/qrack
)

Expand Down
1 change: 0 additions & 1 deletion cmake/Qbdt.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ if (ENABLE_QBDT)
target_sources (qrack PRIVATE
src/qbdt/node_interface.cpp
src/qbdt/node.cpp
src/qbdt/qengine_node.cpp
src/qbdt/tree.cpp
)
endif (ENABLE_QBDT)
172 changes: 73 additions & 99 deletions include/qbdt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@

#pragma once

#include "qbdt_qengine_node.hpp"
#include "qbdt_node.hpp"
#include "qengine.hpp"

#define NODE_TO_QENGINE(leaf) (std::dynamic_pointer_cast<QBdtQEngineNode>(leaf)->qReg)
#define QINTERFACE_TO_QALU(qReg) std::dynamic_pointer_cast<QAlu>(qReg)
#define QINTERFACE_TO_QPARITY(qReg) std::dynamic_pointer_cast<QParity>(qReg)

Expand All @@ -34,64 +33,68 @@ class QBdt : public QAlu, public QParity, public QInterface {
class QBdt : public QParity, public QInterface {
#endif
protected:
bitLenInt attachedQubitCount;
bitLenInt bdtQubitCount;
bitLenInt maxPageQubits;
bitLenInt bdtStride;
int64_t devID;
QBdtNodeInterfacePtr root;
bitCapInt bdtMaxQPower;
std::vector<int64_t> deviceIDs;
std::vector<QInterfaceEngine> engines;

void SetQubitCount(bitLenInt qb, bitLenInt aqb)
{
attachedQubitCount = aqb;
SetQubitCount(qb);
}
QEnginePtr MakeQEngine(bitLenInt qbCount, bitCapInt perm = 0U);

void SetQubitCount(bitLenInt qb)
template <typename Fn> void GetTraversal(Fn getLambda)
{
QInterface::SetQubitCount(qb);
bdtQubitCount = qubitCount - attachedQubitCount;
bdtMaxQPower = pow2(bdtQubitCount);
Finish();

for (bitCapInt i = 0U; i < maxQPower; ++i) {
QBdtNodeInterfacePtr leaf = root;
complex scale = leaf->scale;
for (bitLenInt j = 0U; j < qubitCount; ++j) {
if (norm(leaf->scale) <= _qrack_qbdt_sep_thresh) {
break;
}
leaf = leaf->branches[SelectBit(i, j)];
scale *= leaf->scale;
}

getLambda((bitCapIntOcl)i, scale);
}
}
template <typename Fn> void SetTraversal(Fn setLambda)
{
Dump();

QBdtQEngineNodePtr MakeQEngineNode(complex scale, bitLenInt qbCount, bitCapInt perm = 0U);
root = std::make_shared<QBdtNode>();
root->Branch(qubitCount);

QInterfacePtr MakeTempStateVector()
{
QInterfacePtr copyPtr = NODE_TO_QENGINE(MakeQEngineNode(ONE_R1, qubitCount));
Finish();
GetQuantumState(copyPtr);
_par_for(maxQPower, [&](const bitCapInt& i, const unsigned& cpu) {
QBdtNodeInterfacePtr prevLeaf = root;
QBdtNodeInterfacePtr leaf = root;
for (bitLenInt j = 0U; j < qubitCount; ++j) {
prevLeaf = leaf;
leaf = leaf->branches[SelectBit(i, j)];
}

// If the calling function fully deferences our return, it's automatically freed.
return copyPtr;
}
setLambda((bitCapIntOcl)i, leaf);
});

template <typename Fn> void GetTraversal(Fn getLambda);
template <typename Fn> void SetTraversal(Fn setLambda);
root->PopStateVector(qubitCount);
root->Prune(qubitCount);
}
template <typename Fn> void ExecuteAsStateVector(Fn operation)
{
if (!bdtQubitCount) {
operation(NODE_TO_QENGINE(root));
return;
}

SetStateVector();
operation(NODE_TO_QENGINE(root));
ResetStateVector();
QInterfacePtr qReg = MakeQEngine(qubitCount);
GetQuantumState(qReg);
operation(qReg);
SetQuantumState(qReg);
}

template <typename Fn> bitCapInt BitCapIntAsStateVector(Fn operation)
{
if (!bdtQubitCount) {
return operation(NODE_TO_QENGINE(root));
}

SetStateVector();
bitCapInt toRet = operation(NODE_TO_QENGINE(root));
ResetStateVector();
QInterfacePtr qReg = MakeQEngine(qubitCount);
GetQuantumState(qReg);
const bitCapInt toRet = operation(qReg);
SetQuantumState(qReg);

return toRet;
}
Expand Down Expand Up @@ -133,29 +136,9 @@ class QBdt : public QParity, public QInterface {
{
}

QBdt(QEnginePtr enginePtr, std::vector<QInterfaceEngine> eng, bitLenInt qBitCount, bitCapInt ignored = 0U,
qrack_rand_gen_ptr rgp = nullptr, complex phaseFac = CMPLX_DEFAULT_ARG, bool doNorm = false,
bool randomGlobalPhase = true, bool useHostMem = false, int64_t deviceId = -1, bool useHardwareRNG = true,
bool useSparseStateVec = false, real1_f norm_thresh = REAL1_EPSILON, std::vector<int64_t> devList = {},
bitLenInt qubitThreshold = 0U, real1_f separation_thresh = FP_NORM_EPSILON_F);

QEnginePtr ReleaseEngine()
{
if (bdtQubitCount) {
throw std::domain_error("Cannot release QEngine from QBdt with BDT qubits!");
}

return NODE_TO_QENGINE(root);
}

void LockEngine(QEnginePtr eng) { root = std::make_shared<QBdtQEngineNode>(ONE_CMPLX, eng); }

bool isBinaryDecisionTree() { return true; };

void SetStateVector();
void ResetStateVector(bitLenInt aqb = 0U);

void SetDevice(int64_t dID);
void SetDevice(int64_t dID) { devID = dID; }

void UpdateRunningNorm(real1_f norm_thresh = REAL1_DEFAULT_ARG)
{
Expand All @@ -165,7 +148,7 @@ class QBdt : public QParity, public QInterface {
void NormalizeState(
real1_f nrm = REAL1_DEFAULT_ARG, real1_f norm_thresh = REAL1_DEFAULT_ARG, real1_f phaseArg = ZERO_R1_F)
{
root->Normalize(bdtQubitCount);
root->Normalize(qubitCount);
}

real1_f SumSqrDiff(QInterfacePtr toCompare) { return SumSqrDiff(std::dynamic_pointer_cast<QBdt>(toCompare)); }
Expand All @@ -175,11 +158,26 @@ class QBdt : public QParity, public QInterface {

QInterfacePtr Clone();

void GetQuantumState(complex* state);
void GetQuantumState(QInterfacePtr eng);
void SetQuantumState(const complex* state);
void SetQuantumState(QInterfacePtr eng);
void GetProbs(real1* outputProbs);
void GetQuantumState(complex* state)
{
GetTraversal([state](bitCapIntOcl i, complex scale) { state[i] = scale; });
}
void GetQuantumState(QInterfacePtr eng)
{
GetTraversal([eng](bitCapIntOcl i, complex scale) { eng->SetAmplitude(i, scale); });
}
void SetQuantumState(const complex* state)
{
SetTraversal([state](bitCapIntOcl i, QBdtNodeInterfacePtr leaf) { leaf->scale = state[i]; });
}
void SetQuantumState(QInterfacePtr eng)
{
SetTraversal([eng](bitCapIntOcl i, QBdtNodeInterfacePtr leaf) { leaf->scale = eng->GetAmplitude(i); });
}
void GetProbs(real1* outputProbs)
{
GetTraversal([outputProbs](bitCapIntOcl i, complex scale) { outputProbs[i] = norm(scale); });
}

complex GetAmplitude(bitCapInt perm);
void SetAmplitude(bitCapInt perm, complex amp)
Expand All @@ -196,41 +194,14 @@ class QBdt : public QParity, public QInterface {
void Decompose(bitLenInt start, QInterfacePtr dest)
{
QBdtPtr d = std::dynamic_pointer_cast<QBdt>(dest);
if (!bdtQubitCount) {
d->root = d->MakeQEngineNode(ONE_CMPLX, d->qubitCount, 0U);
NODE_TO_QENGINE(root)->Decompose(start, NODE_TO_QENGINE(d->root));
d->SetQubitCount(d->qubitCount, d->qubitCount);
SetQubitCount(qubitCount - d->qubitCount, qubitCount - d->qubitCount);

return;
}

DecomposeDispose(start, dest->GetQubitCount(), d);
}
QInterfacePtr Decompose(bitLenInt start, bitLenInt length);
void Dispose(bitLenInt start, bitLenInt length)
{
if (!bdtQubitCount) {
NODE_TO_QENGINE(root)->Dispose(start, length);
SetQubitCount(qubitCount - length, qubitCount - length);

return;
}

DecomposeDispose(start, length, NULL);
}
void Dispose(bitLenInt start, bitLenInt length) { DecomposeDispose(start, length, NULL); }

void Dispose(bitLenInt start, bitLenInt length, bitCapInt disposedPerm)
{
if (!bdtQubitCount) {
NODE_TO_QENGINE(root)->Dispose(start, length, disposedPerm);
SetQubitCount(qubitCount - length, qubitCount - length);

return;
}

ForceMReg(start, length, disposedPerm);

DecomposeDispose(start, length, NULL);
}

Expand Down Expand Up @@ -262,8 +233,8 @@ class QBdt : public QParity, public QInterface {
}

real1_f toRet;
ExecuteAsStateVector(
[&](QInterfacePtr eng) { toRet = QINTERFACE_TO_QPARITY(NODE_TO_QENGINE(root))->ProbParity(mask); });
ExecuteAsStateVector([&](QInterfacePtr eng) { toRet = QINTERFACE_TO_QPARITY(eng)->ProbParity(mask); });

return toRet;
}
void CUniformParityRZ(const std::vector<bitLenInt>& controls, bitCapInt mask, real1_f angle)
Expand All @@ -283,8 +254,11 @@ class QBdt : public QParity, public QInterface {
return ForceM(log2(mask), result, doForce);
}

SetStateVector();
return QINTERFACE_TO_QPARITY(NODE_TO_QENGINE(root))->ForceMParity(mask, result, doForce);
bool toRet;
ExecuteAsStateVector(
[&](QInterfacePtr eng) { toRet = QINTERFACE_TO_QPARITY(eng)->ForceMParity(mask, result, doForce); });

return toRet;
}

#if ENABLE_ALU
Expand Down
105 changes: 0 additions & 105 deletions include/qbdt_qengine_node.hpp

This file was deleted.

Loading

0 comments on commit 85cc9bb

Please sign in to comment.