Skip to content

Commit

Permalink
Merge branch 'master' into tree_mcm_tests
Browse files Browse the repository at this point in the history
  • Loading branch information
obliviateandsurrender committed Sep 17, 2024
2 parents caf70ba + 7a4a44b commit 599b7e2
Show file tree
Hide file tree
Showing 43 changed files with 570 additions and 238 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/install_deps/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,23 @@ inputs:
jax_version:
description: The version of JAX to install for any job that requires JAX
required: false
default: 0.4.23
default: '0.4.23'
install_tensorflow:
description: Indicate if TensorFlow should be installed or not
required: false
default: 'true'
tensorflow_version:
description: The version of TensorFlow to install for any job that requires TensorFlow
required: false
default: 2.16.0
default: '2.16.0'
install_pytorch:
description: Indicate if PyTorch should be installed or not
required: false
default: 'true'
pytorch_version:
description: The version of PyTorch to install for any job that requires PyTorch
required: false
default: 2.3.0
default: '2.3.0'
install_pennylane_lightning_master:
description: Indicate if PennyLane-Lightning should be installed from the master branch
required: false
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tests-gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ concurrency:
cancel-in-progress: true

env:
TORCH_VERSION: 2.2.0
TORCH_VERSION: 2.3.0

jobs:
gpu-tests:
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,5 @@ config.toml
qml_debug.log
datasets/*
.benchmarks/*
*.h5
*.hdf5
28 changes: 19 additions & 9 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,29 @@

<h3>Improvements 🛠</h3>

* PennyLane is now compatible with NumPy 2.0.
[(#6061)](https://github.com/PennyLaneAI/pennylane/pull/6061)
[(#6258)](https://github.com/PennyLaneAI/pennylane/pull/6258)

* `qml.qchem.excitations` now optionally returns fermionic operators.
[(#6171)](https://github.com/PennyLaneAI/pennylane/pull/6171)
[(#6171)](https://github.com/PennyLaneAI/pennylane/pull/6171)

* The `diagonalize_measurements` transform now uses a more efficient method of diagonalization
when possible, based on the `pauli_rep` of the relevant observables.
[#6113](https://github.com/PennyLaneAI/pennylane/pull/6113/)

<h4>Capturing and representing hybrid programs</h4>

* Differentiation of hybrid programs via `qml.grad` can now be captured into plxpr.
When evaluating a captured `qml.grad` instruction, it will dispatch to `jax.grad`,
which differs from the Autograd implementation of `qml.grad` itself.
[(#6120)](https://github.com/PennyLaneAI/pennylane/pull/6120)
* The `Hermitian` operator now has a `compute_sparse_matrix` implementation.
[(#6225)](https://github.com/PennyLaneAI/pennylane/pull/6225)

<h4>Capturing and representing hybrid programs</h4>

* Differentiation of hybrid programs via `qml.grad` and `qml.jacobian` can now be captured
into plxpr. When evaluating a captured `qml.grad` (`qml.jacobian`) instruction, it will
dispatch to `jax.grad` (`jax.jacobian`), which differs from the Autograd implementation
without capture.
without capture. Pytree inputs and outputs are supported.
[(#6120)](https://github.com/PennyLaneAI/pennylane/pull/6120)
[(#6127)](https://github.com/PennyLaneAI/pennylane/pull/6127)
[(#6134)](https://github.com/PennyLaneAI/pennylane/pull/6134)

* Improve unit testing for capturing of nested control flows.
[(#6111)](https://github.com/PennyLaneAI/pennylane/pull/6111)
Expand Down Expand Up @@ -116,13 +117,22 @@
* The ``qml.FABLE`` template now returns the correct value when JIT is enabled.
[(#6263)](https://github.com/PennyLaneAI/pennylane/pull/6263)

* <h3>Contributors ✍️</h3>
* Fixes a bug where a circuit using the `autograd` interface sometimes returns nested values that are not of the `autograd` interface.
[(#6225)](https://github.com/PennyLaneAI/pennylane/pull/6225)

* Fixes a bug where a simple circuit with no parameters or only builtin/numpy arrays as parameters returns autograd tensors.
[(#6225)](https://github.com/PennyLaneAI/pennylane/pull/6225)

<h3>Contributors ✍️</h3>

This release contains contributions from (in alphabetical order):

Guillermo Alonso,
Utkarsh Azad,
Astral Cai,
Lillian M. A. Frederiksen,
Pietropaolo Frisoni,
Emiliano Godinez,
Christina Lee,
William Maxwell,
Lee J. O'Riordan,
Expand Down
50 changes: 43 additions & 7 deletions pennylane/_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from pennylane.capture import enabled
from pennylane.capture.capture_diff import _get_grad_prim, _get_jacobian_prim
from pennylane.capture.flatfn import FlatFn
from pennylane.compiler import compiler
from pennylane.compiler.compiler import CompileError

Expand All @@ -33,18 +34,53 @@

def _capture_diff(func, argnum=None, diff_prim=None, method=None, h=None):
"""Capture-compatible gradient computation."""
import jax # pylint: disable=import-outside-toplevel
# pylint: disable=import-outside-toplevel
import jax
from jax.tree_util import tree_flatten, tree_leaves, tree_unflatten, treedef_tuple

if isinstance(argnum, int):
argnum = [argnum]
if argnum is None:
argnum = [0]
argnum = 0
if argnum_is_int := isinstance(argnum, int):
argnum = [argnum]

@wraps(func)
def new_func(*args, **kwargs):
jaxpr = jax.make_jaxpr(partial(func, **kwargs))(*args)
prim_kwargs = {"argnum": argnum, "jaxpr": jaxpr.jaxpr, "n_consts": len(jaxpr.consts)}
return diff_prim.bind(*jaxpr.consts, *args, **prim_kwargs, method=method, h=h)
flat_args, in_trees = zip(*(tree_flatten(arg) for arg in args))
full_in_tree = treedef_tuple(in_trees)

# Create a new input tree that only takes inputs marked by argnum into account
trainable_in_trees = (in_tree for i, in_tree in enumerate(in_trees) if i in argnum)
# If an integer was provided as argnum, unpack the arguments axis of the derivatives
if argnum_is_int:
trainable_in_tree = list(trainable_in_trees)[0]
else:
trainable_in_tree = treedef_tuple(trainable_in_trees)

# Create argnum for the flat list of input arrays. For each flattened argument,
# add a list of flat argnums if the argument is trainable and an empty list otherwise.
start = 0
flat_argnum_gen = (
(
list(range(start, (start := start + len(flat_arg))))
if i in argnum
else list(range((start := start + len(flat_arg)), start))
)
for i, flat_arg in enumerate(flat_args)
)
flat_argnum = sum(flat_argnum_gen, start=[])

# Create fully flattened function (flat inputs & outputs)
flat_fn = FlatFn(partial(func, **kwargs) if kwargs else func, full_in_tree)
flat_args = sum(flat_args, start=[])
jaxpr = jax.make_jaxpr(flat_fn)(*flat_args)
prim_kwargs = {"argnum": flat_argnum, "jaxpr": jaxpr.jaxpr, "n_consts": len(jaxpr.consts)}
out_flat = diff_prim.bind(*jaxpr.consts, *flat_args, **prim_kwargs, method=method, h=h)
# flatten once more to go from 2D derivative structure (outputs, args) to flat structure
out_flat = tree_leaves(out_flat)
assert flat_fn.out_tree is not None, "out_tree should be set after executing flat_fn"
# The derivative output tree is the composition of output tree and trainable input trees
combined_tree = flat_fn.out_tree.compose(trainable_in_tree)
return tree_unflatten(combined_tree, out_flat)

return new_func

Expand Down
2 changes: 1 addition & 1 deletion pennylane/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
Version number (major.minor.patch[-label])
"""

__version__ = "0.39.0-dev14"
__version__ = "0.39.0-dev16"
3 changes: 3 additions & 0 deletions pennylane/capture/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
~create_measurement_wires_primitive
~create_measurement_mcm_primitive
~qnode_call
~FlatFn
The ``primitives`` submodule offers easy access to objects with jax dependencies such as
Expand Down Expand Up @@ -154,6 +155,7 @@ def _(*args, **kwargs):
create_measurement_mcm_primitive,
)
from .capture_qnode import qnode_call
from .flatfn import FlatFn

# by defining this here, we avoid
# E0611: No name 'AbstractOperator' in module 'pennylane.capture' (no-name-in-module)
Expand Down Expand Up @@ -196,4 +198,5 @@ def __getattr__(key):
"AbstractOperator",
"AbstractMeasurement",
"qnode_prim",
"FlatFn",
)
2 changes: 1 addition & 1 deletion pennylane/capture/capture_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def _(*args, argnum, jaxpr, n_consts, method, h):
def func(*inner_args):
return jax.core.eval_jaxpr(jaxpr, consts, *inner_args)

return jax.jacobian(func, argnums=argnum)(*args)
return jax.tree_util.tree_leaves(jax.jacobian(func, argnums=argnum)(*args))

# pylint: disable=unused-argument
@jacobian_prim.def_abstract_eval
Expand Down
10 changes: 7 additions & 3 deletions pennylane/capture/capture_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,12 @@ def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr, n_consts):
mps = qfunc_jaxpr.outvars
return _get_shapes_for(*mps, shots=shots, num_device_wires=len(device.wires))

def _qnode_jvp(*args_and_tangents, **impl_kwargs):
return jax.jvp(partial(qnode_prim.impl, **impl_kwargs), *args_and_tangents)
def make_zero(tan, arg):
return jax.lax.zeros_like_array(arg) if isinstance(tan, ad.Zero) else tan

def _qnode_jvp(args, tangents, **impl_kwargs):
tangents = tuple(map(make_zero, tangents, args))
return jax.jvp(partial(qnode_prim.impl, **impl_kwargs), args, tangents)

ad.primitive_jvps[qnode_prim] = _qnode_jvp

Expand Down Expand Up @@ -174,7 +178,7 @@ def f(x):
qnode_kwargs = {"diff_method": qnode.diff_method, **execute_kwargs, **mcm_config}
qnode_prim = _get_qnode_prim()

flat_args, _ = jax.tree_util.tree_flatten(args)
flat_args = jax.tree_util.tree_leaves(args)
res = qnode_prim.bind(
*qfunc_jaxpr.consts,
*flat_args,
Expand Down
7 changes: 3 additions & 4 deletions pennylane/capture/explanations.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,11 @@ You can also see the const variable `a` as argument `e:i32[]` to the inner neste
### Pytree handling

Evaluating a jaxpr requires accepting and returning a flat list of tensor-like inputs and outputs.
list of tensor-like outputs. These long lists can be hard to manage and are very
restrictive on the allowed functions, but we can take advantage of pytrees to allow handling
arbitrary functions.
These long lists can be hard to manage and are very restrictive on the allowed functions, but we
can take advantage of pytrees to allow handling arbitrary functions.

To start, we import the `FlatFn` helper. This class converts a function to one that caches
the resulting result pytree into `flat_fn.out_tree` when executed. This can be used to repack the
the result pytree into `flat_fn.out_tree` when executed. This can be used to repack the
results into the correct shape. It also returns flattened results. This does not particularly
matter for program capture, as we will only be producing jaxpr from the function, not calling
it directly.
Expand Down
31 changes: 28 additions & 3 deletions pennylane/capture/flatfn.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,48 @@ class FlatFn:
property, so that the results can be repacked later. It also returns flattened results
instead of the original result object.
If an ``in_tree`` is provided, the function accepts flattened inputs instead of the
original inputs with tree structure given by ``in_tree``.
**Example**
>>> import jax
>>> from pennylane.capture.flatfn import FlatFn
>>> def f(x):
... return {"y": 2+x["x"]}
>>> flat_f = FlatFn(f)
>>> res = flat_f({"x": 0})
>>> arg = {"x": 0.5}
>>> res = flat_f(arg)
>>> res
[2.5]
>>> jax.tree_util.tree_unflatten(flat_f.out_tree, res)
{'y': 2.5}
If we want to use a fully flattened function that also takes flat inputs instead of
the original inputs with tree structure, we can provide the treedef for this input
structure:
>>> flat_args, in_tree = jax.tree_util.tree_flatten((arg,))
>>> flat_f = FlatFn(f, in_tree)
>>> res = flat_f(*flat_args)
>>> res
[2]
[2.5]
>>> jax.tree_util.tree_unflatten(flat_f.out_tree, res)
{'y': 2.5}
Note that the ``in_tree`` has to be created by flattening a tuple of all input
arguments, even if there is only a single argument.
"""

def __init__(self, f):
def __init__(self, f, in_tree=None):
self.f = f
self.in_tree = in_tree
self.out_tree = None
update_wrapper(self, f)

def __call__(self, *args):
if self.in_tree is not None:
args = jax.tree_util.tree_unflatten(self.in_tree, args)
out = self.f(*args)
out_flat, out_tree = jax.tree_util.tree_flatten(out)
self.out_tree = out_tree
Expand Down
6 changes: 3 additions & 3 deletions pennylane/devices/execution_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from dataclasses import dataclass, field
from typing import Optional, Union

from pennylane.workflow import SUPPORTED_INTERFACES
from pennylane.workflow import SUPPORTED_INTERFACE_NAMES


@dataclass
Expand Down Expand Up @@ -110,9 +110,9 @@ def __post_init__(self):
Note that this hook is automatically called after init via the dataclass integration.
"""
if self.interface not in SUPPORTED_INTERFACES:
if self.interface not in SUPPORTED_INTERFACE_NAMES:
raise ValueError(
f"Unknown interface. interface must be in {SUPPORTED_INTERFACES}, got {self.interface} instead."
f"Unknown interface. interface must be in {SUPPORTED_INTERFACE_NAMES}, got {self.interface} instead."
)

if self.grad_on_execution not in {True, False, None}:
Expand Down
10 changes: 5 additions & 5 deletions pennylane/devices/legacy_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import pennylane as qml
from pennylane.measurements import MidMeasureMP, Shots
from pennylane.transforms.core.transform_program import TransformProgram
from pennylane.workflow.execution import INTERFACE_MAP

from .device_api import Device
from .execution_config import DefaultExecutionConfig
Expand Down Expand Up @@ -322,25 +323,24 @@ def _validate_backprop_method(self, tape):
return False
params = tape.get_parameters(trainable_only=False)
interface = qml.math.get_interface(*params)
if interface != "numpy":
interface = INTERFACE_MAP.get(interface, interface)

if tape and any(isinstance(m.obs, qml.SparseHamiltonian) for m in tape.measurements):
return False
if interface == "numpy":
interface = None
mapped_interface = qml.workflow.execution.INTERFACE_MAP.get(interface, interface)

# determine if the device supports backpropagation
backprop_interface = self._device.capabilities().get("passthru_interface", None)

if backprop_interface is not None:
# device supports backpropagation natively
return mapped_interface in [backprop_interface, "Numpy"]
return interface in [backprop_interface, "numpy"]
# determine if the device has any child devices that support backpropagation
backprop_devices = self._device.capabilities().get("passthru_devices", None)

if backprop_devices is None:
return False
return mapped_interface in backprop_devices or mapped_interface == "Numpy"
return interface in backprop_devices or interface == "numpy"

def _validate_adjoint_method(self, tape):
# The conditions below provide a minimal set of requirements that we can likely improve upon in
Expand Down
5 changes: 3 additions & 2 deletions pennylane/numpy/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,19 @@
it works with the PennyLane :class:`~.tensor` class.
"""

from autograd.numpy import random as _random
# isort: skip_file
from numpy import __version__ as np_version
from numpy.random import MT19937, PCG64, SFC64, Philox # pylint: disable=unused-import
from autograd.numpy import random as _random
from packaging.specifiers import SpecifierSet
from packaging.version import Version

from .wrapper import tensor_wrapper, wrap_arrays

wrap_arrays(_random.__dict__, globals())


if Version(np_version) in SpecifierSet(">=0.17.0"):

# pylint: disable=too-few-public-methods
# pylint: disable=missing-class-docstring
class Generator(_random.Generator):
Expand Down
4 changes: 4 additions & 0 deletions pennylane/ops/qubit/observables.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ def compute_matrix(A: TensorLike) -> TensorLike: # pylint: disable=arguments-di
Hermitian._validate_input(A)
return A

@staticmethod
def compute_sparse_matrix(A) -> csr_matrix: # pylint: disable=arguments-differ
return csr_matrix(Hermitian.compute_matrix(A))

@property
def eigendecomposition(self) -> dict[str, TensorLike]:
"""Return the eigendecomposition of the matrix specified by the Hermitian observable.
Expand Down
Loading

0 comments on commit 599b7e2

Please sign in to comment.