diff --git a/tests/devices/qubit/test_apply_operation.py b/tests/devices/qubit/test_apply_operation.py index e9239f27679..49abf3e462b 100644 --- a/tests/devices/qubit/test_apply_operation.py +++ b/tests/devices/qubit/test_apply_operation.py @@ -40,6 +40,8 @@ methods = [apply_operation_einsum, apply_operation_tensordot, apply_operation] +# pylint: disable=import-outside-toplevel,unsubscriptable-object,arguments-differ + def test_custom_operator_with_matrix(): """Test that apply_operation works with any operation that defines a matrix.""" @@ -53,6 +55,8 @@ def test_custom_operator_with_matrix(): # pylint: disable=too-few-public-methods class CustomOp(Operation): + """Custom Operation""" + num_wires = 1 def matrix(self): @@ -294,6 +298,8 @@ def f2(params, t): @pytest.mark.jax class TestApplyParametrizedEvolution: + """Test that apply_operation works with ParametrizedEvolution""" + @pytest.mark.parametrize("method", methods) def test_parameterized_evolution_time_independent(self, method): """Test that applying a ParametrizedEvolution gives the expected state @@ -651,6 +657,7 @@ class TestRXCalcGrad: ) def compare_expected_result(self, phi, state, new_state, g): + """Compares the new state against the expected state""" expected0 = np.cos(phi / 2) * state[0, :, :] + -1j * np.sin(phi / 2) * state[1, :, :] expected1 = -1j * np.sin(phi / 2) * state[0, :, :] + np.cos(phi / 2) * state[1, :, :] @@ -1267,3 +1274,121 @@ def circuit(init_state): results = circuit(tf.Variable(states)) assert qml.math.shape(results) == (3, 256) assert np.array_equal(results[:, 128], [-1.0 + 0.0j] * 3) + + +# pylint: disable=too-few-public-methods +class TestConditionalsAndMidMeasure: + """Test dispatching for mid-circuit measurements and conditionals.""" + + @pytest.mark.all_interfaces + @pytest.mark.parametrize("ml_framework", ml_frameworks_list) + @pytest.mark.parametrize("batched", (False, True)) + @pytest.mark.parametrize("unitary", (qml.CRX, qml.CRZ)) + @pytest.mark.parametrize("wires", ([0, 1], [1, 0])) + def test_conditional(self, wires, unitary, batched, ml_framework): + """Test the application of a Conditional on an arbitrary state.""" + + n_states = int(batched) + 1 + initial_state = np.array( + [ + [ + 0.3541035 + 0.05231577j, + 0.6912382 + 0.49474503j, + 0.29276263 + 0.06231887j, + 0.10736635 + 0.21947607j, + ], + [ + 0.09803567 + 0.47557068j, + 0.4427561 + 0.13810454j, + 0.26421703 + 0.5366283j, + 0.03825933 + 0.4357423j, + ], + ][:n_states] + ) + + rotated_state = qml.math.dot( + initial_state, qml.matrix(unitary(-0.238, wires), wire_order=[0, 1]).T + ) + rotated_state = qml.math.asarray(rotated_state, like=ml_framework) + rotated_state = qml.math.squeeze(qml.math.reshape(rotated_state, (n_states, 2, 2))) + + m0 = qml.measure(0) + op = qml.ops.op_math.Conditional(m0, unitary(0.238, wires)) + + mid_meas = {m0.measurements[0]: 0} + old_state = apply_operation( + op, rotated_state, batched, interface=ml_framework, mid_measurements=mid_meas + ) + assert qml.math.allclose(rotated_state, old_state) + + mid_meas[m0.measurements[0]] = 1 + new_state = apply_operation( + op, rotated_state, batched, interface=ml_framework, mid_measurements=mid_meas + ) + assert qml.math.allclose( + qml.math.squeeze(initial_state), qml.math.reshape(new_state, (n_states, 4)) + ) + + @pytest.mark.parametrize("rng_seed, m_res", ((12, (0, 0)), (42, (1, 1)))) + def test_mid_measure(self, rng_seed, m_res): + """Test the application of a MidMeasureMP on an arbitrary state to give a basis state.""" + + initial_state = np.array( + [ + [0.09068964 + 0.36775595j, 0.37578343 + 0.4786927j], + [0.3537292 + 0.27214766j, 0.01928256 + 0.53536021j], + ] + ) + + mid_state, end_state = np.zeros((2, 2), dtype=complex), np.zeros((2, 2), dtype=complex) + mid_state[m_res[0]] = initial_state[m_res[0]] / np.linalg.norm(initial_state[m_res[0]]) + end_state[m_res] = mid_state[m_res] / np.abs(mid_state[m_res]) + + rng = np.random.default_rng(rng_seed) + m0, m1 = qml.measure(0).measurements[0], qml.measure(1).measurements[0] + mid_meas = {} + + res_state = apply_operation(m0, initial_state, mid_measurements=mid_meas, rng=rng) + assert qml.math.allclose(mid_state, res_state) + + res_state = apply_operation(m1, res_state, mid_measurements=mid_meas, rng=rng) + assert qml.math.allclose(end_state, res_state) + + assert mid_meas == {m0: m_res[0], m1: m_res[1]} + + @pytest.mark.parametrize("reset", (False, True)) + @pytest.mark.parametrize("m_res", ([0, 0], [1, 0], [1, 1])) + def test_mid_measure_with_postselect_and_reset(self, m_res, reset): + """Test the application of a MidMeasureMP with postselection and reset.""" + + initial_state = np.array([[0.5 + 0.0j, 0.5 + 0.0j], [0.5 + 0.0j, 0.5 + 0.0j]]) + mid_state, end_state = np.zeros((4, 4)), np.zeros((4, 4)) + + if reset: + m_res[0] = 0 + + mid_state[2 * m_res[0] : 2 * (m_res[0] + 1), 2 * m_res[0] : 2 * (m_res[0] + 1)] = 0.5 + end_state[2 * m_res[0] + m_res[1], 2 * m_res[0] + m_res[1]] = 1.0 + + m0 = qml.measure(0, postselect=m_res[0], reset=reset).measurements[0] + m1 = qml.measure(1, postselect=m_res[1]).measurements[0] + mid_meas = {m0: m_res[0], m1: m_res[1]} + + new_state = apply_operation( + m0, initial_state, mid_measurements=mid_meas, postselect_mode="fill-shots" + ) + res_state = qml.math.reshape(new_state, 4) + assert qml.math.allclose(mid_state, qml.math.outer(res_state, res_state)) + + new_state = apply_operation( + m1, new_state, mid_measurements=mid_meas, postselect_mode="fill-shots" + ) + res_state = qml.math.reshape(new_state, 4) + assert qml.math.allclose(end_state, qml.math.outer(res_state, res_state)) + + def test_error_bactched_mid_measure(self): + """Test that an error is raised when mid_measure is applied to a batched input state.""" + + with pytest.raises(ValueError, match="MidMeasureMP cannot be applied to batched states."): + m0, input_state = qml.measure(0).measurements[0], qml.math.array([[1, 0], [1, 0]]) + apply_operation(m0, state=input_state, is_state_batched=True)