Skip to content

Commit

Permalink
Merge branch 'master' into lattice_models
Browse files Browse the repository at this point in the history
  • Loading branch information
ddhawan11 committed Sep 16, 2024
2 parents 2a19d6f + 79fc6d3 commit efb7a55
Show file tree
Hide file tree
Showing 104 changed files with 1,165 additions and 18,529 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/interface-unit-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -379,8 +379,7 @@ jobs:
# catalyst requires the latest version of pennylane that is about to be released.
# Installing catalyst after pennylane to make sure that the latest catalyst is used.
install_catalyst_nightly: true
# using lightning master does not work for the tests with external libraries
install_pennylane_lightning_master: false
install_pennylane_lightning_master: true
pytest_coverage_flags: ${{ inputs.pytest_coverage_flags }}
pytest_markers: external
additional_pip_packages: pyzx matplotlib stim quimb mitiq pennylane-qiskit ply
Expand Down
2 changes: 0 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,10 @@ clean-docs:

test:
$(PYTHON) $(TESTRUNNER)
$(PYTHON) $(PLUGIN_TESTRUNNER) --device=default.qubit.autograd

coverage:
@echo "Generating coverage report..."
$(PYTHON) $(TESTRUNNER) $(COVERAGE)
$(PYTHON) $(PLUGIN_TESTRUNNER) --device=default.qubit.autograd $(COVERAGE) --cov-append

.PHONY:format
format:
Expand Down
34 changes: 26 additions & 8 deletions doc/development/deprecations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,18 @@ deprecations are listed below.
Pending deprecations
--------------------

* All of the legacy devices (any with the name ``default.qubit.{autograd,torch,tf,jax,legacy}``) are deprecated. Use ``default.qubit`` instead,
as it supports backpropagation for the many backends the legacy devices support.
* ``Device``, ``QubitDevice``, and ``QutritDevice`` will no longer be imported top level in v0.40. They instead
we be available as ``qml.devices.LegacyDevice``, ``qml.devices.QubitDevice``, and ``qml.devices.QutritDevice``
respectively.

- Deprecated in v0.38
- Will be removed in v0.39
- Deprecated top level access in v0.39
- Top level access removed in v0.40

* The logic for internally switching a device for a different backpropagation
compatible device is now deprecated, as it was in place for the deprecated ``default.qubit.legacy``.
* `QNode.gradient_fn` is deprecated. Please use `QNode.diff_method` instead. `QNode.get_gradient_fn` can also be used to
process the diff method.

- Deprecated in v0.38
- Will be removed in v0.39
- Deprecated in v0.39
- Will be removed in v0.40

* The ``decomp_depth`` argument in ``qml.device`` is deprecated.

Expand Down Expand Up @@ -82,6 +83,23 @@ Other deprecations
Completed deprecation cycles
----------------------------

* All of the legacy devices (any with the name ``default.qubit.{autograd,torch,tf,jax,legacy}``) are removed. Use ``default.qubit`` instead,
as it supports backpropagation for the many backends the legacy devices support.

- Deprecated in v0.38
- Removed in v0.39

* The logic for internally switching a device for a different backpropagation
compatible device is removed, as it was in place for removed ``default.qubit.legacy``.

- Deprecated in v0.38
- Removed in v0.39

* `Operator.expand` is now removed. Use `qml.tape.QuantumScript(op.deocomposition())` instead.

- Deprecated in v0.38
- Removed in v0.39

* The ``expansion_strategy`` attribute of ``qml.QNode`` is removed.
Users should make use of ``qml.workflow.construct_batch``, should they require fine control over the output tape(s).

Expand Down
2 changes: 1 addition & 1 deletion doc/development/guide/logging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ process, and surrounding operations:
# Get logger for use by this script only.
logger = logging.getLogger(__name__)
dev_name = "default.qubit.jax"
dev_name = "default.qubit"
num_wires = 2
num_shots = None
Expand Down
2 changes: 2 additions & 0 deletions doc/development/release_notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ This page contains the release notes for PennyLane.

.. mdinclude:: ../releases/changelog-dev.md

.. mdinclude:: ../releases/changelog-0.38.1.md

.. mdinclude:: ../releases/changelog-0.38.0.md

.. mdinclude:: ../releases/changelog-0.37.0.md
Expand Down
8 changes: 4 additions & 4 deletions doc/introduction/interfaces/jax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ a JAX-capable QNode in PennyLane. Simply specify the ``interface='jax'`` keyword

.. code-block:: python
dev = qml.device('default.qubit.jax', wires=2)
dev = qml.device('default.qubit', wires=2)
@qml.qnode(dev, interface='jax')
def circuit1(phi, theta):
Expand Down Expand Up @@ -85,7 +85,7 @@ For example:

.. code-block:: python
dev = qml.device('default.qubit.jax', wires=2)
dev = qml.device('default.qubit', wires=2)
@qml.qnode(dev, interface='jax')
def circuit3(phi, theta):
Expand Down Expand Up @@ -119,7 +119,7 @@ the ``@jax.jit`` decorator can be directly applied to the QNode.

.. code-block:: python
dev = qml.device('default.qubit.jax', wires=2)
dev = qml.device('default.qubit', wires=2)
@jax.jit # QNode calls will now be jitted, and should run faster.
@qml.qnode(dev, interface='jax')
Expand Down Expand Up @@ -176,7 +176,7 @@ Example:
# Device construction should happen inside a `jax.jit` decorated
# method when using a PRNGKey.
dev = qml.device('default.qubit.jax', wires=2, prng_key=key, shots=100)
dev = qml.device('default.qubit', wires=2, prng_key=key, shots=100)
@qml.qnode(dev, interface='jax', diff_method=None)
Expand Down
2 changes: 1 addition & 1 deletion doc/releases/changelog-0.38.0.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
:orphan:

# Release 0.38.0 (current release)
# Release 0.38.0

<h3>New features since last release</h3>

Expand Down
14 changes: 14 additions & 0 deletions doc/releases/changelog-0.38.1.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
:orphan:

# Release 0.38.1 (current release)

<h3>Bug fixes 🐛</h3>

* Fix float-to-complex casting in various places across PennyLane.
[(#6260)](https://github.com/PennyLaneAI/pennylane/pull/6260)

<h3>Contributors ✍️</h3>

This release contains contributions from (in alphabetical order):

Mudit Pandey
23 changes: 21 additions & 2 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@
* Differentiation of hybrid programs via `qml.grad` and `qml.jacobian` can now be captured
into plxpr. When evaluating a captured `qml.grad` (`qml.jacobian`) instruction, it will
dispatch to `jax.grad` (`jax.jacobian`), which differs from the Autograd implementation
without capture.
without capture. Pytree inputs and outputs are supported.
[(#6120)](https://github.com/PennyLaneAI/pennylane/pull/6120)
[(#6127)](https://github.com/PennyLaneAI/pennylane/pull/6127)
[(#6134)](https://github.com/PennyLaneAI/pennylane/pull/6134)

* Improve unit testing for capturing of nested control flows.
[(#6111)](https://github.com/PennyLaneAI/pennylane/pull/6111)
Expand Down Expand Up @@ -62,9 +63,12 @@
* Remove support for Python 3.9.
[(#6223)](https://github.com/PennyLaneAI/pennylane/pull/6223)

* `DefaultQubitTF` and `DefaultQubitTorch` are removed. Please use `default.qubit` for all interfaces.
* `DefaultQubitTF`, `DefaultQubitTorch`, `DefaultQubitJax`, and `DefaultQubitAutograd` are removed.
Please use `default.qubit` for all interfaces.
[(#6207)](https://github.com/PennyLaneAI/pennylane/pull/6207)
[(#6208)](https://github.com/PennyLaneAI/pennylane/pull/6208)
[(#6209)](https://github.com/PennyLaneAI/pennylane/pull/6209)
[(#6210)](https://github.com/PennyLaneAI/pennylane/pull/6210)

* `expand_fn`, `max_expansion`, `override_shots`, and `device_batch_transform` are removed from the
signature of `qml.execute`.
Expand All @@ -80,8 +84,20 @@
Please use `qml.transforms.split_non_commuting` instead.
[(#6204)](https://github.com/PennyLaneAI/pennylane/pull/6204)

* `Operator.expand` is now removed. Use `qml.tape.QuantumScript(op.deocomposition())` instead.
[(#6227)](https://github.com/PennyLaneAI/pennylane/pull/6227)


<h3>Deprecations 👋</h3>

* `Device`, `QubitDevice`, and `QutritDevice` will no longer be accessible via top-level import in v0.40.
They will still be accessible as `qml.devices.LegacyDevice`, `qml.devices.QubitDevice`, and `qml.devices.QutritDevice`
respectively.
[(#6238)](https://github.com/PennyLaneAI/pennylane/pull/6238/)

* `QNode.gradient_fn` is deprecated. Please use `QNode.diff_method` and `QNode.get_gradient_fn` instead.
[(#6244)](https://github.com/PennyLaneAI/pennylane/pull/6244)

<h3>Documentation 📝</h3>

<h3>Bug fixes 🐛</h3>
Expand All @@ -101,6 +117,9 @@
* The ``qml.Qubitization`` template now orders the ``control`` wires first and the ``hamiltonian`` wires second, which is the expected according to other templates.
[(#6229)](https://github.com/PennyLaneAI/pennylane/pull/6229)

* The ``qml.FABLE`` template now returns the correct value when JIT is enabled.
[(#6263)](https://github.com/PennyLaneAI/pennylane/pull/6263)

* <h3>Contributors ✍️</h3>

This release contains contributions from (in alphabetical order):
Expand Down
19 changes: 17 additions & 2 deletions pennylane/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
PennyLane can be directly imported.
"""

import numpy as _np


from pennylane.boolean_fn import BooleanFn
import pennylane.numpy
Expand Down Expand Up @@ -180,13 +178,30 @@ def __getattr__(name):
if name == "plugin_devices":
return pennylane.devices.device_constructor.plugin_devices

from warnings import warn # pylint: disable=import-outside-toplevel

if name == "QubitDevice":
warn(
"QubitDevice will no longer be accessible top level. Please access "
" the class as pennylane.devices.QubitDevice",
PennyLaneDeprecationWarning,
)
return pennylane.devices._qubit_device.QubitDevice # pylint:disable=protected-access

if name == "QutritDevice":
warn(
"QutritDevice will no longer be accessible top level. Please access "
" the class as pennylane.devices.QutritDevice",
PennyLaneDeprecationWarning,
)
return pennylane.devices._qutrit_device.QutritDevice # pylint:disable=protected-access

if name == "Device":
warn(
"Device will no longer be accessible top level. Please access "
" the class as pennylane.devices.LegacyDevice",
PennyLaneDeprecationWarning,
)
return pennylane.devices._legacy_device.Device # pylint:disable=protected-access

raise AttributeError(f"module 'pennylane' has no attribute '{name}'")
Expand Down
50 changes: 43 additions & 7 deletions pennylane/_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from pennylane.capture import enabled
from pennylane.capture.capture_diff import _get_grad_prim, _get_jacobian_prim
from pennylane.capture.flatfn import FlatFn
from pennylane.compiler import compiler
from pennylane.compiler.compiler import CompileError

Expand All @@ -33,18 +34,53 @@

def _capture_diff(func, argnum=None, diff_prim=None, method=None, h=None):
"""Capture-compatible gradient computation."""
import jax # pylint: disable=import-outside-toplevel
# pylint: disable=import-outside-toplevel
import jax
from jax.tree_util import tree_flatten, tree_leaves, tree_unflatten, treedef_tuple

if isinstance(argnum, int):
argnum = [argnum]
if argnum is None:
argnum = [0]
argnum = 0
if argnum_is_int := isinstance(argnum, int):
argnum = [argnum]

@wraps(func)
def new_func(*args, **kwargs):
jaxpr = jax.make_jaxpr(partial(func, **kwargs))(*args)
prim_kwargs = {"argnum": argnum, "jaxpr": jaxpr.jaxpr, "n_consts": len(jaxpr.consts)}
return diff_prim.bind(*jaxpr.consts, *args, **prim_kwargs, method=method, h=h)
flat_args, in_trees = zip(*(tree_flatten(arg) for arg in args))
full_in_tree = treedef_tuple(in_trees)

# Create a new input tree that only takes inputs marked by argnum into account
trainable_in_trees = (in_tree for i, in_tree in enumerate(in_trees) if i in argnum)
# If an integer was provided as argnum, unpack the arguments axis of the derivatives
if argnum_is_int:
trainable_in_tree = list(trainable_in_trees)[0]
else:
trainable_in_tree = treedef_tuple(trainable_in_trees)

# Create argnum for the flat list of input arrays. For each flattened argument,
# add a list of flat argnums if the argument is trainable and an empty list otherwise.
start = 0
flat_argnum_gen = (
(
list(range(start, (start := start + len(flat_arg))))
if i in argnum
else list(range((start := start + len(flat_arg)), start))
)
for i, flat_arg in enumerate(flat_args)
)
flat_argnum = sum(flat_argnum_gen, start=[])

# Create fully flattened function (flat inputs & outputs)
flat_fn = FlatFn(partial(func, **kwargs) if kwargs else func, full_in_tree)
flat_args = sum(flat_args, start=[])
jaxpr = jax.make_jaxpr(flat_fn)(*flat_args)
prim_kwargs = {"argnum": flat_argnum, "jaxpr": jaxpr.jaxpr, "n_consts": len(jaxpr.consts)}
out_flat = diff_prim.bind(*jaxpr.consts, *flat_args, **prim_kwargs, method=method, h=h)
# flatten once more to go from 2D derivative structure (outputs, args) to flat structure
out_flat = tree_leaves(out_flat)
assert flat_fn.out_tree is not None, "out_tree should be set after executing flat_fn"
# The derivative output tree is the composition of output tree and trainable input trees
combined_tree = flat_fn.out_tree.compose(trainable_in_tree)
return tree_unflatten(combined_tree, out_flat)

return new_func

Expand Down
2 changes: 1 addition & 1 deletion pennylane/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
Version number (major.minor.patch[-label])
"""

__version__ = "0.39.0-dev11"
__version__ = "0.39.0-dev15"
3 changes: 3 additions & 0 deletions pennylane/capture/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
~create_measurement_wires_primitive
~create_measurement_mcm_primitive
~qnode_call
~FlatFn
The ``primitives`` submodule offers easy access to objects with jax dependencies such as
Expand Down Expand Up @@ -154,6 +155,7 @@ def _(*args, **kwargs):
create_measurement_mcm_primitive,
)
from .capture_qnode import qnode_call
from .flatfn import FlatFn

# by defining this here, we avoid
# E0611: No name 'AbstractOperator' in module 'pennylane.capture' (no-name-in-module)
Expand Down Expand Up @@ -196,4 +198,5 @@ def __getattr__(key):
"AbstractOperator",
"AbstractMeasurement",
"qnode_prim",
"FlatFn",
)
2 changes: 1 addition & 1 deletion pennylane/capture/capture_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def _(*args, argnum, jaxpr, n_consts, method, h):
def func(*inner_args):
return jax.core.eval_jaxpr(jaxpr, consts, *inner_args)

return jax.jacobian(func, argnums=argnum)(*args)
return jax.tree_util.tree_leaves(jax.jacobian(func, argnums=argnum)(*args))

# pylint: disable=unused-argument
@jacobian_prim.def_abstract_eval
Expand Down
10 changes: 7 additions & 3 deletions pennylane/capture/capture_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,12 @@ def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr, n_consts):
mps = qfunc_jaxpr.outvars
return _get_shapes_for(*mps, shots=shots, num_device_wires=len(device.wires))

def _qnode_jvp(*args_and_tangents, **impl_kwargs):
return jax.jvp(partial(qnode_prim.impl, **impl_kwargs), *args_and_tangents)
def make_zero(tan, arg):
return jax.lax.zeros_like_array(arg) if isinstance(tan, ad.Zero) else tan

def _qnode_jvp(args, tangents, **impl_kwargs):
tangents = tuple(map(make_zero, tangents, args))
return jax.jvp(partial(qnode_prim.impl, **impl_kwargs), args, tangents)

ad.primitive_jvps[qnode_prim] = _qnode_jvp

Expand Down Expand Up @@ -174,7 +178,7 @@ def f(x):
qnode_kwargs = {"diff_method": qnode.diff_method, **execute_kwargs, **mcm_config}
qnode_prim = _get_qnode_prim()

flat_args, _ = jax.tree_util.tree_flatten(args)
flat_args = jax.tree_util.tree_leaves(args)
res = qnode_prim.bind(
*qfunc_jaxpr.consts,
*flat_args,
Expand Down
7 changes: 3 additions & 4 deletions pennylane/capture/explanations.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,11 @@ You can also see the const variable `a` as argument `e:i32[]` to the inner neste
### Pytree handling

Evaluating a jaxpr requires accepting and returning a flat list of tensor-like inputs and outputs.
list of tensor-like outputs. These long lists can be hard to manage and are very
restrictive on the allowed functions, but we can take advantage of pytrees to allow handling
arbitrary functions.
These long lists can be hard to manage and are very restrictive on the allowed functions, but we
can take advantage of pytrees to allow handling arbitrary functions.

To start, we import the `FlatFn` helper. This class converts a function to one that caches
the resulting result pytree into `flat_fn.out_tree` when executed. This can be used to repack the
the result pytree into `flat_fn.out_tree` when executed. This can be used to repack the
results into the correct shape. It also returns flattened results. This does not particularly
matter for program capture, as we will only be producing jaxpr from the function, not calling
it directly.
Expand Down
Loading

0 comments on commit efb7a55

Please sign in to comment.