Skip to content

Commit

Permalink
Optimize TryDecompose()
Browse files Browse the repository at this point in the history
  • Loading branch information
WrathfulSpatula committed Nov 3, 2024
1 parent cbd2f7e commit ee5eb69
Show file tree
Hide file tree
Showing 17 changed files with 207 additions and 10 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
cmake_minimum_required (VERSION 3.9)
project (Qrack VERSION 9.12.3 DESCRIPTION "High Performance Quantum Bit Simulation" LANGUAGES CXX)
project (Qrack VERSION 9.12.4 DESCRIPTION "High Performance Quantum Bit Simulation" LANGUAGES CXX)

# Installation commands
include (GNUInstallDirs)
Expand Down
13 changes: 13 additions & 0 deletions include/qbdt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,19 @@ class QBdt : public QParity, public QInterface {
return sample;
}

void Copy(QInterfacePtr orig) { Copy(std::dynamic_pointer_cast<QBdt>(orig)); }
void Copy(QBdtPtr orig)
{
QInterface::Copy(orig);
bdtStride = orig->bdtStride;
devID = orig->devID;
root = orig->root;
bdtMaxQPower = orig->bdtMaxQPower;
deviceIDs = orig->deviceIDs;
engines = orig->engines;
shards = orig->shards;
}

public:
QBdt(std::vector<QInterfaceEngine> eng, bitLenInt qBitCount, const bitCapInt& initState = ZERO_BCI,
qrack_rand_gen_ptr rgp = nullptr, const complex& phaseFac = CMPLX_DEFAULT_ARG, bool doNorm = false,
Expand Down
11 changes: 11 additions & 0 deletions include/qengine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,15 @@ class QEngine : public QParity, public QInterface {

void EitherMtrx(const std::vector<bitLenInt>& controls, const complex* mtrx, bitLenInt target, bool isAnti);

virtual void Copy(QInterfacePtr orig) { Copy(std::dynamic_pointer_cast<QEngine>(orig)); }
virtual void Copy(QEnginePtr orig)
{
QInterface::Copy(orig);
useHostRam = orig->useHostRam;
runningNorm = orig->runningNorm;
maxQPowerOcl = orig->maxQPowerOcl;
}

public:
QEngine(bitLenInt qBitCount, qrack_rand_gen_ptr rgp = nullptr, bool doNorm = false, bool randomGlobalPhase = true,
bool useHostMem = false, bool useHardwareRNG = true, real1_f norm_thresh = REAL1_EPSILON)
Expand Down Expand Up @@ -95,6 +104,8 @@ class QEngine : public QParity, public QInterface {
// Virtual destructor for inheritance
}

using QInterface::Copy;

virtual void SetQubitCount(bitLenInt qb)
{
QInterface::SetQubitCount(qb);
Expand Down
7 changes: 7 additions & 0 deletions include/qengine_cpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ class QEngineCPU : public QEngine {
DispatchQueue dispatchQueue;
#endif

void Copy(QInterfacePtr orig) { Copy(std::dynamic_pointer_cast<QEngineCPU>(orig)); }
void Copy(QEngineCPUPtr orig)
{
QEngine::Copy(std::dynamic_pointer_cast<QEngine>(orig));
stateVec = orig->stateVec;
}

public:
QEngineCPU(bitLenInt qBitCount, const bitCapInt& initState, qrack_rand_gen_ptr rgp = nullptr,
const complex& phaseFac = CMPLX_DEFAULT_ARG, bool doNorm = false, bool randomGlobalPhase = true,
Expand Down
24 changes: 24 additions & 0 deletions include/qengine_cuda.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,29 @@ class QEngineCUDA : public QEngine {
throw std::runtime_error(message + ", error code: " + std::to_string(error));
}

void Copy(QInterfacePtr orig) {Copy(std::dynamic_pointer_cast<QEngineOCL>(orig); }
void Copy(QEngineOCLPtr orig)
{
didInit = orig->didInit;
usingHostRam = orig->usingHostRam;
unlockHostMem = orig->unlockHostMem;
nrmGroupCount = orig->nrmGroupCount;
nrmGroupSize = orig->nrmGroupSize;
AddAlloc(orig->totalOclAllocSize);
deviceID = orig->deviceID;
lockSyncFlags = orig->lockSyncFlags;
permutationAmp = orig->permutationAmp;
stateVec = orig->stateVec;
// queue_mutex = orig->queue_mutex;
stateBuffer = orig->stateBuffer;
nrmBuffer = orig->nrmBuffer;
device_context = orig->device_context;
wait_refs = orig->wait_refs;
wait_queue_items = orig->wait_queue_items;
poolItems = orig->poolItems;
nrmArray = orig->nrmArray;
}

public:
/// 1 / OclMemDenom is the maximum fraction of total OCL device RAM that a single state vector should occupy, by
/// design of the QEngine.
Expand Down Expand Up @@ -451,6 +474,7 @@ class QEngineCUDA : public QEngine {
bool isFinished() { return wait_queue_items.empty(); };

QInterfacePtr Clone();
QInterfacePtr Copy();

void PopQueue();
void DispatchQueue();
Expand Down
26 changes: 26 additions & 0 deletions include/qengine_opencl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,32 @@ class QEngineOCL : public QEngine {
throw std::runtime_error(message + ", error code: " + std::to_string(error));
}

void Copy(QInterfacePtr orig) { Copy(std::dynamic_pointer_cast<QEngineOCL>(orig)); }
void Copy(QEngineOCLPtr orig)
{
QEngine::Copy(std::dynamic_pointer_cast<QEngine>(orig));
didInit = orig->didInit;
usingHostRam = orig->usingHostRam;
unlockHostMem = orig->unlockHostMem;
callbackError = orig->callbackError;
nrmGroupCount = orig->nrmGroupCount;
nrmGroupSize = orig->nrmGroupSize;
AddAlloc(orig->totalOclAllocSize);
deviceID = orig->deviceID;
lockSyncFlags = orig->lockSyncFlags;
permutationAmp = orig->permutationAmp;
stateVec = orig->stateVec;
// queue_mutex = orig->queue_mutex;
queue = orig->queue;
context = orig->context;
stateBuffer = orig->stateBuffer;
nrmBuffer = orig->nrmBuffer;
device_context = orig->device_context;
wait_refs = orig->wait_refs;
wait_queue_items = orig->wait_queue_items;
poolItems = orig->poolItems;
}

public:
/// 1 / OclMemDenom is the maximum fraction of total OCL device RAM that a single state vector should occupy, by
/// design of the QEngine.
Expand Down
17 changes: 17 additions & 0 deletions include/qhybrid.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,23 @@ class QHybrid : public QEngine {
complex phaseFactor;
std::vector<int64_t> deviceIDs;

void Copy(QInterfacePtr orig) { Copy(std::dynamic_pointer_cast<QHybrid>(orig)); }
void Copy(QHybridPtr orig)
{
QEngine::Copy(std::dynamic_pointer_cast<QEngine>(orig));
isGpu = orig->isGpu;
isPager = orig->isPager;
useRDRAND = orig->useRDRAND;
isSparse = orig->isSparse;
gpuThresholdQubits = orig->gpuThresholdQubits;
pagerThresholdQubits = orig->pagerThresholdQubits;
separabilityThreshold = orig->separabilityThreshold;
devID = orig->devID;
engine = orig->engine;
phaseFactor = orig->phaseFactor;
deviceIDs = orig->deviceIDs;
}

public:
QHybrid(bitLenInt qBitCount, const bitCapInt& initState = ZERO_BCI, qrack_rand_gen_ptr rgp = nullptr,
const complex& phaseFac = CMPLX_DEFAULT_ARG, bool doNorm = false, bool randomGlobalPhase = true,
Expand Down
15 changes: 15 additions & 0 deletions include/qinterface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,21 @@ class QInterface : public ParallelFor {
return isExp ? ExpectationBitsFactorized(bits, perms, offset) : VarianceBitsFactorized(bits, perms, offset);
}

virtual void Copy(QInterfacePtr orig)
{
orig->Finish();
doNormalize = orig->doNormalize;
randGlobalPhase = orig->randGlobalPhase;
useRDRAND = orig->useRDRAND;
qubitCount = orig->qubitCount;
randomSeed = orig->randomSeed;
amplitudeFloor = orig->amplitudeFloor;
maxQPower = orig->maxQPower;
rand_generator = orig->rand_generator;
rand_distribution = orig->rand_distribution;
hardware_rand_generator = orig->hardware_rand_generator;
}

public:
QInterface(bitLenInt n, qrack_rand_gen_ptr rgp = nullptr, bool doNorm = false, bool useHardwareRNG = true,
bool randomGlobalPhase = true, real1_f norm_thresh = REAL1_EPSILON);
Expand Down
10 changes: 10 additions & 0 deletions include/qinterface_noisy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,16 @@ class QInterfaceNoisy : public QInterface {
}
}

void Copy(QInterfacePtr orig) { Copy(std::dynamic_pointer_cast<QInterfaceNoisy>(orig)); }
void Copy(QInterfaceNoisyPtr orig)
{
QInterface::Copy(orig);
logFidelity = orig->logFidelity;
noiseParam = orig->noiseParam;
engine = orig->engine;
engines = orig->engines;
}

public:
QInterfaceNoisy(bitLenInt qBitCount, const bitCapInt& initState = ZERO_BCI, qrack_rand_gen_ptr rgp = nullptr,
const complex& phaseFac = CMPLX_DEFAULT_ARG, bool doNorm = false, bool randomGlobalPhase = true,
Expand Down
23 changes: 23 additions & 0 deletions include/qpager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,29 @@ class QPager : public QEngine, public std::enable_shared_from_this<QPager> {

real1_f ExpVarBitsAll(bool isExp, const std::vector<bitLenInt>& bits, const bitCapInt& offset = ZERO_BCI);

void Copy(QInterfacePtr orig) { Copy(std::dynamic_pointer_cast<QPager>(orig)); }
void Copy(QPagerPtr orig)
{
QEngine::Copy(std::dynamic_pointer_cast<QEngine>(orig));
useGpuThreshold = orig->useGpuThreshold;
isSparse = orig->isSparse;
useTGadget = orig->useTGadget;
maxPageSetting = orig->maxPageSetting;
maxPageQubits = orig->maxPageQubits;
thresholdQubitsPerPage = orig->thresholdQubitsPerPage;
baseQubitsPerPage = orig->baseQubitsPerPage;
maxQubits = orig->maxQubits;
devID = orig->devID;
rootEngine = orig->rootEngine;
basePageMaxQPower = orig->basePageMaxQPower;
basePageCount = orig->basePageCount;
phaseFactor = orig->phaseFactor;
devicesHostPointer = orig->devicesHostPointer;
deviceIDs = orig->deviceIDs;
engines = orig->engines;
qPages = orig->qPages;
}

public:
QPager(std::vector<QInterfaceEngine> eng, bitLenInt qBitCount, const bitCapInt& initState = ZERO_BCI,
qrack_rand_gen_ptr rgp = nullptr, const complex& phaseFac = CMPLX_DEFAULT_ARG, bool doNorm = false,
Expand Down
6 changes: 6 additions & 0 deletions include/qstabilizer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,12 @@ class QStabilizer : public QInterface {
}
}

void Copy(QInterfacePtr orig)
{
throw std::domain_error("Can't TryDecompose() on QStabilizerHybrid! (If you know the system is exactly "
"separable, just use Decompose() instead.)");
}

public:
QStabilizer(bitLenInt n, const bitCapInt& perm = ZERO_BCI, qrack_rand_gen_ptr rgp = nullptr,
const complex& phaseFac = CMPLX_DEFAULT_ARG, bool doNorm = false, bool randomGlobalPhase = true,
Expand Down
5 changes: 5 additions & 0 deletions include/qstabilizerhybrid.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,11 @@ class QStabilizerHybrid : public QParity, public QInterface {
complex GetAmplitudeOrProb(const bitCapInt& perm, bool isProb = false);

QInterfacePtr CloneBody(bool isCopy);
void Copy(QInterfacePtr orig)
{
throw std::domain_error("Can't TryDecompose() on QStabilizerHybrid! (If you know the system is exactly "
"separable, just use Decompose() instead.)");
}

public:
QStabilizerHybrid(std::vector<QInterfaceEngine> eng, bitLenInt qBitCount, const bitCapInt& initState = ZERO_BCI,
Expand Down
6 changes: 6 additions & 0 deletions include/qtensornetwork.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,12 @@ class QTensorNetwork : public QInterface {
}
}

void Copy(QInterfacePtr orig)
{
throw std::domain_error("Can't TryDecompose() on QTensorNetwork! (QTensorNetwork does not allow Schmidt "
"decomposition in general!)");
}

public:
QTensorNetwork(std::vector<QInterfaceEngine> eng, bitLenInt qBitCount, const bitCapInt& initState = ZERO_BCI,
qrack_rand_gen_ptr rgp = nullptr, const complex& phaseFac = CMPLX_DEFAULT_ARG, bool doNorm = false,
Expand Down
20 changes: 20 additions & 0 deletions include/qunit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,26 @@ class QUnit : public QParity, public QInterface {

QInterfacePtr MakeEngine(bitLenInt length, const bitCapInt& perm);

void Copy(QInterfacePtr orig) { Copy(std::dynamic_pointer_cast<QUnit>(orig)); }
void Copy(QUnitPtr orig)
{
QInterface::Copy(orig);
freezeBasis2Qb = orig->freezeBasis2Qb;
useHostRam = orig->useHostRam;
isSparse = orig->isSparse;
isReactiveSeparate = orig->isReactiveSeparate;
useTGadget = orig->useTGadget;
thresholdQubits = orig->thresholdQubits;
separabilityThreshold = orig->separabilityThreshold;
roundingThreshold = orig->roundingThreshold;
logFidelity = orig->logFidelity;
devID = orig->devID;
phaseFactor = orig->phaseFactor;
shards = orig->shards;
deviceIDs = orig->deviceIDs;
engines = orig->engines;
}

public:
QUnit(std::vector<QInterfaceEngine> eng, bitLenInt qBitCount, const bitCapInt& initState = ZERO_BCI,
qrack_rand_gen_ptr rgp = nullptr, const complex& phaseFac = CMPLX_DEFAULT_ARG, bool doNorm = false,
Expand Down
6 changes: 6 additions & 0 deletions include/qunitclifford.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ class QUnitClifford : public QInterface {
complex phaseOffset;
std::vector<CliffordShard> shards;

void Copy(QInterfacePtr orig)
{
throw std::domain_error("Can't TryDecompose() on QUnitClifford! (If you know the system is exactly separable, "
"just use Decompose() instead.)");
}

void CombinePhaseOffsets(QStabilizerPtr unit)
{
if (randGlobalPhase) {
Expand Down
11 changes: 11 additions & 0 deletions src/qengine/cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3112,6 +3112,17 @@ QEnginePtr QEngineCUDA::CloneEmpty()
return copyPtr;
}

QInterfacePtr QEngineCUDA::Copy()
{
QEngineCUDAPtr copyPtr = std::dynamic_pointer_cast<QEngineCUDA>(CloneEmpty());
copyPtr->stateVec = stateVec;
copyPtr->stateBuffer = stateBuffer;
// TODO: This is a hack for TryDecompose():
AddAlloc(sizeof(complex) * maxQPowerOcl);

return copyPtr;
}

void QEngineCUDA::NormalizeState(real1_f nrm, real1_f norm_thresh, real1_f phaseArg)
{
CHECK_ZERO_SKIP();
Expand Down
15 changes: 6 additions & 9 deletions src/qinterface/qinterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -828,19 +828,16 @@ bool QInterface::TryDecompose(bitLenInt start, QInterfacePtr dest, real1_f error
{
Finish();

const bool tempDoNorm = doNormalize;
doNormalize = false;
QInterfacePtr unitCopy = Copy();
doNormalize = tempDoNorm;
QInterfacePtr orig = Copy();
orig->Decompose(start, dest);
QInterfacePtr output = orig->Copy();
orig->Compose(dest, start);

unitCopy->Decompose(start, dest);
unitCopy->Compose(dest, start);

const bool didSeparate = ApproxCompare(unitCopy, error_tol);
const bool didSeparate = ApproxCompare(orig, error_tol);

if (didSeparate) {
// The subsystem is separable.
Dispose(start, dest->GetQubitCount());
Copy(output);
}

return didSeparate;
Expand Down

0 comments on commit ee5eb69

Please sign in to comment.