From 7740a41f8bbac6b52d70f80bee5c41e960280830 Mon Sep 17 00:00:00 2001 From: obliviateandsurrender Date: Fri, 6 Sep 2024 04:58:58 -0400 Subject: [PATCH 01/18] init_tests --- tests/devices/qubit/test_simulate.py | 34 +++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/tests/devices/qubit/test_simulate.py b/tests/devices/qubit/test_simulate.py index dbe9573b8df..49fd33f03fd 100644 --- a/tests/devices/qubit/test_simulate.py +++ b/tests/devices/qubit/test_simulate.py @@ -19,7 +19,7 @@ import pennylane as qml from pennylane.devices.qubit import get_final_state, measure_final_state, simulate -from pennylane.devices.qubit.simulate import _FlexShots +from pennylane.devices.qubit.simulate import _FlexShots, simulate_tree_mcm class TestCurrentlyUnsupportedCases: @@ -1178,3 +1178,35 @@ def test_qinfo_tf(self): grad5 = grad_tape.jacobian(results[5], phi) assert qml.math.allclose(grad5, expected_grads[5]) + + +class TestMidMeasurements: + """Tests for simulating scripts with mid-circuit measurements using the ``simulate_tree_mcm``.""" + + @pytest.mark.parametrize("postselect", [0, 1]) + def test_basic_mid_meas_circuit(self, postselect): + """Test execution with a basic circuit with mid-circuit measurements.""" + qs = qml.tape.QuantumScript( + [qml.Hadamard(0), qml.CNOT([0, 1]), qml.measurements.MidMeasureMP(0, postselect=1)], + [qml.expval(qml.X(0)), qml.expval(qml.Z(0))] + ) + result = simulate_tree_mcm(qs) + assert result == (0, -1 ** postselect) + + + def test_basic_mid_meas_circuit_with_reset(self): + """Test execution with a basic circuit with mid-circuit measurements.""" + phi = np.array(0.397) + qs = qml.tape.QuantumScript( + [qml.RX(phi, wires=0)], [qml.expval(qml.PauliY(0)), qml.expval(qml.PauliZ(0))] + ) + result = simulate(qs) + + def test_dynamic_mid_meas_circuit(self): + """Test execution with a basic circuit with mid-circuit measurements.""" + # qs = qml.tape.QuantumScript( + # [] + + # [qml.RX(phi, wires=0)], [qml.expval(qml.PauliY(0)), qml.expval(qml.PauliZ(0))] + # ) + result = simulate(qs) \ No newline at end of file From 58676a387448445adbd0adb728fbad96cf8d772f Mon Sep 17 00:00:00 2001 From: obliviateandsurrender Date: Fri, 6 Sep 2024 12:26:57 -0400 Subject: [PATCH 02/18] add tests --- tests/devices/qubit/test_simulate.py | 86 ++++++++++++++++++++++------ 1 file changed, 69 insertions(+), 17 deletions(-) diff --git a/tests/devices/qubit/test_simulate.py b/tests/devices/qubit/test_simulate.py index 49fd33f03fd..a75a7d790c0 100644 --- a/tests/devices/qubit/test_simulate.py +++ b/tests/devices/qubit/test_simulate.py @@ -1181,32 +1181,84 @@ def test_qinfo_tf(self): class TestMidMeasurements: - """Tests for simulating scripts with mid-circuit measurements using the ``simulate_tree_mcm``.""" + """Tests for simulating scripts with mid-circuit measurements using the ``simulate_tree_mcm``.""" - @pytest.mark.parametrize("postselect", [0, 1]) - def test_basic_mid_meas_circuit(self, postselect): + @pytest.mark.parametrize("val", [0, 1]) + def test_basic_mid_meas_circuit(self, val): """Test execution with a basic circuit with mid-circuit measurements.""" qs = qml.tape.QuantumScript( - [qml.Hadamard(0), qml.CNOT([0, 1]), qml.measurements.MidMeasureMP(0, postselect=1)], - [qml.expval(qml.X(0)), qml.expval(qml.Z(0))] + [qml.Hadamard(0), qml.CNOT([0, 1]), qml.measurements.MidMeasureMP(0, postselect=val)], + [qml.expval(qml.X(0)), qml.expval(qml.Z(0))], ) result = simulate_tree_mcm(qs) - assert result == (0, -1 ** postselect) - + assert result == (0, (-1.0) ** val) def test_basic_mid_meas_circuit_with_reset(self): """Test execution with a basic circuit with mid-circuit measurements.""" - phi = np.array(0.397) qs = qml.tape.QuantumScript( - [qml.RX(phi, wires=0)], [qml.expval(qml.PauliY(0)), qml.expval(qml.PauliZ(0))] + [ + qml.Hadamard(0), + qml.Hadamard(1), + qml.CNOT([0, 1]), + (m0 := qml.measure(0, reset=True)).measurements[0], + qml.Hadamard(0), + qml.CNOT([1, 0]), + ], # equivalent to a circuit that gives equiprobable basis states + [qml.probs(op=m0), qml.probs(op=qml.Z(0)), qml.probs(op=qml.Z(1))], ) - result = simulate(qs) + result = simulate_tree_mcm(qs) + assert qml.math.allclose(result, qml.math.array([0.5, 0.5])) - def test_dynamic_mid_meas_circuit(self): + @pytest.mark.parametrize("shots", [None, int(5e5), [int(4e5), int(6e5)]]) + @pytest.mark.parametrize("rng", [None, 42, np.array([37])]) + @pytest.mark.parametrize("angles", [(0.123, 0.015), (0.543, 0.057)]) + def test_dynamic_mid_meas_circuit(self, shots, rng, angles): """Test execution with a basic circuit with mid-circuit measurements.""" - # qs = qml.tape.QuantumScript( - # [] - - # [qml.RX(phi, wires=0)], [qml.expval(qml.PauliY(0)), qml.expval(qml.PauliZ(0))] - # ) - result = simulate(qs) \ No newline at end of file + qs_with_mid_meas = qml.tape.QuantumScript( + [ + qml.Hadamard(0), + qml.Hadamard(1), + qml.CZ([0, 1]), + qml.CNOT([0, 2]), + qml.CNOT([2, 3]), + qml.CZ([1, 3]), + qml.Toffoli([3, 2, 0]), + (m0 := qml.measure(0)).measurements[0], + qml.ops.op_math.Conditional(m0, qml.RZ(angles[0], 1)), + qml.Hadamard(1), + qml.Z(1), + (m1 := qml.measure(1)).measurements[0], + qml.ops.op_math.Conditional(m1, qml.RX(angles[1], 3)), + ], + [ + qml.probs(wires=[0, 1, 2, 3]), + qml.var(qml.X(0) @ qml.X(1) @ qml.Z(2) @ qml.Z(3)), + ], + shots=shots, + ) + qs_without_mid_meas = qml.tape.QuantumScript( + [ + qml.Hadamard(0), + qml.Hadamard(1), + qml.CZ([0, 1]), + qml.CNOT([0, 2]), + qml.CNOT([2, 3]), + qml.CZ([1, 3]), + qml.Toffoli([3, 2, 0]), + qml.Hadamard(1), + qml.Z(1), + qml.RX(angles[1], 3), + ], + [ + qml.probs(wires=[0, 1, 2, 3]), + qml.var(qml.X(0) @ qml.X(1) @ qml.Z(2) @ qml.Z(3)), + ], + shots=shots, + ) # approximate compiled circuit of the above + res1 = simulate_tree_mcm(qs_with_mid_meas, rng=rng) + res2 = simulate(qs_without_mid_meas, rng=rng) + if not isinstance(shots, list): + assert all(qml.math.allclose(r1, r2, atol=1e-2) for r1, r2 in zip(res1, res2)) + else: + for rs1, rs2 in zip(res1, res2): + assert all(qml.math.allclose(r1, r2, atol=1e-2) for r1, r2 in zip(rs1, rs2)) From 1c18eb46a4e18ee7f41c431a83934688176833ee Mon Sep 17 00:00:00 2001 From: obliviateandsurrender Date: Mon, 9 Sep 2024 23:43:14 -0400 Subject: [PATCH 03/18] add more tests --- pennylane/devices/qubit/simulate.py | 4 +- tests/devices/qubit/test_simulate.py | 133 ++++++++++++++++++++++++++- 2 files changed, 133 insertions(+), 4 deletions(-) diff --git a/pennylane/devices/qubit/simulate.py b/pennylane/devices/qubit/simulate.py index 56e4a8f1a48..54c25a3bf8d 100644 --- a/pennylane/devices/qubit/simulate.py +++ b/pennylane/devices/qubit/simulate.py @@ -922,7 +922,7 @@ def _(original_measurement: ExpectationMP, measures): # pylint: disable=unused- for v in measures.values(): if not v[0] or v[1] is tuple(): continue - cum_value += v[0] * v[1] + cum_value += v[0] * qml.math.squeeze(v[1]) total_counts += v[0] return cum_value / total_counts @@ -935,7 +935,7 @@ def _(original_measurement: ProbabilityMP, measures): # pylint: disable=unused- for v in measures.values(): if not v[0] or v[1] is tuple(): continue - cum_value += v[0] * v[1] + cum_value += v[0] * qml.math.squeeze(v[1]) total_counts += v[0] return cum_value / total_counts diff --git a/tests/devices/qubit/test_simulate.py b/tests/devices/qubit/test_simulate.py index a75a7d790c0..8314b02ade7 100644 --- a/tests/devices/qubit/test_simulate.py +++ b/tests/devices/qubit/test_simulate.py @@ -17,9 +17,11 @@ import pytest from dummy_debugger import Debugger +import mcm_utils + import pennylane as qml from pennylane.devices.qubit import get_final_state, measure_final_state, simulate -from pennylane.devices.qubit.simulate import _FlexShots, simulate_tree_mcm +from pennylane.devices.qubit.simulate import _FlexShots, simulate_tree_mcm, split_circuit_at_mcms class TestCurrentlyUnsupportedCases: @@ -1209,10 +1211,97 @@ def test_basic_mid_meas_circuit_with_reset(self): result = simulate_tree_mcm(qs) assert qml.math.allclose(result, qml.math.array([0.5, 0.5])) + @pytest.mark.parametrize("shots", [None, 5500]) + @pytest.mark.parametrize("postselect", [None, 0]) + @pytest.mark.parametrize("reset", [False, True]) + @pytest.mark.parametrize("measure_f", [qml.counts, qml.expval, qml.probs, qml.sample, qml.var]) + @pytest.mark.parametrize( + "meas_obj", + [qml.Y(0), [1], [1, 0], "mcm", "composite_mcm", "mcm_list"], + ) + def test_simple_dynamic_circuit(self, shots, measure_f, postselect, reset, meas_obj): + """Tests that `simulate` can handles a simple dynamic circuit with the following measurements: + + * qml.counts with obs (comp basis or not), single wire, multiple wires (ordered/unordered), MCM, f(MCM), MCM list + * qml.expval with obs (comp basis or not), MCM, f(MCM), MCM list + * qml.probs with obs (comp basis or not), single wire, multiple wires (ordered/unordered), MCM, f(MCM), MCM list + * qml.sample with obs (comp basis or not), single wire, multiple wires (ordered/unordered), MCM, f(MCM), MCM list + * qml.var with obs (comp basis or not), MCM, f(MCM), MCM list + + The above combinations should work for finite shots, shot vectors and post-selecting of either the 0 or 1 branch. + """ + + if ( + isinstance(meas_obj, (qml.X, qml.Z, qml.Y)) + and measure_f in (qml.var,) + and not qml.operation.active_new_opmath() + ): + pytest.xfail( + "The tree-traversal method does not work with legacy opmath with " + "`qml.var` of pauli observables in the circuit." + ) + + if measure_f in (qml.expval, qml.var) and ( + isinstance(meas_obj, list) or meas_obj == "mcm_list" + ): + pytest.skip("Can't use wires/mcm lists with var or expval") + + if measure_f in (qml.counts, qml.sample) and shots is None: + pytest.skip("Can't measure counts/sample in analytic mode (`shots=None`)") + + if measure_f in (qml.probs,) and meas_obj in ["composite_mcm"]: + pytest.skip( + "Cannot use qml.probs() when measuring multiple mid-circuit measurements collected using arithmetic operators." + ) + + qscript = qml.tape.QuantumScript( + [ + qml.RX(np.pi / 2.5, 0), + qml.RZ(np.pi / 4, 0), + (m0 := qml.measure(0, reset=reset)).measurements[0], + qml.ops.op_math.Conditional(m0 == 0, qml.RX(np.pi / 4, 0)), + qml.ops.op_math.Conditional(m0 == 1, qml.RX(-np.pi / 4, 0)), + qml.RX(np.pi / 3, 1), + qml.RZ(np.pi / 4, 1), + (m1 := qml.measure(1, postselect=postselect)).measurements[0], + qml.ops.op_math.Conditional(m1 == 0, qml.RY(np.pi / 4, 1)), + qml.ops.op_math.Conditional(m1 == 1, qml.RY(-np.pi / 4, 1)), + ], + [ + measure_f( + **{ + "wires" if isinstance(meas_obj, list) else "op": ( + ( + m0 + if meas_obj == "mcm" + else (0.5 * m0 if meas_obj == "composite_mcm" else [m0, m1]) + ) + if isinstance(meas_obj, str) + else meas_obj + ) + } + ) + ], + shots=shots, + ) + print(qscript.measurements) + results0 = simulate(qscript, mcm_method="tree-traversal") + + deferred_tapes, deferred_func = qml.defer_measurements(qscript) + results1 = deferred_func([simulate(tape, mcm_method="deferred") for tape in deferred_tapes]) + mcm_utils.validate_measurements(measure_f, shots, results1, results0) + + if shots is not None: + one_shot_tapes, one_shot_func = qml.dynamic_one_shot(qscript) + results2 = one_shot_func( + [simulate(tape, mcm_method="one-shot") for tape in one_shot_tapes] + ) + mcm_utils.validate_measurements(measure_f, shots, results2, results0) + @pytest.mark.parametrize("shots", [None, int(5e5), [int(4e5), int(6e5)]]) @pytest.mark.parametrize("rng", [None, 42, np.array([37])]) @pytest.mark.parametrize("angles", [(0.123, 0.015), (0.543, 0.057)]) - def test_dynamic_mid_meas_circuit(self, shots, rng, angles): + def test_approx_dynamic_mid_meas_circuit(self, shots, rng, angles): """Test execution with a basic circuit with mid-circuit measurements.""" qs_with_mid_meas = qml.tape.QuantumScript( [ @@ -1262,3 +1351,43 @@ def test_dynamic_mid_meas_circuit(self, shots, rng, angles): else: for rs1, rs2 in zip(res1, res2): assert all(qml.math.allclose(r1, r2, atol=1e-2) for r1, r2 in zip(rs1, rs2)) + + @pytest.mark.parametrize("postselect_mode", ["hw-like", "fill-shots"]) + def test_tree_traversal_postselect_mode(self, postselect_mode): + """Test that invalid shots are discarded if requested""" + + shots = 100 + qscript = qml.tape.QuantumScript( + [ + qml.RX(np.pi / 2, 0), + (m0 := qml.measure(0, postselect=1)).measurements[0], + qml.ops.op_math.Conditional(m0, qml.RZ(1.57, 1)), + ], + [qml.sample(wires=[0, 1])], + shots=shots, + ) + + res = simulate_tree_mcm(qscript, postselect_mode=postselect_mode) + + assert (len(res) < shots) if postselect_mode == "hw-like" else (len(res) == shots) + assert np.all(res != np.iinfo(np.int32).min) + + def test_tree_traversal_deep_circuit(self): + """Test that `simulate_tree_mcm` works with circuits with many mid-circuit measurements""" + + n_circs = 500 + qscript = qml.tape.QuantumScript( + [ + qml.RX(1.234, 0), + (m0 := qml.measure(0)).measurements[0], + qml.ops.op_math.Conditional(m0, qml.RZ(1.786, 1)), + ] + * n_circs, + [qml.sample(wires=[0, 1]), qml.expval(m0)], + shots=20, + ) + + res = simulate_tree_mcm(qscript) + assert len(res[0]) == 40 + assert isinstance(res[1], np.float64) + assert len(split_circuit_at_mcms(qscript)) == n_circs + 1 From e6faa7b2de3c259d9e892ce0c3ef77f975e494ee Mon Sep 17 00:00:00 2001 From: obliviateandsurrender Date: Mon, 9 Sep 2024 23:46:32 -0400 Subject: [PATCH 04/18] happy `isort` --- tests/devices/qubit/test_simulate.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/devices/qubit/test_simulate.py b/tests/devices/qubit/test_simulate.py index 8314b02ade7..f409fcb9703 100644 --- a/tests/devices/qubit/test_simulate.py +++ b/tests/devices/qubit/test_simulate.py @@ -13,12 +13,11 @@ # limitations under the License. """Unit tests for simulate in devices/qubit.""" +import mcm_utils import numpy as np import pytest from dummy_debugger import Debugger -import mcm_utils - import pennylane as qml from pennylane.devices.qubit import get_final_state, measure_final_state, simulate from pennylane.devices.qubit.simulate import _FlexShots, simulate_tree_mcm, split_circuit_at_mcms From 726e45ea0dd00d8bccc663c6bec94b819dc9326c Mon Sep 17 00:00:00 2001 From: obliviateandsurrender Date: Mon, 9 Sep 2024 23:59:00 -0400 Subject: [PATCH 05/18] happy `pylint` --- tests/devices/qubit/test_simulate.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/devices/qubit/test_simulate.py b/tests/devices/qubit/test_simulate.py index f409fcb9703..09b2f6deeb6 100644 --- a/tests/devices/qubit/test_simulate.py +++ b/tests/devices/qubit/test_simulate.py @@ -1210,6 +1210,7 @@ def test_basic_mid_meas_circuit_with_reset(self): result = simulate_tree_mcm(qs) assert qml.math.allclose(result, qml.math.array([0.5, 0.5])) + # pylint: disable=too-many-arguments @pytest.mark.parametrize("shots", [None, 5500]) @pytest.mark.parametrize("postselect", [None, 0]) @pytest.mark.parametrize("reset", [False, True]) From 6214f74d1a6d817057b3348a6026a3911c275c20 Mon Sep 17 00:00:00 2001 From: obliviateandsurrender Date: Tue, 10 Sep 2024 10:50:20 -0400 Subject: [PATCH 06/18] tweak test --- pennylane/devices/qubit/simulate.py | 2 +- tests/devices/qubit/test_simulate.py | 48 +++++++++++++++++++--------- 2 files changed, 34 insertions(+), 16 deletions(-) diff --git a/pennylane/devices/qubit/simulate.py b/pennylane/devices/qubit/simulate.py index 54c25a3bf8d..1244149ed8f 100644 --- a/pennylane/devices/qubit/simulate.py +++ b/pennylane/devices/qubit/simulate.py @@ -942,7 +942,7 @@ def _(original_measurement: ProbabilityMP, measures): # pylint: disable=unused- @combine_measurements_core.register def _(original_measurement: SampleMP, measures): # pylint: disable=unused-argument - """The combined samples of two branches is obtained by concatenating the sample if each branch..""" + """The combined samples of two branches is obtained by concatenating the sample of each branch.""" new_sample = tuple( qml.math.atleast_1d(m[1]) for m in measures.values() if m[0] and not m[1] is tuple() ) diff --git a/tests/devices/qubit/test_simulate.py b/tests/devices/qubit/test_simulate.py index 09b2f6deeb6..22465b5c485 100644 --- a/tests/devices/qubit/test_simulate.py +++ b/tests/devices/qubit/test_simulate.py @@ -20,7 +20,12 @@ import pennylane as qml from pennylane.devices.qubit import get_final_state, measure_final_state, simulate -from pennylane.devices.qubit.simulate import _FlexShots, simulate_tree_mcm, split_circuit_at_mcms +from pennylane.devices.qubit.simulate import ( + _FlexShots, + find_post_processed_mcms, + simulate_tree_mcm, + split_circuit_at_mcms, +) class TestCurrentlyUnsupportedCases: @@ -1274,7 +1279,7 @@ def test_simple_dynamic_circuit(self, shots, measure_f, postselect, reset, meas_ ( m0 if meas_obj == "mcm" - else (0.5 * m0 if meas_obj == "composite_mcm" else [m0, m1]) + else (0.5 * m0 + m1 if meas_obj == "composite_mcm" else [m0, m1]) ) if isinstance(meas_obj, str) else meas_obj @@ -1284,7 +1289,7 @@ def test_simple_dynamic_circuit(self, shots, measure_f, postselect, reset, meas_ ], shots=shots, ) - print(qscript.measurements) + results0 = simulate(qscript, mcm_method="tree-traversal") deferred_tapes, deferred_func = qml.defer_measurements(qscript) @@ -1298,7 +1303,7 @@ def test_simple_dynamic_circuit(self, shots, measure_f, postselect, reset, meas_ ) mcm_utils.validate_measurements(measure_f, shots, results2, results0) - @pytest.mark.parametrize("shots", [None, int(5e5), [int(4e5), int(6e5)]]) + @pytest.mark.parametrize("shots", [None, 5000, [5000, 5001]]) @pytest.mark.parametrize("rng", [None, 42, np.array([37])]) @pytest.mark.parametrize("angles", [(0.123, 0.015), (0.543, 0.057)]) def test_approx_dynamic_mid_meas_circuit(self, shots, rng, angles): @@ -1376,18 +1381,31 @@ def test_tree_traversal_deep_circuit(self): """Test that `simulate_tree_mcm` works with circuits with many mid-circuit measurements""" n_circs = 500 + operations = [] + for _ in range(n_circs): + operations.extend( + [ + qml.RX(1.234, 0), + (m0 := qml.measure(0, postselect=1)).measurements[0], + qml.CNOT([0, 1]), + qml.ops.op_math.Conditional(m0, qml.RZ(1.786, 1)), + ] + ) + qscript = qml.tape.QuantumScript( - [ - qml.RX(1.234, 0), - (m0 := qml.measure(0)).measurements[0], - qml.ops.op_math.Conditional(m0, qml.RZ(1.786, 1)), - ] - * n_circs, - [qml.sample(wires=[0, 1]), qml.expval(m0)], + operations, + [qml.sample(wires=[0, 1]), qml.counts(wires=[0, 1])], shots=20, ) - res = simulate_tree_mcm(qscript) - assert len(res[0]) == 40 - assert isinstance(res[1], np.float64) - assert len(split_circuit_at_mcms(qscript)) == n_circs + 1 + mcms = find_post_processed_mcms(qscript) + assert len(mcms) == n_circs + + split_circs = split_circuit_at_mcms(qscript) + assert len(split_circs) == n_circs + 1 + for circ in split_circs: + assert not [o for o in circ.operations if isinstance(o, qml.measurements.MidMeasureMP)] + + res = simulate_tree_mcm(qscript, postselect_mode="fill-shots") + assert len(res[0]) == 20 + assert isinstance(res[1], dict) and sum(list(res[1].values())) == 20 From 348bdf3fa3c80480f2135b9b9caf8e497132e497 Mon Sep 17 00:00:00 2001 From: obliviateandsurrender Date: Tue, 10 Sep 2024 11:59:04 -0400 Subject: [PATCH 07/18] fix `shots` --- tests/devices/qubit/test_simulate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/devices/qubit/test_simulate.py b/tests/devices/qubit/test_simulate.py index 22465b5c485..7bfa2efd67e 100644 --- a/tests/devices/qubit/test_simulate.py +++ b/tests/devices/qubit/test_simulate.py @@ -1303,7 +1303,7 @@ def test_simple_dynamic_circuit(self, shots, measure_f, postselect, reset, meas_ ) mcm_utils.validate_measurements(measure_f, shots, results2, results0) - @pytest.mark.parametrize("shots", [None, 5000, [5000, 5001]]) + @pytest.mark.parametrize("shots", [None, 500000, [500000, 500001]]) @pytest.mark.parametrize("rng", [None, 42, np.array([37])]) @pytest.mark.parametrize("angles", [(0.123, 0.015), (0.543, 0.057)]) def test_approx_dynamic_mid_meas_circuit(self, shots, rng, angles): From 2269b1f9f1e77f9fc0e9642e42a64e1b0326ed1e Mon Sep 17 00:00:00 2001 From: obliviateandsurrender Date: Tue, 10 Sep 2024 19:52:27 -0400 Subject: [PATCH 08/18] add `marker` --- pennylane/devices/qubit/simulate.py | 4 ++-- tests/devices/qubit/test_simulate.py | 8 ++++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/pennylane/devices/qubit/simulate.py b/pennylane/devices/qubit/simulate.py index 1244149ed8f..1b4cec10b33 100644 --- a/pennylane/devices/qubit/simulate.py +++ b/pennylane/devices/qubit/simulate.py @@ -922,7 +922,7 @@ def _(original_measurement: ExpectationMP, measures): # pylint: disable=unused- for v in measures.values(): if not v[0] or v[1] is tuple(): continue - cum_value += v[0] * qml.math.squeeze(v[1]) + cum_value += qml.math.multiply(v[0], v[1]) total_counts += v[0] return cum_value / total_counts @@ -935,7 +935,7 @@ def _(original_measurement: ProbabilityMP, measures): # pylint: disable=unused- for v in measures.values(): if not v[0] or v[1] is tuple(): continue - cum_value += v[0] * qml.math.squeeze(v[1]) + cum_value += qml.math.multiply(v[0], v[1]) total_counts += v[0] return cum_value / total_counts diff --git a/tests/devices/qubit/test_simulate.py b/tests/devices/qubit/test_simulate.py index 7bfa2efd67e..b8571f205a8 100644 --- a/tests/devices/qubit/test_simulate.py +++ b/tests/devices/qubit/test_simulate.py @@ -1189,6 +1189,7 @@ def test_qinfo_tf(self): class TestMidMeasurements: """Tests for simulating scripts with mid-circuit measurements using the ``simulate_tree_mcm``.""" + @pytest.mark.unit @pytest.mark.parametrize("val", [0, 1]) def test_basic_mid_meas_circuit(self, val): """Test execution with a basic circuit with mid-circuit measurements.""" @@ -1199,6 +1200,7 @@ def test_basic_mid_meas_circuit(self, val): result = simulate_tree_mcm(qs) assert result == (0, (-1.0) ** val) + @pytest.mark.unit def test_basic_mid_meas_circuit_with_reset(self): """Test execution with a basic circuit with mid-circuit measurements.""" qs = qml.tape.QuantumScript( @@ -1216,13 +1218,13 @@ def test_basic_mid_meas_circuit_with_reset(self): assert qml.math.allclose(result, qml.math.array([0.5, 0.5])) # pylint: disable=too-many-arguments + @pytest.mark.unit @pytest.mark.parametrize("shots", [None, 5500]) @pytest.mark.parametrize("postselect", [None, 0]) @pytest.mark.parametrize("reset", [False, True]) @pytest.mark.parametrize("measure_f", [qml.counts, qml.expval, qml.probs, qml.sample, qml.var]) @pytest.mark.parametrize( - "meas_obj", - [qml.Y(0), [1], [1, 0], "mcm", "composite_mcm", "mcm_list"], + "meas_obj", [qml.Y(0), [1], [1, 0], "mcm", "composite_mcm", "mcm_list"] ) def test_simple_dynamic_circuit(self, shots, measure_f, postselect, reset, meas_obj): """Tests that `simulate` can handles a simple dynamic circuit with the following measurements: @@ -1357,6 +1359,7 @@ def test_approx_dynamic_mid_meas_circuit(self, shots, rng, angles): for rs1, rs2 in zip(res1, res2): assert all(qml.math.allclose(r1, r2, atol=1e-2) for r1, r2 in zip(rs1, rs2)) + @pytest.mark.unit @pytest.mark.parametrize("postselect_mode", ["hw-like", "fill-shots"]) def test_tree_traversal_postselect_mode(self, postselect_mode): """Test that invalid shots are discarded if requested""" @@ -1377,6 +1380,7 @@ def test_tree_traversal_postselect_mode(self, postselect_mode): assert (len(res) < shots) if postselect_mode == "hw-like" else (len(res) == shots) assert np.all(res != np.iinfo(np.int32).min) + @pytest.mark.unit def test_tree_traversal_deep_circuit(self): """Test that `simulate_tree_mcm` works with circuits with many mid-circuit measurements""" From e0a20e0c4e83bfdad13e2e2a6c691520e63376fa Mon Sep 17 00:00:00 2001 From: obliviateandsurrender Date: Fri, 13 Sep 2024 21:53:20 -0400 Subject: [PATCH 09/18] add more unit tests --- tests/devices/qubit/test_simulate.py | 126 +++++++++++++++++++++++---- 1 file changed, 111 insertions(+), 15 deletions(-) diff --git a/tests/devices/qubit/test_simulate.py b/tests/devices/qubit/test_simulate.py index b8571f205a8..18483550cd4 100644 --- a/tests/devices/qubit/test_simulate.py +++ b/tests/devices/qubit/test_simulate.py @@ -16,15 +16,21 @@ import mcm_utils import numpy as np import pytest +import scipy as sp from dummy_debugger import Debugger import pennylane as qml from pennylane.devices.qubit import get_final_state, measure_final_state, simulate from pennylane.devices.qubit.simulate import ( _FlexShots, + branch_state, + combine_measurements_core, find_post_processed_mcms, + samples_to_counts, + counts_to_probs, simulate_tree_mcm, split_circuit_at_mcms, + TreeTraversalStack, ) @@ -1186,6 +1192,61 @@ def test_qinfo_tf(self): assert qml.math.allclose(grad5, expected_grads[5]) +class TestTreeTraversalStack: + """Unit tests for TreeTraversalStack""" + + @pytest.mark.unit + @pytest.mark.parametrize( + "max_depth", + [0, 1, 10, 100], + ) + def test_init_with_depth(self, max_depth): + """Test that TreeTraversalStack is initialized correctly with given ``max_depth``""" + tree_stack = TreeTraversalStack(max_depth) + + assert tree_stack.counts.count(None) == max_depth + assert tree_stack.counts.probs(None) == max_depth + assert tree_stack.counts.results_0(None) == max_depth + assert tree_stack.counts.results_1(None) == max_depth + assert tree_stack.counts.states(None) == max_depth + + @pytest.mark.unit + def test_full_prune_empty_methods(self): + """Test that TreeTraversalStack object's class methods work correctly.""" + + max_depth = 10 + tree_stack = TreeTraversalStack(max_depth) + + np.random.shuffle(r_depths := list(range(max_depth))) + for depth in r_depths: + counts_0 = np.random.randint(1, 9) + counts_1 = 10 - counts_0 + tree_stack.counts[depth] = [counts_0, counts_1] + tree_stack.probs[depth] = [counts_0 / 10, counts_1 / 10] + tree_stack.results_0[depth] = [0] * counts_0 + tree_stack.results_1[depth] = [1] * counts_1 + tree_stack.states[depth] = [np.sqrt(counts_0), np.sqrt(counts_1)] + assert tree_stack.is_full(depth) + + assert tree_stack.counts[depth] == list( + samples_to_counts( + np.array(tree_stack.results_0[depth] + tree_stack.results_1[depth]) + ).values() + ) + assert tree_stack.probs[depth] == list( + counts_to_probs(dict(zip([0, 1], tree_stack.counts[depth]))).values() + ) + + state_vec = np.array(tree_stack.states[depth]).T + state_vec /= np.linalg.norm(tree_stack.states[depth]) + meas, meas_r = qml.measure(0), qml.measure(0, reset=True) + assert np.allclose(branch_state(state_vec, 0, meas.measurements[0]), [1.0, 0.0]) + assert np.allclose(branch_state(state_vec, 1, meas_r.measurements[0]), [1.0, 0.0]) + + tree_stack.prune(depth) + assert tree_stack.any_is_empty(depth) + + class TestMidMeasurements: """Tests for simulating scripts with mid-circuit measurements using the ``simulate_tree_mcm``.""" @@ -1305,11 +1366,17 @@ def test_simple_dynamic_circuit(self, shots, measure_f, postselect, reset, meas_ ) mcm_utils.validate_measurements(measure_f, shots, results2, results0) - @pytest.mark.parametrize("shots", [None, 500000, [500000, 500001]]) + @pytest.mark.unit + @pytest.mark.parametrize("shots", [None, 5500, [5500, 5500]]) @pytest.mark.parametrize("rng", [None, 42, np.array([37])]) @pytest.mark.parametrize("angles", [(0.123, 0.015), (0.543, 0.057)]) - def test_approx_dynamic_mid_meas_circuit(self, shots, rng, angles): - """Test execution with a basic circuit with mid-circuit measurements.""" + @pytest.mark.parametrize("measure_f", [qml.probs, qml.sample]) + def test_approx_dynamic_mid_meas_circuit(self, shots, rng, angles, measure_f): + """Test execution of a dynamic circuit with an equivalent static one.""" + + if measure_f in (qml.sample,) and shots is None: + pytest.skip("Can't measure samples in analytic mode (`shots=None`)") + qs_with_mid_meas = qml.tape.QuantumScript( [ qml.Hadamard(0), @@ -1326,10 +1393,7 @@ def test_approx_dynamic_mid_meas_circuit(self, shots, rng, angles): (m1 := qml.measure(1)).measurements[0], qml.ops.op_math.Conditional(m1, qml.RX(angles[1], 3)), ], - [ - qml.probs(wires=[0, 1, 2, 3]), - qml.var(qml.X(0) @ qml.X(1) @ qml.Z(2) @ qml.Z(3)), - ], + [measure_f(wires=[0, 1, 2, 3])], shots=shots, ) qs_without_mid_meas = qml.tape.QuantumScript( @@ -1345,19 +1409,29 @@ def test_approx_dynamic_mid_meas_circuit(self, shots, rng, angles): qml.Z(1), qml.RX(angles[1], 3), ], - [ - qml.probs(wires=[0, 1, 2, 3]), - qml.var(qml.X(0) @ qml.X(1) @ qml.Z(2) @ qml.Z(3)), - ], + [measure_f(wires=[0, 1, 2, 3])], shots=shots, ) # approximate compiled circuit of the above res1 = simulate_tree_mcm(qs_with_mid_meas, rng=rng) res2 = simulate(qs_without_mid_meas, rng=rng) + if not isinstance(shots, list): - assert all(qml.math.allclose(r1, r2, atol=1e-2) for r1, r2 in zip(res1, res2)) - else: - for rs1, rs2 in zip(res1, res2): - assert all(qml.math.allclose(r1, r2, atol=1e-2) for r1, r2 in zip(rs1, rs2)) + res1, res2 = (res1,), (res2,) + + for rs1, rs2 in zip(res1, res2): + prob_dist1, prob_dist2 = rs1, rs2 + if measure_f in (qml.sample,): + n_wires = rs1.shape[1] + prob_dist1, prob_dist2 = np.zeros(2**n_wires), np.zeros(2**n_wires) + for prob, rs in zip([prob_dist1, prob_dist2], [rs1, rs2]): + index, count = np.unique( + np.packbits(rs, axis=1, bitorder="little").squeeze(), return_counts=True + ) + prob[index] = count + + assert qml.math.allclose( + sp.stats.entropy(prob_dist1 + 1e-12, prob_dist2 + 1e-12), 0.0, atol=5e-2 + ) @pytest.mark.unit @pytest.mark.parametrize("postselect_mode", ["hw-like", "fill-shots"]) @@ -1413,3 +1487,25 @@ def test_tree_traversal_deep_circuit(self): res = simulate_tree_mcm(qscript, postselect_mode="fill-shots") assert len(res[0]) == 20 assert isinstance(res[1], dict) and sum(list(res[1].values())) == 20 + + @pytest.mark.unit + @pytest.mark.parametrize( + "measurements, expected", + [ + [(qml.counts(0), {"a": (1, {0: 42}), "b": (2, {1: 58})}), {0: 42, 1: 58}], + [ + (qml.expval(qml.Z(0)), {"a": (1, (0.42, -1)), "b": (2, (0.58, 1))}), + [1.58 / 3, 1 / 3], + ], + [(qml.probs(wires=0), {"a": (1, (0.42, -1)), "b": (2, (0.58, 1))}), [1.58 / 3, 1 / 3]], + [(qml.sample(wires=0), {"a": (1, (0, 1, 0)), "b": (2, (1, 0, 1))}), [0, 1, 0, 1, 0, 1]], + ], + ) + def test_tree_traversal_combine_measurements(self, measurements, expected): + """Test that the measurement value of a given type can be combined""" + print(combine_measurements_core(*measurements)) + combined_measurement = combine_measurements_core(*measurements) + if isinstance(combined_measurement, dict): + assert combined_measurement == expected + else: + assert qml.math.allclose(combined_measurement, expected) From c5f962b9fa5eae29967c9ff05fc457c8141e935d Mon Sep 17 00:00:00 2001 From: obliviateandsurrender Date: Fri, 13 Sep 2024 23:40:07 -0400 Subject: [PATCH 10/18] fix `isort` --- tests/devices/qubit/test_simulate.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/devices/qubit/test_simulate.py b/tests/devices/qubit/test_simulate.py index 18483550cd4..256ea48125d 100644 --- a/tests/devices/qubit/test_simulate.py +++ b/tests/devices/qubit/test_simulate.py @@ -22,15 +22,15 @@ import pennylane as qml from pennylane.devices.qubit import get_final_state, measure_final_state, simulate from pennylane.devices.qubit.simulate import ( + TreeTraversalStack, _FlexShots, branch_state, combine_measurements_core, + counts_to_probs, find_post_processed_mcms, samples_to_counts, - counts_to_probs, simulate_tree_mcm, split_circuit_at_mcms, - TreeTraversalStack, ) @@ -1205,10 +1205,10 @@ def test_init_with_depth(self, max_depth): tree_stack = TreeTraversalStack(max_depth) assert tree_stack.counts.count(None) == max_depth - assert tree_stack.counts.probs(None) == max_depth - assert tree_stack.counts.results_0(None) == max_depth - assert tree_stack.counts.results_1(None) == max_depth - assert tree_stack.counts.states(None) == max_depth + assert tree_stack.probs.count(None) == max_depth + assert tree_stack.results_0.count(None) == max_depth + assert tree_stack.results_1.count(None) == max_depth + assert tree_stack.states.count(None) == max_depth @pytest.mark.unit def test_full_prune_empty_methods(self): From e250c89bb53bbb18c696e0f433c0c9eb42f1865a Mon Sep 17 00:00:00 2001 From: obliviateandsurrender Date: Wed, 18 Sep 2024 16:01:51 -0400 Subject: [PATCH 11/18] add another test --- tests/devices/qubit/test_simulate.py | 61 +++++++++++++++++++++++++++- 1 file changed, 60 insertions(+), 1 deletion(-) diff --git a/tests/devices/qubit/test_simulate.py b/tests/devices/qubit/test_simulate.py index 256ea48125d..36dde770eed 100644 --- a/tests/devices/qubit/test_simulate.py +++ b/tests/devices/qubit/test_simulate.py @@ -33,6 +33,14 @@ split_circuit_at_mcms, ) +ml_frameworks_list = [ + "numpy", + pytest.param("autograd", marks=pytest.mark.autograd), + pytest.param("jax", marks=pytest.mark.jax), + pytest.param("torch", marks=pytest.mark.torch), + pytest.param("tensorflow", marks=pytest.mark.tf), +] + class TestCurrentlyUnsupportedCases: # pylint: disable=too-few-public-methods @@ -1434,7 +1442,58 @@ def test_approx_dynamic_mid_meas_circuit(self, shots, rng, angles, measure_f): ) @pytest.mark.unit - @pytest.mark.parametrize("postselect_mode", ["hw-like", "fill-shots"]) + @pytest.mark.parametrize("ml_framework", ml_frameworks_list) + @pytest.mark.parametrize( + "postselect_mode", [None, "hw-like", "pad-invalid-samples", "fill-shots"] + ) + def test_tree_traversal_interface_mcm(self, ml_framework, postselect_mode): + """Test that tree traversal works numerically with different interfaces""" + # pylint:disable = singleton-comparison, import-outside-toplevel + if ml_framework == "tensorflow": + import tensorflow as tf + + tf.experimental.numpy.experimental_enable_numpy_behavior() + + qscript = qml.tape.QuantumScript( + [ + qml.RX(np.pi / 4, wires=0), + (m0 := qml.measure(0, reset=True)).measurements[0], + qml.RX(np.pi / 4, wires=0), + ], + [qml.sample(qml.Z(0)), qml.sample(m0)], + shots=5500, + ) + + res1, res2 = simulate_tree_mcm(qscript, interface=ml_framework) + + p1 = [qml.math.mean(res1 == -1), qml.math.mean(res1 == 1)] + p2 = [qml.math.mean(res2 == True), qml.math.mean(res2 == False)] + assert qml.math.allclose(qml.math.sum(sp.special.rel_entr(p1, p2)), 0.0, atol=0.05) + + qscript2 = qml.tape.QuantumScript( + [ + qml.RX(np.pi / 4, wires=0), + (m0 := qml.measure(0, postselect=0)).measurements[0], + qml.RX(np.pi / 4, wires=0), + ], + [qml.sample(qml.Z(0))], + shots=5500, + ) + qscript3 = qml.tape.QuantumScript( + [qml.RX(np.pi / 4, wires=0)], [qml.sample(qml.Z(0))], shots=5500 + ) + + res3 = simulate_tree_mcm(qscript2, postselect_mode=postselect_mode) + res4 = simulate(qscript3) + + p3 = [qml.math.mean(res3 == -1), qml.math.mean(res3 == 1)] + p4 = [qml.math.mean(res4 == -1), qml.math.mean(res4 == 1)] + assert qml.math.allclose(qml.math.sum(sp.special.rel_entr(p3, p4)), 0.0, atol=0.05) + + @pytest.mark.unit + @pytest.mark.parametrize( + "postselect_mode", [None, "hw-like", "pad-invalid-samples", "fill-shots"] + ) def test_tree_traversal_postselect_mode(self, postselect_mode): """Test that invalid shots are discarded if requested""" From 9255508def749ac01333817ad00ff6af1e25442e Mon Sep 17 00:00:00 2001 From: obliviateandsurrender Date: Wed, 18 Sep 2024 16:03:08 -0400 Subject: [PATCH 12/18] skip `tf` --- tests/devices/qubit/test_simulate.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/devices/qubit/test_simulate.py b/tests/devices/qubit/test_simulate.py index e9addaae190..74e847a345f 100644 --- a/tests/devices/qubit/test_simulate.py +++ b/tests/devices/qubit/test_simulate.py @@ -1448,11 +1448,7 @@ def test_approx_dynamic_mid_meas_circuit(self, shots, rng, angles, measure_f): ) def test_tree_traversal_interface_mcm(self, ml_framework, postselect_mode): """Test that tree traversal works numerically with different interfaces""" - # pylint:disable = singleton-comparison, import-outside-toplevel - if ml_framework == "tensorflow": - import tensorflow as tf - - tf.experimental.numpy.experimental_enable_numpy_behavior() + # pylint:disable = singleton-comparison qscript = qml.tape.QuantumScript( [ From 85c902bca8ab7e32950f57e0ca1ce11092d39663 Mon Sep 17 00:00:00 2001 From: obliviateandsurrender Date: Wed, 18 Sep 2024 16:11:54 -0400 Subject: [PATCH 13/18] revert `tf` --- tests/devices/qubit/test_simulate.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/devices/qubit/test_simulate.py b/tests/devices/qubit/test_simulate.py index 74e847a345f..e9addaae190 100644 --- a/tests/devices/qubit/test_simulate.py +++ b/tests/devices/qubit/test_simulate.py @@ -1448,7 +1448,11 @@ def test_approx_dynamic_mid_meas_circuit(self, shots, rng, angles, measure_f): ) def test_tree_traversal_interface_mcm(self, ml_framework, postselect_mode): """Test that tree traversal works numerically with different interfaces""" - # pylint:disable = singleton-comparison + # pylint:disable = singleton-comparison, import-outside-toplevel + if ml_framework == "tensorflow": + import tensorflow as tf + + tf.experimental.numpy.experimental_enable_numpy_behavior() qscript = qml.tape.QuantumScript( [ From 2a4bea285db32640fc6284897c8dcfe851f7fbfc Mon Sep 17 00:00:00 2001 From: obliviateandsurrender Date: Wed, 18 Sep 2024 18:11:06 -0400 Subject: [PATCH 14/18] fix post selection --- tests/devices/qubit/test_simulate.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/devices/qubit/test_simulate.py b/tests/devices/qubit/test_simulate.py index e9addaae190..6626ca8c8b4 100644 --- a/tests/devices/qubit/test_simulate.py +++ b/tests/devices/qubit/test_simulate.py @@ -1491,9 +1491,7 @@ def test_tree_traversal_interface_mcm(self, ml_framework, postselect_mode): assert qml.math.allclose(qml.math.sum(sp.special.rel_entr(p3, p4)), 0.0, atol=0.05) @pytest.mark.unit - @pytest.mark.parametrize( - "postselect_mode", [None, "hw-like", "pad-invalid-samples", "fill-shots"] - ) + @pytest.mark.parametrize("postselect_mode", ["hw-like", "fill-shots"]) def test_tree_traversal_postselect_mode(self, postselect_mode): """Test that invalid shots are discarded if requested""" From a84a6a9518c190a1294ec5b8ed2815aa9dc5c6bb Mon Sep 17 00:00:00 2001 From: obliviateandsurrender Date: Thu, 19 Sep 2024 10:45:33 -0400 Subject: [PATCH 15/18] shifting marker to test `class` --- tests/devices/qubit/test_simulate.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/tests/devices/qubit/test_simulate.py b/tests/devices/qubit/test_simulate.py index 6626ca8c8b4..8684dd27c1c 100644 --- a/tests/devices/qubit/test_simulate.py +++ b/tests/devices/qubit/test_simulate.py @@ -1200,10 +1200,10 @@ def test_qinfo_tf(self): assert qml.math.allclose(grad5, expected_grads[5]) +@pytest.mark.unit class TestTreeTraversalStack: """Unit tests for TreeTraversalStack""" - @pytest.mark.unit @pytest.mark.parametrize( "max_depth", [0, 1, 10, 100], @@ -1218,7 +1218,6 @@ def test_init_with_depth(self, max_depth): assert tree_stack.results_1.count(None) == max_depth assert tree_stack.states.count(None) == max_depth - @pytest.mark.unit def test_full_prune_empty_methods(self): """Test that TreeTraversalStack object's class methods work correctly.""" @@ -1255,10 +1254,10 @@ def test_full_prune_empty_methods(self): assert tree_stack.any_is_empty(depth) +@pytest.mark.unit class TestMidMeasurements: """Tests for simulating scripts with mid-circuit measurements using the ``simulate_tree_mcm``.""" - @pytest.mark.unit @pytest.mark.parametrize("val", [0, 1]) def test_basic_mid_meas_circuit(self, val): """Test execution with a basic circuit with mid-circuit measurements.""" @@ -1269,7 +1268,6 @@ def test_basic_mid_meas_circuit(self, val): result = simulate_tree_mcm(qs) assert result == (0, (-1.0) ** val) - @pytest.mark.unit def test_basic_mid_meas_circuit_with_reset(self): """Test execution with a basic circuit with mid-circuit measurements.""" qs = qml.tape.QuantumScript( @@ -1287,7 +1285,6 @@ def test_basic_mid_meas_circuit_with_reset(self): assert qml.math.allclose(result, qml.math.array([0.5, 0.5])) # pylint: disable=too-many-arguments - @pytest.mark.unit @pytest.mark.parametrize("shots", [None, 5500]) @pytest.mark.parametrize("postselect", [None, 0]) @pytest.mark.parametrize("reset", [False, True]) @@ -1374,7 +1371,6 @@ def test_simple_dynamic_circuit(self, shots, measure_f, postselect, reset, meas_ ) mcm_utils.validate_measurements(measure_f, shots, results2, results0) - @pytest.mark.unit @pytest.mark.parametrize("shots", [None, 5500, [5500, 5500]]) @pytest.mark.parametrize("rng", [None, 42, np.array([37])]) @pytest.mark.parametrize("angles", [(0.123, 0.015), (0.543, 0.057)]) @@ -1441,7 +1437,6 @@ def test_approx_dynamic_mid_meas_circuit(self, shots, rng, angles, measure_f): sp.stats.entropy(prob_dist1 + 1e-12, prob_dist2 + 1e-12), 0.0, atol=5e-2 ) - @pytest.mark.unit @pytest.mark.parametrize("ml_framework", ml_frameworks_list) @pytest.mark.parametrize( "postselect_mode", [None, "hw-like", "pad-invalid-samples", "fill-shots"] @@ -1490,7 +1485,6 @@ def test_tree_traversal_interface_mcm(self, ml_framework, postselect_mode): p4 = [qml.math.mean(res4 == -1), qml.math.mean(res4 == 1)] assert qml.math.allclose(qml.math.sum(sp.special.rel_entr(p3, p4)), 0.0, atol=0.05) - @pytest.mark.unit @pytest.mark.parametrize("postselect_mode", ["hw-like", "fill-shots"]) def test_tree_traversal_postselect_mode(self, postselect_mode): """Test that invalid shots are discarded if requested""" @@ -1511,7 +1505,6 @@ def test_tree_traversal_postselect_mode(self, postselect_mode): assert (len(res) < shots) if postselect_mode == "hw-like" else (len(res) == shots) assert np.all(res != np.iinfo(np.int32).min) - @pytest.mark.unit def test_tree_traversal_deep_circuit(self): """Test that `simulate_tree_mcm` works with circuits with many mid-circuit measurements""" @@ -1545,7 +1538,6 @@ def test_tree_traversal_deep_circuit(self): assert len(res[0]) == 20 assert isinstance(res[1], dict) and sum(list(res[1].values())) == 20 - @pytest.mark.unit @pytest.mark.parametrize( "measurements, expected", [ From 3705690165981318576fcb8601ae7721d33274b5 Mon Sep 17 00:00:00 2001 From: obliviateandsurrender Date: Thu, 19 Sep 2024 11:12:01 -0400 Subject: [PATCH 16/18] apply suggestions --- pennylane/devices/qubit/simulate.py | 3 ++- tests/devices/qubit/test_simulate.py | 4 ---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/pennylane/devices/qubit/simulate.py b/pennylane/devices/qubit/simulate.py index 1b4cec10b33..13ab3be5b53 100644 --- a/pennylane/devices/qubit/simulate.py +++ b/pennylane/devices/qubit/simulate.py @@ -689,7 +689,8 @@ def prepend_state_prep(circuit, state, interface, wires): else state ) return qml.tape.QuantumScript( - [qml.StatePrep(state.ravel(), wires=wires, validate_norm=False)] + circuit.operations, + [qml.StatePrep(qml.math.ravel(state), wires=wires, validate_norm=False)] + + circuit.operations, circuit.measurements, shots=circuit.shots, ) diff --git a/tests/devices/qubit/test_simulate.py b/tests/devices/qubit/test_simulate.py index 8684dd27c1c..2adb4ddc145 100644 --- a/tests/devices/qubit/test_simulate.py +++ b/tests/devices/qubit/test_simulate.py @@ -1444,10 +1444,6 @@ def test_approx_dynamic_mid_meas_circuit(self, shots, rng, angles, measure_f): def test_tree_traversal_interface_mcm(self, ml_framework, postselect_mode): """Test that tree traversal works numerically with different interfaces""" # pylint:disable = singleton-comparison, import-outside-toplevel - if ml_framework == "tensorflow": - import tensorflow as tf - - tf.experimental.numpy.experimental_enable_numpy_behavior() qscript = qml.tape.QuantumScript( [ From aa045f369b84a732e5f38e26511092a31f4aa8b3 Mon Sep 17 00:00:00 2001 From: obliviateandsurrender Date: Thu, 19 Sep 2024 15:23:47 -0400 Subject: [PATCH 17/18] reliable one-shot? --- tests/devices/qubit/test_simulate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/devices/qubit/test_simulate.py b/tests/devices/qubit/test_simulate.py index c720d687451..b82a998af57 100644 --- a/tests/devices/qubit/test_simulate.py +++ b/tests/devices/qubit/test_simulate.py @@ -1573,7 +1573,7 @@ def test_simulate_one_shot_native_mcm(self, ml_framework, postselect_mode): circuit = qml.tape.QuantumScript(q.queue, [qml.expval(qml.Z(0)), qml.sample(m)], shots=[1]) - n_shots = 200 + n_shots = 500 results = [ simulate_one_shot_native_mcm( circuit, From 20499eb2d05fde7d7c9fad9d69c900b7baffb62e Mon Sep 17 00:00:00 2001 From: obliviateandsurrender Date: Thu, 19 Sep 2024 15:42:22 -0400 Subject: [PATCH 18/18] happier `one_shot`? --- tests/devices/qubit/test_simulate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/devices/qubit/test_simulate.py b/tests/devices/qubit/test_simulate.py index b82a998af57..85c97acfa9d 100644 --- a/tests/devices/qubit/test_simulate.py +++ b/tests/devices/qubit/test_simulate.py @@ -1558,7 +1558,7 @@ def test_tree_traversal_combine_measurements(self, measurements, expected): else: assert qml.math.allclose(combined_measurement, expected) - @flaky(max_runs=3, min_passes=2) + @flaky(max_runs=3, min_passes=1) @pytest.mark.parametrize("ml_framework", ml_frameworks_list) @pytest.mark.parametrize( "postselect_mode", [None, "hw-like", "pad-invalid-samples", "fill-shots"]