Skip to content

Commit

Permalink
SProd.terms() flattens out nested terms (#5885)
Browse files Browse the repository at this point in the history
**Context:**
The `.terms()` method of `Sum`, `LinearCombination`, `Hamiltonian`,
`Prod` all flatten out any nested terms, but `SProd` does not. For
examples:
```pycon
>>> coeffs, ops = (2 * (qml.X(0) + qml.Y(0))).terms()
>>> len(ops)
1
```

**Description of the Change:**
Make it so that `SProd` flattens out the terms.

**Benefits:**
More consistent behaviour

[sc-66627]
  • Loading branch information
astralcai committed Jul 12, 2024
1 parent dbb26f2 commit 2963166
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 9 deletions.
5 changes: 4 additions & 1 deletion doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
* The `qml.PrepSelPrep` template is added. The template implements a block-encoding of a linear
combination of unitaries.
[(#5756)](https://github.com/PennyLaneAI/pennylane/pull/5756)

* `SProd.terms` now flattens out the terms if the base is a multi-term observable.
[(#5885)](https://github.com/PennyLaneAI/pennylane/pull/5885)

<h3>Improvements 🛠</h3>

Expand Down Expand Up @@ -59,7 +62,7 @@

This release contains contributions from (in alphabetical order):

Ahmed Darwish
Ahmed Darwish,
Astral Cai,
Yushao Chen,
Pietropaolo Frisoni,
Expand Down
11 changes: 8 additions & 3 deletions pennylane/ops/op_math/sprod.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import pennylane as qml
import pennylane.math as qnp
from pennylane.operation import Operator, convert_to_opmath
from pennylane.operation import Operator, TermsUndefinedError, convert_to_opmath
from pennylane.ops.op_math.pow import Pow
from pennylane.ops.op_math.sum import Sum
from pennylane.queuing import QueuingManager
Expand Down Expand Up @@ -181,7 +181,7 @@ def num_params(self):
"""
return 1 + self.base.num_params

def terms(self): # is this method necessary for this class?
def terms(self):
r"""Representation of the operator as a linear combination of other operators.
.. math:: O = \sum_i c_i O_i
Expand All @@ -191,8 +191,13 @@ def terms(self): # is this method necessary for this class?
Returns:
tuple[list[tensor_like or float], list[.Operation]]: list of coefficients :math:`c_i`
and list of operations :math:`O_i`
"""
return [self.scalar], [self.base]
try:
base_coeffs, base_ops = self.base.terms()
return [self.scalar * coeff for coeff in base_coeffs], base_ops
except TermsUndefinedError:
return [self.scalar], [self.base]

@property
def is_hermitian(self):
Expand Down
4 changes: 0 additions & 4 deletions pennylane/transforms/split_non_commuting.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,10 +537,6 @@ def _split_all_multi_term_obs_mps(tape: qml.tape.QuantumScript):
obs = mp.obs
offset = 0
if isinstance(mp, ExpectationMP) and isinstance(obs, (Hamiltonian, Sum, Prod, SProd)):
if isinstance(obs, SProd):
# This is necessary because SProd currently does not flatten into
# multiple terms if the base is a sum, which is needed here.
obs = obs.simplify()
# Break the observable into terms, and construct an ExpectationMP with each term.
for c, o in zip(*obs.terms()):
# If the observable is an identity, track it with a constant offset
Expand Down
18 changes: 18 additions & 0 deletions tests/ops/op_math/test_sprod.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,24 @@ def test_terms(self, op, scalar):
for op1, op2 in zip(op2, [op]):
qml.assert_equal(op1, op2)

@pytest.mark.parametrize(
"sprod_op, coeffs_exp, ops_exp",
[
(qml.s_prod(1.23, qml.sum(qml.X(0), qml.Y(0))), [1.23, 1.23], [qml.X(0), qml.Y(0)]),
(
qml.s_prod(1.23, qml.Hamiltonian([0.1, 0.2], [qml.X(0), qml.Y(0)])),
[0.123, 0.246],
[qml.X(0), qml.Y(0)],
),
],
)
def test_terms_nested(self, sprod_op, coeffs_exp, ops_exp):
"""Tests that SProd.terms() flattens a nested structure."""
coeffs, ops_actual = sprod_op.terms()
assert coeffs == coeffs_exp
for op1, op2 in zip(ops_actual, ops_exp):
qml.assert_equal(op1, op2)

def test_decomposition_raises_error(self):
sprod_op = s_prod(3.14, qml.Identity(wires=1))

Expand Down
2 changes: 1 addition & 1 deletion tests/templates/test_subroutines/test_prepselprep.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
(qml.ops.LinearCombination([0.25, 0.75], [qml.Z(2), qml.X(1) @ qml.X(2)]), [0]),
(qml.ops.LinearCombination([-0.25, 0.75j], [qml.Z(3), qml.X(2) @ qml.X(3)]), [0, 1]),
(
qml.ops.LinearCombination([-0.25 + 0.1j, 0.75j], [qml.Z(4), qml.X(4) @ qml.X(4)]),
qml.ops.LinearCombination([-0.25 + 0.1j, 0.75j], [qml.Z(4), qml.X(4) @ qml.X(5)]),
[0, 1, 2, 3],
),
],
Expand Down

0 comments on commit 2963166

Please sign in to comment.