diff --git a/doc/releases/changelog-0.36.0.md b/doc/releases/changelog-0.36.0.md
index 17d8775df55..2798cba12aa 100644
--- a/doc/releases/changelog-0.36.0.md
+++ b/doc/releases/changelog-0.36.0.md
@@ -722,6 +722,10 @@
Bug fixes 🐛
+* Patches the QNode so that parameter-shift will be considered best with lightning if
+ `qml.metric_tensor` is in the transform program.
+ [(#5624)](https://github.com/PennyLaneAI/pennylane/pull/5624)
+
* Stopped printing the ID of `qcut.MeasureNode` and `qcut.PrepareNode` in tape drawing.
[(#5613)](https://github.com/PennyLaneAI/pennylane/pull/5613)
diff --git a/pennylane/gradients/metric_tensor.py b/pennylane/gradients/metric_tensor.py
index 62e16562475..fbde615c029 100644
--- a/pennylane/gradients/metric_tensor.py
+++ b/pennylane/gradients/metric_tensor.py
@@ -469,19 +469,14 @@ def _metric_tensor_cov_matrix(tape, argnum, diag_approx): # pylint: disable=too
# Create a quantum tape with all operations
# prior to the parametrized layer, and the rotations
# to measure in the basis of the parametrized layer generators.
- with qml.queuing.AnnotatedQueue() as layer_q:
- for op in queue:
- # TODO: Maybe there are gates that do not affect the
- # generators of interest and thus need not be applied.
- qml.apply(op)
+ # TODO: Maybe there are gates that do not affect the
+ # generators of interest and thus need not be applied.
- for o, param_in_argnum in zip(layer_obs, in_argnum_list[-1]):
- if param_in_argnum:
- o.diagonalizing_gates()
-
- qml.probs(wires=tape.wires)
+ for o, param_in_argnum in zip(layer_obs, in_argnum_list[-1]):
+ if param_in_argnum:
+ queue.extend(o.diagonalizing_gates())
- layer_tape = qml.tape.QuantumScript.from_queue(layer_q)
+ layer_tape = qml.tape.QuantumScript(queue, [qml.probs(wires=tape.wires)], shots=tape.shots)
metric_tensor_tapes.append(layer_tape)
def processing_fn(probs):
@@ -573,7 +568,7 @@ def _get_gen_op(op, allow_nonunitary, aux_wire):
) from e
-def _get_first_term_tapes(layer_i, layer_j, allow_nonunitary, aux_wire):
+def _get_first_term_tapes(layer_i, layer_j, allow_nonunitary, aux_wire, shots):
r"""Obtain the tapes for the first term of all tensor entries
belonging to an off-diagonal block.
@@ -610,23 +605,16 @@ def _get_first_term_tapes(layer_i, layer_j, allow_nonunitary, aux_wire):
for diffed_op_j, par_idx_j in zip(layer_j.ops, layer_j.param_inds):
gen_op_j = _get_gen_op(WrappedObj(diffed_op_j), allow_nonunitary, aux_wire)
- with qml.queuing.AnnotatedQueue() as q:
- # Initialize auxiliary wire
- qml.Hadamard(wires=aux_wire)
- # Apply backward cone of first layer
- for op in layer_i.pre_ops:
- qml.apply(op)
- # Controlled-generator operation of first diff'ed op
- qml.apply(gen_op_i)
- # Apply first layer and operations between layers
- for op in ops_between_cgens:
- qml.apply(op)
- # Controlled-generator operation of second diff'ed op
- qml.apply(gen_op_j)
- # Measure X on auxiliary wire
- qml.expval(qml.X(aux_wire))
-
- tapes.append(qml.tape.QuantumScript.from_queue(q))
+ ops = [
+ qml.Hadamard(wires=aux_wire),
+ *layer_i.pre_ops,
+ gen_op_i,
+ *ops_between_cgens,
+ gen_op_j,
+ ]
+ new_tape = qml.tape.QuantumScript(ops, [qml.expval(qml.X(aux_wire))], shots=shots)
+
+ tapes.append(new_tape)
# Memorize to which metric entry this tape belongs
ids.append((par_idx_i, par_idx_j))
@@ -707,7 +695,9 @@ def _metric_tensor_hadamard(
block_sizes.append(len(layer_i.param_inds))
for layer_j in layers[idx_i + 1 :]:
- _tapes, _ids = _get_first_term_tapes(layer_i, layer_j, allow_nonunitary, aux_wire)
+ _tapes, _ids = _get_first_term_tapes(
+ layer_i, layer_j, allow_nonunitary, aux_wire, shots=tape.shots
+ )
first_term_tapes.extend(_tapes)
ids.extend(_ids)
diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py
index 7941fa89fe3..2cf3f89ec42 100644
--- a/pennylane/workflow/qnode.py
+++ b/pennylane/workflow/qnode.py
@@ -527,9 +527,9 @@ def __init__(
self.gradient_kwargs = {}
self._tape_cached = False
+ self._transform_program = qml.transforms.core.TransformProgram()
self._update_gradient_fn()
functools.update_wrapper(self, func)
- self._transform_program = qml.transforms.core.TransformProgram()
def __copy__(self):
copied_qnode = QNode.__new__(QNode)
@@ -592,8 +592,17 @@ def _update_gradient_fn(self, shots=None, tape=None):
return
if tape is None and shots:
tape = qml.tape.QuantumScript([], [], shots=shots)
+
+ diff_method = self.diff_method
+ if (
+ self.device.name == "lightning.qubit"
+ and qml.metric_tensor in self.transform_program
+ and self.diff_method == "best"
+ ):
+ diff_method = "parameter-shift"
+
self.gradient_fn, self.gradient_kwargs, self.device = self.get_gradient_fn(
- self._original_device, self.interface, self.diff_method, tape=tape
+ self._original_device, self.interface, diff_method, tape=tape
)
self.gradient_kwargs.update(self._user_gradient_kwargs or {})
@@ -714,6 +723,7 @@ def get_best_method(device, interface, tape=None):
"""
config = _make_execution_config(None, "best")
if isinstance(device, qml.devices.Device):
+
if device.supports_derivatives(config, circuit=tape):
new_config = device.preprocess(config)[1]
return new_config.gradient_method, {}, device
diff --git a/tests/gradients/core/test_jvp.py b/tests/gradients/core/test_jvp.py
index 7445b5f8a2b..54bdb051572 100644
--- a/tests/gradients/core/test_jvp.py
+++ b/tests/gradients/core/test_jvp.py
@@ -284,6 +284,7 @@ def test_dtype_jax(self, dtype1, dtype2):
determined by the dtype of the dy."""
import jax
+ jax.config.update("jax_enable_x64", True)
dtype = dtype1
dtype1 = getattr(jax.numpy, dtype1)
dtype2 = getattr(jax.numpy, dtype2)
diff --git a/tests/gradients/core/test_metric_tensor.py b/tests/gradients/core/test_metric_tensor.py
index bd234062736..01b43bc8177 100644
--- a/tests/gradients/core/test_metric_tensor.py
+++ b/tests/gradients/core/test_metric_tensor.py
@@ -913,7 +913,7 @@ def test_no_trainable_params_tape(self):
mt_tapes, post_processing = qml.metric_tensor(tape)
res = post_processing(qml.execute(mt_tapes, dev, None))
- assert mt_tapes == []
+ assert mt_tapes == [] # pylint: disable=use-implicit-booleaness-not-comparison
assert res == ()
@@ -1091,8 +1091,13 @@ def qnode(*params):
def mt(*params):
state = qnode(*params)
- rqnode = lambda *params: np.real(qnode(*params))
- iqnode = lambda *params: np.imag(qnode(*params))
+
+ def rqnode(*params):
+ return np.real(qnode(*params))
+
+ def iqnode(*params):
+ return np.imag(qnode(*params))
+
rjac = qml.jacobian(rqnode)(*params)
ijac = qml.jacobian(iqnode)(*params)
@@ -1125,9 +1130,11 @@ class TestFullMetricTensor:
@pytest.mark.autograd
@pytest.mark.parametrize("ansatz, params", zip(fubini_ansatze, fubini_params))
@pytest.mark.parametrize("interface", ["auto", "autograd"])
- def test_correct_output_autograd(self, ansatz, params, interface):
+ @pytest.mark.parametrize("dev_name", ("default.qubit", "lightning.qubit"))
+ def test_correct_output_autograd(self, dev_name, ansatz, params, interface):
+
expected = autodiff_metric_tensor(ansatz, self.num_wires)(*params)
- dev = qml.device("default.qubit.autograd", wires=self.num_wires + 1)
+ dev = qml.device(dev_name, wires=self.num_wires + 1)
@qml.qnode(dev, interface=interface)
def circuit(*params):
@@ -1145,14 +1152,20 @@ def circuit(*params):
@pytest.mark.jax
@pytest.mark.parametrize("ansatz, params", zip(fubini_ansatze, fubini_params))
@pytest.mark.parametrize("interface", ["auto", "jax"])
- def test_correct_output_jax(self, ansatz, params, interface):
+ @pytest.mark.parametrize("dev_name", ("default.qubit", "lightning.qubit"))
+ def test_correct_output_jax(self, dev_name, ansatz, params, interface):
import jax
from jax import numpy as jnp
+ if ansatz == fubini_ansatz2:
+ pytest.xfail("Issue involving trainable indices to be resolved.")
+ if ansatz == fubini_ansatz3 and dev_name == "lightning.qubit":
+ pytest.xfail("Issue invovling trainable_params to be resolved.")
+
jax.config.update("jax_enable_x64", True)
expected = autodiff_metric_tensor(ansatz, self.num_wires)(*params)
- dev = qml.device("default.qubit.jax", wires=self.num_wires + 1)
+ dev = qml.device(dev_name, wires=self.num_wires + 1)
params = tuple(jnp.array(p) for p in params)
@@ -1176,10 +1189,11 @@ def circuit(*params):
@pytest.mark.jax
@pytest.mark.parametrize("ansatz, params", zip(fubini_ansatze, fubini_params))
@pytest.mark.parametrize("interface", ["auto", "jax"])
- def test_jax_argnum_error(self, ansatz, params, interface):
+ @pytest.mark.parametrize("dev_name", ("default.qubit", "lightning.qubit"))
+ def test_jax_argnum_error(self, dev_name, ansatz, params, interface):
from jax import numpy as jnp
- dev = qml.device("default.qubit.jax", wires=self.num_wires + 1)
+ dev = qml.device(dev_name, wires=self.num_wires + 1)
params = tuple(jnp.array(p) for p in params)
@@ -1198,11 +1212,12 @@ def circuit(*params):
@pytest.mark.torch
@pytest.mark.parametrize("ansatz, params", zip(fubini_ansatze, fubini_params))
@pytest.mark.parametrize("interface", ["auto", "torch"])
- def test_correct_output_torch(self, ansatz, params, interface):
+ @pytest.mark.parametrize("dev_name", ("default.qubit", "lightning.qubit"))
+ def test_correct_output_torch(self, dev_name, ansatz, params, interface):
import torch
expected = autodiff_metric_tensor(ansatz, self.num_wires)(*params)
- dev = qml.device("default.qubit.torch", wires=self.num_wires + 1)
+ dev = qml.device(dev_name, wires=self.num_wires + 1)
params = tuple(torch.tensor(p, dtype=torch.float64, requires_grad=True) for p in params)
@@ -1222,11 +1237,12 @@ def circuit(*params):
@pytest.mark.tf
@pytest.mark.parametrize("ansatz, params", zip(fubini_ansatze, fubini_params))
@pytest.mark.parametrize("interface", ["auto", "tf"])
- def test_correct_output_tf(self, ansatz, params, interface):
+ @pytest.mark.parametrize("dev_name", ("default.qubit", "lightning.qubit"))
+ def test_correct_output_tf(self, dev_name, ansatz, params, interface):
import tensorflow as tf
expected = autodiff_metric_tensor(ansatz, self.num_wires)(*params)
- dev = qml.device("default.qubit.tf", wires=self.num_wires + 1)
+ dev = qml.device(dev_name, wires=self.num_wires + 1)
params = tuple(tf.Variable(p, dtype=tf.float64) for p in params)
@@ -1254,17 +1270,18 @@ def diffability_ansatz_0(weights, wires=None):
qml.RZ(weights[2], wires=1)
-expected_diag_jac_0 = lambda weights: np.array(
- [
- [0, 0, 0],
- [0, 0, 0],
+def expected_diag_jac_0(weights):
+ return np.array(
[
- np.cos(weights[0] + weights[1]) * np.sin(weights[0] + weights[1]) / 2,
- np.cos(weights[0] + weights[1]) * np.sin(weights[0] + weights[1]) / 2,
- 0,
- ],
- ]
-)
+ [0, 0, 0],
+ [0, 0, 0],
+ [
+ np.cos(weights[0] + weights[1]) * np.sin(weights[0] + weights[1]) / 2,
+ np.cos(weights[0] + weights[1]) * np.sin(weights[0] + weights[1]) / 2,
+ 0,
+ ],
+ ]
+ )
def diffability_ansatz_1(weights, wires=None):
@@ -1275,17 +1292,18 @@ def diffability_ansatz_1(weights, wires=None):
qml.RZ(weights[2], wires=1)
-expected_diag_jac_1 = lambda weights: np.array(
- [
- [0, 0, 0],
- [-np.sin(2 * weights[0]) / 4, 0, 0],
+def expected_diag_jac_1(weights):
+ return np.array(
[
- np.cos(weights[0]) * np.cos(weights[1]) ** 2 * np.sin(weights[0]) / 2,
- np.cos(weights[0]) ** 2 * np.sin(2 * weights[1]) / 4,
- 0,
- ],
- ]
-)
+ [0, 0, 0],
+ [-np.sin(2 * weights[0]) / 4, 0, 0],
+ [
+ np.cos(weights[0]) * np.cos(weights[1]) ** 2 * np.sin(weights[0]) / 2,
+ np.cos(weights[0]) ** 2 * np.sin(2 * weights[1]) / 4,
+ 0,
+ ],
+ ]
+ )
def diffability_ansatz_2(weights, wires=None):
@@ -1296,17 +1314,19 @@ def diffability_ansatz_2(weights, wires=None):
qml.RZ(weights[2], wires=1)
-expected_diag_jac_2 = lambda weights: np.array(
- [
- [0, 0, 0],
- [0, 0, 0],
+def expected_diag_jac_2(weights):
+ return np.array(
[
- np.cos(weights[1]) ** 2 * np.sin(2 * weights[0]) / 4,
- np.cos(weights[0]) ** 2 * np.sin(2 * weights[1]) / 4,
- 0,
- ],
- ]
-)
+ [0, 0, 0],
+ [0, 0, 0],
+ [
+ np.cos(weights[1]) ** 2 * np.sin(2 * weights[0]) / 4,
+ np.cos(weights[0]) ** 2 * np.sin(2 * weights[1]) / 4,
+ 0,
+ ],
+ ]
+ )
+
weights_diff = np.array([0.432, 0.12, -0.292], requires_grad=True)
@@ -1466,7 +1486,9 @@ def test_autograd(self, diff_method, tol, ansatz, weights, interface):
def cost_full(*weights):
return np.array(qml.metric_tensor(qnode, approx=None)(*weights))
- _cost_full = lambda *weights: np.array(autodiff_metric_tensor(ansatz, 3)(*weights))
+ def _cost_full(*weights):
+ return np.array(autodiff_metric_tensor(ansatz, 3)(*weights))
+
_c = _cost_full(*weights)
c = cost_full(*weights)
assert all(