diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 1376ed4bc96..20cc600aa99 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -10,6 +10,9 @@ * The `qml.PrepSelPrep` template is added. The template implements a block-encoding of a linear combination of unitaries. [(#5756)](https://github.com/PennyLaneAI/pennylane/pull/5756) + +* `SProd.terms` now flattens out the terms if the base is a multi-term observable. + [(#5885)](https://github.com/PennyLaneAI/pennylane/pull/5885)

Improvements 🛠

@@ -59,7 +62,7 @@ This release contains contributions from (in alphabetical order): -Ahmed Darwish +Ahmed Darwish, Astral Cai, Yushao Chen, Pietropaolo Frisoni, diff --git a/pennylane/ops/op_math/sprod.py b/pennylane/ops/op_math/sprod.py index e9c5af3abfc..b3b2a9cc96f 100644 --- a/pennylane/ops/op_math/sprod.py +++ b/pennylane/ops/op_math/sprod.py @@ -20,7 +20,7 @@ import pennylane as qml import pennylane.math as qnp -from pennylane.operation import Operator, convert_to_opmath +from pennylane.operation import Operator, TermsUndefinedError, convert_to_opmath from pennylane.ops.op_math.pow import Pow from pennylane.ops.op_math.sum import Sum from pennylane.queuing import QueuingManager @@ -181,7 +181,7 @@ def num_params(self): """ return 1 + self.base.num_params - def terms(self): # is this method necessary for this class? + def terms(self): r"""Representation of the operator as a linear combination of other operators. .. math:: O = \sum_i c_i O_i @@ -191,8 +191,13 @@ def terms(self): # is this method necessary for this class? Returns: tuple[list[tensor_like or float], list[.Operation]]: list of coefficients :math:`c_i` and list of operations :math:`O_i` + """ - return [self.scalar], [self.base] + try: + base_coeffs, base_ops = self.base.terms() + return [self.scalar * coeff for coeff in base_coeffs], base_ops + except TermsUndefinedError: + return [self.scalar], [self.base] @property def is_hermitian(self): diff --git a/pennylane/transforms/split_non_commuting.py b/pennylane/transforms/split_non_commuting.py index 90566b92d7d..128b22276a6 100644 --- a/pennylane/transforms/split_non_commuting.py +++ b/pennylane/transforms/split_non_commuting.py @@ -537,10 +537,6 @@ def _split_all_multi_term_obs_mps(tape: qml.tape.QuantumScript): obs = mp.obs offset = 0 if isinstance(mp, ExpectationMP) and isinstance(obs, (Hamiltonian, Sum, Prod, SProd)): - if isinstance(obs, SProd): - # This is necessary because SProd currently does not flatten into - # multiple terms if the base is a sum, which is needed here. - obs = obs.simplify() # Break the observable into terms, and construct an ExpectationMP with each term. for c, o in zip(*obs.terms()): # If the observable is an identity, track it with a constant offset diff --git a/tests/ops/op_math/test_sprod.py b/tests/ops/op_math/test_sprod.py index 437a94d92f7..0bcfca4b7ab 100644 --- a/tests/ops/op_math/test_sprod.py +++ b/tests/ops/op_math/test_sprod.py @@ -164,6 +164,24 @@ def test_terms(self, op, scalar): for op1, op2 in zip(op2, [op]): qml.assert_equal(op1, op2) + @pytest.mark.parametrize( + "sprod_op, coeffs_exp, ops_exp", + [ + (qml.s_prod(1.23, qml.sum(qml.X(0), qml.Y(0))), [1.23, 1.23], [qml.X(0), qml.Y(0)]), + ( + qml.s_prod(1.23, qml.Hamiltonian([0.1, 0.2], [qml.X(0), qml.Y(0)])), + [0.123, 0.246], + [qml.X(0), qml.Y(0)], + ), + ], + ) + def test_terms_nested(self, sprod_op, coeffs_exp, ops_exp): + """Tests that SProd.terms() flattens a nested structure.""" + coeffs, ops_actual = sprod_op.terms() + assert coeffs == coeffs_exp + for op1, op2 in zip(ops_actual, ops_exp): + qml.assert_equal(op1, op2) + def test_decomposition_raises_error(self): sprod_op = s_prod(3.14, qml.Identity(wires=1)) diff --git a/tests/templates/test_subroutines/test_prepselprep.py b/tests/templates/test_subroutines/test_prepselprep.py index 70708218786..587bbdccb26 100644 --- a/tests/templates/test_subroutines/test_prepselprep.py +++ b/tests/templates/test_subroutines/test_prepselprep.py @@ -29,7 +29,7 @@ (qml.ops.LinearCombination([0.25, 0.75], [qml.Z(2), qml.X(1) @ qml.X(2)]), [0]), (qml.ops.LinearCombination([-0.25, 0.75j], [qml.Z(3), qml.X(2) @ qml.X(3)]), [0, 1]), ( - qml.ops.LinearCombination([-0.25 + 0.1j, 0.75j], [qml.Z(4), qml.X(4) @ qml.X(4)]), + qml.ops.LinearCombination([-0.25 + 0.1j, 0.75j], [qml.Z(4), qml.X(4) @ qml.X(5)]), [0, 1, 2, 3], ), ],