Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SProd.terms() flattens out nested terms #5885

Merged
merged 9 commits into from
Jul 12, 2024
5 changes: 4 additions & 1 deletion doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
* The `qml.PrepSelPrep` template is added. The template implements a block-encoding of a linear
combination of unitaries.
[(#5756)](https://github.com/PennyLaneAI/pennylane/pull/5756)

* `SProd.terms` now flattens out the terms if the base is a multi-term observable.
[(#5885)](https://github.com/PennyLaneAI/pennylane/pull/5885)

<h3>Improvements 🛠</h3>

Expand Down Expand Up @@ -52,7 +55,7 @@

This release contains contributions from (in alphabetical order):

Ahmed Darwish
Ahmed Darwish,
Astral Cai,
Yushao Chen,
Christina Lee,
Expand Down
11 changes: 8 additions & 3 deletions pennylane/ops/op_math/sprod.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import pennylane as qml
import pennylane.math as qnp
from pennylane.operation import Operator, convert_to_opmath
from pennylane.operation import Operator, TermsUndefinedError, convert_to_opmath
from pennylane.ops.op_math.pow import Pow
from pennylane.ops.op_math.sum import Sum
from pennylane.queuing import QueuingManager
Expand Down Expand Up @@ -181,7 +181,7 @@ def num_params(self):
"""
return 1 + self.base.num_params

def terms(self): # is this method necessary for this class?
def terms(self):
r"""Representation of the operator as a linear combination of other operators.

.. math:: O = \sum_i c_i O_i
Expand All @@ -191,8 +191,13 @@ def terms(self): # is this method necessary for this class?
Returns:
tuple[list[tensor_like or float], list[.Operation]]: list of coefficients :math:`c_i`
and list of operations :math:`O_i`

Qottmann marked this conversation as resolved.
Show resolved Hide resolved
"""
return [self.scalar], [self.base]
try:
base_coeffs, base_ops = self.base.terms()
return [self.scalar * coeff for coeff in base_coeffs], base_ops
except TermsUndefinedError:
return [self.scalar], [self.base]
Qottmann marked this conversation as resolved.
Show resolved Hide resolved

@property
def is_hermitian(self):
Expand Down
4 changes: 0 additions & 4 deletions pennylane/transforms/split_non_commuting.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,10 +537,6 @@ def _split_all_multi_term_obs_mps(tape: qml.tape.QuantumScript):
obs = mp.obs
offset = 0
if isinstance(mp, ExpectationMP) and isinstance(obs, (Hamiltonian, Sum, Prod, SProd)):
if isinstance(obs, SProd):
# This is necessary because SProd currently does not flatten into
# multiple terms if the base is a sum, which is needed here.
obs = obs.simplify()
# Break the observable into terms, and construct an ExpectationMP with each term.
for c, o in zip(*obs.terms()):
# If the observable is an identity, track it with a constant offset
Expand Down
18 changes: 18 additions & 0 deletions tests/ops/op_math/test_sprod.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,24 @@ def test_terms(self, op, scalar):
for op1, op2 in zip(op2, [op]):
qml.assert_equal(op1, op2)

@pytest.mark.parametrize(
"sprod_op, coeffs_exp, ops_exp",
[
(qml.s_prod(1.23, qml.sum(qml.X(0), qml.Y(0))), [1.23, 1.23], [qml.X(0), qml.Y(0)]),
(
qml.s_prod(1.23, qml.Hamiltonian([0.1, 0.2], [qml.X(0), qml.Y(0)])),
[0.123, 0.246],
[qml.X(0), qml.Y(0)],
),
],
)
def test_terms_nested(self, sprod_op, coeffs_exp, ops_exp):
"""Tests that SProd.terms() flattens a nested structure."""
coeffs, ops_actual = sprod_op.terms()
assert coeffs == coeffs_exp
for op1, op2 in zip(ops_actual, ops_exp):
qml.assert_equal(op1, op2)

def test_decomposition_raises_error(self):
sprod_op = s_prod(3.14, qml.Identity(wires=1))

Expand Down
2 changes: 1 addition & 1 deletion tests/templates/test_subroutines/test_prepselprep.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
(qml.ops.LinearCombination([0.25, 0.75], [qml.Z(2), qml.X(1) @ qml.X(2)]), [0]),
(qml.ops.LinearCombination([-0.25, 0.75j], [qml.Z(3), qml.X(2) @ qml.X(3)]), [0, 1]),
(
qml.ops.LinearCombination([-0.25 + 0.1j, 0.75j], [qml.Z(4), qml.X(4) @ qml.X(4)]),
qml.ops.LinearCombination([-0.25 + 0.1j, 0.75j], [qml.Z(4), qml.X(4) @ qml.X(5)]),
lillian542 marked this conversation as resolved.
Show resolved Hide resolved
[0, 1, 2, 3],
),
],
Expand Down
Loading