Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Program Capture] Add pytree support to captured qml.grad and qml.jacobian #6134

Merged
merged 90 commits into from
Sep 15, 2024
Merged
Show file tree
Hide file tree
Changes from 83 commits
Commits
Show all changes
90 commits
Select commit Hold shift + click to select a range
206ca6a
first prototype
dwierichs Aug 21, 2024
a466a83
first tests
dwierichs Aug 21, 2024
8b402f7
cleanup
dwierichs Aug 21, 2024
9e8f4c9
-a
dwierichs Aug 21, 2024
cfae262
changelog, lint
dwierichs Aug 21, 2024
ed8f515
add tests...
dwierichs Aug 21, 2024
79fe790
move primitive
dwierichs Aug 22, 2024
b90fa2b
import
dwierichs Aug 22, 2024
b9b6536
lint
dwierichs Aug 22, 2024
ff54b86
lint
dwierichs Aug 22, 2024
fde341f
parshift test
dwierichs Aug 22, 2024
52ccec1
first port
dwierichs Aug 22, 2024
3cd33f1
prepare jac
dwierichs Aug 22, 2024
7f8a8ae
Merge branch 'capture-grad' into capture-jacobian
dwierichs Aug 22, 2024
c20e7d3
classical test
dwierichs Aug 22, 2024
a087e53
make grad differentiable
dwierichs Aug 22, 2024
393aa35
nested grad test
dwierichs Aug 22, 2024
f1a6b4d
lint
dwierichs Aug 22, 2024
b86619b
Merge branch 'master' into capture-grad
dwierichs Aug 22, 2024
f678a91
lint
dwierichs Aug 22, 2024
373790b
Merge branch 'capture-grad' into capture-jacobian
dwierichs Aug 22, 2024
cd7fe86
nested jac test
dwierichs Aug 22, 2024
9c58349
nested jac test
dwierichs Aug 22, 2024
60e4d33
fix
dwierichs Aug 22, 2024
bc055fb
jacobian qnode test
dwierichs Aug 22, 2024
cd7d128
fix
dwierichs Aug 22, 2024
3a43436
fix again
dwierichs Aug 22, 2024
15ed932
Merge branch 'master' into capture-grad
dwierichs Aug 22, 2024
ed2b22a
implement pytree support, start testing
dwierichs Aug 23, 2024
de82838
test grad conclusion
dwierichs Aug 23, 2024
de67faa
flatfn mod
dwierichs Aug 23, 2024
692db13
Merge branch 'master' into capture-grad
dwierichs Aug 23, 2024
6ad098e
Merge branch 'capture-grad' into capture-jacobian
dwierichs Aug 23, 2024
7ab7201
changelog
dwierichs Aug 23, 2024
e2acc09
changelog
dwierichs Aug 23, 2024
896431b
changelog
dwierichs Aug 23, 2024
75a26b9
Merge branch 'capture-grad' into capture-grad-pytrees
dwierichs Aug 23, 2024
0e407c3
Merge branch 'master' into capture-grad
dwierichs Aug 26, 2024
ea149a2
Merge branch 'capture-grad' into capture-jacobian
dwierichs Aug 26, 2024
f877674
Merge branch 'capture-grad' into capture-grad-pytrees
dwierichs Aug 26, 2024
f4bb23e
commentary, treedef_tuple trick
dwierichs Aug 26, 2024
93bda87
method and h allowed in capture
dwierichs Aug 26, 2024
5bb2900
higher order primitive tests
dwierichs Aug 26, 2024
10b2899
Merge branch 'master' into capture-grad
dwierichs Aug 27, 2024
c838b4a
Apply suggestions from code review
dwierichs Aug 27, 2024
67bdeb8
Merge branch 'master' into capture-grad
dwierichs Aug 27, 2024
cb43c96
Merge branch 'master' into capture-grad
dwierichs Aug 28, 2024
7ae7415
kwargs, test structure
dwierichs Aug 28, 2024
4b3ac68
merge
dwierichs Aug 28, 2024
aa3876e
merge
dwierichs Aug 29, 2024
389c22d
merge more
dwierichs Aug 29, 2024
de782b8
lint more
dwierichs Aug 29, 2024
a9a4472
add file
dwierichs Aug 29, 2024
769eb98
[skip ci]
dwierichs Aug 29, 2024
43628ab
merge [skip ci]
dwierichs Aug 29, 2024
6c60a08
format [skip ci]
dwierichs Aug 29, 2024
9269f9c
-m
dwierichs Aug 29, 2024
9a4c580
Merge branch 'master' into capture-grad
dwierichs Sep 3, 2024
e79d4a8
merge
dwierichs Sep 5, 2024
552417b
while_loop
dwierichs Sep 5, 2024
81ab600
import fix
dwierichs Sep 5, 2024
63d82f5
Merge branch 'capture-grad' into capture-jacobian
dwierichs Sep 5, 2024
1af5b97
lint
dwierichs Sep 5, 2024
c3fbd78
fix import
dwierichs Sep 5, 2024
e73f2fb
import and skip order
dwierichs Sep 6, 2024
3d2cb89
Merge branch 'master' into capture-grad
dwierichs Sep 6, 2024
724675f
[skip ci]
dwierichs Sep 6, 2024
9fda181
Merge branch 'master' into capture-grad
dwierichs Sep 9, 2024
3513487
merge
dwierichs Sep 9, 2024
db99812
lint
dwierichs Sep 9, 2024
c59da75
merge
dwierichs Sep 9, 2024
f2aead1
merge
dwierichs Sep 9, 2024
a3e5eee
merge
dwierichs Sep 9, 2024
6800365
tmp
dwierichs Sep 9, 2024
5aa57e6
merge
dwierichs Sep 9, 2024
4222774
Merge branch 'capture-jacobian' into capture-grad-pytrees
dwierichs Sep 9, 2024
4dc9582
Merge branch 'master' into capture-jacobian
dwierichs Sep 9, 2024
76f91ce
extend pytree support to Jacobian
dwierichs Sep 9, 2024
b2b9604
Merge branch 'master' into capture-jacobian
dwierichs Sep 9, 2024
9045e38
Merge branch 'capture-jacobian' into capture-grad-pytrees
dwierichs Sep 9, 2024
df7eb06
Merge branch 'master' into capture-grad-pytrees
dwierichs Sep 10, 2024
2718606
lint
dwierichs Sep 10, 2024
14657b8
Apply suggestions from code review
dwierichs Sep 10, 2024
2b86f7c
Merge branch 'master' into capture-grad-pytrees
dwierichs Sep 11, 2024
5774099
-a
dwierichs Sep 11, 2024
a56437b
bugfix
dwierichs Sep 11, 2024
2055107
-a
dwierichs Sep 11, 2024
0c5c62e
flatfn to module
dwierichs Sep 11, 2024
50093c4
Merge branch 'master' into capture-grad-pytrees
dwierichs Sep 13, 2024
c2c877e
Merge branch 'master' into capture-grad-pytrees
dwierichs Sep 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@
* Differentiation of hybrid programs via `qml.grad` and `qml.jacobian` can now be captured
into plxpr. When evaluating a captured `qml.grad` (`qml.jacobian`) instruction, it will
dispatch to `jax.grad` (`jax.jacobian`), which differs from the Autograd implementation
without capture.
without capture. Pytree inputs and outputs are supported.
[(#6120)](https://github.com/PennyLaneAI/pennylane/pull/6120)
[(#6127)](https://github.com/PennyLaneAI/pennylane/pull/6127)
[(#6134)](https://github.com/PennyLaneAI/pennylane/pull/6134)

* Improve unit testing for capturing of nested control flows.
[(#6111)](https://github.com/PennyLaneAI/pennylane/pull/6111)
Expand Down
40 changes: 36 additions & 4 deletions pennylane/_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from pennylane.capture import enabled
from pennylane.capture.capture_diff import _get_grad_prim, _get_jacobian_prim
from pennylane.capture.flatfn import FlatFn
from pennylane.compiler import compiler
from pennylane.compiler.compiler import CompileError

Expand All @@ -33,7 +34,9 @@

def _capture_diff(func, argnum=None, diff_prim=None, method=None, h=None):
"""Capture-compatible gradient computation."""
import jax # pylint: disable=import-outside-toplevel
# pylint: disable=import-outside-toplevel
import jax
from jax.tree_util import tree_flatten, tree_unflatten, treedef_tuple

if isinstance(argnum, int):
argnum = [argnum]
Expand All @@ -42,9 +45,38 @@ def _capture_diff(func, argnum=None, diff_prim=None, method=None, h=None):

@wraps(func)
def new_func(*args, **kwargs):
jaxpr = jax.make_jaxpr(partial(func, **kwargs))(*args)
prim_kwargs = {"argnum": argnum, "jaxpr": jaxpr.jaxpr, "n_consts": len(jaxpr.consts)}
return diff_prim.bind(*jaxpr.consts, *args, **prim_kwargs, method=method, h=h)
flat_args, in_trees = zip(*(tree_flatten(arg) for arg in args))
full_in_tree = treedef_tuple(in_trees)

# Create a new input tree that only takes inputs marked by argnum into account
trainable_in_trees = (in_tree for i, in_tree in enumerate(in_trees) if i in argnum)
trainable_in_tree = treedef_tuple(trainable_in_trees)

# Create argnum for the flat list of input arrays. For each flattened argument,
# add a list of flat argnums if the argument is trainable and an empty list otherwise.
start = 0
flat_argnum_gen = (
(
list(range(start, (start := start + len(flat_arg))))
if i in argnum
else list(range((start := start + len(flat_arg)), start))
)
for i, flat_arg in enumerate(flat_args)
)
flat_argnum = sum(flat_argnum_gen, start=[])

# Create fully flattened function (flat inputs & outputs)
flat_fn = FlatFn(partial(func, **kwargs) if kwargs else func, full_in_tree)
flat_args = sum(flat_args, start=[])
jaxpr = jax.make_jaxpr(flat_fn)(*flat_args)
prim_kwargs = {"argnum": flat_argnum, "jaxpr": jaxpr.jaxpr, "n_consts": len(jaxpr.consts)}
out_flat = diff_prim.bind(*jaxpr.consts, *flat_args, **prim_kwargs, method=method, h=h)
# flatten once more to go from 2D derivative structure (outputs, args) to flat structure
out_flat, _ = tree_flatten(out_flat)
assert flat_fn.out_tree is not None, "out_tree should be set after executing flat_fn"
# The derivative output tree is the composition of output tree and trainable input trees
combined_tree = flat_fn.out_tree.compose(trainable_in_tree)
return tree_unflatten(combined_tree, out_flat)

return new_func

Expand Down
2 changes: 1 addition & 1 deletion pennylane/capture/capture_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def _(*args, argnum, jaxpr, n_consts, method, h):
def func(*inner_args):
return jax.core.eval_jaxpr(jaxpr, consts, *inner_args)

return jax.jacobian(func, argnums=argnum)(*args)
return jax.tree_util.tree_flatten(jax.jacobian(func, argnums=argnum)(*args))[0]
dwierichs marked this conversation as resolved.
Show resolved Hide resolved

# pylint: disable=unused-argument
@jacobian_prim.def_abstract_eval
Expand Down
7 changes: 3 additions & 4 deletions pennylane/capture/explanations.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,11 @@ You can also see the const variable `a` as argument `e:i32[]` to the inner neste
### Pytree handling

Evaluating a jaxpr requires accepting and returning a flat list of tensor-like inputs and outputs.
list of tensor-like outputs. These long lists can be hard to manage and are very
restrictive on the allowed functions, but we can take advantage of pytrees to allow handling
arbitrary functions.
These long lists can be hard to manage and are very restrictive on the allowed functions, but we
can take advantage of pytrees to allow handling arbitrary functions.

To start, we import the `FlatFn` helper. This class converts a function to one that caches
the resulting result pytree into `flat_fn.out_tree` when executed. This can be used to repack the
the result pytree into `flat_fn.out_tree` when executed. This can be used to repack the
results into the correct shape. It also returns flattened results. This does not particularly
matter for program capture, as we will only be producing jaxpr from the function, not calling
it directly.
Expand Down
31 changes: 28 additions & 3 deletions pennylane/capture/flatfn.py
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,48 @@ class FlatFn:
property, so that the results can be repacked later. It also returns flattened results
instead of the original result object.

If an ``in_tree`` is provided, the function accepts flattened inputs instead of the
original inputs with tree structure given by ``in_tree``.

**Example**

>>> import jax
>>> from pennylane.capture.flatfn import FlatFn
>>> def f(x):
... return {"y": 2+x["x"]}
>>> flat_f = FlatFn(f)
>>> res = flat_f({"x": 0})
>>> arg = {"x": 0.5}
>>> res = flat_f(arg)
>>> res
[2.5]
>>> jax.tree_util.tree_unflatten(flat_f.out_tree, res)
{'y': 2.5}

If we want to use a fully flattened function that also takes flat inputs instead of
the original inputs with tree structure, we can provide the treedef for this input
structure:

>>> flat_args, in_tree = jax.tree_util.tree_flatten((arg,))
>>> flat_f = FlatFn(f, in_tree)
>>> res = flat_f(*flat_args)
>>> res
[2]
[2.5]
>>> jax.tree_util.tree_unflatten(flat_f.out_tree, res)
{'y': 2.5}

Note that the ``in_tree`` has to be created by flattening a tuple of all input
arguments, even if there is only a single argument.
"""

def __init__(self, f):
def __init__(self, f, in_tree=None):
self.f = f
self.in_tree = in_tree
self.out_tree = None
update_wrapper(self, f)

def __call__(self, *args):
if self.in_tree is not None:
args = jax.tree_util.tree_unflatten(self.in_tree, args)
out = self.f(*args)
out_flat, out_tree = jax.tree_util.tree_flatten(out)
self.out_tree = out_tree
Expand Down
120 changes: 114 additions & 6 deletions tests/capture/test_capture_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,55 @@ def circuit(x):

jax.config.update("jax_enable_x64", initial_mode)

@pytest.mark.parametrize("argnum", ([0, 1], [0], [1]))
def test_grad_pytree_input(self, x64_mode, argnum):
"""Test that the qml.grad primitive can be captured with pytree inputs."""

initial_mode = jax.config.jax_enable_x64
jax.config.update("jax_enable_x64", x64_mode)
fdtype = jax.numpy.float64 if x64_mode else jax.numpy.float32

def inner_func(x, y):
return jnp.prod(jnp.sin(x["a"]) * jnp.cos(y[0]["b"][1]) ** 2)

def func_qml(x):
return qml.grad(inner_func, argnum=argnum)(
{"a": x}, ({"b": [None, 0.4 * jnp.sqrt(x)]},)
)

def func_jax(x):
return jax.grad(inner_func, argnums=argnum)(
{"a": x}, ({"b": [None, 0.4 * jnp.sqrt(x)]},)
)

x = 0.7
jax_out = func_jax(x)
jax_out_flat, jax_out_tree = jax.tree_util.tree_flatten(jax_out)
qml_out_flat, qml_out_tree = jax.tree_util.tree_flatten(func_qml(x))
assert jax_out_tree == qml_out_tree
assert qml.math.allclose(jax_out_flat, qml_out_flat)

# Check overall jaxpr properties
if isinstance(argnum, int):
argnum = [argnum]
jaxpr = jax.make_jaxpr(func_qml)(x)
assert jaxpr.in_avals == [jax.core.ShapedArray((), fdtype, weak_type=True)]
assert len(jaxpr.eqns) == 3
assert jaxpr.out_avals == [jax.core.ShapedArray((), fdtype, weak_type=True)] * len(argnum)

grad_eqn = jaxpr.eqns[2]
diff_eqn_assertions(grad_eqn, grad_prim, argnum=argnum)
assert [var.aval for var in grad_eqn.outvars] == jaxpr.out_avals
assert len(grad_eqn.params["jaxpr"].eqns) == 6 # 5 numeric eqns, 1 conversion eqn

manual_out = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x)
manual_out_flat, manual_out_tree = jax.tree_util.tree_flatten(manual_out)
# Assert that the output from the manual evaluation is flat
assert manual_out_tree == jax.tree_util.tree_flatten(manual_out_flat)[1]
assert qml.math.allclose(jax_out_flat, manual_out_flat)

jax.config.update("jax_enable_x64", initial_mode)


def _jac_allclose(jac1, jac2, num_axes, atol=1e-8):
"""Test that two Jacobians, given as nested sequences of arrays, are equal."""
Expand All @@ -279,10 +328,6 @@ class TestJacobian:
@pytest.mark.parametrize("argnum", ([0, 1], [0], [1], 0, 1))
def test_classical_jacobian(self, x64_mode, argnum):
"""Test that the qml.jacobian primitive can be captured with classical nodes."""
if isinstance(argnum, list) and len(argnum) > 1:
# These cases will only be unlocked with Pytree support
pytest.xfail()

initial_mode = jax.config.jax_enable_x64
jax.config.update("jax_enable_x64", x64_mode)
fdtype = jnp.float64 if x64_mode else jnp.float32
Expand All @@ -307,14 +352,14 @@ def inner_func(x, y):
func_jax = jax.jacobian(inner_func, argnums=argnum)

jax_out = func_jax(x, y)
num_axes = 1 if isinstance(argnum, int) else 2
num_axes = 1 if (int_argnum := isinstance(argnum, int)) else 2
assert _jac_allclose(func_qml(x, y), jax_out, num_axes)

# Check overall jaxpr properties
jaxpr = jax.make_jaxpr(func_jax)(x, y)
jaxpr = jax.make_jaxpr(func_qml)(x, y)

if isinstance(argnum, int):
if int_argnum:
argnum = [argnum]

exp_in_avals = [shaped_array(shape) for shape in [(4,), (2, 3)]]
Expand All @@ -331,6 +376,9 @@ def inner_func(x, y):
diff_eqn_assertions(jac_eqn, jacobian_prim, argnum=argnum)

manual_eval = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x, y)
# Evaluating jaxpr gives flat list results. Need to adapt the JAX output to that
if not int_argnum:
jax_out = sum(jax_out, start=())
assert _jac_allclose(manual_eval, jax_out, num_axes)

jax.config.update("jax_enable_x64", initial_mode)
Expand Down Expand Up @@ -473,3 +521,63 @@ def circuit(x):
assert _jac_allclose(manual_res, expected_res, 1)

jax.config.update("jax_enable_x64", initial_mode)

@pytest.mark.parametrize("argnum", ([0, 1], [0], [1]))
def test_jacobian_pytrees(self, x64_mode, argnum):
"""Test that the qml.jacobian primitive can be captured with
pytree inputs and outputs."""

initial_mode = jax.config.jax_enable_x64
jax.config.update("jax_enable_x64", x64_mode)
fdtype = jax.numpy.float64 if x64_mode else jax.numpy.float32

def inner_func(x, y):
return {
"prod_cos": jnp.prod(jnp.sin(x["a"]) * jnp.cos(y[0]["b"][1]) ** 2),
"sum_sin": jnp.sum(jnp.sin(x["a"]) * jnp.sin(y[1]["c"]) ** 2),
}

def func_qml(x):
return qml.jacobian(inner_func, argnum=argnum)(
{"a": x}, ({"b": [None, 0.4 * jnp.sqrt(x)]}, {"c": 0.5})
)
dwierichs marked this conversation as resolved.
Show resolved Hide resolved

def func_jax(x):
return jax.jacobian(inner_func, argnums=argnum)(
{"a": x}, ({"b": [None, 0.4 * jnp.sqrt(x)]}, {"c": 0.5})
)

x = 0.7
jax_out = func_jax(x)
jax_out_flat, jax_out_tree = jax.tree_util.tree_flatten(jax_out)
qml_out_flat, qml_out_tree = jax.tree_util.tree_flatten(func_qml(x))
assert jax_out_tree == qml_out_tree
assert qml.math.allclose(jax_out_flat, qml_out_flat)

# Check overall jaxpr properties
if isinstance(argnum, int):
argnum = [argnum]
jaxpr = jax.make_jaxpr(func_qml)(x)
assert jaxpr.in_avals == [jax.core.ShapedArray((), fdtype, weak_type=True)]
assert len(jaxpr.eqns) == 3

# Compute the flat argnum in order to determine the expected number of out tracers
flat_argnum = []
if 0 in argnum:
flat_argnum.append(0)
if 1 in argnum:
flat_argnum.extend([1, 2])
assert jaxpr.out_avals == [jax.core.ShapedArray((), fdtype)] * (2 * len(flat_argnum))

jac_eqn = jaxpr.eqns[2]

diff_eqn_assertions(jac_eqn, jacobian_prim, argnum=flat_argnum)
assert [var.aval for var in jac_eqn.outvars] == jaxpr.out_avals

manual_out = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x)
manual_out_flat, manual_out_tree = jax.tree_util.tree_flatten(manual_out)
# Assert that the output from the manual evaluation is flat
assert manual_out_tree == jax.tree_util.tree_flatten(manual_out_flat)[1]
assert qml.math.allclose(jax_out_flat, manual_out_flat)

jax.config.update("jax_enable_x64", initial_mode)
Loading