diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index f0e4f66368a..5ea039d3a3b 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -6,6 +6,10 @@
Improvements ðŸ›
+* `qml.equal` no longer raises errors when operators or measurements of different types are compared.
+ Instead, it returns `False`.
+ [(#4315)](https://github.com/PennyLaneAI/pennylane/pull/4315)
+
* The `qml.gradients` module no longer mutates operators in-place for any gradient transforms.
Instead, operators that need to be mutated are copied with new parameters.
[(#4220)](https://github.com/PennyLaneAI/pennylane/pull/4220)
diff --git a/pennylane/ops/functions/equal.py b/pennylane/ops/functions/equal.py
index d0a93167deb..5c74bd521e8 100644
--- a/pennylane/ops/functions/equal.py
+++ b/pennylane/ops/functions/equal.py
@@ -94,11 +94,11 @@ def equal(
>>> qml.equal(H1, H2), qml.equal(H1, H3)
(True, False)
- >>> qml.equal(qml.expval(qml.PauliX(0)), qml.expval(qml.PauliX(0)) )
+ >>> qml.equal(qml.expval(qml.PauliX(0)), qml.expval(qml.PauliX(0)))
True
- >>> qml.equal(qml.probs(wires=(0,1)), qml.probs(wires=(1,2)) )
+ >>> qml.equal(qml.probs(wires=(0,1)), qml.probs(wires=(1,2)))
False
- >>> qml.equal(qml.classical_shadow(wires=[0,1]), qml.classical_shadow(wires=[0,1]) )
+ >>> qml.equal(qml.classical_shadow(wires=[0,1]), qml.classical_shadow(wires=[0,1]))
True
.. details::
@@ -160,7 +160,7 @@ def _equal(
check_trainability=True,
rtol=1e-5,
atol=1e-9,
-):
+): # pylint: disable=unused-argument
raise NotImplementedError(f"Comparison of {type(op1)} and {type(op2)} not implemented")
@@ -183,9 +183,10 @@ def _equal_operators(
return False
if op1.arithmetic_depth > 0:
- raise NotImplementedError(
- "Comparison of operators with an arithmetic depth larger than 0 is not yet implemented."
- )
+ # Other dispatches cover cases of operations with arithmetic depth > 0.
+ # If any new operations are added with arithmetic depth > 0, a new dispatch
+ # should be created for them.
+ return False
if not all(
qml.math.allclose(d1, d2, rtol=rtol, atol=atol) for d1, d2 in zip(op1.data, op2.data)
):
@@ -235,12 +236,8 @@ def _equal_controlled(op1: Controlled, op2: Controlled, **kwargs):
op2.arithmetic_depth,
]:
return False
- try:
- return qml.equal(op1.base, op2.base, **kwargs)
- except NotImplementedError as e:
- raise NotImplementedError(
- f"Unable to compare base operators {op1.base} and {op2.base}."
- ) from e
+
+ return qml.equal(op1.base, op2.base, **kwargs)
@_equal.register
@@ -291,7 +288,7 @@ def _equal_tensor(op1: Tensor, op2: Observable, **kwargs):
if isinstance(op2, Tensor):
return op1._obs_data() == op2._obs_data() # pylint: disable=protected-access
- raise NotImplementedError(f"Comparison of {type(op1)} and {type(op2)} not implemented")
+ return False
@_equal.register
diff --git a/tests/ops/functions/test_equal.py b/tests/ops/functions/test_equal.py
index 42aa05b6e48..bda02c3272d 100644
--- a/tests/ops/functions/test_equal.py
+++ b/tests/ops/functions/test_equal.py
@@ -15,6 +15,7 @@
Unit tests for the equal function.
Tests are divided by number of parameters and wires different operators take.
"""
+# pylint: disable=too-many-arguments
import itertools
import numpy as np
@@ -582,7 +583,7 @@ def test_equal_simple_op_3p1w(self, op1):
)
@pytest.mark.all_interfaces
- def test_equal_op_remaining(self):
+ def test_equal_op_remaining(self): # pylint: disable=too-many-statements
"""Test optional arguments are working"""
# pylint: disable=too-many-statements
wire = 0
@@ -1105,9 +1106,9 @@ def test_equal_with_different_arithmetic_depth(self):
op2 = qml.prod(op1, qml.RY(0.25, wires=1))
assert not qml.equal(op1, op2)
- def test_equal_with_unsupported_nested_operators_raises_error(self):
- """Test that the equal method with two operators with the same arithmetic depth (>0) raises
- an error unless there is a singledispatch function specifically comparing that operator type.
+ def test_equal_with_unsupported_nested_operators_returns_false(self):
+ """Test that the equal method with two operators with the same arithmetic depth (>0) returns
+ `False` unless there is a singledispatch function specifically comparing that operator type.
"""
op1 = SymbolicOp(qml.PauliY(0))
@@ -1116,12 +1117,7 @@ def test_equal_with_unsupported_nested_operators_raises_error(self):
assert op1.arithmetic_depth == op2.arithmetic_depth
assert op1.arithmetic_depth > 0
- with pytest.raises(
- NotImplementedError,
- match="Comparison of operators with an arithmetic"
- + " depth larger than 0 is not yet implemented.",
- ):
- qml.equal(op1, op2)
+ assert not qml.equal(op1, op2)
# Measurements test cases
@pytest.mark.parametrize("ops", PARAMETRIZED_MEASUREMENTS_COMBINATIONS)
@@ -1258,13 +1254,12 @@ def test_tensor_and_operation_not_equal(self):
assert qml.equal(op1, op2) is False
assert qml.equal(op2, op1) is False
- def test_tensor_and_unsupported_observable_not_implemented(self):
- """Tests that trying to compare a Tensor to something other than another Tensor or a Hamiltonian raises a NotImplmenetedError"""
+ def test_tensor_and_unsupported_observable_returns_false(self):
+ """Tests that trying to compare a Tensor to something other than another Tensor or a Hamiltonian returns False"""
op1 = qml.PauliX(0) @ qml.PauliY(1)
op2 = qml.Hermitian([[0, 1], [1, 0]], 0)
- with pytest.raises(NotImplementedError, match="Comparison of"):
- qml.equal(op1, op2)
+ assert not qml.equal(op1, op2)
def test_unsupported_object_type_not_implemented(self):
dev = qml.device("default.qubit", wires=1)
@@ -1301,14 +1296,13 @@ def test_mismatched_arithmetic_depth(self):
assert op2.arithmetic_depth == 2
assert qml.equal(op1, op2) is False
- def test_comparison_of_base_not_implemented_error(self):
- """Test that comparing SymbolicOps of base operators whose comparison is not yet implemented raises an error"""
+ def test_comparison_of_base_not_implemented_returns_false(self):
+ """Test that comparing SymbolicOps of base operators whose comparison is not yet implemented returns False"""
base = SymbolicOp(qml.RX(1.2, 0))
op1 = Controlled(base, control_wires=2)
op2 = Controlled(base, control_wires=2)
- with pytest.raises(NotImplementedError, match="Unable to compare base operators "):
- qml.equal(op1, op2)
+ assert not qml.equal(op1, op2)
@pytest.mark.torch
@pytest.mark.jax