Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up Device.batch_transform to use split_non_commuting #5828

Merged
merged 40 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
de98f51
`split-non-commuting` use wire grouping with non-pauli-word observable
astralcai Jun 10, 2024
8f47d26
Clean up `Device.batch_transform` to use `split_non_commuting`
astralcai Jun 10, 2024
b858673
remove unused import
astralcai Jun 10, 2024
a0f44e6
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
astralcai Jun 10, 2024
29cdd05
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
astralcai Jun 10, 2024
c99b441
add tests
astralcai Jun 10, 2024
46f0d10
Merge branch 'split-non-com-non-pauli' of https://github.com/PennyLan…
astralcai Jun 10, 2024
4c951ba
update conditions for null processing
astralcai Jun 10, 2024
f24e0f3
add special logic for braket plugin
astralcai Jun 10, 2024
efea67f
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
astralcai Jun 10, 2024
8ad6d3a
revert accidental removal
astralcai Jun 10, 2024
c922164
more testcase fixes
astralcai Jun 11, 2024
eddbd3e
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
astralcai Jun 11, 2024
61dee38
Fix bug where `split_non_commuting` erases trainability of observables
astralcai Jun 11, 2024
51dc029
remove accidental fixture
astralcai Jun 11, 2024
9ae3a0e
Merge branch 'split-non-com-diff' of https://github.com/PennyLaneAI/p…
astralcai Jun 11, 2024
a7dc2db
update logic
astralcai Jun 11, 2024
ccc6cfa
apply suggestions from code review
astralcai Jun 11, 2024
4d5925d
update test cases
astralcai Jun 11, 2024
d90503c
Merge branch 'split-non-com-diff' of https://github.com/PennyLaneAI/p…
astralcai Jun 11, 2024
b1bd7f9
fix autograd tests
astralcai Jun 11, 2024
f72158e
fix device test
astralcai Jun 11, 2024
5ce4c9a
update logic
astralcai Jun 11, 2024
6489e9a
fix vqe tests
astralcai Jun 12, 2024
3a011fe
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
astralcai Jun 12, 2024
05eb3dd
Merge branch 'master' into batch-transform
astralcai Jun 12, 2024
844352b
fix test case
astralcai Jun 12, 2024
f10b1c6
add changelog entry
astralcai Jun 12, 2024
f9b160e
Merge branch 'master' into batch-transform
astralcai Jun 12, 2024
ae00cef
fix device tests
astralcai Jun 12, 2024
8e75a38
make codecov happy
astralcai Jun 12, 2024
ebd7906
Merge branch 'master' into batch-transform
astralcai Jun 12, 2024
f85d723
fix bug
astralcai Jun 12, 2024
16d1cfd
fix bug
astralcai Jun 12, 2024
a0cf500
Merge branch 'master' into batch-transform
astralcai Jun 12, 2024
bba54f7
minor update
astralcai Jun 12, 2024
e438931
update name and doc
astralcai Jun 12, 2024
421a9a1
apply suggestions from code review
astralcai Jun 13, 2024
6425c89
Merge branch 'master' into batch-transform
astralcai Jun 18, 2024
eaf611f
Merge branch 'master' into batch-transform
astralcai Jun 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
* `qml.transforms.split_non_commuting` can now handle circuits containing measurements of multi-term observables.
[(#5729)](https://github.com/PennyLaneAI/pennylane/pull/5729)
[(#5853)](https://github.com/PennyLaneAI/pennylane/pull/5838)
[(#5828)](https://github.com/PennyLaneAI/pennylane/pull/5828)
[(#5869)](https://github.com/PennyLaneAI/pennylane/pull/5869)

* The qchem module has dedicated functions for calling `pyscf` and `openfermion` backends. The
Expand Down
113 changes: 61 additions & 52 deletions pennylane/_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,11 @@

import pennylane as qml
from pennylane.measurements import (
CountsMP,
Expectation,
ExpectationMP,
MeasurementProcess,
MidMeasureMP,
Probability,
ProbabilityMP,
Sample,
SampleMP,
ShadowExpvalMP,
State,
Variance,
Expand Down Expand Up @@ -738,79 +734,92 @@ def batch_transform(self, circuit: QuantumTape):
the sequence of circuits to be executed, and a post-processing function
to be applied to the list of evaluated circuit results.
"""
supports_hamiltonian = self.supports_observable("Hamiltonian")
supports_sum = self.supports_observable("Sum")

def null_postprocess(results):
return results[0]

finite_shots = self.shots is not None
grouping_known = all(
obs.grouping_indices is not None
for obs in circuit.observables
if isinstance(obs, (Hamiltonian, LinearCombination))
has_shadow = any(isinstance(m, ShadowExpvalMP) for m in circuit.measurements)
is_analytic_or_shadow = not finite_shots or has_shadow
all_obs_usable = self._all_multi_term_obs_supported(circuit)
exists_multi_term_obs = any(
isinstance(m.obs, (Hamiltonian, Sum, Prod, SProd)) for m in circuit.measurements
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
)
# device property present in braket plugin
use_grouping = getattr(self, "use_grouping", True)

hamiltonian_in_obs = any(
isinstance(obs, (Hamiltonian, LinearCombination)) for obs in circuit.observables
has_overlapping_wires = len(circuit.obs_sharing_wires) > 0
single_hamiltonian = len(circuit.measurements) == 1 and isinstance(
circuit.measurements[0].obs, (Hamiltonian, Sum)
)

expval_sum_or_prod_in_obs = any(
isinstance(m.obs, (Sum, Prod, SProd)) and isinstance(m, ExpectationMP)
for m in circuit.measurements
single_hamiltonian_with_grouping_known = (
single_hamiltonian and circuit.measurements[0].obs.grouping_indices is not None
)

is_shadow = any(isinstance(m, ShadowExpvalMP) for m in circuit.measurements)
if not getattr(self, "use_grouping", True) and single_hamiltonian and all_obs_usable:
# Special logic for the braket plugin
circuits = [circuit]
processing_fn = null_postprocess

hamiltonian_unusable = not supports_hamiltonian or (finite_shots and not is_shadow)
elif not exists_multi_term_obs and not has_overlapping_wires:
circuits = [circuit]
processing_fn = null_postprocess

if hamiltonian_in_obs and (hamiltonian_unusable or (use_grouping and grouping_known)):
# If the observable contains a Hamiltonian and the device does not
# support Hamiltonians, or if the simulation uses finite shots, or
# if the Hamiltonian explicitly specifies an observable grouping,
# split tape into multiple tapes of diagonalizable known observables.
try:
circuits, hamiltonian_fn = qml.transforms.hamiltonian_expand(circuit, group=False)
except ValueError:
circuits, hamiltonian_fn = qml.transforms.sum_expand(circuit)
elif is_analytic_or_shadow and all_obs_usable and not has_overlapping_wires:
circuits = [circuit]
processing_fn = null_postprocess

elif expval_sum_or_prod_in_obs and not is_shadow and not supports_sum:
circuits, hamiltonian_fn = qml.transforms.sum_expand(circuit)
elif single_hamiltonian_with_grouping_known:

elif (
len(circuit.obs_sharing_wires) > 0
and not hamiltonian_in_obs
and all(
not isinstance(m, (SampleMP, ProbabilityMP, CountsMP)) for m in circuit.measurements
)
):
# Check for case of non-commuting terms and that there are no Hamiltonians
# TODO: allow for Hamiltonians in list of observables as well.
circuits, hamiltonian_fn = qml.transforms.split_non_commuting(circuit)
# Use qwc grouping if the circuit contains a single measurement of a
# Hamiltonian/Sum with grouping indices already calculated.
circuits, processing_fn = qml.transforms.split_non_commuting(circuit, "qwc")

else:
# otherwise, return the output of an identity transform
circuits = [circuit]
elif any(isinstance(m.obs, (Hamiltonian, LinearCombination)) for m in circuit.measurements):

def hamiltonian_fn(res):
return res[0]
# Otherwise, use wire-based grouping if the circuit contains a Hamiltonian
# that is potentially very large.
circuits, processing_fn = qml.transforms.split_non_commuting(circuit, "wires")

# Check whether the circuit was broadcasted (then the Hamiltonian-expanded
# ones will be as well) and whether broadcasting is supported
else:
circuits, processing_fn = qml.transforms.split_non_commuting(circuit)

# Check whether the circuit was broadcasted and whether broadcasting is supported
if circuit.batch_size is None or self.capabilities().get("supports_broadcasting"):
# If the circuit wasn't broadcasted or broadcasting is supported, no action required
return circuits, hamiltonian_fn
return circuits, processing_fn

# Expand each of the broadcasted Hamiltonian-expanded circuits
expanded_tapes, expanded_fn = qml.transforms.broadcast_expand(circuits)

# Chain the postprocessing functions of the broadcasted-tape expansions and the Hamiltonian
# expansion. Note that the application order is reversed compared to the expansion order,
# i.e. while we first applied `hamiltonian_expand` to the tape, we need to process the
# i.e. while we first applied `split_non_commuting` to the tape, we need to process the
# results from the broadcast expansion first.
def total_processing(results):
return hamiltonian_fn(expanded_fn(results))
return processing_fn(expanded_fn(results))

return expanded_tapes, total_processing

def _all_multi_term_obs_supported(self, circuit):
"""Check whether all multi-term observables in the circuit are supported."""

for mp in circuit.measurements:

if mp.obs is None:
# Some measurements are not observable based.
continue
astralcai marked this conversation as resolved.
Show resolved Hide resolved

if mp.obs.name == "LinearCombination" and not self.supports_observable("Hamiltonian"):
return False

if mp.obs.name in (
"Hamiltonian",
"Sum",
"Prod",
"SProd",
) and not self.supports_observable(mp.obs.name):
return False

return True

@property
def op_queue(self):
"""The operation queue to be applied.
Expand Down
25 changes: 24 additions & 1 deletion pennylane/devices/qubit/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,30 @@ def _get_num_executions_for_expval_H(obs):
indices = obs.grouping_indices
if indices:
return len(indices)
return sum(int(not isinstance(o, qml.Identity)) for o in obs.terms()[1])
return _get_num_wire_groups_for_expval_H(obs)


def _get_num_wire_groups_for_expval_H(obs):
_, obs_list = obs.terms()
wires_list = []
added_obs = []
num_groups = 0
for o in obs_list:
if o in added_obs:
continue
if isinstance(o, qml.Identity):
continue
added = False
for wires in wires_list:
if len(qml.wires.Wires.shared_wires([wires, o.wires])) == 0:
added_obs.append(o)
added = True
break
if not added:
added_obs.append(o)
wires_list.append(o.wires)
num_groups += 1
return num_groups


def _get_num_executions_for_sum(obs):
Expand Down
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,12 @@ def new_opmath_only():
pytest.skip("This feature only works with new opmath enabled")


@pytest.fixture
def legacy_opmath_only():
if qml.operation.active_new_opmath():
pytest.skip("This test exclusively tests legacy opmath")


#######################################################################

try:
Expand Down
36 changes: 12 additions & 24 deletions tests/devices/default_qubit/test_default_qubit_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,40 +195,28 @@ def circuit_3(y):

shot_testing_combos = [
# expval combinations
([qml.expval(qml.PauliX(0))], 1, 10),
([qml.expval(qml.PauliX(0)), qml.expval(qml.PauliY(0))], 2, 20),
([qml.expval(qml.X(0))], 1, 10),
([qml.expval(qml.X(0)), qml.expval(qml.Y(0))], 2, 20),
# Hamiltonian test cases
([qml.expval(qml.Hamiltonian([1, 1], [qml.PauliX(0), qml.PauliX(1)]))], 2, 20),
(
[qml.expval(qml.Hamiltonian([1, 1], [qml.PauliX(0), qml.PauliX(1)], grouping_type="qwc"))],
1,
10,
),
(
[qml.expval(qml.Hamiltonian([1, 1], [qml.PauliX(0), qml.PauliY(0)], grouping_type="qwc"))],
2,
20,
),
([qml.expval(qml.Hamiltonian([1, 0.5, 1], [qml.X(0), qml.Y(0), qml.X(1)]))], 2, 20),
([qml.expval(qml.Hamiltonian([1, 1], [qml.X(0), qml.X(1)], grouping_type="qwc"))], 1, 10),
([qml.expval(qml.Hamiltonian([1, 1], [qml.X(0), qml.Y(0)], grouping_type="qwc"))], 2, 20),
# op arithmetic test cases
([qml.expval(qml.sum(qml.PauliX(0), qml.PauliY(0)))], 2, 20),
([qml.expval(qml.sum(qml.PauliX(0), qml.PauliX(0) @ qml.PauliX(1)))], 1, 10),
([qml.expval(qml.sum(qml.PauliX(0), qml.Hadamard(0)))], 2, 20),
(
[qml.expval(qml.sum(qml.PauliX(0), qml.PauliY(1) @ qml.PauliX(1), grouping_type="qwc"))],
1,
10,
),
([qml.expval(qml.sum(qml.X(0), qml.Y(0)))], 2, 20),
([qml.expval(qml.sum(qml.X(0), qml.X(0) @ qml.X(1)))], 1, 10),
([qml.expval(qml.sum(qml.X(0), qml.Hadamard(0)))], 2, 20),
([qml.expval(qml.sum(qml.X(0), qml.Y(1) @ qml.X(1), grouping_type="qwc"))], 1, 10),
(
[
qml.expval(qml.prod(qml.PauliX(0), qml.PauliX(1))),
qml.expval(qml.prod(qml.PauliX(1), qml.PauliX(2))),
qml.expval(qml.prod(qml.X(0), qml.X(1))),
qml.expval(qml.prod(qml.X(1), qml.X(2))),
],
1,
10,
),
# computational basis measurements
([qml.probs(wires=(0, 1)), qml.sample(wires=(0, 1))], 1, 10),
([qml.probs(wires=(0, 1)), qml.sample(wires=(0, 1)), qml.expval(qml.PauliX(0))], 2, 20),
([qml.probs(wires=(0, 1)), qml.sample(wires=(0, 1)), qml.expval(qml.X(0))], 2, 20),
# classical shadows
([qml.shadow_expval(H0)], 10, 10),
([qml.shadow_expval(H0), qml.probs(wires=(0, 1))], 11, 20),
Expand Down
2 changes: 1 addition & 1 deletion tests/devices/qutrit_mixed/test_qutrit_mixed_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def circuit_3(y):
([qml.expval(qml.GellMann(0, 1))], 1, 10),
([qml.expval(qml.GellMann(0, 1)), qml.expval(qml.GellMann(0, 2))], 2, 20),
# Hamiltonian test cases
([qml.expval(qml.Hamiltonian([1, 1], [qml.GellMann(0, 1), qml.GellMann(1, 5)]))], 2, 20),
([qml.expval(qml.Hamiltonian([1, 1], [qml.GellMann(0, 1), qml.GellMann(1, 5)]))], 1, 10),
# op arithmetic test cases
([qml.expval(qml.sum(qml.GellMann(0, 1), qml.GellMann(1, 4)))], 2, 20),
(
Expand Down
50 changes: 11 additions & 39 deletions tests/interfaces/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,6 @@ class TestBatchTransformExecution:
"""Tests to ensure batch transforms can be correctly executed
via qml.execute and map_batch_transform"""

@pytest.mark.usefixtures("use_new_opmath")
def test_no_batch_transform(self, mocker):
"""Test that batch transforms can be disabled and enabled"""
dev = qml.device("default.qubit.legacy", wires=2, shots=100000)
Expand All @@ -200,47 +199,20 @@ def test_no_batch_transform(self, mocker):
tape = qml.tape.QuantumScript.from_queue(q)
spy = mocker.spy(dev, "batch_transform")

res = qml.execute([tape], dev, None, device_batch_transform=False)
assert np.allclose(res[0], np.cos(y), atol=0.1)

spy.assert_not_called()

res = qml.execute([tape], dev, None, device_batch_transform=True)
spy.assert_called()

assert isinstance(res[0], np.ndarray)
assert res[0].shape == ()
assert np.allclose(res[0], np.cos(y), atol=0.1)

@pytest.mark.usefixtures("use_legacy_opmath")
def test_no_batch_transform_legacy_opmath(self, mocker):
"""Test functionality to enable and disable"""
dev = qml.device("default.qubit.legacy", wires=2, shots=100000)

H = qml.PauliZ(0) @ qml.PauliZ(1) - qml.PauliX(0)
x = 0.6
y = 0.2

with qml.queuing.AnnotatedQueue() as q:
qml.RX(x, wires=0)
qml.RY(y, wires=1)
qml.CNOT(wires=[0, 1])
qml.expval(H)

tape = qml.tape.QuantumScript.from_queue(q)
spy = mocker.spy(dev, "batch_transform")

with pytest.raises(AssertionError, match="Hamiltonian must be used with shots=None"):
if not qml.operation.active_new_opmath():
with pytest.raises(AssertionError, match="Hamiltonian must be used with shots=None"):
_ = qml.execute([tape], dev, None, device_batch_transform=False)
else:
res = qml.execute([tape], dev, None, device_batch_transform=False)
assert np.allclose(res[0], np.cos(y), atol=0.1)

spy.assert_not_called()

res = qml.execute([tape], dev, None, device_batch_transform=True)
spy.assert_called()

assert isinstance(res[0], np.ndarray)
assert res[0].shape == ()
assert np.allclose(res[0], np.cos(y), atol=0.1)
assert qml.math.shape(res[0]) == ()
assert np.allclose(res[0], np.cos(y), rtol=0.05)

def test_batch_transform_dynamic_shots(self):
"""Tests that the batch transform considers the number of shots for the execution, not those
Expand Down Expand Up @@ -462,10 +434,10 @@ def f(x):
assert qml.math.allclose(out, expected)

def test_single_backward_pass_split_hamiltonian(self):
"""Tests that the backward pass is one single batch, not a bunch of batches, when parameter shift
derivatives are requested for a a tape that the device split into batches."""
"""Tests that the backward pass is one single batch, not a bunch of batches, when parameter
shift derivatives are requested for a tape that the device split into batches."""

dev = qml.device("default.qubit.legacy", wires=2)
dev = qml.device("default.qubit.legacy", wires=2, shots=50000)

H = qml.Hamiltonian([1, 1], [qml.PauliY(0), qml.PauliZ(0)], grouping_type="qwc")

Expand All @@ -480,7 +452,7 @@ def f(x):
assert dev.tracker.totals["batches"] == 2
assert dev.tracker.history["batch_len"] == [2, 4]

assert qml.math.allclose(out, -np.cos(x) - np.sin(x))
assert qml.math.allclose(out, -np.cos(x) - np.sin(x), atol=0.05)


execute_kwargs_integration = [
Expand Down
2 changes: 1 addition & 1 deletion tests/interfaces/test_autograd_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -1666,7 +1666,7 @@ def test_hamiltonian_expansion_finite_shots(
gradient_kwargs = {"h": 0.05}

dev = qml.device(dev_name, wires=3, shots=50000)
spy = mocker.spy(qml.transforms, "hamiltonian_expand")
spy = mocker.spy(qml.transforms, "split_non_commuting")
obs = [qml.PauliX(0), qml.PauliX(0) @ qml.PauliZ(1), qml.PauliZ(0) @ qml.PauliZ(1)]

@qnode(
Expand Down
4 changes: 2 additions & 2 deletions tests/interfaces/test_jax_jit_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -1540,7 +1540,7 @@ def test_hamiltonian_expansion_analytic(
tol = TOL_FOR_SPSA

dev = qml.device(dev_name, wires=3, shots=None)
spy = mocker.spy(qml.transforms, "hamiltonian_expand")
spy = mocker.spy(qml.transforms, "split_non_commuting")
obs = [qml.PauliX(0), qml.PauliX(0) @ qml.PauliZ(1), qml.PauliZ(0) @ qml.PauliZ(1)]

@qnode(
Expand Down Expand Up @@ -1609,7 +1609,7 @@ def test_hamiltonian_expansion_finite_shots(
tol = TOL_FOR_SPSA

dev = qml.device(dev_name, wires=3, shots=50000)
spy = mocker.spy(qml.transforms, "hamiltonian_expand")
spy = mocker.spy(qml.transforms, "split_non_commuting")
obs = [qml.PauliX(0), qml.PauliX(0) @ qml.PauliZ(1), qml.PauliZ(0) @ qml.PauliZ(1)]

@qnode(
Expand Down
Loading
Loading