Skip to content

Commit

Permalink
Legacy device API handles Prod observables (#5475)
Browse files Browse the repository at this point in the history
**Context:**
The legacy device API has special handling of `Tensor` and `Sum` as
observables, but `Prod` is not covered by that.

**Description of the Change:**
- `Prod` observables are expanded if they simplify to a `Sum`
- In `check_validity`, the operands of `Prod` observables are checked.

**Possible Drawbacks:**
`simplify` is called on `Prod` observables many times, which might be
inefficient

**Related Shortcut Story:**
[sc-60584]

---------

Co-authored-by: Korbinian Kottmann <43949391+Qottmann@users.noreply.github.com>
  • Loading branch information
astralcai and Qottmann authored Apr 12, 2024
1 parent a04d1a2 commit 8bdf226
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 66 deletions.
25 changes: 22 additions & 3 deletions pennylane/_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
)

from pennylane.operation import Observable, Operation, Tensor, Operator, StatePrepBase
from pennylane.ops import Hamiltonian, Sum, LinearCombination
from pennylane.ops import Hamiltonian, Sum, LinearCombination, Prod
from pennylane.tape import QuantumScript, QuantumTape, expand_tape_state_prep
from pennylane.wires import WireError, Wires
from pennylane.queuing import QueuingManager
Expand Down Expand Up @@ -744,7 +744,6 @@ def batch_transform(self, circuit: QuantumTape):
to be applied to the list of evaluated circuit results.
"""
supports_hamiltonian = self.supports_observable("Hamiltonian")

supports_sum = self.supports_observable("Sum")
finite_shots = self.shots is not None
grouping_known = all(
Expand All @@ -759,7 +758,12 @@ def batch_transform(self, circuit: QuantumTape):
isinstance(obs, (Hamiltonian, LinearCombination)) for obs in circuit.observables
)
expval_sum_in_obs = any(
isinstance(m.obs, Sum) and isinstance(m, ExpectationMP) for m in circuit.measurements
(
isinstance(m.obs, Sum)
or (isinstance(m.obs, Prod) and isinstance(m.obs.simplify(), Sum))
)
and isinstance(m, ExpectationMP)
for m in circuit.measurements
)

is_shadow = any(isinstance(m, ShadowExpvalMP) for m in circuit.measurements)
Expand Down Expand Up @@ -1007,6 +1011,21 @@ def check_validity(self, queue, observables):
raise DeviceError(
f"Observable {i.name} not supported on device {self.short_name}"
)

elif isinstance(o, qml.ops.Prod):

supports_prod = self.supports_observable(o.name)
if not supports_prod:
raise DeviceError(f"Observable Prod not supported on device {self.short_name}")

simplified_op = o.simplify()
if isinstance(simplified_op, qml.ops.Prod):
for i in o.simplify().operands:
if not self.supports_observable(i.name):
raise DeviceError(
f"Observable {i.name} not supported on device {self.short_name}"
)

else:
observable_name = o.name

Expand Down
2 changes: 2 additions & 0 deletions pennylane/devices/tests/test_measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,10 @@ class TestHamiltonianSupport:
"""Separate test to ensure that the device can differentiate Hamiltonian observables."""

@pytest.mark.parametrize("ham_constructor", [qml.ops.Hamiltonian, qml.ops.LinearCombination])
@pytest.mark.filterwarnings("ignore::pennylane.PennyLaneDeprecationWarning")
def test_hamiltonian_diff(self, ham_constructor, device_kwargs, tol):
"""Tests a simple VQE gradient using parameter-shift rules."""

device_kwargs["wires"] = 1
dev = qml.device(**device_kwargs)
coeffs = np.array([-0.05, 0.17])
Expand Down
4 changes: 3 additions & 1 deletion pennylane/transforms/hamiltonian_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import pennylane as qml
from pennylane.measurements import ExpectationMP, MeasurementProcess
from pennylane.ops import SProd, Sum
from pennylane.ops import SProd, Sum, Prod
from pennylane.tape import QuantumScript, QuantumTape
from pennylane.transforms import transform

Expand Down Expand Up @@ -341,6 +341,8 @@ def sum_expand(tape: QuantumTape, group: bool = True) -> (Sequence[QuantumTape],
idxs_coeffs_dict = {} # {m_hash: [(location_idx, coeff)]}
for idx, m in enumerate(tape.measurements):
obs = m.obs
if isinstance(obs, Prod) and isinstance(m, ExpectationMP):
obs = obs.simplify()
if isinstance(obs, Sum) and isinstance(m, ExpectationMP):
for summand in obs.operands:
coeff = 1
Expand Down
141 changes: 79 additions & 62 deletions tests/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,36 +30,6 @@
# pylint: disable=abstract-class-instantiated, no-self-use, redefined-outer-name, invalid-name, missing-function-docstring


@pytest.fixture(scope="function")
def mock_device_with_operations(monkeypatch):
"""A function to create a mock device with non-empty operations"""
with monkeypatch.context() as m:
m.setattr(Device, "__abstractmethods__", frozenset())
m.setattr(Device, "operations", mock_device_paulis)
m.setattr(Device, "observables", mock_device_paulis)
m.setattr(Device, "short_name", "MockDevice")

def get_device(wires=1):
return Device(wires=wires)

yield get_device


@pytest.fixture(scope="function")
def mock_device_with_observables(monkeypatch):
"""A function to create a mock device with non-empty observables"""
with monkeypatch.context() as m:
m.setattr(Device, "__abstractmethods__", frozenset())
m.setattr(Device, "operations", mock_device_paulis)
m.setattr(Device, "observables", mock_device_paulis)
m.setattr(Device, "short_name", "MockDevice")

def get_device(wires=1):
return Device(wires=wires)

yield get_device


@pytest.fixture(scope="function")
def mock_device_with_identity(monkeypatch):
"""A function to create a mock device with non-empty observables"""
Expand Down Expand Up @@ -203,19 +173,15 @@ def get_device(wires=1):


@pytest.fixture(scope="function")
def mock_device_arbitrary_wires(monkeypatch):
def mock_device_supporting_prod(monkeypatch):
with monkeypatch.context() as m:
m.setattr(Device, "__abstractmethods__", frozenset())
m.setattr(Device, "_capabilities", mock_device_capabilities)
m.setattr(Device, "operations", ["PauliY", "RX", "Rot"])
m.setattr(Device, "observables", ["PauliZ"])
m.setattr(Device, "operations", ["PauliX", "PauliZ"])
m.setattr(Device, "observables", ["PauliX", "PauliZ", "Prod"])
m.setattr(Device, "short_name", "MockDevice")
m.setattr(Device, "expval", lambda self, x, y, z: 0)
m.setattr(Device, "var", lambda self, x, y, z: 0)
m.setattr(Device, "sample", lambda self, x, y, z: 0)
m.setattr(Device, "apply", lambda self, x, y, z: None)

def get_device(wires):
def get_device(wires=1):
return Device(wires=wires)

yield get_device
Expand Down Expand Up @@ -245,22 +211,22 @@ class TestDeviceSupportedLogic:

# pylint: disable=no-self-use, redefined-outer-name

def test_supports_operation_argument_types(self, mock_device_with_operations):
def test_supports_operation_argument_types(self, mock_device_supporting_paulis):
"""Checks that device.supports_operations returns the correct result
when passed both string and Operation class arguments"""

dev = mock_device_with_operations()
dev = mock_device_supporting_paulis()

assert dev.supports_operation("PauliX")
assert dev.supports_operation(qml.PauliX)

assert not dev.supports_operation("S")
assert not dev.supports_operation(qml.CNOT)

def test_supports_observable_argument_types(self, mock_device_with_observables):
def test_supports_observable_argument_types(self, mock_device_supporting_paulis):
"""Checks that device.supports_observable returns the correct result
when passed both string and Operation class arguments"""
dev = mock_device_with_observables()
dev = mock_device_supporting_paulis()

assert dev.supports_observable("PauliX")
assert dev.supports_observable(qml.PauliX)
Expand Down Expand Up @@ -309,14 +275,14 @@ class TestInternalFunctions: # pylint:disable=too-many-public-methods
"""Test the internal functions of the abstract Device class"""

# pylint: disable=unnecessary-dunder-call
def test_repr(self, mock_device_with_operations):
def test_repr(self, mock_device_supporting_paulis):
"""Tests the __repr__ function"""
dev = mock_device_with_operations()
dev = mock_device_supporting_paulis()
assert "<Device device (wires=1, shots=1000) at " in dev.__repr__()

def test_str(self, mock_device_with_operations):
def test_str(self, mock_device_supporting_paulis):
"""Tests the __str__ function"""
dev = mock_device_with_operations()
dev = mock_device_supporting_paulis()
string = str(dev)
assert "Short name: MockDevice" in string
assert "Package: pennylane" in string
Expand All @@ -340,6 +306,43 @@ def test_check_validity_on_valid_queue(self, mock_device_supporting_paulis):
# Raises an error if queue or observables are invalid
dev.check_validity(queue, observables)

def test_check_validity_containing_prod(self, mock_device_supporting_prod):
"""Tests that the function Device.check_validity works with Prod"""

dev = mock_device_supporting_prod()

queue = [
qml.PauliX(wires=0),
qml.PauliZ(wires=1),
]

observables = [
qml.expval(qml.PauliX(0) @ qml.PauliZ(1)),
qml.expval(qml.PauliZ(0) @ (qml.PauliX(1) @ qml.PauliZ(2))),
]

dev.check_validity(queue, observables)

def test_prod_containing_unsupported_nested_observables(self, mock_device_supporting_prod):
"""Tests that the observables nested within Prod are checked for validity"""

dev = mock_device_supporting_prod()

queue = [
qml.PauliX(wires=0),
qml.PauliZ(wires=1),
]

unsupported_nested_observables = [
qml.expval(qml.PauliZ(0) @ (qml.PauliX(1) @ qml.PauliY(2)))
]

with pytest.raises(
DeviceError,
match="Observable PauliY not supported",
):
dev.check_validity(queue, unsupported_nested_observables)

@pytest.mark.usefixtures("use_legacy_opmath")
def test_check_validity_on_tensor_support_legacy_opmath(self, mock_device_supporting_paulis):
"""Tests the function Device.check_validity with tensor support capability"""
Expand Down Expand Up @@ -429,9 +432,9 @@ def test_check_validity_on_invalid_observable(self, mock_device_supporting_pauli
with pytest.raises(DeviceError, match="Observable Hadamard not supported on device"):
dev.check_validity(queue, observables)

def test_check_validity_on_projector_as_operation(self, mock_device_with_operations):
def test_check_validity_on_projector_as_operation(self, mock_device_supporting_paulis):
"""Test that an error is raised if the operation queue contains qml.Projector"""
dev = mock_device_with_operations(wires=1)
dev = mock_device_supporting_paulis(wires=1)

queue = [qml.PauliX(0), qml.Projector([0], wires=0), qml.PauliZ(0)]
observables = []
Expand Down Expand Up @@ -592,8 +595,8 @@ def test_conditional_ops_unsupported_error(self, mock_device_with_paulis_and_met
(Wires([0]), Wires([0]), Wires([0])),
],
)
def test_order_wires(self, wires, subset, expected_subset, mock_device_arbitrary_wires):
dev = mock_device_arbitrary_wires(wires=wires)
def test_order_wires(self, wires, subset, expected_subset, mock_device):
dev = mock_device(wires=wires)
ordered_subset = dev.order_wires(subset_wires=subset)
assert ordered_subset == expected_subset

Expand All @@ -606,8 +609,8 @@ def test_order_wires(self, wires, subset, expected_subset, mock_device_arbitrary
(Wires([0]), Wires([2])),
],
)
def test_order_wires_raises_value_error(self, wires, subset, mock_device_arbitrary_wires):
dev = mock_device_arbitrary_wires(wires=wires)
def test_order_wires_raises_value_error(self, wires, subset, mock_device):
dev = mock_device(wires=wires)
with pytest.raises(ValueError, match="Could not find some or all subset wires"):
_ = dev.order_wires(subset_wires=subset)

Expand Down Expand Up @@ -658,11 +661,11 @@ def test_default_expand_with_initial_state(self, op, decomp):
assert new_tape.batch_size == tape.batch_size
assert new_tape.output_dim == tape.output_dim

def test_default_expand_fn_with_invalid_op(self, mock_device_with_operations, recwarn):
def test_default_expand_fn_with_invalid_op(self, mock_device_supporting_paulis, recwarn):
"""Test that default_expand_fn works with an invalid op and some measurement."""
invalid_tape = qml.tape.QuantumScript([qml.S(0)], [qml.expval(qml.PauliZ(0))])
expected_tape = qml.tape.QuantumScript([qml.RZ(np.pi / 2, 0)], [qml.expval(qml.PauliZ(0))])
dev = mock_device_with_operations(wires=1)
dev = mock_device_supporting_paulis(wires=1)
expanded_tape = dev.expand_fn(invalid_tape, max_expansion=3)
assert qml.equal(expanded_tape, expected_tape)
assert len(recwarn) == 0
Expand Down Expand Up @@ -799,33 +802,33 @@ def test_unsupported_operations_raise_error(self, mock_device_with_paulis_and_me
with pytest.raises(DeviceError, match="Gate Hadamard not supported on device"):
dev.execute(queue, observables)

def test_execute_obs_probs(self, mock_device_with_observables):
def test_execute_obs_probs(self, mock_device_supporting_paulis):
"""Tests that the execute function raises an error if probabilities are
not supported by the device"""
dev = mock_device_with_observables()
dev = mock_device_supporting_paulis()
obs = qml.probs(op=qml.PauliZ(0))
with pytest.raises(NotImplementedError):
dev.execute([], [obs])

def test_var(self, mock_device_with_observables):
def test_var(self, mock_device_supporting_paulis):
"""Tests that the variance method are not implemented by the device by
default"""
dev = mock_device_with_observables()
dev = mock_device_supporting_paulis()
with pytest.raises(NotImplementedError):
dev.var(qml.PauliZ, 0, [])

def test_sample(self, mock_device_with_observables):
def test_sample(self, mock_device_supporting_paulis):
"""Tests that the sample method are not implemented by the device by
default"""
dev = mock_device_with_observables()
dev = mock_device_supporting_paulis()
with pytest.raises(NotImplementedError):
dev.sample(qml.PauliZ, 0, [])

@pytest.mark.parametrize("wires", [None, []])
def test_probability(self, mock_device_with_observables, wires):
def test_probability(self, mock_device_supporting_paulis, wires):
"""Tests that the probability method are not implemented by the device
by default"""
dev = mock_device_with_observables()
dev = mock_device_supporting_paulis()
with pytest.raises(NotImplementedError):
dev.probability(wires=wires)

Expand Down Expand Up @@ -1204,3 +1207,17 @@ def test_batch_transform_expands_not_supported_sums(self, mocker):

assert len(new_qscripts) == 2
spy.assert_called()

def test_batch_transform_expands_prod_containing_sums(self, mocker):
"""Tests that batch_transform expands a Prod with a nested Sum"""

H = qml.prod(qml.PauliX(0), qml.sum(qml.PauliY(0), qml.PauliZ(0)))
qs = qml.tape.QuantumScript(measurements=[qml.expval(H)])
spy = mocker.spy(qml.transforms, "sum_expand")

dev = self.SomeDevice()
dev.supports_observable = lambda *args, **kwargs: False
new_qscripts, _ = dev.batch_transform(qs)

assert len(new_qscripts) == 2
spy.assert_called()

0 comments on commit 8bdf226

Please sign in to comment.