Skip to content

Commit

Permalink
[Program Capture] Capture & execute qml.jacobian in plxpr (#6127)
Browse files Browse the repository at this point in the history
**Context:**
We're adding support for differentiation in plxpr, also see #6120.

**Description of the Change:**
This PR adds support for `qml.jacobian`, similar to the support for
`qml.grad`.
Note that Pytree support will be needed to allow for multi-argument
derivatives.

**Benefits:**
Capture derivatives of non-scalar functions.

**Possible Drawbacks:**
See discussion around `qml.grad` in #6120.

**Related GitHub Issues:**

[sc-71860]

---------

Co-authored-by: Christina Lee <christina@xanadu.ai>
  • Loading branch information
dwierichs and albi3ro committed Sep 9, 2024
1 parent 48b9dc7 commit 38ee38e
Show file tree
Hide file tree
Showing 6 changed files with 473 additions and 203 deletions.
9 changes: 9 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@
which differs from the Autograd implementation of `qml.grad` itself.
[(#6120)](https://github.com/PennyLaneAI/pennylane/pull/6120)

<h4>Capturing and representing hybrid programs</h4>

* Differentiation of hybrid programs via `qml.grad` and `qml.jacobian` can now be captured
into plxpr. When evaluating a captured `qml.grad` (`qml.jacobian`) instruction, it will
dispatch to `jax.grad` (`jax.jacobian`), which differs from the Autograd implementation
without capture.
[(#6120)](https://github.com/PennyLaneAI/pennylane/pull/6120)
[(#6127)](https://github.com/PennyLaneAI/pennylane/pull/6127)

* Improve unit testing for capturing of nested control flows.
[(#6111)](https://github.com/PennyLaneAI/pennylane/pull/6111)

Expand Down
7 changes: 5 additions & 2 deletions pennylane/_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from autograd.wrap_util import unary_to_nary

from pennylane.capture import enabled
from pennylane.capture.capture_diff import _get_grad_prim
from pennylane.capture.capture_diff import _get_grad_prim, _get_jacobian_prim
from pennylane.compiler import compiler
from pennylane.compiler.compiler import CompileError

Expand Down Expand Up @@ -434,8 +434,11 @@ def circuit(x):
ops_loader = available_eps[active_jit]["ops"].load()
return ops_loader.jacobian(func, method=method, h=h, argnums=argnum)

if enabled():
return _capture_diff(func, argnum, _get_jacobian_prim(), method=method, h=h)

if method or h:
raise ValueError(f"Invalid values for 'method={method}' and 'h={h}' in interpreted mode")
raise ValueError(f"Invalid values '{method=}' and '{h=}' without QJIT.")

def _get_argnum(args):
"""Inspect the arguments for differentiability and return the
Expand Down
35 changes: 35 additions & 0 deletions pennylane/capture/capture_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,38 @@ def _(*args, argnum, jaxpr, n_consts, method, h):
return tuple(jaxpr.invars[i].aval for i in argnum)

return grad_prim


@lru_cache
def _get_jacobian_prim():
"""Create a primitive for Jacobian computations.
This primitive is used when capturing ``qml.jacobian``.
"""
jacobian_prim = create_non_jvp_primitive()("jacobian")
jacobian_prim.multiple_results = True # pylint: disable=attribute-defined-outside-init

# pylint: disable=too-many-arguments
@jacobian_prim.def_impl
def _(*args, argnum, jaxpr, n_consts, method, h):
if method or h: # pragma: no cover
raise ValueError(f"Invalid values '{method=}' and '{h=}' without QJIT.")
consts = args[:n_consts]
args = args[n_consts:]

def func(*inner_args):
return jax.core.eval_jaxpr(jaxpr, consts, *inner_args)

return jax.jacobian(func, argnums=argnum)(*args)

# pylint: disable=unused-argument
@jacobian_prim.def_abstract_eval
def _(*args, argnum, jaxpr, n_consts, method, h):
in_avals = [jaxpr.invars[i].aval for i in argnum]
out_shapes = (outvar.aval.shape for outvar in jaxpr.outvars)
return [
jax.core.ShapedArray(out_shape + in_aval.shape, in_aval.dtype)
for out_shape in out_shapes
for in_aval in in_avals
]

return jacobian_prim
4 changes: 3 additions & 1 deletion pennylane/capture/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from pennylane.ops.op_math.condition import _get_cond_qfunc_prim
from pennylane.ops.op_math.controlled import _get_ctrl_qfunc_prim

from .capture_diff import _get_grad_prim
from .capture_diff import _get_grad_prim, _get_jacobian_prim
from .capture_measurements import _get_abstract_measurement
from .capture_operators import _get_abstract_operator
from .capture_qnode import _get_qnode_prim
Expand All @@ -32,6 +32,7 @@
adjoint_transform_prim = _get_adjoint_qfunc_prim()
ctrl_transform_prim = _get_ctrl_qfunc_prim()
grad_prim = _get_grad_prim()
jacobian_prim = _get_jacobian_prim()
qnode_prim = _get_qnode_prim()
cond_prim = _get_cond_qfunc_prim()
for_loop_prim = _get_for_loop_qfunc_prim()
Expand All @@ -44,6 +45,7 @@
"adjoint_transform_prim",
"ctrl_transform_prim",
"grad_prim",
"jacobian_prim",
"qnode_prim",
"cond_prim",
"for_loop_prim",
Expand Down
Loading

0 comments on commit 38ee38e

Please sign in to comment.