Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Program Capture] Add pytree support to captured qml.grad and qml.jacobian #6134

Merged
merged 90 commits into from
Sep 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
90 commits
Select commit Hold shift + click to select a range
206ca6a
first prototype
dwierichs Aug 21, 2024
a466a83
first tests
dwierichs Aug 21, 2024
8b402f7
cleanup
dwierichs Aug 21, 2024
9e8f4c9
-a
dwierichs Aug 21, 2024
cfae262
changelog, lint
dwierichs Aug 21, 2024
ed8f515
add tests...
dwierichs Aug 21, 2024
79fe790
move primitive
dwierichs Aug 22, 2024
b90fa2b
import
dwierichs Aug 22, 2024
b9b6536
lint
dwierichs Aug 22, 2024
ff54b86
lint
dwierichs Aug 22, 2024
fde341f
parshift test
dwierichs Aug 22, 2024
52ccec1
first port
dwierichs Aug 22, 2024
3cd33f1
prepare jac
dwierichs Aug 22, 2024
7f8a8ae
Merge branch 'capture-grad' into capture-jacobian
dwierichs Aug 22, 2024
c20e7d3
classical test
dwierichs Aug 22, 2024
a087e53
make grad differentiable
dwierichs Aug 22, 2024
393aa35
nested grad test
dwierichs Aug 22, 2024
f1a6b4d
lint
dwierichs Aug 22, 2024
b86619b
Merge branch 'master' into capture-grad
dwierichs Aug 22, 2024
f678a91
lint
dwierichs Aug 22, 2024
373790b
Merge branch 'capture-grad' into capture-jacobian
dwierichs Aug 22, 2024
cd7fe86
nested jac test
dwierichs Aug 22, 2024
9c58349
nested jac test
dwierichs Aug 22, 2024
60e4d33
fix
dwierichs Aug 22, 2024
bc055fb
jacobian qnode test
dwierichs Aug 22, 2024
cd7d128
fix
dwierichs Aug 22, 2024
3a43436
fix again
dwierichs Aug 22, 2024
15ed932
Merge branch 'master' into capture-grad
dwierichs Aug 22, 2024
ed2b22a
implement pytree support, start testing
dwierichs Aug 23, 2024
de82838
test grad conclusion
dwierichs Aug 23, 2024
de67faa
flatfn mod
dwierichs Aug 23, 2024
692db13
Merge branch 'master' into capture-grad
dwierichs Aug 23, 2024
6ad098e
Merge branch 'capture-grad' into capture-jacobian
dwierichs Aug 23, 2024
7ab7201
changelog
dwierichs Aug 23, 2024
e2acc09
changelog
dwierichs Aug 23, 2024
896431b
changelog
dwierichs Aug 23, 2024
75a26b9
Merge branch 'capture-grad' into capture-grad-pytrees
dwierichs Aug 23, 2024
0e407c3
Merge branch 'master' into capture-grad
dwierichs Aug 26, 2024
ea149a2
Merge branch 'capture-grad' into capture-jacobian
dwierichs Aug 26, 2024
f877674
Merge branch 'capture-grad' into capture-grad-pytrees
dwierichs Aug 26, 2024
f4bb23e
commentary, treedef_tuple trick
dwierichs Aug 26, 2024
93bda87
method and h allowed in capture
dwierichs Aug 26, 2024
5bb2900
higher order primitive tests
dwierichs Aug 26, 2024
10b2899
Merge branch 'master' into capture-grad
dwierichs Aug 27, 2024
c838b4a
Apply suggestions from code review
dwierichs Aug 27, 2024
67bdeb8
Merge branch 'master' into capture-grad
dwierichs Aug 27, 2024
cb43c96
Merge branch 'master' into capture-grad
dwierichs Aug 28, 2024
7ae7415
kwargs, test structure
dwierichs Aug 28, 2024
4b3ac68
merge
dwierichs Aug 28, 2024
aa3876e
merge
dwierichs Aug 29, 2024
389c22d
merge more
dwierichs Aug 29, 2024
de782b8
lint more
dwierichs Aug 29, 2024
a9a4472
add file
dwierichs Aug 29, 2024
769eb98
[skip ci]
dwierichs Aug 29, 2024
43628ab
merge [skip ci]
dwierichs Aug 29, 2024
6c60a08
format [skip ci]
dwierichs Aug 29, 2024
9269f9c
-m
dwierichs Aug 29, 2024
9a4c580
Merge branch 'master' into capture-grad
dwierichs Sep 3, 2024
e79d4a8
merge
dwierichs Sep 5, 2024
552417b
while_loop
dwierichs Sep 5, 2024
81ab600
import fix
dwierichs Sep 5, 2024
63d82f5
Merge branch 'capture-grad' into capture-jacobian
dwierichs Sep 5, 2024
1af5b97
lint
dwierichs Sep 5, 2024
c3fbd78
fix import
dwierichs Sep 5, 2024
e73f2fb
import and skip order
dwierichs Sep 6, 2024
3d2cb89
Merge branch 'master' into capture-grad
dwierichs Sep 6, 2024
724675f
[skip ci]
dwierichs Sep 6, 2024
9fda181
Merge branch 'master' into capture-grad
dwierichs Sep 9, 2024
3513487
merge
dwierichs Sep 9, 2024
db99812
lint
dwierichs Sep 9, 2024
c59da75
merge
dwierichs Sep 9, 2024
f2aead1
merge
dwierichs Sep 9, 2024
a3e5eee
merge
dwierichs Sep 9, 2024
6800365
tmp
dwierichs Sep 9, 2024
5aa57e6
merge
dwierichs Sep 9, 2024
4222774
Merge branch 'capture-jacobian' into capture-grad-pytrees
dwierichs Sep 9, 2024
4dc9582
Merge branch 'master' into capture-jacobian
dwierichs Sep 9, 2024
76f91ce
extend pytree support to Jacobian
dwierichs Sep 9, 2024
b2b9604
Merge branch 'master' into capture-jacobian
dwierichs Sep 9, 2024
9045e38
Merge branch 'capture-jacobian' into capture-grad-pytrees
dwierichs Sep 9, 2024
df7eb06
Merge branch 'master' into capture-grad-pytrees
dwierichs Sep 10, 2024
2718606
lint
dwierichs Sep 10, 2024
14657b8
Apply suggestions from code review
dwierichs Sep 10, 2024
2b86f7c
Merge branch 'master' into capture-grad-pytrees
dwierichs Sep 11, 2024
5774099
-a
dwierichs Sep 11, 2024
a56437b
bugfix
dwierichs Sep 11, 2024
2055107
-a
dwierichs Sep 11, 2024
0c5c62e
flatfn to module
dwierichs Sep 11, 2024
50093c4
Merge branch 'master' into capture-grad-pytrees
dwierichs Sep 13, 2024
c2c877e
Merge branch 'master' into capture-grad-pytrees
dwierichs Sep 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@
* Differentiation of hybrid programs via `qml.grad` and `qml.jacobian` can now be captured
into plxpr. When evaluating a captured `qml.grad` (`qml.jacobian`) instruction, it will
dispatch to `jax.grad` (`jax.jacobian`), which differs from the Autograd implementation
without capture.
without capture. Pytree inputs and outputs are supported.
[(#6120)](https://github.com/PennyLaneAI/pennylane/pull/6120)
[(#6127)](https://github.com/PennyLaneAI/pennylane/pull/6127)
[(#6134)](https://github.com/PennyLaneAI/pennylane/pull/6134)

* Improve unit testing for capturing of nested control flows.
[(#6111)](https://github.com/PennyLaneAI/pennylane/pull/6111)
Expand Down
50 changes: 43 additions & 7 deletions pennylane/_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from pennylane.capture import enabled
from pennylane.capture.capture_diff import _get_grad_prim, _get_jacobian_prim
from pennylane.capture.flatfn import FlatFn
from pennylane.compiler import compiler
from pennylane.compiler.compiler import CompileError

Expand All @@ -33,18 +34,53 @@

def _capture_diff(func, argnum=None, diff_prim=None, method=None, h=None):
"""Capture-compatible gradient computation."""
import jax # pylint: disable=import-outside-toplevel
# pylint: disable=import-outside-toplevel
import jax
from jax.tree_util import tree_flatten, tree_leaves, tree_unflatten, treedef_tuple

if isinstance(argnum, int):
argnum = [argnum]
if argnum is None:
argnum = [0]
argnum = 0
if argnum_is_int := isinstance(argnum, int):
argnum = [argnum]

@wraps(func)
def new_func(*args, **kwargs):
jaxpr = jax.make_jaxpr(partial(func, **kwargs))(*args)
prim_kwargs = {"argnum": argnum, "jaxpr": jaxpr.jaxpr, "n_consts": len(jaxpr.consts)}
return diff_prim.bind(*jaxpr.consts, *args, **prim_kwargs, method=method, h=h)
flat_args, in_trees = zip(*(tree_flatten(arg) for arg in args))
full_in_tree = treedef_tuple(in_trees)

# Create a new input tree that only takes inputs marked by argnum into account
trainable_in_trees = (in_tree for i, in_tree in enumerate(in_trees) if i in argnum)
# If an integer was provided as argnum, unpack the arguments axis of the derivatives
if argnum_is_int:
trainable_in_tree = list(trainable_in_trees)[0]
else:
trainable_in_tree = treedef_tuple(trainable_in_trees)

# Create argnum for the flat list of input arrays. For each flattened argument,
# add a list of flat argnums if the argument is trainable and an empty list otherwise.
start = 0
flat_argnum_gen = (
(
list(range(start, (start := start + len(flat_arg))))
if i in argnum
else list(range((start := start + len(flat_arg)), start))
)
for i, flat_arg in enumerate(flat_args)
)
flat_argnum = sum(flat_argnum_gen, start=[])

# Create fully flattened function (flat inputs & outputs)
flat_fn = FlatFn(partial(func, **kwargs) if kwargs else func, full_in_tree)
flat_args = sum(flat_args, start=[])
jaxpr = jax.make_jaxpr(flat_fn)(*flat_args)
prim_kwargs = {"argnum": flat_argnum, "jaxpr": jaxpr.jaxpr, "n_consts": len(jaxpr.consts)}
out_flat = diff_prim.bind(*jaxpr.consts, *flat_args, **prim_kwargs, method=method, h=h)
# flatten once more to go from 2D derivative structure (outputs, args) to flat structure
out_flat = tree_leaves(out_flat)
assert flat_fn.out_tree is not None, "out_tree should be set after executing flat_fn"
# The derivative output tree is the composition of output tree and trainable input trees
combined_tree = flat_fn.out_tree.compose(trainable_in_tree)
return tree_unflatten(combined_tree, out_flat)

return new_func

Expand Down
3 changes: 3 additions & 0 deletions pennylane/capture/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
~create_measurement_wires_primitive
~create_measurement_mcm_primitive
~qnode_call
~FlatFn


The ``primitives`` submodule offers easy access to objects with jax dependencies such as
Expand Down Expand Up @@ -154,6 +155,7 @@ def _(*args, **kwargs):
create_measurement_mcm_primitive,
)
from .capture_qnode import qnode_call
from .flatfn import FlatFn

# by defining this here, we avoid
# E0611: No name 'AbstractOperator' in module 'pennylane.capture' (no-name-in-module)
Expand Down Expand Up @@ -196,4 +198,5 @@ def __getattr__(key):
"AbstractOperator",
"AbstractMeasurement",
"qnode_prim",
"FlatFn",
)
2 changes: 1 addition & 1 deletion pennylane/capture/capture_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def _(*args, argnum, jaxpr, n_consts, method, h):
def func(*inner_args):
return jax.core.eval_jaxpr(jaxpr, consts, *inner_args)

return jax.jacobian(func, argnums=argnum)(*args)
return jax.tree_util.tree_leaves(jax.jacobian(func, argnums=argnum)(*args))

# pylint: disable=unused-argument
@jacobian_prim.def_abstract_eval
Expand Down
10 changes: 7 additions & 3 deletions pennylane/capture/capture_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,12 @@ def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr, n_consts):
mps = qfunc_jaxpr.outvars
return _get_shapes_for(*mps, shots=shots, num_device_wires=len(device.wires))

def _qnode_jvp(*args_and_tangents, **impl_kwargs):
return jax.jvp(partial(qnode_prim.impl, **impl_kwargs), *args_and_tangents)
def make_zero(tan, arg):
return jax.lax.zeros_like_array(arg) if isinstance(tan, ad.Zero) else tan

def _qnode_jvp(args, tangents, **impl_kwargs):
tangents = tuple(map(make_zero, tangents, args))
return jax.jvp(partial(qnode_prim.impl, **impl_kwargs), args, tangents)

ad.primitive_jvps[qnode_prim] = _qnode_jvp

Expand Down Expand Up @@ -174,7 +178,7 @@ def f(x):
qnode_kwargs = {"diff_method": qnode.diff_method, **execute_kwargs, **mcm_config}
qnode_prim = _get_qnode_prim()

flat_args, _ = jax.tree_util.tree_flatten(args)
flat_args = jax.tree_util.tree_leaves(args)
res = qnode_prim.bind(
*qfunc_jaxpr.consts,
*flat_args,
Expand Down
7 changes: 3 additions & 4 deletions pennylane/capture/explanations.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,11 @@ You can also see the const variable `a` as argument `e:i32[]` to the inner neste
### Pytree handling

Evaluating a jaxpr requires accepting and returning a flat list of tensor-like inputs and outputs.
list of tensor-like outputs. These long lists can be hard to manage and are very
restrictive on the allowed functions, but we can take advantage of pytrees to allow handling
arbitrary functions.
These long lists can be hard to manage and are very restrictive on the allowed functions, but we
can take advantage of pytrees to allow handling arbitrary functions.

To start, we import the `FlatFn` helper. This class converts a function to one that caches
the resulting result pytree into `flat_fn.out_tree` when executed. This can be used to repack the
the result pytree into `flat_fn.out_tree` when executed. This can be used to repack the
results into the correct shape. It also returns flattened results. This does not particularly
matter for program capture, as we will only be producing jaxpr from the function, not calling
it directly.
Expand Down
31 changes: 28 additions & 3 deletions pennylane/capture/flatfn.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,48 @@ class FlatFn:
property, so that the results can be repacked later. It also returns flattened results
instead of the original result object.

If an ``in_tree`` is provided, the function accepts flattened inputs instead of the
original inputs with tree structure given by ``in_tree``.

**Example**

>>> import jax
>>> from pennylane.capture.flatfn import FlatFn
>>> def f(x):
... return {"y": 2+x["x"]}
>>> flat_f = FlatFn(f)
>>> res = flat_f({"x": 0})
>>> arg = {"x": 0.5}
>>> res = flat_f(arg)
>>> res
[2.5]
>>> jax.tree_util.tree_unflatten(flat_f.out_tree, res)
{'y': 2.5}

If we want to use a fully flattened function that also takes flat inputs instead of
the original inputs with tree structure, we can provide the treedef for this input
structure:

>>> flat_args, in_tree = jax.tree_util.tree_flatten((arg,))
>>> flat_f = FlatFn(f, in_tree)
>>> res = flat_f(*flat_args)
>>> res
[2]
[2.5]
>>> jax.tree_util.tree_unflatten(flat_f.out_tree, res)
{'y': 2.5}

Note that the ``in_tree`` has to be created by flattening a tuple of all input
arguments, even if there is only a single argument.
"""

def __init__(self, f):
def __init__(self, f, in_tree=None):
self.f = f
self.in_tree = in_tree
self.out_tree = None
update_wrapper(self, f)

def __call__(self, *args):
if self.in_tree is not None:
args = jax.tree_util.tree_unflatten(self.in_tree, args)
out = self.f(*args)
out_flat, out_tree = jax.tree_util.tree_flatten(out)
self.out_tree = out_tree
Expand Down
Loading
Loading