diff --git a/tests/ops/qubit/test_parametric_ops.py b/tests/ops/qubit/test_parametric_ops.py index 2fe069bb444..176d7960d59 100644 --- a/tests/ops/qubit/test_parametric_ops.py +++ b/tests/ops/qubit/test_parametric_ops.py @@ -2772,7 +2772,9 @@ def test_PauliRot_all_Identity(self): decomp_op = decomp_ops[0] - assert qml.equal(decomp_op, qml.GlobalPhase(theta/2)) + assert qml.equal(decomp_op, qml.GlobalPhase(theta / 2)) + + assert qml.math.allclose(op.matrix(), decomp_op.matrix() * np.eye(4)) def test_PauliRot_all_Identity_broadcasted(self): """Test handling of the broadcasted all-identity Pauli.""" @@ -2785,16 +2787,12 @@ def test_PauliRot_all_Identity_broadcasted(self): decomp_op = decomp_ops[0] - assert decomp_op.name == "GlobalPhase" - - # global phase acts on all wires so wire attribute is unused - assert decomp_op.wires == Wires([]) - - assert len(decomp_op.data[0]) == len(theta) - for param, angle in zip(decomp_op.data[0], theta): - assert qml.math.allclose(param, angle / 2) + assert qml.equal(decomp_op, qml.GlobalPhase(theta / 2)) - for op_matrix, decomp_phase in zip(op.matrix(), decomp_op.matrix().T): + op_matrices = op.matrix() + decomp_op_matrices = decomp_op.matrix().T + assert len(op_matrices) == len(decomp_op_matrices) + for op_matrix, decomp_phase in zip(op_matrices, decomp_op_matrices): assert qml.math.allclose(op_matrix, decomp_phase * np.eye(4)) @pytest.mark.parametrize("theta", [0.4, np.array([np.pi / 3, 0.1, -0.9])])