diff --git a/pennylane/templates/subroutines/prepselprep.py b/pennylane/templates/subroutines/prepselprep.py index f6c02265a1c..53df7f96cfc 100644 --- a/pennylane/templates/subroutines/prepselprep.py +++ b/pennylane/templates/subroutines/prepselprep.py @@ -75,9 +75,10 @@ class PrepSelPrep(Operation): """ def __init__(self, lcu, control=None, id=None): + coeffs, ops = lcu.terms() control = qml.wires.Wires(control) - self.hyperparameters["lcu"] = lcu + self.hyperparameters["lcu"] = qml.ops.LinearCombination(coeffs, ops) self.hyperparameters["coeffs"] = coeffs self.hyperparameters["ops"] = ops self.hyperparameters["control"] = control @@ -95,14 +96,11 @@ def __init__(self, lcu, control=None, id=None): super().__init__(*self.data, wires=all_wires, id=id) def _flatten(self): - return tuple(self.lcu), (self.control) + return (self.lcu,), (self.control,) @classmethod def _unflatten(cls, data, metadata) -> "PrepSelPrep": - coeffs = [term.terms()[0][0] for term in data] - ops = [term.terms()[1][0] for term in data] - lcu = qml.ops.LinearCombination(coeffs, ops) - return cls(lcu, metadata) + return cls(data[0], metadata[0]) def __repr__(self): return f"PrepSelPrep(coeffs={tuple(self.coeffs)}, ops={tuple(self.ops)}, control={self.control})" diff --git a/tests/templates/test_subroutines/test_prepselprep.py b/tests/templates/test_subroutines/test_prepselprep.py index ea6b71e9347..564ca3bf93e 100644 --- a/tests/templates/test_subroutines/test_prepselprep.py +++ b/tests/templates/test_subroutines/test_prepselprep.py @@ -27,9 +27,13 @@ ("lcu", "control"), [ (qml.ops.LinearCombination([0.25, 0.75], [qml.Z(2), qml.X(1) @ qml.X(2)]), [0]), + (qml.dot([0.25, 0.75], [qml.Z(2), qml.X(1) @ qml.X(2)]), [0]), + (qml.Hamiltonian([0.25, 0.75], [qml.Z(2), qml.X(1) @ qml.X(2)]), [0]), + (0.25 * qml.Z(2) - 0.75 * qml.X(1) @ qml.X(2), [0]), + (qml.Z(2) + qml.X(1) @ qml.X(2), [0]), (qml.ops.LinearCombination([-0.25, 0.75j], [qml.Z(3), qml.X(2) @ qml.X(3)]), [0, 1]), ( - qml.ops.LinearCombination([-0.25 + 0.1j, 0.75j], [qml.Z(4), qml.X(4) @ qml.X(4)]), + qml.ops.LinearCombination([-0.25 + 0.1j, 0.75j], [qml.Z(4), qml.X(4) @ qml.X(5)]), [0, 1, 2, 3], ), ], @@ -258,25 +262,35 @@ def test_copy(self): assert qml.equal(op, op_copy) - def test_flatten_unflatten(self): + @pytest.mark.parametrize( + ("lcu"), + [ + qml.ops.LinearCombination([0.25, 0.75], [qml.Z(2), qml.X(1) @ qml.X(2)]), + qml.dot([0.25, 0.75], [qml.Z(2), qml.X(1) @ qml.X(2)]), + qml.Hamiltonian([0.25, 0.75], [qml.Z(2), qml.X(1) @ qml.X(2)]), + 0.25 * qml.Z(2) - 0.75 * qml.X(1) @ qml.X(2), + qml.Z(2) + qml.X(1) @ qml.X(2), + qml.ops.LinearCombination([-0.25, 0.75j], [qml.Z(3), qml.X(2) @ qml.X(3)]), + qml.ops.LinearCombination([-0.25 + 0.1j, 0.75j], [qml.Z(4), qml.X(4) @ qml.X(5)]), + ], + ) + def test_flatten_unflatten(self, lcu): """Test that the class can be correctly flattened and unflattened""" - lcu = qml.ops.LinearCombination([1 / 2, 1 / 2], [qml.Identity(1), qml.PauliZ(1)]) lcu_coeffs, lcu_ops = lcu.terms() op = qml.PrepSelPrep(lcu, control=0) data, metadata = op._flatten() - data_coeffs = [term.terms()[0][0] for term in data] - data_ops = [term.terms()[1][0] for term in data] + data_coeffs, data_ops = data[0].terms() assert hash(metadata) - assert len(data) == len(lcu) + assert len(data[0]) == len(lcu) assert all(coeff1 == coeff2 for coeff1, coeff2 in zip(lcu_coeffs, data_coeffs)) assert all(op1 == op2 for op1, op2 in zip(lcu_ops, data_ops)) - assert metadata == op.control + assert metadata[0] == op.control new_op = type(op)._unflatten(*op._flatten()) assert op.lcu == new_op.lcu