Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unit tests for simulate_tree_mcm #6231

Open
wants to merge 25 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
7740a41
init_tests
obliviateandsurrender Sep 6, 2024
58676a3
add tests
obliviateandsurrender Sep 6, 2024
1c18eb4
add more tests
obliviateandsurrender Sep 10, 2024
e6faa7b
happy `isort`
obliviateandsurrender Sep 10, 2024
b6861f7
Merge branch 'master' into tree_mcm_tests
obliviateandsurrender Sep 10, 2024
726e45e
happy `pylint`
obliviateandsurrender Sep 10, 2024
6214f74
tweak test
obliviateandsurrender Sep 10, 2024
348bdf3
fix `shots`
obliviateandsurrender Sep 10, 2024
2269b1f
add `marker`
obliviateandsurrender Sep 10, 2024
e0a20e0
add more unit tests
obliviateandsurrender Sep 14, 2024
c5f962b
fix `isort`
obliviateandsurrender Sep 14, 2024
caf70ba
Merge branch 'master' into tree_mcm_tests
obliviateandsurrender Sep 14, 2024
599b7e2
Merge branch 'master' into tree_mcm_tests
obliviateandsurrender Sep 17, 2024
e250c89
add another test
obliviateandsurrender Sep 18, 2024
5639b52
Merge branch 'tree_mcm_tests' of https://github.com/PennyLaneAI/penny…
obliviateandsurrender Sep 18, 2024
9255508
skip `tf`
obliviateandsurrender Sep 18, 2024
85c902b
revert `tf`
obliviateandsurrender Sep 18, 2024
2a4bea2
fix post selection
obliviateandsurrender Sep 18, 2024
875dbbd
Merge branch 'master' into tree_mcm_tests
obliviateandsurrender Sep 19, 2024
a84a6a9
shifting marker to test `class`
obliviateandsurrender Sep 19, 2024
3705690
apply suggestions
obliviateandsurrender Sep 19, 2024
6128b35
Merge branch 'master' into tree_mcm_tests
obliviateandsurrender Sep 19, 2024
aa045f3
reliable one-shot?
obliviateandsurrender Sep 19, 2024
20499eb
happier `one_shot`?
obliviateandsurrender Sep 19, 2024
833a101
Merge branch 'master' into tree_mcm_tests
obliviateandsurrender Sep 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pennylane/devices/qubit/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,7 +922,7 @@ def _(original_measurement: ExpectationMP, measures): # pylint: disable=unused-
for v in measures.values():
if not v[0] or v[1] is tuple():
continue
cum_value += v[0] * v[1]
cum_value += v[0] * qml.math.squeeze(v[1])
obliviateandsurrender marked this conversation as resolved.
Show resolved Hide resolved
total_counts += v[0]
return cum_value / total_counts

Expand All @@ -935,14 +935,14 @@ def _(original_measurement: ProbabilityMP, measures): # pylint: disable=unused-
for v in measures.values():
if not v[0] or v[1] is tuple():
continue
cum_value += v[0] * v[1]
cum_value += v[0] * qml.math.squeeze(v[1])
obliviateandsurrender marked this conversation as resolved.
Show resolved Hide resolved
total_counts += v[0]
return cum_value / total_counts


@combine_measurements_core.register
def _(original_measurement: SampleMP, measures): # pylint: disable=unused-argument
"""The combined samples of two branches is obtained by concatenating the sample if each branch.."""
"""The combined samples of two branches is obtained by concatenating the sample of each branch."""
new_sample = tuple(
qml.math.atleast_1d(m[1]) for m in measures.values() if m[0] and not m[1] is tuple()
)
Expand Down
233 changes: 232 additions & 1 deletion tests/devices/qubit/test_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,19 @@
# limitations under the License.
"""Unit tests for simulate in devices/qubit."""

import mcm_utils
import numpy as np
import pytest
from dummy_debugger import Debugger

import pennylane as qml
from pennylane.devices.qubit import get_final_state, measure_final_state, simulate
from pennylane.devices.qubit.simulate import _FlexShots
from pennylane.devices.qubit.simulate import (
_FlexShots,
find_post_processed_mcms,
simulate_tree_mcm,
split_circuit_at_mcms,
)


class TestCurrentlyUnsupportedCases:
Expand Down Expand Up @@ -1178,3 +1184,228 @@ def test_qinfo_tf(self):

grad5 = grad_tape.jacobian(results[5], phi)
assert qml.math.allclose(grad5, expected_grads[5])


class TestMidMeasurements:
"""Tests for simulating scripts with mid-circuit measurements using the ``simulate_tree_mcm``."""

@pytest.mark.parametrize("val", [0, 1])
def test_basic_mid_meas_circuit(self, val):
obliviateandsurrender marked this conversation as resolved.
Show resolved Hide resolved
"""Test execution with a basic circuit with mid-circuit measurements."""
qs = qml.tape.QuantumScript(
[qml.Hadamard(0), qml.CNOT([0, 1]), qml.measurements.MidMeasureMP(0, postselect=val)],
[qml.expval(qml.X(0)), qml.expval(qml.Z(0))],
)
result = simulate_tree_mcm(qs)
assert result == (0, (-1.0) ** val)

def test_basic_mid_meas_circuit_with_reset(self):
"""Test execution with a basic circuit with mid-circuit measurements."""
qs = qml.tape.QuantumScript(
[
qml.Hadamard(0),
qml.Hadamard(1),
qml.CNOT([0, 1]),
(m0 := qml.measure(0, reset=True)).measurements[0],
qml.Hadamard(0),
qml.CNOT([1, 0]),
], # equivalent to a circuit that gives equiprobable basis states
[qml.probs(op=m0), qml.probs(op=qml.Z(0)), qml.probs(op=qml.Z(1))],
)
result = simulate_tree_mcm(qs)
assert qml.math.allclose(result, qml.math.array([0.5, 0.5]))

# pylint: disable=too-many-arguments
@pytest.mark.parametrize("shots", [None, 5500])
@pytest.mark.parametrize("postselect", [None, 0])
@pytest.mark.parametrize("reset", [False, True])
@pytest.mark.parametrize("measure_f", [qml.counts, qml.expval, qml.probs, qml.sample, qml.var])
@pytest.mark.parametrize(
"meas_obj",
[qml.Y(0), [1], [1, 0], "mcm", "composite_mcm", "mcm_list"],
)
def test_simple_dynamic_circuit(self, shots, measure_f, postselect, reset, meas_obj):
obliviateandsurrender marked this conversation as resolved.
Show resolved Hide resolved
"""Tests that `simulate` can handles a simple dynamic circuit with the following measurements:
* qml.counts with obs (comp basis or not), single wire, multiple wires (ordered/unordered), MCM, f(MCM), MCM list
* qml.expval with obs (comp basis or not), MCM, f(MCM), MCM list
* qml.probs with obs (comp basis or not), single wire, multiple wires (ordered/unordered), MCM, f(MCM), MCM list
* qml.sample with obs (comp basis or not), single wire, multiple wires (ordered/unordered), MCM, f(MCM), MCM list
* qml.var with obs (comp basis or not), MCM, f(MCM), MCM list
The above combinations should work for finite shots, shot vectors and post-selecting of either the 0 or 1 branch.
"""

if (
isinstance(meas_obj, (qml.X, qml.Z, qml.Y))
and measure_f in (qml.var,)
and not qml.operation.active_new_opmath()
):
pytest.xfail(
"The tree-traversal method does not work with legacy opmath with "
"`qml.var` of pauli observables in the circuit."
)

if measure_f in (qml.expval, qml.var) and (
isinstance(meas_obj, list) or meas_obj == "mcm_list"
):
pytest.skip("Can't use wires/mcm lists with var or expval")

if measure_f in (qml.counts, qml.sample) and shots is None:
pytest.skip("Can't measure counts/sample in analytic mode (`shots=None`)")

if measure_f in (qml.probs,) and meas_obj in ["composite_mcm"]:
pytest.skip(
"Cannot use qml.probs() when measuring multiple mid-circuit measurements collected using arithmetic operators."
)

qscript = qml.tape.QuantumScript(
[
qml.RX(np.pi / 2.5, 0),
qml.RZ(np.pi / 4, 0),
(m0 := qml.measure(0, reset=reset)).measurements[0],
qml.ops.op_math.Conditional(m0 == 0, qml.RX(np.pi / 4, 0)),
qml.ops.op_math.Conditional(m0 == 1, qml.RX(-np.pi / 4, 0)),
qml.RX(np.pi / 3, 1),
qml.RZ(np.pi / 4, 1),
(m1 := qml.measure(1, postselect=postselect)).measurements[0],
qml.ops.op_math.Conditional(m1 == 0, qml.RY(np.pi / 4, 1)),
qml.ops.op_math.Conditional(m1 == 1, qml.RY(-np.pi / 4, 1)),
],
[
measure_f(
**{
"wires" if isinstance(meas_obj, list) else "op": (
(
m0
if meas_obj == "mcm"
else (0.5 * m0 + m1 if meas_obj == "composite_mcm" else [m0, m1])
)
if isinstance(meas_obj, str)
else meas_obj
)
}
)
],
shots=shots,
)

results0 = simulate(qscript, mcm_method="tree-traversal")

deferred_tapes, deferred_func = qml.defer_measurements(qscript)
results1 = deferred_func([simulate(tape, mcm_method="deferred") for tape in deferred_tapes])
mcm_utils.validate_measurements(measure_f, shots, results1, results0)

if shots is not None:
one_shot_tapes, one_shot_func = qml.dynamic_one_shot(qscript)
results2 = one_shot_func(
[simulate(tape, mcm_method="one-shot") for tape in one_shot_tapes]
)
mcm_utils.validate_measurements(measure_f, shots, results2, results0)

@pytest.mark.parametrize("shots", [None, 500000, [500000, 500001]])
@pytest.mark.parametrize("rng", [None, 42, np.array([37])])
@pytest.mark.parametrize("angles", [(0.123, 0.015), (0.543, 0.057)])
def test_approx_dynamic_mid_meas_circuit(self, shots, rng, angles):
"""Test execution with a basic circuit with mid-circuit measurements."""
qs_with_mid_meas = qml.tape.QuantumScript(
[
qml.Hadamard(0),
qml.Hadamard(1),
qml.CZ([0, 1]),
qml.CNOT([0, 2]),
qml.CNOT([2, 3]),
qml.CZ([1, 3]),
qml.Toffoli([3, 2, 0]),
(m0 := qml.measure(0)).measurements[0],
obliviateandsurrender marked this conversation as resolved.
Show resolved Hide resolved
qml.ops.op_math.Conditional(m0, qml.RZ(angles[0], 1)),
qml.Hadamard(1),
qml.Z(1),
(m1 := qml.measure(1)).measurements[0],
qml.ops.op_math.Conditional(m1, qml.RX(angles[1], 3)),
],
[
qml.probs(wires=[0, 1, 2, 3]),
qml.var(qml.X(0) @ qml.X(1) @ qml.Z(2) @ qml.Z(3)),
],
shots=shots,
)
qs_without_mid_meas = qml.tape.QuantumScript(
[
qml.Hadamard(0),
qml.Hadamard(1),
qml.CZ([0, 1]),
qml.CNOT([0, 2]),
qml.CNOT([2, 3]),
qml.CZ([1, 3]),
qml.Toffoli([3, 2, 0]),
qml.Hadamard(1),
qml.Z(1),
qml.RX(angles[1], 3),
],
[
qml.probs(wires=[0, 1, 2, 3]),
qml.var(qml.X(0) @ qml.X(1) @ qml.Z(2) @ qml.Z(3)),
],
shots=shots,
) # approximate compiled circuit of the above
res1 = simulate_tree_mcm(qs_with_mid_meas, rng=rng)
res2 = simulate(qs_without_mid_meas, rng=rng)
if not isinstance(shots, list):
assert all(qml.math.allclose(r1, r2, atol=1e-2) for r1, r2 in zip(res1, res2))
else:
for rs1, rs2 in zip(res1, res2):
assert all(qml.math.allclose(r1, r2, atol=1e-2) for r1, r2 in zip(rs1, rs2))

@pytest.mark.parametrize("postselect_mode", ["hw-like", "fill-shots"])
def test_tree_traversal_postselect_mode(self, postselect_mode):
"""Test that invalid shots are discarded if requested"""

shots = 100
qscript = qml.tape.QuantumScript(
[
qml.RX(np.pi / 2, 0),
(m0 := qml.measure(0, postselect=1)).measurements[0],
qml.ops.op_math.Conditional(m0, qml.RZ(1.57, 1)),
],
[qml.sample(wires=[0, 1])],
shots=shots,
)

res = simulate_tree_mcm(qscript, postselect_mode=postselect_mode)

assert (len(res) < shots) if postselect_mode == "hw-like" else (len(res) == shots)
assert np.all(res != np.iinfo(np.int32).min)

def test_tree_traversal_deep_circuit(self):
"""Test that `simulate_tree_mcm` works with circuits with many mid-circuit measurements"""

n_circs = 500
operations = []
for _ in range(n_circs):
operations.extend(
[
qml.RX(1.234, 0),
(m0 := qml.measure(0, postselect=1)).measurements[0],
qml.CNOT([0, 1]),
qml.ops.op_math.Conditional(m0, qml.RZ(1.786, 1)),
]
)

qscript = qml.tape.QuantumScript(
operations,
[qml.sample(wires=[0, 1]), qml.counts(wires=[0, 1])],
shots=20,
)

mcms = find_post_processed_mcms(qscript)
assert len(mcms) == n_circs

split_circs = split_circuit_at_mcms(qscript)
assert len(split_circs) == n_circs + 1
for circ in split_circs:
assert not [o for o in circ.operations if isinstance(o, qml.measurements.MidMeasureMP)]

res = simulate_tree_mcm(qscript, postselect_mode="fill-shots")
assert len(res[0]) == 20
assert isinstance(res[1], dict) and sum(list(res[1].values())) == 20
Loading