diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 4411b25e69c..7cbbcd67bd7 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -82,6 +82,11 @@ allowing error types to be more consistent with the context the `decompose` function is used in. [(#5669)](https://github.com/PennyLaneAI/pennylane/pull/5669) +

Community contributions 🥳

+ +* Implemented kwargs (`check_interface`, `check_trainability`, `rtol` and `atol`) support in `qml.equal` for the operators `Pow`, `Adjoint`, `Exp`, and `SProd`. + [(#5668)](https://github.com/PennyLaneAI/pennylane/issues/5668) +

Breaking changes 💔

* `qml.is_commuting` no longer accepts the `wire_map` argument, which does not bring any functionality. @@ -128,4 +133,5 @@ Pietropaolo Frisoni, Soran Jahangiri, Christina Lee, Vincent Michaud-Rioux, +Kenya Sakka, David Wierichs. diff --git a/pennylane/ops/functions/equal.py b/pennylane/ops/functions/equal.py index 8130fa02a91..44cfd42fee6 100644 --- a/pennylane/ops/functions/equal.py +++ b/pennylane/ops/functions/equal.py @@ -417,9 +417,19 @@ def _equal_controlled_sequence(op1: ControlledSequence, op2: ControlledSequence, # pylint: disable=unused-argument def _equal_pow(op1: Pow, op2: Pow, **kwargs): """Determine whether two Pow objects are equal""" + check_interface, check_trainability = kwargs["check_interface"], kwargs["check_trainability"] + + if check_interface: + if qml.math.get_interface(op1.z) != qml.math.get_interface(op2.z): + return False + if check_trainability: + if qml.math.requires_grad(op1.z) != qml.math.requires_grad(op2.z): + return False + if op1.z != op2.z: return False - return qml.equal(op1.base, op2.base) + + return qml.equal(op1.base, op2.base, **kwargs) @_equal.register @@ -427,32 +437,61 @@ def _equal_pow(op1: Pow, op2: Pow, **kwargs): def _equal_adjoint(op1: Adjoint, op2: Adjoint, **kwargs): """Determine whether two Adjoint objects are equal""" # first line of top-level equal function already confirms both are Adjoint - only need to compare bases - return qml.equal(op1.base, op2.base) + return qml.equal(op1.base, op2.base, **kwargs) @_equal.register # pylint: disable=unused-argument def _equal_exp(op1: Exp, op2: Exp, **kwargs): """Determine whether two Exp objects are equal""" - rtol, atol = (kwargs["rtol"], kwargs["atol"]) + check_interface, check_trainability, rtol, atol = ( + kwargs["check_interface"], + kwargs["check_trainability"], + kwargs["rtol"], + kwargs["atol"], + ) + + if check_interface: + for params1, params2 in zip(op1.data, op2.data): + if qml.math.get_interface(params1) != qml.math.get_interface(params2): + return False + if check_trainability: + for params1, params2 in zip(op1.data, op2.data): + if qml.math.requires_grad(params1) != qml.math.requires_grad(params2): + return False if not qml.math.allclose(op1.coeff, op2.coeff, rtol=rtol, atol=atol): return False - return qml.equal(op1.base, op2.base) + + return qml.equal(op1.base, op2.base, **kwargs) @_equal.register # pylint: disable=unused-argument def _equal_sprod(op1: SProd, op2: SProd, **kwargs): """Determine whether two SProd objects are equal""" - rtol, atol = (kwargs["rtol"], kwargs["atol"]) + check_interface, check_trainability, rtol, atol = ( + kwargs["check_interface"], + kwargs["check_trainability"], + kwargs["rtol"], + kwargs["atol"], + ) + + if check_interface: + for params1, params2 in zip(op1.data, op2.data): + if qml.math.get_interface(params1) != qml.math.get_interface(params2): + return False + if check_trainability: + for params1, params2 in zip(op1.data, op2.data): + if qml.math.requires_grad(params1) != qml.math.requires_grad(params2): + return False if op1.pauli_rep is not None and (op1.pauli_rep == op2.pauli_rep): # shortcut check return True - if not qml.math.allclose(op1.scalar, op2.scalar, rtol=rtol, atol=atol): return False - return qml.equal(op1.base, op2.base) + + return qml.equal(op1.base, op2.base, **kwargs) @_equal.register diff --git a/tests/ops/functions/test_equal.py b/tests/ops/functions/test_equal.py index 909e60eed9e..0cc8e633748 100644 --- a/tests/ops/functions/test_equal.py +++ b/tests/ops/functions/test_equal.py @@ -1622,6 +1622,32 @@ def test_adjoint_comparison(self, base): assert qml.equal(op1, op2) assert not qml.equal(op1, op3) + def test_adjoint_comparison_with_tolerance(self): + """Test that equal compares the parameters within a provided tolerance of the Adjoint class.""" + op1 = qml.adjoint(qml.RX(1.2, wires=0)) + op2 = qml.adjoint(qml.RX(1.2 + 1e-4, wires=0)) + + assert qml.equal(op1, op2, atol=1e-3, rtol=0) + assert not qml.equal(op1, op2, atol=1e-5, rtol=0) + assert qml.equal(op1, op2, atol=0, rtol=1e-3) + assert not qml.equal(op1, op2, atol=0, rtol=1e-5) + + def test_adjoint_base_op_comparison_with_interface(self): + """Test that equal compares the parameters within a provided interface of the base operator of Adjoint class.""" + op1 = qml.adjoint(qml.RX(1.2, wires=0)) + op2 = qml.adjoint(qml.RX(npp.array(1.2), wires=0)) + + assert qml.equal(op1, op2, check_interface=False, check_trainability=False) + assert not qml.equal(op1, op2, check_interface=True, check_trainability=False) + + def test_adjoint_base_op_comparison_with_trainability(self): + """Test that equal compares the parameters within a provided trainability of the base operator of Adjoint class.""" + op1 = qml.adjoint(qml.RX(npp.array(1.2, requires_grad=False), wires=0)) + op2 = qml.adjoint(qml.RX(npp.array(1.2, requires_grad=True), wires=0)) + + assert qml.equal(op1, op2, check_interface=False, check_trainability=False) + assert not qml.equal(op1, op2, check_interface=False, check_trainability=True) + @pytest.mark.parametrize("bases_bases_match", BASES) @pytest.mark.parametrize("params_params_match", PARAMS) def test_pow_comparison(self, bases_bases_match, params_params_match): @@ -1632,6 +1658,48 @@ def test_pow_comparison(self, bases_bases_match, params_params_match): op2 = qml.pow(base2, param2) assert qml.equal(op1, op2) == (bases_match and params_match) + def test_pow_comparison_with_tolerance(self): + """Test that equal compares the parameters within a provided tolerance of the Pow class.""" + op1 = qml.pow(qml.RX(1.2, wires=0), 2) + op2 = qml.pow(qml.RX(1.2 + 1e-4, wires=0), 2) + + assert qml.equal(op1, op2, atol=1e-3, rtol=0) + assert not qml.equal(op1, op2, atol=1e-5, rtol=0) + assert qml.equal(op1, op2, atol=0, rtol=1e-3) + assert not qml.equal(op1, op2, atol=0, rtol=1e-5) + + def test_pow_comparison_with_interface(self): + """Test that equal compares the parameters within a provided interface of the Pow class.""" + op1 = qml.pow(qml.RX(1.2, wires=0), 2) + op2 = qml.pow(qml.RX(1.2, wires=0), npp.array(2)) + + assert qml.equal(op1, op2, check_interface=False, check_trainability=False) + assert not qml.equal(op1, op2, check_interface=True, check_trainability=False) + + def test_pow_comparison_with_trainability(self): + """Test that equal compares the parameters within a provided trainability of the Pow class.""" + op1 = qml.pow(qml.RX(1.2, wires=0), npp.array(2, requires_grad=False)) + op2 = qml.pow(qml.RX(1.2, wires=0), npp.array(2, requires_grad=True)) + + assert qml.equal(op1, op2, check_interface=False, check_trainability=False) + assert not qml.equal(op1, op2, check_interface=False, check_trainability=True) + + def test_pow_base_op_comparison_with_interface(self): + """Test that equal compares the parameters within a provided interface of the base operator of Pow class.""" + op1 = qml.pow(qml.RX(1.2, wires=0), 2) + op2 = qml.pow(qml.RX(npp.array(1.2), wires=0), 2) + + assert qml.equal(op1, op2, check_interface=False, check_trainability=False) + assert not qml.equal(op1, op2, check_interface=True, check_trainability=False) + + def test_pow_base_op_comparison_with_trainability(self): + """Test that equal compares the parameters within a provided trainability of the base operator of Pow class.""" + op1 = qml.pow(qml.RX(npp.array(1.2, requires_grad=False), wires=0), 2) + op2 = qml.pow(qml.RX(npp.array(1.2, requires_grad=True), wires=0), 2) + + assert qml.equal(op1, op2, check_interface=False, check_trainability=False) + assert not qml.equal(op1, op2, check_interface=False, check_trainability=True) + @pytest.mark.parametrize("bases_bases_match", BASES) @pytest.mark.parametrize("params_params_match", PARAMS) def test_exp_comparison(self, bases_bases_match, params_params_match): @@ -1643,12 +1711,46 @@ def test_exp_comparison(self, bases_bases_match, params_params_match): assert qml.equal(op1, op2) == (bases_match and params_match) def test_exp_comparison_with_tolerance(self): - """Test that equal compares the parameters within a provided tolerance.""" - op1 = qml.exp(qml.PauliX(0), 0.12345) - op2 = qml.exp(qml.PauliX(0), 0.12356) + """Test that equal compares the parameters within a provided tolerance of the Exp class.""" + op1 = qml.exp(qml.PauliX(0), 0.12) + op2 = qml.exp(qml.PauliX(0), 0.12 + 1e-4) + + assert qml.equal(op1, op2, atol=1e-3, rtol=0) + assert not qml.equal(op1, op2, atol=1e-5, rtol=0) + assert qml.equal(op1, op2, atol=0, rtol=1e-2) + assert not qml.equal(op1, op2, atol=0, rtol=1e-5) - assert qml.equal(op1, op2, atol=1e-3, rtol=1e-2) - assert not qml.equal(op1, op2, atol=1e-5, rtol=1e-4) + def test_exp_comparison_with_interface(self): + """Test that equal compares the parameters within a provided interface of the Exp class.""" + op1 = qml.exp(qml.PauliX(0), 1.2) + op2 = qml.exp(qml.PauliX(0), npp.array(1.2)) + + assert qml.equal(op1, op2, check_interface=False, check_trainability=False) + assert not qml.equal(op1, op2, check_interface=True, check_trainability=False) + + def test_exp_comparison_with_trainability(self): + """Test that equal compares the parameters within a provided trainability of the Exp class.""" + op1 = qml.exp(qml.PauliX(0), npp.array(1.2, requires_grad=False)) + op2 = qml.exp(qml.PauliX(0), npp.array(1.2, requires_grad=True)) + + assert qml.equal(op1, op2, check_interface=False, check_trainability=False) + assert not qml.equal(op1, op2, check_interface=False, check_trainability=True) + + def test_exp_base_op_comparison_with_interface(self): + """Test that equal compares the parameters within a provided interface of the base operator of Exp class.""" + op1 = qml.exp(qml.RX(0.5, wires=0), 1.2) + op2 = qml.exp(qml.RX(npp.array(0.5), wires=0), 1.2) + + assert qml.equal(op1, op2, check_interface=False, check_trainability=False) + assert not qml.equal(op1, op2, check_interface=True, check_trainability=False) + + def test_exp_base_op_comparison_with_trainability(self): + """Test that equal compares the parameters within a provided trainability of the base operator of Exp class.""" + op1 = qml.exp(qml.RX(npp.array(0.5, requires_grad=False), wires=0), 1.2) + op2 = qml.exp(qml.RX(npp.array(0.5, requires_grad=True), wires=0), 1.2) + + assert qml.equal(op1, op2, check_interface=False, check_trainability=False) + assert not qml.equal(op1, op2, check_interface=False, check_trainability=True) additional_cases = [ (qml.sum(qml.PauliX(0), qml.PauliY(0)), qml.sum(qml.PauliY(0), qml.PauliX(0)), True), @@ -1668,12 +1770,46 @@ def test_s_prod_comparison(self, bases_bases_match, params_params_match): assert qml.equal(op1, op2) == (bases_match and params_match) def test_s_prod_comparison_with_tolerance(self): - """Test that equal compares the parameters within a provided tolerance.""" - op1 = qml.s_prod(0.12345, qml.PauliX(0)) - op2 = qml.s_prod(0.12356, qml.PauliX(0)) + """Test that equal compares the parameters within a provided tolerance of the SProd class.""" + op1 = qml.s_prod(0.12, qml.PauliX(0)) + op2 = qml.s_prod(0.12 + 1e-4, qml.PauliX(0)) + + assert qml.equal(op1, op2, atol=1e-3, rtol=0) + assert not qml.equal(op1, op2, atol=1e-5, rtol=0) + assert qml.equal(op1, op2, atol=0, rtol=1e-3) + assert not qml.equal(op1, op2, atol=0, rtol=1e-5) - assert qml.equal(op1, op2, atol=1e-3, rtol=1e-2) - assert not qml.equal(op1, op2, atol=1e-5, rtol=1e-4) + def test_s_prod_comparison_with_interface(self): + """Test that equal compares the parameters within a provided interface of the SProd class.""" + op1 = qml.s_prod(0.12, qml.PauliX(0)) + op2 = qml.s_prod(npp.array(0.12), qml.PauliX(0)) + + assert qml.equal(op1, op2, check_interface=False, check_trainability=False) + assert not qml.equal(op1, op2, check_interface=True, check_trainability=False) + + def test_s_prod_comparison_with_trainability(self): + """Test that equal compares the parameters within a provided trainability of the SProd class.""" + op1 = qml.s_prod(npp.array(0.12, requires_grad=False), qml.PauliX(0)) + op2 = qml.s_prod(npp.array(0.12, requires_grad=True), qml.PauliX(0)) + + assert qml.equal(op1, op2, check_interface=False, check_trainability=False) + assert not qml.equal(op1, op2, check_interface=False, check_trainability=True) + + def test_s_prod_base_op_comparison_with_interface(self): + """Test that equal compares the parameters within a provided interface of the base operator of SProd class.""" + op1 = qml.s_prod(0.12, qml.RX(0.5, wires=0)) + op2 = qml.s_prod(0.12, qml.RX(npp.array(0.5), wires=0)) + + assert qml.equal(op1, op2, check_interface=False, check_trainability=False) + assert not qml.equal(op1, op2, check_interface=True, check_trainability=False) + + def test_s_prod_base_op_comparison_with_trainability(self): + """Test that equal compares the parameters within a provided trainability of the base operator of SProd class.""" + op1 = qml.s_prod(0.12, qml.RX(npp.array(0.5, requires_grad=False), wires=0)) + op2 = qml.s_prod(0.12, qml.RX(npp.array(0.5, requires_grad=True), wires=0)) + + assert qml.equal(op1, op2, check_interface=False, check_trainability=False) + assert not qml.equal(op1, op2, check_interface=False, check_trainability=True) @pytest.mark.usefixtures("use_new_opmath") diff --git a/tests/ops/functions/test_simplify.py b/tests/ops/functions/test_simplify.py index d1f19cabdb3..738d873db77 100644 --- a/tests/ops/functions/test_simplify.py +++ b/tests/ops/functions/test_simplify.py @@ -80,7 +80,9 @@ def test_jit_simplification(self): sum_op = qml.sum(qml.PauliX(0), qml.PauliX(0)) simp_op = jax.jit(qml.simplify)(sum_op) - assert qml.equal(simp_op, qml.s_prod(2.0, qml.PauliX(0))) + assert qml.equal( + simp_op, qml.s_prod(2.0, qml.PauliX(0)), check_interface=False, check_trainability=False + ) class TestSimplifyTapes: