Skip to content

Commit

Permalink
[Program Capture] Add pytree support to captured qml.grad and `qml.…
Browse files Browse the repository at this point in the history
…jacobian` (#6134)

**Context:**
#6120 and #6127 add support to capture `qml.grad` and `qml.jacobian` in
plxpr. Once captured, they dispatch to `jax.grad` and `jax.jacobian`.

**Description of the Change:**
This PR adds support for pytree inputs and outputs of the differentiated
functions, similar to #6081.
For this, it extends the internal class `FlatFn` by the extra
functionality to turn the wrapper into a `*flat_args -> *flat_outputs`
function, instead of a `*pytree_args -> *flat_outputs` function.

**Benefits:**
Pytree support 🌳 

**Possible Drawbacks:**

**Related GitHub Issues:**

[sc-70930]
[sc-71862]

---------

Co-authored-by: Christina Lee <christina@xanadu.ai>
Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
  • Loading branch information
3 people committed Sep 18, 2024
1 parent 4a62336 commit 12c50e3
Show file tree
Hide file tree
Showing 8 changed files with 260 additions and 34 deletions.
3 changes: 2 additions & 1 deletion doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@
* Differentiation of hybrid programs via `qml.grad` and `qml.jacobian` can now be captured
into plxpr. When evaluating a captured `qml.grad` (`qml.jacobian`) instruction, it will
dispatch to `jax.grad` (`jax.jacobian`), which differs from the Autograd implementation
without capture.
without capture. Pytree inputs and outputs are supported.
[(#6120)](https://github.com/PennyLaneAI/pennylane/pull/6120)
[(#6127)](https://github.com/PennyLaneAI/pennylane/pull/6127)
[(#6134)](https://github.com/PennyLaneAI/pennylane/pull/6134)

* Improve unit testing for capturing of nested control flows.
[(#6111)](https://github.com/PennyLaneAI/pennylane/pull/6111)
Expand Down
50 changes: 43 additions & 7 deletions pennylane/_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from pennylane.capture import enabled
from pennylane.capture.capture_diff import _get_grad_prim, _get_jacobian_prim
from pennylane.capture.flatfn import FlatFn
from pennylane.compiler import compiler
from pennylane.compiler.compiler import CompileError

Expand All @@ -33,18 +34,53 @@

def _capture_diff(func, argnum=None, diff_prim=None, method=None, h=None):
"""Capture-compatible gradient computation."""
import jax # pylint: disable=import-outside-toplevel
# pylint: disable=import-outside-toplevel
import jax
from jax.tree_util import tree_flatten, tree_leaves, tree_unflatten, treedef_tuple

if isinstance(argnum, int):
argnum = [argnum]
if argnum is None:
argnum = [0]
argnum = 0
if argnum_is_int := isinstance(argnum, int):
argnum = [argnum]

@wraps(func)
def new_func(*args, **kwargs):
jaxpr = jax.make_jaxpr(partial(func, **kwargs))(*args)
prim_kwargs = {"argnum": argnum, "jaxpr": jaxpr.jaxpr, "n_consts": len(jaxpr.consts)}
return diff_prim.bind(*jaxpr.consts, *args, **prim_kwargs, method=method, h=h)
flat_args, in_trees = zip(*(tree_flatten(arg) for arg in args))
full_in_tree = treedef_tuple(in_trees)

# Create a new input tree that only takes inputs marked by argnum into account
trainable_in_trees = (in_tree for i, in_tree in enumerate(in_trees) if i in argnum)
# If an integer was provided as argnum, unpack the arguments axis of the derivatives
if argnum_is_int:
trainable_in_tree = list(trainable_in_trees)[0]
else:
trainable_in_tree = treedef_tuple(trainable_in_trees)

# Create argnum for the flat list of input arrays. For each flattened argument,
# add a list of flat argnums if the argument is trainable and an empty list otherwise.
start = 0
flat_argnum_gen = (
(
list(range(start, (start := start + len(flat_arg))))
if i in argnum
else list(range((start := start + len(flat_arg)), start))
)
for i, flat_arg in enumerate(flat_args)
)
flat_argnum = sum(flat_argnum_gen, start=[])

# Create fully flattened function (flat inputs & outputs)
flat_fn = FlatFn(partial(func, **kwargs) if kwargs else func, full_in_tree)
flat_args = sum(flat_args, start=[])
jaxpr = jax.make_jaxpr(flat_fn)(*flat_args)
prim_kwargs = {"argnum": flat_argnum, "jaxpr": jaxpr.jaxpr, "n_consts": len(jaxpr.consts)}
out_flat = diff_prim.bind(*jaxpr.consts, *flat_args, **prim_kwargs, method=method, h=h)
# flatten once more to go from 2D derivative structure (outputs, args) to flat structure
out_flat = tree_leaves(out_flat)
assert flat_fn.out_tree is not None, "out_tree should be set after executing flat_fn"
# The derivative output tree is the composition of output tree and trainable input trees
combined_tree = flat_fn.out_tree.compose(trainable_in_tree)
return tree_unflatten(combined_tree, out_flat)

return new_func

Expand Down
3 changes: 3 additions & 0 deletions pennylane/capture/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
~create_measurement_wires_primitive
~create_measurement_mcm_primitive
~qnode_call
~FlatFn
The ``primitives`` submodule offers easy access to objects with jax dependencies such as
Expand Down Expand Up @@ -154,6 +155,7 @@ def _(*args, **kwargs):
create_measurement_mcm_primitive,
)
from .capture_qnode import qnode_call
from .flatfn import FlatFn

# by defining this here, we avoid
# E0611: No name 'AbstractOperator' in module 'pennylane.capture' (no-name-in-module)
Expand Down Expand Up @@ -196,4 +198,5 @@ def __getattr__(key):
"AbstractOperator",
"AbstractMeasurement",
"qnode_prim",
"FlatFn",
)
2 changes: 1 addition & 1 deletion pennylane/capture/capture_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def _(*args, argnum, jaxpr, n_consts, method, h):
def func(*inner_args):
return jax.core.eval_jaxpr(jaxpr, consts, *inner_args)

return jax.jacobian(func, argnums=argnum)(*args)
return jax.tree_util.tree_leaves(jax.jacobian(func, argnums=argnum)(*args))

# pylint: disable=unused-argument
@jacobian_prim.def_abstract_eval
Expand Down
10 changes: 7 additions & 3 deletions pennylane/capture/capture_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,12 @@ def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr, n_consts):
mps = qfunc_jaxpr.outvars
return _get_shapes_for(*mps, shots=shots, num_device_wires=len(device.wires))

def _qnode_jvp(*args_and_tangents, **impl_kwargs):
return jax.jvp(partial(qnode_prim.impl, **impl_kwargs), *args_and_tangents)
def make_zero(tan, arg):
return jax.lax.zeros_like_array(arg) if isinstance(tan, ad.Zero) else tan

def _qnode_jvp(args, tangents, **impl_kwargs):
tangents = tuple(map(make_zero, tangents, args))
return jax.jvp(partial(qnode_prim.impl, **impl_kwargs), args, tangents)

ad.primitive_jvps[qnode_prim] = _qnode_jvp

Expand Down Expand Up @@ -174,7 +178,7 @@ def f(x):
qnode_kwargs = {"diff_method": qnode.diff_method, **execute_kwargs, **mcm_config}
qnode_prim = _get_qnode_prim()

flat_args, _ = jax.tree_util.tree_flatten(args)
flat_args = jax.tree_util.tree_leaves(args)
res = qnode_prim.bind(
*qfunc_jaxpr.consts,
*flat_args,
Expand Down
7 changes: 3 additions & 4 deletions pennylane/capture/explanations.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,11 @@ You can also see the const variable `a` as argument `e:i32[]` to the inner neste
### Pytree handling

Evaluating a jaxpr requires accepting and returning a flat list of tensor-like inputs and outputs.
list of tensor-like outputs. These long lists can be hard to manage and are very
restrictive on the allowed functions, but we can take advantage of pytrees to allow handling
arbitrary functions.
These long lists can be hard to manage and are very restrictive on the allowed functions, but we
can take advantage of pytrees to allow handling arbitrary functions.

To start, we import the `FlatFn` helper. This class converts a function to one that caches
the resulting result pytree into `flat_fn.out_tree` when executed. This can be used to repack the
the result pytree into `flat_fn.out_tree` when executed. This can be used to repack the
results into the correct shape. It also returns flattened results. This does not particularly
matter for program capture, as we will only be producing jaxpr from the function, not calling
it directly.
Expand Down
31 changes: 28 additions & 3 deletions pennylane/capture/flatfn.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,48 @@ class FlatFn:
property, so that the results can be repacked later. It also returns flattened results
instead of the original result object.
If an ``in_tree`` is provided, the function accepts flattened inputs instead of the
original inputs with tree structure given by ``in_tree``.
**Example**
>>> import jax
>>> from pennylane.capture.flatfn import FlatFn
>>> def f(x):
... return {"y": 2+x["x"]}
>>> flat_f = FlatFn(f)
>>> res = flat_f({"x": 0})
>>> arg = {"x": 0.5}
>>> res = flat_f(arg)
>>> res
[2.5]
>>> jax.tree_util.tree_unflatten(flat_f.out_tree, res)
{'y': 2.5}
If we want to use a fully flattened function that also takes flat inputs instead of
the original inputs with tree structure, we can provide the treedef for this input
structure:
>>> flat_args, in_tree = jax.tree_util.tree_flatten((arg,))
>>> flat_f = FlatFn(f, in_tree)
>>> res = flat_f(*flat_args)
>>> res
[2]
[2.5]
>>> jax.tree_util.tree_unflatten(flat_f.out_tree, res)
{'y': 2.5}
Note that the ``in_tree`` has to be created by flattening a tuple of all input
arguments, even if there is only a single argument.
"""

def __init__(self, f):
def __init__(self, f, in_tree=None):
self.f = f
self.in_tree = in_tree
self.out_tree = None
update_wrapper(self, f)

def __call__(self, *args):
if self.in_tree is not None:
args = jax.tree_util.tree_unflatten(self.in_tree, args)
out = self.f(*args)
out_flat, out_tree = jax.tree_util.tree_flatten(out)
self.out_tree = out_tree
Expand Down
Loading

0 comments on commit 12c50e3

Please sign in to comment.