Skip to content

Commit

Permalink
squash tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dwierichs committed Jul 25, 2024
1 parent 42c80d6 commit ccfcb34
Showing 1 changed file with 26 additions and 52 deletions.
78 changes: 26 additions & 52 deletions tests/transforms/test_optimization/test_optimization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,24 +75,29 @@ class TestRotGateFusion:
([0.9, np.pi / 2, np.pi / 2], [-np.pi / 2, -np.pi / 2, 0.0]),
]

@staticmethod
def mat_from_prod(angles_1, angles_2):
def run_interface_test(self, angles_1, angles_2):
"""Execute standard test lines for different interfaces and batch tests.
Note that the transpose calls only are relevant for tests with batching."""

def original_ops():
qml.Rot(*angles_1, wires=0)
qml.Rot(*angles_2, wires=0)
qml.Rot(*qml.math.transpose(angles_1), wires=0)
qml.Rot(*qml.math.transpose(angles_2), wires=0)

matrix_expected = qml.matrix(original_ops, [0])() # pylint:disable=too-many-function-args

fused_angles = fuse_rot_angles(angles_1, angles_2)
# The reshape is only used in the _mixed_batching test. Otherwise it is irrelevant.
matrix_obtained = qml.Rot(
*qml.math.transpose(qml.math.reshape(fused_angles, (-1, 3))), wires=0
).matrix()

return qml.matrix(original_ops, [0])() # pylint:disable=too-many-function-args
assert qml.math.allclose(matrix_expected, matrix_obtained)

@pytest.mark.parametrize("angles_1, angles_2", generic_test_angles)
def test_full_rot_fusion_numpy(self, angles_1, angles_2):
"""Test that the fusion of two Rot gates has the same effect as
applying the Rots sequentially."""

matrix_expected = self.mat_from_prod(angles_1, angles_2)
fused_angles = fuse_rot_angles(angles_1, angles_2)
matrix_obtained = qml.Rot(*fused_angles, wires=0).matrix()

assert qml.math.allclose(matrix_expected, matrix_obtained)
self.run_interface_test(angles_1, angles_2)

mixed_batched_angles = [
([[0.4, 0.1, 0.0], [0.7, 0.2, 0.1]], [-0.9, 1.2, 0.6]), # (2, None)
Expand All @@ -110,14 +115,9 @@ def test_full_rot_fusion_mixed_batching(self, angles_1, angles_2):
applying the Rots sequentially when the input angles are batched
with mixed batching shapes."""

fused_angles = fuse_rot_angles(angles_1, angles_2)

reshaped_angles_1 = np.reshape(angles_1, (-1, 3) if np.ndim(angles_1) > 1 else (3,))
reshaped_angles_2 = np.reshape(angles_2, (-1, 3) if np.ndim(angles_2) > 1 else (3,))
matrix_expected = self.mat_from_prod(reshaped_angles_1.T, reshaped_angles_2.T)
matrix_obtained = qml.Rot(*fused_angles.reshape((-1, 3)).T, wires=0).matrix()

assert qml.math.allclose(matrix_expected, matrix_obtained)
self.run_interface_test(reshaped_angles_1, reshaped_angles_2)

@pytest.mark.autograd
@pytest.mark.parametrize("angles_1, angles_2", generic_test_angles)
Expand All @@ -126,11 +126,7 @@ def test_full_rot_fusion_autograd(self, angles_1, angles_2):
applying the Rots sequentially, in Autograd."""

angles_1, angles_2 = qml.numpy.array(angles_1), qml.numpy.array(angles_1)
matrix_expected = self.mat_from_prod(angles_1, angles_2)
fused_angles = fuse_rot_angles(angles_1, angles_2)
matrix_obtained = qml.Rot(*fused_angles, wires=0).matrix()

assert qml.math.allclose(matrix_expected, matrix_obtained)
self.run_interface_test(angles_1, angles_2)

@pytest.mark.tf
@pytest.mark.parametrize("angles_1, angles_2", generic_test_angles)
Expand All @@ -139,14 +135,9 @@ def test_full_rot_fusion_tensorflow(self, angles_1, angles_2):
applying the Rots sequentially, in Tensorflow."""
import tensorflow as tf

angles_1, angles_2 = tf.Variable(angles_1, dtype=tf.float64), tf.Variable(
angles_1, dtype=tf.float64
)
matrix_expected = self.mat_from_prod(angles_1, angles_2)
fused_angles = fuse_rot_angles(angles_1, angles_2)
matrix_obtained = qml.Rot(*fused_angles, wires=0).matrix()

assert qml.math.allclose(matrix_expected, matrix_obtained)
angles_1 = tf.Variable(angles_1, dtype=tf.float64)
angles_2 = tf.Variable(angles_2, dtype=tf.float64)
self.run_interface_test(angles_1, angles_2)

@pytest.mark.torch
@pytest.mark.parametrize("angles_1, angles_2", generic_test_angles)
Expand All @@ -155,14 +146,9 @@ def test_full_rot_fusion_torch(self, angles_1, angles_2):
applying the Rots sequentially, in torch."""
import torch

angles_1, angles_2 = torch.tensor(
angles_1, requires_grad=True, dtype=torch.float64
), torch.tensor(angles_1, requires_grad=True, dtype=torch.float64)
matrix_expected = self.mat_from_prod(angles_1, angles_2)
fused_angles = fuse_rot_angles(angles_1, angles_2)
matrix_obtained = qml.Rot(*fused_angles, wires=0).matrix()

assert qml.math.allclose(matrix_expected, matrix_obtained)
angles_1 = torch.tensor(angles_1, requires_grad=True, dtype=torch.float64)
angles_2 = torch.tensor(angles_2, requires_grad=True, dtype=torch.float64)
self.run_interface_test(angles_1, angles_2)

@pytest.mark.jax
@pytest.mark.parametrize("angles_1, angles_2", generic_test_angles)
Expand All @@ -172,11 +158,7 @@ def test_full_rot_fusion_jax(self, angles_1, angles_2):
import jax

angles_1, angles_2 = jax.numpy.array(angles_1), jax.numpy.array(angles_1)
matrix_expected = self.mat_from_prod(angles_1, angles_2)
fused_angles = fuse_rot_angles(angles_1, angles_2)
matrix_obtained = qml.Rot(*fused_angles, wires=0).matrix()

assert qml.math.allclose(matrix_expected, matrix_obtained)
self.run_interface_test(angles_1, angles_2)

@pytest.mark.slow
def test_full_rot_fusion_special_angles(self):
Expand All @@ -188,15 +170,7 @@ def test_full_rot_fusion_special_angles(self):
special_points = np.array([3 / 2, 1, 1 / 2, 0, -1 / 2, -1, -3 / 2]) * np.pi
special_angles = np.array(list(product(special_points, repeat=6))).reshape((-1, 2, 3))
angles_1, angles_2 = np.transpose(special_angles, (1, 0, 2))

# Transpose to bring size-3 axis to front
matrix_expected = self.mat_from_prod(angles_1.T, angles_2.T)

fused_angles = fuse_rot_angles(angles_1, angles_2)
# Transpose to bring size-3 axis to front
matrix_obtained = qml.Rot(*fused_angles.T, wires=0).matrix()

assert qml.math.allclose(matrix_expected, matrix_obtained)
self.run_interface_test(angles_1, angles_2)

@pytest.mark.slow
@pytest.mark.jax
Expand Down

0 comments on commit ccfcb34

Please sign in to comment.