Skip to content

Commit

Permalink
Minor updates for PR #5511 (#5524)
Browse files Browse the repository at this point in the history
  • Loading branch information
albi3ro committed Apr 24, 2024
1 parent f834328 commit 10d59e7
Show file tree
Hide file tree
Showing 13 changed files with 110 additions and 14 deletions.
7 changes: 7 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,9 @@
* `qml.transforms.split_non_commuting` will now work with single-term operator arithmetic.
[(#5314)](https://github.com/PennyLaneAI/pennylane/pull/5314)

* `LinearCombination` and `Sum` now accept `_grouping_indices` on initialization.
[(#5524)](https://github.com/PennyLaneAI/pennylane/pull/5524)

<h4>Mid-circuit measurements and dynamic circuits</h4>

* The `QubitDevice` class and children classes support the `dynamic_one_shot` transform provided that they support `MidMeasureMP` operations natively.
Expand Down Expand Up @@ -465,6 +468,10 @@

<h3>Bug fixes 🐛</h3>

* `ApproxTimeEvolution`, `CommutingEvolution`, `QDrift`, and `TrotterProduct`
now de-queue their input observable.
[(#5524)](https://github.com/PennyLaneAI/pennylane/pull/5524)

* (In)equality of `qml.HilbertSchmidt` instances is now reported correctly by `qml.equal`.
[(#5538)](https://github.com/PennyLaneAI/pennylane/pull/5538)

Expand Down
15 changes: 9 additions & 6 deletions pennylane/ops/op_math/linear_combination.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,11 @@ class LinearCombination(Sum):

def _flatten(self):
# note that we are unable to restore grouping type or method without creating new properties
return (self._coeffs, self._ops, self.data), (self.grouping_indices,)
return self.terms(), (self.grouping_indices,)

@classmethod
def _unflatten(cls, data, metadata):
new_op = cls(data[0], data[1])
new_op._grouping_indices = metadata[0] # pylint: disable=protected-access
new_op.data = data[2]
return new_op
return cls(data[0], data[1], _grouping_indices=metadata[0])

def __init__(
self,
Expand All @@ -116,6 +113,7 @@ def __init__(
simplify=False,
grouping_type=None,
method="rlf",
_grouping_indices=None,
_pauli_rep=None,
id=None,
):
Expand Down Expand Up @@ -147,7 +145,12 @@ def __init__(
operands = tuple(qml.s_prod(c, op) for c, op in zip(coeffs, observables))

super().__init__(
*operands, grouping_type=grouping_type, method=method, id=id, _pauli_rep=_pauli_rep
*operands,
grouping_type=grouping_type,
method=method,
id=id,
_grouping_indices=_grouping_indices,
_pauli_rep=_pauli_rep,
)

@staticmethod
Expand Down
19 changes: 13 additions & 6 deletions pennylane/ops/op_math/sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,17 +216,24 @@ def _flatten(self):

@classmethod
def _unflatten(cls, data, metadata):
# pylint: disable=protected-access
new_op = cls(*data)
new_op._grouping_indices = metadata[0]
return new_op
return cls(*data, _grouping_indices=metadata[0])

def __init__(
self, *operands: Operator, grouping_type=None, method="rlf", id=None, _pauli_rep=None
self,
*operands: Operator,
grouping_type=None,
method="rlf",
id=None,
_grouping_indices=None,
_pauli_rep=None,
):
super().__init__(*operands, id=id, _pauli_rep=_pauli_rep)

self._grouping_indices = None
self._grouping_indices = _grouping_indices
if _grouping_indices is not None and grouping_type is not None:
raise ValueError(
"_grouping_indices and grouping_type cannot be specified at the same time."
)
if grouping_type is not None:
self.compute_grouping(grouping_type=grouping_type, method=method)

Expand Down
5 changes: 5 additions & 0 deletions pennylane/templates/subroutines/approx_time_evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,11 @@ def __init__(self, hamiltonian, time, n, id=None):
# trainable parameters are passed to the base init method
super().__init__(*hamiltonian.data, time, wires=wires, id=id)

def queue(self, context=qml.QueuingManager):
context.remove(self.hyperparameters["hamiltonian"])
context.append(self)
return self

@staticmethod
def compute_decomposition(
*coeffs_and_time, wires, hamiltonian, n
Expand Down
5 changes: 5 additions & 0 deletions pennylane/templates/subroutines/commuting_evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,11 @@ def __init__(self, hamiltonian, time, frequencies=None, shifts=None, id=None):

super().__init__(time, *hamiltonian.parameters, wires=hamiltonian.wires, id=id)

def queue(self, context=qml.QueuingManager):
context.remove(self.hyperparameters["hamiltonian"])
context.append(self)
return self

@staticmethod
def compute_decomposition(
time, *_, wires, hamiltonian, **__
Expand Down
5 changes: 5 additions & 0 deletions pennylane/templates/subroutines/qdrift.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,11 @@ def __init__( # pylint: disable=too-many-arguments
}
super().__init__(time, wires=hamiltonian.wires, id=id)

def queue(self, context=qml.QueuingManager):
context.remove(self.hyperparameters["base"])
context.append(self)
return self

@classmethod
def _unflatten(cls, data, metadata):
"""Recreate an operation from its serialized format.
Expand Down
5 changes: 5 additions & 0 deletions pennylane/templates/subroutines/trotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,11 @@ def __init__( # pylint: disable=too-many-arguments
}
super().__init__(time, wires=hamiltonian.wires, id=id)

def queue(self, context=qml.QueuingManager):
context.remove(self.hyperparameters["base"])
context.append(self)
return self

def error(
self, method: str = "commutator", fast: bool = True
): # pylint: disable=arguments-differ
Expand Down
9 changes: 7 additions & 2 deletions tests/ops/op_math/test_linear_combination.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,12 +618,11 @@ def test_flatten_unflatten(self, coeffs, ops, grouping_type):
data, metadata = H._flatten()
assert metadata[0] == H.grouping_indices
assert hash(metadata)
assert len(data) == 3
assert len(data) == 2
assert qml.math.allequal(
data[0], H._coeffs
) # Previously checking "is" instead of "==", problem?
assert data[1] == H._ops
assert data[2] == H.data

new_H = LinearCombination._unflatten(*H._flatten())
assert qml.equal(H, new_H)
Expand Down Expand Up @@ -1457,6 +1456,12 @@ def test_LinearCombination_matmul(self):
class TestGrouping:
"""Tests for the grouping functionality"""

def test_set_on_initialization(self):
"""Test that grouping indices can be set on initialization."""

op = qml.ops.LinearCombination([1, 1], [qml.X(0), qml.Y(1)], _grouping_indices=[[0, 1]])
assert op.grouping_indices == [[0, 1]]

def test_indentities_preserved(self):
"""Tests that the grouping indices do not drop identity terms when the wire order is nonstandard."""

Expand Down
12 changes: 12 additions & 0 deletions tests/ops/op_math/test_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -1326,6 +1326,18 @@ def test_adjoint(self):
class TestGrouping:
"""Test grouping functionality of Sum"""

def test_set_on_initialization(self):
"""Test that grouping indices can be set on initialization."""

op = qml.ops.Sum(qml.X(0), qml.Y(1), _grouping_indices=[[0, 1]])
assert op.grouping_indices == [[0, 1]]
op_ac = qml.ops.Sum(qml.X(0), qml.Y(1), grouping_type="anticommuting")
assert op_ac.grouping_indices == ((0,), (1,))
with pytest.raises(ValueError, match=r"cannot be specified at the same time."):
qml.ops.Sum(
qml.X(0), qml.Y(1), grouping_type="anticommuting", _grouping_indices=[[0, 1]]
)

def test_non_pauli_error(self):
"""Test that grouping non-Pauli observables is not supported."""
op = Sum(qml.PauliX(0), Prod(qml.PauliZ(0), qml.PauliX(1)), qml.Hadamard(2))
Expand Down
11 changes: 11 additions & 0 deletions tests/templates/test_subroutines/test_approx_time_evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,17 @@ def test_flatten_unflatten():
assert new_op is not op


def test_queuing():
"""Test that ApproxTimeEvolution de-queues the input hamiltonian."""

with qml.queuing.AnnotatedQueue() as q:
H = qml.X(0) + qml.Y(1)
op = qml.ApproxTimeEvolution(H, 0.1, n=20)

assert len(q.queue) == 1
assert q.queue[0] is op


class TestDecomposition:
"""Tests that the template defines the correct decomposition."""

Expand Down
11 changes: 11 additions & 0 deletions tests/templates/test_subroutines/test_commuting_evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,17 @@ def evolution_circuit(time):
assert all(np.isclose(state1, state2))


def test_queuing():
"""Test that CommutingEvolution de-queues the input hamiltonian."""

with qml.queuing.AnnotatedQueue() as q:
H = qml.X(0) + qml.Y(1)
op = qml.CommutingEvolution(H, 0.1, (2,))

assert len(q.queue) == 1
assert q.queue[0] is op


def test_decomposition_expand():
"""Test that the decomposition of CommutingEvolution is an ApproxTimeEvolution with one step."""

Expand Down
10 changes: 10 additions & 0 deletions tests/templates/test_subroutines/test_qdrift.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@
class TestInitialization:
"""Test that the class is intialized correctly."""

def test_queuing(self):
"""Test that QDrift de-queues the input hamiltonian."""

with qml.queuing.AnnotatedQueue() as q:
H = qml.X(0) + qml.Y(1)
op = qml.QDrift(H, 0.1, n=20)

assert len(q.queue) == 1
assert q.queue[0] is op

@pytest.mark.parametrize("n", (1, 2, 3))
@pytest.mark.parametrize("time", (0.5, 1, 2))
@pytest.mark.parametrize("seed", (None, 1234, 42))
Expand Down
10 changes: 10 additions & 0 deletions tests/templates/test_subroutines/test_trotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,16 @@ def test_convention_approx_time_evolv(self, time, n):
qml.matrix(op2, wire_order=hamiltonian.wires),
)

def test_queuing(self):
"""Test that the target operator is removed from the queue."""

with qml.queuing.AnnotatedQueue() as q:
H = qml.X(0) + qml.Y(1)
op = qml.TrotterProduct(H, time=2)

assert len(q.queue) == 1
assert q.queue[0] is op


class TestPrivateFunctions:
"""Test the private helper functions."""
Expand Down

0 comments on commit 10d59e7

Please sign in to comment.