diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index f0e4f66368a..5ea039d3a3b 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -6,6 +6,10 @@

Improvements 🛠

+* `qml.equal` no longer raises errors when operators or measurements of different types are compared. + Instead, it returns `False`. + [(#4315)](https://github.com/PennyLaneAI/pennylane/pull/4315) + * The `qml.gradients` module no longer mutates operators in-place for any gradient transforms. Instead, operators that need to be mutated are copied with new parameters. [(#4220)](https://github.com/PennyLaneAI/pennylane/pull/4220) diff --git a/pennylane/ops/functions/equal.py b/pennylane/ops/functions/equal.py index d0a93167deb..5c74bd521e8 100644 --- a/pennylane/ops/functions/equal.py +++ b/pennylane/ops/functions/equal.py @@ -94,11 +94,11 @@ def equal( >>> qml.equal(H1, H2), qml.equal(H1, H3) (True, False) - >>> qml.equal(qml.expval(qml.PauliX(0)), qml.expval(qml.PauliX(0)) ) + >>> qml.equal(qml.expval(qml.PauliX(0)), qml.expval(qml.PauliX(0))) True - >>> qml.equal(qml.probs(wires=(0,1)), qml.probs(wires=(1,2)) ) + >>> qml.equal(qml.probs(wires=(0,1)), qml.probs(wires=(1,2))) False - >>> qml.equal(qml.classical_shadow(wires=[0,1]), qml.classical_shadow(wires=[0,1]) ) + >>> qml.equal(qml.classical_shadow(wires=[0,1]), qml.classical_shadow(wires=[0,1])) True .. details:: @@ -160,7 +160,7 @@ def _equal( check_trainability=True, rtol=1e-5, atol=1e-9, -): +): # pylint: disable=unused-argument raise NotImplementedError(f"Comparison of {type(op1)} and {type(op2)} not implemented") @@ -183,9 +183,10 @@ def _equal_operators( return False if op1.arithmetic_depth > 0: - raise NotImplementedError( - "Comparison of operators with an arithmetic depth larger than 0 is not yet implemented." - ) + # Other dispatches cover cases of operations with arithmetic depth > 0. + # If any new operations are added with arithmetic depth > 0, a new dispatch + # should be created for them. + return False if not all( qml.math.allclose(d1, d2, rtol=rtol, atol=atol) for d1, d2 in zip(op1.data, op2.data) ): @@ -235,12 +236,8 @@ def _equal_controlled(op1: Controlled, op2: Controlled, **kwargs): op2.arithmetic_depth, ]: return False - try: - return qml.equal(op1.base, op2.base, **kwargs) - except NotImplementedError as e: - raise NotImplementedError( - f"Unable to compare base operators {op1.base} and {op2.base}." - ) from e + + return qml.equal(op1.base, op2.base, **kwargs) @_equal.register @@ -291,7 +288,7 @@ def _equal_tensor(op1: Tensor, op2: Observable, **kwargs): if isinstance(op2, Tensor): return op1._obs_data() == op2._obs_data() # pylint: disable=protected-access - raise NotImplementedError(f"Comparison of {type(op1)} and {type(op2)} not implemented") + return False @_equal.register diff --git a/tests/ops/functions/test_equal.py b/tests/ops/functions/test_equal.py index 42aa05b6e48..bda02c3272d 100644 --- a/tests/ops/functions/test_equal.py +++ b/tests/ops/functions/test_equal.py @@ -15,6 +15,7 @@ Unit tests for the equal function. Tests are divided by number of parameters and wires different operators take. """ +# pylint: disable=too-many-arguments import itertools import numpy as np @@ -582,7 +583,7 @@ def test_equal_simple_op_3p1w(self, op1): ) @pytest.mark.all_interfaces - def test_equal_op_remaining(self): + def test_equal_op_remaining(self): # pylint: disable=too-many-statements """Test optional arguments are working""" # pylint: disable=too-many-statements wire = 0 @@ -1105,9 +1106,9 @@ def test_equal_with_different_arithmetic_depth(self): op2 = qml.prod(op1, qml.RY(0.25, wires=1)) assert not qml.equal(op1, op2) - def test_equal_with_unsupported_nested_operators_raises_error(self): - """Test that the equal method with two operators with the same arithmetic depth (>0) raises - an error unless there is a singledispatch function specifically comparing that operator type. + def test_equal_with_unsupported_nested_operators_returns_false(self): + """Test that the equal method with two operators with the same arithmetic depth (>0) returns + `False` unless there is a singledispatch function specifically comparing that operator type. """ op1 = SymbolicOp(qml.PauliY(0)) @@ -1116,12 +1117,7 @@ def test_equal_with_unsupported_nested_operators_raises_error(self): assert op1.arithmetic_depth == op2.arithmetic_depth assert op1.arithmetic_depth > 0 - with pytest.raises( - NotImplementedError, - match="Comparison of operators with an arithmetic" - + " depth larger than 0 is not yet implemented.", - ): - qml.equal(op1, op2) + assert not qml.equal(op1, op2) # Measurements test cases @pytest.mark.parametrize("ops", PARAMETRIZED_MEASUREMENTS_COMBINATIONS) @@ -1258,13 +1254,12 @@ def test_tensor_and_operation_not_equal(self): assert qml.equal(op1, op2) is False assert qml.equal(op2, op1) is False - def test_tensor_and_unsupported_observable_not_implemented(self): - """Tests that trying to compare a Tensor to something other than another Tensor or a Hamiltonian raises a NotImplmenetedError""" + def test_tensor_and_unsupported_observable_returns_false(self): + """Tests that trying to compare a Tensor to something other than another Tensor or a Hamiltonian returns False""" op1 = qml.PauliX(0) @ qml.PauliY(1) op2 = qml.Hermitian([[0, 1], [1, 0]], 0) - with pytest.raises(NotImplementedError, match="Comparison of"): - qml.equal(op1, op2) + assert not qml.equal(op1, op2) def test_unsupported_object_type_not_implemented(self): dev = qml.device("default.qubit", wires=1) @@ -1301,14 +1296,13 @@ def test_mismatched_arithmetic_depth(self): assert op2.arithmetic_depth == 2 assert qml.equal(op1, op2) is False - def test_comparison_of_base_not_implemented_error(self): - """Test that comparing SymbolicOps of base operators whose comparison is not yet implemented raises an error""" + def test_comparison_of_base_not_implemented_returns_false(self): + """Test that comparing SymbolicOps of base operators whose comparison is not yet implemented returns False""" base = SymbolicOp(qml.RX(1.2, 0)) op1 = Controlled(base, control_wires=2) op2 = Controlled(base, control_wires=2) - with pytest.raises(NotImplementedError, match="Unable to compare base operators "): - qml.equal(op1, op2) + assert not qml.equal(op1, op2) @pytest.mark.torch @pytest.mark.jax