Skip to content

Commit

Permalink
Revert QBdt stabilizer work (putting back in branch)
Browse files Browse the repository at this point in the history
  • Loading branch information
WrathfulSpatula committed Aug 7, 2023
1 parent ca6d38b commit 82aef87
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 563 deletions.
39 changes: 0 additions & 39 deletions include/qbdt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,9 @@

#pragma once

#include "mpsshard.hpp"
#include "qbdt_node.hpp"
#include "qbdt_qstabilizer_node.hpp"
#include "qengine.hpp"

#define NODE_TO_STABILIZER(leaf) (std::dynamic_pointer_cast<QBdtQStabilizerNode>(leaf)->qReg)
#define QINTERFACE_TO_QALU(qReg) std::dynamic_pointer_cast<QAlu>(qReg)
#define QINTERFACE_TO_QPARITY(qReg) std::dynamic_pointer_cast<QParity>(qReg)

Expand All @@ -42,43 +39,7 @@ class QBdt : public QParity, public QInterface {
bitCapInt bdtMaxQPower;
std::vector<int64_t> deviceIDs;
std::vector<QInterfaceEngine> engines;
std::vector<MpsShardPtr> shards;

void DumpBuffers()
{
for (size_t i = 0; i < shards.size(); ++i) {
shards[i] = NULL;
}
}
void FlushBuffers()
{
for (size_t i = 0U; i < shards.size(); ++i) {
const MpsShardPtr shard = shards[i];
if (shard) {
shards[i] = NULL;
ApplySingle(shard->gate, i);
}
}
}

void FlushIfBlocked(bitLenInt target, const std::vector<bitLenInt>& controls = std::vector<bitLenInt>())
{
for (const bitLenInt& control : controls) {
const MpsShardPtr shard = shards[control];
if (shard && !shard->IsPhase()) {
shards[control] = NULL;
ApplySingle(shard->gate, control);
}
}

const MpsShardPtr shard = shards[target];
if (shard) {
shards[target] = NULL;
ApplySingle(shard->gate, target);
}
}

QBdtQStabilizerNodePtr MakeQStabilizerNode(complex scale, bitLenInt qbCount, bitCapInt perm = 0U);
QEnginePtr MakeQEngine(bitLenInt qbCount, bitCapInt perm = 0U);

template <typename Fn> void GetTraversal(Fn getLambda)
Expand Down
19 changes: 0 additions & 19 deletions include/qbdt_node.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,25 +76,6 @@ class QBdtNode : public QBdtNodeInterface {
#else
virtual void Apply2x2(complex const* mtrx, bitLenInt depth);
#endif

virtual QBdtNodeInterfacePtr PopSpecial(bitLenInt depth = 1U)
{
if (!depth) {
return shared_from_this();
}

if (norm(scale) <= _qrack_qbdt_sep_thresh) {
SetZero();
return shared_from_this();
}

--depth;

branches[0U] = branches[0U]->PopSpecial(depth);
branches[1U] = branches[1U]->PopSpecial(depth);

return shared_from_this();
}
};

} // namespace Qrack
24 changes: 14 additions & 10 deletions include/qbdt_node_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ namespace Qrack {
class QBdtNodeInterface;
typedef std::shared_ptr<QBdtNodeInterface> QBdtNodeInterfacePtr;

class QBdtNodeInterface : public std::enable_shared_from_this<QBdtNodeInterface> {
class QBdtNodeInterface {
protected:
static size_t SelectBit(bitCapInt perm, bitLenInt bit) { return (size_t)((perm >> bit) & 1U); }
static void _par_for_qbdt(const bitCapInt end, BdtFunc fn);
Expand Down Expand Up @@ -82,8 +82,6 @@ class QBdtNodeInterface : public std::enable_shared_from_this<QBdtNodeInterface>
// Virtual destructor for inheritance
}

virtual bool IsStabilizer() { return false; }

virtual void InsertAtDepth(QBdtNodeInterfacePtr b, bitLenInt depth, const bitLenInt& size, bitLenInt parDepth = 1U)
{
throw std::out_of_range("QBdtNodeInterface::InsertAtDepth() not implemented! (You probably set "
Expand Down Expand Up @@ -148,20 +146,26 @@ class QBdtNodeInterface : public std::enable_shared_from_this<QBdtNodeInterface>
"QRACK_QBDT_SEPARABILITY_THRESHOLD too high.)");
}

#if ENABLE_COMPLEX_X2
virtual void Apply2x2(const complex2& mtrxCol1, const complex2& mtrxCol2, const complex2& mtrxColShuff1,
const complex2& mtrxColShuff2, bitLenInt depth)
{
throw std::out_of_range("QBdtQStabilizerNode::Apply2x2() not implemented!");
}
#else
virtual void Apply2x2(complex const* mtrx, bitLenInt depth)
#endif
{
throw std::out_of_range("QBdtQStabilizerNode::Apply2x2() not implemented!");
throw std::out_of_range("QBdtNodeInterface::Apply2x2() not implemented! (You probably set "
"QRACK_QBDT_SEPARABILITY_THRESHOLD too high.)");
}

virtual QBdtNodeInterfacePtr PopSpecial(bitLenInt depth = 1U)
#if ENABLE_COMPLEX_X2
virtual void PushSpecial(const complex2& mtrxCol1, const complex2& mtrxCol2, const complex2& mtrxColShuff1,
const complex2& mtrxColShuff2, QBdtNodeInterfacePtr& b1)
#else
virtual void PushSpecial(complex const* mtrx, QBdtNodeInterfacePtr& b1)
#endif
{
throw std::out_of_range(
"QBdtNodeInterface::PopSpecial() not implemented! (Check IsStabilizer() before PopSpecial().)");
throw std::out_of_range("QBdtNodeInterface::PushSpecial() not implemented! (You probably called "
"PushStateVector() past terminal depth.)");
}
};

Expand Down
24 changes: 18 additions & 6 deletions src/qbdt/node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -587,15 +587,21 @@ void QBdtNode::PushStateVector(const complex2& mtrxCol1, const complex2& mtrxCol
b0->Branch();
b1->Branch();

b0 = b0->PopSpecial();
b1 = b1->PopSpecial();

// For parallelism, keep shared_ptr from deallocating.
QBdtNodeInterfacePtr b00 = b0->branches[0U];
QBdtNodeInterfacePtr b01 = b0->branches[1U];
QBdtNodeInterfacePtr b10 = b1->branches[0U];
QBdtNodeInterfacePtr b11 = b1->branches[1U];

if (!b00) {
b0->PushSpecial(mtrxCol1, mtrxCol2, mtrxColShuff1, mtrxColShuff2, b1);

b0->PopStateVector();
b1->PopStateVector();

return;
}

if (true) {
std::lock(b00->mtx, b01->mtx);
std::lock_guard<std::mutex> lock0(b00->mtx, std::adopt_lock);
Expand Down Expand Up @@ -738,15 +744,21 @@ void QBdtNode::PushStateVector(
b0->Branch();
b1->Branch();

b0 = b0->PopSpecial();
b1 = b1->PopSpecial();

// For parallelism, keep shared_ptr from deallocating.
QBdtNodeInterfacePtr b00 = b0->branches[0U];
QBdtNodeInterfacePtr b01 = b0->branches[1U];
QBdtNodeInterfacePtr b10 = b1->branches[0U];
QBdtNodeInterfacePtr b11 = b1->branches[1U];

if (!b00) {
b0->PushSpecial(mtrx, b1);

b0->PopStateVector();
b1->PopStateVector();

return;
}

if (true) {
std::lock(b00->mtx, b01->mtx);
std::lock_guard<std::mutex> lock0(b00->mtx, std::adopt_lock);
Expand Down
Loading

0 comments on commit 82aef87

Please sign in to comment.