From 34a8a1fbf0d9e71f3f2953b5c7128fe3a73e8a8d Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Mon, 16 Sep 2024 11:03:03 -0400 Subject: [PATCH] use fisher exact test --- tests/devices/qubit/test_simulate.py | 47 +++++++++++++++++++++------- 1 file changed, 35 insertions(+), 12 deletions(-) diff --git a/tests/devices/qubit/test_simulate.py b/tests/devices/qubit/test_simulate.py index a281dd689e9..30abb4a0afa 100644 --- a/tests/devices/qubit/test_simulate.py +++ b/tests/devices/qubit/test_simulate.py @@ -16,7 +16,7 @@ import numpy as np import pytest from dummy_debugger import Debugger -from scipy.stats import ttest_1samp +from scipy.stats import ttest_1samp, fisher_exact import pennylane as qml from pennylane.devices.qubit import get_final_state, measure_final_state, simulate @@ -1190,6 +1190,20 @@ def test_qinfo_tf(self): ] +def verify_binary_sample(actual, expected, outcomes=(0, 1), threshold=0.1): + """Checks that a binary sample matches the expected distribution using the Fisher exact test.""" + + actual, expected = np.asarray(actual), np.asarray(expected) + contingency_table = np.array( + [ + [np.sum(actual == outcomes[0]), np.sum(actual == outcomes[1])], + [np.sum(expected == outcomes[0]), np.sum(expected == outcomes[1])], + ] + ) + _, p_value = fisher_exact(contingency_table) + assert p_value > threshold, "The sample does not match the expected distribution." + + # pylint:disable=too-few-public-methods @pytest.mark.unit class TestMidCircuitMeasurements: @@ -1228,20 +1242,29 @@ def test_simulate_one_shot_native_mcm(self, ml_framework, postselect_mode): if postselect_mode == "fill-shots": assert all(ms == 0 for ms in mcm_results) - ttest_result = ttest_1samp(terminal_results, np.cos(np.pi / 4)) - assert ttest_result.pvalue > 0.05 + equivalent_tape = qml.tape.QuantumScript( + [qml.RX(np.pi / 4, wires=0)], [qml.expval(qml.Z(0))], shots=n_shots + ) + expected_sample = simulate(equivalent_tape, rng=rng) + verify_binary_sample(terminal_results, expected_sample, outcomes=(-1, 1)) else: - expected_mcm_average = (1 - np.cos(np.pi / 4)) / 2 - ttest_result = ttest_1samp(mcm_results, expected_mcm_average) - assert ttest_result.pvalue > 0.05 + equivalent_tape = qml.tape.QuantumScript( + [qml.RX(np.pi / 4, wires=0)], [qml.sample(wires=0)], shots=n_shots + ) + expected_result = simulate(equivalent_tape, rng=rng) + verify_binary_sample(mcm_results, expected_result) subset = [ts for ms, ts in zip(mcm_results, terminal_results) if ms == 0] - expected_average = np.cos(np.pi / 4) - ttest_result = ttest_1samp(subset, expected_average) - assert ttest_result.pvalue > 0.05 + equivalent_tape = qml.tape.QuantumScript( + [qml.RX(np.pi / 4, wires=0)], [qml.expval(qml.Z(0))], shots=n_shots + ) + expected_sample = simulate(equivalent_tape, rng=rng) + verify_binary_sample(subset, expected_sample, outcomes=(-1, 1)) subset = [ts for ms, ts in zip(mcm_results, terminal_results) if ms == 1] - expected_average = -np.cos(np.pi / 4) - ttest_result = ttest_1samp(subset, expected_average) - assert ttest_result.pvalue > 0.05 + equivalent_tape = qml.tape.QuantumScript( + [qml.X(0), qml.RX(np.pi / 4, wires=0)], [qml.expval(qml.Z(0))], shots=n_shots + ) + expected_sample = simulate(equivalent_tape, rng=rng) + verify_binary_sample(subset, expected_sample, outcomes=(-1, 1))