diff --git a/include/qunit.hpp b/include/qunit.hpp index 0b2476e4a..4860a1db2 100644 --- a/include/qunit.hpp +++ b/include/qunit.hpp @@ -120,11 +120,16 @@ class QUnit : public QParity, public QInterface { const bitLenInt length = dest->GetQubitCount(); + if ((start + length) > qubitCount) { + throw std::invalid_argument("QUnit::TryDecompose qubit range out-of-bounds!"); + } + for (bitLenInt i = 0U; i < length; ++i) { - if (!shards[i].unit) { + QEngineShard& shard = shards[start + i]; + if (!shard.unit) { continue; } - if (!shards[i].unit->isBinaryDecisionTree()) { + if (!shard.unit->isBinaryDecisionTree()) { return QInterface::TryDecompose(start, dest, error_tol); } } @@ -135,17 +140,18 @@ class QUnit : public QParity, public QInterface { Swap(start + i, qubitCount - (i + 1U)); } - const bool isSeparable = TryDetach(nStart); + if (TryDetach(nStart)) { + Decompose(qubitCount - length, dest); + for (bitLenInt i = shift; i > 0U; --i) { + dest->Swap(i - 1U, dest->GetQubitCount() - i); + } + return true; + } for (bitLenInt i = shift; i > 0U; --i) { Swap(start + (i - 1U), qubitCount - i); } - if (isSeparable) { - Decompose(start, dest); - return true; - } - return false; } diff --git a/src/qunit.cpp b/src/qunit.cpp index 6d02a50f0..e91c51db4 100644 --- a/src/qunit.cpp +++ b/src/qunit.cpp @@ -386,7 +386,10 @@ void QUnit::Detach(bitLenInt start, bitLenInt length, QUnitPtr dest) bool QUnit::TryDetach(bitLenInt length) { - if (!length || (length > qubitCount)) { + if (!length || (length == qubitCount)) { + return true; + } + if (length > qubitCount) { throw std::invalid_argument("QUnit::Detach range is out-of-bounds!"); } @@ -417,26 +420,10 @@ bool QUnit::TryDetach(bitLenInt length) // After ordering all subunits contiguously, since the top level mapping is a contiguous array, all subunit sets are // also contiguous. From the lowest index bits, they are mapped simply for the length count of bits involved in the // entire subunit. - std::map decomposedUnits; - for (bitLenInt i = 0U; i < length; ++i) { - QEngineShard& shard = shards[start + i]; - QBdtPtr unit = std::dynamic_pointer_cast(shard.unit); - - if (unit == NULL) { - continue; - } - - if (decomposedUnits.find(unit) == decomposedUnits.end()) { - decomposedUnits[unit] = start + i; - const bitLenInt subLen = subunits[unit]; - const bitLenInt origLen = unit->GetQubitCount(); - if ((subLen != origLen) && !unit->IsSeparable(shard.mapped)) { - return false; - } - } - } + QEngineShard& shard = shards[start]; + QBdtPtr unit = std::dynamic_pointer_cast(shard.unit); - return true; + return (unit == NULL) || (shard.mapped == 0U) || unit->IsSeparable(shard.mapped); } QInterfacePtr QUnit::EntangleInCurrentBasis(