diff --git a/pennylane/devices/qubit/simulate.py b/pennylane/devices/qubit/simulate.py index 89c041b8f3e..f6ba32d4c7f 100644 --- a/pennylane/devices/qubit/simulate.py +++ b/pennylane/devices/qubit/simulate.py @@ -955,6 +955,8 @@ def simulate_one_shot_native_mcm( ) -> Result: """Simulate a single shot of a single quantum script with native mid-circuit measurements. + Assumes that the circuit has been transformed by `dynamic_one_shot`. + Args: circuit (QuantumTape): The single circuit to simulate debugger (_Debugger): The debugger to use @@ -968,8 +970,8 @@ def simulate_one_shot_native_mcm( keep the same number of shots. Default is ``None``. Returns: - tuple(TensorLike): The results of the simulation - dict: The mid-circuit measurement results of the simulation + Result: The results of the simulation + """ prng_key = execution_kwargs.pop("prng_key", None) diff --git a/tests/devices/qubit/test_simulate.py b/tests/devices/qubit/test_simulate.py index 4dce5afd4c5..cb28c88648b 100644 --- a/tests/devices/qubit/test_simulate.py +++ b/tests/devices/qubit/test_simulate.py @@ -16,10 +16,12 @@ import numpy as np import pytest from dummy_debugger import Debugger +from flaky import flaky +from stat_utils import fisher_exact_test import pennylane as qml from pennylane.devices.qubit import get_final_state, measure_final_state, simulate -from pennylane.devices.qubit.simulate import _FlexShots +from pennylane.devices.qubit.simulate import _FlexShots, simulate_one_shot_native_mcm class TestCurrentlyUnsupportedCases: @@ -1178,3 +1180,74 @@ def test_qinfo_tf(self): grad5 = grad_tape.jacobian(results[5], phi) assert qml.math.allclose(grad5, expected_grads[5]) + + +ml_frameworks_list = [ + "numpy", + pytest.param("autograd", marks=pytest.mark.autograd), + pytest.param("jax", marks=pytest.mark.jax), + pytest.param("torch", marks=pytest.mark.torch), + pytest.param("tensorflow", marks=pytest.mark.tf), +] + + +# pylint:disable=too-few-public-methods +@pytest.mark.unit +class TestMidCircuitMeasurements: + """Unit tests for simulating mid-circuit measurements.""" + + @flaky(max_runs=3, min_passes=2) + @pytest.mark.parametrize("ml_framework", ml_frameworks_list) + @pytest.mark.parametrize( + "postselect_mode", [None, "hw-like", "pad-invalid-samples", "fill-shots"] + ) + def test_simulate_one_shot_native_mcm(self, ml_framework, postselect_mode): + """Unit tests for simulate_one_shot_native_mcm""" + + with qml.queuing.AnnotatedQueue() as q: + qml.RX(np.pi / 4, wires=0) + m = qml.measure(wires=0, postselect=0) + qml.RX(np.pi / 4, wires=0) + + circuit = qml.tape.QuantumScript(q.queue, [qml.expval(qml.Z(0)), qml.sample(m)], shots=[1]) + + n_shots = 200 + results = [ + simulate_one_shot_native_mcm( + circuit, + n_shots, + interface=ml_framework, + postselect_mode=postselect_mode, + ) + for _ in range(n_shots) + ] + terminal_results, mcm_results = zip(*results) + + if postselect_mode == "fill-shots": + assert all(ms == 0 for ms in mcm_results) + equivalent_tape = qml.tape.QuantumScript( + [qml.RX(np.pi / 4, wires=0)], [qml.expval(qml.Z(0))], shots=n_shots + ) + expected_sample = simulate(equivalent_tape) + fisher_exact_test(terminal_results, expected_sample, outcomes=(-1, 1)) + + else: + equivalent_tape = qml.tape.QuantumScript( + [qml.RX(np.pi / 4, wires=0)], [qml.sample(wires=0)], shots=n_shots + ) + expected_result = simulate(equivalent_tape) + fisher_exact_test(mcm_results, expected_result) + + subset = [ts for ms, ts in zip(mcm_results, terminal_results) if ms == 0] + equivalent_tape = qml.tape.QuantumScript( + [qml.RX(np.pi / 4, wires=0)], [qml.expval(qml.Z(0))], shots=n_shots + ) + expected_sample = simulate(equivalent_tape) + fisher_exact_test(subset, expected_sample, outcomes=(-1, 1)) + + subset = [ts for ms, ts in zip(mcm_results, terminal_results) if ms == 1] + equivalent_tape = qml.tape.QuantumScript( + [qml.X(0), qml.RX(np.pi / 4, wires=0)], [qml.expval(qml.Z(0))], shots=n_shots + ) + expected_sample = simulate(equivalent_tape) + fisher_exact_test(subset, expected_sample, outcomes=(-1, 1)) diff --git a/tests/helpers/stat_utils.py b/tests/helpers/stat_utils.py new file mode 100644 index 00000000000..e0699df6eca --- /dev/null +++ b/tests/helpers/stat_utils.py @@ -0,0 +1,32 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Helper functions for testing stochastic processes. +""" +import numpy as np +from scipy.stats import fisher_exact + + +def fisher_exact_test(actual, expected, outcomes=(0, 1), threshold=0.1): + """Checks that a binary sample matches the expected distribution using the Fisher exact test.""" + + actual, expected = np.asarray(actual), np.asarray(expected) + contingency_table = np.array( + [ + [np.sum(actual == outcomes[0]), np.sum(actual == outcomes[1])], + [np.sum(expected == outcomes[0]), np.sum(expected == outcomes[1])], + ] + ) + _, p_value = fisher_exact(contingency_table) + assert p_value > threshold, "The sample does not match the expected distribution."