Skip to content

Commit

Permalink
Fix bug in Tensor equal check (#5877)
Browse files Browse the repository at this point in the history
Fixes bug where `qml.equal` incorrectly returns `True` when a `Tensor`
is not equal with another non-`Tensor` observable.

---------

Co-authored-by: Christina Lee <christina@xanadu.ai>
  • Loading branch information
astralcai and albi3ro authored Jun 19, 2024
1 parent 11d3d93 commit b38fe79
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 13 deletions.
3 changes: 2 additions & 1 deletion doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,8 @@
[(#5502)](https://github.com/PennyLaneAI/pennylane/pull/5502)

* Implement support in `assert_equal` for `Operator`, `Controlled`, `Adjoint`, `Pow`, `Exp`, `SProd`, `ControlledSequence`, `Prod`, `Sum`, `Tensor` and `Hamiltonian`
[(#5780)](https://github.com/PennyLaneAI/pennylane/pull/5780)
[(#5780)](https://github.com/PennyLaneAI/pennylane/pull/5780)
[(#5877)](https://github.com/PennyLaneAI/pennylane/pull/5877)

* `qml.QutritChannel` has been added, enabling the specification of noise using a collection of (3x3) Kraus matrices on the `default.qutrit.mixed` device.
[(#5793)](https://github.com/PennyLaneAI/pennylane/issues/5793)
Expand Down
20 changes: 11 additions & 9 deletions pennylane/ops/functions/equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,31 +597,33 @@ def _equal_sprod(op1: SProd, op2: SProd, **kwargs):
# pylint: disable=unused-argument
def _equal_tensor(op1: Tensor, op2: Observable, **kwargs):
"""Determine whether a Tensor object is equal to a Hamiltonian/Tensor"""

if not isinstance(op2, Observable):
return f"{op2} is not of type Observable"

if isinstance(op2, (Hamiltonian, LinearCombination, Hermitian)):
if not op2.compare(op1):
return f"'{op1}' and '{op2}' are not same"
return (
op2.compare(op1) or f"'{op1}' and '{op2}' are not the same for an unspecified reason."
)

if isinstance(op2, Tensor):
if not op1._obs_data() == op2._obs_data(): # pylint: disable=protected-access
return "op1 and op2 have different _obs_data outputs"
return (
op1._obs_data() == op2._obs_data() # pylint: disable=protected-access
or f"{op1} and {op2} have different _obs_data outputs"
)

return True
return f"{op1} is of type {type(op1)} and {op2} is of type {type(op2)}"


@_equal_dispatch.register
# pylint: disable=unused-argument
def _equal_hamiltonian(op1: Hamiltonian, op2: Observable, **kwargs):
"""Determine whether a Hamiltonian object is equal to a Hamiltonian/Tensor objects"""

if not isinstance(op2, Observable):
return f"{op2} is not of type Observable"

if not op1.compare(op2):
return f"'{op1}' and '{op2}' are not same"

return True
return op1.compare(op2) or f"'{op1}' and '{op2}' are not the same for an unspecified reason"


@_equal_dispatch.register
Expand Down
15 changes: 12 additions & 3 deletions tests/ops/functions/test_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1467,7 +1467,7 @@ def test_hamiltonian_equal(self, H1, H2, res):
assert qml.equal(H1, H2) == qml.equal(H2, H1)
assert qml.equal(H1, H2) == res
if not res:
error_message_pattern = re.compile(r"'([^']+)' and '([^']+)' are not same")
error_message_pattern = re.compile(r"'([^']+)' and '([^']+)' are not the same")
with pytest.raises(AssertionError, match=error_message_pattern):
assert_equal(H1, H2)

Expand All @@ -1481,7 +1481,7 @@ def test_tensors_not_equal(self):
"""Tensors are not equal because of different observable data"""
op1 = qml.operation.Tensor(qml.X(0), qml.Y(1))
op2 = qml.operation.Tensor(qml.Y(0), qml.X(1))
with pytest.raises(AssertionError, match="op1 and op2 have different _obs_data outputs"):
with pytest.raises(AssertionError, match="have different _obs_data outputs"):
assert_equal(op1, op2)

@pytest.mark.parametrize(("H", "T", "res"), equal_hamiltonians_and_tensors)
Expand Down Expand Up @@ -1522,13 +1522,22 @@ def test_tensor_and_operation_not_equal(self):
with pytest.raises(AssertionError, match="is not of type Observable"):
assert_equal(op1, op2)

def test_tensor_and_observable_not_equal(self):
"""Tests that comparing a Tensor with an Observable that is not a Tensor returns False"""
op1 = qml.PauliX(0) @ qml.PauliY(1)
op2 = qml.Z(0)
assert qml.equal(op1, op2) is False
assert qml.equal(op2, op1) is False
with pytest.raises(AssertionError, match="is of type <class 'pennylane.operation.Tensor'>"):
assert_equal(op1, op2)

def test_tensor_and_unsupported_observable_returns_false(self):
"""Tests that trying to compare a Tensor to something other than another Tensor or a Hamiltonian returns False"""
op1 = qml.PauliX(0) @ qml.PauliY(1)
op2 = qml.Hermitian([[0, 1], [1, 0]], 0)

assert not qml.equal(op1, op2)
error_message_pattern = re.compile(r"'([^']+)' and '([^']+)' are not same")
error_message_pattern = re.compile(r"'([^']+)' and '([^']+)' are not the same")
with pytest.raises(AssertionError, match=error_message_pattern):
assert_equal(op1, op2)

Expand Down

0 comments on commit b38fe79

Please sign in to comment.