diff --git a/tests/transforms/test_optimization/test_optimization_utils.py b/tests/transforms/test_optimization/test_optimization_utils.py index 945797999e1..0872bc8fc30 100644 --- a/tests/transforms/test_optimization/test_optimization_utils.py +++ b/tests/transforms/test_optimization/test_optimization_utils.py @@ -75,24 +75,29 @@ class TestRotGateFusion: ([0.9, np.pi / 2, np.pi / 2], [-np.pi / 2, -np.pi / 2, 0.0]), ] - @staticmethod - def mat_from_prod(angles_1, angles_2): + def run_interface_test(self, angles_1, angles_2): + """Execute standard test lines for different interfaces and batch tests. + Note that the transpose calls only are relevant for tests with batching.""" + def original_ops(): - qml.Rot(*angles_1, wires=0) - qml.Rot(*angles_2, wires=0) + qml.Rot(*qml.math.transpose(angles_1), wires=0) + qml.Rot(*qml.math.transpose(angles_2), wires=0) + + matrix_expected = qml.matrix(original_ops, [0])() # pylint:disable=too-many-function-args + + fused_angles = fuse_rot_angles(angles_1, angles_2) + # The reshape is only used in the _mixed_batching test. Otherwise it is irrelevant. + matrix_obtained = qml.Rot( + *qml.math.transpose(qml.math.reshape(fused_angles, (-1, 3))), wires=0 + ).matrix() - return qml.matrix(original_ops, [0])() # pylint:disable=too-many-function-args + assert qml.math.allclose(matrix_expected, matrix_obtained) @pytest.mark.parametrize("angles_1, angles_2", generic_test_angles) def test_full_rot_fusion_numpy(self, angles_1, angles_2): """Test that the fusion of two Rot gates has the same effect as applying the Rots sequentially.""" - - matrix_expected = self.mat_from_prod(angles_1, angles_2) - fused_angles = fuse_rot_angles(angles_1, angles_2) - matrix_obtained = qml.Rot(*fused_angles, wires=0).matrix() - - assert qml.math.allclose(matrix_expected, matrix_obtained) + self.run_interface_test(angles_1, angles_2) mixed_batched_angles = [ ([[0.4, 0.1, 0.0], [0.7, 0.2, 0.1]], [-0.9, 1.2, 0.6]), # (2, None) @@ -110,14 +115,9 @@ def test_full_rot_fusion_mixed_batching(self, angles_1, angles_2): applying the Rots sequentially when the input angles are batched with mixed batching shapes.""" - fused_angles = fuse_rot_angles(angles_1, angles_2) - reshaped_angles_1 = np.reshape(angles_1, (-1, 3) if np.ndim(angles_1) > 1 else (3,)) reshaped_angles_2 = np.reshape(angles_2, (-1, 3) if np.ndim(angles_2) > 1 else (3,)) - matrix_expected = self.mat_from_prod(reshaped_angles_1.T, reshaped_angles_2.T) - matrix_obtained = qml.Rot(*fused_angles.reshape((-1, 3)).T, wires=0).matrix() - - assert qml.math.allclose(matrix_expected, matrix_obtained) + self.run_interface_test(reshaped_angles_1, reshaped_angles_2) @pytest.mark.autograd @pytest.mark.parametrize("angles_1, angles_2", generic_test_angles) @@ -126,11 +126,7 @@ def test_full_rot_fusion_autograd(self, angles_1, angles_2): applying the Rots sequentially, in Autograd.""" angles_1, angles_2 = qml.numpy.array(angles_1), qml.numpy.array(angles_1) - matrix_expected = self.mat_from_prod(angles_1, angles_2) - fused_angles = fuse_rot_angles(angles_1, angles_2) - matrix_obtained = qml.Rot(*fused_angles, wires=0).matrix() - - assert qml.math.allclose(matrix_expected, matrix_obtained) + self.run_interface_test(angles_1, angles_2) @pytest.mark.tf @pytest.mark.parametrize("angles_1, angles_2", generic_test_angles) @@ -139,14 +135,9 @@ def test_full_rot_fusion_tensorflow(self, angles_1, angles_2): applying the Rots sequentially, in Tensorflow.""" import tensorflow as tf - angles_1, angles_2 = tf.Variable(angles_1, dtype=tf.float64), tf.Variable( - angles_1, dtype=tf.float64 - ) - matrix_expected = self.mat_from_prod(angles_1, angles_2) - fused_angles = fuse_rot_angles(angles_1, angles_2) - matrix_obtained = qml.Rot(*fused_angles, wires=0).matrix() - - assert qml.math.allclose(matrix_expected, matrix_obtained) + angles_1 = tf.Variable(angles_1, dtype=tf.float64) + angles_2 = tf.Variable(angles_2, dtype=tf.float64) + self.run_interface_test(angles_1, angles_2) @pytest.mark.torch @pytest.mark.parametrize("angles_1, angles_2", generic_test_angles) @@ -155,14 +146,9 @@ def test_full_rot_fusion_torch(self, angles_1, angles_2): applying the Rots sequentially, in torch.""" import torch - angles_1, angles_2 = torch.tensor( - angles_1, requires_grad=True, dtype=torch.float64 - ), torch.tensor(angles_1, requires_grad=True, dtype=torch.float64) - matrix_expected = self.mat_from_prod(angles_1, angles_2) - fused_angles = fuse_rot_angles(angles_1, angles_2) - matrix_obtained = qml.Rot(*fused_angles, wires=0).matrix() - - assert qml.math.allclose(matrix_expected, matrix_obtained) + angles_1 = torch.tensor(angles_1, requires_grad=True, dtype=torch.float64) + angles_2 = torch.tensor(angles_2, requires_grad=True, dtype=torch.float64) + self.run_interface_test(angles_1, angles_2) @pytest.mark.jax @pytest.mark.parametrize("angles_1, angles_2", generic_test_angles) @@ -172,11 +158,7 @@ def test_full_rot_fusion_jax(self, angles_1, angles_2): import jax angles_1, angles_2 = jax.numpy.array(angles_1), jax.numpy.array(angles_1) - matrix_expected = self.mat_from_prod(angles_1, angles_2) - fused_angles = fuse_rot_angles(angles_1, angles_2) - matrix_obtained = qml.Rot(*fused_angles, wires=0).matrix() - - assert qml.math.allclose(matrix_expected, matrix_obtained) + self.run_interface_test(angles_1, angles_2) @pytest.mark.slow def test_full_rot_fusion_special_angles(self): @@ -188,15 +170,7 @@ def test_full_rot_fusion_special_angles(self): special_points = np.array([3 / 2, 1, 1 / 2, 0, -1 / 2, -1, -3 / 2]) * np.pi special_angles = np.array(list(product(special_points, repeat=6))).reshape((-1, 2, 3)) angles_1, angles_2 = np.transpose(special_angles, (1, 0, 2)) - - # Transpose to bring size-3 axis to front - matrix_expected = self.mat_from_prod(angles_1.T, angles_2.T) - - fused_angles = fuse_rot_angles(angles_1, angles_2) - # Transpose to bring size-3 axis to front - matrix_obtained = qml.Rot(*fused_angles.T, wires=0).matrix() - - assert qml.math.allclose(matrix_expected, matrix_obtained) + self.run_interface_test(angles_1, angles_2) @pytest.mark.slow @pytest.mark.jax