diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 1376ed4bc96..20cc600aa99 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -10,6 +10,9 @@
* The `qml.PrepSelPrep` template is added. The template implements a block-encoding of a linear
combination of unitaries.
[(#5756)](https://github.com/PennyLaneAI/pennylane/pull/5756)
+
+* `SProd.terms` now flattens out the terms if the base is a multi-term observable.
+ [(#5885)](https://github.com/PennyLaneAI/pennylane/pull/5885)
Improvements ðŸ›
@@ -59,7 +62,7 @@
This release contains contributions from (in alphabetical order):
-Ahmed Darwish
+Ahmed Darwish,
Astral Cai,
Yushao Chen,
Pietropaolo Frisoni,
diff --git a/pennylane/ops/op_math/sprod.py b/pennylane/ops/op_math/sprod.py
index e9c5af3abfc..b3b2a9cc96f 100644
--- a/pennylane/ops/op_math/sprod.py
+++ b/pennylane/ops/op_math/sprod.py
@@ -20,7 +20,7 @@
import pennylane as qml
import pennylane.math as qnp
-from pennylane.operation import Operator, convert_to_opmath
+from pennylane.operation import Operator, TermsUndefinedError, convert_to_opmath
from pennylane.ops.op_math.pow import Pow
from pennylane.ops.op_math.sum import Sum
from pennylane.queuing import QueuingManager
@@ -181,7 +181,7 @@ def num_params(self):
"""
return 1 + self.base.num_params
- def terms(self): # is this method necessary for this class?
+ def terms(self):
r"""Representation of the operator as a linear combination of other operators.
.. math:: O = \sum_i c_i O_i
@@ -191,8 +191,13 @@ def terms(self): # is this method necessary for this class?
Returns:
tuple[list[tensor_like or float], list[.Operation]]: list of coefficients :math:`c_i`
and list of operations :math:`O_i`
+
"""
- return [self.scalar], [self.base]
+ try:
+ base_coeffs, base_ops = self.base.terms()
+ return [self.scalar * coeff for coeff in base_coeffs], base_ops
+ except TermsUndefinedError:
+ return [self.scalar], [self.base]
@property
def is_hermitian(self):
diff --git a/pennylane/transforms/split_non_commuting.py b/pennylane/transforms/split_non_commuting.py
index 90566b92d7d..128b22276a6 100644
--- a/pennylane/transforms/split_non_commuting.py
+++ b/pennylane/transforms/split_non_commuting.py
@@ -537,10 +537,6 @@ def _split_all_multi_term_obs_mps(tape: qml.tape.QuantumScript):
obs = mp.obs
offset = 0
if isinstance(mp, ExpectationMP) and isinstance(obs, (Hamiltonian, Sum, Prod, SProd)):
- if isinstance(obs, SProd):
- # This is necessary because SProd currently does not flatten into
- # multiple terms if the base is a sum, which is needed here.
- obs = obs.simplify()
# Break the observable into terms, and construct an ExpectationMP with each term.
for c, o in zip(*obs.terms()):
# If the observable is an identity, track it with a constant offset
diff --git a/tests/ops/op_math/test_sprod.py b/tests/ops/op_math/test_sprod.py
index 437a94d92f7..0bcfca4b7ab 100644
--- a/tests/ops/op_math/test_sprod.py
+++ b/tests/ops/op_math/test_sprod.py
@@ -164,6 +164,24 @@ def test_terms(self, op, scalar):
for op1, op2 in zip(op2, [op]):
qml.assert_equal(op1, op2)
+ @pytest.mark.parametrize(
+ "sprod_op, coeffs_exp, ops_exp",
+ [
+ (qml.s_prod(1.23, qml.sum(qml.X(0), qml.Y(0))), [1.23, 1.23], [qml.X(0), qml.Y(0)]),
+ (
+ qml.s_prod(1.23, qml.Hamiltonian([0.1, 0.2], [qml.X(0), qml.Y(0)])),
+ [0.123, 0.246],
+ [qml.X(0), qml.Y(0)],
+ ),
+ ],
+ )
+ def test_terms_nested(self, sprod_op, coeffs_exp, ops_exp):
+ """Tests that SProd.terms() flattens a nested structure."""
+ coeffs, ops_actual = sprod_op.terms()
+ assert coeffs == coeffs_exp
+ for op1, op2 in zip(ops_actual, ops_exp):
+ qml.assert_equal(op1, op2)
+
def test_decomposition_raises_error(self):
sprod_op = s_prod(3.14, qml.Identity(wires=1))
diff --git a/tests/templates/test_subroutines/test_prepselprep.py b/tests/templates/test_subroutines/test_prepselprep.py
index 70708218786..587bbdccb26 100644
--- a/tests/templates/test_subroutines/test_prepselprep.py
+++ b/tests/templates/test_subroutines/test_prepselprep.py
@@ -29,7 +29,7 @@
(qml.ops.LinearCombination([0.25, 0.75], [qml.Z(2), qml.X(1) @ qml.X(2)]), [0]),
(qml.ops.LinearCombination([-0.25, 0.75j], [qml.Z(3), qml.X(2) @ qml.X(3)]), [0, 1]),
(
- qml.ops.LinearCombination([-0.25 + 0.1j, 0.75j], [qml.Z(4), qml.X(4) @ qml.X(4)]),
+ qml.ops.LinearCombination([-0.25 + 0.1j, 0.75j], [qml.Z(4), qml.X(4) @ qml.X(5)]),
[0, 1, 2, 3],
),
],