Skip to content

Commit

Permalink
Set up framework for an assert_equal function (#5634)
Browse files Browse the repository at this point in the history
**Context:**

While `qml.equal` has been incredibly helpful in helping us compare
operators, measurements, and tapes, it can sometimes be rather difficult
to determine *why* two operations are not equal.

This new `assert_equal` function will raise and `AssertionError` with
context, making it easier to debug failures occuring in tests or
scripts.

**Description of the Change:**

This PR just sets up the framework by:

1) Allowing the `_equal` function to return either a bool or str.
2) Making `qml.equal` interpret a string as `False`
3) Adding `assert_equal` to raise an `AssertionError` from the output of
`_equal`
4) Implementing the bool-> str conversion for `BasisRotation`

`BasisRotation` is as it an edge case where I can demonstrate what we
want these functions to look like without taking away the core part of
the work.

**Benefits:**

Easier debugging.

**Possible Drawbacks:**

More to maintain.

**Related GitHub Issues:**

---------

Co-authored-by: Thomas R. Bromley <49409390+trbromley@users.noreply.github.com>
  • Loading branch information
albi3ro and trbromley authored May 3, 2024
1 parent 465d337 commit b6e9658
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 10 deletions.
5 changes: 5 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@
`qml.devices.Device`, which follows the new device API.
[(#5581)](https://github.com/PennyLaneAI/pennylane/pull/5581)

* Sets up the framework for the development of an `assert_equal` function for testing operator comparison.
[(#5634)](https://github.com/PennyLaneAI/pennylane/pull/5634)

<h3>Breaking changes 💔</h3>

<h3>Deprecations 👋</h3>
Expand All @@ -71,5 +74,7 @@
<h3>Contributors ✍️</h3>

This release contains contributions from (in alphabetical order):

Pietropaolo Frisoni,
Christina Lee,
David Wierichs.
109 changes: 100 additions & 9 deletions pennylane/ops/functions/equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def equal(
check_trainability=True,
rtol=1e-5,
atol=1e-9,
):
) -> bool:
r"""Function for determining operator or measurement equality.
.. Warning::
Expand Down Expand Up @@ -163,14 +163,99 @@ def equal(
if isinstance(op2, (Hamiltonian, Tensor)):
return _equal(op2, op1)

return _equal(
dispatch_result = _equal(
op1,
op2,
check_interface=check_interface,
check_trainability=check_trainability,
atol=atol,
rtol=rtol,
)
if isinstance(dispatch_result, str):
return False
return dispatch_result


def assert_equal(
op1: Union[Operator, MeasurementProcess, QuantumTape],
op2: Union[Operator, MeasurementProcess, QuantumTape],
check_interface=True,
check_trainability=True,
rtol=1e-5,
atol=1e-9,
) -> None:
"""Function to assert that two operators are equal with the requested configuration.
Args:
op1 (.Operator or .MeasurementProcess or .QuantumTape): First object to compare
op2 (.Operator or .MeasurementProcess or .QuantumTape): Second object to compare
check_interface (bool, optional): Whether to compare interfaces. Default: ``True``.
Not used for comparing ``MeasurementProcess``, ``Hamiltonian`` or ``Tensor`` objects.
check_trainability (bool, optional): Whether to compare trainability status. Default: ``True``.
Not used for comparing ``MeasurementProcess``, ``Hamiltonian`` or ``Tensor`` objects.
rtol (float, optional): Relative tolerance for parameters. Not used for comparing ``MeasurementProcess``, ``Hamiltonian`` or ``Tensor`` objects.
atol (float, optional): Absolute tolerance for parameters. Not used for comparing ``MeasurementProcess``, ``Hamiltonian`` or ``Tensor`` objects.
Returns:
None
Raises:
AssertionError: An ``AssertionError`` is raised if the two operators are not equal.
.. warning::
This function is still under developement.
.. see-also::
:func:`~.equal`
>>> mat1 = qml.IsingXX.compute_matrix(0.1)
>>> op1 = qml.BasisRotation(wires=(0,1), unitary_matrix = mat1)
>>> mat2 = qml.IsingXX.compute_matrix(0.2)
>>> op2 = qml.BasisRotation(wires=(0,1), unitary_matrix = mat2)
>>> assert_equal(op1, op2)
AssertionError: The hyperparameter unitary_matrix is not equal for op1 and op2.
Got [[0.99875026+0.j 0. +0.j 0. +0.j
0. -0.04997917j]
[0. +0.j 0.99875026+0.j 0. -0.04997917j
0. +0.j ]
[0. +0.j 0. -0.04997917j 0.99875026+0.j
0. +0.j ]
[0. -0.04997917j 0. +0.j 0. +0.j
0.99875026+0.j ]]
and [[0.99500417+0.j 0. +0.j 0. +0.j
0. -0.09983342j]
[0. +0.j 0.99500417+0.j 0. -0.09983342j
0. +0.j ]
[0. +0.j 0. -0.09983342j 0.99500417+0.j
0. +0.j ]
[0. -0.09983342j 0. +0.j 0. +0.j
0.99500417+0.j ]].
>>> mat3 = qml.numpy.array(0.3)
>>> op3 = qml.BasisRotation(wires=(0,1), unitary_matrix = mat3)
>>> assert_equal(op1, op3)
AssertionError: The hyperparameter unitary_matrix has different interfaces for op1 and op2. Got numpy and autograd.
"""
if not isinstance(op2, type(op1)) and not isinstance(op1, Observable):
raise AssertionError(
f"op1 and op2 are of different types. Got {type(op1)} and {type(op2)}."
)

dispatch_result = _equal(
op1,
op2,
check_interface=check_interface,
check_trainability=check_trainability,
atol=atol,
rtol=rtol,
)
if isinstance(dispatch_result, str):
raise AssertionError(dispatch_result)
if not dispatch_result:
raise AssertionError(f"{op1} and {op2} are not equal for an unspecified reason.")


@singledispatch
Expand All @@ -181,7 +266,7 @@ def _equal(
check_trainability=True,
rtol=1e-5,
atol=1e-9,
): # pylint: disable=unused-argument
) -> Union[bool, str]: # pylint: disable=unused-argument
raise NotImplementedError(f"Comparison of {type(op1)} and {type(op2)} not implemented")


Expand Down Expand Up @@ -531,14 +616,20 @@ def _equal_basis_rotation(
atol=atol,
rtol=rtol,
):
return False
return (
"The hyperparameter unitary_matrix is not equal for op1 and op2.\n"
f"Got {op1.hyperparameters['unitary_matrix']}\n and {op2.hyperparameters['unitary_matrix']}."
)
if op1.wires != op2.wires:
return False
return f"op1 and op2 have different wires. Got {op1.wires} and {op2.wires}."
if check_interface:
if qml.math.get_interface(op1.hyperparameters["unitary_matrix"]) != qml.math.get_interface(
op2.hyperparameters["unitary_matrix"]
):
return False
interface1 = qml.math.get_interface(op1.hyperparameters["unitary_matrix"])
interface2 = qml.math.get_interface(op2.hyperparameters["unitary_matrix"])
if interface1 != interface2:
return (
"The hyperparameter unitary_matrix has different interfaces for op1 and op2."
f" Got {interface1} and {interface2}."
)
return True


Expand Down
44 changes: 43 additions & 1 deletion tests/ops/functions/test_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
Unit tests for the equal function.
Tests are divided by number of parameters and wires different operators take.
"""
# pylint: disable=too-many-arguments, too-many-public-methods
import itertools

# pylint: disable=too-many-arguments, too-many-public-methods
from copy import deepcopy

import numpy as np
Expand All @@ -26,6 +27,7 @@
from pennylane import numpy as npp
from pennylane.measurements import ExpectationMP
from pennylane.measurements.probs import ProbabilityMP
from pennylane.ops.functions.equal import _equal, assert_equal
from pennylane.ops.op_math import Controlled, SymbolicOp
from pennylane.templates.subroutines import ControlledSequence

Expand Down Expand Up @@ -315,6 +317,34 @@
]


def test_assert_equal_types():
"""Test that assert equal raises if the operator types are different."""

op1 = qml.S(0)
op2 = qml.T(0)
with pytest.raises(AssertionError, match="op1 and op2 are of different types"):
assert_equal(op1, op2)


def test_assert_equal_unspecified():

# pylint: disable=too-few-public-methods
class RandomType:
"""dummy type"""

def __init__(self):
pass

# pylint: disable=unused-argument
@_equal.register
def _(op1: RandomType, op2, **_):
"""always returns false"""
return False

with pytest.raises(AssertionError, match=r"for an unspecified reason"):
assert_equal(RandomType(), RandomType())


class TestEqual:
@pytest.mark.parametrize("ops", PARAMETRIZED_OPERATIONS_COMBINATIONS)
def test_equal_simple_diff_op(self, ops):
Expand Down Expand Up @@ -2062,16 +2092,24 @@ class TestBasisRotation:
@pytest.mark.parametrize("op, other_op", [(op1, op3)])
def test_different_tolerances_comparison(self, op, other_op):
assert qml.equal(op, other_op, atol=1e-5)
assert_equal(op, other_op, atol=1e-5)
assert qml.equal(op, other_op, rtol=0, atol=1e-9) is False

with pytest.raises(AssertionError, match="The hyperparameter unitary_matrix is not equal"):
assert_equal(op, other_op, rtol=0, atol=1e-9)

@pytest.mark.parametrize("op, other_op", [(op1, op2)])
def test_non_equal_training_params_comparison(self, op, other_op):
assert qml.equal(op, other_op)
assert_equal(op, other_op)

@pytest.mark.parametrize("op, other_op", [(op1, op4)])
def test_non_equal_training_wires(self, op, other_op):
assert qml.equal(op, other_op) is False

with pytest.raises(AssertionError, match="op1 and op2 have different wires."):
assert_equal(op, other_op)

@pytest.mark.jax
@pytest.mark.parametrize("op", [op1])
def test_non_equal_interfaces(self, op):
Expand All @@ -2085,8 +2123,12 @@ def test_non_equal_interfaces(self, op):
)
other_op = qml.BasisRotation(wires=range(2), unitary_matrix=rotation_mat_jax)
assert qml.equal(op, other_op, check_interface=False)
assert_equal(op, other_op, check_interface=False)
assert qml.equal(op, other_op) is False

with pytest.raises(AssertionError, match=r"has different interfaces for op1 and op2"):
assert_equal(op, other_op)


class TestHilbertSchmidt:
"""Test that qml.equal works with qml.HilbertSchmidt."""
Expand Down

0 comments on commit b6e9658

Please sign in to comment.