Skip to content

Commit

Permalink
support trainable Sum observables (#4251)
Browse files Browse the repository at this point in the history
* support trainable Sum observables (analytic only)

* just use pre-rotated state; use super with finite shots

* fix tests; add test for trainable Sum coeffs

* changelog

* use overlapping wires to prove they work

* add hacky gradient support

* Revert "add hacky gradient support"

This reverts commit 09009c8.

* set interface=None to keep test behaviour
  • Loading branch information
timmysilv authored Jun 16, 2023
1 parent ffafa34 commit 723cdc8
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 15 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,9 @@
An if conditional was intended to prevent divide by zero errors but the division was by the sine of the argument so any multiple of $\pi$ should trigger the conditional, but it was only checking if the argument was 0. Example: `qml.Rot(2.3, 2.3, 2.3)`
[(#4210)](https://github.com/PennyLaneAI/pennylane/pull/4210)
* Allow for `Sum` observables with trainable parameters.
[(#4251)](https://github.com/PennyLaneAI/pennylane/pull/4251)
<h3>Contributors ✍️</h3>
This release contains contributions from (in alphabetical order):
Expand Down
9 changes: 9 additions & 0 deletions pennylane/devices/default_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,12 @@

import pennylane as qml
from pennylane import BasisState, DeviceError, QubitDevice, QubitStateVector, Snapshot
from pennylane.devices.qubit import measure
from pennylane.operation import Operation
from pennylane.ops import Sum
from pennylane.ops.qubit.attributes import diagonal_in_z_basis
from pennylane.pulse import ParametrizedEvolution
from pennylane.measurements import ExpectationMP
from pennylane.typing import TensorLike
from pennylane.wires import WireError

Expand Down Expand Up @@ -567,6 +570,12 @@ def expval(self, observable, shot_range=None, bin_size=None):
Hamiltonian is not NumPy or Autograd
"""
# intercept Sums
if isinstance(observable, Sum) and not self.shots:
return measure(
ExpectationMP(observable.map_wires(self.wire_map)), self._pre_rotated_state
)

# intercept other Hamiltonians
# TODO: Ideally, this logic should not live in the Device, but be moved
# to a component that can be re-used by devices as needed.
Expand Down
24 changes: 15 additions & 9 deletions pennylane/devices/qubit/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,18 +135,24 @@ def get_measurement_function(
if measurementprocess.obs.name == "SparseHamiltonian":
return csr_dot_products

if isinstance(measurementprocess.obs, Hamiltonian) or (
isinstance(measurementprocess.obs, Sum)
and measurementprocess.obs.has_overlapping_wires
and len(measurementprocess.obs.wires) > 7
):
# Use tensor contraction for `Sum` expectation values with non-commuting summands
# and 8 or more wires as it's faster than using eigenvalues.

backprop_mode = math.get_interface(state) != "numpy"
if isinstance(measurementprocess.obs, Hamiltonian):
# need to work out thresholds for when its faster to use "backprop mode" measurements
backprop_mode = math.get_interface(state) != "numpy"
return sum_of_terms_method if backprop_mode else csr_dot_products

if isinstance(measurementprocess.obs, Sum):
if backprop_mode:
# always use sum_of_terms_method for Sum observables in backprop mode
return sum_of_terms_method
if (
measurementprocess.obs.has_overlapping_wires
and len(measurementprocess.obs.wires) > 7
):
# Use tensor contraction for `Sum` expectation values with non-commuting summands
# and 8 or more wires as it's faster than using eigenvalues.

return csr_dot_products

if measurementprocess.obs is None or measurementprocess.obs.has_diagonalizing_gates:
return state_diagonalizing_gates

Expand Down
4 changes: 1 addition & 3 deletions pennylane/measurements/measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
from enum import Enum
from typing import Sequence, Tuple, Optional

import numpy as np

import pennylane as qml
from pennylane.operation import Operator
from pennylane.wires import Wires
Expand Down Expand Up @@ -154,7 +152,7 @@ def __init__(
if obs is not None:
raise ValueError("Cannot set the eigenvalues if an observable is provided.")

self._eigvals = np.array(eigvals)
self._eigvals = qml.math.asarray(eigvals)

# TODO: remove the following lines once devices
# have been refactored to accept and understand receiving
Expand Down
5 changes: 4 additions & 1 deletion pennylane/ops/op_math/sprod.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,10 @@ def eigvals(self):
Returns:
array: array containing the eigenvalues of the operator.
"""
return self.scalar * self.base.eigvals()
base_eigs = self.base.eigvals()
if qml.math.get_interface(self.scalar) == "torch" and self.scalar.requires_grad:
base_eigs = qml.math.convert_like(base_eigs, self.scalar)
return self.scalar * base_eigs

def sparse_matrix(self, wire_order=None):
"""Computes, by default, a `scipy.sparse.csr_matrix` representation of this Tensor.
Expand Down
69 changes: 69 additions & 0 deletions tests/devices/test_default_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2361,6 +2361,75 @@ def test_Hamiltonian_filtered_from_rotations(self, mocker):
assert qml.equal(call_args.measurements[0], qml.expval(qml.PauliX(0)))


class TestSumSupport:
"""Tests for custom Sum support in DefaultQubit."""

expected_grad = [-np.sin(1.3), np.cos(1.3)]

@staticmethod
def circuit(y, z):
qml.RX(1.3, 0)
return qml.expval(
qml.sum(
qml.s_prod(y, qml.PauliY(0)),
qml.s_prod(z, qml.PauliZ(0)),
)
)

def test_super_expval_not_called(self, mocker):
"""Tests basic expval result, and ensures QubitDevice.expval is not called."""
dev = qml.device("default.qubit", wires=1)
spy = mocker.spy(qml.QubitDevice, "expval")
obs = qml.sum(qml.s_prod(0.1, qml.PauliX(0)), qml.s_prod(0.2, qml.PauliZ(0)))
assert np.isclose(dev.expval(obs), 0.2)
spy.assert_not_called()

@pytest.mark.autograd
def test_trainable_autograd(self):
"""Tests that coeffs passed to a sum are trainable with autograd."""
dev = qml.device("default.qubit", wires=1)
qnode = qml.QNode(self.circuit, dev, interface="autograd")
y, z = np.array([1.1, 2.2])
actual = qml.grad(qnode)(y, z)
assert np.allclose(actual, self.expected_grad)

@pytest.mark.torch
def test_trainable_torch(self):
"""Tests that coeffs passed to a sum are trainable with torch."""
import torch

dev = qml.device("default.qubit", wires=1)
qnode = qml.QNode(self.circuit, dev, interface="torch")
y, z = torch.tensor(1.1, requires_grad=True), torch.tensor(2.2, requires_grad=True)
qnode(y, z).backward()
actual = [y.grad, z.grad]
assert np.allclose(actual, self.expected_grad)

@pytest.mark.tf
def test_trainable_tf(self):
"""Tests that coeffs passed to a sum are trainable with tf."""
import tensorflow as tf

dev = qml.device("default.qubit", wires=1)
qnode = qml.QNode(self.circuit, dev, interface="tensorflow")
y, z = tf.Variable(1.1, dtype=tf.float64), tf.Variable(2.2, dtype=tf.float64)
with tf.GradientTape() as tape:
res = qnode(y, z)
actual = tape.gradient(res, [y, z])
assert np.allclose(actual, self.expected_grad)

@pytest.mark.jax
def test_trainable_jax(self):
"""Tests that coeffs passed to a sum are trainable with jax."""
import jax

dev = qml.device("default.qubit", wires=1)
qnode = qml.QNode(self.circuit, dev, interface="jax")
y, z = jax.numpy.array([1.1, 2.2])
actual = jax.grad(qnode, argnums=[0, 1])(y, z)
assert np.allclose(actual, self.expected_grad)


class TestGetBatchSize:
"""Tests for the helper method ``_get_batch_size`` of ``QubitDevice``."""

Expand Down
4 changes: 2 additions & 2 deletions tests/ops/op_math/test_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,7 +1067,7 @@ def test_non_hermitian_op_in_measurement_process(self):
dev = qml.device("default.qubit", wires=wires)
sum_op = Sum(Prod(qml.RX(1.23, wires=0), qml.Identity(wires=1)), qml.Identity(wires=1))

@qml.qnode(dev)
@qml.qnode(dev, interface=None)
def my_circ():
qml.PauliX(0)
return qml.expval(sum_op)
Expand All @@ -1079,7 +1079,7 @@ def test_params_can_be_considered_trainable(self):
"""Tests that the parameters of a Sum are considered trainable."""
dev = qml.device("default.qubit", wires=2)

@qml.qnode(dev)
@qml.qnode(dev, interface=None)
def circuit():
return qml.expval(Sum(qml.RX(1.1, 0), qml.RY(qnp.array(2.2), 0)))

Expand Down

0 comments on commit 723cdc8

Please sign in to comment.