diff --git a/CMakeLists.txt b/CMakeLists.txt index 3d9f41bdc..3bf2f4f5c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,5 +1,5 @@ cmake_minimum_required (VERSION 3.9) -project (Qrack VERSION 9.12.3 DESCRIPTION "High Performance Quantum Bit Simulation" LANGUAGES CXX) +project (Qrack VERSION 9.12.4 DESCRIPTION "High Performance Quantum Bit Simulation" LANGUAGES CXX) # Installation commands include (GNUInstallDirs) diff --git a/include/qbdt.hpp b/include/qbdt.hpp index 7a7c436f5..6e8a59931 100644 --- a/include/qbdt.hpp +++ b/include/qbdt.hpp @@ -186,6 +186,19 @@ class QBdt : public QParity, public QInterface { return sample; } + void Copy(QInterfacePtr orig) { Copy(std::dynamic_pointer_cast(orig)); } + void Copy(QBdtPtr orig) + { + QInterface::Copy(orig); + bdtStride = orig->bdtStride; + devID = orig->devID; + root = orig->root; + bdtMaxQPower = orig->bdtMaxQPower; + deviceIDs = orig->deviceIDs; + engines = orig->engines; + shards = orig->shards; + } + public: QBdt(std::vector eng, bitLenInt qBitCount, const bitCapInt& initState = ZERO_BCI, qrack_rand_gen_ptr rgp = nullptr, const complex& phaseFac = CMPLX_DEFAULT_ARG, bool doNorm = false, diff --git a/include/qengine.hpp b/include/qengine.hpp index 71d058937..7fbbc0477 100644 --- a/include/qengine.hpp +++ b/include/qengine.hpp @@ -67,6 +67,15 @@ class QEngine : public QParity, public QInterface { void EitherMtrx(const std::vector& controls, const complex* mtrx, bitLenInt target, bool isAnti); + virtual void Copy(QInterfacePtr orig) { Copy(std::dynamic_pointer_cast(orig)); } + virtual void Copy(QEnginePtr orig) + { + QInterface::Copy(orig); + useHostRam = orig->useHostRam; + runningNorm = orig->runningNorm; + maxQPowerOcl = orig->maxQPowerOcl; + } + public: QEngine(bitLenInt qBitCount, qrack_rand_gen_ptr rgp = nullptr, bool doNorm = false, bool randomGlobalPhase = true, bool useHostMem = false, bool useHardwareRNG = true, real1_f norm_thresh = REAL1_EPSILON) @@ -95,6 +104,8 @@ class QEngine : public QParity, public QInterface { // Virtual destructor for inheritance } + using QInterface::Copy; + virtual void SetQubitCount(bitLenInt qb) { QInterface::SetQubitCount(qb); diff --git a/include/qengine_cpu.hpp b/include/qengine_cpu.hpp index 44ea37d8a..c865f1cad 100644 --- a/include/qengine_cpu.hpp +++ b/include/qengine_cpu.hpp @@ -40,6 +40,13 @@ class QEngineCPU : public QEngine { DispatchQueue dispatchQueue; #endif + void Copy(QInterfacePtr orig) { Copy(std::dynamic_pointer_cast(orig)); } + void Copy(QEngineCPUPtr orig) + { + QEngine::Copy(std::dynamic_pointer_cast(orig)); + stateVec = orig->stateVec; + } + public: QEngineCPU(bitLenInt qBitCount, const bitCapInt& initState, qrack_rand_gen_ptr rgp = nullptr, const complex& phaseFac = CMPLX_DEFAULT_ARG, bool doNorm = false, bool randomGlobalPhase = true, diff --git a/include/qengine_cuda.hpp b/include/qengine_cuda.hpp index db0a0e243..dd99bd8d5 100644 --- a/include/qengine_cuda.hpp +++ b/include/qengine_cuda.hpp @@ -232,6 +232,29 @@ class QEngineCUDA : public QEngine { throw std::runtime_error(message + ", error code: " + std::to_string(error)); } + void Copy(QInterfacePtr orig) {Copy(std::dynamic_pointer_cast(orig); } + void Copy(QEngineOCLPtr orig) + { + didInit = orig->didInit; + usingHostRam = orig->usingHostRam; + unlockHostMem = orig->unlockHostMem; + nrmGroupCount = orig->nrmGroupCount; + nrmGroupSize = orig->nrmGroupSize; + AddAlloc(orig->totalOclAllocSize); + deviceID = orig->deviceID; + lockSyncFlags = orig->lockSyncFlags; + permutationAmp = orig->permutationAmp; + stateVec = orig->stateVec; + // queue_mutex = orig->queue_mutex; + stateBuffer = orig->stateBuffer; + nrmBuffer = orig->nrmBuffer; + device_context = orig->device_context; + wait_refs = orig->wait_refs; + wait_queue_items = orig->wait_queue_items; + poolItems = orig->poolItems; + nrmArray = orig->nrmArray; + } + public: /// 1 / OclMemDenom is the maximum fraction of total OCL device RAM that a single state vector should occupy, by /// design of the QEngine. @@ -451,6 +474,7 @@ class QEngineCUDA : public QEngine { bool isFinished() { return wait_queue_items.empty(); }; QInterfacePtr Clone(); + QInterfacePtr Copy(); void PopQueue(); void DispatchQueue(); diff --git a/include/qengine_opencl.hpp b/include/qengine_opencl.hpp index b3faf773c..18ee87c07 100644 --- a/include/qengine_opencl.hpp +++ b/include/qengine_opencl.hpp @@ -246,6 +246,32 @@ class QEngineOCL : public QEngine { throw std::runtime_error(message + ", error code: " + std::to_string(error)); } + void Copy(QInterfacePtr orig) { Copy(std::dynamic_pointer_cast(orig)); } + void Copy(QEngineOCLPtr orig) + { + QEngine::Copy(std::dynamic_pointer_cast(orig)); + didInit = orig->didInit; + usingHostRam = orig->usingHostRam; + unlockHostMem = orig->unlockHostMem; + callbackError = orig->callbackError; + nrmGroupCount = orig->nrmGroupCount; + nrmGroupSize = orig->nrmGroupSize; + AddAlloc(orig->totalOclAllocSize); + deviceID = orig->deviceID; + lockSyncFlags = orig->lockSyncFlags; + permutationAmp = orig->permutationAmp; + stateVec = orig->stateVec; + // queue_mutex = orig->queue_mutex; + queue = orig->queue; + context = orig->context; + stateBuffer = orig->stateBuffer; + nrmBuffer = orig->nrmBuffer; + device_context = orig->device_context; + wait_refs = orig->wait_refs; + wait_queue_items = orig->wait_queue_items; + poolItems = orig->poolItems; + } + public: /// 1 / OclMemDenom is the maximum fraction of total OCL device RAM that a single state vector should occupy, by /// design of the QEngine. diff --git a/include/qhybrid.hpp b/include/qhybrid.hpp index 9ba0c8d7c..b70c42a75 100644 --- a/include/qhybrid.hpp +++ b/include/qhybrid.hpp @@ -46,6 +46,23 @@ class QHybrid : public QEngine { complex phaseFactor; std::vector deviceIDs; + void Copy(QInterfacePtr orig) { Copy(std::dynamic_pointer_cast(orig)); } + void Copy(QHybridPtr orig) + { + QEngine::Copy(std::dynamic_pointer_cast(orig)); + isGpu = orig->isGpu; + isPager = orig->isPager; + useRDRAND = orig->useRDRAND; + isSparse = orig->isSparse; + gpuThresholdQubits = orig->gpuThresholdQubits; + pagerThresholdQubits = orig->pagerThresholdQubits; + separabilityThreshold = orig->separabilityThreshold; + devID = orig->devID; + engine = orig->engine; + phaseFactor = orig->phaseFactor; + deviceIDs = orig->deviceIDs; + } + public: QHybrid(bitLenInt qBitCount, const bitCapInt& initState = ZERO_BCI, qrack_rand_gen_ptr rgp = nullptr, const complex& phaseFac = CMPLX_DEFAULT_ARG, bool doNorm = false, bool randomGlobalPhase = true, diff --git a/include/qinterface.hpp b/include/qinterface.hpp index 8f5045b26..9cd4e3fe8 100644 --- a/include/qinterface.hpp +++ b/include/qinterface.hpp @@ -216,6 +216,21 @@ class QInterface : public ParallelFor { return isExp ? ExpectationBitsFactorized(bits, perms, offset) : VarianceBitsFactorized(bits, perms, offset); } + virtual void Copy(QInterfacePtr orig) + { + orig->Finish(); + doNormalize = orig->doNormalize; + randGlobalPhase = orig->randGlobalPhase; + useRDRAND = orig->useRDRAND; + qubitCount = orig->qubitCount; + randomSeed = orig->randomSeed; + amplitudeFloor = orig->amplitudeFloor; + maxQPower = orig->maxQPower; + rand_generator = orig->rand_generator; + rand_distribution = orig->rand_distribution; + hardware_rand_generator = orig->hardware_rand_generator; + } + public: QInterface(bitLenInt n, qrack_rand_gen_ptr rgp = nullptr, bool doNorm = false, bool useHardwareRNG = true, bool randomGlobalPhase = true, real1_f norm_thresh = REAL1_EPSILON); diff --git a/include/qinterface_noisy.hpp b/include/qinterface_noisy.hpp index 0a443e792..cb632025a 100644 --- a/include/qinterface_noisy.hpp +++ b/include/qinterface_noisy.hpp @@ -49,6 +49,16 @@ class QInterfaceNoisy : public QInterface { } } + void Copy(QInterfacePtr orig) { Copy(std::dynamic_pointer_cast(orig)); } + void Copy(QInterfaceNoisyPtr orig) + { + QInterface::Copy(orig); + logFidelity = orig->logFidelity; + noiseParam = orig->noiseParam; + engine = orig->engine; + engines = orig->engines; + } + public: QInterfaceNoisy(bitLenInt qBitCount, const bitCapInt& initState = ZERO_BCI, qrack_rand_gen_ptr rgp = nullptr, const complex& phaseFac = CMPLX_DEFAULT_ARG, bool doNorm = false, bool randomGlobalPhase = true, diff --git a/include/qpager.hpp b/include/qpager.hpp index 7c3bff24d..8cad862e8 100644 --- a/include/qpager.hpp +++ b/include/qpager.hpp @@ -99,6 +99,29 @@ class QPager : public QEngine, public std::enable_shared_from_this { real1_f ExpVarBitsAll(bool isExp, const std::vector& bits, const bitCapInt& offset = ZERO_BCI); + void Copy(QInterfacePtr orig) { Copy(std::dynamic_pointer_cast(orig)); } + void Copy(QPagerPtr orig) + { + QEngine::Copy(std::dynamic_pointer_cast(orig)); + useGpuThreshold = orig->useGpuThreshold; + isSparse = orig->isSparse; + useTGadget = orig->useTGadget; + maxPageSetting = orig->maxPageSetting; + maxPageQubits = orig->maxPageQubits; + thresholdQubitsPerPage = orig->thresholdQubitsPerPage; + baseQubitsPerPage = orig->baseQubitsPerPage; + maxQubits = orig->maxQubits; + devID = orig->devID; + rootEngine = orig->rootEngine; + basePageMaxQPower = orig->basePageMaxQPower; + basePageCount = orig->basePageCount; + phaseFactor = orig->phaseFactor; + devicesHostPointer = orig->devicesHostPointer; + deviceIDs = orig->deviceIDs; + engines = orig->engines; + qPages = orig->qPages; + } + public: QPager(std::vector eng, bitLenInt qBitCount, const bitCapInt& initState = ZERO_BCI, qrack_rand_gen_ptr rgp = nullptr, const complex& phaseFac = CMPLX_DEFAULT_ARG, bool doNorm = false, diff --git a/include/qstabilizer.hpp b/include/qstabilizer.hpp index 24cf7e5fe..cc487de28 100644 --- a/include/qstabilizer.hpp +++ b/include/qstabilizer.hpp @@ -80,6 +80,12 @@ class QStabilizer : public QInterface { } } + void Copy(QInterfacePtr orig) + { + throw std::domain_error("Can't TryDecompose() on QStabilizerHybrid! (If you know the system is exactly " + "separable, just use Decompose() instead.)"); + } + public: QStabilizer(bitLenInt n, const bitCapInt& perm = ZERO_BCI, qrack_rand_gen_ptr rgp = nullptr, const complex& phaseFac = CMPLX_DEFAULT_ARG, bool doNorm = false, bool randomGlobalPhase = true, diff --git a/include/qstabilizerhybrid.hpp b/include/qstabilizerhybrid.hpp index 22d925e94..ce4008248 100644 --- a/include/qstabilizerhybrid.hpp +++ b/include/qstabilizerhybrid.hpp @@ -320,6 +320,11 @@ class QStabilizerHybrid : public QParity, public QInterface { complex GetAmplitudeOrProb(const bitCapInt& perm, bool isProb = false); QInterfacePtr CloneBody(bool isCopy); + void Copy(QInterfacePtr orig) + { + throw std::domain_error("Can't TryDecompose() on QStabilizerHybrid! (If you know the system is exactly " + "separable, just use Decompose() instead.)"); + } public: QStabilizerHybrid(std::vector eng, bitLenInt qBitCount, const bitCapInt& initState = ZERO_BCI, diff --git a/include/qtensornetwork.hpp b/include/qtensornetwork.hpp index ec1b48fe1..e6cece510 100644 --- a/include/qtensornetwork.hpp +++ b/include/qtensornetwork.hpp @@ -130,6 +130,12 @@ class QTensorNetwork : public QInterface { } } + void Copy(QInterfacePtr orig) + { + throw std::domain_error("Can't TryDecompose() on QTensorNetwork! (QTensorNetwork does not allow Schmidt " + "decomposition in general!)"); + } + public: QTensorNetwork(std::vector eng, bitLenInt qBitCount, const bitCapInt& initState = ZERO_BCI, qrack_rand_gen_ptr rgp = nullptr, const complex& phaseFac = CMPLX_DEFAULT_ARG, bool doNorm = false, diff --git a/include/qunit.hpp b/include/qunit.hpp index 6a7fa3fea..647fbbb9b 100644 --- a/include/qunit.hpp +++ b/include/qunit.hpp @@ -47,6 +47,26 @@ class QUnit : public QParity, public QInterface { QInterfacePtr MakeEngine(bitLenInt length, const bitCapInt& perm); + void Copy(QInterfacePtr orig) { Copy(std::dynamic_pointer_cast(orig)); } + void Copy(QUnitPtr orig) + { + QInterface::Copy(orig); + freezeBasis2Qb = orig->freezeBasis2Qb; + useHostRam = orig->useHostRam; + isSparse = orig->isSparse; + isReactiveSeparate = orig->isReactiveSeparate; + useTGadget = orig->useTGadget; + thresholdQubits = orig->thresholdQubits; + separabilityThreshold = orig->separabilityThreshold; + roundingThreshold = orig->roundingThreshold; + logFidelity = orig->logFidelity; + devID = orig->devID; + phaseFactor = orig->phaseFactor; + shards = orig->shards; + deviceIDs = orig->deviceIDs; + engines = orig->engines; + } + public: QUnit(std::vector eng, bitLenInt qBitCount, const bitCapInt& initState = ZERO_BCI, qrack_rand_gen_ptr rgp = nullptr, const complex& phaseFac = CMPLX_DEFAULT_ARG, bool doNorm = false, diff --git a/include/qunitclifford.hpp b/include/qunitclifford.hpp index 90686749b..66451c696 100644 --- a/include/qunitclifford.hpp +++ b/include/qunitclifford.hpp @@ -44,6 +44,12 @@ class QUnitClifford : public QInterface { complex phaseOffset; std::vector shards; + void Copy(QInterfacePtr orig) + { + throw std::domain_error("Can't TryDecompose() on QUnitClifford! (If you know the system is exactly separable, " + "just use Decompose() instead.)"); + } + void CombinePhaseOffsets(QStabilizerPtr unit) { if (randGlobalPhase) { diff --git a/src/qengine/cuda.cu b/src/qengine/cuda.cu index 3396adf3f..f4a958b16 100644 --- a/src/qengine/cuda.cu +++ b/src/qengine/cuda.cu @@ -3112,6 +3112,17 @@ QEnginePtr QEngineCUDA::CloneEmpty() return copyPtr; } +QInterfacePtr QEngineCUDA::Copy() +{ + QEngineCUDAPtr copyPtr = std::dynamic_pointer_cast(CloneEmpty()); + copyPtr->stateVec = stateVec; + copyPtr->stateBuffer = stateBuffer; + // TODO: This is a hack for TryDecompose(): + AddAlloc(sizeof(complex) * maxQPowerOcl); + + return copyPtr; +} + void QEngineCUDA::NormalizeState(real1_f nrm, real1_f norm_thresh, real1_f phaseArg) { CHECK_ZERO_SKIP(); diff --git a/src/qinterface/qinterface.cpp b/src/qinterface/qinterface.cpp index a2f7c8d68..61d77e613 100644 --- a/src/qinterface/qinterface.cpp +++ b/src/qinterface/qinterface.cpp @@ -828,19 +828,16 @@ bool QInterface::TryDecompose(bitLenInt start, QInterfacePtr dest, real1_f error { Finish(); - const bool tempDoNorm = doNormalize; - doNormalize = false; - QInterfacePtr unitCopy = Copy(); - doNormalize = tempDoNorm; + QInterfacePtr orig = Copy(); + orig->Decompose(start, dest); + QInterfacePtr output = orig->Copy(); + orig->Compose(dest, start); - unitCopy->Decompose(start, dest); - unitCopy->Compose(dest, start); - - const bool didSeparate = ApproxCompare(unitCopy, error_tol); + const bool didSeparate = ApproxCompare(orig, error_tol); if (didSeparate) { // The subsystem is separable. - Dispose(start, dest->GetQubitCount()); + Copy(output); } return didSeparate;