Skip to content

Commit

Permalink
use fisher exact test
Browse files Browse the repository at this point in the history
  • Loading branch information
astralcai committed Sep 16, 2024
1 parent ddac214 commit 34a8a1f
Showing 1 changed file with 35 additions and 12 deletions.
47 changes: 35 additions & 12 deletions tests/devices/qubit/test_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import numpy as np
import pytest
from dummy_debugger import Debugger
from scipy.stats import ttest_1samp
from scipy.stats import ttest_1samp, fisher_exact

import pennylane as qml
from pennylane.devices.qubit import get_final_state, measure_final_state, simulate
Expand Down Expand Up @@ -1190,6 +1190,20 @@ def test_qinfo_tf(self):
]


def verify_binary_sample(actual, expected, outcomes=(0, 1), threshold=0.1):
"""Checks that a binary sample matches the expected distribution using the Fisher exact test."""

actual, expected = np.asarray(actual), np.asarray(expected)
contingency_table = np.array(
[
[np.sum(actual == outcomes[0]), np.sum(actual == outcomes[1])],
[np.sum(expected == outcomes[0]), np.sum(expected == outcomes[1])],
]
)
_, p_value = fisher_exact(contingency_table)
assert p_value > threshold, "The sample does not match the expected distribution."


# pylint:disable=too-few-public-methods
@pytest.mark.unit
class TestMidCircuitMeasurements:
Expand Down Expand Up @@ -1228,20 +1242,29 @@ def test_simulate_one_shot_native_mcm(self, ml_framework, postselect_mode):

if postselect_mode == "fill-shots":
assert all(ms == 0 for ms in mcm_results)
ttest_result = ttest_1samp(terminal_results, np.cos(np.pi / 4))
assert ttest_result.pvalue > 0.05
equivalent_tape = qml.tape.QuantumScript(
[qml.RX(np.pi / 4, wires=0)], [qml.expval(qml.Z(0))], shots=n_shots
)
expected_sample = simulate(equivalent_tape, rng=rng)
verify_binary_sample(terminal_results, expected_sample, outcomes=(-1, 1))

else:
expected_mcm_average = (1 - np.cos(np.pi / 4)) / 2
ttest_result = ttest_1samp(mcm_results, expected_mcm_average)
assert ttest_result.pvalue > 0.05
equivalent_tape = qml.tape.QuantumScript(
[qml.RX(np.pi / 4, wires=0)], [qml.sample(wires=0)], shots=n_shots
)
expected_result = simulate(equivalent_tape, rng=rng)
verify_binary_sample(mcm_results, expected_result)

subset = [ts for ms, ts in zip(mcm_results, terminal_results) if ms == 0]
expected_average = np.cos(np.pi / 4)
ttest_result = ttest_1samp(subset, expected_average)
assert ttest_result.pvalue > 0.05
equivalent_tape = qml.tape.QuantumScript(
[qml.RX(np.pi / 4, wires=0)], [qml.expval(qml.Z(0))], shots=n_shots
)
expected_sample = simulate(equivalent_tape, rng=rng)
verify_binary_sample(subset, expected_sample, outcomes=(-1, 1))

subset = [ts for ms, ts in zip(mcm_results, terminal_results) if ms == 1]
expected_average = -np.cos(np.pi / 4)
ttest_result = ttest_1samp(subset, expected_average)
assert ttest_result.pvalue > 0.05
equivalent_tape = qml.tape.QuantumScript(
[qml.X(0), qml.RX(np.pi / 4, wires=0)], [qml.expval(qml.Z(0))], shots=n_shots
)
expected_sample = simulate(equivalent_tape, rng=rng)
verify_binary_sample(subset, expected_sample, outcomes=(-1, 1))

0 comments on commit 34a8a1f

Please sign in to comment.