From 9370f6d7410c1655e40b3664d71fb870c72667f7 Mon Sep 17 00:00:00 2001 From: Aaron Sander <61705296+aaronleesander@users.noreply.github.com> Date: Wed, 14 Jun 2023 14:08:11 +0200 Subject: [PATCH] makeIdent() now returns a terminal node. Trace checks for terminal before calculation. --- include/dd/Package.hpp | 42 +++++++++++++----------------------------- test/test_package.cpp | 39 ++++++++++++++++++++++++--------------- 2 files changed, 37 insertions(+), 44 deletions(-) diff --git a/include/dd/Package.hpp b/include/dd/Package.hpp index 22f6b998..79e2b532 100644 --- a/include/dd/Package.hpp +++ b/include/dd/Package.hpp @@ -2066,7 +2066,15 @@ namespace dd { public: mEdge partialTrace(const mEdge& a, const std::vector& eliminate) { [[maybe_unused]] const auto before = cn.cacheCount(); - const auto result = trace(a, eliminate); + mEdge result; + // Check for identity case + if (a.isTerminal()) { + auto relevantQubits = std::count(eliminate.begin(), eliminate.end(), true); + auto ones = std::pow(2, relevantQubits); + result = mEdge::terminal(cn.lookup(ones)); + } else { + result = trace(a, eliminate); + } [[maybe_unused]] const auto after = cn.cacheCount(); assert(before == after); return result; @@ -2220,34 +2228,10 @@ namespace dd { /// /// Identity matrices /// - // create n-qubit identity DD. makeIdent(n) === makeIdent(0, n-1) - // mEdge makeIdent(QubitCount n) { return makeIdent(0, static_cast(n - 1)); } - // mEdge makeIdent(Qubit leastSignificantQubit, Qubit mostSignificantQubit) { - // if (mostSignificantQubit < leastSignificantQubit) { - // return mEdge::one; - // } -// - // if (leastSignificantQubit == 0 && idTable[static_cast(mostSignificantQubit)].p != nullptr) { - // return idTable[static_cast(mostSignificantQubit)]; - // } - // if (mostSignificantQubit >= 1 && (idTable[static_cast(mostSignificantQubit - 1)]).p != nullptr) { - // idTable[static_cast(mostSignificantQubit)] = makeDDNode(mostSignificantQubit, - // std::array{idTable[static_cast(mostSignificantQubit - 1)], - // mEdge::zero, - // mEdge::zero, - // idTable[static_cast(mostSignificantQubit - 1)]}); - // return idTable[static_cast(mostSignificantQubit)]; - // } -// - // auto e = makeDDNode(leastSignificantQubit, std::array{mEdge::one, mEdge::zero, mEdge::zero, mEdge::one}); - // for (auto k = static_cast(leastSignificantQubit + 1); k <= static_cast>(mostSignificantQubit); k++) { - // e = makeDDNode(static_cast(k), std::array{e, mEdge::zero, mEdge::zero, e}); - // } - // if (leastSignificantQubit == 0) { - // idTable[static_cast(mostSignificantQubit)] = e; - // } - // return e; - // } + // create n-qubit identity DD. Equivalent to creating a terminal with weight 1. + mEdge makeIdent() { + return mEdge::terminal(cn.lookup(1)); + } // identity table access and reset [[nodiscard]] const auto& getIdentityTable() const { return idTable; } diff --git a/test/test_package.cpp b/test/test_package.cpp index b451a860..99b7e65c 100644 --- a/test/test_package.cpp +++ b/test/test_package.cpp @@ -173,22 +173,27 @@ TEST(DDPackageTest, NegativeControl) { EXPECT_EQ(dd->getValueByPath(state01, 0b01).r, 1.); } -//TEST(DDPackageTest, IdentityTrace) { -// auto dd = std::make_unique>(4); -// auto identity = dd->makeIdent(4); -// std::string filename1 = "C:/Users/aaron/OneDrive/Documents/GitHub/ddsim/extern/qfr/extern/dd_package/graphs/Identity"; -// dd::export2Dot(identity, filename1); -// auto fullTrace = dd->trace(dd->makeIdent(4)); -// -// ASSERT_EQ(fullTrace, (dd::ComplexValue{16, 0})); -//} +TEST(DDPackageTest, IdentityTrace) { + auto dd = std::make_unique>(4); + auto identity = dd->makeIdent(); -//TEST(DDPackageTest, PartialIdentityTrace) { -// auto dd = std::make_unique>(2); -// auto tr = dd->partialTrace(dd->makeIdent(2), {false, true}); -// auto mul = dd->multiply(tr, tr); -// EXPECT_EQ(dd::CTEntry::val(mul.w.r), 4.0); -//} + auto fullTrace = dd->trace(identity); + + ASSERT_EQ(fullTrace, (dd::ComplexValue{16, 0})); +} + +TEST(DDPackageTest, PartialIdentityTrace) { + auto dd = std::make_unique>(2); + auto tr = dd->partialTrace(dd->makeIdent(), {false, true}); + std::string filename1 = "C:/Users/aaron/OneDrive/Documents/GitHub/ddsim/extern/qfr/extern/dd_package/graphs/Trace"; + dd::export2Dot(tr, filename1); + auto mul = dd->multiply(tr, tr); + std::string filename2 = "C:/Users/aaron/OneDrive/Documents/GitHub/ddsim/extern/qfr/extern/dd_package/graphs/Multiply"; + dd::export2Dot(mul, filename2); + dd::export2Dot(tr, filename1); + + EXPECT_EQ(dd::CTEntry::val(mul.w.r), 4.0); +} TEST(DDPackageTest, StateGenerationManipulation) { const std::size_t nqubits = 6; @@ -214,6 +219,8 @@ TEST(DDPackageTest, VectorSerializationTest) { auto zeroState = dd->makeZeroState(2); auto bellState = dd->multiply(dd->multiply(cxGate, hGate), zeroState); + std::string filename1 = "C:/Users/aaron/OneDrive/Documents/GitHub/ddsim/extern/qfr/extern/dd_package/graphs/BellState"; + dd::export2Dot(bellState, filename1, true, true); serialize(bellState, "bell_state.dd", false); auto deserializedBellState = dd->deserialize("bell_state.dd", false); @@ -232,6 +239,8 @@ TEST(DDPackageTest, BellMatrix) { auto bellMatrix = dd->multiply(cxGate, hGate); + std::string filename1 = "C:/Users/aaron/OneDrive/Documents/GitHub/ddsim/extern/qfr/extern/dd_package/graphs/BellMatrix"; + dd::export2Dot(bellMatrix, filename1, true, true); ASSERT_EQ(dd->getValueByPath(bellMatrix, "00"), (dd::ComplexValue{dd::SQRT2_2, 0})); ASSERT_EQ(dd->getValueByPath(bellMatrix, "02"), (dd::ComplexValue{0, 0})); ASSERT_EQ(dd->getValueByPath(bellMatrix, "20"), (dd::ComplexValue{0, 0}));