From d0344b02309fee6b2d01c7a9a433e89af4ad5a78 Mon Sep 17 00:00:00 2001 From: David Wierichs Date: Sun, 15 Sep 2024 21:47:02 +0200 Subject: [PATCH 1/6] [Program Capture] Add pytree support to captured `qml.grad` and `qml.jacobian` (#6134) **Context:** #6120 and #6127 add support to capture `qml.grad` and `qml.jacobian` in plxpr. Once captured, they dispatch to `jax.grad` and `jax.jacobian`. **Description of the Change:** This PR adds support for pytree inputs and outputs of the differentiated functions, similar to #6081. For this, it extends the internal class `FlatFn` by the extra functionality to turn the wrapper into a `*flat_args -> *flat_outputs` function, instead of a `*pytree_args -> *flat_outputs` function. **Benefits:** Pytree support :deciduous_tree: **Possible Drawbacks:** **Related GitHub Issues:** [sc-70930] [sc-71862] --------- Co-authored-by: Christina Lee Co-authored-by: Mudit Pandey --- doc/releases/changelog-dev.md | 3 +- pennylane/_grad.py | 50 ++++++-- pennylane/capture/__init__.py | 3 + pennylane/capture/capture_diff.py | 2 +- pennylane/capture/capture_qnode.py | 10 +- pennylane/capture/explanations.md | 7 +- pennylane/capture/flatfn.py | 31 ++++- tests/capture/test_capture_diff.py | 188 ++++++++++++++++++++++++++--- 8 files changed, 260 insertions(+), 34 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 488fd8d37e8..0a93727855d 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -25,9 +25,10 @@ * Differentiation of hybrid programs via `qml.grad` and `qml.jacobian` can now be captured into plxpr. When evaluating a captured `qml.grad` (`qml.jacobian`) instruction, it will dispatch to `jax.grad` (`jax.jacobian`), which differs from the Autograd implementation - without capture. + without capture. Pytree inputs and outputs are supported. [(#6120)](https://github.com/PennyLaneAI/pennylane/pull/6120) [(#6127)](https://github.com/PennyLaneAI/pennylane/pull/6127) + [(#6134)](https://github.com/PennyLaneAI/pennylane/pull/6134) * Improve unit testing for capturing of nested control flows. [(#6111)](https://github.com/PennyLaneAI/pennylane/pull/6111) diff --git a/pennylane/_grad.py b/pennylane/_grad.py index 859ae5d9fbb..d7cb52e0d52 100644 --- a/pennylane/_grad.py +++ b/pennylane/_grad.py @@ -25,6 +25,7 @@ from pennylane.capture import enabled from pennylane.capture.capture_diff import _get_grad_prim, _get_jacobian_prim +from pennylane.capture.flatfn import FlatFn from pennylane.compiler import compiler from pennylane.compiler.compiler import CompileError @@ -33,18 +34,53 @@ def _capture_diff(func, argnum=None, diff_prim=None, method=None, h=None): """Capture-compatible gradient computation.""" - import jax # pylint: disable=import-outside-toplevel + # pylint: disable=import-outside-toplevel + import jax + from jax.tree_util import tree_flatten, tree_leaves, tree_unflatten, treedef_tuple - if isinstance(argnum, int): - argnum = [argnum] if argnum is None: - argnum = [0] + argnum = 0 + if argnum_is_int := isinstance(argnum, int): + argnum = [argnum] @wraps(func) def new_func(*args, **kwargs): - jaxpr = jax.make_jaxpr(partial(func, **kwargs))(*args) - prim_kwargs = {"argnum": argnum, "jaxpr": jaxpr.jaxpr, "n_consts": len(jaxpr.consts)} - return diff_prim.bind(*jaxpr.consts, *args, **prim_kwargs, method=method, h=h) + flat_args, in_trees = zip(*(tree_flatten(arg) for arg in args)) + full_in_tree = treedef_tuple(in_trees) + + # Create a new input tree that only takes inputs marked by argnum into account + trainable_in_trees = (in_tree for i, in_tree in enumerate(in_trees) if i in argnum) + # If an integer was provided as argnum, unpack the arguments axis of the derivatives + if argnum_is_int: + trainable_in_tree = list(trainable_in_trees)[0] + else: + trainable_in_tree = treedef_tuple(trainable_in_trees) + + # Create argnum for the flat list of input arrays. For each flattened argument, + # add a list of flat argnums if the argument is trainable and an empty list otherwise. + start = 0 + flat_argnum_gen = ( + ( + list(range(start, (start := start + len(flat_arg)))) + if i in argnum + else list(range((start := start + len(flat_arg)), start)) + ) + for i, flat_arg in enumerate(flat_args) + ) + flat_argnum = sum(flat_argnum_gen, start=[]) + + # Create fully flattened function (flat inputs & outputs) + flat_fn = FlatFn(partial(func, **kwargs) if kwargs else func, full_in_tree) + flat_args = sum(flat_args, start=[]) + jaxpr = jax.make_jaxpr(flat_fn)(*flat_args) + prim_kwargs = {"argnum": flat_argnum, "jaxpr": jaxpr.jaxpr, "n_consts": len(jaxpr.consts)} + out_flat = diff_prim.bind(*jaxpr.consts, *flat_args, **prim_kwargs, method=method, h=h) + # flatten once more to go from 2D derivative structure (outputs, args) to flat structure + out_flat = tree_leaves(out_flat) + assert flat_fn.out_tree is not None, "out_tree should be set after executing flat_fn" + # The derivative output tree is the composition of output tree and trainable input trees + combined_tree = flat_fn.out_tree.compose(trainable_in_tree) + return tree_unflatten(combined_tree, out_flat) return new_func diff --git a/pennylane/capture/__init__.py b/pennylane/capture/__init__.py index 6deeef29682..2e43a246e2c 100644 --- a/pennylane/capture/__init__.py +++ b/pennylane/capture/__init__.py @@ -34,6 +34,7 @@ ~create_measurement_wires_primitive ~create_measurement_mcm_primitive ~qnode_call + ~FlatFn The ``primitives`` submodule offers easy access to objects with jax dependencies such as @@ -154,6 +155,7 @@ def _(*args, **kwargs): create_measurement_mcm_primitive, ) from .capture_qnode import qnode_call +from .flatfn import FlatFn # by defining this here, we avoid # E0611: No name 'AbstractOperator' in module 'pennylane.capture' (no-name-in-module) @@ -196,4 +198,5 @@ def __getattr__(key): "AbstractOperator", "AbstractMeasurement", "qnode_prim", + "FlatFn", ) diff --git a/pennylane/capture/capture_diff.py b/pennylane/capture/capture_diff.py index 92dde5a2956..829a9516af1 100644 --- a/pennylane/capture/capture_diff.py +++ b/pennylane/capture/capture_diff.py @@ -103,7 +103,7 @@ def _(*args, argnum, jaxpr, n_consts, method, h): def func(*inner_args): return jax.core.eval_jaxpr(jaxpr, consts, *inner_args) - return jax.jacobian(func, argnums=argnum)(*args) + return jax.tree_util.tree_leaves(jax.jacobian(func, argnums=argnum)(*args)) # pylint: disable=unused-argument @jacobian_prim.def_abstract_eval diff --git a/pennylane/capture/capture_qnode.py b/pennylane/capture/capture_qnode.py index f7f06451230..491b9f3f6a4 100644 --- a/pennylane/capture/capture_qnode.py +++ b/pennylane/capture/capture_qnode.py @@ -82,8 +82,12 @@ def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr, n_consts): mps = qfunc_jaxpr.outvars return _get_shapes_for(*mps, shots=shots, num_device_wires=len(device.wires)) - def _qnode_jvp(*args_and_tangents, **impl_kwargs): - return jax.jvp(partial(qnode_prim.impl, **impl_kwargs), *args_and_tangents) + def make_zero(tan, arg): + return jax.lax.zeros_like_array(arg) if isinstance(tan, ad.Zero) else tan + + def _qnode_jvp(args, tangents, **impl_kwargs): + tangents = tuple(map(make_zero, tangents, args)) + return jax.jvp(partial(qnode_prim.impl, **impl_kwargs), args, tangents) ad.primitive_jvps[qnode_prim] = _qnode_jvp @@ -174,7 +178,7 @@ def f(x): qnode_kwargs = {"diff_method": qnode.diff_method, **execute_kwargs, **mcm_config} qnode_prim = _get_qnode_prim() - flat_args, _ = jax.tree_util.tree_flatten(args) + flat_args = jax.tree_util.tree_leaves(args) res = qnode_prim.bind( *qfunc_jaxpr.consts, *flat_args, diff --git a/pennylane/capture/explanations.md b/pennylane/capture/explanations.md index 61b88b94030..84feef9786f 100644 --- a/pennylane/capture/explanations.md +++ b/pennylane/capture/explanations.md @@ -159,12 +159,11 @@ You can also see the const variable `a` as argument `e:i32[]` to the inner neste ### Pytree handling Evaluating a jaxpr requires accepting and returning a flat list of tensor-like inputs and outputs. -list of tensor-like outputs. These long lists can be hard to manage and are very -restrictive on the allowed functions, but we can take advantage of pytrees to allow handling -arbitrary functions. +These long lists can be hard to manage and are very restrictive on the allowed functions, but we +can take advantage of pytrees to allow handling arbitrary functions. To start, we import the `FlatFn` helper. This class converts a function to one that caches -the resulting result pytree into `flat_fn.out_tree` when executed. This can be used to repack the +the result pytree into `flat_fn.out_tree` when executed. This can be used to repack the results into the correct shape. It also returns flattened results. This does not particularly matter for program capture, as we will only be producing jaxpr from the function, not calling it directly. diff --git a/pennylane/capture/flatfn.py b/pennylane/capture/flatfn.py index 4ba6005f41e..59e1c52b948 100644 --- a/pennylane/capture/flatfn.py +++ b/pennylane/capture/flatfn.py @@ -29,23 +29,48 @@ class FlatFn: property, so that the results can be repacked later. It also returns flattened results instead of the original result object. + If an ``in_tree`` is provided, the function accepts flattened inputs instead of the + original inputs with tree structure given by ``in_tree``. + + **Example** + + >>> import jax + >>> from pennylane.capture.flatfn import FlatFn >>> def f(x): ... return {"y": 2+x["x"]} >>> flat_f = FlatFn(f) - >>> res = flat_f({"x": 0}) + >>> arg = {"x": 0.5} + >>> res = flat_f(arg) + >>> res + [2.5] + >>> jax.tree_util.tree_unflatten(flat_f.out_tree, res) + {'y': 2.5} + + If we want to use a fully flattened function that also takes flat inputs instead of + the original inputs with tree structure, we can provide the treedef for this input + structure: + + >>> flat_args, in_tree = jax.tree_util.tree_flatten((arg,)) + >>> flat_f = FlatFn(f, in_tree) + >>> res = flat_f(*flat_args) >>> res - [2] + [2.5] >>> jax.tree_util.tree_unflatten(flat_f.out_tree, res) {'y': 2.5} + Note that the ``in_tree`` has to be created by flattening a tuple of all input + arguments, even if there is only a single argument. """ - def __init__(self, f): + def __init__(self, f, in_tree=None): self.f = f + self.in_tree = in_tree self.out_tree = None update_wrapper(self, f) def __call__(self, *args): + if self.in_tree is not None: + args = jax.tree_util.tree_unflatten(self.in_tree, args) out = self.f(*args) out_flat, out_tree = jax.tree_util.tree_flatten(out) self.out_tree = out_tree diff --git a/tests/capture/test_capture_diff.py b/tests/capture/test_capture_diff.py index cf7834aafb8..07596e01b47 100644 --- a/tests/capture/test_capture_diff.py +++ b/tests/capture/test_capture_diff.py @@ -99,11 +99,11 @@ def func_jax(x): assert qml.math.allclose(func_qml(x), jax_out) # Check overall jaxpr properties - if isinstance(argnum, int): - argnum = [argnum] jaxpr = jax.make_jaxpr(func_qml)(x) assert jaxpr.in_avals == [jax.core.ShapedArray((), fdtype, weak_type=True)] assert len(jaxpr.eqns) == 3 + if isinstance(argnum, int): + argnum = [argnum] assert jaxpr.out_avals == [jax.core.ShapedArray((), fdtype, weak_type=True)] * len(argnum) grad_eqn = jaxpr.eqns[2] @@ -260,6 +260,106 @@ def circuit(x): jax.config.update("jax_enable_x64", initial_mode) + @pytest.mark.parametrize("argnum", ([0, 1], [0], [1])) + def test_grad_pytree_input(self, x64_mode, argnum): + """Test that the qml.grad primitive can be captured with pytree inputs.""" + + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + fdtype = jax.numpy.float64 if x64_mode else jax.numpy.float32 + + def inner_func(x, y): + return jnp.prod(jnp.sin(x["a"]) * jnp.cos(y[0]["b"][1]) ** 2) + + def func_qml(x): + return qml.grad(inner_func, argnum=argnum)( + {"a": x}, ({"b": [None, 0.4 * jnp.sqrt(x)]},) + ) + + def func_jax(x): + return jax.grad(inner_func, argnums=argnum)( + {"a": x}, ({"b": [None, 0.4 * jnp.sqrt(x)]},) + ) + + x = 0.7 + jax_out = func_jax(x) + jax_out_flat, jax_out_tree = jax.tree_util.tree_flatten(jax_out) + qml_out_flat, qml_out_tree = jax.tree_util.tree_flatten(func_qml(x)) + assert jax_out_tree == qml_out_tree + assert qml.math.allclose(jax_out_flat, qml_out_flat) + + # Check overall jaxpr properties + jaxpr = jax.make_jaxpr(func_qml)(x) + assert jaxpr.in_avals == [jax.core.ShapedArray((), fdtype, weak_type=True)] + assert len(jaxpr.eqns) == 3 + argnum = [argnum] if isinstance(argnum, int) else argnum + assert jaxpr.out_avals == [jax.core.ShapedArray((), fdtype, weak_type=True)] * len(argnum) + + grad_eqn = jaxpr.eqns[2] + diff_eqn_assertions(grad_eqn, grad_prim, argnum=argnum) + assert [var.aval for var in grad_eqn.outvars] == jaxpr.out_avals + assert len(grad_eqn.params["jaxpr"].eqns) == 6 # 5 numeric eqns, 1 conversion eqn + + manual_out = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x) + manual_out_flat, manual_out_tree = jax.tree_util.tree_flatten(manual_out) + # Assert that the output from the manual evaluation is flat + assert manual_out_tree == jax.tree_util.tree_flatten(manual_out_flat)[1] + assert qml.math.allclose(jax_out_flat, manual_out_flat) + + jax.config.update("jax_enable_x64", initial_mode) + + @pytest.mark.parametrize("argnum", ([0, 1, 2], [0, 2], [1], 0)) + def test_grad_qnode_with_pytrees(self, argnum, x64_mode): + """Test capturing the gradient of a qnode that uses Pytrees.""" + # pylint: disable=protected-access + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + fdtype = jax.numpy.float64 if x64_mode else jax.numpy.float32 + + dev = qml.device("default.qubit", wires=2) + + @qml.qnode(dev) + def circuit(x, y, z): + qml.RX(x["a"], wires=0) + qml.RY(y, wires=0) + qml.RZ(z[1][0], wires=0) + return qml.expval(qml.X(0)) + + dcircuit = qml.grad(circuit, argnum=argnum) + x = {"a": 0.6, "b": 0.9} + y = 0.6 + z = ({"c": 0.5}, [0.2, 0.3]) + qml_out = dcircuit(x, y, z) + qml_out_flat, qml_out_tree = jax.tree_util.tree_flatten(qml_out) + jax_out = jax.grad(circuit, argnums=argnum)(x, y, z) + jax_out_flat, jax_out_tree = jax.tree_util.tree_flatten(jax_out) + assert jax_out_tree == qml_out_tree + assert qml.math.allclose(jax_out_flat, qml_out_flat) + + jaxpr = jax.make_jaxpr(dcircuit)(x, y, z) + + assert len(jaxpr.eqns) == 1 # grad equation + assert jaxpr.in_avals == [jax.core.ShapedArray((), fdtype, weak_type=True)] * 6 + argnum = [argnum] if isinstance(argnum, int) else argnum + num_out_avals = 2 * (0 in argnum) + (1 in argnum) + 3 * (2 in argnum) + assert jaxpr.out_avals == [jax.core.ShapedArray((), fdtype, weak_type=True)] * num_out_avals + + grad_eqn = jaxpr.eqns[0] + assert all(invar.aval == in_aval for invar, in_aval in zip(grad_eqn.invars, jaxpr.in_avals)) + flat_argnum = [0, 1] * (0 in argnum) + [2] * (1 in argnum) + [3, 4, 5] * (2 in argnum) + diff_eqn_assertions(grad_eqn, grad_prim, argnum=flat_argnum) + grad_jaxpr = grad_eqn.params["jaxpr"] + assert len(grad_jaxpr.eqns) == 1 # qnode equation + + flat_args = jax.tree_util.tree_leaves((x, y, z)) + manual_out = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *flat_args) + manual_out_flat, manual_out_tree = jax.tree_util.tree_flatten(manual_out) + # Assert that the output from the manual evaluation is flat + assert manual_out_tree == jax.tree_util.tree_flatten(manual_out_flat)[1] + assert qml.math.allclose(jax_out_flat, manual_out_flat) + + jax.config.update("jax_enable_x64", initial_mode) + def _jac_allclose(jac1, jac2, num_axes, atol=1e-8): """Test that two Jacobians, given as nested sequences of arrays, are equal.""" @@ -279,10 +379,6 @@ class TestJacobian: @pytest.mark.parametrize("argnum", ([0, 1], [0], [1], 0, 1)) def test_classical_jacobian(self, x64_mode, argnum): """Test that the qml.jacobian primitive can be captured with classical nodes.""" - if isinstance(argnum, list) and len(argnum) > 1: - # These cases will only be unlocked with Pytree support - pytest.xfail() - initial_mode = jax.config.jax_enable_x64 jax.config.update("jax_enable_x64", x64_mode) fdtype = jnp.float64 if x64_mode else jnp.float32 @@ -307,14 +403,14 @@ def inner_func(x, y): func_jax = jax.jacobian(inner_func, argnums=argnum) jax_out = func_jax(x, y) - num_axes = 1 if isinstance(argnum, int) else 2 - assert _jac_allclose(func_qml(x, y), jax_out, num_axes) + qml_out = func_qml(x, y) + num_axes = 1 if (int_argnum := isinstance(argnum, int)) else 2 + assert _jac_allclose(qml_out, jax_out, num_axes) # Check overall jaxpr properties - jaxpr = jax.make_jaxpr(func_jax)(x, y) jaxpr = jax.make_jaxpr(func_qml)(x, y) - if isinstance(argnum, int): + if int_argnum: argnum = [argnum] exp_in_avals = [shaped_array(shape) for shape in [(4,), (2, 3)]] @@ -331,6 +427,9 @@ def inner_func(x, y): diff_eqn_assertions(jac_eqn, jacobian_prim, argnum=argnum) manual_eval = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x, y) + # Evaluating jaxpr gives flat list results. Need to adapt the JAX output to that + if not int_argnum: + jax_out = sum(jax_out, start=()) assert _jac_allclose(manual_eval, jax_out, num_axes) jax.config.update("jax_enable_x64", initial_mode) @@ -354,7 +453,6 @@ def func(x): return jnp.prod(x) * jnp.sin(x), jnp.sum(x**2) x = jnp.array([0.7, -0.9, 0.6, 0.3]) - x = x[:1] dim = len(x) eye = jnp.eye(dim) @@ -382,15 +480,20 @@ def func(x): # 2nd order qml_func_2 = qml.jacobian(qml_func_1) + hyperdiag = qml.numpy.zeros((4, 4, 4)) + for i in range(4): + hyperdiag[i, i, i] = 1 expected_2 = ( prod_sin[:, None, None] / x[None, :, None] / x[None, None, :] + - jnp.tensordot(prod_sin, eye / x**2, axes=0) # Correct diagonal entries + prod_cos_e_i[:, :, None] / x[None, None, :] + prod_cos_e_i[:, None, :] / x[None, :, None] - - jnp.tensordot(prod_sin, eye + eye / x**2, axes=0), - jnp.tensordot(jnp.ones(dim), eye * 2, axes=0), + - prod_sin * hyperdiag, + eye * 2, ) # Output only has one tuple axis - assert _jac_allclose(qml_func_2(x), expected_2, 1) + atol = 1e-8 if x64_mode else 2e-7 + assert _jac_allclose(qml_func_2(x), expected_2, 1, atol=atol) jaxpr_2 = jax.make_jaxpr(qml_func_2)(x) assert jaxpr_2.in_avals == [jax.core.ShapedArray((dim,), fdtype)] @@ -406,7 +509,7 @@ def func(x): assert jac_eqn.params["jaxpr"].eqns[0].primitive == jacobian_prim manual_eval_2 = jax.core.eval_jaxpr(jaxpr_2.jaxpr, jaxpr_2.consts, x) - assert _jac_allclose(manual_eval_2, expected_2, 1) + assert _jac_allclose(manual_eval_2, expected_2, 1, atol=atol) jax.config.update("jax_enable_x64", initial_mode) @@ -473,3 +576,58 @@ def circuit(x): assert _jac_allclose(manual_res, expected_res, 1) jax.config.update("jax_enable_x64", initial_mode) + + @pytest.mark.parametrize("argnum", ([0, 1], [0], [1])) + def test_jacobian_pytrees(self, x64_mode, argnum): + """Test that the qml.jacobian primitive can be captured with + pytree inputs and outputs.""" + + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + fdtype = jax.numpy.float64 if x64_mode else jax.numpy.float32 + + def inner_func(x, y): + return { + "prod_cos": jnp.prod(jnp.sin(x["a"]) * jnp.cos(y[0]["b"][1]) ** 2), + "sum_sin": jnp.sum(jnp.sin(x["a"]) * jnp.sin(y[1]["c"]) ** 2), + } + + def func_qml(x): + return qml.jacobian(inner_func, argnum=argnum)( + {"a": x}, ({"b": [None, 0.4 * jnp.sqrt(x)]}, {"c": 0.5}) + ) + + def func_jax(x): + return jax.jacobian(inner_func, argnums=argnum)( + {"a": x}, ({"b": [None, 0.4 * jnp.sqrt(x)]}, {"c": 0.5}) + ) + + x = 0.7 + jax_out = func_jax(x) + jax_out_flat, jax_out_tree = jax.tree_util.tree_flatten(jax_out) + qml_out_flat, qml_out_tree = jax.tree_util.tree_flatten(func_qml(x)) + assert jax_out_tree == qml_out_tree + assert qml.math.allclose(jax_out_flat, qml_out_flat) + + # Check overall jaxpr properties + jaxpr = jax.make_jaxpr(func_qml)(x) + assert jaxpr.in_avals == [jax.core.ShapedArray((), fdtype, weak_type=True)] + assert len(jaxpr.eqns) == 3 + + argnum = [argnum] if isinstance(argnum, int) else argnum + # Compute the flat argnum in order to determine the expected number of out tracers + flat_argnum = [0] * (0 in argnum) + [1, 2] * (1 in argnum) + assert jaxpr.out_avals == [jax.core.ShapedArray((), fdtype)] * (2 * len(flat_argnum)) + + jac_eqn = jaxpr.eqns[2] + + diff_eqn_assertions(jac_eqn, jacobian_prim, argnum=flat_argnum) + assert [var.aval for var in jac_eqn.outvars] == jaxpr.out_avals + + manual_out = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x) + manual_out_flat, manual_out_tree = jax.tree_util.tree_flatten(manual_out) + # Assert that the output from the manual evaluation is flat + assert manual_out_tree == jax.tree_util.tree_flatten(manual_out_flat)[1] + assert qml.math.allclose(jax_out_flat, manual_out_flat) + + jax.config.update("jax_enable_x64", initial_mode) From 79fc6d3dadb4a1dcd435f28e0d6e26aceb72ae3f Mon Sep 17 00:00:00 2001 From: ringo-but-quantum Date: Mon, 16 Sep 2024 09:51:46 +0000 Subject: [PATCH 2/6] [no ci] bump nightly version --- pennylane/_version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pennylane/_version.py b/pennylane/_version.py index a6ead820881..77639685bc6 100644 --- a/pennylane/_version.py +++ b/pennylane/_version.py @@ -16,4 +16,4 @@ Version number (major.minor.patch[-label]) """ -__version__ = "0.39.0-dev14" +__version__ = "0.39.0-dev15" From 94f067a257d1e2a96812d30b576bc90da5381318 Mon Sep 17 00:00:00 2001 From: Cristian Emiliano Godinez Ramirez <57567043+EmilianoG-byte@users.noreply.github.com> Date: Mon, 16 Sep 2024 15:08:41 +0200 Subject: [PATCH 3/6] Pennylane is compatible with numpy 2.0 (#6061) **Context:** We want to make Pennylane compatible with Numpy 2.0. After several discussions, we decided to test NumPy 2.0 on the CI by default in every PR (testing both Python versions would have been to slow). Some jobs still downgrade automatically to Numpy 1.x, since some interfaces (such as Tensorflow) still do not support NumPy 2.0. **Description of the Change:** We can distinguish the changes into 3 main categories: *Changes to workflows* - None in the final version *Changes to requirements and setup files* - Unpin the Numpy version in `setup.py` (now we also allow Numpy 2.0). - Update `requirements-ci.txt` to include Scipy 1.13 (this adds support for Numpy 2.0). - Pin Numpy in `requirements-ci.txt` to 2.0. *Changes to the source code* - Change `np.NaN` to `np.nan`. - Use legacy printing representation in tests, contrary to the new numpy representation of scalars, e.g. np.float64(3.0) rather than just 3.0. - Update probabilities warning to be case insensitive and check for a partial match, since this warning was changed in Numpy 2.0. - Check the datatype of np.exp from the Global phase only for Numpy 1.x, since this gets promoted to complex128 in Numpy 2.x. https://numpy.org/neps/nep-0050-scalar-promotion.html#schema-of-the-new-proposed-promotion-rules. **Benefits:** Make Pennylane compatible with Numpy 2.0. **Possible Drawbacks:** - We need to create a separate workflow to keep testing PennyLane with NumPy 1.x, since we still want to maintain compatibility with previous NumPy versions. This will be done in a separate PR. - We are not testing Numpy 2.x for the interfaces that implicitly require Numpy 1.x. These currently seem to be `tensorflow` and `openfermionpyscf` (notice that `tensorflow` is required in some code sections like qcut). In particular, `openfermionpyscf` causes an error: ``` AttributeError: np.string_ was removed in the NumPy 2.0 release. Use np.bytes_ instead. ``` in the qchem tests. The attribute `np.string_` is not used in the PL source code, so it is a problem with the package itself. [sc-61399] [sc-66548] --------- Co-authored-by: PietropaoloFrisoni Co-authored-by: Pietropaolo Frisoni --- .github/workflows/install_deps/action.yml | 6 +-- .gitignore | 2 + doc/releases/changelog-dev.md | 5 +++ pennylane/numpy/random.py | 5 ++- requirements-ci.txt | 2 +- setup.py | 2 +- .../data/attributes/operator/test_operator.py | 11 +++-- tests/data/attributes/test_dict.py | 10 +++-- tests/data/attributes/test_list.py | 12 +++-- tests/data/base/test_attribute.py | 4 +- tests/devices/qubit/test_measure.py | 8 ++-- tests/devices/qubit/test_sampling.py | 14 +++--- .../test_qutrit_mixed_sampling.py | 2 +- tests/devices/test_default_qubit_legacy.py | 44 ++++++++++--------- tests/devices/test_qubit_device.py | 6 +-- tests/measurements/test_counts.py | 6 +-- .../test_subroutines/test_prepselprep.py | 9 ++-- 17 files changed, 88 insertions(+), 60 deletions(-) diff --git a/.github/workflows/install_deps/action.yml b/.github/workflows/install_deps/action.yml index 809358ce745..99b77dc8157 100644 --- a/.github/workflows/install_deps/action.yml +++ b/.github/workflows/install_deps/action.yml @@ -15,7 +15,7 @@ inputs: jax_version: description: The version of JAX to install for any job that requires JAX required: false - default: 0.4.23 + default: '0.4.23' install_tensorflow: description: Indicate if TensorFlow should be installed or not required: false @@ -23,7 +23,7 @@ inputs: tensorflow_version: description: The version of TensorFlow to install for any job that requires TensorFlow required: false - default: 2.16.0 + default: '2.16.0' install_pytorch: description: Indicate if PyTorch should be installed or not required: false @@ -31,7 +31,7 @@ inputs: pytorch_version: description: The version of PyTorch to install for any job that requires PyTorch required: false - default: 2.3.0 + default: '2.3.0' install_pennylane_lightning_master: description: Indicate if PennyLane-Lightning should be installed from the master branch required: false diff --git a/.gitignore b/.gitignore index fc69a5281fd..d9da3038c69 100644 --- a/.gitignore +++ b/.gitignore @@ -27,3 +27,5 @@ config.toml qml_debug.log datasets/* .benchmarks/* +*.h5 +*.hdf5 diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 0a93727855d..506504f631b 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -6,6 +6,9 @@

Improvements 🛠

+* PennyLane is now compatible with NumPy 2.0. + [(#6061)](https://github.com/PennyLaneAI/pennylane/pull/6061) + * `qml.qchem.excitations` now optionally returns fermionic operators. [(#6171)](https://github.com/PennyLaneAI/pennylane/pull/6171) @@ -124,6 +127,8 @@ This release contains contributions from (in alphabetical order): Guillermo Alonso, Utkarsh Azad, Lillian M. A. Frederiksen, +Pietropaolo Frisoni, +Emiliano Godinez, Christina Lee, William Maxwell, Lee J. O'Riordan, diff --git a/pennylane/numpy/random.py b/pennylane/numpy/random.py index eae1511cf4f..12a00e798a5 100644 --- a/pennylane/numpy/random.py +++ b/pennylane/numpy/random.py @@ -16,9 +16,10 @@ it works with the PennyLane :class:`~.tensor` class. """ -from autograd.numpy import random as _random +# isort: skip_file from numpy import __version__ as np_version from numpy.random import MT19937, PCG64, SFC64, Philox # pylint: disable=unused-import +from autograd.numpy import random as _random from packaging.specifiers import SpecifierSet from packaging.version import Version @@ -26,8 +27,8 @@ wrap_arrays(_random.__dict__, globals()) - if Version(np_version) in SpecifierSet(">=0.17.0"): + # pylint: disable=too-few-public-methods # pylint: disable=missing-class-docstring class Generator(_random.Generator): diff --git a/requirements-ci.txt b/requirements-ci.txt index 083beaae25f..d552b95e904 100644 --- a/requirements-ci.txt +++ b/requirements-ci.txt @@ -1,5 +1,5 @@ numpy -scipy<1.13.0 +scipy<=1.13.0 cvxpy cvxopt networkx diff --git a/setup.py b/setup.py index e13673fb1fa..41ae9775027 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ version = f.readlines()[-1].split()[-1].strip("\"'") requirements = [ - "numpy<2.0", + "numpy<=2.0", "scipy", "networkx", "rustworkx>=0.14.0", diff --git a/tests/data/attributes/operator/test_operator.py b/tests/data/attributes/operator/test_operator.py index 83e0c658ab6..8079b720f89 100644 --- a/tests/data/attributes/operator/test_operator.py +++ b/tests/data/attributes/operator/test_operator.py @@ -174,6 +174,7 @@ def test_value_init(self, attribute_cls, op_in): """Test that a DatasetOperator can be value-initialized from an operator, and that the deserialized operator is equivalent.""" + if not qml.operation.active_new_opmath() and isinstance(op_in, qml.ops.LinearCombination): op_in = qml.operation.convert_to_legacy_H(op_in) @@ -183,7 +184,8 @@ def test_value_init(self, attribute_cls, op_in): assert dset_op.info["py_type"] == get_type_str(type(op_in)) op_out = dset_op.get_value() - assert repr(op_out) == repr(op_in) + with np.printoptions(legacy="1.21"): + assert repr(op_out) == repr(op_in) assert op_in.data == op_out.data @pytest.mark.parametrize( @@ -199,6 +201,7 @@ def test_bind_init(self, attribute_cls, op_in): """Test that a DatasetOperator can be bind-initialized from an operator, and that the deserialized operator is equivalent.""" + if not qml.operation.active_new_opmath() and isinstance(op_in, qml.ops.LinearCombination): op_in = qml.operation.convert_to_legacy_H(op_in) @@ -210,10 +213,12 @@ def test_bind_init(self, attribute_cls, op_in): assert dset_op.info["py_type"] == get_type_str(type(op_in)) op_out = dset_op.get_value() - assert repr(op_out) == repr(op_in) + with np.printoptions(legacy="1.21"): + assert repr(op_out) == repr(op_in) assert op_in.data == op_out.data assert op_in.wires == op_out.wires - assert repr(op_in) == repr(op_out) + with np.printoptions(legacy="1.21"): + assert repr(op_in) == repr(op_out) @pytest.mark.parametrize("attribute_cls", [DatasetOperator, DatasetPyTree]) diff --git a/tests/data/attributes/test_dict.py b/tests/data/attributes/test_dict.py index 6bf6e202fd6..3e3a2a9d4b2 100644 --- a/tests/data/attributes/test_dict.py +++ b/tests/data/attributes/test_dict.py @@ -15,6 +15,7 @@ Tests for the ``DatasetDict`` attribute type. """ +import numpy as np import pytest from pennylane.data.attributes import DatasetDict @@ -45,7 +46,8 @@ def test_value_init(self, value): assert dset_dict.info.py_type == "dict" assert dset_dict.bind.keys() == value.keys() assert len(dset_dict) == len(value) - assert repr(value) == repr(dset_dict) + with np.printoptions(legacy="1.21"): + assert repr(value) == repr(dset_dict) @pytest.mark.parametrize( "value", [{"a": 1, "b": 2}, {}, {"a": 1, "b": {"x": "y", "z": [1, 2]}}] @@ -93,7 +95,8 @@ def test_copy(self, value): assert builtin_dict.keys() == value.keys() assert len(builtin_dict) == len(value) - assert repr(builtin_dict) == repr(value) + with np.printoptions(legacy="1.21"): + assert repr(builtin_dict) == repr(value) @pytest.mark.parametrize( "value", [{"a": 1, "b": 2}, {}, {"a": 1, "b": {"x": "y", "z": [1, 2]}}] @@ -121,4 +124,5 @@ def test_equality_same_length(self): ) def test_string_conversion(self, value): dset_dict = DatasetDict(value) - assert str(dset_dict) == str(value) + with np.printoptions(legacy="1.21"): + assert str(dset_dict) == str(value) diff --git a/tests/data/attributes/test_list.py b/tests/data/attributes/test_list.py index eef27057616..2f4c937d178 100644 --- a/tests/data/attributes/test_list.py +++ b/tests/data/attributes/test_list.py @@ -18,6 +18,7 @@ from itertools import combinations +import numpy as np import pytest from pennylane.data import DatasetList @@ -56,8 +57,9 @@ def test_value_init(self, input_type, value): lst = DatasetList(input_type(value)) assert lst == value - assert repr(lst) == repr(value) assert len(lst) == len(value) + with np.printoptions(legacy="1.21"): + assert repr(lst) == repr(value) @pytest.mark.parametrize("input_type", (list, tuple)) @pytest.mark.parametrize("value", [[], [1], [1, 2, 3], ["a", "b", "c"], [{"a": 1}]]) @@ -148,12 +150,14 @@ def test_setitem_out_of_range(self, index): @pytest.mark.parametrize("value", [[], [1], [1, 2, 3], ["a", "b", "c"], [{"a": 1}]]) def test_copy(self, input_type, value): """Test that a `DatasetList` can be copied.""" + ds = DatasetList(input_type(value)) ds_copy = ds.copy() assert ds_copy == value - assert repr(ds_copy) == repr(value) assert len(ds_copy) == len(value) + with np.printoptions(legacy="1.21"): + assert repr(ds_copy) == repr(value) @pytest.mark.parametrize("input_type", (list, tuple)) @pytest.mark.parametrize("value", [[], [1], [1, 2, 3], ["a", "b", "c"], [{"a": 1}]]) @@ -169,8 +173,10 @@ def test_equality(self, input_type, value): @pytest.mark.parametrize("value", [[], [1], [1, 2, 3], ["a", "b", "c"], [{"a": 1}]]) def test_string_conversion(self, value): """Test that a `DatasetList` is converted to a string correctly.""" + dset_dict = DatasetList(value) - assert str(dset_dict) == str(value) + with np.printoptions(legacy="1.21"): + assert str(dset_dict) == str(value) @pytest.mark.parametrize("value", [[1], [1, 2, 3], ["a", "b", "c"], [{"a": 1}]]) def test_deleting_elements(self, value): diff --git a/tests/data/base/test_attribute.py b/tests/data/base/test_attribute.py index d38249c1672..da500db48e5 100644 --- a/tests/data/base/test_attribute.py +++ b/tests/data/base/test_attribute.py @@ -285,8 +285,8 @@ def test_bind_init_from_other_bind(self): ) def test_repr(self, val, attribute_type): """Test that __repr__ has the expected format.""" - - assert repr(attribute(val)) == f"{attribute_type.__name__}({repr(val)})" + with np.printoptions(legacy="1.21"): + assert repr(attribute(val)) == f"{attribute_type.__name__}({repr(val)})" @pytest.mark.parametrize( "val", diff --git a/tests/devices/qubit/test_measure.py b/tests/devices/qubit/test_measure.py index d0c618311cf..47e4d8c2a31 100644 --- a/tests/devices/qubit/test_measure.py +++ b/tests/devices/qubit/test_measure.py @@ -302,7 +302,7 @@ class TestNaNMeasurements: def test_nan_float_result(self, mp, interface): """Test that the result of circuits with 0 probability postselections is NaN with the expected shape.""" - state = qml.math.full((2, 2), np.NaN, like=interface) + state = qml.math.full((2, 2), np.nan, like=interface) res = measure(mp, state, is_state_batched=False) assert qml.math.ndim(res) == 0 @@ -339,7 +339,7 @@ def test_nan_float_result(self, mp, interface): def test_nan_float_result_jax(self, mp, use_jit): """Test that the result of circuits with 0 probability postselections is NaN with the expected shape.""" - state = qml.math.full((2, 2), np.NaN, like="jax") + state = qml.math.full((2, 2), np.nan, like="jax") if use_jit: import jax @@ -360,7 +360,7 @@ def test_nan_float_result_jax(self, mp, use_jit): def test_nan_probs(self, mp, interface): """Test that the result of circuits with 0 probability postselections is NaN with the expected shape.""" - state = qml.math.full((2, 2), np.NaN, like=interface) + state = qml.math.full((2, 2), np.nan, like=interface) res = measure(mp, state, is_state_batched=False) assert qml.math.shape(res) == (2 ** len(mp.wires),) @@ -375,7 +375,7 @@ def test_nan_probs(self, mp, interface): def test_nan_probs_jax(self, mp, use_jit): """Test that the result of circuits with 0 probability postselections is NaN with the expected shape.""" - state = qml.math.full((2, 2), np.NaN, like="jax") + state = qml.math.full((2, 2), np.nan, like="jax") if use_jit: import jax diff --git a/tests/devices/qubit/test_sampling.py b/tests/devices/qubit/test_sampling.py index 4174ed63aae..e36c69c26a3 100644 --- a/tests/devices/qubit/test_sampling.py +++ b/tests/devices/qubit/test_sampling.py @@ -591,7 +591,7 @@ def test_only_catch_nan_errors(self, shots): mp = qml.expval(qml.PauliZ(0)) _shots = Shots(shots) - with pytest.raises(ValueError, match="probabilities do not sum to 1"): + with pytest.raises(ValueError, match=r"(?i)probabilities do not sum to 1"): _ = measure_with_samples([mp], state, _shots) @pytest.mark.all_interfaces @@ -619,7 +619,7 @@ def test_only_catch_nan_errors(self, shots): def test_nan_float_result(self, mp, interface, shots): """Test that the result of circuits with 0 probability postselections is NaN with the expected shape.""" - state = qml.math.full((2, 2), np.NaN, like=interface) + state = qml.math.full((2, 2), np.nan, like=interface) res = measure_with_samples((mp,), state, _FlexShots(shots), is_state_batched=False) if not isinstance(shots, list): @@ -646,7 +646,7 @@ def test_nan_float_result(self, mp, interface, shots): def test_nan_samples(self, mp, interface, shots): """Test that the result of circuits with 0 probability postselections is NaN with the expected shape.""" - state = qml.math.full((2, 2), np.NaN, like=interface) + state = qml.math.full((2, 2), np.nan, like=interface) res = measure_with_samples((mp,), state, _FlexShots(shots), is_state_batched=False) if not isinstance(shots, list): @@ -672,7 +672,7 @@ def test_nan_samples(self, mp, interface, shots): def test_nan_classical_shadows(self, interface, shots): """Test that classical_shadows returns an empty array when the state has NaN values""" - state = qml.math.full((2, 2), np.NaN, like=interface) + state = qml.math.full((2, 2), np.nan, like=interface) res = measure_with_samples( (qml.classical_shadow([0]),), state, _FlexShots(shots), is_state_batched=False ) @@ -699,7 +699,7 @@ def test_nan_classical_shadows(self, interface, shots): def test_nan_shadow_expval(self, H, interface, shots): """Test that shadow_expval returns an empty array when the state has NaN values""" - state = qml.math.full((2, 2), np.NaN, like=interface) + state = qml.math.full((2, 2), np.nan, like=interface) res = measure_with_samples( (qml.shadow_expval(H),), state, _FlexShots(shots), is_state_batched=False ) @@ -757,7 +757,7 @@ def test_sample_state_renorm_error(self, interface): """Test that renormalization does not occur if the error is too large.""" state = qml.math.array(two_qubit_state_not_normalized, like=interface) - with pytest.raises(ValueError, match="probabilities do not sum to 1"): + with pytest.raises(ValueError, match=r"(?i)probabilities do not sum to 1"): _ = sample_state(state, 10) @pytest.mark.all_interfaces @@ -775,7 +775,7 @@ def test_sample_batched_state_renorm_error(self, interface): """Test that renormalization does not occur if the error is too large.""" state = qml.math.array(batched_state_not_normalized, like=interface) - with pytest.raises(ValueError, match="probabilities do not sum to 1"): + with pytest.raises(ValueError, match=r"(?i)probabilities do not sum to 1"): _ = sample_state(state, 10, is_state_batched=True) diff --git a/tests/devices/qutrit_mixed/test_qutrit_mixed_sampling.py b/tests/devices/qutrit_mixed/test_qutrit_mixed_sampling.py index eb3383ed5a6..ecd2fbbcca8 100644 --- a/tests/devices/qutrit_mixed/test_qutrit_mixed_sampling.py +++ b/tests/devices/qutrit_mixed/test_qutrit_mixed_sampling.py @@ -402,7 +402,7 @@ def test_only_catch_nan_errors(self, shots): mp = qml.sample(wires=range(2)) _shots = Shots(shots) - with pytest.raises(ValueError, match="probabilities do not sum to 1"): + with pytest.raises(ValueError, match=r"(?i)probabilities do not sum to 1"): _ = measure_with_samples(mp, state, _shots) @pytest.mark.parametrize("mp", [qml.probs(0), qml.probs(op=qml.GellMann(0, 1))]) diff --git a/tests/devices/test_default_qubit_legacy.py b/tests/devices/test_default_qubit_legacy.py index 11ca082441c..9b67d18b5c6 100644 --- a/tests/devices/test_default_qubit_legacy.py +++ b/tests/devices/test_default_qubit_legacy.py @@ -18,6 +18,7 @@ # pylint: disable=protected-access,cell-var-from-loop import cmath import math +from importlib.metadata import version import pytest @@ -628,7 +629,8 @@ def test_apply_global_phase(self, qubit_device_3_wires, tol, wire, input_state): expected_output = np.array(input_state) * np.exp(-1j * phase) assert np.allclose(qubit_device_3_wires._state, np.array(expected_output), atol=tol, rtol=0) - assert qubit_device_3_wires._state.dtype == qubit_device_3_wires.C_DTYPE + if version("numpy") < "2.0.0": + assert qubit_device_3_wires._state.dtype == qubit_device_3_wires.C_DTYPE def test_apply_errors_qubit_state_vector(self, qubit_device_2_wires): """Test that apply fails for incorrect state preparation, and > 2 qubit gates""" @@ -650,26 +652,26 @@ def test_apply_errors_qubit_state_vector(self, qubit_device_2_wires): ) def test_apply_errors_basis_state(self, qubit_device_2_wires): - - with pytest.raises( - ValueError, match=r"Basis state must only consist of 0s and 1s; got \[-0\.2, 4\.2\]" - ): - qubit_device_2_wires.apply([qml.BasisState(np.array([-0.2, 4.2]), wires=[0, 1])]) - - with pytest.raises( - ValueError, match=r"State must be of length 1; got length 2 \(state=\[0 1\]\)\." - ): - qubit_device_2_wires.apply([qml.BasisState(np.array([0, 1]), wires=[0])]) - - with pytest.raises( - qml.DeviceError, - match="Operation BasisState cannot be used after other Operations have already been applied " - "on a default.qubit.legacy device.", - ): - qubit_device_2_wires.reset() - qubit_device_2_wires.apply( - [qml.RZ(0.5, wires=[0]), qml.BasisState(np.array([1, 1]), wires=[0, 1])] - ) + with np.printoptions(legacy="1.21"): + with pytest.raises( + ValueError, match=r"Basis state must only consist of 0s and 1s; got \[-0\.2, 4\.2\]" + ): + qubit_device_2_wires.apply([qml.BasisState(np.array([-0.2, 4.2]), wires=[0, 1])]) + + with pytest.raises( + ValueError, match=r"State must be of length 1; got length 2 \(state=\[0 1\]\)\." + ): + qubit_device_2_wires.apply([qml.BasisState(np.array([0, 1]), wires=[0])]) + + with pytest.raises( + qml.DeviceError, + match="Operation BasisState cannot be used after other Operations have already been applied " + "on a default.qubit.legacy device.", + ): + qubit_device_2_wires.reset() + qubit_device_2_wires.apply( + [qml.RZ(0.5, wires=[0]), qml.BasisState(np.array([1, 1]), wires=[0, 1])] + ) class TestExpval: diff --git a/tests/devices/test_qubit_device.py b/tests/devices/test_qubit_device.py index 9edc522a408..8f50bb65329 100644 --- a/tests/devices/test_qubit_device.py +++ b/tests/devices/test_qubit_device.py @@ -1605,9 +1605,9 @@ def test_samples_to_counts_with_nan(self): # imitate hardware return with NaNs (requires dtype float) samples = qml.math.cast_like(samples, np.array([1.2])) - samples[0][0] = np.NaN - samples[17][1] = np.NaN - samples[850][0] = np.NaN + samples[0][0] = np.nan + samples[17][1] = np.nan + samples[850][0] = np.nan result = device._samples_to_counts(samples, mp=qml.measurements.CountsMP(), num_wires=2) diff --git a/tests/measurements/test_counts.py b/tests/measurements/test_counts.py index 3f3badb5c0e..08da35015c9 100644 --- a/tests/measurements/test_counts.py +++ b/tests/measurements/test_counts.py @@ -135,9 +135,9 @@ def test_counts_with_nan_samples(self): rng = np.random.default_rng(123) samples = rng.choice([0, 1], size=(shots, 2)).astype(np.float64) - samples[0][0] = np.NaN - samples[17][1] = np.NaN - samples[850][0] = np.NaN + samples[0][0] = np.nan + samples[17][1] = np.nan + samples[850][0] = np.nan result = qml.counts(wires=[0, 1]).process_samples(samples, wire_order=[0, 1]) diff --git a/tests/templates/test_subroutines/test_prepselprep.py b/tests/templates/test_subroutines/test_prepselprep.py index 95e7f771ef7..7e3b7966c28 100644 --- a/tests/templates/test_subroutines/test_prepselprep.py +++ b/tests/templates/test_subroutines/test_prepselprep.py @@ -47,13 +47,16 @@ def test_standard_checks(lcu, control): def test_repr(): """Test the repr method.""" + lcu = qml.dot([0.25, 0.75], [qml.Z(2), qml.X(1) @ qml.X(2)]) control = [0] op = qml.PrepSelPrep(lcu, control) - assert ( - repr(op) == "PrepSelPrep(coeffs=(0.25, 0.75), ops=(Z(2), X(1) @ X(2)), control=Wires([0]))" - ) + with np.printoptions(legacy="1.21"): + assert ( + repr(op) + == "PrepSelPrep(coeffs=(0.25, 0.75), ops=(Z(2), X(1) @ X(2)), control=Wires([0]))" + ) def _get_new_terms(lcu): From 228fdaf815df25ade861636d294d8e5ef51d1aa2 Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Mon, 16 Sep 2024 09:57:22 -0400 Subject: [PATCH 4/6] Clean up how `interface` is handled in `QNode` and `qml.execute` (#6225) Regarding `numpy` and `autograd`: - When the parameters are of the `numpy` interface, internally treat it as `interface=None`. - Does not change the behaviour of treating user specified `interface="numpy"` as using autograd. Regarding interfaces in general: - The set of canonical interface names in `INTERFACE_MAP` is expanded to include more specific names such as `jax-jit`, and `tf-autograph`. `_convert_to_interfaces` in `qnode.py` uses a separate `interface_conversion_map` to further map the specific interfaces to their corresponding general interface names that can be passed to the `like` argument of `qml.math.asarray` (e.g. "tf" to "tensorflow", "jax-jit" to "jax"). - In `QNode` and `qml.execute`, every time we get an interface from user input or `qml.math.get_interface`, we map it to a canonical interface name using `INTERFACE_MAP`. Aside from these two scenarios, we assume that the interface name is one of the canonical interface names everywhere else. `QNode.interface` is now assumed to be one of the canonical interface names. - User input of `interface=None` gets mapped to `numpy` immediately. Internally, `QNode.interface` will never be `None`. It'll be `numpy` for having no interface. - If `qml.math.get_interface` returns `numpy`, we do not map it to anything. We keep `numpy`. Collateral bug fix included as well: - Fixes a bug where a circuit of the `autograd` interfaces sometimes returns results that are not `autograd`. - Adds `compute_sparse_matrix` to `Hermitian` [sc-73144] --------- Co-authored-by: Christina Lee --- doc/releases/changelog-dev.md | 19 +++-- pennylane/devices/execution_config.py | 6 +- pennylane/devices/legacy_facade.py | 10 +-- pennylane/devices/qubit/simulate.py | 4 +- pennylane/ops/qubit/observables.py | 4 + pennylane/workflow/__init__.py | 2 +- pennylane/workflow/execution.py | 75 ++++++++++-------- pennylane/workflow/interfaces/autograd.py | 17 ++++- pennylane/workflow/qnode.py | 40 +++++++--- .../default_qubit/test_default_qubit.py | 2 +- tests/devices/qubit/test_simulate.py | 2 +- .../finite_diff/test_spsa_gradient.py | 76 ++++++++++--------- .../test_spsa_gradient_shot_vec.py | 70 ++++++++--------- tests/interfaces/test_jax_jit.py | 2 +- tests/measurements/test_sample.py | 4 +- tests/qnn/test_keras.py | 6 +- tests/qnn/test_qnn_torch.py | 6 +- tests/test_qnode.py | 18 ++++- tests/test_qnode_legacy.py | 2 +- 19 files changed, 221 insertions(+), 144 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 506504f631b..559f1f20b6e 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -10,18 +10,14 @@ [(#6061)](https://github.com/PennyLaneAI/pennylane/pull/6061) * `qml.qchem.excitations` now optionally returns fermionic operators. - [(#6171)](https://github.com/PennyLaneAI/pennylane/pull/6171) + [(#6171)](https://github.com/PennyLaneAI/pennylane/pull/6171) * The `diagonalize_measurements` transform now uses a more efficient method of diagonalization when possible, based on the `pauli_rep` of the relevant observables. [#6113](https://github.com/PennyLaneAI/pennylane/pull/6113/) -

Capturing and representing hybrid programs

- -* Differentiation of hybrid programs via `qml.grad` can now be captured into plxpr. - When evaluating a captured `qml.grad` instruction, it will dispatch to `jax.grad`, - which differs from the Autograd implementation of `qml.grad` itself. - [(#6120)](https://github.com/PennyLaneAI/pennylane/pull/6120) +* The `Hermitian` operator now has a `compute_sparse_matrix` implementation. + [(#6225)](https://github.com/PennyLaneAI/pennylane/pull/6225)

Capturing and representing hybrid programs

@@ -120,12 +116,19 @@ * The ``qml.FABLE`` template now returns the correct value when JIT is enabled. [(#6263)](https://github.com/PennyLaneAI/pennylane/pull/6263) -*

Contributors ✍️

+* Fixes a bug where a circuit using the `autograd` interface sometimes returns nested values that are not of the `autograd` interface. + [(#6225)](https://github.com/PennyLaneAI/pennylane/pull/6225) + +* Fixes a bug where a simple circuit with no parameters or only builtin/numpy arrays as parameters returns autograd tensors. + [(#6225)](https://github.com/PennyLaneAI/pennylane/pull/6225) + +

Contributors ✍️

This release contains contributions from (in alphabetical order): Guillermo Alonso, Utkarsh Azad, +Astral Cai, Lillian M. A. Frederiksen, Pietropaolo Frisoni, Emiliano Godinez, diff --git a/pennylane/devices/execution_config.py b/pennylane/devices/execution_config.py index 5b7af096d81..7f3866d9e86 100644 --- a/pennylane/devices/execution_config.py +++ b/pennylane/devices/execution_config.py @@ -17,7 +17,7 @@ from dataclasses import dataclass, field from typing import Optional, Union -from pennylane.workflow import SUPPORTED_INTERFACES +from pennylane.workflow import SUPPORTED_INTERFACE_NAMES @dataclass @@ -110,9 +110,9 @@ def __post_init__(self): Note that this hook is automatically called after init via the dataclass integration. """ - if self.interface not in SUPPORTED_INTERFACES: + if self.interface not in SUPPORTED_INTERFACE_NAMES: raise ValueError( - f"Unknown interface. interface must be in {SUPPORTED_INTERFACES}, got {self.interface} instead." + f"Unknown interface. interface must be in {SUPPORTED_INTERFACE_NAMES}, got {self.interface} instead." ) if self.grad_on_execution not in {True, False, None}: diff --git a/pennylane/devices/legacy_facade.py b/pennylane/devices/legacy_facade.py index 41c1e0dea2c..bd2190f0fe1 100644 --- a/pennylane/devices/legacy_facade.py +++ b/pennylane/devices/legacy_facade.py @@ -24,6 +24,7 @@ import pennylane as qml from pennylane.measurements import MidMeasureMP, Shots from pennylane.transforms.core.transform_program import TransformProgram +from pennylane.workflow.execution import INTERFACE_MAP from .device_api import Device from .execution_config import DefaultExecutionConfig @@ -322,25 +323,24 @@ def _validate_backprop_method(self, tape): return False params = tape.get_parameters(trainable_only=False) interface = qml.math.get_interface(*params) + if interface != "numpy": + interface = INTERFACE_MAP.get(interface, interface) if tape and any(isinstance(m.obs, qml.SparseHamiltonian) for m in tape.measurements): return False - if interface == "numpy": - interface = None - mapped_interface = qml.workflow.execution.INTERFACE_MAP.get(interface, interface) # determine if the device supports backpropagation backprop_interface = self._device.capabilities().get("passthru_interface", None) if backprop_interface is not None: # device supports backpropagation natively - return mapped_interface in [backprop_interface, "Numpy"] + return interface in [backprop_interface, "numpy"] # determine if the device has any child devices that support backpropagation backprop_devices = self._device.capabilities().get("passthru_devices", None) if backprop_devices is None: return False - return mapped_interface in backprop_devices or mapped_interface == "Numpy" + return interface in backprop_devices or interface == "numpy" def _validate_adjoint_method(self, tape): # The conditions below provide a minimal set of requirements that we can likely improve upon in diff --git a/pennylane/devices/qubit/simulate.py b/pennylane/devices/qubit/simulate.py index 56e4a8f1a48..89c041b8f3e 100644 --- a/pennylane/devices/qubit/simulate.py +++ b/pennylane/devices/qubit/simulate.py @@ -922,7 +922,7 @@ def _(original_measurement: ExpectationMP, measures): # pylint: disable=unused- for v in measures.values(): if not v[0] or v[1] is tuple(): continue - cum_value += v[0] * v[1] + cum_value += qml.math.multiply(v[0], v[1]) total_counts += v[0] return cum_value / total_counts @@ -935,7 +935,7 @@ def _(original_measurement: ProbabilityMP, measures): # pylint: disable=unused- for v in measures.values(): if not v[0] or v[1] is tuple(): continue - cum_value += v[0] * v[1] + cum_value += qml.math.multiply(v[0], v[1]) total_counts += v[0] return cum_value / total_counts diff --git a/pennylane/ops/qubit/observables.py b/pennylane/ops/qubit/observables.py index 8f992c81bc2..4fc4a98c092 100644 --- a/pennylane/ops/qubit/observables.py +++ b/pennylane/ops/qubit/observables.py @@ -137,6 +137,10 @@ def compute_matrix(A: TensorLike) -> TensorLike: # pylint: disable=arguments-di Hermitian._validate_input(A) return A + @staticmethod + def compute_sparse_matrix(A) -> csr_matrix: # pylint: disable=arguments-differ + return csr_matrix(Hermitian.compute_matrix(A)) + @property def eigendecomposition(self) -> dict[str, TensorLike]: """Return the eigendecomposition of the matrix specified by the Hermitian observable. diff --git a/pennylane/workflow/__init__.py b/pennylane/workflow/__init__.py index 55068804b68..b41c031e8a4 100644 --- a/pennylane/workflow/__init__.py +++ b/pennylane/workflow/__init__.py @@ -56,6 +56,6 @@ """ from .construct_batch import construct_batch, get_transform_program -from .execution import INTERFACE_MAP, SUPPORTED_INTERFACES, execute +from .execution import INTERFACE_MAP, SUPPORTED_INTERFACE_NAMES, execute from .qnode import QNode, qnode from .set_shots import set_shots diff --git a/pennylane/workflow/execution.py b/pennylane/workflow/execution.py index 7445bcea2b7..8d8f0adb9ef 100644 --- a/pennylane/workflow/execution.py +++ b/pennylane/workflow/execution.py @@ -51,12 +51,9 @@ "autograd", "numpy", "torch", - "pytorch", "jax", - "jax-python", "jax-jit", "tf", - "tensorflow", } SupportedInterfaceUserInput = Literal[ @@ -78,30 +75,29 @@ ] _mapping_output = ( - "Numpy", + "numpy", "auto", "autograd", "autograd", "numpy", "jax", - "jax", + "jax-jit", "jax", "jax", "torch", "torch", "tf", "tf", - "tf", - "tf", + "tf-autograph", + "tf-autograph", ) + INTERFACE_MAP = dict(zip(get_args(SupportedInterfaceUserInput), _mapping_output)) """dict[str, str]: maps an allowed interface specification to its canonical name.""" -#: list[str]: allowed interface strings -SUPPORTED_INTERFACES = list(INTERFACE_MAP) +SUPPORTED_INTERFACE_NAMES = list(INTERFACE_MAP) """list[str]: allowed interface strings""" - _CACHED_EXECUTION_WITH_FINITE_SHOTS_WARNINGS = ( "Cached execution with finite shots detected!\n" "Note that samples as well as all noisy quantities computed via sampling " @@ -135,23 +131,21 @@ def _get_ml_boundary_execute( pennylane.QuantumFunctionError if the required package is not installed. """ - mapped_interface = INTERFACE_MAP[interface] try: - if mapped_interface == "autograd": + if interface == "autograd": from .interfaces.autograd import autograd_execute as ml_boundary - elif mapped_interface == "tf": - if "autograph" in interface: - from .interfaces.tensorflow_autograph import execute as ml_boundary + elif interface == "tf-autograph": + from .interfaces.tensorflow_autograph import execute as ml_boundary - ml_boundary = partial(ml_boundary, grad_on_execution=grad_on_execution) + ml_boundary = partial(ml_boundary, grad_on_execution=grad_on_execution) - else: - from .interfaces.tensorflow import tf_execute as full_ml_boundary + elif interface == "tf": + from .interfaces.tensorflow import tf_execute as full_ml_boundary - ml_boundary = partial(full_ml_boundary, differentiable=differentiable) + ml_boundary = partial(full_ml_boundary, differentiable=differentiable) - elif mapped_interface == "torch": + elif interface == "torch": from .interfaces.torch import execute as ml_boundary elif interface == "jax-jit": @@ -159,7 +153,8 @@ def _get_ml_boundary_execute( from .interfaces.jax_jit import jax_jit_vjp_execute as ml_boundary else: from .interfaces.jax_jit import jax_jit_jvp_execute as ml_boundary - else: # interface in {"jax", "jax-python", "JAX"}: + + else: # interface is jax if device_vjp: from .interfaces.jax_jit import jax_jit_vjp_execute as ml_boundary else: @@ -167,9 +162,10 @@ def _get_ml_boundary_execute( except ImportError as e: # pragma: no cover raise qml.QuantumFunctionError( - f"{mapped_interface} not found. Please install the latest " - f"version of {mapped_interface} to enable the '{mapped_interface}' interface." + f"{interface} not found. Please install the latest " + f"version of {interface} to enable the '{interface}' interface." ) from e + return ml_boundary @@ -263,12 +259,22 @@ def _get_interface_name(tapes, interface): Returns: str: Interface name""" + + if interface not in SUPPORTED_INTERFACE_NAMES: + raise qml.QuantumFunctionError( + f"Unknown interface {interface}. Interface must be one of {SUPPORTED_INTERFACE_NAMES}." + ) + + interface = INTERFACE_MAP[interface] + if interface == "auto": params = [] for tape in tapes: params.extend(tape.get_parameters(trainable_only=False)) interface = qml.math.get_interface(*params) - if INTERFACE_MAP.get(interface, "") == "tf" and _use_tensorflow_autograph(): + if interface != "numpy": + interface = INTERFACE_MAP[interface] + if interface == "tf" and _use_tensorflow_autograph(): interface = "tf-autograph" if interface == "jax": try: # pragma: no cover @@ -439,6 +445,7 @@ def cost_fn(params, x): ### Specifying and preprocessing variables #### + _interface_user_input = interface interface = _get_interface_name(tapes, interface) # Only need to calculate derivatives with jax when we know it will be executed later. if interface in {"jax", "jax-jit"}: @@ -460,7 +467,11 @@ def cost_fn(params, x): ) # Mid-circuit measurement configuration validation - mcm_interface = interface or _get_interface_name(tapes, "auto") + # If the user specifies `interface=None`, regular execution considers it numpy, but the mcm + # workflow still needs to know if jax-jit is used + mcm_interface = ( + _get_interface_name(tapes, "auto") if _interface_user_input is None else interface + ) finite_shots = any(tape.shots for tape in tapes) _update_mcm_config(config.mcm_config, mcm_interface, finite_shots) @@ -479,12 +490,12 @@ def cost_fn(params, x): cache = None # changing this set of conditions causes a bunch of tests to break. - no_interface_boundary_required = interface is None or config.gradient_method in { + no_interface_boundary_required = interface == "numpy" or config.gradient_method in { None, "backprop", } device_supports_interface_data = no_interface_boundary_required and ( - interface is None + interface == "numpy" or config.gradient_method == "backprop" or getattr(device, "short_name", "") == "default.mixed" ) @@ -497,9 +508,9 @@ def cost_fn(params, x): numpy_only=not device_supports_interface_data, ) - # moved to its own explicit step so it will be easier to remove + # moved to its own explicit step so that it will be easier to remove def inner_execute_with_empty_jac(tapes, **_): - return (inner_execute(tapes), []) + return inner_execute(tapes), [] if interface in jpc_interfaces: execute_fn = inner_execute @@ -522,7 +533,7 @@ def inner_execute_with_empty_jac(tapes, **_): and getattr(device, "short_name", "") in ("lightning.gpu", "lightning.kokkos") and interface in jpc_interfaces ): # pragma: no cover - if INTERFACE_MAP[interface] == "jax" and "use_device_state" in gradient_kwargs: + if "jax" in interface and "use_device_state" in gradient_kwargs: gradient_kwargs["use_device_state"] = False jpc = LightningVJPs(device, gradient_kwargs=gradient_kwargs) @@ -563,7 +574,7 @@ def execute_fn(internal_tapes) -> tuple[ResultBatch, tuple]: config: the ExecutionConfig that specifies how to perform the simulations. """ numpy_tapes, _ = qml.transforms.convert_to_numpy_parameters(internal_tapes) - return (device.execute(numpy_tapes, config), tuple()) + return device.execute(numpy_tapes, config), tuple() def gradient_fn(internal_tapes): """A partial function that wraps compute_derivatives method of the device. @@ -612,7 +623,7 @@ def gradient_fn(internal_tapes): # trainable parameters can only be set on the first pass for jax # not higher order passes for higher order derivatives - if interface in {"jax", "jax-python", "jax-jit"}: + if "jax" in interface: for tape in tapes: params = tape.get_parameters(trainable_only=False) tape.trainable_params = qml.math.get_trainable_indices(params) diff --git a/pennylane/workflow/interfaces/autograd.py b/pennylane/workflow/interfaces/autograd.py index 9452af31854..cb5731ddc8b 100644 --- a/pennylane/workflow/interfaces/autograd.py +++ b/pennylane/workflow/interfaces/autograd.py @@ -147,6 +147,21 @@ def autograd_execute( return _execute(parameters, tuple(tapes), execute_fn, jpc) +def _to_autograd(result: qml.typing.ResultBatch) -> qml.typing.ResultBatch: + """Converts an arbitrary result batch to one with autograd arrays. + Args: + result (ResultBatch): a nested structure of lists, tuples, dicts, and numpy arrays + Returns: + ResultBatch: a nested structure of tuples, dicts, and jax arrays + """ + if isinstance(result, dict): + return result + # pylint: disable=no-member + if isinstance(result, (list, tuple, autograd.builtins.tuple, autograd.builtins.list)): + return tuple(_to_autograd(r) for r in result) + return autograd.numpy.array(result) + + @autograd.extend.primitive def _execute( parameters, @@ -165,7 +180,7 @@ def _execute( for the input tapes. """ - return execute_fn(tapes) + return _to_autograd(execute_fn(tapes)) # pylint: disable=unused-argument diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py index 408a0794674..ab68a9ad147 100644 --- a/pennylane/workflow/qnode.py +++ b/pennylane/workflow/qnode.py @@ -32,7 +32,7 @@ from pennylane.tape import QuantumScript, QuantumTape from pennylane.transforms.core import TransformContainer, TransformDispatcher, TransformProgram -from .execution import INTERFACE_MAP, SUPPORTED_INTERFACES, SupportedInterfaceUserInput +from .execution import INTERFACE_MAP, SUPPORTED_INTERFACE_NAMES, SupportedInterfaceUserInput logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) @@ -56,9 +56,8 @@ def _convert_to_interface(res, interface): """ Recursively convert res to the given interface. """ - interface = INTERFACE_MAP[interface] - if interface in ["Numpy"]: + if interface == "numpy": return res if isinstance(res, (list, tuple)): @@ -67,7 +66,18 @@ def _convert_to_interface(res, interface): if isinstance(res, dict): return {k: _convert_to_interface(v, interface) for k, v in res.items()} - return qml.math.asarray(res, like=interface if interface != "tf" else "tensorflow") + interface_conversion_map = { + "autograd": "autograd", + "jax": "jax", + "jax-jit": "jax", + "torch": "torch", + "tf": "tensorflow", + "tf-autograph": "tensorflow", + } + + interface_name = interface_conversion_map[interface] + + return qml.math.asarray(res, like=interface_name) def _make_execution_config( @@ -495,10 +505,10 @@ def __init__( gradient_kwargs, ) - if interface not in SUPPORTED_INTERFACES: + if interface not in SUPPORTED_INTERFACE_NAMES: raise qml.QuantumFunctionError( f"Unknown interface {interface}. Interface must be " - f"one of {SUPPORTED_INTERFACES}." + f"one of {SUPPORTED_INTERFACE_NAMES}." ) if not isinstance(device, (qml.devices.LegacyDevice, qml.devices.Device)): @@ -524,7 +534,7 @@ def __init__( # input arguments self.func = func self.device = device - self._interface = None if diff_method is None else interface + self._interface = "numpy" if diff_method is None else INTERFACE_MAP[interface] self.diff_method = diff_method mcm_config = qml.devices.MCMConfig(mcm_method=mcm_method, postselect_mode=postselect_mode) cache = (max_diff > 1) if cache == "auto" else cache @@ -617,10 +627,10 @@ def interface(self) -> str: @interface.setter def interface(self, value: SupportedInterfaceUserInput): - if value not in SUPPORTED_INTERFACES: + if value not in SUPPORTED_INTERFACE_NAMES: raise qml.QuantumFunctionError( - f"Unknown interface {value}. Interface must be one of {SUPPORTED_INTERFACES}." + f"Unknown interface {value}. Interface must be one of {SUPPORTED_INTERFACE_NAMES}." ) self._interface = INTERFACE_MAP[value] @@ -923,12 +933,18 @@ def _execution_component(self, args: tuple, kwargs: dict) -> qml.typing.Result: execute_kwargs["mcm_config"] = mcm_config + # Mapping numpy to None here because `qml.execute` will map None back into + # numpy. If we do not do this, numpy will become autograd in `qml.execute`. + # If the user specified interface="numpy", it would've already been converted to + # "autograd", and it wouldn't be affected. + interface = None if self.interface == "numpy" else self.interface + # pylint: disable=unexpected-keyword-arg res = qml.execute( (self._tape,), device=self.device, gradient_fn=gradient_fn, - interface=self.interface, + interface=interface, transform_program=full_transform_program, inner_transform=inner_transform_program, config=config, @@ -961,7 +977,9 @@ def _impl_call(self, *args, **kwargs) -> qml.typing.Result: if qml.capture.enabled() else qml.math.get_interface(*args, *list(kwargs.values())) ) - self._interface = INTERFACE_MAP[interface] + if interface != "numpy": + interface = INTERFACE_MAP[interface] + self._interface = interface try: res = self._execution_component(args, kwargs) diff --git a/tests/devices/default_qubit/test_default_qubit.py b/tests/devices/default_qubit/test_default_qubit.py index 8b3a1e257dd..d3049d90eae 100644 --- a/tests/devices/default_qubit/test_default_qubit.py +++ b/tests/devices/default_qubit/test_default_qubit.py @@ -1960,7 +1960,7 @@ def test_postselection_invalid_analytic( dev = qml.device("default.qubit") @qml.defer_measurements - @qml.qnode(dev, interface=interface) + @qml.qnode(dev, interface=None if interface == "numpy" else interface) def circ(): qml.RX(np.pi, 0) qml.CNOT([0, 1]) diff --git a/tests/devices/qubit/test_simulate.py b/tests/devices/qubit/test_simulate.py index dbe9573b8df..4dce5afd4c5 100644 --- a/tests/devices/qubit/test_simulate.py +++ b/tests/devices/qubit/test_simulate.py @@ -205,7 +205,7 @@ def test_result_has_correct_interface(self, op): def test_expand_state_keeps_autograd_interface(self): """Test that expand_state doesn't convert autograd to numpy.""" - @qml.qnode(qml.device("default.qubit", wires=2)) + @qml.qnode(qml.device("default.qubit", wires=2), interface="autograd") def circuit(x): qml.RX(x, 0) return qml.probs(wires=[0, 1]) diff --git a/tests/gradients/finite_diff/test_spsa_gradient.py b/tests/gradients/finite_diff/test_spsa_gradient.py index d8f19dcf826..2730cd53d00 100644 --- a/tests/gradients/finite_diff/test_spsa_gradient.py +++ b/tests/gradients/finite_diff/test_spsa_gradient.py @@ -14,11 +14,11 @@ """ Tests for the gradients.spsa_gradient module. """ -import numpy +import numpy as np import pytest import pennylane as qml -from pennylane import numpy as np +from pennylane import numpy as pnp from pennylane.devices import DefaultQubitLegacy from pennylane.gradients import spsa_grad from pennylane.gradients.spsa_gradient import _rademacher_sampler @@ -168,7 +168,7 @@ def circuit(param): expected_message = "The argument sampler_rng is expected to be a NumPy PRNG" with pytest.raises(ValueError, match=expected_message): - qml.grad(circuit)(np.array(1.0)) + qml.grad(circuit)(pnp.array(1.0)) def test_trainable_batched_tape_raises(self): """Test that an error is raised for a broadcasted/batched tape if the broadcasted @@ -202,7 +202,7 @@ def test_nontrainable_batched_tape(self): def test_non_differentiable_error(self): """Test error raised if attempting to differentiate with respect to a non-differentiable argument""" - psi = np.array([1, 0, 1, 0], requires_grad=False) / np.sqrt(2) + psi = pnp.array([1, 0, 1, 0], requires_grad=False) / np.sqrt(2) with qml.queuing.AnnotatedQueue() as q: qml.StatePrep(psi, wires=[0, 1]) @@ -227,10 +227,10 @@ def test_non_differentiable_error(self): assert isinstance(res, tuple) assert len(res) == 2 - assert isinstance(res[0], numpy.ndarray) + assert isinstance(res[0], np.ndarray) assert res[0].shape == (4,) - assert isinstance(res[1], numpy.ndarray) + assert isinstance(res[1], np.ndarray) assert res[1].shape == (4,) @pytest.mark.parametrize("num_directions", [1, 10]) @@ -252,8 +252,8 @@ def test_independent_parameter(self, num_directions, mocker): assert isinstance(res, tuple) assert len(res) == 2 - assert isinstance(res[0], numpy.ndarray) - assert isinstance(res[1], numpy.ndarray) + assert isinstance(res[0], np.ndarray) + assert isinstance(res[1], np.ndarray) # 2 tapes per direction because the default strategy for SPSA is "center" assert len(spy.call_args_list) == num_directions @@ -282,7 +282,7 @@ def test_no_trainable_params_tape(self): res = post_processing(qml.execute(g_tapes, dev, None)) assert g_tapes == [] - assert isinstance(res, numpy.ndarray) + assert isinstance(res, np.ndarray) assert res.shape == (0,) def test_no_trainable_params_multiple_return_tape(self): @@ -383,7 +383,7 @@ def circuit(params): qml.Rot(*params, wires=0) return qml.probs([2, 3]) - params = np.array([0.5, 0.5, 0.5], requires_grad=True) + params = pnp.array([0.5, 0.5, 0.5], requires_grad=True) result = spsa_grad(circuit)(params) @@ -402,7 +402,7 @@ def circuit(params): qml.Rot(*params, wires=0) return qml.expval(qml.PauliZ(wires=2)), qml.probs([2, 3]) - params = np.array([0.5, 0.5, 0.5], requires_grad=True) + params = pnp.array([0.5, 0.5, 0.5], requires_grad=True) result = spsa_grad(circuit)(params) @@ -514,7 +514,7 @@ def cost6(x): qml.Rot(*x, wires=0) return qml.probs([0, 1]), qml.probs([2, 3]) - x = np.random.rand(3) + x = pnp.random.rand(3) circuits = [qml.QNode(cost, dev) for cost in (cost1, cost2, cost3, cost4, cost5, cost6)] transform = [qml.math.shape(spsa_grad(c)(x)) for c in circuits] @@ -576,7 +576,7 @@ class DeviceSupportingSpecialObservable(DefaultQubitLegacy): @staticmethod def _asarray(arr, dtype=None): - return np.array(arr, dtype=dtype) + return pnp.array(arr, dtype=dtype) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -603,9 +603,11 @@ def reference_qnode(x): qml.RY(x, wires=0) return qml.expval(qml.PauliZ(wires=0)) - par = np.array(0.2, requires_grad=True) - assert np.isclose(qnode(par).item().val, reference_qnode(par)) - assert np.isclose(qml.jacobian(qnode)(par).item().val, qml.jacobian(reference_qnode)(par)) + par = pnp.array(0.2, requires_grad=True) + assert np.isclose(qnode(par).item().val, reference_qnode(par).item()) + assert np.isclose( + qml.jacobian(qnode)(par).item().val, qml.jacobian(reference_qnode)(par).item() + ) @pytest.mark.parametrize("approx_order", [2, 4]) @@ -684,10 +686,10 @@ def test_single_expectation_value(self, approx_order, strategy, validate, tol): # 1 / num_params here. res = tuple(qml.math.convert_like(r * 2, r) for r in res) - assert isinstance(res[0], numpy.ndarray) + assert isinstance(res[0], np.ndarray) assert res[0].shape == () - assert isinstance(res[1], numpy.ndarray) + assert isinstance(res[1], np.ndarray) assert res[1].shape == () expected = np.array([[-np.sin(y) * np.sin(x), np.cos(y) * np.cos(x)]]) @@ -728,10 +730,10 @@ def test_single_expectation_value_with_argnum_all(self, approx_order, strategy, # 1 / num_params here. res = tuple(qml.math.convert_like(r * 2, r) for r in res) - assert isinstance(res[0], numpy.ndarray) + assert isinstance(res[0], np.ndarray) assert res[0].shape == () - assert isinstance(res[1], numpy.ndarray) + assert isinstance(res[1], np.ndarray) assert res[1].shape == () expected = np.array([[-np.sin(y) * np.sin(x), np.cos(y) * np.cos(x)]]) @@ -772,10 +774,10 @@ def test_single_expectation_value_with_argnum_one(self, approx_order, strategy, assert isinstance(res, tuple) assert len(res) == 2 - assert isinstance(res[0], numpy.ndarray) + assert isinstance(res[0], np.ndarray) assert res[0].shape == () - assert isinstance(res[1], numpy.ndarray) + assert isinstance(res[1], np.ndarray) assert res[1].shape == () expected = [0, np.cos(y) * np.cos(x)] @@ -856,14 +858,14 @@ def test_multiple_expectation_values(self, approx_order, strategy, validate, tol assert isinstance(res[0], tuple) assert len(res[0]) == 2 assert np.allclose(res[0], [-np.sin(x), 0], atol=tol, rtol=0) - assert isinstance(res[0][0], numpy.ndarray) - assert isinstance(res[0][1], numpy.ndarray) + assert isinstance(res[0][0], np.ndarray) + assert isinstance(res[0][1], np.ndarray) assert isinstance(res[1], tuple) assert len(res[1]) == 2 assert np.allclose(res[1], [0, np.cos(y)], atol=tol, rtol=0) - assert isinstance(res[1][0], numpy.ndarray) - assert isinstance(res[1][1], numpy.ndarray) + assert isinstance(res[1][0], np.ndarray) + assert isinstance(res[1][1], np.ndarray) def test_var_expectation_values(self, approx_order, strategy, validate, tol): """Tests correct output shape and evaluation for a tape @@ -901,14 +903,14 @@ def test_var_expectation_values(self, approx_order, strategy, validate, tol): assert isinstance(res[0], tuple) assert len(res[0]) == 2 assert np.allclose(res[0], [-np.sin(x), 0], atol=tol, rtol=0) - assert isinstance(res[0][0], numpy.ndarray) - assert isinstance(res[0][1], numpy.ndarray) + assert isinstance(res[0][0], np.ndarray) + assert isinstance(res[0][1], np.ndarray) assert isinstance(res[1], tuple) assert len(res[1]) == 2 assert np.allclose(res[1], [0, -2 * np.cos(y) * np.sin(y)], atol=tol, rtol=0) - assert isinstance(res[1][0], numpy.ndarray) - assert isinstance(res[1][1], numpy.ndarray) + assert isinstance(res[1][0], np.ndarray) + assert isinstance(res[1][1], np.ndarray) def test_prob_expectation_values(self, approx_order, strategy, validate, tol): """Tests correct output shape and evaluation for a tape @@ -946,9 +948,9 @@ def test_prob_expectation_values(self, approx_order, strategy, validate, tol): assert isinstance(res[0], tuple) assert len(res[0]) == 2 assert np.allclose(res[0][0], -np.sin(x), atol=tol, rtol=0) - assert isinstance(res[0][0], numpy.ndarray) + assert isinstance(res[0][0], np.ndarray) assert np.allclose(res[0][1], 0, atol=tol, rtol=0) - assert isinstance(res[0][1], numpy.ndarray) + assert isinstance(res[0][1], np.ndarray) assert isinstance(res[1], tuple) assert len(res[1]) == 2 @@ -963,7 +965,7 @@ def test_prob_expectation_values(self, approx_order, strategy, validate, tol): atol=tol, rtol=0, ) - assert isinstance(res[1][0], numpy.ndarray) + assert isinstance(res[1][0], np.ndarray) assert np.allclose( res[1][1], [ @@ -975,7 +977,7 @@ def test_prob_expectation_values(self, approx_order, strategy, validate, tol): atol=tol, rtol=0, ) - assert isinstance(res[1][1], numpy.ndarray) + assert isinstance(res[1][1], np.ndarray) @pytest.mark.parametrize( @@ -989,7 +991,7 @@ def test_autograd(self, sampler, num_directions, atol): """Tests that the output of the SPSA gradient transform can be differentiated using autograd, yielding second derivatives.""" dev = qml.device("default.qubit", wires=2) - params = np.array([0.543, -0.654], requires_grad=True) + params = pnp.array([0.543, -0.654], requires_grad=True) rng = np.random.default_rng(42) def cost_fn(x): @@ -1004,7 +1006,7 @@ def cost_fn(x): tapes, fn = spsa_grad( tape, n=1, num_directions=num_directions, sampler=sampler, sampler_rng=rng ) - jac = np.array(fn(dev.execute(tapes))) + jac = pnp.array(fn(dev.execute(tapes))) if sampler is coordinate_sampler: jac *= 2 return jac @@ -1025,7 +1027,7 @@ def test_autograd_ragged(self, sampler, num_directions, atol): """Tests that the output of the SPSA gradient transform of a ragged tape can be differentiated using autograd, yielding second derivatives.""" dev = qml.device("default.qubit", wires=2) - params = np.array([0.543, -0.654], requires_grad=True) + params = pnp.array([0.543, -0.654], requires_grad=True) rng = np.random.default_rng(42) def cost_fn(x): diff --git a/tests/gradients/finite_diff/test_spsa_gradient_shot_vec.py b/tests/gradients/finite_diff/test_spsa_gradient_shot_vec.py index 46f8aa1288e..2c771dc2832 100644 --- a/tests/gradients/finite_diff/test_spsa_gradient_shot_vec.py +++ b/tests/gradients/finite_diff/test_spsa_gradient_shot_vec.py @@ -14,11 +14,11 @@ """ Tests for the gradients.spsa_gradient module using shot vectors. """ -import numpy +import numpy as np import pytest import pennylane as qml -from pennylane import numpy as np +from pennylane import numpy as pnp from pennylane.devices import DefaultQubitLegacy from pennylane.gradients import spsa_grad from pennylane.measurements import Shots @@ -49,7 +49,7 @@ class TestSpsaGradient: def test_non_differentiable_error(self): """Test error raised if attempting to differentiate with respect to a non-differentiable argument""" - psi = np.array([1, 0, 1, 0], requires_grad=False) / np.sqrt(2) + psi = pnp.array([1, 0, 1, 0], requires_grad=False) / np.sqrt(2) with qml.queuing.AnnotatedQueue() as q: qml.StatePrep(psi, wires=[0, 1]) @@ -78,10 +78,10 @@ def test_non_differentiable_error(self): for res in all_res: assert isinstance(res, tuple) - assert isinstance(res[0], numpy.ndarray) + assert isinstance(res[0], np.ndarray) assert res[0].shape == (4,) - assert isinstance(res[1], numpy.ndarray) + assert isinstance(res[1], np.ndarray) assert res[1].shape == (4,) @pytest.mark.parametrize("num_directions", [1, 6]) @@ -107,8 +107,8 @@ def test_independent_parameter(self, num_directions, mocker): assert isinstance(res, tuple) assert len(res) == 2 - assert isinstance(res[0], numpy.ndarray) - assert isinstance(res[1], numpy.ndarray) + assert isinstance(res[0], np.ndarray) + assert isinstance(res[1], np.ndarray) # 2 tapes per direction because the default strategy for SPSA is "center" assert len(spy.call_args_list) == num_directions @@ -139,7 +139,7 @@ def test_no_trainable_params_tape(self): for res in all_res: assert g_tapes == [] - assert isinstance(res, numpy.ndarray) + assert isinstance(res, np.ndarray) assert res.shape == (0,) def test_no_trainable_params_multiple_return_tape(self): @@ -244,7 +244,7 @@ def circuit(params): qml.Rot(*params, wires=0) return qml.probs([2, 3]) - params = np.array([0.5, 0.5, 0.5], requires_grad=True) + params = pnp.array([0.5, 0.5, 0.5], requires_grad=True) grad_fn = spsa_grad(circuit, h=h_val, sampler_rng=rng) all_result = grad_fn(params) @@ -269,7 +269,7 @@ def circuit(params): qml.Rot(*params, wires=0) return qml.expval(qml.PauliZ(wires=2)), qml.probs([2, 3]) - params = np.array([0.5, 0.5, 0.5], requires_grad=True) + params = pnp.array([0.5, 0.5, 0.5], requires_grad=True) grad_fn = spsa_grad(circuit, h=h_val, sampler_rng=rng) all_result = grad_fn(params) @@ -416,7 +416,7 @@ def cost6(x): qml.Rot(*x, wires=0) return qml.probs([0, 1]), qml.probs([2, 3]) - x = np.random.rand(3) + x = pnp.random.rand(3) circuits = [qml.QNode(cost, dev) for cost in (cost1, cost2, cost3, cost4, cost5, cost6)] transform = [qml.math.shape(spsa_grad(c, h=h_val)(x)) for c in circuits] @@ -498,9 +498,11 @@ def reference_qnode(x): qml.RY(x, wires=0) return qml.expval(qml.PauliZ(wires=0)) - par = np.array(0.2, requires_grad=True) - assert np.isclose(qnode(par).item().val, reference_qnode(par)) - assert np.isclose(qml.jacobian(qnode)(par).item().val, qml.jacobian(reference_qnode)(par)) + par = pnp.array(0.2, requires_grad=True) + assert np.isclose(qnode(par).item().val, reference_qnode(par).item()) + assert np.isclose( + qml.jacobian(qnode)(par).item().val, qml.jacobian(reference_qnode)(par).item() + ) @pytest.mark.parametrize("approx_order", [2, 4]) @@ -586,10 +588,10 @@ def test_single_expectation_value(self, approx_order, strategy, validate): assert isinstance(res, tuple) assert len(res) == 2 - assert isinstance(res[0], numpy.ndarray) + assert isinstance(res[0], np.ndarray) assert res[0].shape == () - assert isinstance(res[1], numpy.ndarray) + assert isinstance(res[1], np.ndarray) assert res[1].shape == () # The coordinate_sampler produces the right evaluation points, but the tape execution @@ -635,10 +637,10 @@ def test_single_expectation_value_with_argnum_all(self, approx_order, strategy, assert isinstance(res, tuple) assert len(res) == 2 - assert isinstance(res[0], numpy.ndarray) + assert isinstance(res[0], np.ndarray) assert res[0].shape == () - assert isinstance(res[1], numpy.ndarray) + assert isinstance(res[1], np.ndarray) assert res[1].shape == () # The coordinate_sampler produces the right evaluation points, but the tape execution @@ -689,10 +691,10 @@ def test_single_expectation_value_with_argnum_one(self, approx_order, strategy, assert isinstance(res, tuple) assert len(res) == 2 - assert isinstance(res[0], numpy.ndarray) + assert isinstance(res[0], np.ndarray) assert res[0].shape == () - assert isinstance(res[1], numpy.ndarray) + assert isinstance(res[1], np.ndarray) assert res[1].shape == () # The coordinate_sampler produces the right evaluation points and there is just one @@ -783,13 +785,13 @@ def test_multiple_expectation_values(self, approx_order, strategy, validate): assert isinstance(res[0], tuple) assert len(res[0]) == 2 - assert isinstance(res[0][0], numpy.ndarray) - assert isinstance(res[0][1], numpy.ndarray) + assert isinstance(res[0][0], np.ndarray) + assert isinstance(res[0][1], np.ndarray) assert isinstance(res[1], tuple) assert len(res[1]) == 2 - assert isinstance(res[1][0], numpy.ndarray) - assert isinstance(res[1][1], numpy.ndarray) + assert isinstance(res[1][0], np.ndarray) + assert isinstance(res[1][1], np.ndarray) # The coordinate_sampler produces the right evaluation points, but the tape execution # results are averaged instead of added, so that we need to revert the prefactor @@ -837,13 +839,13 @@ def test_var_expectation_values(self, approx_order, strategy, validate): assert isinstance(res[0], tuple) assert len(res[0]) == 2 - assert isinstance(res[0][0], numpy.ndarray) - assert isinstance(res[0][1], numpy.ndarray) + assert isinstance(res[0][0], np.ndarray) + assert isinstance(res[0][1], np.ndarray) assert isinstance(res[1], tuple) assert len(res[1]) == 2 - assert isinstance(res[1][0], numpy.ndarray) - assert isinstance(res[1][1], numpy.ndarray) + assert isinstance(res[1][0], np.ndarray) + assert isinstance(res[1][1], np.ndarray) # The coordinate_sampler produces the right evaluation points, but the tape execution # results are averaged instead of added, so that we need to revert the prefactor @@ -892,13 +894,13 @@ def test_prob_expectation_values(self, approx_order, strategy, validate): assert isinstance(res[0], tuple) assert len(res[0]) == 2 - assert isinstance(res[0][0], numpy.ndarray) - assert isinstance(res[0][1], numpy.ndarray) + assert isinstance(res[0][0], np.ndarray) + assert isinstance(res[0][1], np.ndarray) assert isinstance(res[1], tuple) assert len(res[1]) == 2 - assert isinstance(res[1][0], numpy.ndarray) - assert isinstance(res[1][1], numpy.ndarray) + assert isinstance(res[1][0], np.ndarray) + assert isinstance(res[1][1], np.ndarray) # The coordinate_sampler produces the right evaluation points, but the tape execution # results are averaged instead of added, so that we need to revert the prefactor @@ -943,7 +945,7 @@ def test_autograd(self, approx_order, strategy): """Tests that the output of the SPSA gradient transform can be differentiated using autograd, yielding second derivatives.""" dev = qml.device("default.qubit", wires=2, shots=many_shots_shot_vector) - params = np.array([0.543, -0.654], requires_grad=True) + params = pnp.array([0.543, -0.654], requires_grad=True) rng = np.random.default_rng(42) def cost_fn(x): @@ -986,7 +988,7 @@ def test_autograd_ragged(self, approx_order, strategy): """Tests that the output of the SPSA gradient transform of a ragged tape can be differentiated using autograd, yielding second derivatives.""" dev = qml.device("default.qubit", wires=2, shots=many_shots_shot_vector) - params = np.array([0.543, -0.654], requires_grad=True) + params = pnp.array([0.543, -0.654], requires_grad=True) rng = np.random.default_rng(42) def cost_fn(x): diff --git a/tests/interfaces/test_jax_jit.py b/tests/interfaces/test_jax_jit.py index a9927dad7fb..eea7b6be52a 100644 --- a/tests/interfaces/test_jax_jit.py +++ b/tests/interfaces/test_jax_jit.py @@ -107,7 +107,7 @@ def cost(a, device): interface="None", )[0] - with pytest.raises(ValueError, match="Unknown interface"): + with pytest.raises(qml.QuantumFunctionError, match="Unknown interface"): cost(a, device=dev) def test_grad_on_execution(self, mocker): diff --git a/tests/measurements/test_sample.py b/tests/measurements/test_sample.py index e0d4ec25724..d31ce97d4a5 100644 --- a/tests/measurements/test_sample.py +++ b/tests/measurements/test_sample.py @@ -121,8 +121,8 @@ def circuit(): # If all the dimensions are equal the result will end up to be a proper rectangular array assert len(result) == 3 - assert isinstance(result[0], np.ndarray) - assert isinstance(result[1], np.ndarray) + assert isinstance(result[0], float) + assert isinstance(result[1], float) assert result[2].dtype == np.dtype("float") assert np.array_equal(result[2].shape, (n_sample,)) diff --git a/tests/qnn/test_keras.py b/tests/qnn/test_keras.py index f4f9769edc2..1115460922d 100644 --- a/tests/qnn/test_keras.py +++ b/tests/qnn/test_keras.py @@ -588,7 +588,11 @@ def circuit(inputs, w1): return qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliZ(1)) qlayer = KerasLayer(circuit, weight_shapes, output_dim=2) - assert qlayer.qnode.interface == circuit.interface == interface + assert ( + qlayer.qnode.interface + == circuit.interface + == qml.workflow.execution.INTERFACE_MAP[interface] + ) @pytest.mark.tf diff --git a/tests/qnn/test_qnn_torch.py b/tests/qnn/test_qnn_torch.py index 64aeb9b1a9c..e2642df0e4b 100644 --- a/tests/qnn/test_qnn_torch.py +++ b/tests/qnn/test_qnn_torch.py @@ -632,7 +632,11 @@ def circuit(inputs, w1): return qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliZ(1)) qlayer = TorchLayer(circuit, weight_shapes) - assert qlayer.qnode.interface == circuit.interface == interface + assert ( + qlayer.qnode.interface + == circuit.interface + == qml.workflow.execution.INTERFACE_MAP[interface] + ) @pytest.mark.torch diff --git a/tests/test_qnode.py b/tests/test_qnode.py index 38b8847106d..1322ca62c16 100644 --- a/tests/test_qnode.py +++ b/tests/test_qnode.py @@ -434,7 +434,7 @@ def circuit(x): qml.RX(x, wires=0) return qml.expval(qml.PauliZ(0)) - assert circuit.interface is None + assert circuit.interface == "numpy" with pytest.warns( qml.PennyLaneDeprecationWarning, match=r"QNode.gradient_fn is deprecated" ): @@ -1139,6 +1139,20 @@ def circuit(): assert q.queue == [] # pylint: disable=use-implicit-booleaness-not-comparison assert len(circuit.tape.operations) == 1 + def test_qnode_preserves_inferred_numpy_interface(self): + """Tests that the QNode respects the inferred numpy interface.""" + + dev = qml.device("default.qubit", wires=1) + + @qml.qnode(dev) + def circuit(x): + qml.RX(x, wires=0) + return qml.expval(qml.PauliZ(0)) + + x = np.array(0.8) + res = circuit(x) + assert qml.math.get_interface(res) == "numpy" + class TestShots: """Unit tests for specifying shots per call.""" @@ -1899,7 +1913,7 @@ def circuit(x): else: spy = mocker.spy(circuit.device, "execute") - x = np.array(0.5) + x = pnp.array(0.5) circuit(x) tape = spy.call_args[0][0][0] diff --git a/tests/test_qnode_legacy.py b/tests/test_qnode_legacy.py index 3ee36d99bdb..73eaf29b302 100644 --- a/tests/test_qnode_legacy.py +++ b/tests/test_qnode_legacy.py @@ -1488,7 +1488,7 @@ def circuit(x): else: spy = mocker.spy(circuit.device, "execute") - x = np.array(0.5) + x = pnp.array(0.5) circuit(x) tape = spy.call_args[0][0][0] From b78565cb119e926388375898093aa2387a4e6480 Mon Sep 17 00:00:00 2001 From: Pietropaolo Frisoni Date: Mon, 16 Sep 2024 10:47:02 -0400 Subject: [PATCH 5/6] Updating torch to 2.3.0 (#6258) **Context:** We want to make PL compatible with torch 2.3.0 (including GPU tests) after updating Numpy from 1.x to 2.x. **Description of the Change:** As above. **Benefits:** We are sure that PL is compatible with torch 2.3.0. **Possible Drawbacks:** None that I can think of right now. **Related GitHub Issues:** None. **Related Shortcut Stories:** [sc-61391] --- .github/workflows/tests-gpu.yml | 2 +- doc/releases/changelog-dev.md | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests-gpu.yml b/.github/workflows/tests-gpu.yml index 846261ff898..bd8cd5dae67 100644 --- a/.github/workflows/tests-gpu.yml +++ b/.github/workflows/tests-gpu.yml @@ -15,7 +15,7 @@ concurrency: cancel-in-progress: true env: - TORCH_VERSION: 2.2.0 + TORCH_VERSION: 2.3.0 jobs: gpu-tests: diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 559f1f20b6e..045f0b4528c 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -7,7 +7,8 @@

Improvements 🛠

* PennyLane is now compatible with NumPy 2.0. - [(#6061)](https://github.com/PennyLaneAI/pennylane/pull/6061) + [(#6061)](https://github.com/PennyLaneAI/pennylane/pull/6061) + [(#6258)](https://github.com/PennyLaneAI/pennylane/pull/6258) * `qml.qchem.excitations` now optionally returns fermionic operators. [(#6171)](https://github.com/PennyLaneAI/pennylane/pull/6171) From 7a4a44bdda8105e33074c41b8284b81cbd2005ed Mon Sep 17 00:00:00 2001 From: ringo-but-quantum Date: Tue, 17 Sep 2024 09:51:42 +0000 Subject: [PATCH 6/6] [no ci] bump nightly version --- pennylane/_version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pennylane/_version.py b/pennylane/_version.py index 77639685bc6..0c39c922ce2 100644 --- a/pennylane/_version.py +++ b/pennylane/_version.py @@ -16,4 +16,4 @@ Version number (major.minor.patch[-label]) """ -__version__ = "0.39.0-dev15" +__version__ = "0.39.0-dev16"