Skip to content

Commit

Permalink
Change qml.equal default behaviour to return False instead of rai…
Browse files Browse the repository at this point in the history
…sing error (#4315)

* Update qml.equal

* Update changelog

* Update doc/releases/changelog-dev.md

* Improve changelog entry

* Updated tests; linting

* Added test to linting file

* Updated tests

* Added comment with info

* Updated test
  • Loading branch information
mudit2812 committed Jul 11, 2023
1 parent 58c01e5 commit b781be0
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 32 deletions.
4 changes: 4 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@

<h3>Improvements 🛠</h3>

* `qml.equal` no longer raises errors when operators or measurements of different types are compared.
Instead, it returns `False`.
[(#4315)](https://github.com/PennyLaneAI/pennylane/pull/4315)

* The `qml.gradients` module no longer mutates operators in-place for any gradient transforms.
Instead, operators that need to be mutated are copied with new parameters.
[(#4220)](https://github.com/PennyLaneAI/pennylane/pull/4220)
Expand Down
25 changes: 11 additions & 14 deletions pennylane/ops/functions/equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,11 @@ def equal(
>>> qml.equal(H1, H2), qml.equal(H1, H3)
(True, False)
>>> qml.equal(qml.expval(qml.PauliX(0)), qml.expval(qml.PauliX(0)) )
>>> qml.equal(qml.expval(qml.PauliX(0)), qml.expval(qml.PauliX(0)))
True
>>> qml.equal(qml.probs(wires=(0,1)), qml.probs(wires=(1,2)) )
>>> qml.equal(qml.probs(wires=(0,1)), qml.probs(wires=(1,2)))
False
>>> qml.equal(qml.classical_shadow(wires=[0,1]), qml.classical_shadow(wires=[0,1]) )
>>> qml.equal(qml.classical_shadow(wires=[0,1]), qml.classical_shadow(wires=[0,1]))
True
.. details::
Expand Down Expand Up @@ -160,7 +160,7 @@ def _equal(
check_trainability=True,
rtol=1e-5,
atol=1e-9,
):
): # pylint: disable=unused-argument
raise NotImplementedError(f"Comparison of {type(op1)} and {type(op2)} not implemented")


Expand All @@ -183,9 +183,10 @@ def _equal_operators(
return False

if op1.arithmetic_depth > 0:
raise NotImplementedError(
"Comparison of operators with an arithmetic depth larger than 0 is not yet implemented."
)
# Other dispatches cover cases of operations with arithmetic depth > 0.
# If any new operations are added with arithmetic depth > 0, a new dispatch
# should be created for them.
return False
if not all(
qml.math.allclose(d1, d2, rtol=rtol, atol=atol) for d1, d2 in zip(op1.data, op2.data)
):
Expand Down Expand Up @@ -235,12 +236,8 @@ def _equal_controlled(op1: Controlled, op2: Controlled, **kwargs):
op2.arithmetic_depth,
]:
return False
try:
return qml.equal(op1.base, op2.base, **kwargs)
except NotImplementedError as e:
raise NotImplementedError(
f"Unable to compare base operators {op1.base} and {op2.base}."
) from e

return qml.equal(op1.base, op2.base, **kwargs)


@_equal.register
Expand Down Expand Up @@ -291,7 +288,7 @@ def _equal_tensor(op1: Tensor, op2: Observable, **kwargs):
if isinstance(op2, Tensor):
return op1._obs_data() == op2._obs_data() # pylint: disable=protected-access

raise NotImplementedError(f"Comparison of {type(op1)} and {type(op2)} not implemented")
return False


@_equal.register
Expand Down
30 changes: 12 additions & 18 deletions tests/ops/functions/test_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Unit tests for the equal function.
Tests are divided by number of parameters and wires different operators take.
"""
# pylint: disable=too-many-arguments
import itertools

import numpy as np
Expand Down Expand Up @@ -582,7 +583,7 @@ def test_equal_simple_op_3p1w(self, op1):
)

@pytest.mark.all_interfaces
def test_equal_op_remaining(self):
def test_equal_op_remaining(self): # pylint: disable=too-many-statements
"""Test optional arguments are working"""
# pylint: disable=too-many-statements
wire = 0
Expand Down Expand Up @@ -1105,9 +1106,9 @@ def test_equal_with_different_arithmetic_depth(self):
op2 = qml.prod(op1, qml.RY(0.25, wires=1))
assert not qml.equal(op1, op2)

def test_equal_with_unsupported_nested_operators_raises_error(self):
"""Test that the equal method with two operators with the same arithmetic depth (>0) raises
an error unless there is a singledispatch function specifically comparing that operator type.
def test_equal_with_unsupported_nested_operators_returns_false(self):
"""Test that the equal method with two operators with the same arithmetic depth (>0) returns
`False` unless there is a singledispatch function specifically comparing that operator type.
"""

op1 = SymbolicOp(qml.PauliY(0))
Expand All @@ -1116,12 +1117,7 @@ def test_equal_with_unsupported_nested_operators_raises_error(self):
assert op1.arithmetic_depth == op2.arithmetic_depth
assert op1.arithmetic_depth > 0

with pytest.raises(
NotImplementedError,
match="Comparison of operators with an arithmetic"
+ " depth larger than 0 is not yet implemented.",
):
qml.equal(op1, op2)
assert not qml.equal(op1, op2)

# Measurements test cases
@pytest.mark.parametrize("ops", PARAMETRIZED_MEASUREMENTS_COMBINATIONS)
Expand Down Expand Up @@ -1258,13 +1254,12 @@ def test_tensor_and_operation_not_equal(self):
assert qml.equal(op1, op2) is False
assert qml.equal(op2, op1) is False

def test_tensor_and_unsupported_observable_not_implemented(self):
"""Tests that trying to compare a Tensor to something other than another Tensor or a Hamiltonian raises a NotImplmenetedError"""
def test_tensor_and_unsupported_observable_returns_false(self):
"""Tests that trying to compare a Tensor to something other than another Tensor or a Hamiltonian returns False"""
op1 = qml.PauliX(0) @ qml.PauliY(1)
op2 = qml.Hermitian([[0, 1], [1, 0]], 0)

with pytest.raises(NotImplementedError, match="Comparison of"):
qml.equal(op1, op2)
assert not qml.equal(op1, op2)

def test_unsupported_object_type_not_implemented(self):
dev = qml.device("default.qubit", wires=1)
Expand Down Expand Up @@ -1301,14 +1296,13 @@ def test_mismatched_arithmetic_depth(self):
assert op2.arithmetic_depth == 2
assert qml.equal(op1, op2) is False

def test_comparison_of_base_not_implemented_error(self):
"""Test that comparing SymbolicOps of base operators whose comparison is not yet implemented raises an error"""
def test_comparison_of_base_not_implemented_returns_false(self):
"""Test that comparing SymbolicOps of base operators whose comparison is not yet implemented returns False"""
base = SymbolicOp(qml.RX(1.2, 0))
op1 = Controlled(base, control_wires=2)
op2 = Controlled(base, control_wires=2)

with pytest.raises(NotImplementedError, match="Unable to compare base operators "):
qml.equal(op1, op2)
assert not qml.equal(op1, op2)

@pytest.mark.torch
@pytest.mark.jax
Expand Down

0 comments on commit b781be0

Please sign in to comment.