diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 971352772de..506504f631b 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -28,9 +28,10 @@ * Differentiation of hybrid programs via `qml.grad` and `qml.jacobian` can now be captured into plxpr. When evaluating a captured `qml.grad` (`qml.jacobian`) instruction, it will dispatch to `jax.grad` (`jax.jacobian`), which differs from the Autograd implementation - without capture. + without capture. Pytree inputs and outputs are supported. [(#6120)](https://github.com/PennyLaneAI/pennylane/pull/6120) [(#6127)](https://github.com/PennyLaneAI/pennylane/pull/6127) + [(#6134)](https://github.com/PennyLaneAI/pennylane/pull/6134) * Improve unit testing for capturing of nested control flows. [(#6111)](https://github.com/PennyLaneAI/pennylane/pull/6111) diff --git a/pennylane/_grad.py b/pennylane/_grad.py index 859ae5d9fbb..d7cb52e0d52 100644 --- a/pennylane/_grad.py +++ b/pennylane/_grad.py @@ -25,6 +25,7 @@ from pennylane.capture import enabled from pennylane.capture.capture_diff import _get_grad_prim, _get_jacobian_prim +from pennylane.capture.flatfn import FlatFn from pennylane.compiler import compiler from pennylane.compiler.compiler import CompileError @@ -33,18 +34,53 @@ def _capture_diff(func, argnum=None, diff_prim=None, method=None, h=None): """Capture-compatible gradient computation.""" - import jax # pylint: disable=import-outside-toplevel + # pylint: disable=import-outside-toplevel + import jax + from jax.tree_util import tree_flatten, tree_leaves, tree_unflatten, treedef_tuple - if isinstance(argnum, int): - argnum = [argnum] if argnum is None: - argnum = [0] + argnum = 0 + if argnum_is_int := isinstance(argnum, int): + argnum = [argnum] @wraps(func) def new_func(*args, **kwargs): - jaxpr = jax.make_jaxpr(partial(func, **kwargs))(*args) - prim_kwargs = {"argnum": argnum, "jaxpr": jaxpr.jaxpr, "n_consts": len(jaxpr.consts)} - return diff_prim.bind(*jaxpr.consts, *args, **prim_kwargs, method=method, h=h) + flat_args, in_trees = zip(*(tree_flatten(arg) for arg in args)) + full_in_tree = treedef_tuple(in_trees) + + # Create a new input tree that only takes inputs marked by argnum into account + trainable_in_trees = (in_tree for i, in_tree in enumerate(in_trees) if i in argnum) + # If an integer was provided as argnum, unpack the arguments axis of the derivatives + if argnum_is_int: + trainable_in_tree = list(trainable_in_trees)[0] + else: + trainable_in_tree = treedef_tuple(trainable_in_trees) + + # Create argnum for the flat list of input arrays. For each flattened argument, + # add a list of flat argnums if the argument is trainable and an empty list otherwise. + start = 0 + flat_argnum_gen = ( + ( + list(range(start, (start := start + len(flat_arg)))) + if i in argnum + else list(range((start := start + len(flat_arg)), start)) + ) + for i, flat_arg in enumerate(flat_args) + ) + flat_argnum = sum(flat_argnum_gen, start=[]) + + # Create fully flattened function (flat inputs & outputs) + flat_fn = FlatFn(partial(func, **kwargs) if kwargs else func, full_in_tree) + flat_args = sum(flat_args, start=[]) + jaxpr = jax.make_jaxpr(flat_fn)(*flat_args) + prim_kwargs = {"argnum": flat_argnum, "jaxpr": jaxpr.jaxpr, "n_consts": len(jaxpr.consts)} + out_flat = diff_prim.bind(*jaxpr.consts, *flat_args, **prim_kwargs, method=method, h=h) + # flatten once more to go from 2D derivative structure (outputs, args) to flat structure + out_flat = tree_leaves(out_flat) + assert flat_fn.out_tree is not None, "out_tree should be set after executing flat_fn" + # The derivative output tree is the composition of output tree and trainable input trees + combined_tree = flat_fn.out_tree.compose(trainable_in_tree) + return tree_unflatten(combined_tree, out_flat) return new_func diff --git a/pennylane/_version.py b/pennylane/_version.py index a6ead820881..77639685bc6 100644 --- a/pennylane/_version.py +++ b/pennylane/_version.py @@ -16,4 +16,4 @@ Version number (major.minor.patch[-label]) """ -__version__ = "0.39.0-dev14" +__version__ = "0.39.0-dev15" diff --git a/pennylane/capture/__init__.py b/pennylane/capture/__init__.py index 6deeef29682..2e43a246e2c 100644 --- a/pennylane/capture/__init__.py +++ b/pennylane/capture/__init__.py @@ -34,6 +34,7 @@ ~create_measurement_wires_primitive ~create_measurement_mcm_primitive ~qnode_call + ~FlatFn The ``primitives`` submodule offers easy access to objects with jax dependencies such as @@ -154,6 +155,7 @@ def _(*args, **kwargs): create_measurement_mcm_primitive, ) from .capture_qnode import qnode_call +from .flatfn import FlatFn # by defining this here, we avoid # E0611: No name 'AbstractOperator' in module 'pennylane.capture' (no-name-in-module) @@ -196,4 +198,5 @@ def __getattr__(key): "AbstractOperator", "AbstractMeasurement", "qnode_prim", + "FlatFn", ) diff --git a/pennylane/capture/capture_diff.py b/pennylane/capture/capture_diff.py index 92dde5a2956..829a9516af1 100644 --- a/pennylane/capture/capture_diff.py +++ b/pennylane/capture/capture_diff.py @@ -103,7 +103,7 @@ def _(*args, argnum, jaxpr, n_consts, method, h): def func(*inner_args): return jax.core.eval_jaxpr(jaxpr, consts, *inner_args) - return jax.jacobian(func, argnums=argnum)(*args) + return jax.tree_util.tree_leaves(jax.jacobian(func, argnums=argnum)(*args)) # pylint: disable=unused-argument @jacobian_prim.def_abstract_eval diff --git a/pennylane/capture/capture_qnode.py b/pennylane/capture/capture_qnode.py index f7f06451230..491b9f3f6a4 100644 --- a/pennylane/capture/capture_qnode.py +++ b/pennylane/capture/capture_qnode.py @@ -82,8 +82,12 @@ def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr, n_consts): mps = qfunc_jaxpr.outvars return _get_shapes_for(*mps, shots=shots, num_device_wires=len(device.wires)) - def _qnode_jvp(*args_and_tangents, **impl_kwargs): - return jax.jvp(partial(qnode_prim.impl, **impl_kwargs), *args_and_tangents) + def make_zero(tan, arg): + return jax.lax.zeros_like_array(arg) if isinstance(tan, ad.Zero) else tan + + def _qnode_jvp(args, tangents, **impl_kwargs): + tangents = tuple(map(make_zero, tangents, args)) + return jax.jvp(partial(qnode_prim.impl, **impl_kwargs), args, tangents) ad.primitive_jvps[qnode_prim] = _qnode_jvp @@ -174,7 +178,7 @@ def f(x): qnode_kwargs = {"diff_method": qnode.diff_method, **execute_kwargs, **mcm_config} qnode_prim = _get_qnode_prim() - flat_args, _ = jax.tree_util.tree_flatten(args) + flat_args = jax.tree_util.tree_leaves(args) res = qnode_prim.bind( *qfunc_jaxpr.consts, *flat_args, diff --git a/pennylane/capture/explanations.md b/pennylane/capture/explanations.md index 61b88b94030..84feef9786f 100644 --- a/pennylane/capture/explanations.md +++ b/pennylane/capture/explanations.md @@ -159,12 +159,11 @@ You can also see the const variable `a` as argument `e:i32[]` to the inner neste ### Pytree handling Evaluating a jaxpr requires accepting and returning a flat list of tensor-like inputs and outputs. -list of tensor-like outputs. These long lists can be hard to manage and are very -restrictive on the allowed functions, but we can take advantage of pytrees to allow handling -arbitrary functions. +These long lists can be hard to manage and are very restrictive on the allowed functions, but we +can take advantage of pytrees to allow handling arbitrary functions. To start, we import the `FlatFn` helper. This class converts a function to one that caches -the resulting result pytree into `flat_fn.out_tree` when executed. This can be used to repack the +the result pytree into `flat_fn.out_tree` when executed. This can be used to repack the results into the correct shape. It also returns flattened results. This does not particularly matter for program capture, as we will only be producing jaxpr from the function, not calling it directly. diff --git a/pennylane/capture/flatfn.py b/pennylane/capture/flatfn.py index 4ba6005f41e..59e1c52b948 100644 --- a/pennylane/capture/flatfn.py +++ b/pennylane/capture/flatfn.py @@ -29,23 +29,48 @@ class FlatFn: property, so that the results can be repacked later. It also returns flattened results instead of the original result object. + If an ``in_tree`` is provided, the function accepts flattened inputs instead of the + original inputs with tree structure given by ``in_tree``. + + **Example** + + >>> import jax + >>> from pennylane.capture.flatfn import FlatFn >>> def f(x): ... return {"y": 2+x["x"]} >>> flat_f = FlatFn(f) - >>> res = flat_f({"x": 0}) + >>> arg = {"x": 0.5} + >>> res = flat_f(arg) + >>> res + [2.5] + >>> jax.tree_util.tree_unflatten(flat_f.out_tree, res) + {'y': 2.5} + + If we want to use a fully flattened function that also takes flat inputs instead of + the original inputs with tree structure, we can provide the treedef for this input + structure: + + >>> flat_args, in_tree = jax.tree_util.tree_flatten((arg,)) + >>> flat_f = FlatFn(f, in_tree) + >>> res = flat_f(*flat_args) >>> res - [2] + [2.5] >>> jax.tree_util.tree_unflatten(flat_f.out_tree, res) {'y': 2.5} + Note that the ``in_tree`` has to be created by flattening a tuple of all input + arguments, even if there is only a single argument. """ - def __init__(self, f): + def __init__(self, f, in_tree=None): self.f = f + self.in_tree = in_tree self.out_tree = None update_wrapper(self, f) def __call__(self, *args): + if self.in_tree is not None: + args = jax.tree_util.tree_unflatten(self.in_tree, args) out = self.f(*args) out_flat, out_tree = jax.tree_util.tree_flatten(out) self.out_tree = out_tree diff --git a/tests/capture/test_capture_diff.py b/tests/capture/test_capture_diff.py index cf7834aafb8..07596e01b47 100644 --- a/tests/capture/test_capture_diff.py +++ b/tests/capture/test_capture_diff.py @@ -99,11 +99,11 @@ def func_jax(x): assert qml.math.allclose(func_qml(x), jax_out) # Check overall jaxpr properties - if isinstance(argnum, int): - argnum = [argnum] jaxpr = jax.make_jaxpr(func_qml)(x) assert jaxpr.in_avals == [jax.core.ShapedArray((), fdtype, weak_type=True)] assert len(jaxpr.eqns) == 3 + if isinstance(argnum, int): + argnum = [argnum] assert jaxpr.out_avals == [jax.core.ShapedArray((), fdtype, weak_type=True)] * len(argnum) grad_eqn = jaxpr.eqns[2] @@ -260,6 +260,106 @@ def circuit(x): jax.config.update("jax_enable_x64", initial_mode) + @pytest.mark.parametrize("argnum", ([0, 1], [0], [1])) + def test_grad_pytree_input(self, x64_mode, argnum): + """Test that the qml.grad primitive can be captured with pytree inputs.""" + + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + fdtype = jax.numpy.float64 if x64_mode else jax.numpy.float32 + + def inner_func(x, y): + return jnp.prod(jnp.sin(x["a"]) * jnp.cos(y[0]["b"][1]) ** 2) + + def func_qml(x): + return qml.grad(inner_func, argnum=argnum)( + {"a": x}, ({"b": [None, 0.4 * jnp.sqrt(x)]},) + ) + + def func_jax(x): + return jax.grad(inner_func, argnums=argnum)( + {"a": x}, ({"b": [None, 0.4 * jnp.sqrt(x)]},) + ) + + x = 0.7 + jax_out = func_jax(x) + jax_out_flat, jax_out_tree = jax.tree_util.tree_flatten(jax_out) + qml_out_flat, qml_out_tree = jax.tree_util.tree_flatten(func_qml(x)) + assert jax_out_tree == qml_out_tree + assert qml.math.allclose(jax_out_flat, qml_out_flat) + + # Check overall jaxpr properties + jaxpr = jax.make_jaxpr(func_qml)(x) + assert jaxpr.in_avals == [jax.core.ShapedArray((), fdtype, weak_type=True)] + assert len(jaxpr.eqns) == 3 + argnum = [argnum] if isinstance(argnum, int) else argnum + assert jaxpr.out_avals == [jax.core.ShapedArray((), fdtype, weak_type=True)] * len(argnum) + + grad_eqn = jaxpr.eqns[2] + diff_eqn_assertions(grad_eqn, grad_prim, argnum=argnum) + assert [var.aval for var in grad_eqn.outvars] == jaxpr.out_avals + assert len(grad_eqn.params["jaxpr"].eqns) == 6 # 5 numeric eqns, 1 conversion eqn + + manual_out = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x) + manual_out_flat, manual_out_tree = jax.tree_util.tree_flatten(manual_out) + # Assert that the output from the manual evaluation is flat + assert manual_out_tree == jax.tree_util.tree_flatten(manual_out_flat)[1] + assert qml.math.allclose(jax_out_flat, manual_out_flat) + + jax.config.update("jax_enable_x64", initial_mode) + + @pytest.mark.parametrize("argnum", ([0, 1, 2], [0, 2], [1], 0)) + def test_grad_qnode_with_pytrees(self, argnum, x64_mode): + """Test capturing the gradient of a qnode that uses Pytrees.""" + # pylint: disable=protected-access + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + fdtype = jax.numpy.float64 if x64_mode else jax.numpy.float32 + + dev = qml.device("default.qubit", wires=2) + + @qml.qnode(dev) + def circuit(x, y, z): + qml.RX(x["a"], wires=0) + qml.RY(y, wires=0) + qml.RZ(z[1][0], wires=0) + return qml.expval(qml.X(0)) + + dcircuit = qml.grad(circuit, argnum=argnum) + x = {"a": 0.6, "b": 0.9} + y = 0.6 + z = ({"c": 0.5}, [0.2, 0.3]) + qml_out = dcircuit(x, y, z) + qml_out_flat, qml_out_tree = jax.tree_util.tree_flatten(qml_out) + jax_out = jax.grad(circuit, argnums=argnum)(x, y, z) + jax_out_flat, jax_out_tree = jax.tree_util.tree_flatten(jax_out) + assert jax_out_tree == qml_out_tree + assert qml.math.allclose(jax_out_flat, qml_out_flat) + + jaxpr = jax.make_jaxpr(dcircuit)(x, y, z) + + assert len(jaxpr.eqns) == 1 # grad equation + assert jaxpr.in_avals == [jax.core.ShapedArray((), fdtype, weak_type=True)] * 6 + argnum = [argnum] if isinstance(argnum, int) else argnum + num_out_avals = 2 * (0 in argnum) + (1 in argnum) + 3 * (2 in argnum) + assert jaxpr.out_avals == [jax.core.ShapedArray((), fdtype, weak_type=True)] * num_out_avals + + grad_eqn = jaxpr.eqns[0] + assert all(invar.aval == in_aval for invar, in_aval in zip(grad_eqn.invars, jaxpr.in_avals)) + flat_argnum = [0, 1] * (0 in argnum) + [2] * (1 in argnum) + [3, 4, 5] * (2 in argnum) + diff_eqn_assertions(grad_eqn, grad_prim, argnum=flat_argnum) + grad_jaxpr = grad_eqn.params["jaxpr"] + assert len(grad_jaxpr.eqns) == 1 # qnode equation + + flat_args = jax.tree_util.tree_leaves((x, y, z)) + manual_out = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *flat_args) + manual_out_flat, manual_out_tree = jax.tree_util.tree_flatten(manual_out) + # Assert that the output from the manual evaluation is flat + assert manual_out_tree == jax.tree_util.tree_flatten(manual_out_flat)[1] + assert qml.math.allclose(jax_out_flat, manual_out_flat) + + jax.config.update("jax_enable_x64", initial_mode) + def _jac_allclose(jac1, jac2, num_axes, atol=1e-8): """Test that two Jacobians, given as nested sequences of arrays, are equal.""" @@ -279,10 +379,6 @@ class TestJacobian: @pytest.mark.parametrize("argnum", ([0, 1], [0], [1], 0, 1)) def test_classical_jacobian(self, x64_mode, argnum): """Test that the qml.jacobian primitive can be captured with classical nodes.""" - if isinstance(argnum, list) and len(argnum) > 1: - # These cases will only be unlocked with Pytree support - pytest.xfail() - initial_mode = jax.config.jax_enable_x64 jax.config.update("jax_enable_x64", x64_mode) fdtype = jnp.float64 if x64_mode else jnp.float32 @@ -307,14 +403,14 @@ def inner_func(x, y): func_jax = jax.jacobian(inner_func, argnums=argnum) jax_out = func_jax(x, y) - num_axes = 1 if isinstance(argnum, int) else 2 - assert _jac_allclose(func_qml(x, y), jax_out, num_axes) + qml_out = func_qml(x, y) + num_axes = 1 if (int_argnum := isinstance(argnum, int)) else 2 + assert _jac_allclose(qml_out, jax_out, num_axes) # Check overall jaxpr properties - jaxpr = jax.make_jaxpr(func_jax)(x, y) jaxpr = jax.make_jaxpr(func_qml)(x, y) - if isinstance(argnum, int): + if int_argnum: argnum = [argnum] exp_in_avals = [shaped_array(shape) for shape in [(4,), (2, 3)]] @@ -331,6 +427,9 @@ def inner_func(x, y): diff_eqn_assertions(jac_eqn, jacobian_prim, argnum=argnum) manual_eval = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x, y) + # Evaluating jaxpr gives flat list results. Need to adapt the JAX output to that + if not int_argnum: + jax_out = sum(jax_out, start=()) assert _jac_allclose(manual_eval, jax_out, num_axes) jax.config.update("jax_enable_x64", initial_mode) @@ -354,7 +453,6 @@ def func(x): return jnp.prod(x) * jnp.sin(x), jnp.sum(x**2) x = jnp.array([0.7, -0.9, 0.6, 0.3]) - x = x[:1] dim = len(x) eye = jnp.eye(dim) @@ -382,15 +480,20 @@ def func(x): # 2nd order qml_func_2 = qml.jacobian(qml_func_1) + hyperdiag = qml.numpy.zeros((4, 4, 4)) + for i in range(4): + hyperdiag[i, i, i] = 1 expected_2 = ( prod_sin[:, None, None] / x[None, :, None] / x[None, None, :] + - jnp.tensordot(prod_sin, eye / x**2, axes=0) # Correct diagonal entries + prod_cos_e_i[:, :, None] / x[None, None, :] + prod_cos_e_i[:, None, :] / x[None, :, None] - - jnp.tensordot(prod_sin, eye + eye / x**2, axes=0), - jnp.tensordot(jnp.ones(dim), eye * 2, axes=0), + - prod_sin * hyperdiag, + eye * 2, ) # Output only has one tuple axis - assert _jac_allclose(qml_func_2(x), expected_2, 1) + atol = 1e-8 if x64_mode else 2e-7 + assert _jac_allclose(qml_func_2(x), expected_2, 1, atol=atol) jaxpr_2 = jax.make_jaxpr(qml_func_2)(x) assert jaxpr_2.in_avals == [jax.core.ShapedArray((dim,), fdtype)] @@ -406,7 +509,7 @@ def func(x): assert jac_eqn.params["jaxpr"].eqns[0].primitive == jacobian_prim manual_eval_2 = jax.core.eval_jaxpr(jaxpr_2.jaxpr, jaxpr_2.consts, x) - assert _jac_allclose(manual_eval_2, expected_2, 1) + assert _jac_allclose(manual_eval_2, expected_2, 1, atol=atol) jax.config.update("jax_enable_x64", initial_mode) @@ -473,3 +576,58 @@ def circuit(x): assert _jac_allclose(manual_res, expected_res, 1) jax.config.update("jax_enable_x64", initial_mode) + + @pytest.mark.parametrize("argnum", ([0, 1], [0], [1])) + def test_jacobian_pytrees(self, x64_mode, argnum): + """Test that the qml.jacobian primitive can be captured with + pytree inputs and outputs.""" + + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + fdtype = jax.numpy.float64 if x64_mode else jax.numpy.float32 + + def inner_func(x, y): + return { + "prod_cos": jnp.prod(jnp.sin(x["a"]) * jnp.cos(y[0]["b"][1]) ** 2), + "sum_sin": jnp.sum(jnp.sin(x["a"]) * jnp.sin(y[1]["c"]) ** 2), + } + + def func_qml(x): + return qml.jacobian(inner_func, argnum=argnum)( + {"a": x}, ({"b": [None, 0.4 * jnp.sqrt(x)]}, {"c": 0.5}) + ) + + def func_jax(x): + return jax.jacobian(inner_func, argnums=argnum)( + {"a": x}, ({"b": [None, 0.4 * jnp.sqrt(x)]}, {"c": 0.5}) + ) + + x = 0.7 + jax_out = func_jax(x) + jax_out_flat, jax_out_tree = jax.tree_util.tree_flatten(jax_out) + qml_out_flat, qml_out_tree = jax.tree_util.tree_flatten(func_qml(x)) + assert jax_out_tree == qml_out_tree + assert qml.math.allclose(jax_out_flat, qml_out_flat) + + # Check overall jaxpr properties + jaxpr = jax.make_jaxpr(func_qml)(x) + assert jaxpr.in_avals == [jax.core.ShapedArray((), fdtype, weak_type=True)] + assert len(jaxpr.eqns) == 3 + + argnum = [argnum] if isinstance(argnum, int) else argnum + # Compute the flat argnum in order to determine the expected number of out tracers + flat_argnum = [0] * (0 in argnum) + [1, 2] * (1 in argnum) + assert jaxpr.out_avals == [jax.core.ShapedArray((), fdtype)] * (2 * len(flat_argnum)) + + jac_eqn = jaxpr.eqns[2] + + diff_eqn_assertions(jac_eqn, jacobian_prim, argnum=flat_argnum) + assert [var.aval for var in jac_eqn.outvars] == jaxpr.out_avals + + manual_out = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x) + manual_out_flat, manual_out_tree = jax.tree_util.tree_flatten(manual_out) + # Assert that the output from the manual evaluation is flat + assert manual_out_tree == jax.tree_util.tree_flatten(manual_out_flat)[1] + assert qml.math.allclose(jax_out_flat, manual_out_flat) + + jax.config.update("jax_enable_x64", initial_mode)