From b6e96588f9e3d4236d1f31f73e22566298a22344 Mon Sep 17 00:00:00 2001 From: Christina Lee Date: Fri, 3 May 2024 15:28:52 -0400 Subject: [PATCH] Set up framework for an `assert_equal` function (#5634) **Context:** While `qml.equal` has been incredibly helpful in helping us compare operators, measurements, and tapes, it can sometimes be rather difficult to determine *why* two operations are not equal. This new `assert_equal` function will raise and `AssertionError` with context, making it easier to debug failures occuring in tests or scripts. **Description of the Change:** This PR just sets up the framework by: 1) Allowing the `_equal` function to return either a bool or str. 2) Making `qml.equal` interpret a string as `False` 3) Adding `assert_equal` to raise an `AssertionError` from the output of `_equal` 4) Implementing the bool-> str conversion for `BasisRotation` `BasisRotation` is as it an edge case where I can demonstrate what we want these functions to look like without taking away the core part of the work. **Benefits:** Easier debugging. **Possible Drawbacks:** More to maintain. **Related GitHub Issues:** --------- Co-authored-by: Thomas R. Bromley <49409390+trbromley@users.noreply.github.com> --- doc/releases/changelog-dev.md | 5 ++ pennylane/ops/functions/equal.py | 109 +++++++++++++++++++++++++++--- tests/ops/functions/test_equal.py | 44 +++++++++++- 3 files changed, 148 insertions(+), 10 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index ee36a5ff256..ad7395f90b5 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -60,6 +60,9 @@ `qml.devices.Device`, which follows the new device API. [(#5581)](https://github.com/PennyLaneAI/pennylane/pull/5581) +* Sets up the framework for the development of an `assert_equal` function for testing operator comparison. + [(#5634)](https://github.com/PennyLaneAI/pennylane/pull/5634) +

Breaking changes 💔

Deprecations 👋

@@ -71,5 +74,7 @@

Contributors ✍️

This release contains contributions from (in alphabetical order): + Pietropaolo Frisoni, +Christina Lee, David Wierichs. \ No newline at end of file diff --git a/pennylane/ops/functions/equal.py b/pennylane/ops/functions/equal.py index 2cf5e1f803f..8130fa02a91 100644 --- a/pennylane/ops/functions/equal.py +++ b/pennylane/ops/functions/equal.py @@ -49,7 +49,7 @@ def equal( check_trainability=True, rtol=1e-5, atol=1e-9, -): +) -> bool: r"""Function for determining operator or measurement equality. .. Warning:: @@ -163,7 +163,7 @@ def equal( if isinstance(op2, (Hamiltonian, Tensor)): return _equal(op2, op1) - return _equal( + dispatch_result = _equal( op1, op2, check_interface=check_interface, @@ -171,6 +171,91 @@ def equal( atol=atol, rtol=rtol, ) + if isinstance(dispatch_result, str): + return False + return dispatch_result + + +def assert_equal( + op1: Union[Operator, MeasurementProcess, QuantumTape], + op2: Union[Operator, MeasurementProcess, QuantumTape], + check_interface=True, + check_trainability=True, + rtol=1e-5, + atol=1e-9, +) -> None: + """Function to assert that two operators are equal with the requested configuration. + + Args: + op1 (.Operator or .MeasurementProcess or .QuantumTape): First object to compare + op2 (.Operator or .MeasurementProcess or .QuantumTape): Second object to compare + check_interface (bool, optional): Whether to compare interfaces. Default: ``True``. + Not used for comparing ``MeasurementProcess``, ``Hamiltonian`` or ``Tensor`` objects. + check_trainability (bool, optional): Whether to compare trainability status. Default: ``True``. + Not used for comparing ``MeasurementProcess``, ``Hamiltonian`` or ``Tensor`` objects. + rtol (float, optional): Relative tolerance for parameters. Not used for comparing ``MeasurementProcess``, ``Hamiltonian`` or ``Tensor`` objects. + atol (float, optional): Absolute tolerance for parameters. Not used for comparing ``MeasurementProcess``, ``Hamiltonian`` or ``Tensor`` objects. + + Returns: + None + + Raises: + + AssertionError: An ``AssertionError`` is raised if the two operators are not equal. + + .. warning:: + + This function is still under developement. + + .. see-also:: + + :func:`~.equal` + + >>> mat1 = qml.IsingXX.compute_matrix(0.1) + >>> op1 = qml.BasisRotation(wires=(0,1), unitary_matrix = mat1) + >>> mat2 = qml.IsingXX.compute_matrix(0.2) + >>> op2 = qml.BasisRotation(wires=(0,1), unitary_matrix = mat2) + >>> assert_equal(op1, op2) + AssertionError: The hyperparameter unitary_matrix is not equal for op1 and op2. + Got [[0.99875026+0.j 0. +0.j 0. +0.j + 0. -0.04997917j] + [0. +0.j 0.99875026+0.j 0. -0.04997917j + 0. +0.j ] + [0. +0.j 0. -0.04997917j 0.99875026+0.j + 0. +0.j ] + [0. -0.04997917j 0. +0.j 0. +0.j + 0.99875026+0.j ]] + and [[0.99500417+0.j 0. +0.j 0. +0.j + 0. -0.09983342j] + [0. +0.j 0.99500417+0.j 0. -0.09983342j + 0. +0.j ] + [0. +0.j 0. -0.09983342j 0.99500417+0.j + 0. +0.j ] + [0. -0.09983342j 0. +0.j 0. +0.j + 0.99500417+0.j ]]. + >>> mat3 = qml.numpy.array(0.3) + >>> op3 = qml.BasisRotation(wires=(0,1), unitary_matrix = mat3) + >>> assert_equal(op1, op3) + AssertionError: The hyperparameter unitary_matrix has different interfaces for op1 and op2. Got numpy and autograd. + + """ + if not isinstance(op2, type(op1)) and not isinstance(op1, Observable): + raise AssertionError( + f"op1 and op2 are of different types. Got {type(op1)} and {type(op2)}." + ) + + dispatch_result = _equal( + op1, + op2, + check_interface=check_interface, + check_trainability=check_trainability, + atol=atol, + rtol=rtol, + ) + if isinstance(dispatch_result, str): + raise AssertionError(dispatch_result) + if not dispatch_result: + raise AssertionError(f"{op1} and {op2} are not equal for an unspecified reason.") @singledispatch @@ -181,7 +266,7 @@ def _equal( check_trainability=True, rtol=1e-5, atol=1e-9, -): # pylint: disable=unused-argument +) -> Union[bool, str]: # pylint: disable=unused-argument raise NotImplementedError(f"Comparison of {type(op1)} and {type(op2)} not implemented") @@ -531,14 +616,20 @@ def _equal_basis_rotation( atol=atol, rtol=rtol, ): - return False + return ( + "The hyperparameter unitary_matrix is not equal for op1 and op2.\n" + f"Got {op1.hyperparameters['unitary_matrix']}\n and {op2.hyperparameters['unitary_matrix']}." + ) if op1.wires != op2.wires: - return False + return f"op1 and op2 have different wires. Got {op1.wires} and {op2.wires}." if check_interface: - if qml.math.get_interface(op1.hyperparameters["unitary_matrix"]) != qml.math.get_interface( - op2.hyperparameters["unitary_matrix"] - ): - return False + interface1 = qml.math.get_interface(op1.hyperparameters["unitary_matrix"]) + interface2 = qml.math.get_interface(op2.hyperparameters["unitary_matrix"]) + if interface1 != interface2: + return ( + "The hyperparameter unitary_matrix has different interfaces for op1 and op2." + f" Got {interface1} and {interface2}." + ) return True diff --git a/tests/ops/functions/test_equal.py b/tests/ops/functions/test_equal.py index 9e1824e3349..909e60eed9e 100644 --- a/tests/ops/functions/test_equal.py +++ b/tests/ops/functions/test_equal.py @@ -15,8 +15,9 @@ Unit tests for the equal function. Tests are divided by number of parameters and wires different operators take. """ -# pylint: disable=too-many-arguments, too-many-public-methods import itertools + +# pylint: disable=too-many-arguments, too-many-public-methods from copy import deepcopy import numpy as np @@ -26,6 +27,7 @@ from pennylane import numpy as npp from pennylane.measurements import ExpectationMP from pennylane.measurements.probs import ProbabilityMP +from pennylane.ops.functions.equal import _equal, assert_equal from pennylane.ops.op_math import Controlled, SymbolicOp from pennylane.templates.subroutines import ControlledSequence @@ -315,6 +317,34 @@ ] +def test_assert_equal_types(): + """Test that assert equal raises if the operator types are different.""" + + op1 = qml.S(0) + op2 = qml.T(0) + with pytest.raises(AssertionError, match="op1 and op2 are of different types"): + assert_equal(op1, op2) + + +def test_assert_equal_unspecified(): + + # pylint: disable=too-few-public-methods + class RandomType: + """dummy type""" + + def __init__(self): + pass + + # pylint: disable=unused-argument + @_equal.register + def _(op1: RandomType, op2, **_): + """always returns false""" + return False + + with pytest.raises(AssertionError, match=r"for an unspecified reason"): + assert_equal(RandomType(), RandomType()) + + class TestEqual: @pytest.mark.parametrize("ops", PARAMETRIZED_OPERATIONS_COMBINATIONS) def test_equal_simple_diff_op(self, ops): @@ -2062,16 +2092,24 @@ class TestBasisRotation: @pytest.mark.parametrize("op, other_op", [(op1, op3)]) def test_different_tolerances_comparison(self, op, other_op): assert qml.equal(op, other_op, atol=1e-5) + assert_equal(op, other_op, atol=1e-5) assert qml.equal(op, other_op, rtol=0, atol=1e-9) is False + with pytest.raises(AssertionError, match="The hyperparameter unitary_matrix is not equal"): + assert_equal(op, other_op, rtol=0, atol=1e-9) + @pytest.mark.parametrize("op, other_op", [(op1, op2)]) def test_non_equal_training_params_comparison(self, op, other_op): assert qml.equal(op, other_op) + assert_equal(op, other_op) @pytest.mark.parametrize("op, other_op", [(op1, op4)]) def test_non_equal_training_wires(self, op, other_op): assert qml.equal(op, other_op) is False + with pytest.raises(AssertionError, match="op1 and op2 have different wires."): + assert_equal(op, other_op) + @pytest.mark.jax @pytest.mark.parametrize("op", [op1]) def test_non_equal_interfaces(self, op): @@ -2085,8 +2123,12 @@ def test_non_equal_interfaces(self, op): ) other_op = qml.BasisRotation(wires=range(2), unitary_matrix=rotation_mat_jax) assert qml.equal(op, other_op, check_interface=False) + assert_equal(op, other_op, check_interface=False) assert qml.equal(op, other_op) is False + with pytest.raises(AssertionError, match=r"has different interfaces for op1 and op2"): + assert_equal(op, other_op) + class TestHilbertSchmidt: """Test that qml.equal works with qml.HilbertSchmidt."""