diff --git a/include/checker/zx/ZXChecker.hpp b/include/checker/zx/ZXChecker.hpp index c8419b44..bb9067cd 100644 --- a/include/checker/zx/ZXChecker.hpp +++ b/include/checker/zx/ZXChecker.hpp @@ -4,6 +4,7 @@ #include "Definitions.hpp" #include "EquivalenceCriterion.hpp" #include "QuantumComputation.hpp" +#include "Simplify.hpp" #include "ZXDiagram.hpp" #include "checker/EquivalenceChecker.hpp" #include "nlohmann/json.hpp" @@ -24,6 +25,77 @@ namespace ec { zx::ZXDiagram miter; zx::fp tolerance; bool ancilla = false; + + // the following methods are adaptations of the core ZX simplification routines + // that additionally check a criterion for early termination of the simplification. + std::size_t fullReduceApproximate(); + std::size_t fullReduce(); + + std::size_t gadgetSimp(); + std::size_t interiorCliffordSimp(); + std::size_t cliffordSimp(); + + std::size_t idSimp() { + return simplifyVertices(zx::checkIdSimp, zx::removeId); + } + + std::size_t spiderSimp() { + return simplifyEdges(zx::checkSpiderFusion, zx::fuseSpiders); + } + + std::size_t localCompSimp() { + return simplifyVertices(zx::checkLocalComp, zx::localComp); + } + + std::size_t pivotPauliSimp() { + return simplifyEdges(zx::checkPivotPauli, zx::pivotPauli); + } + + std::size_t pivotSimp() { + return simplifyEdges(zx::checkPivot, zx::pivot); + } + + std::size_t pivotGadgetSimp() { + return simplifyEdges(zx::checkPivotGadget, zx::pivotGadget); + } + + template + std::size_t simplifyVertices(CheckFun check, RuleFun rule) { + std::size_t nSimplifications = 0; + bool newMatches = true; + + while (!isDone() && newMatches) { + newMatches = false; + for (const auto& [v, _]: miter.getVertices()) { + if (isDone() || !check(miter, v)) { + continue; + } + rule(miter, v); + newMatches = true; + nSimplifications++; + } + } + return nSimplifications; + } + + template + std::size_t simplifyEdges(CheckFun check, RuleFun rule) { + std::size_t nSimplifications = 0; + bool newMatches = true; + + while (!isDone() && newMatches) { + newMatches = false; + for (const auto& [v0, v1]: miter.getEdges()) { + if (isDone() || miter.isDeleted(v0) || miter.isDeleted(v1) || !check(miter, v0, v1)) { + continue; + } + rule(miter, v0, v1); + newMatches = true; + nSimplifications++; + } + } + return nSimplifications; + } }; qc::Permutation complete(const qc::Permutation& p, dd::Qubit n); diff --git a/src/checker/zx/ZXChecker.cpp b/src/checker/zx/ZXChecker.cpp index 25295c6c..3354da1e 100644 --- a/src/checker/zx/ZXChecker.cpp +++ b/src/checker/zx/ZXChecker.cpp @@ -37,7 +37,7 @@ namespace ec { EquivalenceCriterion ZXEquivalenceChecker::run() { const auto start = std::chrono::steady_clock::now(); - zx::fullReduceApproximate(miter, tolerance); + fullReduceApproximate(); bool equivalent = true; @@ -61,8 +61,8 @@ namespace ec { const auto end = std::chrono::steady_clock::now(); runtime += std::chrono::duration(end - start).count(); - // non-equivalence might be due to incorrect assumption about the state of ancillaries, so no information can be given - if (!equivalent && ancilla) + // non-equivalence might be due to incorrect assumption about the state of ancillaries or the check was aborted prematurely, so no information can be given + if ((!equivalent && ancilla) || isDone()) equivalence = EquivalenceCriterion::NoInformation; else equivalence = equivalent ? EquivalenceCriterion::EquivalentUpToGlobalPhase : EquivalenceCriterion::ProbablyNotEquivalent; @@ -141,6 +141,97 @@ namespace ec { } qc::Permutation invertPermutations(const qc::QuantumComputation& qc) { - return concat(invert(complete(qc.outputPermutation, qc.getNqubits())), complete(qc.initialLayout, qc.getNqubits())); + return concat( + invert(complete(qc.outputPermutation, static_cast(qc.getNqubits()))), + complete(qc.initialLayout, static_cast(qc.getNqubits()))); } + + std::size_t ZXEquivalenceChecker::fullReduceApproximate() { + auto nSimplifications = fullReduce(); + std::size_t newSimps; + do { + miter.approximateCliffords(tolerance); + newSimps = fullReduce(); + nSimplifications += newSimps; + } while (!isDone() && (newSimps > 0)); + return nSimplifications; + } + + std::size_t ZXEquivalenceChecker::fullReduce() { + if (!isDone()) { + miter.toGraphlike(); + } + interiorCliffordSimp(); + + std::size_t nSimplifications = 0; + while (!isDone()) { + cliffordSimp(); + const auto nGadget = gadgetSimp(); + interiorCliffordSimp(); + const auto nPivot = pivotGadgetSimp(); + if (nGadget + nPivot == 0) + break; + nSimplifications += nGadget + nPivot; + } + if (!isDone()) { + miter.removeDisconnectedSpiders(); + } + + return nSimplifications; + } + + std::size_t ZXEquivalenceChecker::gadgetSimp() { + std::size_t nSimplifications = 0; + bool new_matches = true; + + while (!isDone() && new_matches) { + new_matches = false; + for (const auto& [v, _]: miter.getVertices()) { + if (miter.isDeleted(v)) + continue; + + if (!isDone() && checkAndFuseGadget(miter, v)) { + new_matches = true; + nSimplifications++; + } + } + } + return nSimplifications; + } + + std::size_t ZXEquivalenceChecker::interiorCliffordSimp() { + spiderSimp(); + + bool newMatches = true; + std::size_t nSimplifications = 0; + while (!isDone() && newMatches) { + newMatches = false; + const auto nId = idSimp(); + const auto nSpider = spiderSimp(); + const auto nPivot = pivotPauliSimp(); + const auto nLocalComp = localCompSimp(); + + if (nId + nSpider + nPivot + nLocalComp != 0) { + newMatches = true; + nSimplifications++; + } + } + return nSimplifications; + } + + std::size_t ZXEquivalenceChecker::cliffordSimp() { + bool newMatches = true; + std::size_t nSimplifications = 0; + while (!isDone() && newMatches) { + newMatches = false; + const auto nClifford = interiorCliffordSimp(); + const auto nPivot = pivotSimp(); + if (nClifford + nPivot != 0) { + newMatches = true; + nSimplifications++; + } + } + return nSimplifications; + } + } // namespace ec