diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 4411b25e69c..7cbbcd67bd7 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -82,6 +82,11 @@
allowing error types to be more consistent with the context the `decompose` function is used in.
[(#5669)](https://github.com/PennyLaneAI/pennylane/pull/5669)
+
Community contributions 🥳
+
+* Implemented kwargs (`check_interface`, `check_trainability`, `rtol` and `atol`) support in `qml.equal` for the operators `Pow`, `Adjoint`, `Exp`, and `SProd`.
+ [(#5668)](https://github.com/PennyLaneAI/pennylane/issues/5668)
+
Breaking changes 💔
* `qml.is_commuting` no longer accepts the `wire_map` argument, which does not bring any functionality.
@@ -128,4 +133,5 @@ Pietropaolo Frisoni,
Soran Jahangiri,
Christina Lee,
Vincent Michaud-Rioux,
+Kenya Sakka,
David Wierichs.
diff --git a/pennylane/ops/functions/equal.py b/pennylane/ops/functions/equal.py
index 8130fa02a91..44cfd42fee6 100644
--- a/pennylane/ops/functions/equal.py
+++ b/pennylane/ops/functions/equal.py
@@ -417,9 +417,19 @@ def _equal_controlled_sequence(op1: ControlledSequence, op2: ControlledSequence,
# pylint: disable=unused-argument
def _equal_pow(op1: Pow, op2: Pow, **kwargs):
"""Determine whether two Pow objects are equal"""
+ check_interface, check_trainability = kwargs["check_interface"], kwargs["check_trainability"]
+
+ if check_interface:
+ if qml.math.get_interface(op1.z) != qml.math.get_interface(op2.z):
+ return False
+ if check_trainability:
+ if qml.math.requires_grad(op1.z) != qml.math.requires_grad(op2.z):
+ return False
+
if op1.z != op2.z:
return False
- return qml.equal(op1.base, op2.base)
+
+ return qml.equal(op1.base, op2.base, **kwargs)
@_equal.register
@@ -427,32 +437,61 @@ def _equal_pow(op1: Pow, op2: Pow, **kwargs):
def _equal_adjoint(op1: Adjoint, op2: Adjoint, **kwargs):
"""Determine whether two Adjoint objects are equal"""
# first line of top-level equal function already confirms both are Adjoint - only need to compare bases
- return qml.equal(op1.base, op2.base)
+ return qml.equal(op1.base, op2.base, **kwargs)
@_equal.register
# pylint: disable=unused-argument
def _equal_exp(op1: Exp, op2: Exp, **kwargs):
"""Determine whether two Exp objects are equal"""
- rtol, atol = (kwargs["rtol"], kwargs["atol"])
+ check_interface, check_trainability, rtol, atol = (
+ kwargs["check_interface"],
+ kwargs["check_trainability"],
+ kwargs["rtol"],
+ kwargs["atol"],
+ )
+
+ if check_interface:
+ for params1, params2 in zip(op1.data, op2.data):
+ if qml.math.get_interface(params1) != qml.math.get_interface(params2):
+ return False
+ if check_trainability:
+ for params1, params2 in zip(op1.data, op2.data):
+ if qml.math.requires_grad(params1) != qml.math.requires_grad(params2):
+ return False
if not qml.math.allclose(op1.coeff, op2.coeff, rtol=rtol, atol=atol):
return False
- return qml.equal(op1.base, op2.base)
+
+ return qml.equal(op1.base, op2.base, **kwargs)
@_equal.register
# pylint: disable=unused-argument
def _equal_sprod(op1: SProd, op2: SProd, **kwargs):
"""Determine whether two SProd objects are equal"""
- rtol, atol = (kwargs["rtol"], kwargs["atol"])
+ check_interface, check_trainability, rtol, atol = (
+ kwargs["check_interface"],
+ kwargs["check_trainability"],
+ kwargs["rtol"],
+ kwargs["atol"],
+ )
+
+ if check_interface:
+ for params1, params2 in zip(op1.data, op2.data):
+ if qml.math.get_interface(params1) != qml.math.get_interface(params2):
+ return False
+ if check_trainability:
+ for params1, params2 in zip(op1.data, op2.data):
+ if qml.math.requires_grad(params1) != qml.math.requires_grad(params2):
+ return False
if op1.pauli_rep is not None and (op1.pauli_rep == op2.pauli_rep): # shortcut check
return True
-
if not qml.math.allclose(op1.scalar, op2.scalar, rtol=rtol, atol=atol):
return False
- return qml.equal(op1.base, op2.base)
+
+ return qml.equal(op1.base, op2.base, **kwargs)
@_equal.register
diff --git a/tests/ops/functions/test_equal.py b/tests/ops/functions/test_equal.py
index 909e60eed9e..0cc8e633748 100644
--- a/tests/ops/functions/test_equal.py
+++ b/tests/ops/functions/test_equal.py
@@ -1622,6 +1622,32 @@ def test_adjoint_comparison(self, base):
assert qml.equal(op1, op2)
assert not qml.equal(op1, op3)
+ def test_adjoint_comparison_with_tolerance(self):
+ """Test that equal compares the parameters within a provided tolerance of the Adjoint class."""
+ op1 = qml.adjoint(qml.RX(1.2, wires=0))
+ op2 = qml.adjoint(qml.RX(1.2 + 1e-4, wires=0))
+
+ assert qml.equal(op1, op2, atol=1e-3, rtol=0)
+ assert not qml.equal(op1, op2, atol=1e-5, rtol=0)
+ assert qml.equal(op1, op2, atol=0, rtol=1e-3)
+ assert not qml.equal(op1, op2, atol=0, rtol=1e-5)
+
+ def test_adjoint_base_op_comparison_with_interface(self):
+ """Test that equal compares the parameters within a provided interface of the base operator of Adjoint class."""
+ op1 = qml.adjoint(qml.RX(1.2, wires=0))
+ op2 = qml.adjoint(qml.RX(npp.array(1.2), wires=0))
+
+ assert qml.equal(op1, op2, check_interface=False, check_trainability=False)
+ assert not qml.equal(op1, op2, check_interface=True, check_trainability=False)
+
+ def test_adjoint_base_op_comparison_with_trainability(self):
+ """Test that equal compares the parameters within a provided trainability of the base operator of Adjoint class."""
+ op1 = qml.adjoint(qml.RX(npp.array(1.2, requires_grad=False), wires=0))
+ op2 = qml.adjoint(qml.RX(npp.array(1.2, requires_grad=True), wires=0))
+
+ assert qml.equal(op1, op2, check_interface=False, check_trainability=False)
+ assert not qml.equal(op1, op2, check_interface=False, check_trainability=True)
+
@pytest.mark.parametrize("bases_bases_match", BASES)
@pytest.mark.parametrize("params_params_match", PARAMS)
def test_pow_comparison(self, bases_bases_match, params_params_match):
@@ -1632,6 +1658,48 @@ def test_pow_comparison(self, bases_bases_match, params_params_match):
op2 = qml.pow(base2, param2)
assert qml.equal(op1, op2) == (bases_match and params_match)
+ def test_pow_comparison_with_tolerance(self):
+ """Test that equal compares the parameters within a provided tolerance of the Pow class."""
+ op1 = qml.pow(qml.RX(1.2, wires=0), 2)
+ op2 = qml.pow(qml.RX(1.2 + 1e-4, wires=0), 2)
+
+ assert qml.equal(op1, op2, atol=1e-3, rtol=0)
+ assert not qml.equal(op1, op2, atol=1e-5, rtol=0)
+ assert qml.equal(op1, op2, atol=0, rtol=1e-3)
+ assert not qml.equal(op1, op2, atol=0, rtol=1e-5)
+
+ def test_pow_comparison_with_interface(self):
+ """Test that equal compares the parameters within a provided interface of the Pow class."""
+ op1 = qml.pow(qml.RX(1.2, wires=0), 2)
+ op2 = qml.pow(qml.RX(1.2, wires=0), npp.array(2))
+
+ assert qml.equal(op1, op2, check_interface=False, check_trainability=False)
+ assert not qml.equal(op1, op2, check_interface=True, check_trainability=False)
+
+ def test_pow_comparison_with_trainability(self):
+ """Test that equal compares the parameters within a provided trainability of the Pow class."""
+ op1 = qml.pow(qml.RX(1.2, wires=0), npp.array(2, requires_grad=False))
+ op2 = qml.pow(qml.RX(1.2, wires=0), npp.array(2, requires_grad=True))
+
+ assert qml.equal(op1, op2, check_interface=False, check_trainability=False)
+ assert not qml.equal(op1, op2, check_interface=False, check_trainability=True)
+
+ def test_pow_base_op_comparison_with_interface(self):
+ """Test that equal compares the parameters within a provided interface of the base operator of Pow class."""
+ op1 = qml.pow(qml.RX(1.2, wires=0), 2)
+ op2 = qml.pow(qml.RX(npp.array(1.2), wires=0), 2)
+
+ assert qml.equal(op1, op2, check_interface=False, check_trainability=False)
+ assert not qml.equal(op1, op2, check_interface=True, check_trainability=False)
+
+ def test_pow_base_op_comparison_with_trainability(self):
+ """Test that equal compares the parameters within a provided trainability of the base operator of Pow class."""
+ op1 = qml.pow(qml.RX(npp.array(1.2, requires_grad=False), wires=0), 2)
+ op2 = qml.pow(qml.RX(npp.array(1.2, requires_grad=True), wires=0), 2)
+
+ assert qml.equal(op1, op2, check_interface=False, check_trainability=False)
+ assert not qml.equal(op1, op2, check_interface=False, check_trainability=True)
+
@pytest.mark.parametrize("bases_bases_match", BASES)
@pytest.mark.parametrize("params_params_match", PARAMS)
def test_exp_comparison(self, bases_bases_match, params_params_match):
@@ -1643,12 +1711,46 @@ def test_exp_comparison(self, bases_bases_match, params_params_match):
assert qml.equal(op1, op2) == (bases_match and params_match)
def test_exp_comparison_with_tolerance(self):
- """Test that equal compares the parameters within a provided tolerance."""
- op1 = qml.exp(qml.PauliX(0), 0.12345)
- op2 = qml.exp(qml.PauliX(0), 0.12356)
+ """Test that equal compares the parameters within a provided tolerance of the Exp class."""
+ op1 = qml.exp(qml.PauliX(0), 0.12)
+ op2 = qml.exp(qml.PauliX(0), 0.12 + 1e-4)
+
+ assert qml.equal(op1, op2, atol=1e-3, rtol=0)
+ assert not qml.equal(op1, op2, atol=1e-5, rtol=0)
+ assert qml.equal(op1, op2, atol=0, rtol=1e-2)
+ assert not qml.equal(op1, op2, atol=0, rtol=1e-5)
- assert qml.equal(op1, op2, atol=1e-3, rtol=1e-2)
- assert not qml.equal(op1, op2, atol=1e-5, rtol=1e-4)
+ def test_exp_comparison_with_interface(self):
+ """Test that equal compares the parameters within a provided interface of the Exp class."""
+ op1 = qml.exp(qml.PauliX(0), 1.2)
+ op2 = qml.exp(qml.PauliX(0), npp.array(1.2))
+
+ assert qml.equal(op1, op2, check_interface=False, check_trainability=False)
+ assert not qml.equal(op1, op2, check_interface=True, check_trainability=False)
+
+ def test_exp_comparison_with_trainability(self):
+ """Test that equal compares the parameters within a provided trainability of the Exp class."""
+ op1 = qml.exp(qml.PauliX(0), npp.array(1.2, requires_grad=False))
+ op2 = qml.exp(qml.PauliX(0), npp.array(1.2, requires_grad=True))
+
+ assert qml.equal(op1, op2, check_interface=False, check_trainability=False)
+ assert not qml.equal(op1, op2, check_interface=False, check_trainability=True)
+
+ def test_exp_base_op_comparison_with_interface(self):
+ """Test that equal compares the parameters within a provided interface of the base operator of Exp class."""
+ op1 = qml.exp(qml.RX(0.5, wires=0), 1.2)
+ op2 = qml.exp(qml.RX(npp.array(0.5), wires=0), 1.2)
+
+ assert qml.equal(op1, op2, check_interface=False, check_trainability=False)
+ assert not qml.equal(op1, op2, check_interface=True, check_trainability=False)
+
+ def test_exp_base_op_comparison_with_trainability(self):
+ """Test that equal compares the parameters within a provided trainability of the base operator of Exp class."""
+ op1 = qml.exp(qml.RX(npp.array(0.5, requires_grad=False), wires=0), 1.2)
+ op2 = qml.exp(qml.RX(npp.array(0.5, requires_grad=True), wires=0), 1.2)
+
+ assert qml.equal(op1, op2, check_interface=False, check_trainability=False)
+ assert not qml.equal(op1, op2, check_interface=False, check_trainability=True)
additional_cases = [
(qml.sum(qml.PauliX(0), qml.PauliY(0)), qml.sum(qml.PauliY(0), qml.PauliX(0)), True),
@@ -1668,12 +1770,46 @@ def test_s_prod_comparison(self, bases_bases_match, params_params_match):
assert qml.equal(op1, op2) == (bases_match and params_match)
def test_s_prod_comparison_with_tolerance(self):
- """Test that equal compares the parameters within a provided tolerance."""
- op1 = qml.s_prod(0.12345, qml.PauliX(0))
- op2 = qml.s_prod(0.12356, qml.PauliX(0))
+ """Test that equal compares the parameters within a provided tolerance of the SProd class."""
+ op1 = qml.s_prod(0.12, qml.PauliX(0))
+ op2 = qml.s_prod(0.12 + 1e-4, qml.PauliX(0))
+
+ assert qml.equal(op1, op2, atol=1e-3, rtol=0)
+ assert not qml.equal(op1, op2, atol=1e-5, rtol=0)
+ assert qml.equal(op1, op2, atol=0, rtol=1e-3)
+ assert not qml.equal(op1, op2, atol=0, rtol=1e-5)
- assert qml.equal(op1, op2, atol=1e-3, rtol=1e-2)
- assert not qml.equal(op1, op2, atol=1e-5, rtol=1e-4)
+ def test_s_prod_comparison_with_interface(self):
+ """Test that equal compares the parameters within a provided interface of the SProd class."""
+ op1 = qml.s_prod(0.12, qml.PauliX(0))
+ op2 = qml.s_prod(npp.array(0.12), qml.PauliX(0))
+
+ assert qml.equal(op1, op2, check_interface=False, check_trainability=False)
+ assert not qml.equal(op1, op2, check_interface=True, check_trainability=False)
+
+ def test_s_prod_comparison_with_trainability(self):
+ """Test that equal compares the parameters within a provided trainability of the SProd class."""
+ op1 = qml.s_prod(npp.array(0.12, requires_grad=False), qml.PauliX(0))
+ op2 = qml.s_prod(npp.array(0.12, requires_grad=True), qml.PauliX(0))
+
+ assert qml.equal(op1, op2, check_interface=False, check_trainability=False)
+ assert not qml.equal(op1, op2, check_interface=False, check_trainability=True)
+
+ def test_s_prod_base_op_comparison_with_interface(self):
+ """Test that equal compares the parameters within a provided interface of the base operator of SProd class."""
+ op1 = qml.s_prod(0.12, qml.RX(0.5, wires=0))
+ op2 = qml.s_prod(0.12, qml.RX(npp.array(0.5), wires=0))
+
+ assert qml.equal(op1, op2, check_interface=False, check_trainability=False)
+ assert not qml.equal(op1, op2, check_interface=True, check_trainability=False)
+
+ def test_s_prod_base_op_comparison_with_trainability(self):
+ """Test that equal compares the parameters within a provided trainability of the base operator of SProd class."""
+ op1 = qml.s_prod(0.12, qml.RX(npp.array(0.5, requires_grad=False), wires=0))
+ op2 = qml.s_prod(0.12, qml.RX(npp.array(0.5, requires_grad=True), wires=0))
+
+ assert qml.equal(op1, op2, check_interface=False, check_trainability=False)
+ assert not qml.equal(op1, op2, check_interface=False, check_trainability=True)
@pytest.mark.usefixtures("use_new_opmath")
diff --git a/tests/ops/functions/test_simplify.py b/tests/ops/functions/test_simplify.py
index d1f19cabdb3..738d873db77 100644
--- a/tests/ops/functions/test_simplify.py
+++ b/tests/ops/functions/test_simplify.py
@@ -80,7 +80,9 @@ def test_jit_simplification(self):
sum_op = qml.sum(qml.PauliX(0), qml.PauliX(0))
simp_op = jax.jit(qml.simplify)(sum_op)
- assert qml.equal(simp_op, qml.s_prod(2.0, qml.PauliX(0)))
+ assert qml.equal(
+ simp_op, qml.s_prod(2.0, qml.PauliX(0)), check_interface=False, check_trainability=False
+ )
class TestSimplifyTapes: