Skip to content

Commit

Permalink
Clean up Device.batch_transform to use split_non_commuting (#5828)
Browse files Browse the repository at this point in the history
**Context:**
#5729 introduced the
unified `split_non_commuting`. Now the legacy device can use
`split_non_commuting` in all scenarios.

**Description of the Change:**
Cleans up `batch_transform` to use `split_non_commuting`

**Benefits:**
Cleaner code, prepares for the deprecation of `hamiltonian_expand` and
`sum_expand`.

**Related GitHub Issues:**
[sc-61253]
  • Loading branch information
astralcai committed Jun 18, 2024
1 parent 209867b commit eb39402
Show file tree
Hide file tree
Showing 15 changed files with 175 additions and 213 deletions.
1 change: 1 addition & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
* `qml.transforms.split_non_commuting` can now handle circuits containing measurements of multi-term observables.
[(#5729)](https://github.com/PennyLaneAI/pennylane/pull/5729)
[(#5853)](https://github.com/PennyLaneAI/pennylane/pull/5838)
[(#5828)](https://github.com/PennyLaneAI/pennylane/pull/5828)
[(#5869)](https://github.com/PennyLaneAI/pennylane/pull/5869)

* The qchem module has dedicated functions for calling `pyscf` and `openfermion` backends. The
Expand Down
113 changes: 61 additions & 52 deletions pennylane/_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,11 @@

import pennylane as qml
from pennylane.measurements import (
CountsMP,
Expectation,
ExpectationMP,
MeasurementProcess,
MidMeasureMP,
Probability,
ProbabilityMP,
Sample,
SampleMP,
ShadowExpvalMP,
State,
Variance,
Expand Down Expand Up @@ -738,79 +734,92 @@ def batch_transform(self, circuit: QuantumTape):
the sequence of circuits to be executed, and a post-processing function
to be applied to the list of evaluated circuit results.
"""
supports_hamiltonian = self.supports_observable("Hamiltonian")
supports_sum = self.supports_observable("Sum")

def null_postprocess(results):
return results[0]

finite_shots = self.shots is not None
grouping_known = all(
obs.grouping_indices is not None
for obs in circuit.observables
if isinstance(obs, (Hamiltonian, LinearCombination))
has_shadow = any(isinstance(m, ShadowExpvalMP) for m in circuit.measurements)
is_analytic_or_shadow = not finite_shots or has_shadow
all_obs_usable = self._all_multi_term_obs_supported(circuit)
exists_multi_term_obs = any(
isinstance(m.obs, (Hamiltonian, Sum, Prod, SProd)) for m in circuit.measurements
)
# device property present in braket plugin
use_grouping = getattr(self, "use_grouping", True)

hamiltonian_in_obs = any(
isinstance(obs, (Hamiltonian, LinearCombination)) for obs in circuit.observables
has_overlapping_wires = len(circuit.obs_sharing_wires) > 0
single_hamiltonian = len(circuit.measurements) == 1 and isinstance(
circuit.measurements[0].obs, (Hamiltonian, Sum)
)

expval_sum_or_prod_in_obs = any(
isinstance(m.obs, (Sum, Prod, SProd)) and isinstance(m, ExpectationMP)
for m in circuit.measurements
single_hamiltonian_with_grouping_known = (
single_hamiltonian and circuit.measurements[0].obs.grouping_indices is not None
)

is_shadow = any(isinstance(m, ShadowExpvalMP) for m in circuit.measurements)
if not getattr(self, "use_grouping", True) and single_hamiltonian and all_obs_usable:
# Special logic for the braket plugin
circuits = [circuit]
processing_fn = null_postprocess

hamiltonian_unusable = not supports_hamiltonian or (finite_shots and not is_shadow)
elif not exists_multi_term_obs and not has_overlapping_wires:
circuits = [circuit]
processing_fn = null_postprocess

if hamiltonian_in_obs and (hamiltonian_unusable or (use_grouping and grouping_known)):
# If the observable contains a Hamiltonian and the device does not
# support Hamiltonians, or if the simulation uses finite shots, or
# if the Hamiltonian explicitly specifies an observable grouping,
# split tape into multiple tapes of diagonalizable known observables.
try:
circuits, hamiltonian_fn = qml.transforms.hamiltonian_expand(circuit, group=False)
except ValueError:
circuits, hamiltonian_fn = qml.transforms.sum_expand(circuit)
elif is_analytic_or_shadow and all_obs_usable and not has_overlapping_wires:
circuits = [circuit]
processing_fn = null_postprocess

elif expval_sum_or_prod_in_obs and not is_shadow and not supports_sum:
circuits, hamiltonian_fn = qml.transforms.sum_expand(circuit)
elif single_hamiltonian_with_grouping_known:

elif (
len(circuit.obs_sharing_wires) > 0
and not hamiltonian_in_obs
and all(
not isinstance(m, (SampleMP, ProbabilityMP, CountsMP)) for m in circuit.measurements
)
):
# Check for case of non-commuting terms and that there are no Hamiltonians
# TODO: allow for Hamiltonians in list of observables as well.
circuits, hamiltonian_fn = qml.transforms.split_non_commuting(circuit)
# Use qwc grouping if the circuit contains a single measurement of a
# Hamiltonian/Sum with grouping indices already calculated.
circuits, processing_fn = qml.transforms.split_non_commuting(circuit, "qwc")

else:
# otherwise, return the output of an identity transform
circuits = [circuit]
elif any(isinstance(m.obs, (Hamiltonian, LinearCombination)) for m in circuit.measurements):

def hamiltonian_fn(res):
return res[0]
# Otherwise, use wire-based grouping if the circuit contains a Hamiltonian
# that is potentially very large.
circuits, processing_fn = qml.transforms.split_non_commuting(circuit, "wires")

# Check whether the circuit was broadcasted (then the Hamiltonian-expanded
# ones will be as well) and whether broadcasting is supported
else:
circuits, processing_fn = qml.transforms.split_non_commuting(circuit)

# Check whether the circuit was broadcasted and whether broadcasting is supported
if circuit.batch_size is None or self.capabilities().get("supports_broadcasting"):
# If the circuit wasn't broadcasted or broadcasting is supported, no action required
return circuits, hamiltonian_fn
return circuits, processing_fn

# Expand each of the broadcasted Hamiltonian-expanded circuits
expanded_tapes, expanded_fn = qml.transforms.broadcast_expand(circuits)

# Chain the postprocessing functions of the broadcasted-tape expansions and the Hamiltonian
# expansion. Note that the application order is reversed compared to the expansion order,
# i.e. while we first applied `hamiltonian_expand` to the tape, we need to process the
# i.e. while we first applied `split_non_commuting` to the tape, we need to process the
# results from the broadcast expansion first.
def total_processing(results):
return hamiltonian_fn(expanded_fn(results))
return processing_fn(expanded_fn(results))

return expanded_tapes, total_processing

def _all_multi_term_obs_supported(self, circuit):
"""Check whether all multi-term observables in the circuit are supported."""

for mp in circuit.measurements:

if mp.obs is None:
# Some measurements are not observable based.
continue

if mp.obs.name == "LinearCombination" and not self.supports_observable("Hamiltonian"):
return False

if mp.obs.name in (
"Hamiltonian",
"Sum",
"Prod",
"SProd",
) and not self.supports_observable(mp.obs.name):
return False

return True

@property
def op_queue(self):
"""The operation queue to be applied.
Expand Down
25 changes: 24 additions & 1 deletion pennylane/devices/qubit/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,30 @@ def _get_num_executions_for_expval_H(obs):
indices = obs.grouping_indices
if indices:
return len(indices)
return sum(int(not isinstance(o, qml.Identity)) for o in obs.terms()[1])
return _get_num_wire_groups_for_expval_H(obs)


def _get_num_wire_groups_for_expval_H(obs):
_, obs_list = obs.terms()
wires_list = []
added_obs = []
num_groups = 0
for o in obs_list:
if o in added_obs:
continue
if isinstance(o, qml.Identity):
continue
added = False
for wires in wires_list:
if len(qml.wires.Wires.shared_wires([wires, o.wires])) == 0:
added_obs.append(o)
added = True
break
if not added:
added_obs.append(o)
wires_list.append(o.wires)
num_groups += 1
return num_groups


def _get_num_executions_for_sum(obs):
Expand Down
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,12 @@ def new_opmath_only():
pytest.skip("This feature only works with new opmath enabled")


@pytest.fixture
def legacy_opmath_only():
if qml.operation.active_new_opmath():
pytest.skip("This test exclusively tests legacy opmath")


#######################################################################

try:
Expand Down
36 changes: 12 additions & 24 deletions tests/devices/default_qubit/test_default_qubit_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,40 +195,28 @@ def circuit_3(y):

shot_testing_combos = [
# expval combinations
([qml.expval(qml.PauliX(0))], 1, 10),
([qml.expval(qml.PauliX(0)), qml.expval(qml.PauliY(0))], 2, 20),
([qml.expval(qml.X(0))], 1, 10),
([qml.expval(qml.X(0)), qml.expval(qml.Y(0))], 2, 20),
# Hamiltonian test cases
([qml.expval(qml.Hamiltonian([1, 1], [qml.PauliX(0), qml.PauliX(1)]))], 2, 20),
(
[qml.expval(qml.Hamiltonian([1, 1], [qml.PauliX(0), qml.PauliX(1)], grouping_type="qwc"))],
1,
10,
),
(
[qml.expval(qml.Hamiltonian([1, 1], [qml.PauliX(0), qml.PauliY(0)], grouping_type="qwc"))],
2,
20,
),
([qml.expval(qml.Hamiltonian([1, 0.5, 1], [qml.X(0), qml.Y(0), qml.X(1)]))], 2, 20),
([qml.expval(qml.Hamiltonian([1, 1], [qml.X(0), qml.X(1)], grouping_type="qwc"))], 1, 10),
([qml.expval(qml.Hamiltonian([1, 1], [qml.X(0), qml.Y(0)], grouping_type="qwc"))], 2, 20),
# op arithmetic test cases
([qml.expval(qml.sum(qml.PauliX(0), qml.PauliY(0)))], 2, 20),
([qml.expval(qml.sum(qml.PauliX(0), qml.PauliX(0) @ qml.PauliX(1)))], 1, 10),
([qml.expval(qml.sum(qml.PauliX(0), qml.Hadamard(0)))], 2, 20),
(
[qml.expval(qml.sum(qml.PauliX(0), qml.PauliY(1) @ qml.PauliX(1), grouping_type="qwc"))],
1,
10,
),
([qml.expval(qml.sum(qml.X(0), qml.Y(0)))], 2, 20),
([qml.expval(qml.sum(qml.X(0), qml.X(0) @ qml.X(1)))], 1, 10),
([qml.expval(qml.sum(qml.X(0), qml.Hadamard(0)))], 2, 20),
([qml.expval(qml.sum(qml.X(0), qml.Y(1) @ qml.X(1), grouping_type="qwc"))], 1, 10),
(
[
qml.expval(qml.prod(qml.PauliX(0), qml.PauliX(1))),
qml.expval(qml.prod(qml.PauliX(1), qml.PauliX(2))),
qml.expval(qml.prod(qml.X(0), qml.X(1))),
qml.expval(qml.prod(qml.X(1), qml.X(2))),
],
1,
10,
),
# computational basis measurements
([qml.probs(wires=(0, 1)), qml.sample(wires=(0, 1))], 1, 10),
([qml.probs(wires=(0, 1)), qml.sample(wires=(0, 1)), qml.expval(qml.PauliX(0))], 2, 20),
([qml.probs(wires=(0, 1)), qml.sample(wires=(0, 1)), qml.expval(qml.X(0))], 2, 20),
# classical shadows
([qml.shadow_expval(H0)], 10, 10),
([qml.shadow_expval(H0), qml.probs(wires=(0, 1))], 11, 20),
Expand Down
2 changes: 1 addition & 1 deletion tests/devices/qutrit_mixed/test_qutrit_mixed_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def circuit_3(y):
([qml.expval(qml.GellMann(0, 1))], 1, 10),
([qml.expval(qml.GellMann(0, 1)), qml.expval(qml.GellMann(0, 2))], 2, 20),
# Hamiltonian test cases
([qml.expval(qml.Hamiltonian([1, 1], [qml.GellMann(0, 1), qml.GellMann(1, 5)]))], 2, 20),
([qml.expval(qml.Hamiltonian([1, 1], [qml.GellMann(0, 1), qml.GellMann(1, 5)]))], 1, 10),
# op arithmetic test cases
([qml.expval(qml.sum(qml.GellMann(0, 1), qml.GellMann(1, 4)))], 2, 20),
(
Expand Down
50 changes: 11 additions & 39 deletions tests/interfaces/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,6 @@ class TestBatchTransformExecution:
"""Tests to ensure batch transforms can be correctly executed
via qml.execute and map_batch_transform"""

@pytest.mark.usefixtures("use_new_opmath")
def test_no_batch_transform(self, mocker):
"""Test that batch transforms can be disabled and enabled"""
dev = qml.device("default.qubit.legacy", wires=2, shots=100000)
Expand All @@ -200,47 +199,20 @@ def test_no_batch_transform(self, mocker):
tape = qml.tape.QuantumScript.from_queue(q)
spy = mocker.spy(dev, "batch_transform")

res = qml.execute([tape], dev, None, device_batch_transform=False)
assert np.allclose(res[0], np.cos(y), atol=0.1)

spy.assert_not_called()

res = qml.execute([tape], dev, None, device_batch_transform=True)
spy.assert_called()

assert isinstance(res[0], np.ndarray)
assert res[0].shape == ()
assert np.allclose(res[0], np.cos(y), atol=0.1)

@pytest.mark.usefixtures("use_legacy_opmath")
def test_no_batch_transform_legacy_opmath(self, mocker):
"""Test functionality to enable and disable"""
dev = qml.device("default.qubit.legacy", wires=2, shots=100000)

H = qml.PauliZ(0) @ qml.PauliZ(1) - qml.PauliX(0)
x = 0.6
y = 0.2

with qml.queuing.AnnotatedQueue() as q:
qml.RX(x, wires=0)
qml.RY(y, wires=1)
qml.CNOT(wires=[0, 1])
qml.expval(H)

tape = qml.tape.QuantumScript.from_queue(q)
spy = mocker.spy(dev, "batch_transform")

with pytest.raises(AssertionError, match="Hamiltonian must be used with shots=None"):
if not qml.operation.active_new_opmath():
with pytest.raises(AssertionError, match="Hamiltonian must be used with shots=None"):
_ = qml.execute([tape], dev, None, device_batch_transform=False)
else:
res = qml.execute([tape], dev, None, device_batch_transform=False)
assert np.allclose(res[0], np.cos(y), atol=0.1)

spy.assert_not_called()

res = qml.execute([tape], dev, None, device_batch_transform=True)
spy.assert_called()

assert isinstance(res[0], np.ndarray)
assert res[0].shape == ()
assert np.allclose(res[0], np.cos(y), atol=0.1)
assert qml.math.shape(res[0]) == ()
assert np.allclose(res[0], np.cos(y), rtol=0.05)

def test_batch_transform_dynamic_shots(self):
"""Tests that the batch transform considers the number of shots for the execution, not those
Expand Down Expand Up @@ -462,10 +434,10 @@ def f(x):
assert qml.math.allclose(out, expected)

def test_single_backward_pass_split_hamiltonian(self):
"""Tests that the backward pass is one single batch, not a bunch of batches, when parameter shift
derivatives are requested for a a tape that the device split into batches."""
"""Tests that the backward pass is one single batch, not a bunch of batches, when parameter
shift derivatives are requested for a tape that the device split into batches."""

dev = qml.device("default.qubit.legacy", wires=2)
dev = qml.device("default.qubit.legacy", wires=2, shots=50000)

H = qml.Hamiltonian([1, 1], [qml.PauliY(0), qml.PauliZ(0)], grouping_type="qwc")

Expand All @@ -480,7 +452,7 @@ def f(x):
assert dev.tracker.totals["batches"] == 2
assert dev.tracker.history["batch_len"] == [2, 4]

assert qml.math.allclose(out, -np.cos(x) - np.sin(x))
assert qml.math.allclose(out, -np.cos(x) - np.sin(x), atol=0.05)


execute_kwargs_integration = [
Expand Down
2 changes: 1 addition & 1 deletion tests/interfaces/test_autograd_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -1666,7 +1666,7 @@ def test_hamiltonian_expansion_finite_shots(
gradient_kwargs = {"h": 0.05}

dev = qml.device(dev_name, wires=3, shots=50000)
spy = mocker.spy(qml.transforms, "hamiltonian_expand")
spy = mocker.spy(qml.transforms, "split_non_commuting")
obs = [qml.PauliX(0), qml.PauliX(0) @ qml.PauliZ(1), qml.PauliZ(0) @ qml.PauliZ(1)]

@qnode(
Expand Down
4 changes: 2 additions & 2 deletions tests/interfaces/test_jax_jit_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -1540,7 +1540,7 @@ def test_hamiltonian_expansion_analytic(
tol = TOL_FOR_SPSA

dev = qml.device(dev_name, wires=3, shots=None)
spy = mocker.spy(qml.transforms, "hamiltonian_expand")
spy = mocker.spy(qml.transforms, "split_non_commuting")
obs = [qml.PauliX(0), qml.PauliX(0) @ qml.PauliZ(1), qml.PauliZ(0) @ qml.PauliZ(1)]

@qnode(
Expand Down Expand Up @@ -1609,7 +1609,7 @@ def test_hamiltonian_expansion_finite_shots(
tol = TOL_FOR_SPSA

dev = qml.device(dev_name, wires=3, shots=50000)
spy = mocker.spy(qml.transforms, "hamiltonian_expand")
spy = mocker.spy(qml.transforms, "split_non_commuting")
obs = [qml.PauliX(0), qml.PauliX(0) @ qml.PauliZ(1), qml.PauliZ(0) @ qml.PauliZ(1)]

@qnode(
Expand Down
Loading

0 comments on commit eb39402

Please sign in to comment.