Skip to content

Commit

Permalink
clean unflatten in PrepSelPrep (#5987)
Browse files Browse the repository at this point in the history
**Context:**

LinearCombination was being cast within unflaten. It has been moved to
init to avoid issues

---------

Co-authored-by: Thomas R. Bromley <49409390+trbromley@users.noreply.github.com>
  • Loading branch information
KetpuntoG and trbromley committed Jul 22, 2024
1 parent ee4e7de commit 3ad15d7
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 12 deletions.
2 changes: 2 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
combination of unitaries.
[(#5756)](https://github.com/PennyLaneAI/pennylane/pull/5756)

[(#5987)](https://github.com/PennyLaneAI/pennylane/pull/5987)

* The `split_to_single_terms` transform is added. This transform splits expectation values of sums
into multiple single-term measurements on a single tape, providing better support for simulators
that can handle non-commuting observables but don't natively support multi-term observables.
Expand Down
10 changes: 4 additions & 6 deletions pennylane/templates/subroutines/prepselprep.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,10 @@ class PrepSelPrep(Operation):
"""

def __init__(self, lcu, control=None, id=None):

coeffs, ops = lcu.terms()
control = qml.wires.Wires(control)
self.hyperparameters["lcu"] = lcu
self.hyperparameters["lcu"] = qml.ops.LinearCombination(coeffs, ops)
self.hyperparameters["coeffs"] = coeffs
self.hyperparameters["ops"] = ops
self.hyperparameters["control"] = control
Expand All @@ -95,14 +96,11 @@ def __init__(self, lcu, control=None, id=None):
super().__init__(*self.data, wires=all_wires, id=id)

def _flatten(self):
return tuple(self.lcu), (self.control)
return (self.lcu,), (self.control,)

@classmethod
def _unflatten(cls, data, metadata) -> "PrepSelPrep":
coeffs = [term.terms()[0][0] for term in data]
ops = [term.terms()[1][0] for term in data]
lcu = qml.ops.LinearCombination(coeffs, ops)
return cls(lcu, metadata)
return cls(data[0], metadata[0])

def __repr__(self):
return f"PrepSelPrep(coeffs={tuple(self.coeffs)}, ops={tuple(self.ops)}, control={self.control})"
Expand Down
26 changes: 20 additions & 6 deletions tests/templates/test_subroutines/test_prepselprep.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
("lcu", "control"),
[
(qml.ops.LinearCombination([0.25, 0.75], [qml.Z(2), qml.X(1) @ qml.X(2)]), [0]),
(qml.dot([0.25, 0.75], [qml.Z(2), qml.X(1) @ qml.X(2)]), [0]),
(qml.Hamiltonian([0.25, 0.75], [qml.Z(2), qml.X(1) @ qml.X(2)]), [0]),
(0.25 * qml.Z(2) - 0.75 * qml.X(1) @ qml.X(2), [0]),
(qml.Z(2) + qml.X(1) @ qml.X(2), [0]),
(qml.ops.LinearCombination([-0.25, 0.75j], [qml.Z(3), qml.X(2) @ qml.X(3)]), [0, 1]),
(
qml.ops.LinearCombination([-0.25 + 0.1j, 0.75j], [qml.Z(4), qml.X(4) @ qml.X(5)]),
Expand Down Expand Up @@ -257,25 +261,35 @@ def test_copy(self):

assert qml.equal(op, op_copy)

def test_flatten_unflatten(self):
@pytest.mark.parametrize(
("lcu"),
[
qml.ops.LinearCombination([0.25, 0.75], [qml.Z(2), qml.X(1) @ qml.X(2)]),
qml.dot([0.25, 0.75], [qml.Z(2), qml.X(1) @ qml.X(2)]),
qml.Hamiltonian([0.25, 0.75], [qml.Z(2), qml.X(1) @ qml.X(2)]),
0.25 * qml.Z(2) - 0.75 * qml.X(1) @ qml.X(2),
qml.Z(2) + qml.X(1) @ qml.X(2),
qml.ops.LinearCombination([-0.25, 0.75j], [qml.Z(3), qml.X(2) @ qml.X(3)]),
qml.ops.LinearCombination([-0.25 + 0.1j, 0.75j], [qml.Z(4), qml.X(4) @ qml.X(5)]),
],
)
def test_flatten_unflatten(self, lcu):
"""Test that the class can be correctly flattened and unflattened"""

lcu = qml.ops.LinearCombination([1 / 2, 1 / 2], [qml.Identity(1), qml.PauliZ(1)])
lcu_coeffs, lcu_ops = lcu.terms()

op = qml.PrepSelPrep(lcu, control=0)
data, metadata = op._flatten()

data_coeffs = [term.terms()[0][0] for term in data]
data_ops = [term.terms()[1][0] for term in data]
data_coeffs, data_ops = data[0].terms()

assert hash(metadata)

assert len(data) == len(lcu)
assert len(data[0]) == len(lcu)
assert all(coeff1 == coeff2 for coeff1, coeff2 in zip(lcu_coeffs, data_coeffs))
assert all(op1 == op2 for op1, op2 in zip(lcu_ops, data_ops))

assert metadata == op.control
assert metadata[0] == op.control

new_op = type(op)._unflatten(*op._flatten())
assert op.lcu == new_op.lcu
Expand Down

0 comments on commit 3ad15d7

Please sign in to comment.