diff --git a/pennylane/devices/qubit/simulate.py b/pennylane/devices/qubit/simulate.py index f6ba32d4c7f..f535f5428e6 100644 --- a/pennylane/devices/qubit/simulate.py +++ b/pennylane/devices/qubit/simulate.py @@ -689,7 +689,8 @@ def prepend_state_prep(circuit, state, interface, wires): else state ) return qml.tape.QuantumScript( - [qml.StatePrep(state.ravel(), wires=wires, validate_norm=False)] + circuit.operations, + [qml.StatePrep(qml.math.ravel(state), wires=wires, validate_norm=False)] + + circuit.operations, circuit.measurements, shots=circuit.shots, ) @@ -942,7 +943,7 @@ def _(original_measurement: ProbabilityMP, measures): # pylint: disable=unused- @combine_measurements_core.register def _(original_measurement: SampleMP, measures): # pylint: disable=unused-argument - """The combined samples of two branches is obtained by concatenating the sample if each branch..""" + """The combined samples of two branches is obtained by concatenating the sample of each branch.""" new_sample = tuple( qml.math.atleast_1d(m[1]) for m in measures.values() if m[0] and not m[1] is tuple() ) diff --git a/tests/devices/qubit/test_simulate.py b/tests/devices/qubit/test_simulate.py index cb28c88648b..85c97acfa9d 100644 --- a/tests/devices/qubit/test_simulate.py +++ b/tests/devices/qubit/test_simulate.py @@ -13,15 +13,36 @@ # limitations under the License. """Unit tests for simulate in devices/qubit.""" +import mcm_utils import numpy as np import pytest +import scipy as sp from dummy_debugger import Debugger from flaky import flaky from stat_utils import fisher_exact_test import pennylane as qml from pennylane.devices.qubit import get_final_state, measure_final_state, simulate -from pennylane.devices.qubit.simulate import _FlexShots, simulate_one_shot_native_mcm +from pennylane.devices.qubit.simulate import ( + TreeTraversalStack, + _FlexShots, + branch_state, + combine_measurements_core, + counts_to_probs, + find_post_processed_mcms, + samples_to_counts, + simulate_one_shot_native_mcm, + simulate_tree_mcm, + split_circuit_at_mcms, +) + +ml_frameworks_list = [ + "numpy", + pytest.param("autograd", marks=pytest.mark.autograd), + pytest.param("jax", marks=pytest.mark.jax), + pytest.param("torch", marks=pytest.mark.torch), + pytest.param("tensorflow", marks=pytest.mark.tf), +] class TestCurrentlyUnsupportedCases: @@ -1182,21 +1203,362 @@ def test_qinfo_tf(self): assert qml.math.allclose(grad5, expected_grads[5]) -ml_frameworks_list = [ - "numpy", - pytest.param("autograd", marks=pytest.mark.autograd), - pytest.param("jax", marks=pytest.mark.jax), - pytest.param("torch", marks=pytest.mark.torch), - pytest.param("tensorflow", marks=pytest.mark.tf), -] +@pytest.mark.unit +class TestTreeTraversalStack: + """Unit tests for TreeTraversalStack""" + + @pytest.mark.parametrize( + "max_depth", + [0, 1, 10, 100], + ) + def test_init_with_depth(self, max_depth): + """Test that TreeTraversalStack is initialized correctly with given ``max_depth``""" + tree_stack = TreeTraversalStack(max_depth) + + assert tree_stack.counts.count(None) == max_depth + assert tree_stack.probs.count(None) == max_depth + assert tree_stack.results_0.count(None) == max_depth + assert tree_stack.results_1.count(None) == max_depth + assert tree_stack.states.count(None) == max_depth + + def test_full_prune_empty_methods(self): + """Test that TreeTraversalStack object's class methods work correctly.""" + + max_depth = 10 + tree_stack = TreeTraversalStack(max_depth) + + np.random.shuffle(r_depths := list(range(max_depth))) + for depth in r_depths: + counts_0 = np.random.randint(1, 9) + counts_1 = 10 - counts_0 + tree_stack.counts[depth] = [counts_0, counts_1] + tree_stack.probs[depth] = [counts_0 / 10, counts_1 / 10] + tree_stack.results_0[depth] = [0] * counts_0 + tree_stack.results_1[depth] = [1] * counts_1 + tree_stack.states[depth] = [np.sqrt(counts_0), np.sqrt(counts_1)] + assert tree_stack.is_full(depth) + + assert tree_stack.counts[depth] == list( + samples_to_counts( + np.array(tree_stack.results_0[depth] + tree_stack.results_1[depth]) + ).values() + ) + assert tree_stack.probs[depth] == list( + counts_to_probs(dict(zip([0, 1], tree_stack.counts[depth]))).values() + ) + + state_vec = np.array(tree_stack.states[depth]).T + state_vec /= np.linalg.norm(tree_stack.states[depth]) + meas, meas_r = qml.measure(0), qml.measure(0, reset=True) + assert np.allclose(branch_state(state_vec, 0, meas.measurements[0]), [1.0, 0.0]) + assert np.allclose(branch_state(state_vec, 1, meas_r.measurements[0]), [1.0, 0.0]) + + tree_stack.prune(depth) + assert tree_stack.any_is_empty(depth) -# pylint:disable=too-few-public-methods @pytest.mark.unit -class TestMidCircuitMeasurements: - """Unit tests for simulating mid-circuit measurements.""" +class TestMidMeasurements: + """Tests for simulating scripts with mid-circuit measurements using the ``simulate_tree_mcm``.""" + + @pytest.mark.parametrize("val", [0, 1]) + def test_basic_mid_meas_circuit(self, val): + """Test execution with a basic circuit with mid-circuit measurements.""" + qs = qml.tape.QuantumScript( + [qml.Hadamard(0), qml.CNOT([0, 1]), qml.measurements.MidMeasureMP(0, postselect=val)], + [qml.expval(qml.X(0)), qml.expval(qml.Z(0))], + ) + result = simulate_tree_mcm(qs) + assert result == (0, (-1.0) ** val) + + def test_basic_mid_meas_circuit_with_reset(self): + """Test execution with a basic circuit with mid-circuit measurements.""" + qs = qml.tape.QuantumScript( + [ + qml.Hadamard(0), + qml.Hadamard(1), + qml.CNOT([0, 1]), + (m0 := qml.measure(0, reset=True)).measurements[0], + qml.Hadamard(0), + qml.CNOT([1, 0]), + ], # equivalent to a circuit that gives equiprobable basis states + [qml.probs(op=m0), qml.probs(op=qml.Z(0)), qml.probs(op=qml.Z(1))], + ) + result = simulate_tree_mcm(qs) + assert qml.math.allclose(result, qml.math.array([0.5, 0.5])) + + # pylint: disable=too-many-arguments + @pytest.mark.parametrize("shots", [None, 5500]) + @pytest.mark.parametrize("postselect", [None, 0]) + @pytest.mark.parametrize("reset", [False, True]) + @pytest.mark.parametrize("measure_f", [qml.counts, qml.expval, qml.probs, qml.sample, qml.var]) + @pytest.mark.parametrize( + "meas_obj", [qml.Y(0), [1], [1, 0], "mcm", "composite_mcm", "mcm_list"] + ) + def test_simple_dynamic_circuit(self, shots, measure_f, postselect, reset, meas_obj): + """Tests that `simulate` can handles a simple dynamic circuit with the following measurements: + + * qml.counts with obs (comp basis or not), single wire, multiple wires (ordered/unordered), MCM, f(MCM), MCM list + * qml.expval with obs (comp basis or not), MCM, f(MCM), MCM list + * qml.probs with obs (comp basis or not), single wire, multiple wires (ordered/unordered), MCM, f(MCM), MCM list + * qml.sample with obs (comp basis or not), single wire, multiple wires (ordered/unordered), MCM, f(MCM), MCM list + * qml.var with obs (comp basis or not), MCM, f(MCM), MCM list + + The above combinations should work for finite shots, shot vectors and post-selecting of either the 0 or 1 branch. + """ + + if ( + isinstance(meas_obj, (qml.X, qml.Z, qml.Y)) + and measure_f in (qml.var,) + and not qml.operation.active_new_opmath() + ): + pytest.xfail( + "The tree-traversal method does not work with legacy opmath with " + "`qml.var` of pauli observables in the circuit." + ) + + if measure_f in (qml.expval, qml.var) and ( + isinstance(meas_obj, list) or meas_obj == "mcm_list" + ): + pytest.skip("Can't use wires/mcm lists with var or expval") + + if measure_f in (qml.counts, qml.sample) and shots is None: + pytest.skip("Can't measure counts/sample in analytic mode (`shots=None`)") + + if measure_f in (qml.probs,) and meas_obj in ["composite_mcm"]: + pytest.skip( + "Cannot use qml.probs() when measuring multiple mid-circuit measurements collected using arithmetic operators." + ) + + qscript = qml.tape.QuantumScript( + [ + qml.RX(np.pi / 2.5, 0), + qml.RZ(np.pi / 4, 0), + (m0 := qml.measure(0, reset=reset)).measurements[0], + qml.ops.op_math.Conditional(m0 == 0, qml.RX(np.pi / 4, 0)), + qml.ops.op_math.Conditional(m0 == 1, qml.RX(-np.pi / 4, 0)), + qml.RX(np.pi / 3, 1), + qml.RZ(np.pi / 4, 1), + (m1 := qml.measure(1, postselect=postselect)).measurements[0], + qml.ops.op_math.Conditional(m1 == 0, qml.RY(np.pi / 4, 1)), + qml.ops.op_math.Conditional(m1 == 1, qml.RY(-np.pi / 4, 1)), + ], + [ + measure_f( + **{ + "wires" if isinstance(meas_obj, list) else "op": ( + ( + m0 + if meas_obj == "mcm" + else (0.5 * m0 + m1 if meas_obj == "composite_mcm" else [m0, m1]) + ) + if isinstance(meas_obj, str) + else meas_obj + ) + } + ) + ], + shots=shots, + ) + + results0 = simulate(qscript, mcm_method="tree-traversal") + + deferred_tapes, deferred_func = qml.defer_measurements(qscript) + results1 = deferred_func([simulate(tape, mcm_method="deferred") for tape in deferred_tapes]) + mcm_utils.validate_measurements(measure_f, shots, results1, results0) + + if shots is not None: + one_shot_tapes, one_shot_func = qml.dynamic_one_shot(qscript) + results2 = one_shot_func( + [simulate(tape, mcm_method="one-shot") for tape in one_shot_tapes] + ) + mcm_utils.validate_measurements(measure_f, shots, results2, results0) + + @pytest.mark.parametrize("shots", [None, 5500, [5500, 5500]]) + @pytest.mark.parametrize("rng", [None, 42, np.array([37])]) + @pytest.mark.parametrize("angles", [(0.123, 0.015), (0.543, 0.057)]) + @pytest.mark.parametrize("measure_f", [qml.probs, qml.sample]) + def test_approx_dynamic_mid_meas_circuit(self, shots, rng, angles, measure_f): + """Test execution of a dynamic circuit with an equivalent static one.""" + + if measure_f in (qml.sample,) and shots is None: + pytest.skip("Can't measure samples in analytic mode (`shots=None`)") + + qs_with_mid_meas = qml.tape.QuantumScript( + [ + qml.Hadamard(0), + qml.Hadamard(1), + qml.CZ([0, 1]), + qml.CNOT([0, 2]), + qml.CNOT([2, 3]), + qml.CZ([1, 3]), + qml.Toffoli([3, 2, 0]), + (m0 := qml.measure(0)).measurements[0], + qml.ops.op_math.Conditional(m0, qml.RZ(angles[0], 1)), + qml.Hadamard(1), + qml.Z(1), + (m1 := qml.measure(1)).measurements[0], + qml.ops.op_math.Conditional(m1, qml.RX(angles[1], 3)), + ], + [measure_f(wires=[0, 1, 2, 3])], + shots=shots, + ) + qs_without_mid_meas = qml.tape.QuantumScript( + [ + qml.Hadamard(0), + qml.Hadamard(1), + qml.CZ([0, 1]), + qml.CNOT([0, 2]), + qml.CNOT([2, 3]), + qml.CZ([1, 3]), + qml.Toffoli([3, 2, 0]), + qml.Hadamard(1), + qml.Z(1), + qml.RX(angles[1], 3), + ], + [measure_f(wires=[0, 1, 2, 3])], + shots=shots, + ) # approximate compiled circuit of the above + res1 = simulate_tree_mcm(qs_with_mid_meas, rng=rng) + res2 = simulate(qs_without_mid_meas, rng=rng) + + if not isinstance(shots, list): + res1, res2 = (res1,), (res2,) + + for rs1, rs2 in zip(res1, res2): + prob_dist1, prob_dist2 = rs1, rs2 + if measure_f in (qml.sample,): + n_wires = rs1.shape[1] + prob_dist1, prob_dist2 = np.zeros(2**n_wires), np.zeros(2**n_wires) + for prob, rs in zip([prob_dist1, prob_dist2], [rs1, rs2]): + index, count = np.unique( + np.packbits(rs, axis=1, bitorder="little").squeeze(), return_counts=True + ) + prob[index] = count + + assert qml.math.allclose( + sp.stats.entropy(prob_dist1 + 1e-12, prob_dist2 + 1e-12), 0.0, atol=5e-2 + ) + + @pytest.mark.parametrize("ml_framework", ml_frameworks_list) + @pytest.mark.parametrize( + "postselect_mode", [None, "hw-like", "pad-invalid-samples", "fill-shots"] + ) + def test_tree_traversal_interface_mcm(self, ml_framework, postselect_mode): + """Test that tree traversal works numerically with different interfaces""" + # pylint:disable = singleton-comparison, import-outside-toplevel + + qscript = qml.tape.QuantumScript( + [ + qml.RX(np.pi / 4, wires=0), + (m0 := qml.measure(0, reset=True)).measurements[0], + qml.RX(np.pi / 4, wires=0), + ], + [qml.sample(qml.Z(0)), qml.sample(m0)], + shots=5500, + ) + + res1, res2 = simulate_tree_mcm(qscript, interface=ml_framework) + + p1 = [qml.math.mean(res1 == -1), qml.math.mean(res1 == 1)] + p2 = [qml.math.mean(res2 == True), qml.math.mean(res2 == False)] + assert qml.math.allclose(qml.math.sum(sp.special.rel_entr(p1, p2)), 0.0, atol=0.05) + + qscript2 = qml.tape.QuantumScript( + [ + qml.RX(np.pi / 4, wires=0), + (m0 := qml.measure(0, postselect=0)).measurements[0], + qml.RX(np.pi / 4, wires=0), + ], + [qml.sample(qml.Z(0))], + shots=5500, + ) + qscript3 = qml.tape.QuantumScript( + [qml.RX(np.pi / 4, wires=0)], [qml.sample(qml.Z(0))], shots=5500 + ) + + res3 = simulate_tree_mcm(qscript2, postselect_mode=postselect_mode) + res4 = simulate(qscript3) + + p3 = [qml.math.mean(res3 == -1), qml.math.mean(res3 == 1)] + p4 = [qml.math.mean(res4 == -1), qml.math.mean(res4 == 1)] + assert qml.math.allclose(qml.math.sum(sp.special.rel_entr(p3, p4)), 0.0, atol=0.05) + + @pytest.mark.parametrize("postselect_mode", ["hw-like", "fill-shots"]) + def test_tree_traversal_postselect_mode(self, postselect_mode): + """Test that invalid shots are discarded if requested""" + + shots = 100 + qscript = qml.tape.QuantumScript( + [ + qml.RX(np.pi / 2, 0), + (m0 := qml.measure(0, postselect=1)).measurements[0], + qml.ops.op_math.Conditional(m0, qml.RZ(1.57, 1)), + ], + [qml.sample(wires=[0, 1])], + shots=shots, + ) + + res = simulate_tree_mcm(qscript, postselect_mode=postselect_mode) + + assert (len(res) < shots) if postselect_mode == "hw-like" else (len(res) == shots) + assert np.all(res != np.iinfo(np.int32).min) + + def test_tree_traversal_deep_circuit(self): + """Test that `simulate_tree_mcm` works with circuits with many mid-circuit measurements""" + + n_circs = 500 + operations = [] + for _ in range(n_circs): + operations.extend( + [ + qml.RX(1.234, 0), + (m0 := qml.measure(0, postselect=1)).measurements[0], + qml.CNOT([0, 1]), + qml.ops.op_math.Conditional(m0, qml.RZ(1.786, 1)), + ] + ) + + qscript = qml.tape.QuantumScript( + operations, + [qml.sample(wires=[0, 1]), qml.counts(wires=[0, 1])], + shots=20, + ) + + mcms = find_post_processed_mcms(qscript) + assert len(mcms) == n_circs + + split_circs = split_circuit_at_mcms(qscript) + assert len(split_circs) == n_circs + 1 + for circ in split_circs: + assert not [o for o in circ.operations if isinstance(o, qml.measurements.MidMeasureMP)] + + res = simulate_tree_mcm(qscript, postselect_mode="fill-shots") + assert len(res[0]) == 20 + assert isinstance(res[1], dict) and sum(list(res[1].values())) == 20 + + @pytest.mark.parametrize( + "measurements, expected", + [ + [(qml.counts(0), {"a": (1, {0: 42}), "b": (2, {1: 58})}), {0: 42, 1: 58}], + [ + (qml.expval(qml.Z(0)), {"a": (1, (0.42, -1)), "b": (2, (0.58, 1))}), + [1.58 / 3, 1 / 3], + ], + [(qml.probs(wires=0), {"a": (1, (0.42, -1)), "b": (2, (0.58, 1))}), [1.58 / 3, 1 / 3]], + [(qml.sample(wires=0), {"a": (1, (0, 1, 0)), "b": (2, (1, 0, 1))}), [0, 1, 0, 1, 0, 1]], + ], + ) + def test_tree_traversal_combine_measurements(self, measurements, expected): + """Test that the measurement value of a given type can be combined""" + print(combine_measurements_core(*measurements)) + combined_measurement = combine_measurements_core(*measurements) + if isinstance(combined_measurement, dict): + assert combined_measurement == expected + else: + assert qml.math.allclose(combined_measurement, expected) - @flaky(max_runs=3, min_passes=2) + @flaky(max_runs=3, min_passes=1) @pytest.mark.parametrize("ml_framework", ml_frameworks_list) @pytest.mark.parametrize( "postselect_mode", [None, "hw-like", "pad-invalid-samples", "fill-shots"] @@ -1211,7 +1573,7 @@ def test_simulate_one_shot_native_mcm(self, ml_framework, postselect_mode): circuit = qml.tape.QuantumScript(q.queue, [qml.expval(qml.Z(0)), qml.sample(m)], shots=[1]) - n_shots = 200 + n_shots = 500 results = [ simulate_one_shot_native_mcm( circuit,