Skip to content

Commit

Permalink
Add tests for MidMeasureMP and Conditional dispatches (#6118)
Browse files Browse the repository at this point in the history
**Context:** Adds unit tests for `apply_operation`'s dispatches for
`MidMeasureMP` and `Conditional`

**Description of the Change:** Includes new tests in
`test_apply_operation.py`

**Benefits:** More reliance on unit-tests for catching bugs instead of
expensive system-level tests.

**Possible Drawbacks:** N/A

**Related GitHub Issues:** [sc-71559]
  • Loading branch information
obliviateandsurrender committed Sep 6, 2024
1 parent 40195fd commit df63953
Showing 1 changed file with 125 additions and 0 deletions.
125 changes: 125 additions & 0 deletions tests/devices/qubit/test_apply_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@

methods = [apply_operation_einsum, apply_operation_tensordot, apply_operation]

# pylint: disable=import-outside-toplevel,unsubscriptable-object,arguments-differ


def test_custom_operator_with_matrix():
"""Test that apply_operation works with any operation that defines a matrix."""
Expand All @@ -53,6 +55,8 @@ def test_custom_operator_with_matrix():

# pylint: disable=too-few-public-methods
class CustomOp(Operation):
"""Custom Operation"""

num_wires = 1

def matrix(self):
Expand Down Expand Up @@ -294,6 +298,8 @@ def f2(params, t):

@pytest.mark.jax
class TestApplyParametrizedEvolution:
"""Test that apply_operation works with ParametrizedEvolution"""

@pytest.mark.parametrize("method", methods)
def test_parameterized_evolution_time_independent(self, method):
"""Test that applying a ParametrizedEvolution gives the expected state
Expand Down Expand Up @@ -651,6 +657,7 @@ class TestRXCalcGrad:
)

def compare_expected_result(self, phi, state, new_state, g):
"""Compares the new state against the expected state"""
expected0 = np.cos(phi / 2) * state[0, :, :] + -1j * np.sin(phi / 2) * state[1, :, :]
expected1 = -1j * np.sin(phi / 2) * state[0, :, :] + np.cos(phi / 2) * state[1, :, :]

Expand Down Expand Up @@ -1267,3 +1274,121 @@ def circuit(init_state):
results = circuit(tf.Variable(states))
assert qml.math.shape(results) == (3, 256)
assert np.array_equal(results[:, 128], [-1.0 + 0.0j] * 3)


# pylint: disable=too-few-public-methods
class TestConditionalsAndMidMeasure:
"""Test dispatching for mid-circuit measurements and conditionals."""

@pytest.mark.all_interfaces
@pytest.mark.parametrize("ml_framework", ml_frameworks_list)
@pytest.mark.parametrize("batched", (False, True))
@pytest.mark.parametrize("unitary", (qml.CRX, qml.CRZ))
@pytest.mark.parametrize("wires", ([0, 1], [1, 0]))
def test_conditional(self, wires, unitary, batched, ml_framework):
"""Test the application of a Conditional on an arbitrary state."""

n_states = int(batched) + 1
initial_state = np.array(
[
[
0.3541035 + 0.05231577j,
0.6912382 + 0.49474503j,
0.29276263 + 0.06231887j,
0.10736635 + 0.21947607j,
],
[
0.09803567 + 0.47557068j,
0.4427561 + 0.13810454j,
0.26421703 + 0.5366283j,
0.03825933 + 0.4357423j,
],
][:n_states]
)

rotated_state = qml.math.dot(
initial_state, qml.matrix(unitary(-0.238, wires), wire_order=[0, 1]).T
)
rotated_state = qml.math.asarray(rotated_state, like=ml_framework)
rotated_state = qml.math.squeeze(qml.math.reshape(rotated_state, (n_states, 2, 2)))

m0 = qml.measure(0)
op = qml.ops.op_math.Conditional(m0, unitary(0.238, wires))

mid_meas = {m0.measurements[0]: 0}
old_state = apply_operation(
op, rotated_state, batched, interface=ml_framework, mid_measurements=mid_meas
)
assert qml.math.allclose(rotated_state, old_state)

mid_meas[m0.measurements[0]] = 1
new_state = apply_operation(
op, rotated_state, batched, interface=ml_framework, mid_measurements=mid_meas
)
assert qml.math.allclose(
qml.math.squeeze(initial_state), qml.math.reshape(new_state, (n_states, 4))
)

@pytest.mark.parametrize("rng_seed, m_res", ((12, (0, 0)), (42, (1, 1))))
def test_mid_measure(self, rng_seed, m_res):
"""Test the application of a MidMeasureMP on an arbitrary state to give a basis state."""

initial_state = np.array(
[
[0.09068964 + 0.36775595j, 0.37578343 + 0.4786927j],
[0.3537292 + 0.27214766j, 0.01928256 + 0.53536021j],
]
)

mid_state, end_state = np.zeros((2, 2), dtype=complex), np.zeros((2, 2), dtype=complex)
mid_state[m_res[0]] = initial_state[m_res[0]] / np.linalg.norm(initial_state[m_res[0]])
end_state[m_res] = mid_state[m_res] / np.abs(mid_state[m_res])

rng = np.random.default_rng(rng_seed)
m0, m1 = qml.measure(0).measurements[0], qml.measure(1).measurements[0]
mid_meas = {}

res_state = apply_operation(m0, initial_state, mid_measurements=mid_meas, rng=rng)
assert qml.math.allclose(mid_state, res_state)

res_state = apply_operation(m1, res_state, mid_measurements=mid_meas, rng=rng)
assert qml.math.allclose(end_state, res_state)

assert mid_meas == {m0: m_res[0], m1: m_res[1]}

@pytest.mark.parametrize("reset", (False, True))
@pytest.mark.parametrize("m_res", ([0, 0], [1, 0], [1, 1]))
def test_mid_measure_with_postselect_and_reset(self, m_res, reset):
"""Test the application of a MidMeasureMP with postselection and reset."""

initial_state = np.array([[0.5 + 0.0j, 0.5 + 0.0j], [0.5 + 0.0j, 0.5 + 0.0j]])
mid_state, end_state = np.zeros((4, 4)), np.zeros((4, 4))

if reset:
m_res[0] = 0

mid_state[2 * m_res[0] : 2 * (m_res[0] + 1), 2 * m_res[0] : 2 * (m_res[0] + 1)] = 0.5
end_state[2 * m_res[0] + m_res[1], 2 * m_res[0] + m_res[1]] = 1.0

m0 = qml.measure(0, postselect=m_res[0], reset=reset).measurements[0]
m1 = qml.measure(1, postselect=m_res[1]).measurements[0]
mid_meas = {m0: m_res[0], m1: m_res[1]}

new_state = apply_operation(
m0, initial_state, mid_measurements=mid_meas, postselect_mode="fill-shots"
)
res_state = qml.math.reshape(new_state, 4)
assert qml.math.allclose(mid_state, qml.math.outer(res_state, res_state))

new_state = apply_operation(
m1, new_state, mid_measurements=mid_meas, postselect_mode="fill-shots"
)
res_state = qml.math.reshape(new_state, 4)
assert qml.math.allclose(end_state, qml.math.outer(res_state, res_state))

def test_error_bactched_mid_measure(self):
"""Test that an error is raised when mid_measure is applied to a batched input state."""

with pytest.raises(ValueError, match="MidMeasureMP cannot be applied to batched states."):
m0, input_state = qml.measure(0).measurements[0], qml.math.array([[1, 0], [1, 0]])
apply_operation(m0, state=input_state, is_state_batched=True)

0 comments on commit df63953

Please sign in to comment.