Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Capture the adjoint transform into jaxpr #5966

Merged
merged 22 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@
* `QuantumScript.hash` is now cached, leading to performance improvements.
[(#5919)](https://github.com/PennyLaneAI/pennylane/pull/5919)

* Applying `adjoint` to a quantum function can now be captured into plxpr.
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
[(#5966)](https://github.com/PennyLaneAI/pennylane/pull/5966)

* Set operations are now supported by Wires.
[(#5983)](https://github.com/PennyLaneAI/pennylane/pull/5983)

Expand Down
127 changes: 126 additions & 1 deletion pennylane/capture/explanations.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
This documentation explains the principles behind `qml.capture.CaptureMeta`.
This documentation explains the principles behind `qml.capture.CaptureMeta` and higher order primitives.


```python
Expand Down Expand Up @@ -32,6 +32,131 @@ def my_func(x):
[a:f32[1] = my_func b]
```

# Higher Order Primitives and nested jaxpr

Higher order primitives are essentially function transforms. They are functions that accept other
functions. Our higher order primitives will include `adjoint`, `ctrl`, `for_loop`, `while_loop`, `cond`, `grad`,
and `jacobian`.

Jax describes two separate ways of defining higher order derivatives:

1) *On the fly processing*: the primitive binds the function itself as metadata

2) *Staged processing*: the primitive binds the function's jaxpr as metadata.

Jax also has a [`CallPrimitive`](https://github.com/google/jax/blob/23ad313817f20345c60281fbf727cf4f8dc83181/jax/_src/core.py#L2366)
but using this seems to be more trouble than its worth so far. Notably, this class is rather private and undocumented.

We will proceed with using *staged processing* for now. This choice is more straightforward to implement, follows Catalyst's choice of representation, and is more
explicit in the contents. On the fly isn't as much "program capture" as deferring capture till later. We want to immediately capture all aspects of the jaxpr.


Suppose we have a transform that repeats a function n times

```python
def repeat(func: Callable, n: int) -> Callable:
def new_func(*args, **kwargs):
for _ in range(n):
args = func(*args, **kwargs)
return args
return new_func
```

We can start by creating the primitive itself:

```python
repeat_prim = jax.core.Primitive("repeat")
repeat_prim.multiple_results = True
```

Instead of starting with the implementation and abstract evaluation, let's write out the function that will
bind the primitive first. This will showcase what the args and keyword args for our bind call will look like:

```python
from functools import partial
from typing import Callable

def repeat(func: Callable, n: int) -> Callable:
def new_func(*args, **kwargs):
func_bound_kwargs = partial(func, **kwargs)
jaxpr = jax.make_jaxpr(func_bound_kwargs)(*args)
n_consts = len(jaxpr.consts)
return repeat_prim.bind(n, *jaxpr.consts, *args, jaxpr=jaxpr.jaxpr, n_consts=n_consts)
return new_func
```

Several things to notice about this code.

First, we have to make the jaxpr from a function with any keyword arguments
already bound. `jax.make_jaxpr` does not currently accept keyword arguments for the function, so we need to pre-bind them.

Next, we decided to make the integer `n` a traceable parameter instead of metadata. We could have chosen to make
`n` metadata instead. This way, we can compile our function once for different integers `n`, and it is in line with how
catalyst treats `for_loop` and `while_loop`. If the function produced outputs of different types and shapes for different `n`,
we would have to treat `n` like metadata and re-compile for different integers `n`.

Finally, we promote the `jaxpr.consts` to being actual positional arguments. The consts
contain any closure variables that the function implicitly depends on that are not present
in the actual call signature. For example: `def f(x): return x+y`. `y` here would be a
`const` pulled from the global environment. `f` implicitly depends on it, and it is
required to reproduce the full behavior of `f`. To separate the normal positional
arguments from the consts, we then also need a `n_consts` keyword argument.

albi3ro marked this conversation as resolved.
Show resolved Hide resolved
Now we can define the implementation for our primitive.

```python
@repeat_prim.def_impl
def _(n, *args, jaxpr, n_consts):
consts = args[:n_consts]
args = args[n_consts:]
for _ in range(n):
args = jax.core.eval_jaxpr(jaxpr, consts, *args)
return args
```

Here we use `jax.core.eval_jaxpr` to execute the jaxpr with concrete arguments. If we had instead used
*on the fly processing*, we could have simply executed the stored function, but when using *staged processing*, we need
to directly evaluate the jaxpr instead.

In addition, we need to define the abstract evaluation. As the function in our case returns outputs that match the inputs in number, shape and type, we can simply extract the abstract values of the `args`.
albi3ro marked this conversation as resolved.
Show resolved Hide resolved

```python
@repeat_prim.def_abstract_eval
def _(n, *args, jaxpr, n_consts):
return args[n_consts:]
```

Now that we have all the parts, we can see it in action:

```pycon
>>> a = jax.numpy.array(1)
>>> def func(x, y, y_coeff=1):
... return (x + a, y_coeff * y)
>>> def workflow(x):
... return repeat(func, 2)(x, 2.0, y_coeff=2.0)
>>> workflow(0.5)
[Array(2.5, dtype=float32, weak_type=True),
Array(8., dtype=float32, weak_type=True)]
>>> jax.make_jaxpr(workflow)(0.5)
{ lambda a:i32[]; b:f32[]. let
c:f32[] d:f32[] = repeat[
jaxpr={ lambda e:i32[]; f:f32[] g:f32[]. let
h:f32[] = convert_element_type[new_dtype=float32 weak_type=True] e
i:f32[] = add f h
j:f32[] = mul 2.0 g
in (i, j) }
n_consts=1
] 2 a b 2.0
in (c, d) }
>>> jax.make_jaxpr(workflow)(0.5).consts
[Array(1, dtype=int32, weak_type=True)]
```

Some notes here about how read this. `a:i32[]` is the global integer variable `a` that is
a constant. The arguments to the `repeat` primitive are `n (const a) x (hardcoded 2.0=y)`.
You can also see the const variable `a` as argument `e:i32[]` to the inner nested jaxpr as well.


## Metaprogramming
dime10 marked this conversation as resolved.
Show resolved Hide resolved
albi3ro marked this conversation as resolved.
Show resolved Hide resolved


Expand Down
78 changes: 66 additions & 12 deletions pennylane/ops/op_math/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
"""
This submodule defines the symbolic operation that indicates the adjoint of an operator.
"""
from functools import wraps
from functools import lru_cache, partial, wraps
from typing import Callable, overload

import pennylane as qml
from pennylane.compiler import compiler
Expand All @@ -26,7 +27,10 @@
from .symbolicop import SymbolicOp


# pylint: disable=no-member
@overload
def adjoint(fn: Operator, lazy: bool = True) -> Operator: ...
@overload
def adjoint(fn: Callable, lazy: bool = True) -> Callable: ...
def adjoint(fn, lazy=True):
"""Create the adjoint of an Operator or a function that applies the adjoint of the provided function.
:func:`~.qjit` compatible.
Expand Down Expand Up @@ -162,25 +166,75 @@ def loop_fn(i):
available_eps = compiler.AvailableCompilers.names_entrypoints
ops_loader = available_eps[active_jit]["ops"].load()
return ops_loader.adjoint(fn, lazy=lazy)
if qml.math.is_abstract(fn):
return Adjoint(fn)
return create_adjoint_op(fn, lazy)


def create_adjoint_op(fn, lazy):
"""Main logic for qml.adjoint, but allows bypassing the compiler dispatch if needed."""
if qml.math.is_abstract(fn):
return Adjoint(fn)
if isinstance(fn, Operator):
return Adjoint(fn) if lazy else _single_op_eager(fn, update_queue=True)
if not callable(fn):
raise ValueError(
f"The object {fn} of type {type(fn)} is not callable. "
"This error might occur if you apply adjoint to a list "
"of operations instead of a function or template."
if callable(fn):
if qml.capture.enabled():
return _capture_adjoint_transform(fn, lazy=lazy)
return _adjoint_transform(fn, lazy=lazy)
raise ValueError(
f"The object {fn} of type {type(fn)} is not callable. "
"This error might occur if you apply adjoint to a list "
"of operations instead of a function or template."
)


@lru_cache # only create the first time requested
def _get_adjoint_qfunc_prim():
"""See capture/explanations.md : Higher Order primitives for more information on this code."""
# if capture is enabled, jax should be installed
import jax # pylint: disable=import-outside-toplevel

adjoint_prim = jax.core.Primitive("adjoint_transform")
adjoint_prim.multiple_results = True

@adjoint_prim.def_impl
def _(*args, jaxpr, lazy, n_consts):
consts = args[:n_consts]
args = args[n_consts:]
with qml.queuing.AnnotatedQueue() as q:
jax.core.eval_jaxpr(jaxpr, consts, *args)
ops, _ = qml.queuing.process_queue(q)
for op in reversed(ops):
adjoint(op, lazy=lazy)
return []
albi3ro marked this conversation as resolved.
Show resolved Hide resolved

@adjoint_prim.def_abstract_eval
def _(*_, **__):
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
return []

return adjoint_prim


def _capture_adjoint_transform(qfunc: Callable, lazy=True) -> Callable:
"""Capture compatible way of performing an adjoint transform."""
# note that this logic is tested in `tests/capture/test_nested_plxpr.py`
import jax # pylint: disable=import-outside-toplevel

adjoint_prim = _get_adjoint_qfunc_prim()

@wraps(qfunc)
def new_qfunc(*args, **kwargs):
jaxpr = jax.make_jaxpr(partial(qfunc, **kwargs))(*args)
adjoint_prim.bind(
*jaxpr.consts, *args, jaxpr=jaxpr.jaxpr, lazy=lazy, n_consts=len(jaxpr.consts)
)

@wraps(fn)
return new_qfunc


def _adjoint_transform(qfunc: Callable, lazy=True) -> Callable:
# default adjoint transform when capture is not enabled.
@wraps(qfunc)
def wrapper(*args, **kwargs):
qscript = make_qscript(fn)(*args, **kwargs)
qscript = make_qscript(qfunc)(*args, **kwargs)
if lazy:
adjoint_ops = [Adjoint(op) for op in reversed(qscript.operations)]
else:
Expand All @@ -191,7 +245,7 @@ def wrapper(*args, **kwargs):
return wrapper


def _single_op_eager(op, update_queue=False):
def _single_op_eager(op: Operator, update_queue: bool = False) -> Operator:
if op.has_adjoint:
adj = op.adjoint()
if update_queue:
Expand Down
12 changes: 12 additions & 0 deletions tests/capture/test_capture_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,18 @@

import pennylane as qml

jax = pytest.importorskip("jax")

pytestmark = pytest.mark.jax


@pytest.fixture(autouse=True)
def enable_disable_plxpr():
"""enable and disable capture around each test."""
qml.capture.enable()
yield
qml.capture.disable()


def test_no_attribute_available():
"""Test that if we try and access an attribute that doesn't exist, we get an attribute error."""
Expand Down
Loading
Loading