Skip to content

Commit

Permalink
Merge branch 'master' into sc-73711-onboard-to-large-runners
Browse files Browse the repository at this point in the history
  • Loading branch information
rashidnhm authored Sep 16, 2024
2 parents b030d4a + 228fdaf commit c2c1ede
Show file tree
Hide file tree
Showing 79 changed files with 841 additions and 7,996 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/install_deps/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,23 @@ inputs:
jax_version:
description: The version of JAX to install for any job that requires JAX
required: false
default: 0.4.23
default: '0.4.23'
install_tensorflow:
description: Indicate if TensorFlow should be installed or not
required: false
default: 'true'
tensorflow_version:
description: The version of TensorFlow to install for any job that requires TensorFlow
required: false
default: 2.16.0
default: '2.16.0'
install_pytorch:
description: Indicate if PyTorch should be installed or not
required: false
default: 'true'
pytorch_version:
description: The version of PyTorch to install for any job that requires PyTorch
required: false
default: 2.3.0
default: '2.3.0'
install_pennylane_lightning_master:
description: Indicate if PennyLane-Lightning should be installed from the master branch
required: false
Expand Down
3 changes: 1 addition & 2 deletions .github/workflows/interface-unit-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -400,8 +400,7 @@ jobs:
# catalyst requires the latest version of pennylane that is about to be released.
# Installing catalyst after pennylane to make sure that the latest catalyst is used.
install_catalyst_nightly: true
# using lightning master does not work for the tests with external libraries
install_pennylane_lightning_master: false
install_pennylane_lightning_master: true
pytest_coverage_flags: ${{ inputs.pytest_coverage_flags }}
pytest_markers: external
additional_pip_packages: pyzx matplotlib stim quimb mitiq pennylane-qiskit ply
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,5 @@ config.toml
qml_debug.log
datasets/*
.benchmarks/*
*.h5
*.hdf5
2 changes: 0 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,10 @@ clean-docs:

test:
$(PYTHON) $(TESTRUNNER)
$(PYTHON) $(PLUGIN_TESTRUNNER) --device=default.qubit.autograd

coverage:
@echo "Generating coverage report..."
$(PYTHON) $(TESTRUNNER) $(COVERAGE)
$(PYTHON) $(PLUGIN_TESTRUNNER) --device=default.qubit.autograd $(COVERAGE) --cov-append

.PHONY:format
format:
Expand Down
31 changes: 19 additions & 12 deletions doc/development/deprecations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,19 @@ deprecations are listed below.
Pending deprecations
--------------------

* ``Device``, ``QubitDevice``, and ``QutritDevice`` will no longer be imported top level in v0.40. They instead
we be available as ``qml.devices.LegacyDevice``, ``qml.devices.QubitDevice``, and ``qml.devices.QutritDevice``
respectively.

- Deprecated top level access in v0.39
- Top level access removed in v0.40

* `QNode.gradient_fn` is deprecated. Please use `QNode.diff_method` instead. `QNode.get_gradient_fn` can also be used to
process the diff method.

- Deprecated in v0.39
- Will be removed in v0.40

* All of the legacy devices (any with the name ``default.qubit.{autograd,torch,tf,jax,legacy}``) are deprecated. Use ``default.qubit`` instead,
as it supports backpropagation for the many backends the legacy devices support.

- Deprecated in v0.38
- Will be removed in v0.39

* The logic for internally switching a device for a different backpropagation
compatible device is now deprecated, as it was in place for the deprecated ``default.qubit.legacy``.

- Deprecated in v0.38
- Will be removed in v0.39

* The ``decomp_depth`` argument in ``qml.device`` is deprecated.

- Deprecated in v0.38
Expand Down Expand Up @@ -88,6 +83,18 @@ Other deprecations
Completed deprecation cycles
----------------------------

* All of the legacy devices (any with the name ``default.qubit.{autograd,torch,tf,jax,legacy}``) are removed. Use ``default.qubit`` instead,
as it supports backpropagation for the many backends the legacy devices support.

- Deprecated in v0.38
- Removed in v0.39

* The logic for internally switching a device for a different backpropagation
compatible device is removed, as it was in place for removed ``default.qubit.legacy``.

- Deprecated in v0.38
- Removed in v0.39

* `Operator.expand` is now removed. Use `qml.tape.QuantumScript(op.deocomposition())` instead.

- Deprecated in v0.38
Expand Down
37 changes: 27 additions & 10 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,28 @@

<h3>Improvements 🛠</h3>

* PennyLane is now compatible with NumPy 2.0.
[(#6061)](https://github.com/PennyLaneAI/pennylane/pull/6061)

* `qml.qchem.excitations` now optionally returns fermionic operators.
[(#6171)](https://github.com/PennyLaneAI/pennylane/pull/6171)
[(#6171)](https://github.com/PennyLaneAI/pennylane/pull/6171)

* The `diagonalize_measurements` transform now uses a more efficient method of diagonalization
when possible, based on the `pauli_rep` of the relevant observables.
[#6113](https://github.com/PennyLaneAI/pennylane/pull/6113/)

<h4>Capturing and representing hybrid programs</h4>

* Differentiation of hybrid programs via `qml.grad` can now be captured into plxpr.
When evaluating a captured `qml.grad` instruction, it will dispatch to `jax.grad`,
which differs from the Autograd implementation of `qml.grad` itself.
[(#6120)](https://github.com/PennyLaneAI/pennylane/pull/6120)
* The `Hermitian` operator now has a `compute_sparse_matrix` implementation.
[(#6225)](https://github.com/PennyLaneAI/pennylane/pull/6225)

<h4>Capturing and representing hybrid programs</h4>

* Differentiation of hybrid programs via `qml.grad` and `qml.jacobian` can now be captured
into plxpr. When evaluating a captured `qml.grad` (`qml.jacobian`) instruction, it will
dispatch to `jax.grad` (`jax.jacobian`), which differs from the Autograd implementation
without capture.
without capture. Pytree inputs and outputs are supported.
[(#6120)](https://github.com/PennyLaneAI/pennylane/pull/6120)
[(#6127)](https://github.com/PennyLaneAI/pennylane/pull/6127)
[(#6134)](https://github.com/PennyLaneAI/pennylane/pull/6134)

* Improve unit testing for capturing of nested control flows.
[(#6111)](https://github.com/PennyLaneAI/pennylane/pull/6111)
Expand Down Expand Up @@ -59,10 +59,12 @@
* Remove support for Python 3.9.
[(#6223)](https://github.com/PennyLaneAI/pennylane/pull/6223)

* `DefaultQubitTF`, `DefaultQubitTorch`, and `DefaultQubitJax` are removed. Please use `default.qubit` for all interfaces.
* `DefaultQubitTF`, `DefaultQubitTorch`, `DefaultQubitJax`, and `DefaultQubitAutograd` are removed.
Please use `default.qubit` for all interfaces.
[(#6207)](https://github.com/PennyLaneAI/pennylane/pull/6207)
[(#6208)](https://github.com/PennyLaneAI/pennylane/pull/6208)
[(#6209)](https://github.com/PennyLaneAI/pennylane/pull/6209)
[(#6210)](https://github.com/PennyLaneAI/pennylane/pull/6210)

* `expand_fn`, `max_expansion`, `override_shots`, and `device_batch_transform` are removed from the
signature of `qml.execute`.
Expand All @@ -81,8 +83,14 @@
* `Operator.expand` is now removed. Use `qml.tape.QuantumScript(op.deocomposition())` instead.
[(#6227)](https://github.com/PennyLaneAI/pennylane/pull/6227)


<h3>Deprecations 👋</h3>

* `Device`, `QubitDevice`, and `QutritDevice` will no longer be accessible via top-level import in v0.40.
They will still be accessible as `qml.devices.LegacyDevice`, `qml.devices.QubitDevice`, and `qml.devices.QutritDevice`
respectively.
[(#6238)](https://github.com/PennyLaneAI/pennylane/pull/6238/)

* `QNode.gradient_fn` is deprecated. Please use `QNode.diff_method` and `QNode.get_gradient_fn` instead.
[(#6244)](https://github.com/PennyLaneAI/pennylane/pull/6244)

Expand All @@ -108,13 +116,22 @@
* The ``qml.FABLE`` template now returns the correct value when JIT is enabled.
[(#6263)](https://github.com/PennyLaneAI/pennylane/pull/6263)

* <h3>Contributors ✍️</h3>
* Fixes a bug where a circuit using the `autograd` interface sometimes returns nested values that are not of the `autograd` interface.
[(#6225)](https://github.com/PennyLaneAI/pennylane/pull/6225)

* Fixes a bug where a simple circuit with no parameters or only builtin/numpy arrays as parameters returns autograd tensors.
[(#6225)](https://github.com/PennyLaneAI/pennylane/pull/6225)

<h3>Contributors ✍️</h3>

This release contains contributions from (in alphabetical order):

Guillermo Alonso,
Utkarsh Azad,
Astral Cai,
Lillian M. A. Frederiksen,
Pietropaolo Frisoni,
Emiliano Godinez,
Christina Lee,
William Maxwell,
Lee J. O'Riordan,
Expand Down
19 changes: 17 additions & 2 deletions pennylane/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
PennyLane can be directly imported.
"""

import numpy as _np


from pennylane.boolean_fn import BooleanFn
import pennylane.numpy
Expand Down Expand Up @@ -180,13 +178,30 @@ def __getattr__(name):
if name == "plugin_devices":
return pennylane.devices.device_constructor.plugin_devices

from warnings import warn # pylint: disable=import-outside-toplevel

if name == "QubitDevice":
warn(
"QubitDevice will no longer be accessible top level. Please access "
" the class as pennylane.devices.QubitDevice",
PennyLaneDeprecationWarning,
)
return pennylane.devices._qubit_device.QubitDevice # pylint:disable=protected-access

if name == "QutritDevice":
warn(
"QutritDevice will no longer be accessible top level. Please access "
" the class as pennylane.devices.QutritDevice",
PennyLaneDeprecationWarning,
)
return pennylane.devices._qutrit_device.QutritDevice # pylint:disable=protected-access

if name == "Device":
warn(
"Device will no longer be accessible top level. Please access "
" the class as pennylane.devices.LegacyDevice",
PennyLaneDeprecationWarning,
)
return pennylane.devices._legacy_device.Device # pylint:disable=protected-access

raise AttributeError(f"module 'pennylane' has no attribute '{name}'")
Expand Down
50 changes: 43 additions & 7 deletions pennylane/_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from pennylane.capture import enabled
from pennylane.capture.capture_diff import _get_grad_prim, _get_jacobian_prim
from pennylane.capture.flatfn import FlatFn
from pennylane.compiler import compiler
from pennylane.compiler.compiler import CompileError

Expand All @@ -33,18 +34,53 @@

def _capture_diff(func, argnum=None, diff_prim=None, method=None, h=None):
"""Capture-compatible gradient computation."""
import jax # pylint: disable=import-outside-toplevel
# pylint: disable=import-outside-toplevel
import jax
from jax.tree_util import tree_flatten, tree_leaves, tree_unflatten, treedef_tuple

if isinstance(argnum, int):
argnum = [argnum]
if argnum is None:
argnum = [0]
argnum = 0
if argnum_is_int := isinstance(argnum, int):
argnum = [argnum]

@wraps(func)
def new_func(*args, **kwargs):
jaxpr = jax.make_jaxpr(partial(func, **kwargs))(*args)
prim_kwargs = {"argnum": argnum, "jaxpr": jaxpr.jaxpr, "n_consts": len(jaxpr.consts)}
return diff_prim.bind(*jaxpr.consts, *args, **prim_kwargs, method=method, h=h)
flat_args, in_trees = zip(*(tree_flatten(arg) for arg in args))
full_in_tree = treedef_tuple(in_trees)

# Create a new input tree that only takes inputs marked by argnum into account
trainable_in_trees = (in_tree for i, in_tree in enumerate(in_trees) if i in argnum)
# If an integer was provided as argnum, unpack the arguments axis of the derivatives
if argnum_is_int:
trainable_in_tree = list(trainable_in_trees)[0]
else:
trainable_in_tree = treedef_tuple(trainable_in_trees)

# Create argnum for the flat list of input arrays. For each flattened argument,
# add a list of flat argnums if the argument is trainable and an empty list otherwise.
start = 0
flat_argnum_gen = (
(
list(range(start, (start := start + len(flat_arg))))
if i in argnum
else list(range((start := start + len(flat_arg)), start))
)
for i, flat_arg in enumerate(flat_args)
)
flat_argnum = sum(flat_argnum_gen, start=[])

# Create fully flattened function (flat inputs & outputs)
flat_fn = FlatFn(partial(func, **kwargs) if kwargs else func, full_in_tree)
flat_args = sum(flat_args, start=[])
jaxpr = jax.make_jaxpr(flat_fn)(*flat_args)
prim_kwargs = {"argnum": flat_argnum, "jaxpr": jaxpr.jaxpr, "n_consts": len(jaxpr.consts)}
out_flat = diff_prim.bind(*jaxpr.consts, *flat_args, **prim_kwargs, method=method, h=h)
# flatten once more to go from 2D derivative structure (outputs, args) to flat structure
out_flat = tree_leaves(out_flat)
assert flat_fn.out_tree is not None, "out_tree should be set after executing flat_fn"
# The derivative output tree is the composition of output tree and trainable input trees
combined_tree = flat_fn.out_tree.compose(trainable_in_tree)
return tree_unflatten(combined_tree, out_flat)

return new_func

Expand Down
2 changes: 1 addition & 1 deletion pennylane/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
Version number (major.minor.patch[-label])
"""

__version__ = "0.39.0-dev14"
__version__ = "0.39.0-dev15"
3 changes: 3 additions & 0 deletions pennylane/capture/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
~create_measurement_wires_primitive
~create_measurement_mcm_primitive
~qnode_call
~FlatFn
The ``primitives`` submodule offers easy access to objects with jax dependencies such as
Expand Down Expand Up @@ -154,6 +155,7 @@ def _(*args, **kwargs):
create_measurement_mcm_primitive,
)
from .capture_qnode import qnode_call
from .flatfn import FlatFn

# by defining this here, we avoid
# E0611: No name 'AbstractOperator' in module 'pennylane.capture' (no-name-in-module)
Expand Down Expand Up @@ -196,4 +198,5 @@ def __getattr__(key):
"AbstractOperator",
"AbstractMeasurement",
"qnode_prim",
"FlatFn",
)
2 changes: 1 addition & 1 deletion pennylane/capture/capture_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def _(*args, argnum, jaxpr, n_consts, method, h):
def func(*inner_args):
return jax.core.eval_jaxpr(jaxpr, consts, *inner_args)

return jax.jacobian(func, argnums=argnum)(*args)
return jax.tree_util.tree_leaves(jax.jacobian(func, argnums=argnum)(*args))

# pylint: disable=unused-argument
@jacobian_prim.def_abstract_eval
Expand Down
10 changes: 7 additions & 3 deletions pennylane/capture/capture_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,12 @@ def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr, n_consts):
mps = qfunc_jaxpr.outvars
return _get_shapes_for(*mps, shots=shots, num_device_wires=len(device.wires))

def _qnode_jvp(*args_and_tangents, **impl_kwargs):
return jax.jvp(partial(qnode_prim.impl, **impl_kwargs), *args_and_tangents)
def make_zero(tan, arg):
return jax.lax.zeros_like_array(arg) if isinstance(tan, ad.Zero) else tan

def _qnode_jvp(args, tangents, **impl_kwargs):
tangents = tuple(map(make_zero, tangents, args))
return jax.jvp(partial(qnode_prim.impl, **impl_kwargs), args, tangents)

ad.primitive_jvps[qnode_prim] = _qnode_jvp

Expand Down Expand Up @@ -174,7 +178,7 @@ def f(x):
qnode_kwargs = {"diff_method": qnode.diff_method, **execute_kwargs, **mcm_config}
qnode_prim = _get_qnode_prim()

flat_args, _ = jax.tree_util.tree_flatten(args)
flat_args = jax.tree_util.tree_leaves(args)
res = qnode_prim.bind(
*qfunc_jaxpr.consts,
*flat_args,
Expand Down
7 changes: 3 additions & 4 deletions pennylane/capture/explanations.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,11 @@ You can also see the const variable `a` as argument `e:i32[]` to the inner neste
### Pytree handling

Evaluating a jaxpr requires accepting and returning a flat list of tensor-like inputs and outputs.
list of tensor-like outputs. These long lists can be hard to manage and are very
restrictive on the allowed functions, but we can take advantage of pytrees to allow handling
arbitrary functions.
These long lists can be hard to manage and are very restrictive on the allowed functions, but we
can take advantage of pytrees to allow handling arbitrary functions.

To start, we import the `FlatFn` helper. This class converts a function to one that caches
the resulting result pytree into `flat_fn.out_tree` when executed. This can be used to repack the
the result pytree into `flat_fn.out_tree` when executed. This can be used to repack the
results into the correct shape. It also returns flattened results. This does not particularly
matter for program capture, as we will only be producing jaxpr from the function, not calling
it directly.
Expand Down
Loading

0 comments on commit c2c1ede

Please sign in to comment.