Skip to content

Commit

Permalink
Unit tests for `simulate_one_shot_native_mcm (#6124)
Browse files Browse the repository at this point in the history
Implement a unit test for `simulate_one_shot_native_mcm`

[sc-71560]

---------

Co-authored-by: Ali Asadi <10773383+maliasadi@users.noreply.github.com>
  • Loading branch information
astralcai and maliasadi committed Sep 19, 2024
1 parent cbc5a9d commit 8d8bece
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 3 deletions.
6 changes: 4 additions & 2 deletions pennylane/devices/qubit/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,6 +955,8 @@ def simulate_one_shot_native_mcm(
) -> Result:
"""Simulate a single shot of a single quantum script with native mid-circuit measurements.
Assumes that the circuit has been transformed by `dynamic_one_shot`.
Args:
circuit (QuantumTape): The single circuit to simulate
debugger (_Debugger): The debugger to use
Expand All @@ -968,8 +970,8 @@ def simulate_one_shot_native_mcm(
keep the same number of shots. Default is ``None``.
Returns:
tuple(TensorLike): The results of the simulation
dict: The mid-circuit measurement results of the simulation
Result: The results of the simulation
"""
prng_key = execution_kwargs.pop("prng_key", None)

Expand Down
75 changes: 74 additions & 1 deletion tests/devices/qubit/test_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
import numpy as np
import pytest
from dummy_debugger import Debugger
from flaky import flaky
from stat_utils import fisher_exact_test

import pennylane as qml
from pennylane.devices.qubit import get_final_state, measure_final_state, simulate
from pennylane.devices.qubit.simulate import _FlexShots
from pennylane.devices.qubit.simulate import _FlexShots, simulate_one_shot_native_mcm


class TestCurrentlyUnsupportedCases:
Expand Down Expand Up @@ -1178,3 +1180,74 @@ def test_qinfo_tf(self):

grad5 = grad_tape.jacobian(results[5], phi)
assert qml.math.allclose(grad5, expected_grads[5])


ml_frameworks_list = [
"numpy",
pytest.param("autograd", marks=pytest.mark.autograd),
pytest.param("jax", marks=pytest.mark.jax),
pytest.param("torch", marks=pytest.mark.torch),
pytest.param("tensorflow", marks=pytest.mark.tf),
]


# pylint:disable=too-few-public-methods
@pytest.mark.unit
class TestMidCircuitMeasurements:
"""Unit tests for simulating mid-circuit measurements."""

@flaky(max_runs=3, min_passes=2)
@pytest.mark.parametrize("ml_framework", ml_frameworks_list)
@pytest.mark.parametrize(
"postselect_mode", [None, "hw-like", "pad-invalid-samples", "fill-shots"]
)
def test_simulate_one_shot_native_mcm(self, ml_framework, postselect_mode):
"""Unit tests for simulate_one_shot_native_mcm"""

with qml.queuing.AnnotatedQueue() as q:
qml.RX(np.pi / 4, wires=0)
m = qml.measure(wires=0, postselect=0)
qml.RX(np.pi / 4, wires=0)

circuit = qml.tape.QuantumScript(q.queue, [qml.expval(qml.Z(0)), qml.sample(m)], shots=[1])

n_shots = 200
results = [
simulate_one_shot_native_mcm(
circuit,
n_shots,
interface=ml_framework,
postselect_mode=postselect_mode,
)
for _ in range(n_shots)
]
terminal_results, mcm_results = zip(*results)

if postselect_mode == "fill-shots":
assert all(ms == 0 for ms in mcm_results)
equivalent_tape = qml.tape.QuantumScript(
[qml.RX(np.pi / 4, wires=0)], [qml.expval(qml.Z(0))], shots=n_shots
)
expected_sample = simulate(equivalent_tape)
fisher_exact_test(terminal_results, expected_sample, outcomes=(-1, 1))

else:
equivalent_tape = qml.tape.QuantumScript(
[qml.RX(np.pi / 4, wires=0)], [qml.sample(wires=0)], shots=n_shots
)
expected_result = simulate(equivalent_tape)
fisher_exact_test(mcm_results, expected_result)

subset = [ts for ms, ts in zip(mcm_results, terminal_results) if ms == 0]
equivalent_tape = qml.tape.QuantumScript(
[qml.RX(np.pi / 4, wires=0)], [qml.expval(qml.Z(0))], shots=n_shots
)
expected_sample = simulate(equivalent_tape)
fisher_exact_test(subset, expected_sample, outcomes=(-1, 1))

subset = [ts for ms, ts in zip(mcm_results, terminal_results) if ms == 1]
equivalent_tape = qml.tape.QuantumScript(
[qml.X(0), qml.RX(np.pi / 4, wires=0)], [qml.expval(qml.Z(0))], shots=n_shots
)
expected_sample = simulate(equivalent_tape)
fisher_exact_test(subset, expected_sample, outcomes=(-1, 1))
32 changes: 32 additions & 0 deletions tests/helpers/stat_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright 2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Helper functions for testing stochastic processes.
"""
import numpy as np
from scipy.stats import fisher_exact


def fisher_exact_test(actual, expected, outcomes=(0, 1), threshold=0.1):
"""Checks that a binary sample matches the expected distribution using the Fisher exact test."""

actual, expected = np.asarray(actual), np.asarray(expected)
contingency_table = np.array(
[
[np.sum(actual == outcomes[0]), np.sum(actual == outcomes[1])],
[np.sum(expected == outcomes[0]), np.sum(expected == outcomes[1])],
]
)
_, p_value = fisher_exact(contingency_table)
assert p_value > threshold, "The sample does not match the expected distribution."

0 comments on commit 8d8bece

Please sign in to comment.