Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lightning qubit uses parameter shift if metric tensor applied #5624

Merged
merged 16 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/releases/changelog-0.36.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,10 @@

<h3>Bug fixes 🐛</h3>

* Patches the QNode so that parameter-shift will be considered best with lightning if
`qml.metric_tensor` is in the transform program.
trbromley marked this conversation as resolved.
Show resolved Hide resolved
[(#5624)](https://github.com/PennyLaneAI/pennylane/pull/5624)

* Improves the error message for setting shots on the new device interface, or trying to access a property
that no longer exists.
[(#5616)](https://github.com/PennyLaneAI/pennylane/pull/5616)
Expand Down
50 changes: 20 additions & 30 deletions pennylane/gradients/metric_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,19 +469,14 @@ def _metric_tensor_cov_matrix(tape, argnum, diag_approx): # pylint: disable=too
# Create a quantum tape with all operations
# prior to the parametrized layer, and the rotations
# to measure in the basis of the parametrized layer generators.
with qml.queuing.AnnotatedQueue() as layer_q:
for op in queue:
# TODO: Maybe there are gates that do not affect the
# generators of interest and thus need not be applied.
qml.apply(op)
# TODO: Maybe there are gates that do not affect the
# generators of interest and thus need not be applied.

for o, param_in_argnum in zip(layer_obs, in_argnum_list[-1]):
if param_in_argnum:
o.diagonalizing_gates()

qml.probs(wires=tape.wires)
for o, param_in_argnum in zip(layer_obs, in_argnum_list[-1]):
if param_in_argnum:
queue.extend(o.diagonalizing_gates())

layer_tape = qml.tape.QuantumScript.from_queue(layer_q)
layer_tape = qml.tape.QuantumScript(queue, [qml.probs(wires=tape.wires)], shots=tape.shots)
metric_tensor_tapes.append(layer_tape)

def processing_fn(probs):
Expand Down Expand Up @@ -573,7 +568,7 @@ def _get_gen_op(op, allow_nonunitary, aux_wire):
) from e


def _get_first_term_tapes(layer_i, layer_j, allow_nonunitary, aux_wire):
def _get_first_term_tapes(layer_i, layer_j, allow_nonunitary, aux_wire, shots):
r"""Obtain the tapes for the first term of all tensor entries
belonging to an off-diagonal block.

Expand Down Expand Up @@ -610,23 +605,16 @@ def _get_first_term_tapes(layer_i, layer_j, allow_nonunitary, aux_wire):
for diffed_op_j, par_idx_j in zip(layer_j.ops, layer_j.param_inds):
gen_op_j = _get_gen_op(WrappedObj(diffed_op_j), allow_nonunitary, aux_wire)

with qml.queuing.AnnotatedQueue() as q:
# Initialize auxiliary wire
qml.Hadamard(wires=aux_wire)
# Apply backward cone of first layer
for op in layer_i.pre_ops:
qml.apply(op)
# Controlled-generator operation of first diff'ed op
qml.apply(gen_op_i)
# Apply first layer and operations between layers
for op in ops_between_cgens:
qml.apply(op)
# Controlled-generator operation of second diff'ed op
qml.apply(gen_op_j)
# Measure X on auxiliary wire
qml.expval(qml.X(aux_wire))

tapes.append(qml.tape.QuantumScript.from_queue(q))
ops = [
qml.Hadamard(wires=aux_wire),
*layer_i.pre_ops,
gen_op_i,
*ops_between_cgens,
gen_op_j,
]
new_tape = qml.tape.QuantumScript(ops, [qml.expval(qml.X(aux_wire))], shots=shots)

tapes.append(new_tape)
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
# Memorize to which metric entry this tape belongs
ids.append((par_idx_i, par_idx_j))

Expand Down Expand Up @@ -707,7 +695,9 @@ def _metric_tensor_hadamard(
block_sizes.append(len(layer_i.param_inds))

for layer_j in layers[idx_i + 1 :]:
_tapes, _ids = _get_first_term_tapes(layer_i, layer_j, allow_nonunitary, aux_wire)
_tapes, _ids = _get_first_term_tapes(
layer_i, layer_j, allow_nonunitary, aux_wire, shots=tape.shots
)
first_term_tapes.extend(_tapes)
ids.extend(_ids)

Expand Down
14 changes: 12 additions & 2 deletions pennylane/workflow/qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,9 +527,9 @@ def __init__(
self.gradient_kwargs = {}
self._tape_cached = False

self._transform_program = qml.transforms.core.TransformProgram()
self._update_gradient_fn()
functools.update_wrapper(self, func)
self._transform_program = qml.transforms.core.TransformProgram()

def __copy__(self):
copied_qnode = QNode.__new__(QNode)
Expand Down Expand Up @@ -592,8 +592,17 @@ def _update_gradient_fn(self, shots=None, tape=None):
return
if tape is None and shots:
tape = qml.tape.QuantumScript([], [], shots=shots)

diff_method = self.diff_method
if (
self.device.name == "lightning.qubit"
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
and qml.metric_tensor in self.transform_program
and self.diff_method == "best"
):
diff_method = "parameter-shift"

self.gradient_fn, self.gradient_kwargs, self.device = self.get_gradient_fn(
self._original_device, self.interface, self.diff_method, tape=tape
self._original_device, self.interface, diff_method, tape=tape
)
self.gradient_kwargs.update(self._user_gradient_kwargs or {})

Expand Down Expand Up @@ -714,6 +723,7 @@ def get_best_method(device, interface, tape=None):
"""
config = _make_execution_config(None, "best")
if isinstance(device, qml.devices.Device):

if device.supports_derivatives(config, circuit=tape):
new_config = device.preprocess(config)[1]
return new_config.gradient_method, {}, device
Expand Down
27 changes: 17 additions & 10 deletions tests/gradients/core/test_metric_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1125,9 +1125,11 @@ class TestFullMetricTensor:
@pytest.mark.autograd
@pytest.mark.parametrize("ansatz, params", zip(fubini_ansatze, fubini_params))
@pytest.mark.parametrize("interface", ["auto", "autograd"])
def test_correct_output_autograd(self, ansatz, params, interface):
@pytest.mark.parametrize("dev_name", ("default.qubit", "lightning.qubit"))
def test_correct_output_autograd(self, dev_name, ansatz, params, interface):

expected = autodiff_metric_tensor(ansatz, self.num_wires)(*params)
dev = qml.device("default.qubit.autograd", wires=self.num_wires + 1)
dev = qml.device(dev_name, wires=self.num_wires + 1)

@qml.qnode(dev, interface=interface)
def circuit(*params):
Expand All @@ -1140,19 +1142,21 @@ def circuit(*params):
if isinstance(mt, tuple):
assert all(qml.math.allclose(_mt, _exp) for _mt, _exp in zip(mt, expected))
else:
print(mt - expected)
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
assert qml.math.allclose(mt, expected)

@pytest.mark.jax
@pytest.mark.parametrize("ansatz, params", zip(fubini_ansatze, fubini_params))
@pytest.mark.parametrize("interface", ["auto", "jax"])
def test_correct_output_jax(self, ansatz, params, interface):
@pytest.mark.parametrize("dev_name", ("default.qubit", "lightning.qubit"))
def test_correct_output_jax(self, dev_name, ansatz, params, interface):
import jax
from jax import numpy as jnp

jax.config.update("jax_enable_x64", True)

expected = autodiff_metric_tensor(ansatz, self.num_wires)(*params)
dev = qml.device("default.qubit.jax", wires=self.num_wires + 1)
dev = qml.device(dev_name, wires=self.num_wires + 1)

params = tuple(jnp.array(p) for p in params)

Expand All @@ -1176,10 +1180,11 @@ def circuit(*params):
@pytest.mark.jax
@pytest.mark.parametrize("ansatz, params", zip(fubini_ansatze, fubini_params))
@pytest.mark.parametrize("interface", ["auto", "jax"])
def test_jax_argnum_error(self, ansatz, params, interface):
@pytest.mark.parametrize("dev_name", ("default.qubit", "lightning.qubit"))
def test_jax_argnum_error(self, dev_name, ansatz, params, interface):
from jax import numpy as jnp

dev = qml.device("default.qubit.jax", wires=self.num_wires + 1)
dev = qml.device(dev_name, wires=self.num_wires + 1)

params = tuple(jnp.array(p) for p in params)

Expand All @@ -1198,11 +1203,12 @@ def circuit(*params):
@pytest.mark.torch
@pytest.mark.parametrize("ansatz, params", zip(fubini_ansatze, fubini_params))
@pytest.mark.parametrize("interface", ["auto", "torch"])
def test_correct_output_torch(self, ansatz, params, interface):
@pytest.mark.parametrize("dev_name", ("default.qubit", "lightning.qubit"))
def test_correct_output_torch(self, dev_name, ansatz, params, interface):
import torch

expected = autodiff_metric_tensor(ansatz, self.num_wires)(*params)
dev = qml.device("default.qubit.torch", wires=self.num_wires + 1)
dev = qml.device(dev_name, wires=self.num_wires + 1)

params = tuple(torch.tensor(p, dtype=torch.float64, requires_grad=True) for p in params)

Expand All @@ -1222,11 +1228,12 @@ def circuit(*params):
@pytest.mark.tf
@pytest.mark.parametrize("ansatz, params", zip(fubini_ansatze, fubini_params))
@pytest.mark.parametrize("interface", ["auto", "tf"])
def test_correct_output_tf(self, ansatz, params, interface):
@pytest.mark.parametrize("dev_name", ("default.qubit", "lightning.qubit"))
def test_correct_output_tf(self, dev_name, ansatz, params, interface):
import tensorflow as tf

expected = autodiff_metric_tensor(ansatz, self.num_wires)(*params)
dev = qml.device("default.qubit.tf", wires=self.num_wires + 1)
dev = qml.device(dev_name, wires=self.num_wires + 1)

params = tuple(tf.Variable(p, dtype=tf.float64) for p in params)

Expand Down
Loading