Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add kwargs for equal functions #5668

Merged
merged 10 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@
allowing error types to be more consistent with the context the `decompose` function is used in.
[(#5669)](https://github.com/PennyLaneAI/pennylane/pull/5669)

<h4>Community contributions 🥳</h4>

* Implemented kwargs (`check_interface`, `check_trainability`, `rtol` and `atol`) support in `qml.equal` for the operators `Pow`, `Adjoint`, `Exp`, and `SProd`.
[(#5668)](https://github.com/PennyLaneAI/pennylane/issues/5668)

<h3>Breaking changes 💔</h3>

* `qml.is_commuting` no longer accepts the `wire_map` argument, which does not bring any functionality.
Expand Down Expand Up @@ -128,4 +133,5 @@ Pietropaolo Frisoni,
Soran Jahangiri,
Christina Lee,
Vincent Michaud-Rioux,
Kenya Sakka,
David Wierichs.
53 changes: 46 additions & 7 deletions pennylane/ops/functions/equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,42 +417,81 @@ def _equal_controlled_sequence(op1: ControlledSequence, op2: ControlledSequence,
# pylint: disable=unused-argument
def _equal_pow(op1: Pow, op2: Pow, **kwargs):
"""Determine whether two Pow objects are equal"""
check_interface, check_trainability = kwargs["check_interface"], kwargs["check_trainability"]

if check_interface:
if qml.math.get_interface(op1.z) != qml.math.get_interface(op2.z):
return False
if check_trainability:
if qml.math.requires_grad(op1.z) != qml.math.requires_grad(op2.z):
return False
kenya-sk marked this conversation as resolved.
Show resolved Hide resolved

if op1.z != op2.z:
return False
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
return qml.equal(op1.base, op2.base)

return qml.equal(op1.base, op2.base, **kwargs)


@_equal.register
# pylint: disable=unused-argument
def _equal_adjoint(op1: Adjoint, op2: Adjoint, **kwargs):
"""Determine whether two Adjoint objects are equal"""
# first line of top-level equal function already confirms both are Adjoint - only need to compare bases
return qml.equal(op1.base, op2.base)
return qml.equal(op1.base, op2.base, **kwargs)


@_equal.register
# pylint: disable=unused-argument
def _equal_exp(op1: Exp, op2: Exp, **kwargs):
"""Determine whether two Exp objects are equal"""
rtol, atol = (kwargs["rtol"], kwargs["atol"])
check_interface, check_trainability, rtol, atol = (
kwargs["check_interface"],
kwargs["check_trainability"],
kwargs["rtol"],
kwargs["atol"],
)

if check_interface:
for params1, params2 in zip(op1.data, op2.data):
if qml.math.get_interface(params1) != qml.math.get_interface(params2):
return False
if check_trainability:
for params1, params2 in zip(op1.data, op2.data):
if qml.math.requires_grad(params1) != qml.math.requires_grad(params2):
return False

if not qml.math.allclose(op1.coeff, op2.coeff, rtol=rtol, atol=atol):
return False
return qml.equal(op1.base, op2.base)

return qml.equal(op1.base, op2.base, **kwargs)


@_equal.register
# pylint: disable=unused-argument
def _equal_sprod(op1: SProd, op2: SProd, **kwargs):
"""Determine whether two SProd objects are equal"""
rtol, atol = (kwargs["rtol"], kwargs["atol"])
check_interface, check_trainability, rtol, atol = (
kwargs["check_interface"],
kwargs["check_trainability"],
kwargs["rtol"],
kwargs["atol"],
)

if check_interface:
for params1, params2 in zip(op1.data, op2.data):
if qml.math.get_interface(params1) != qml.math.get_interface(params2):
return False
if check_trainability:
for params1, params2 in zip(op1.data, op2.data):
if qml.math.requires_grad(params1) != qml.math.requires_grad(params2):
return False

if op1.pauli_rep is not None and (op1.pauli_rep == op2.pauli_rep): # shortcut check
return True

if not qml.math.allclose(op1.scalar, op2.scalar, rtol=rtol, atol=atol):
return False
return qml.equal(op1.base, op2.base)

return qml.equal(op1.base, op2.base, **kwargs)


@_equal.register
Expand Down
156 changes: 146 additions & 10 deletions tests/ops/functions/test_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1622,6 +1622,32 @@ def test_adjoint_comparison(self, base):
assert qml.equal(op1, op2)
assert not qml.equal(op1, op3)

def test_adjoint_comparison_with_tolerance(self):
"""Test that equal compares the parameters within a provided tolerance of the Adjoint class."""
op1 = qml.adjoint(qml.RX(1.2, wires=0))
op2 = qml.adjoint(qml.RX(1.2 + 1e-4, wires=0))

assert qml.equal(op1, op2, atol=1e-3, rtol=0)
assert not qml.equal(op1, op2, atol=1e-5, rtol=0)
assert qml.equal(op1, op2, atol=0, rtol=1e-3)
assert not qml.equal(op1, op2, atol=0, rtol=1e-5)

def test_adjoint_base_op_comparison_with_interface(self):
"""Test that equal compares the parameters within a provided interface of the base operator of Adjoint class."""
op1 = qml.adjoint(qml.RX(1.2, wires=0))
op2 = qml.adjoint(qml.RX(npp.array(1.2), wires=0))

assert qml.equal(op1, op2, check_interface=False, check_trainability=False)
assert not qml.equal(op1, op2, check_interface=True, check_trainability=False)

def test_adjoint_base_op_comparison_with_trainability(self):
"""Test that equal compares the parameters within a provided trainability of the base operator of Adjoint class."""
op1 = qml.adjoint(qml.RX(npp.array(1.2, requires_grad=False), wires=0))
op2 = qml.adjoint(qml.RX(npp.array(1.2, requires_grad=True), wires=0))

assert qml.equal(op1, op2, check_interface=False, check_trainability=False)
assert not qml.equal(op1, op2, check_interface=False, check_trainability=True)

@pytest.mark.parametrize("bases_bases_match", BASES)
@pytest.mark.parametrize("params_params_match", PARAMS)
def test_pow_comparison(self, bases_bases_match, params_params_match):
Expand All @@ -1632,6 +1658,48 @@ def test_pow_comparison(self, bases_bases_match, params_params_match):
op2 = qml.pow(base2, param2)
assert qml.equal(op1, op2) == (bases_match and params_match)

def test_pow_comparison_with_tolerance(self):
"""Test that equal compares the parameters within a provided tolerance of the Pow class."""
op1 = qml.pow(qml.RX(1.2, wires=0), 2)
op2 = qml.pow(qml.RX(1.2 + 1e-4, wires=0), 2)

assert qml.equal(op1, op2, atol=1e-3, rtol=0)
assert not qml.equal(op1, op2, atol=1e-5, rtol=0)
assert qml.equal(op1, op2, atol=0, rtol=1e-3)
assert not qml.equal(op1, op2, atol=0, rtol=1e-5)

def test_pow_comparison_with_interface(self):
"""Test that equal compares the parameters within a provided interface of the Pow class."""
op1 = qml.pow(qml.RX(1.2, wires=0), 2)
op2 = qml.pow(qml.RX(1.2, wires=0), npp.array(2))

assert qml.equal(op1, op2, check_interface=False, check_trainability=False)
assert not qml.equal(op1, op2, check_interface=True, check_trainability=False)

def test_pow_comparison_with_trainability(self):
"""Test that equal compares the parameters within a provided trainability of the Pow class."""
op1 = qml.pow(qml.RX(1.2, wires=0), npp.array(2, requires_grad=False))
op2 = qml.pow(qml.RX(1.2, wires=0), npp.array(2, requires_grad=True))

assert qml.equal(op1, op2, check_interface=False, check_trainability=False)
assert not qml.equal(op1, op2, check_interface=False, check_trainability=True)

def test_pow_base_op_comparison_with_interface(self):
"""Test that equal compares the parameters within a provided interface of the base operator of Pow class."""
op1 = qml.pow(qml.RX(1.2, wires=0), 2)
op2 = qml.pow(qml.RX(npp.array(1.2), wires=0), 2)

assert qml.equal(op1, op2, check_interface=False, check_trainability=False)
assert not qml.equal(op1, op2, check_interface=True, check_trainability=False)

def test_pow_base_op_comparison_with_trainability(self):
"""Test that equal compares the parameters within a provided trainability of the base operator of Pow class."""
op1 = qml.pow(qml.RX(npp.array(1.2, requires_grad=False), wires=0), 2)
op2 = qml.pow(qml.RX(npp.array(1.2, requires_grad=True), wires=0), 2)

assert qml.equal(op1, op2, check_interface=False, check_trainability=False)
assert not qml.equal(op1, op2, check_interface=False, check_trainability=True)
kenya-sk marked this conversation as resolved.
Show resolved Hide resolved

@pytest.mark.parametrize("bases_bases_match", BASES)
@pytest.mark.parametrize("params_params_match", PARAMS)
def test_exp_comparison(self, bases_bases_match, params_params_match):
Expand All @@ -1643,12 +1711,46 @@ def test_exp_comparison(self, bases_bases_match, params_params_match):
assert qml.equal(op1, op2) == (bases_match and params_match)

def test_exp_comparison_with_tolerance(self):
"""Test that equal compares the parameters within a provided tolerance."""
op1 = qml.exp(qml.PauliX(0), 0.12345)
op2 = qml.exp(qml.PauliX(0), 0.12356)
"""Test that equal compares the parameters within a provided tolerance of the Exp class."""
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
op1 = qml.exp(qml.PauliX(0), 0.12)
op2 = qml.exp(qml.PauliX(0), 0.12 + 1e-4)

assert qml.equal(op1, op2, atol=1e-3, rtol=0)
assert not qml.equal(op1, op2, atol=1e-5, rtol=0)
assert qml.equal(op1, op2, atol=0, rtol=1e-2)
assert not qml.equal(op1, op2, atol=0, rtol=1e-5)

assert qml.equal(op1, op2, atol=1e-3, rtol=1e-2)
assert not qml.equal(op1, op2, atol=1e-5, rtol=1e-4)
def test_exp_comparison_with_interface(self):
"""Test that equal compares the parameters within a provided interface of the Exp class."""
op1 = qml.exp(qml.PauliX(0), 1.2)
op2 = qml.exp(qml.PauliX(0), npp.array(1.2))

assert qml.equal(op1, op2, check_interface=False, check_trainability=False)
assert not qml.equal(op1, op2, check_interface=True, check_trainability=False)

def test_exp_comparison_with_trainability(self):
"""Test that equal compares the parameters within a provided trainability of the Exp class."""
op1 = qml.exp(qml.PauliX(0), npp.array(1.2, requires_grad=False))
op2 = qml.exp(qml.PauliX(0), npp.array(1.2, requires_grad=True))

assert qml.equal(op1, op2, check_interface=False, check_trainability=False)
assert not qml.equal(op1, op2, check_interface=False, check_trainability=True)

def test_exp_base_op_comparison_with_interface(self):
"""Test that equal compares the parameters within a provided interface of the base operator of Exp class."""
op1 = qml.exp(qml.RX(0.5, wires=0), 1.2)
op2 = qml.exp(qml.RX(npp.array(0.5), wires=0), 1.2)

assert qml.equal(op1, op2, check_interface=False, check_trainability=False)
assert not qml.equal(op1, op2, check_interface=True, check_trainability=False)

def test_exp_base_op_comparison_with_trainability(self):
"""Test that equal compares the parameters within a provided trainability of the base operator of Exp class."""
op1 = qml.exp(qml.RX(npp.array(0.5, requires_grad=False), wires=0), 1.2)
op2 = qml.exp(qml.RX(npp.array(0.5, requires_grad=True), wires=0), 1.2)

assert qml.equal(op1, op2, check_interface=False, check_trainability=False)
assert not qml.equal(op1, op2, check_interface=False, check_trainability=True)

additional_cases = [
(qml.sum(qml.PauliX(0), qml.PauliY(0)), qml.sum(qml.PauliY(0), qml.PauliX(0)), True),
Expand All @@ -1668,12 +1770,46 @@ def test_s_prod_comparison(self, bases_bases_match, params_params_match):
assert qml.equal(op1, op2) == (bases_match and params_match)

def test_s_prod_comparison_with_tolerance(self):
"""Test that equal compares the parameters within a provided tolerance."""
op1 = qml.s_prod(0.12345, qml.PauliX(0))
op2 = qml.s_prod(0.12356, qml.PauliX(0))
"""Test that equal compares the parameters within a provided tolerance of the SProd class."""
op1 = qml.s_prod(0.12, qml.PauliX(0))
op2 = qml.s_prod(0.12 + 1e-4, qml.PauliX(0))

assert qml.equal(op1, op2, atol=1e-3, rtol=0)
assert not qml.equal(op1, op2, atol=1e-5, rtol=0)
assert qml.equal(op1, op2, atol=0, rtol=1e-3)
assert not qml.equal(op1, op2, atol=0, rtol=1e-5)

assert qml.equal(op1, op2, atol=1e-3, rtol=1e-2)
assert not qml.equal(op1, op2, atol=1e-5, rtol=1e-4)
def test_s_prod_comparison_with_interface(self):
"""Test that equal compares the parameters within a provided interface of the SProd class."""
op1 = qml.s_prod(0.12, qml.PauliX(0))
op2 = qml.s_prod(npp.array(0.12), qml.PauliX(0))

assert qml.equal(op1, op2, check_interface=False, check_trainability=False)
assert not qml.equal(op1, op2, check_interface=True, check_trainability=False)

def test_s_prod_comparison_with_trainability(self):
"""Test that equal compares the parameters within a provided trainability of the SProd class."""
op1 = qml.s_prod(npp.array(0.12, requires_grad=False), qml.PauliX(0))
op2 = qml.s_prod(npp.array(0.12, requires_grad=True), qml.PauliX(0))

assert qml.equal(op1, op2, check_interface=False, check_trainability=False)
assert not qml.equal(op1, op2, check_interface=False, check_trainability=True)

def test_s_prod_base_op_comparison_with_interface(self):
"""Test that equal compares the parameters within a provided interface of the base operator of SProd class."""
op1 = qml.s_prod(0.12, qml.RX(0.5, wires=0))
op2 = qml.s_prod(0.12, qml.RX(npp.array(0.5), wires=0))

assert qml.equal(op1, op2, check_interface=False, check_trainability=False)
assert not qml.equal(op1, op2, check_interface=True, check_trainability=False)

def test_s_prod_base_op_comparison_with_trainability(self):
"""Test that equal compares the parameters within a provided trainability of the base operator of SProd class."""
op1 = qml.s_prod(0.12, qml.RX(npp.array(0.5, requires_grad=False), wires=0))
op2 = qml.s_prod(0.12, qml.RX(npp.array(0.5, requires_grad=True), wires=0))

assert qml.equal(op1, op2, check_interface=False, check_trainability=False)
assert not qml.equal(op1, op2, check_interface=False, check_trainability=True)


@pytest.mark.usefixtures("use_new_opmath")
Expand Down
4 changes: 3 additions & 1 deletion tests/ops/functions/test_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ def test_jit_simplification(self):
sum_op = qml.sum(qml.PauliX(0), qml.PauliX(0))
simp_op = jax.jit(qml.simplify)(sum_op)

assert qml.equal(simp_op, qml.s_prod(2.0, qml.PauliX(0)))
assert qml.equal(
simp_op, qml.s_prod(2.0, qml.PauliX(0)), check_interface=False, check_trainability=False
)


class TestSimplifyTapes:
Expand Down
Loading