Skip to content

Commit

Permalink
split-non-commuting use wire grouping with non-pauli-word observable (
Browse files Browse the repository at this point in the history
#5827)

**Context:**
The new unified `split_non_commuting` uses `qwc` grouping by default,
but it is not supported for non-pauli-word observables.

**Description of the Change:**
Use wire-based grouping if one of the observables is not a pauli word

**Related GitHub Issues:**
Fixes #4362
[sc-42003]

---------

Co-authored-by: Christina Lee <christina@xanadu.ai>
  • Loading branch information
astralcai and albi3ro authored Jun 10, 2024
1 parent bec5bb0 commit 8a51dda
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 0 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,9 @@
* Fixes a bug where the gradient of `ControlledSequence`, `Reflection`, `AmplitudeAmplification`, and `Qubitization` is incorrect on `default.qubit.legacy` with `parameter_shift`.
[(#5806)](https://github.com/PennyLaneAI/pennylane/pull/5806)

* Fixed a bug where `split_non_commuting` raises an error when the circuit contains measurements of observables that are not pauli words.
[(#5827)](https://github.com/PennyLaneAI/pennylane/pull/5827)

<h3>Contributors ✍️</h3>

This release contains contributions from (in alphabetical order):
Expand Down
3 changes: 3 additions & 0 deletions pennylane/transforms/split_non_commuting.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,9 @@ def circuit(x):
isinstance(m, ExpectationMP) and isinstance(m.obs, (LinearCombination, Hamiltonian))
for m in tape.measurements
)
or any(
m.obs is not None and not qml.pauli.is_pauli_word(m.obs) for m in single_term_obs_mps
)
):
# This is a loose check to see whether wires grouping or qwc grouping should be used,
# which does not necessarily make perfect sense but is consistent with the old decision
Expand Down
73 changes: 73 additions & 0 deletions tests/transforms/test_split_non_commuting.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,60 @@ def test_batch_of_tapes(self, batch_type):
result = ([0.1, 0.2], 0.2, 0.3, 0.4)
assert fn(result) == ((0.1, 0.2, 0.2), (0.3, 0.4))

@pytest.mark.parametrize(
"non_pauli_obs",
[
[
qml.Projector([0], wires=[1]),
qml.Projector([1, 1, 0, 1], wires=[0, 1]),
],
[
qml.Hadamard(wires=[1]),
qml.Hadamard(wires=[0]) @ qml.PauliX(wires=[1]),
],
],
)
def test_tape_with_non_pauli_obs(self, non_pauli_obs):
"""Tests that the tape is split correctly when containing non-Pauli observables"""

obs_list = single_term_obs_list + non_pauli_obs
measurements = [
qml.expval(c * o) for c, o in zip([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], obs_list)
]
tape = qml.tape.QuantumScript([], measurements, shots=100)

expected_tapes_no_grouping = [
qml.tape.QuantumScript([], [qml.expval(o)], shots=100) for o in obs_list
]

tapes, fn = split_non_commuting(tape, grouping_strategy=None)
for actual_tape, expected_tape in zip(tapes, expected_tapes_no_grouping):
assert qml.equal(actual_tape, expected_tape)
assert qml.math.allclose(
fn([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]), [0.01, 0.04, 0.09, 0.16, 0.25, 0.36, 0.49]
)

wires_groups = [
[qml.X(0), qml.Z(1)],
[qml.Y(0), non_pauli_obs[0]],
[qml.X(0) @ qml.Y(1)],
[qml.Y(0) @ qml.Z(1)],
[non_pauli_obs[1]],
]

# wires grouping produces [[0, 2], [1, 5], [3], [4], [6]]
expected_tapes_wires_grouping = [
qml.tape.QuantumScript([], [qml.expval(o) for o in group], shots=100)
for group in wires_groups
]

tapes, fn = split_non_commuting(tape)
for actual_tape, expected_tape in zip(tapes, expected_tapes_wires_grouping):
assert qml.equal(actual_tape, expected_tape)
assert qml.math.allclose(
fn([[0.1, 0.2], [0.3, 0.6], 0.4, 0.5, 0.7]), [0.01, 0.06, 0.06, 0.16, 0.25, 0.36, 0.49]
)


class TestIntegration:
"""Tests the ``split_non_commuting`` transform performed on a QNode"""
Expand Down Expand Up @@ -816,6 +870,25 @@ def circuit():
assert _dev.tracker.totals == {}
assert qml.math.allclose(res, [1.5, 2.5])

def test_non_pauli_obs_in_circuit(self):
"""Tests that the tape is executed correctly with non-pauli observables"""

_dev = qml.device("default.qubit", wires=1)

@qml.transforms.split_non_commuting
@qml.qnode(_dev)
def circuit():
qml.Hadamard(0)
return (
qml.expval(qml.Projector([0], wires=[0])),
qml.expval(qml.Projector([1], wires=[0])),
)

with _dev.tracker:
res = circuit()
assert _dev.tracker.totals["simulations"] == 2
assert qml.math.allclose(res, [0.5, 0.5])


expected_grad_param_0 = [
0.125,
Expand Down

0 comments on commit 8a51dda

Please sign in to comment.