Skip to content

Commit

Permalink
Add kwargs for equal functions (#5668)
Browse files Browse the repository at this point in the history
### Before submitting

Please complete the following checklist when submitting a PR:

- [x] All new features must include a unit test.
If you've fixed a bug or added code that should be tested, add a test to
the
      test directory!

- [x] All new functions and code must be clearly commented and
documented.
If you do make documentation changes, make sure that the docs build and
      render correctly by running `make docs`.

- [x] Ensure that the test suite passes, by running `make test`.

- [x] Add a new entry to the `doc/releases/changelog-dev.md` file,
summarizing the
      change, and including a link back to the PR.

- [x] The PennyLane source code conforms to
      [PEP8 standards](https://www.python.org/dev/peps/pep-0008/).
We check all of our code against [Pylint](https://www.pylint.org/).
      To lint modified files, simply `pip install pylint`, and then
      run `pylint pennylane/path/to/file.py`.

When all the above are checked, delete everything above the dashed
line and fill in the pull request template.


------------------------------------------------------------------------------------------------------------

**Context:**
The kwargs (`check_interface`, `check_trainability`, `atol` and `rtol`)
were not available for `_equal_pow`, `_equal_adjoint`, `_equal_exp` and
`_equal_sprod`, which are overrides of `qml.equal`. Therefore,
comparisons that allowed for interfacing and numerical errors were not
possible.

**Description of the Change:**
Changed to allow kwargs to be used in override of `qml.equal`.

**Benefits:**
The following interfaces and comparisons with tolerance for numerical
errors are possible.

```
>>> op1 = qml.RX(1.2, wires=0)
>>> op2 = qml.RX(1.2 + 1e-4, wires=0)
>>> qml.equal(op1 **2, op2 ** 2, atol=1e-3)
True
>>> qml.equal(op1 **2, op2 ** 2, atol=1e-6)
False
>>> qml.equal(qml.adjoint(op1), qml.adjoint(op2), atol=1e-3)
True
>>> qml.equal(qml.adjoint(op1), qml.adjoint(op2), atol=1e-6)
False
>>> op3 = qml.exp( qml.s_prod(2j, qml.X(0)))
>>> op4 = qml.exp(qml.s_prod(qml.numpy.array(2j), qml.X(0)))
>>> qml.equal(op3, op4, check_interface=True)
False
>>> qml.equal(op3, op4, check_interface=False)
True
```

**Possible Drawbacks:**

**Related GitHub Issues:**
#5408

---------

Co-authored-by: Christina Lee <christina@xanadu.ai>
  • Loading branch information
kenya-sk and albi3ro committed May 15, 2024
1 parent f960c0e commit daf2d76
Show file tree
Hide file tree
Showing 4 changed files with 201 additions and 18 deletions.
6 changes: 6 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@
allowing error types to be more consistent with the context the `decompose` function is used in.
[(#5669)](https://github.com/PennyLaneAI/pennylane/pull/5669)

<h4>Community contributions 🥳</h4>

* Implemented kwargs (`check_interface`, `check_trainability`, `rtol` and `atol`) support in `qml.equal` for the operators `Pow`, `Adjoint`, `Exp`, and `SProd`.
[(#5668)](https://github.com/PennyLaneAI/pennylane/issues/5668)

<h3>Breaking changes 💔</h3>

* `qml.is_commuting` no longer accepts the `wire_map` argument, which does not bring any functionality.
Expand Down Expand Up @@ -128,4 +133,5 @@ Pietropaolo Frisoni,
Soran Jahangiri,
Christina Lee,
Vincent Michaud-Rioux,
Kenya Sakka,
David Wierichs.
53 changes: 46 additions & 7 deletions pennylane/ops/functions/equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,42 +417,81 @@ def _equal_controlled_sequence(op1: ControlledSequence, op2: ControlledSequence,
# pylint: disable=unused-argument
def _equal_pow(op1: Pow, op2: Pow, **kwargs):
"""Determine whether two Pow objects are equal"""
check_interface, check_trainability = kwargs["check_interface"], kwargs["check_trainability"]

if check_interface:
if qml.math.get_interface(op1.z) != qml.math.get_interface(op2.z):
return False
if check_trainability:
if qml.math.requires_grad(op1.z) != qml.math.requires_grad(op2.z):
return False

if op1.z != op2.z:
return False
return qml.equal(op1.base, op2.base)

return qml.equal(op1.base, op2.base, **kwargs)


@_equal.register
# pylint: disable=unused-argument
def _equal_adjoint(op1: Adjoint, op2: Adjoint, **kwargs):
"""Determine whether two Adjoint objects are equal"""
# first line of top-level equal function already confirms both are Adjoint - only need to compare bases
return qml.equal(op1.base, op2.base)
return qml.equal(op1.base, op2.base, **kwargs)


@_equal.register
# pylint: disable=unused-argument
def _equal_exp(op1: Exp, op2: Exp, **kwargs):
"""Determine whether two Exp objects are equal"""
rtol, atol = (kwargs["rtol"], kwargs["atol"])
check_interface, check_trainability, rtol, atol = (
kwargs["check_interface"],
kwargs["check_trainability"],
kwargs["rtol"],
kwargs["atol"],
)

if check_interface:
for params1, params2 in zip(op1.data, op2.data):
if qml.math.get_interface(params1) != qml.math.get_interface(params2):
return False
if check_trainability:
for params1, params2 in zip(op1.data, op2.data):
if qml.math.requires_grad(params1) != qml.math.requires_grad(params2):
return False

if not qml.math.allclose(op1.coeff, op2.coeff, rtol=rtol, atol=atol):
return False
return qml.equal(op1.base, op2.base)

return qml.equal(op1.base, op2.base, **kwargs)


@_equal.register
# pylint: disable=unused-argument
def _equal_sprod(op1: SProd, op2: SProd, **kwargs):
"""Determine whether two SProd objects are equal"""
rtol, atol = (kwargs["rtol"], kwargs["atol"])
check_interface, check_trainability, rtol, atol = (
kwargs["check_interface"],
kwargs["check_trainability"],
kwargs["rtol"],
kwargs["atol"],
)

if check_interface:
for params1, params2 in zip(op1.data, op2.data):
if qml.math.get_interface(params1) != qml.math.get_interface(params2):
return False
if check_trainability:
for params1, params2 in zip(op1.data, op2.data):
if qml.math.requires_grad(params1) != qml.math.requires_grad(params2):
return False

if op1.pauli_rep is not None and (op1.pauli_rep == op2.pauli_rep): # shortcut check
return True

if not qml.math.allclose(op1.scalar, op2.scalar, rtol=rtol, atol=atol):
return False
return qml.equal(op1.base, op2.base)

return qml.equal(op1.base, op2.base, **kwargs)


@_equal.register
Expand Down
156 changes: 146 additions & 10 deletions tests/ops/functions/test_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1622,6 +1622,32 @@ def test_adjoint_comparison(self, base):
assert qml.equal(op1, op2)
assert not qml.equal(op1, op3)

def test_adjoint_comparison_with_tolerance(self):
"""Test that equal compares the parameters within a provided tolerance of the Adjoint class."""
op1 = qml.adjoint(qml.RX(1.2, wires=0))
op2 = qml.adjoint(qml.RX(1.2 + 1e-4, wires=0))

assert qml.equal(op1, op2, atol=1e-3, rtol=0)
assert not qml.equal(op1, op2, atol=1e-5, rtol=0)
assert qml.equal(op1, op2, atol=0, rtol=1e-3)
assert not qml.equal(op1, op2, atol=0, rtol=1e-5)

def test_adjoint_base_op_comparison_with_interface(self):
"""Test that equal compares the parameters within a provided interface of the base operator of Adjoint class."""
op1 = qml.adjoint(qml.RX(1.2, wires=0))
op2 = qml.adjoint(qml.RX(npp.array(1.2), wires=0))

assert qml.equal(op1, op2, check_interface=False, check_trainability=False)
assert not qml.equal(op1, op2, check_interface=True, check_trainability=False)

def test_adjoint_base_op_comparison_with_trainability(self):
"""Test that equal compares the parameters within a provided trainability of the base operator of Adjoint class."""
op1 = qml.adjoint(qml.RX(npp.array(1.2, requires_grad=False), wires=0))
op2 = qml.adjoint(qml.RX(npp.array(1.2, requires_grad=True), wires=0))

assert qml.equal(op1, op2, check_interface=False, check_trainability=False)
assert not qml.equal(op1, op2, check_interface=False, check_trainability=True)

@pytest.mark.parametrize("bases_bases_match", BASES)
@pytest.mark.parametrize("params_params_match", PARAMS)
def test_pow_comparison(self, bases_bases_match, params_params_match):
Expand All @@ -1632,6 +1658,48 @@ def test_pow_comparison(self, bases_bases_match, params_params_match):
op2 = qml.pow(base2, param2)
assert qml.equal(op1, op2) == (bases_match and params_match)

def test_pow_comparison_with_tolerance(self):
"""Test that equal compares the parameters within a provided tolerance of the Pow class."""
op1 = qml.pow(qml.RX(1.2, wires=0), 2)
op2 = qml.pow(qml.RX(1.2 + 1e-4, wires=0), 2)

assert qml.equal(op1, op2, atol=1e-3, rtol=0)
assert not qml.equal(op1, op2, atol=1e-5, rtol=0)
assert qml.equal(op1, op2, atol=0, rtol=1e-3)
assert not qml.equal(op1, op2, atol=0, rtol=1e-5)

def test_pow_comparison_with_interface(self):
"""Test that equal compares the parameters within a provided interface of the Pow class."""
op1 = qml.pow(qml.RX(1.2, wires=0), 2)
op2 = qml.pow(qml.RX(1.2, wires=0), npp.array(2))

assert qml.equal(op1, op2, check_interface=False, check_trainability=False)
assert not qml.equal(op1, op2, check_interface=True, check_trainability=False)

def test_pow_comparison_with_trainability(self):
"""Test that equal compares the parameters within a provided trainability of the Pow class."""
op1 = qml.pow(qml.RX(1.2, wires=0), npp.array(2, requires_grad=False))
op2 = qml.pow(qml.RX(1.2, wires=0), npp.array(2, requires_grad=True))

assert qml.equal(op1, op2, check_interface=False, check_trainability=False)
assert not qml.equal(op1, op2, check_interface=False, check_trainability=True)

def test_pow_base_op_comparison_with_interface(self):
"""Test that equal compares the parameters within a provided interface of the base operator of Pow class."""
op1 = qml.pow(qml.RX(1.2, wires=0), 2)
op2 = qml.pow(qml.RX(npp.array(1.2), wires=0), 2)

assert qml.equal(op1, op2, check_interface=False, check_trainability=False)
assert not qml.equal(op1, op2, check_interface=True, check_trainability=False)

def test_pow_base_op_comparison_with_trainability(self):
"""Test that equal compares the parameters within a provided trainability of the base operator of Pow class."""
op1 = qml.pow(qml.RX(npp.array(1.2, requires_grad=False), wires=0), 2)
op2 = qml.pow(qml.RX(npp.array(1.2, requires_grad=True), wires=0), 2)

assert qml.equal(op1, op2, check_interface=False, check_trainability=False)
assert not qml.equal(op1, op2, check_interface=False, check_trainability=True)

@pytest.mark.parametrize("bases_bases_match", BASES)
@pytest.mark.parametrize("params_params_match", PARAMS)
def test_exp_comparison(self, bases_bases_match, params_params_match):
Expand All @@ -1643,12 +1711,46 @@ def test_exp_comparison(self, bases_bases_match, params_params_match):
assert qml.equal(op1, op2) == (bases_match and params_match)

def test_exp_comparison_with_tolerance(self):
"""Test that equal compares the parameters within a provided tolerance."""
op1 = qml.exp(qml.PauliX(0), 0.12345)
op2 = qml.exp(qml.PauliX(0), 0.12356)
"""Test that equal compares the parameters within a provided tolerance of the Exp class."""
op1 = qml.exp(qml.PauliX(0), 0.12)
op2 = qml.exp(qml.PauliX(0), 0.12 + 1e-4)

assert qml.equal(op1, op2, atol=1e-3, rtol=0)
assert not qml.equal(op1, op2, atol=1e-5, rtol=0)
assert qml.equal(op1, op2, atol=0, rtol=1e-2)
assert not qml.equal(op1, op2, atol=0, rtol=1e-5)

assert qml.equal(op1, op2, atol=1e-3, rtol=1e-2)
assert not qml.equal(op1, op2, atol=1e-5, rtol=1e-4)
def test_exp_comparison_with_interface(self):
"""Test that equal compares the parameters within a provided interface of the Exp class."""
op1 = qml.exp(qml.PauliX(0), 1.2)
op2 = qml.exp(qml.PauliX(0), npp.array(1.2))

assert qml.equal(op1, op2, check_interface=False, check_trainability=False)
assert not qml.equal(op1, op2, check_interface=True, check_trainability=False)

def test_exp_comparison_with_trainability(self):
"""Test that equal compares the parameters within a provided trainability of the Exp class."""
op1 = qml.exp(qml.PauliX(0), npp.array(1.2, requires_grad=False))
op2 = qml.exp(qml.PauliX(0), npp.array(1.2, requires_grad=True))

assert qml.equal(op1, op2, check_interface=False, check_trainability=False)
assert not qml.equal(op1, op2, check_interface=False, check_trainability=True)

def test_exp_base_op_comparison_with_interface(self):
"""Test that equal compares the parameters within a provided interface of the base operator of Exp class."""
op1 = qml.exp(qml.RX(0.5, wires=0), 1.2)
op2 = qml.exp(qml.RX(npp.array(0.5), wires=0), 1.2)

assert qml.equal(op1, op2, check_interface=False, check_trainability=False)
assert not qml.equal(op1, op2, check_interface=True, check_trainability=False)

def test_exp_base_op_comparison_with_trainability(self):
"""Test that equal compares the parameters within a provided trainability of the base operator of Exp class."""
op1 = qml.exp(qml.RX(npp.array(0.5, requires_grad=False), wires=0), 1.2)
op2 = qml.exp(qml.RX(npp.array(0.5, requires_grad=True), wires=0), 1.2)

assert qml.equal(op1, op2, check_interface=False, check_trainability=False)
assert not qml.equal(op1, op2, check_interface=False, check_trainability=True)

additional_cases = [
(qml.sum(qml.PauliX(0), qml.PauliY(0)), qml.sum(qml.PauliY(0), qml.PauliX(0)), True),
Expand All @@ -1668,12 +1770,46 @@ def test_s_prod_comparison(self, bases_bases_match, params_params_match):
assert qml.equal(op1, op2) == (bases_match and params_match)

def test_s_prod_comparison_with_tolerance(self):
"""Test that equal compares the parameters within a provided tolerance."""
op1 = qml.s_prod(0.12345, qml.PauliX(0))
op2 = qml.s_prod(0.12356, qml.PauliX(0))
"""Test that equal compares the parameters within a provided tolerance of the SProd class."""
op1 = qml.s_prod(0.12, qml.PauliX(0))
op2 = qml.s_prod(0.12 + 1e-4, qml.PauliX(0))

assert qml.equal(op1, op2, atol=1e-3, rtol=0)
assert not qml.equal(op1, op2, atol=1e-5, rtol=0)
assert qml.equal(op1, op2, atol=0, rtol=1e-3)
assert not qml.equal(op1, op2, atol=0, rtol=1e-5)

assert qml.equal(op1, op2, atol=1e-3, rtol=1e-2)
assert not qml.equal(op1, op2, atol=1e-5, rtol=1e-4)
def test_s_prod_comparison_with_interface(self):
"""Test that equal compares the parameters within a provided interface of the SProd class."""
op1 = qml.s_prod(0.12, qml.PauliX(0))
op2 = qml.s_prod(npp.array(0.12), qml.PauliX(0))

assert qml.equal(op1, op2, check_interface=False, check_trainability=False)
assert not qml.equal(op1, op2, check_interface=True, check_trainability=False)

def test_s_prod_comparison_with_trainability(self):
"""Test that equal compares the parameters within a provided trainability of the SProd class."""
op1 = qml.s_prod(npp.array(0.12, requires_grad=False), qml.PauliX(0))
op2 = qml.s_prod(npp.array(0.12, requires_grad=True), qml.PauliX(0))

assert qml.equal(op1, op2, check_interface=False, check_trainability=False)
assert not qml.equal(op1, op2, check_interface=False, check_trainability=True)

def test_s_prod_base_op_comparison_with_interface(self):
"""Test that equal compares the parameters within a provided interface of the base operator of SProd class."""
op1 = qml.s_prod(0.12, qml.RX(0.5, wires=0))
op2 = qml.s_prod(0.12, qml.RX(npp.array(0.5), wires=0))

assert qml.equal(op1, op2, check_interface=False, check_trainability=False)
assert not qml.equal(op1, op2, check_interface=True, check_trainability=False)

def test_s_prod_base_op_comparison_with_trainability(self):
"""Test that equal compares the parameters within a provided trainability of the base operator of SProd class."""
op1 = qml.s_prod(0.12, qml.RX(npp.array(0.5, requires_grad=False), wires=0))
op2 = qml.s_prod(0.12, qml.RX(npp.array(0.5, requires_grad=True), wires=0))

assert qml.equal(op1, op2, check_interface=False, check_trainability=False)
assert not qml.equal(op1, op2, check_interface=False, check_trainability=True)


@pytest.mark.usefixtures("use_new_opmath")
Expand Down
4 changes: 3 additions & 1 deletion tests/ops/functions/test_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ def test_jit_simplification(self):
sum_op = qml.sum(qml.PauliX(0), qml.PauliX(0))
simp_op = jax.jit(qml.simplify)(sum_op)

assert qml.equal(simp_op, qml.s_prod(2.0, qml.PauliX(0)))
assert qml.equal(
simp_op, qml.s_prod(2.0, qml.PauliX(0)), check_interface=False, check_trainability=False
)


class TestSimplifyTapes:
Expand Down

0 comments on commit daf2d76

Please sign in to comment.