diff --git a/tests/devices/default_qubit/test_default_qubit_native_mcm.py b/tests/devices/default_qubit/test_default_qubit_native_mcm.py index 5f63ec7fe2d..80467c66d3b 100644 --- a/tests/devices/default_qubit/test_default_qubit_native_mcm.py +++ b/tests/devices/default_qubit/test_default_qubit_native_mcm.py @@ -20,7 +20,7 @@ import pytest import pennylane as qml -from pennylane.devices.qubit.apply_operation import MidMeasureMP, apply_mid_measure +from pennylane.devices.qubit.apply_operation import MidMeasureMP from pennylane.devices.qubit.simulate import combine_measurements_core, measurement_with_no_shots from pennylane.transforms.dynamic_one_shot import fill_in_value @@ -53,16 +53,9 @@ def test_measurement_with_no_shots(): assert all(np.isnan(probs).tolist()) -def test_apply_mid_measure(): - """Test that apply_mid_measure raises if applied to a batched state.""" - with pytest.raises(ValueError, match="MidMeasureMP cannot be applied to batched states."): - _ = apply_mid_measure( - MidMeasureMP(0), np.zeros((2, 2)), is_state_batched=True, mid_measurements={} - ) - - +@pytest.mark.parametrize("obs", ["mid-meas", "pauli"]) @pytest.mark.parametrize("mcm_method", ["one-shot", "tree-traversal"]) -def test_all_invalid_shots_circuit(mcm_method): +def test_all_invalid_shots_circuit(obs, mcm_method): """Test that circuits in which all shots mismatch with post-selection conditions return the same answer as ``defer_measurements``.""" dev = get_device() dev_shots = get_device(shots=10) @@ -71,9 +64,13 @@ def circuit_op(): m = qml.measure(0, postselect=1) qml.cond(m, qml.PauliX)(1) return ( - qml.expval(op=qml.PauliZ(1)), - qml.probs(op=qml.PauliY(0) @ qml.PauliZ(1)), - qml.var(op=qml.PauliZ(1)), + ( + qml.expval(op=qml.PauliZ(1)), + qml.probs(op=qml.PauliY(0) @ qml.PauliZ(1)), + qml.var(op=qml.PauliZ(1)), + ) + if obs == "pauli" + else (qml.expval(op=m), qml.probs(op=m), qml.var(op=m)) ) res1 = qml.QNode(circuit_op, dev, mcm_method="deferred")() @@ -84,19 +81,6 @@ def circuit_op(): assert np.all(np.isnan(r1)) assert np.all(np.isnan(r2)) - def circuit_mcm(): - m = qml.measure(0, postselect=1) - qml.cond(m, qml.PauliX)(1) - return qml.expval(op=m), qml.probs(op=m), qml.var(op=m) - - res1 = qml.QNode(circuit_mcm, dev, mcm_method="deferred")() - res2 = qml.QNode(circuit_mcm, dev_shots, mcm_method=mcm_method)() - for r1, r2 in zip(res1, res2): - if isinstance(r1, Sequence): - assert len(r1) == len(r2) - assert np.all(np.isnan(r1)) - assert np.all(np.isnan(r2)) - @pytest.mark.parametrize("mcm_method", ["one-shot", "tree-traversal"]) def test_unsupported_measurement(mcm_method): @@ -118,41 +102,6 @@ def func(x, y): func(*params) -@pytest.mark.parametrize("postselect_mode", ["hw-like", "fill-shots"]) -def test_tree_traversal_postselect_mode(postselect_mode): - """Test that invalid shots are discarded if requested""" - shots = 100 - dev = qml.device("default.qubit", shots=shots) - - @qml.qnode(dev, mcm_method="tree-traversal", postselect_mode=postselect_mode) - def f(x): - qml.RX(x, 0) - _ = qml.measure(0, postselect=1) - return qml.sample(wires=[0, 1]) - - res = f(np.pi / 2) - - if postselect_mode == "hw-like": - assert len(res) < shots - else: - assert len(res) == shots - assert np.all(res != np.iinfo(np.int32).min) - - -def test_deep_circuit(): - """Tests that DefaultQubit handles a circuit with more than 1000 mid-circuit measurements.""" - - dev = qml.device("default.qubit", shots=10) - - def func(x): - for _ in range(600): - qml.RX(x, wires=0) - m0 = qml.measure(0) - return qml.expval(qml.PauliY(0)), qml.expval(m0) - - _ = qml.QNode(func, dev, mcm_method="tree-traversal")(0.1234) - - # pylint: disable=unused-argument def obs_tape(x, y, z, reset=False, postselect=None): qml.RX(x, 0) @@ -174,69 +123,16 @@ def obs_tape(x, y, z, reset=False, postselect=None): @pytest.mark.parametrize("mcm_method", ["one-shot", "tree-traversal"]) @pytest.mark.parametrize("shots", [None, 5500, [5500, 5501]]) -@pytest.mark.parametrize("postselect", [None, 0, 1]) -@pytest.mark.parametrize("measure_f", [qml.counts, qml.expval, qml.probs, qml.sample, qml.var]) @pytest.mark.parametrize( - "meas_obj", - [qml.PauliZ(0), qml.PauliY(1), [0], [0, 1], [1, 0], "mcm", "composite_mcm", "mcm_list"], + "params", + [ + [np.pi / 2.5, np.pi / 3, -np.pi / 3.5], + [[np.pi / 2.5, -np.pi / 3.5], [np.pi / 4.5, np.pi / 3.2], [np.pi, -np.pi / 0.5]], + ], ) -def test_simple_dynamic_circuit(mcm_method, shots, measure_f, postselect, meas_obj): - """Tests that DefaultQubit handles a simple dynamic circuit with the following measurements: - - * qml.counts with obs (comp basis or not), single wire, multiple wires (ordered/unordered), MCM, f(MCM), MCM list - * qml.expval with obs (comp basis or not), MCM, f(MCM), MCM list - * qml.probs with obs (comp basis or not), single wire, multiple wires (ordered/unordered), MCM, f(MCM), MCM list - * qml.sample with obs (comp basis or not), single wire, multiple wires (ordered/unordered), MCM, f(MCM), MCM list - * qml.var with obs (comp basis or not), MCM, f(MCM), MCM list - - The above combinations should work for finite shots, shot vectors and post-selecting of either the 0 or 1 branch. - """ - - if ( - isinstance(meas_obj, (qml.Z, qml.Y)) - and measure_f == qml.var - and mcm_method == "tree-traversal" - and not qml.operation.active_new_opmath() - ): - pytest.xfail( - "The tree-traversal method does not work with legacy opmath with " - "`qml.var` of pauli observables in the circuit." - ) - - if mcm_method == "one-shot" and shots is None: - pytest.skip("`mcm_method='one-shot'` is incompatible with analytic mode (`shots=None`)") - - if measure_f in (qml.expval, qml.var) and ( - isinstance(meas_obj, list) or meas_obj == "mcm_list" - ): - pytest.skip("Can't use wires/mcm lists with var or expval") - - if measure_f in (qml.counts, qml.sample) and shots is None: - pytest.skip("Can't measure counts/sample in analytic mode (`shots=None`)") - - dev = get_device(shots=shots) - params = [np.pi / 2.5, np.pi / 3, -np.pi / 3.5] - - def func(x, y, z): - m0, m1 = obs_tape(x, y, z, postselect=postselect) - mid_measure = ( - m0 if meas_obj == "mcm" else (0.5 * m0 if meas_obj == "composite_mcm" else [m0, m1]) - ) - measurement_key = "wires" if isinstance(meas_obj, list) else "op" - measurement_value = mid_measure if isinstance(meas_obj, str) else meas_obj - return measure_f(**{measurement_key: measurement_value}) - - results0 = qml.QNode(func, dev, mcm_method=mcm_method)(*params) - results1 = qml.QNode(func, dev, mcm_method="deferred")(*params) - - mcm_utils.validate_measurements(measure_f, shots, results1, results0) - - -@pytest.mark.parametrize("mcm_method", ["one-shot", "tree-traversal"]) -@pytest.mark.parametrize("shots", [None, 5000]) @pytest.mark.parametrize("postselect", [None, 0, 1]) @pytest.mark.parametrize("reset", [False, True]) -def test_multiple_measurements_and_reset(mcm_method, shots, postselect, reset): +def test_multiple_measurements_and_reset(mcm_method, shots, params, postselect, reset): """Tests that DefaultQubit handles a circuit with a single mid-circuit measurement with reset and a conditional gate. Multiple measurements of the mid-circuit measurement value are performed. This function also tests `reset` parametrizing over the parameter.""" @@ -250,8 +146,11 @@ def test_multiple_measurements_and_reset(mcm_method, shots, postselect, reset): if mcm_method == "one-shot" and shots is None: pytest.skip("`mcm_method='one-shot'` is incompatible with analytic mode (`shots=None`)") + batch_size = len(params[0]) if isinstance(params[0], list) else None + if batch_size is not None and shots is not None and postselect is not None: + pytest.skip("Postselection with samples doesn't work with broadcasting") + dev = get_device(shots=shots) - params = [np.pi / 2.5, np.pi / 3, -np.pi / 3.5] obs = qml.PauliY(1) state = qml.math.zeros((4,)) state[0] = 1.0 @@ -286,27 +185,34 @@ def func(x, y, z): if shots is None else [qml.counts, qml.expval, qml.probs, qml.sample, qml.var] ) - for measure_f, r1, r0 in zip(measurements, results1, results0): - mcm_utils.validate_measurements(measure_f, shots, r1, r0) + + if not isinstance(shots, list): + shots, results0, results1 = [shots], [results0], [results1] + + for shot, res1, res0 in zip(shots, results1, results0): + for measure_f, r1, r0 in zip(measurements, res1, res0): + if shots is None and measure_f in [qml.expval, qml.probs] and batch_size is not None: + r0 = qml.math.squeeze(r0) + mcm_utils.validate_measurements(measure_f, shot, r1, r0, batch_size=batch_size) @pytest.mark.parametrize("mcm_method", ["one-shot", "tree-traversal"]) -@pytest.mark.parametrize("shots", [None, 3000]) +@pytest.mark.parametrize("shots", [None, 5000, [5000, 5001]]) @pytest.mark.parametrize( - "mcm_f", + "mcm_name, mcm_func", [ - lambda x: x * -1, - lambda x: x * 1, - lambda x: x * 2, - lambda x: 1 - x, - lambda x: x + 1, - lambda x: x & 3, - "mix", - "list", + ("single", lambda x: x * -1), + ("single", lambda x: x * 2), + ("single", lambda x: 1 - x), + ("single", lambda x: x & 3), + ("mix", lambda x, y: x == y), + ("mix", lambda x, y: 4 * x + 2 * y), + ("all", lambda x, y, z: [x, y, z]), + ("all", lambda x, y, z: (x - 2 * y) * z + 7), ], ) @pytest.mark.parametrize("measure_f", [qml.counts, qml.expval, qml.probs, qml.sample, qml.var]) -def test_composite_mcms(mcm_method, shots, mcm_f, measure_f): +def test_composite_mcms(mcm_method, shots, mcm_name, mcm_func, measure_f): """Tests that DefaultQubit handles a circuit with a composite mid-circuit measurement and a conditional gate. A single measurement of a composite mid-circuit measurement is performed at the end.""" @@ -317,20 +223,20 @@ def test_composite_mcms(mcm_method, shots, mcm_f, measure_f): if measure_f in (qml.counts, qml.sample) and shots is None: pytest.skip("Can't measure counts/sample in analytic mode (`shots=None`)") - if measure_f in (qml.expval, qml.var) and (mcm_f in ("list", "mix")): + if measure_f in (qml.expval, qml.var) and mcm_name in ["mix", "all"]: pytest.skip( "expval/var does not support measuring sequences of measurements or observables." ) - if measure_f == qml.probs and mcm_f == "mix": + if measure_f in (qml.probs,) and mcm_name in ["mix", "all"]: pytest.skip( "Cannot use qml.probs() when measuring multiple mid-circuit measurements collected using arithmetic operators." ) dev = get_device(shots=shots) - param = np.pi / 3 + param = qml.numpy.array([np.pi / 3, np.pi / 6]) - def func(x): + def func(x, y): qml.RX(x, 0) m0 = qml.measure(0) qml.RX(0.5 * x, 1) @@ -338,52 +244,18 @@ def func(x): qml.cond((m0 + m1) == 2, qml.RY)(2.0 * x, 0) m2 = qml.measure(0) obs = ( - (m0 - 2 * m1) * m2 + 7 - if mcm_f == "mix" - else ([m0, m1, m2] if mcm_f == "list" else mcm_f(m2)) + mcm_func(m2) + if mcm_name == "single" + else (mcm_func(m0, m1) if mcm_name == "mix" else mcm_func(m0, m1, m2)) ) return measure_f(op=obs) - results0 = qml.QNode(func, dev, mcm_method=mcm_method)(param) - results1 = qml.QNode(func, dev, mcm_method="deferred")(param) + results0 = qml.QNode(func, dev, mcm_method=mcm_method)(*param) + results1 = qml.QNode(func, dev, mcm_method="deferred")(*param) mcm_utils.validate_measurements(measure_f, shots, results1, results0) -@pytest.mark.parametrize("mcm_method", ["one-shot", "tree-traversal"]) -@pytest.mark.parametrize( - "mcm_f", - [ - lambda x, y: x + y, - lambda x, y: x - 7 * y, - lambda x, y: x & y, - lambda x, y: x == y, - lambda x, y: 4.0 * x + 2.0 * y, - ], -) -def test_counts_return_type(mcm_method, mcm_f): - """Tests that DefaultQubit returns the same keys for ``qml.counts`` measurements with ``dynamic_one_shot`` and ``defer_measurements``.""" - - shots = 500 - - dev = get_device(shots=shots) - param = np.pi / 3 - - def func(x): - qml.RX(x, 0) - m0 = qml.measure(0) - qml.RX(0.5 * x, 1) - m1 = qml.measure(1) - qml.cond((m0 + m1) == 2, qml.RY)(2.0 * x, 0) - return qml.counts(op=mcm_f(m0, m1)) - - results0 = qml.QNode(func, dev, mcm_method=mcm_method)(param) - results1 = qml.QNode(func, dev, mcm_method="deferred")(param) - - for r1, r0 in zip(results1.keys(), results0.keys()): - assert r1 == r0 - - @pytest.mark.parametrize("shots", [5000]) @pytest.mark.parametrize("postselect", [None, 0, 1]) @pytest.mark.parametrize("reset", [False, True]) @@ -421,44 +293,6 @@ def func(x, y): assert np.allclose(grad1, grad2, atol=0.01, rtol=0.3) -@pytest.mark.parametrize("mcm_method", ["one-shot", "tree-traversal"]) -@pytest.mark.parametrize("shots", [None, 5500, [5500, 5501]]) -@pytest.mark.parametrize("postselect", [None, 0]) -@pytest.mark.parametrize("measure_f", [qml.counts, qml.expval, qml.probs, qml.sample]) -def test_broadcasting_qnode(mcm_method, shots, postselect, measure_f): - """Test that executing qnodes with broadcasting works as expected""" - - if mcm_method == "one-shot" and shots is None: - pytest.skip("`mcm_method='one-shot'` is incompatible with analytic mode (`shots=None`)") - - if measure_f in (qml.counts, qml.sample) and shots is None: - pytest.skip("Can't measure counts/sample in analytic mode (`shots=None`)") - - if measure_f is qml.sample and postselect is not None: - pytest.skip("Postselection with samples doesn't work with broadcasting") - - dev = get_device(shots=shots) - param = [[np.pi / 3, np.pi / 4], [np.pi / 6, 2 * np.pi / 3]] - obs = qml.PauliZ(0) @ qml.PauliZ(1) - - def func(x, y): - obs_tape(x, y, None, postselect=postselect) - return measure_f(op=obs) - - results0 = qml.QNode(func, dev, mcm_method=mcm_method)(*param) - results1 = qml.QNode(func, dev, mcm_method="deferred")(*param) - - mcm_utils.validate_measurements(measure_f, shots, results1, results0, batch_size=2) - - if measure_f is qml.sample and postselect is None: - for i in range(2): # batch_size - if isinstance(shots, list): - for s, r1, r2 in zip(shots, results1, results0): - assert len(r1[i]) == len(r2[i]) == s - else: - assert len(results1[i]) == len(results0[i]) == shots - - @pytest.mark.parametrize("mcm_method", ["one-shot", "tree-traversal"]) def test_sample_with_broadcasting_and_postselection_error(mcm_method): """Test that an error is raised if returning qml.sample if postselecting with broadcasting"""