Skip to content

Commit

Permalink
DeprecateQNode.gradient_fn (#6244)
Browse files Browse the repository at this point in the history
**Context:**

The existence of `QNode.gradient_fn` ties us in to a bit more of an
object oriented framework with a lot of in-place mutation of the qnode.
By freeing ourselves from this property, we can have a bit more of a
functional structure with less coupling and side effects. It will also
free us up to start making other logical changes and improvements.

`QNode.gradient_fn` is also not really defined, so it's hard to tell
what it should actually be and reflect. Things have changed enough
recently with more dynamic gradient validation, that it no longer really
carries the same information it did when it was added.

There isn't really a good analog yet of `QNode.gradient_fn`, since it's
kinda a "processed diff method". We do have stories for next quarter to
start adding helper transforms for things like this, but it won't be
immediate.

**Description of the Change:**

**Benefits:**

**Possible Drawbacks:**

**Related GitHub Issues:**

[sc-71844]

---------

Co-authored-by: Pietropaolo Frisoni <pietropaolo.frisoni@xanadu.ai>
  • Loading branch information
albi3ro and PietropaoloFrisoni authored Sep 12, 2024
1 parent 62b1f06 commit 30331fd
Show file tree
Hide file tree
Showing 21 changed files with 258 additions and 288 deletions.
6 changes: 6 additions & 0 deletions doc/development/deprecations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ deprecations are listed below.
Pending deprecations
--------------------

* `QNode.gradient_fn` is deprecated. Please use `QNode.diff_method` instead. `QNode.get_gradient_fn` can also be used to
process the diff method.

- Deprecated in v0.39
- Will be removed in v0.40

* All of the legacy devices (any with the name ``default.qubit.{autograd,torch,tf,jax,legacy}``) are deprecated. Use ``default.qubit`` instead,
as it supports backpropagation for the many backends the legacy devices support.

Expand Down
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@

<h3>Deprecations 👋</h3>

* `QNode.gradient_fn` is deprecated. Please use `QNode.diff_method` and `QNode.get_gradient_fn` instead.
[(#6244)](https://github.com/PennyLaneAI/pennylane/pull/6244)

<h3>Documentation 📝</h3>

<h3>Bug fixes 🐛</h3>
Expand Down
14 changes: 10 additions & 4 deletions pennylane/resource/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,18 +208,24 @@ def specs_qnode(*args, **kwargs) -> Union[list[dict], dict]:
else qnode.diff_method
)

if isinstance(qnode.gradient_fn, qml.transforms.core.TransformDispatcher):
info["gradient_fn"] = _get_absolute_import_path(qnode.gradient_fn)
gradient_fn = qml.QNode.get_gradient_fn(
qnode.device,
qnode.interface,
qnode.diff_method,
tape=tape,
)[0]
if isinstance(gradient_fn, qml.transforms.core.TransformDispatcher):
info["gradient_fn"] = _get_absolute_import_path(gradient_fn)

try:
info["num_gradient_executions"] = len(qnode.gradient_fn(tape)[0])
info["num_gradient_executions"] = len(gradient_fn(tape)[0])
except Exception as e: # pylint: disable=broad-except
# In the case of a broad exception, we don't want the `qml.specs` transform
# to fail. Instead, we simply indicate that the number of gradient executions
# is not supported for the reason specified.
info["num_gradient_executions"] = f"NotSupported: {str(e)}"
else:
info["gradient_fn"] = qnode.gradient_fn
info["gradient_fn"] = gradient_fn

infos.append(info)

Expand Down
35 changes: 23 additions & 12 deletions pennylane/workflow/construct_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from collections.abc import Callable
from contextlib import nullcontext
from functools import wraps
from typing import Literal, Optional, Union
from typing import Literal, Union

import pennylane as qml
from pennylane.tape import QuantumScriptBatch
Expand Down Expand Up @@ -56,20 +56,24 @@ def wrapped_expand_fn(tape, *args, **kwargs):
return qml.transform(wrapped_expand_fn)


def _get_full_transform_program(qnode: QNode) -> "qml.transforms.core.TransformProgram":
def _get_full_transform_program(
qnode: QNode, gradient_fn
) -> "qml.transforms.core.TransformProgram":
program = qml.transforms.core.TransformProgram(qnode.transform_program)

if getattr(qnode.gradient_fn, "expand_transform", False):
if getattr(gradient_fn, "expand_transform", False):
program.add_transform(
qml.transform(qnode.gradient_fn.expand_transform),
qml.transform(gradient_fn.expand_transform),
**qnode.gradient_kwargs,
)

config = _make_execution_config(qnode, qnode.gradient_fn)
config = _make_execution_config(qnode, gradient_fn)
return program + qnode.device.preprocess(config)[0]


def get_transform_program(qnode: "QNode", level=None) -> "qml.transforms.core.TransformProgram":
def get_transform_program(
qnode: "QNode", level=None, gradient_fn="unset"
) -> "qml.transforms.core.TransformProgram":
"""Extract a transform program at a designated level.
Args:
Expand All @@ -81,6 +85,8 @@ def get_transform_program(qnode: "QNode", level=None) -> "qml.transforms.core.Tr
* ``int``: How many transforms to include, starting from the front of the program
* ``slice``: a slice to select out components of the transform program.
gradient_fn (None, str, TransformDispatcher): The processed gradient fn for the workflow.
Returns:
TransformProgram: the transform program corresponding to the requested level.
Expand Down Expand Up @@ -174,7 +180,10 @@ def circuit():
TransformProgram(validate_device_wires, mid_circuit_measurements, decompose, validate_measurements, validate_observables)
"""
full_transform_program = _get_full_transform_program(qnode)
if gradient_fn == "unset":
gradient_fn = QNode.get_gradient_fn(qnode.device, qnode.interface, qnode.diff_method)[0]

full_transform_program = _get_full_transform_program(qnode, gradient_fn)

num_user = len(qnode.transform_program)
if qnode.transform_program.has_final_transform:
Expand All @@ -193,7 +202,7 @@ def circuit():
elif level == "gradient":
readd_final_transform = True

level = num_user + 1 if getattr(qnode.gradient_fn, "expand_transform", False) else num_user
level = num_user + 1 if getattr(gradient_fn, "expand_transform", False) else num_user
elif isinstance(level, str):
raise ValueError(
f"level {level} not recognized. Acceptable strings are 'device', 'top', 'user', and 'gradient'."
Expand All @@ -211,8 +220,8 @@ def circuit():


def construct_batch(
qnode: QNode,
level: Optional[Union[Literal["top", "user", "device", "gradient"], int, slice]] = "user",
qnode: Union[QNode, "qml.qnn.KerasLayer", "qml.qnn.TorchLayer"],
level: Union[Literal["top", "user", "device", "gradient"], int, slice, None] = "user",
) -> Callable:
"""Construct the batch of tapes and post processing for a designated stage in the transform program.
Expand Down Expand Up @@ -350,8 +359,10 @@ def batch_constructor(*args, **kwargs) -> tuple[QuantumScriptBatch, Postprocessi
params = initial_tape.get_parameters(trainable_only=False)
initial_tape.trainable_params = qml.math.get_trainable_indices(params)

qnode._update_gradient_fn(tape=initial_tape)
program = get_transform_program(qnode, level=level)
gradient_fn = QNode.get_gradient_fn(
qnode.device, qnode.interface, qnode.diff_method, tape=initial_tape
)[0]
program = get_transform_program(qnode, level=level, gradient_fn=gradient_fn)

return program((initial_tape,))

Expand Down
115 changes: 65 additions & 50 deletions pennylane/workflow/qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ def __init__(
# input arguments
self.func = func
self.device = device
self._interface = interface
self._interface = None if diff_method is None else interface
self.diff_method = diff_method
mcm_config = qml.devices.MCMConfig(mcm_method=mcm_method, postselect_mode=postselect_mode)
cache = (max_diff > 1) if cache == "auto" else cache
Expand All @@ -542,14 +542,48 @@ def __init__(
# internal data attributes
self._tape = None
self._qfunc_output = None
self._user_gradient_kwargs = gradient_kwargs
self.gradient_fn = None
self.gradient_kwargs = {}
self._gradient_fn = None
self.gradient_kwargs = gradient_kwargs

self._transform_program = TransformProgram()
self._update_gradient_fn()
functools.update_wrapper(self, func)

# validation check. Will raise error if bad diff_method
if diff_method is not None:
QNode.get_gradient_fn(self.device, self.interface, self.diff_method)

@property
def gradient_fn(self):
"""A processed version of ``QNode.diff_method``.
.. warning::
This property is deprecated in v0.39 and will be removed in v0.40.
Please see ``QNode.diff_method`` instead.
"""
warnings.warn(
"QNode.gradient_fn is deprecated. Please use QNode.diff_method instead.",
qml.PennyLaneDeprecationWarning,
)
if self.diff_method is None:
return None

if (
self.device.name == "lightning.qubit"
and qml.metric_tensor in self.transform_program
and self.diff_method == "best"
):
return qml.gradients.param_shift

if self.tape is None and self.device.shots:
tape = qml.tape.QuantumScript([], [], shots=self.device.shots)
else:
tape = self.tape

return QNode.get_gradient_fn(self.device, self.interface, self.diff_method, tape=tape)[0]

def __copy__(self) -> "QNode":
copied_qnode = QNode.__new__(QNode)
for attr, value in vars(self).items():
Expand Down Expand Up @@ -590,7 +624,6 @@ def interface(self, value: SupportedInterfaceUserInput):
)

self._interface = INTERFACE_MAP[value]
self._update_gradient_fn(shots=self.device.shots)

@property
def transform_program(self) -> TransformProgram:
Expand All @@ -605,28 +638,6 @@ def add_transform(self, transform_container: TransformContainer):
"""
self._transform_program.push_back(transform_container=transform_container)

def _update_gradient_fn(self, shots=None, tape: Optional["qml.tape.QuantumTape"] = None):
if self.diff_method is None:
self._interface = None
self.gradient_fn = None
self.gradient_kwargs = {}
return
if tape is None and shots:
tape = qml.tape.QuantumScript([], [], shots=shots)

diff_method = self.diff_method
if (
self.device.name == "lightning.qubit"
and qml.metric_tensor in self.transform_program
and self.diff_method == "best"
):
diff_method = "parameter-shift"

self.gradient_fn, self.gradient_kwargs, self.device = QNode.get_gradient_fn(
self.device, self.interface, diff_method, tape=tape
)
self.gradient_kwargs.update(self._user_gradient_kwargs or {})

# pylint: disable=too-many-return-statements
@staticmethod
@debug_logger
Expand All @@ -652,6 +663,8 @@ def get_gradient_fn(
tuple[str or .TransformDispatcher, dict, .device.Device: Tuple containing the ``gradient_fn``,
``gradient_kwargs``, and the device to use when calling the execute function.
"""
if diff_method is None:
return None, {}, device

config = _make_execution_config(None, diff_method)

Expand Down Expand Up @@ -859,8 +872,22 @@ def _execution_component(self, args: tuple, kwargs: dict) -> qml.typing.Result:
Result
"""

if (
self.device.name == "lightning.qubit"
and qml.metric_tensor in self.transform_program
and self.diff_method == "best"
):
gradient_fn = qml.gradients.param_shift
else:
gradient_fn = QNode.get_gradient_fn(
self.device, self.interface, self.diff_method, tape=self.tape
)[0]
execute_kwargs = copy.copy(self.execute_kwargs)

gradient_kwargs = copy.copy(self.gradient_kwargs)
if gradient_fn is qml.gradients.param_shift_cv:
gradient_kwargs["dev"] = self.device

mcm_config = copy.copy(execute_kwargs["mcm_config"])
if not self._tape.shots:
mcm_config.postselect_mode = None
Expand All @@ -875,7 +902,7 @@ def _execution_component(self, args: tuple, kwargs: dict) -> qml.typing.Result:
full_transform_program = qml.transforms.core.TransformProgram(self.transform_program)
inner_transform_program = qml.transforms.core.TransformProgram()

config = _make_execution_config(self, self.gradient_fn, mcm_config)
config = _make_execution_config(self, gradient_fn, mcm_config)
device_transform_program, config = self.device.preprocess(execution_config=config)

if config.use_device_gradient:
Expand All @@ -884,10 +911,10 @@ def _execution_component(self, args: tuple, kwargs: dict) -> qml.typing.Result:
inner_transform_program += device_transform_program

# Add the gradient expand to the program if necessary
if getattr(self.gradient_fn, "expand_transform", False):
if getattr(gradient_fn, "expand_transform", False):
full_transform_program.insert_front_transform(
qml.transform(self.gradient_fn.expand_transform),
**self.gradient_kwargs,
qml.transform(gradient_fn.expand_transform),
**gradient_kwargs,
)

# Calculate the classical jacobians if necessary
Expand All @@ -900,12 +927,12 @@ def _execution_component(self, args: tuple, kwargs: dict) -> qml.typing.Result:
res = qml.execute(
(self._tape,),
device=self.device,
gradient_fn=self.gradient_fn,
gradient_fn=gradient_fn,
interface=self.interface,
transform_program=full_transform_program,
inner_transform=inner_transform_program,
config=config,
gradient_kwargs=self.gradient_kwargs,
gradient_kwargs=gradient_kwargs,
**execute_kwargs,
)
res = res[0]
Expand All @@ -924,6 +951,9 @@ def _execution_component(self, args: tuple, kwargs: dict) -> qml.typing.Result:

def _impl_call(self, *args, **kwargs) -> qml.typing.Result:

# construct the tape
self.construct(args, kwargs)

old_interface = self.interface
if old_interface == "auto":
interface = (
Expand All @@ -933,27 +963,12 @@ def _impl_call(self, *args, **kwargs) -> qml.typing.Result:
)
self._interface = INTERFACE_MAP[interface]

if self._qfunc_uses_shots_arg:
override_shots = False
else:
if "shots" not in kwargs:
kwargs["shots"] = self.device.shots
override_shots = kwargs["shots"]

# construct the tape
self.construct(args, kwargs)

original_grad_fn = [self.gradient_fn, self.gradient_kwargs, self.device]
self._update_gradient_fn(shots=override_shots, tape=self._tape)

try:
res = self._execution_component(args, kwargs)
finally:
if old_interface == "auto":
self._interface = "auto"

_, self.gradient_kwargs, self.device = original_grad_fn

return res

def __call__(self, *args, **kwargs) -> qml.typing.Result:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ def circuit(x):
with dev.tracker:
qml.grad(circuit)(qml.numpy.array(0.1))

assert circuit.gradient_fn == "adjoint"
assert dev.tracker.totals["execute_and_derivative_batches"] == 1


Expand Down
Loading

0 comments on commit 30331fd

Please sign in to comment.