From 7a4a44bdda8105e33074c41b8284b81cbd2005ed Mon Sep 17 00:00:00 2001 From: ringo-but-quantum Date: Tue, 17 Sep 2024 09:51:42 +0000 Subject: [PATCH 1/4] [no ci] bump nightly version --- pennylane/_version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pennylane/_version.py b/pennylane/_version.py index 77639685bc6..0c39c922ce2 100644 --- a/pennylane/_version.py +++ b/pennylane/_version.py @@ -16,4 +16,4 @@ Version number (major.minor.patch[-label]) """ -__version__ = "0.39.0-dev15" +__version__ = "0.39.0-dev16" From 805d4cf0abe85a5b79445411f6cb8b29195ba14f Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Tue, 17 Sep 2024 09:58:23 -0400 Subject: [PATCH 2/4] Add `reference.qubit` for testing and reference (#6181) Created from https://github.com/PennyLaneAI/pennylane/pull/5445 [sc-65558] --------- Co-authored-by: dwierichs Co-authored-by: Christina Lee Co-authored-by: lillian542 <38584660+lillian542@users.noreply.github.com> --- doc/releases/changelog-dev.md | 3 + pennylane/devices/__init__.py | 3 + pennylane/devices/reference_qubit.py | 154 ++++++++++++++++++++++ setup.py | 1 + tests/interfaces/test_autograd.py | 110 +++++++++------- tests/interfaces/test_autograd_qnode.py | 6 + tests/interfaces/test_jax.py | 20 ++- tests/interfaces/test_jax_jit_qnode.py | 13 ++ tests/interfaces/test_jax_qnode.py | 19 ++- tests/interfaces/test_tensorflow.py | 10 ++ tests/interfaces/test_tensorflow_qnode.py | 6 + tests/interfaces/test_torch.py | 15 +++ tests/interfaces/test_torch_qnode.py | 6 + 13 files changed, 311 insertions(+), 55 deletions(-) create mode 100644 pennylane/devices/reference_qubit.py diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 045f0b4528c..91194ac1e19 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -51,6 +51,9 @@ unique representation of the object. [(#6167)](https://github.com/PennyLaneAI/pennylane/pull/6167) +* A `ReferenceQubit` is introduced for testing purposes and as a reference for future plugin development. + [(#6181)](https://github.com/PennyLaneAI/pennylane/pull/6181) + * The `to_mat` methods for `FermiWord` and `FermiSentence` now optionally return a sparse matrix. [(#6173)](https://github.com/PennyLaneAI/pennylane/pull/6173) diff --git a/pennylane/devices/__init__.py b/pennylane/devices/__init__.py index a542ba7df1d..ac9581ede40 100644 --- a/pennylane/devices/__init__.py +++ b/pennylane/devices/__init__.py @@ -37,6 +37,7 @@ _qubit_device _qutrit_device null_qubit + reference_qubit tests Next generation devices @@ -58,6 +59,7 @@ DefaultQubit DefaultTensor NullQubit + ReferenceQubit DefaultQutritMixed LegacyDeviceFacade @@ -160,6 +162,7 @@ def execute(self, circuits, execution_config = qml.devices.DefaultExecutionConfi from .default_clifford import DefaultClifford from .default_tensor import DefaultTensor from .null_qubit import NullQubit +from .reference_qubit import ReferenceQubit from .default_qutrit import DefaultQutrit from .default_qutrit_mixed import DefaultQutritMixed from ._legacy_device import Device as LegacyDevice diff --git a/pennylane/devices/reference_qubit.py b/pennylane/devices/reference_qubit.py new file mode 100644 index 00000000000..49537d71a6e --- /dev/null +++ b/pennylane/devices/reference_qubit.py @@ -0,0 +1,154 @@ +# Copyright 2018-2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Contains the ReferenceQubit device, a minimal device that can be used for testing +and plugin development purposes. +""" + +import numpy as np + +import pennylane as qml + +from .device_api import Device +from .execution_config import DefaultExecutionConfig +from .modifiers import simulator_tracking, single_tape_support +from .preprocess import decompose, validate_device_wires, validate_measurements + + +def sample_state(state: np.ndarray, shots: int, seed=None): + """Generate samples from the provided state and number of shots.""" + + probs = np.imag(state) ** 2 + np.real(state) ** 2 + basis_states = np.arange(len(probs)) + + num_wires = int(np.log2(len(probs))) + + rng = np.random.default_rng(seed) + basis_samples = rng.choice(basis_states, shots, p=probs) + + # convert basis state integers to array of booleans + bin_strings = (format(s, f"0{num_wires}b") for s in basis_samples) + return np.array([[int(val) for val in s] for s in bin_strings]) + + +def simulate(tape: qml.tape.QuantumTape, seed=None) -> qml.typing.Result: + """Simulate a tape and turn it into results. + + Args: + tape (.QuantumTape): a representation of a circuit + seed (Any): A seed to use to control the generation of samples. + + """ + # 1) create the initial state + state = np.zeros(2 ** len(tape.wires)) + state[0] = 1.0 + + # 2) apply all the operations + for op in tape.operations: + op_mat = op.matrix(wire_order=tape.wires) + state = qml.math.matmul(op_mat, state) + + # 3) perform measurements + # note that shots are pulled from the tape, not from the device + if tape.shots: + samples = sample_state(state, shots=tape.shots.total_shots, seed=seed) + # Shot vector support + results = [] + for lower, upper in tape.shots.bins(): + sub_samples = samples[lower:upper] + results.append( + tuple(mp.process_samples(sub_samples, tape.wires) for mp in tape.measurements) + ) + if len(tape.measurements) == 1: + results = [res[0] for res in results] + if not tape.shots.has_partitioned_shots: + results = results[0] + else: + results = tuple(results) + else: + results = tuple(mp.process_state(state, tape.wires) for mp in tape.measurements) + if len(tape.measurements) == 1: + results = results[0] + + return results + + +operations = frozenset({"PauliX", "PauliY", "PauliZ", "Hadamard", "CNOT", "CZ", "RX", "RY", "RZ"}) + + +def supports_operation(op: qml.operation.Operator) -> bool: + """This function used by preprocessing determines what operations + are natively supported by the device. + + While in theory ``simulate`` can support any operation with a matrix, we limit the target + gate set for improved testing and reference purposes. + + """ + return getattr(op, "name", None) in operations + + +@simulator_tracking # update device.tracker with some relevant information +@single_tape_support # add support for device.execute(tape) in addition to device.execute((tape,)) +class ReferenceQubit(Device): + """A slimmed down numpy-based simulator for reference and testing purposes. + + Args: + wires (int, Iterable[Number, str]): Number of wires present on the device, or iterable that + contains unique labels for the wires as numbers (i.e., ``[-1, 0, 2]``) or strings + (``['aux', 'q1', 'q2']``). Default ``None`` if not specified. While this device allows + for ``wires`` to be unspecified at construction time, other devices may make this argument + mandatory. Devices can also implement additional restrictions on the possible wires. + shots (int, Sequence[int], Sequence[Union[int, Sequence[int]]]): The default number of shots + to use in executions involving this device. Note that during execution, shots + are pulled from the circuit, not from the device. + seed (Union[str, None, int, array_like[int], SeedSequence, BitGenerator, Generator, jax.random.PRNGKey]): A + seed-like parameter matching that of ``seed`` for ``numpy.random.default_rng``. This is an optional + keyword argument added to follow recommend NumPy best practices. Other devices do not need + this parameter if it does not make sense for them. + + """ + + name = "reference.qubit" + + def __init__(self, wires=None, shots=None, seed=None): + super().__init__(wires=wires, shots=shots) + + # seed and rng not necessary for a device, but part of recommended + # numpy practices to use a local random number generator + self._rng = np.random.default_rng(seed) + + def preprocess(self, execution_config=DefaultExecutionConfig): + + # Here we convert an arbitrary tape into one natively supported by the device + program = qml.transforms.core.TransformProgram() + program.add_transform(validate_device_wires, wires=self.wires, name="reference.qubit") + program.add_transform(qml.defer_measurements) + program.add_transform(qml.transforms.split_non_commuting) + program.add_transform(qml.transforms.diagonalize_measurements) + program.add_transform( + decompose, + stopping_condition=supports_operation, + skip_initial_state_prep=False, + name="reference.qubit", + ) + program.add_transform(validate_measurements, name="reference.qubit") + program.add_transform(qml.transforms.broadcast_expand) + + # no need to preprocess the config as the device does not support derivatives + return program, execution_config + + def execute(self, circuits, execution_config=DefaultExecutionConfig): + for tape in circuits: + assert all(supports_operation(op) for op in tape.operations) + return tuple(simulate(tape, seed=self._rng) for tape in circuits) diff --git a/setup.py b/setup.py index 41ae9775027..4db98cdca25 100644 --- a/setup.py +++ b/setup.py @@ -52,6 +52,7 @@ "default.qubit.legacy = pennylane.devices:DefaultQubitLegacy", "default.gaussian = pennylane.devices:DefaultGaussian", "default.mixed = pennylane.devices.default_mixed:DefaultMixed", + "reference.qubit = pennylane.devices.reference_qubit:ReferenceQubit", "null.qubit = pennylane.devices.null_qubit:NullQubit", "default.qutrit = pennylane.devices.default_qutrit:DefaultQutrit", "default.clifford = pennylane.devices.default_clifford:DefaultClifford", diff --git a/tests/interfaces/test_autograd.py b/tests/interfaces/test_autograd.py index 2a6ee306508..d206f1758d3 100644 --- a/tests/interfaces/test_autograd.py +++ b/tests/interfaces/test_autograd.py @@ -13,12 +13,13 @@ # limitations under the License. """Autograd specific tests for execute and default qubit 2.""" import autograd +import numpy as np import pytest from param_shift_dev import ParamShiftDerivativesDevice import pennylane as qml from pennylane import execute -from pennylane import numpy as np +from pennylane import numpy as pnp from pennylane.devices import DefaultQubit from pennylane.gradients import param_shift from pennylane.measurements import Shots @@ -36,7 +37,7 @@ def test_caching_param_shift_hessian(self, num_params): caching reduces the number of evaluations to their optimum when computing Hessians.""" dev = DefaultQubit() - params = np.arange(1, num_params + 1) / 10 + params = pnp.arange(1, num_params + 1) / 10 N = len(params) @@ -125,8 +126,8 @@ def f(x): # add tests for lightning 2 when possible # set rng for device when possible test_matrix = [ - ({"gradient_fn": param_shift}, Shots(100000), DefaultQubit(seed=42)), - ({"gradient_fn": param_shift}, Shots((100000, 100000)), DefaultQubit(seed=42)), + ({"gradient_fn": param_shift}, Shots(50000), DefaultQubit(seed=42)), + ({"gradient_fn": param_shift}, Shots((50000, 50000)), DefaultQubit(seed=42)), ({"gradient_fn": param_shift}, Shots(None), DefaultQubit()), ({"gradient_fn": "backprop"}, Shots(None), DefaultQubit()), ( @@ -146,7 +147,7 @@ def f(x): ({"gradient_fn": "adjoint", "device_vjp": True}, Shots(None), DefaultQubit()), ( {"gradient_fn": "device", "device_vjp": False}, - Shots((100000, 100000)), + Shots((50000, 50000)), ParamShiftDerivativesDevice(seed=904747894), ), ( @@ -154,12 +155,27 @@ def f(x): Shots((100000, 100000)), ParamShiftDerivativesDevice(seed=10490244), ), + ( + {"gradient_fn": param_shift}, + Shots(None), + qml.device("reference.qubit"), + ), + ( + {"gradient_fn": param_shift}, + Shots(50000), + qml.device("reference.qubit", seed=8743274), + ), + ( + {"gradient_fn": param_shift}, + Shots((50000, 50000)), + qml.device("reference.qubit", seed=8743274), + ), ] def atol_for_shots(shots): """Return higher tolerance if finite shots.""" - return 1e-2 if shots else 1e-6 + return 5e-2 if shots else 1e-6 @pytest.mark.parametrize("execute_kwargs, shots, device", test_matrix) @@ -179,8 +195,8 @@ def cost(a, b): return execute([tape1, tape2], device, **execute_kwargs) - a = np.array(0.1, requires_grad=True) - b = np.array(0.2, requires_grad=False) + a = pnp.array(0.1, requires_grad=True) + b = pnp.array(0.2, requires_grad=False) with device.tracker: res = cost(a, b) @@ -200,7 +216,7 @@ def cost(a, b): def test_scalar_jacobian(self, execute_kwargs, shots, device): """Test scalar jacobian calculation""" - a = np.array(0.1, requires_grad=True) + a = pnp.array(0.1, requires_grad=True) def cost(a): tape = qml.tape.QuantumScript([qml.RY(a, 0)], [qml.expval(qml.PauliZ(0))], shots=shots) @@ -224,8 +240,8 @@ def cost(a): def test_jacobian(self, execute_kwargs, shots, device): """Test jacobian calculation""" - a = np.array(0.1, requires_grad=True) - b = np.array(0.2, requires_grad=True) + a = pnp.array(0.1, requires_grad=True) + b = pnp.array(0.2, requires_grad=True) def cost(a, b): ops = [qml.RY(a, wires=0), qml.RX(b, wires=1), qml.CNOT(wires=[0, 1])] @@ -270,7 +286,7 @@ def cost(params): ) tape2 = qml.tape.QuantumScript( - [qml.RY(np.array(0.5, requires_grad=False), wires=0)], + [qml.RY(pnp.array(0.5, requires_grad=False), wires=0)], [qml.expval(qml.PauliZ(0))], shots=shots, ) @@ -282,7 +298,7 @@ def cost(params): ) tape4 = qml.tape.QuantumScript( - [qml.RY(np.array(0.5, requires_grad=False), 0)], + [qml.RY(pnp.array(0.5, requires_grad=False), 0)], [qml.probs(wires=[0, 1])], shots=shots, ) @@ -291,7 +307,7 @@ def cost(params): res = tuple(i for r in res for i in r) return sum(autograd.numpy.hstack(res)) - params = np.array([0.1, 0.2], requires_grad=True) + params = pnp.array([0.1, 0.2], requires_grad=True) x, y = params res = cost(params) @@ -321,7 +337,7 @@ def cost(params): ) tape2 = qml.tape.QuantumScript( - [qml.RY(np.array(0.5, requires_grad=False), 0)], + [qml.RY(pnp.array(0.5, requires_grad=False), 0)], [qml.expval(qml.PauliZ(0))], shots=shots, ) @@ -336,7 +352,7 @@ def cost(params): res = tuple(i for r in res for i in r) return autograd.numpy.hstack(res) - params = np.array([0.1, 0.2], requires_grad=True) + params = pnp.array([0.1, 0.2], requires_grad=True) x, y = params res = cost(params) @@ -392,8 +408,8 @@ def cost(params): def test_reusing_quantum_tape(self, execute_kwargs, shots, device): """Test re-using a quantum tape by passing new parameters""" - a = np.array(0.1, requires_grad=True) - b = np.array(0.2, requires_grad=True) + a = pnp.array(0.1, requires_grad=True) + b = pnp.array(0.2, requires_grad=True) tape = qml.tape.QuantumScript( [qml.RY(a, 0), qml.RX(b, 1), qml.CNOT((0, 1))], @@ -408,8 +424,8 @@ def cost(a, b): jac_fn = qml.jacobian(cost) jac = jac_fn(a, b) - a = np.array(0.54, requires_grad=True) - b = np.array(0.8, requires_grad=True) + a = pnp.array(0.54, requires_grad=True) + b = pnp.array(0.8, requires_grad=True) # check that the cost function continues to depend on the # values of the parameters for subsequent calls @@ -429,15 +445,15 @@ def cost(a, b): def test_classical_processing(self, execute_kwargs, device, shots): """Test classical processing within the quantum tape""" - a = np.array(0.1, requires_grad=True) - b = np.array(0.2, requires_grad=False) - c = np.array(0.3, requires_grad=True) + a = pnp.array(0.1, requires_grad=True) + b = pnp.array(0.2, requires_grad=False) + c = pnp.array(0.3, requires_grad=True) def cost(a, b, c): ops = [ qml.RY(a * c, wires=0), qml.RZ(b, wires=0), - qml.RX(c + c**2 + np.sin(a), wires=0), + qml.RX(c + c**2 + pnp.sin(a), wires=0), ] tape = qml.tape.QuantumScript(ops, [qml.expval(qml.PauliZ(0))], shots=shots) @@ -457,8 +473,8 @@ def cost(a, b, c): def test_no_trainable_parameters(self, execute_kwargs, shots, device): """Test evaluation and Jacobian if there are no trainable parameters""" - a = np.array(0.1, requires_grad=False) - b = np.array(0.2, requires_grad=False) + a = pnp.array(0.1, requires_grad=False) + b = pnp.array(0.2, requires_grad=False) def cost(a, b): ops = [qml.RY(a, 0), qml.RX(b, 0), qml.CNOT((0, 1))] @@ -484,8 +500,8 @@ def loss(a, b): def test_matrix_parameter(self, execute_kwargs, device, shots): """Test that the autograd interface works correctly with a matrix parameter""" - U = np.array([[0, 1], [1, 0]], requires_grad=False) - a = np.array(0.1, requires_grad=True) + U = pnp.array([[0, 1], [1, 0]], requires_grad=False) + a = pnp.array(0.1, requires_grad=True) def cost(a, U): ops = [qml.QubitUnitary(U, wires=0), qml.RY(a, wires=0)] @@ -535,8 +551,8 @@ def cost_fn(a, p): program, _ = device.preprocess(execution_config=config) return execute([tape], device, **execute_kwargs, transform_program=program)[0] - a = np.array(0.1, requires_grad=False) - p = np.array([0.1, 0.2, 0.3], requires_grad=True) + a = pnp.array(0.1, requires_grad=False) + p = pnp.array([0.1, 0.2, 0.3], requires_grad=True) res = cost_fn(a, p) expected = np.cos(a) * np.cos(p[1]) * np.sin(p[0]) + np.sin(a) * ( @@ -568,8 +584,8 @@ def cost(x, y): tape = qml.tape.QuantumScript(ops, m) return autograd.numpy.hstack(execute([tape], device, **execute_kwargs)[0]) - x = np.array(0.543, requires_grad=True) - y = np.array(-0.654, requires_grad=True) + x = pnp.array(0.543, requires_grad=True) + y = pnp.array(-0.654, requires_grad=True) res = cost(x, y) expected = np.array( @@ -621,8 +637,8 @@ def cost(x, y): tape = qml.tape.QuantumScript(ops, m) return autograd.numpy.hstack(execute([tape], device, **execute_kwargs)[0]) - x = np.array(0.543, requires_grad=True) - y = np.array(-0.654, requires_grad=True) + x = pnp.array(0.543, requires_grad=True) + y = pnp.array(-0.654, requires_grad=True) res = cost(x, y) expected = np.array( @@ -650,9 +666,9 @@ class TestHigherOrderDerivatives: @pytest.mark.parametrize( "params", [ - np.array([0.543, -0.654], requires_grad=True), - np.array([0, -0.654], requires_grad=True), - np.array([-2.0, 0], requires_grad=True), + pnp.array([0.543, -0.654], requires_grad=True), + pnp.array([0, -0.654], requires_grad=True), + pnp.array([-2.0, 0], requires_grad=True), ], ) def test_parameter_shift_hessian(self, params, tol): @@ -693,7 +709,7 @@ def test_max_diff(self, tol): """Test that setting the max_diff parameter blocks higher-order derivatives""" dev = DefaultQubit() - params = np.array([0.543, -0.654], requires_grad=True) + params = pnp.array([0.543, -0.654], requires_grad=True) def cost_fn(x): ops = [qml.RX(x[0], 0), qml.RY(x[1], 1), qml.CNOT((0, 1))] @@ -788,11 +804,11 @@ def test_multiple_hamiltonians_not_trainable(self, execute_kwargs, cost_fn, shot """Test hamiltonian with no trainable parameters.""" if execute_kwargs["gradient_fn"] == "adjoint" and not qml.operation.active_new_opmath(): - pytest.skip("adjoint differentiation does not suppport hamiltonians.") + pytest.skip("adjoint differentiation does not support hamiltonians.") - coeffs1 = np.array([0.1, 0.2, 0.3], requires_grad=False) - coeffs2 = np.array([0.7], requires_grad=False) - weights = np.array([0.4, 0.5], requires_grad=True) + coeffs1 = pnp.array([0.1, 0.2, 0.3], requires_grad=False) + coeffs2 = pnp.array([0.7], requires_grad=False) + weights = pnp.array([0.4, 0.5], requires_grad=True) res = cost_fn(weights, coeffs1, coeffs2) expected = self.cost_fn_expected(weights, coeffs1, coeffs2) @@ -817,9 +833,9 @@ def test_multiple_hamiltonians_trainable(self, execute_kwargs, cost_fn, shots): if qml.operation.active_new_opmath(): pytest.skip("parameter shift derivatives do not yet support sums.") - coeffs1 = np.array([0.1, 0.2, 0.3], requires_grad=True) - coeffs2 = np.array([0.7], requires_grad=True) - weights = np.array([0.4, 0.5], requires_grad=True) + coeffs1 = pnp.array([0.1, 0.2, 0.3], requires_grad=True) + coeffs2 = pnp.array([0.7], requires_grad=True) + weights = pnp.array([0.4, 0.5], requires_grad=True) res = cost_fn(weights, coeffs1, coeffs2) expected = self.cost_fn_expected(weights, coeffs1, coeffs2) @@ -829,11 +845,11 @@ def test_multiple_hamiltonians_trainable(self, execute_kwargs, cost_fn, shots): else: assert np.allclose(res, expected, atol=atol_for_shots(shots), rtol=0) - res = np.hstack(qml.jacobian(cost_fn)(weights, coeffs1, coeffs2)) + res = pnp.hstack(qml.jacobian(cost_fn)(weights, coeffs1, coeffs2)) expected = self.cost_fn_jacobian(weights, coeffs1, coeffs2) if shots.has_partitioned_shots: pytest.xfail( "multiple hamiltonians with shot vectors does not seem to be differentiable." ) else: - assert np.allclose(res, expected, atol=atol_for_shots(shots), rtol=0) + assert qml.math.allclose(res, expected, atol=atol_for_shots(shots), rtol=0) diff --git a/tests/interfaces/test_autograd_qnode.py b/tests/interfaces/test_autograd_qnode.py index 129ab56dfe8..1d6dcfe397b 100644 --- a/tests/interfaces/test_autograd_qnode.py +++ b/tests/interfaces/test_autograd_qnode.py @@ -37,6 +37,7 @@ [ParamShiftDerivativesDevice(), "best", False, False], [ParamShiftDerivativesDevice(), "parameter-shift", True, False], [ParamShiftDerivativesDevice(), "parameter-shift", False, True], + [qml.device("reference.qubit"), "parameter-shift", False, False], ] interface_qubit_device_and_diff_method = [ @@ -62,6 +63,7 @@ ["auto", DefaultQubit(), "hadamard", False, False], ["auto", qml.device("lightning.qubit", wires=5), "adjoint", False, False], ["auto", qml.device("lightning.qubit", wires=5), "adjoint", True, False], + ["auto", qml.device("reference.qubit"), "parameter-shift", False, False], ] pytestmark = pytest.mark.autograd @@ -1378,6 +1380,8 @@ def test_projector( """Test that the variance of a projector is correctly returned""" if diff_method == "adjoint": pytest.skip("adjoint supports either expvals or diagonal measurements.") + if dev.name == "reference.qubit": + pytest.xfail("diagonalize_measurements do not support projectors (sc-72911)") kwargs = dict( diff_method=diff_method, interface=interface, @@ -1435,6 +1439,8 @@ def test_postselection_differentiation( if diff_method in ["adjoint", "spsa", "hadamard"]: pytest.skip("Diff method does not support postselection.") + if dev.name == "reference.qubit": + pytest.skip("reference.qubit does not support postselection.") @qml.qnode( dev, diff --git a/tests/interfaces/test_jax.py b/tests/interfaces/test_jax.py index 519c0daa028..1c12ba0b524 100644 --- a/tests/interfaces/test_jax.py +++ b/tests/interfaces/test_jax.py @@ -122,16 +122,22 @@ def cost(x, cache): # add tests for lightning 2 when possible # set rng for device when possible no_shots = Shots(None) +shots_10k = Shots(10000) shots_2_10k = Shots((10000, 10000)) -dev_def = DefaultQubit() +dev_def = DefaultQubit(seed=42) dev_ps = ParamShiftDerivativesDevice(seed=54353453) +dev_ref = qml.device("reference.qubit") test_matrix = [ - ({"gradient_fn": param_shift}, Shots(100000), DefaultQubit(seed=42)), # 0 - ({"gradient_fn": param_shift}, no_shots, dev_def), # 1 - ({"gradient_fn": "backprop"}, no_shots, dev_def), # 2 - ({"gradient_fn": "adjoint"}, no_shots, dev_def), # 3 - ({"gradient_fn": "adjoint", "device_vjp": True}, no_shots, dev_def), # 4 - ({"gradient_fn": "device"}, shots_2_10k, dev_ps), # 5 + ({"gradient_fn": param_shift}, shots_10k, dev_def), # 0 + ({"gradient_fn": param_shift}, shots_2_10k, dev_def), # 1 + ({"gradient_fn": param_shift}, no_shots, dev_def), # 2 + ({"gradient_fn": "backprop"}, no_shots, dev_def), # 3 + ({"gradient_fn": "adjoint"}, no_shots, dev_def), # 4 + ({"gradient_fn": "adjoint", "device_vjp": True}, no_shots, dev_def), # 5 + ({"gradient_fn": "device"}, shots_2_10k, dev_ps), # 6 + ({"gradient_fn": param_shift}, no_shots, dev_ref), # 7 + ({"gradient_fn": param_shift}, shots_10k, dev_ref), # 8 + ({"gradient_fn": param_shift}, shots_2_10k, dev_ref), # 9 ] diff --git a/tests/interfaces/test_jax_jit_qnode.py b/tests/interfaces/test_jax_jit_qnode.py index cce76a83b9e..cafa9c47fa1 100644 --- a/tests/interfaces/test_jax_jit_qnode.py +++ b/tests/interfaces/test_jax_jit_qnode.py @@ -41,6 +41,7 @@ [qml.device("lightning.qubit", wires=5), "adjoint", False, False], [qml.device("lightning.qubit", wires=5), "adjoint", True, True], [qml.device("lightning.qubit", wires=5), "parameter-shift", False, False], + [qml.device("reference.qubit"), "parameter-shift", False, False], ] interface_and_qubit_device_and_diff_method = [ ["auto"] + inner_list for inner_list in qubit_device_and_diff_method @@ -1040,6 +1041,8 @@ def test_postselection_differentiation( pytest.xfail("gradient transforms have a different vjp shape convention") elif dev.name == "lightning.qubit": pytest.xfail("lightning qubit does not support postselection.") + if dev.name == "reference.qubit": + pytest.skip("reference.qubit does not support postselection.") @qml.qnode( dev, diff_method=diff_method, interface=interface, grad_on_execution=grad_on_execution @@ -1431,6 +1434,8 @@ def test_projector( elif diff_method == "spsa": gradient_kwargs = {"h": H_FOR_SPSA, "sampler_rng": np.random.default_rng(SEED_FOR_SPSA)} tol = TOL_FOR_SPSA + if dev.name == "reference.qubit": + pytest.xfail("diagonalize_measurements do not support projectors (sc-72911)") P = jax.numpy.array(state) x, y = 0.765, -0.654 @@ -1514,6 +1519,11 @@ def test_hamiltonian_expansion_analytic( are non-commuting groups and the number of shots is None and the first and second order gradients are correctly evaluated""" gradient_kwargs = {} + if dev.name == "reference.qubit": + pytest.skip( + "Cannot add transform to the transform program in preprocessing" + "when using mocker.spy on it." + ) if dev.name == "param_shift.qubit": pytest.xfail("gradients transforms have a different vjp shape convention.") if diff_method == "adjoint": @@ -1840,6 +1850,9 @@ def test_hermitian( to different reasons, hence the parametrization in the test. """ # pylint: disable=unused-argument + if dev.name == "reference.qubit": + pytest.xfail("diagonalize_measurements do not support Hermitians (sc-72911)") + if diff_method == "backprop": pytest.skip("Backpropagation is unsupported if shots > 0.") diff --git a/tests/interfaces/test_jax_qnode.py b/tests/interfaces/test_jax_qnode.py index d24dec3383d..4b612da8e25 100644 --- a/tests/interfaces/test_jax_qnode.py +++ b/tests/interfaces/test_jax_qnode.py @@ -40,6 +40,7 @@ [qml.device("lightning.qubit", wires=5), "adjoint", True, True], [qml.device("lightning.qubit", wires=5), "adjoint", False, False], [qml.device("lightning.qubit", wires=5), "adjoint", True, False], + [qml.device("reference.qubit"), "parameter-shift", False, False], ] interface_and_device_and_diff_method = [ @@ -911,6 +912,8 @@ def test_postselection_differentiation(self, dev, diff_method, grad_on_execution if diff_method in ["adjoint", "spsa", "hadamard"]: pytest.skip("Diff method does not support postselection.") + if dev.name == "reference.qubit": + pytest.xfail("reference.qubit does not support postselection.") @qml.qnode( dev, @@ -1298,6 +1301,8 @@ def test_projector( elif diff_method == "spsa": gradient_kwargs = {"h": H_FOR_SPSA, "sampler_rng": np.random.default_rng(SEED_FOR_SPSA)} tol = TOL_FOR_SPSA + if dev.name == "reference.qubit": + pytest.xfail("diagonalize_measurements do not support projectors (sc-72911)") P = jax.numpy.array(state) x, y = 0.765, -0.654 @@ -1373,7 +1378,7 @@ def circuit(x, y): jax.grad(circuit, argnums=[0])(x, y) @pytest.mark.parametrize("max_diff", [1, 2]) - def test_hamiltonian_expansion_analytic( + def test_split_non_commuting_analytic( self, dev, diff_method, grad_on_execution, max_diff, interface, device_vjp, mocker, tol ): """Test that the Hamiltonian is not expanded if there @@ -1391,6 +1396,11 @@ def test_hamiltonian_expansion_analytic( "sampler_rng": np.random.default_rng(SEED_FOR_SPSA), } tol = TOL_FOR_SPSA + if dev.name == "reference.qubit": + pytest.skip( + "Cannot add transform to the transform program in preprocessing" + "when using mocker.spy on it." + ) spy = mocker.spy(qml.transforms, "split_non_commuting") obs = [qml.PauliX(0), qml.PauliX(0) @ qml.PauliZ(1), qml.PauliZ(0) @ qml.PauliZ(1)] @@ -1451,6 +1461,13 @@ def test_hamiltonian_finite_shots( """Test that the Hamiltonian is correctly measured (and not expanded) if there are non-commuting groups and the number of shots is finite and the first and second order gradients are correctly evaluated""" + + if dev.name == "reference.qubit": + pytest.skip( + "Cannot added to a transform to the transform program in " + "preprocessing when using mocker.spy on it." + ) + gradient_kwargs = {} tol = 0.3 if diff_method in ("adjoint", "backprop", "finite-diff"): diff --git a/tests/interfaces/test_tensorflow.py b/tests/interfaces/test_tensorflow.py index b2329cd27c1..0abe82c1942 100644 --- a/tests/interfaces/test_tensorflow.py +++ b/tests/interfaces/test_tensorflow.py @@ -118,6 +118,16 @@ def cost(x, cache): ({"gradient_fn": "backprop", "interface": "tf-autograph"}, None, DefaultQubit()), # 6 ({"gradient_fn": "adjoint", "interface": "tf-autograph"}, None, DefaultQubit()), # 7 ({"gradient_fn": "adjoint", "interface": "tf", "device_vjp": True}, None, DefaultQubit()), # 8 + ( + {"gradient_fn": param_shift, "interface": "tensorflow"}, + None, + qml.device("reference.qubit"), + ), # 9 + ( + {"gradient_fn": param_shift, "interface": "tensorflow"}, + 100000, + qml.device("reference.qubit"), + ), # 10 ] diff --git a/tests/interfaces/test_tensorflow_qnode.py b/tests/interfaces/test_tensorflow_qnode.py index c09f1632202..c01d32091c6 100644 --- a/tests/interfaces/test_tensorflow_qnode.py +++ b/tests/interfaces/test_tensorflow_qnode.py @@ -38,6 +38,7 @@ [qml.device("lightning.qubit", wires=4), "adjoint", False, False], [qml.device("lightning.qubit", wires=4), "adjoint", True, True], [qml.device("lightning.qubit", wires=4), "adjoint", True, False], + [qml.device("reference.qubit"), "parameter-shift", False, False], ] TOL_FOR_SPSA = 1.0 @@ -980,6 +981,8 @@ def test_projector( kwargs["sampler_rng"] = np.random.default_rng(SEED_FOR_SPSA) kwargs["num_directions"] = 20 tol = TOL_FOR_SPSA + if dev.name == "reference.qubit": + pytest.xfail("diagonalize_measurements do not support projectors (sc-72911)") P = tf.constant(state, dtype=dtype) @@ -1014,6 +1017,9 @@ def test_postselection_differentiation( if diff_method in ["adjoint", "spsa", "hadamard"]: pytest.skip("Diff method does not support postselection.") + if dev.name == "reference.qubit": + pytest.skip("reference.qubit does not support postselection.") + @qml.qnode( dev, diff_method=diff_method, diff --git a/tests/interfaces/test_torch.py b/tests/interfaces/test_torch.py index 3cdcf5eae30..3640d31de9c 100644 --- a/tests/interfaces/test_torch.py +++ b/tests/interfaces/test_torch.py @@ -159,6 +159,21 @@ def cost_cache(x): Shots((100000, 100000)), ParamShiftDerivativesDevice(), ), + ( + {"gradient_fn": param_shift}, + Shots(None), + qml.device("reference.qubit"), + ), + ( + {"gradient_fn": param_shift}, + Shots(100000), + qml.device("reference.qubit"), + ), + ( + {"gradient_fn": param_shift}, + Shots((100000, 100000)), + qml.device("reference.qubit"), + ), ] diff --git a/tests/interfaces/test_torch_qnode.py b/tests/interfaces/test_torch_qnode.py index 82dbda669d4..5ecf181d343 100644 --- a/tests/interfaces/test_torch_qnode.py +++ b/tests/interfaces/test_torch_qnode.py @@ -47,6 +47,7 @@ [ParamShiftDerivativesDevice(), "best", False, False], [ParamShiftDerivativesDevice(), "parameter-shift", True, False], [ParamShiftDerivativesDevice(), "parameter-shift", False, True], + [qml.device("reference.qubit"), "parameter-shift", False, False], ] interface_and_qubit_device_and_diff_method = [ @@ -1131,6 +1132,8 @@ def test_projector( tol = TOL_FOR_SPSA elif diff_method == "hadamard": pytest.skip("Hadamard does not support variances.") + if dev.name == "reference.qubit": + pytest.xfail("diagonalize_measurements do not support projectors (sc-72911)") P = torch.tensor(state, requires_grad=False) @@ -1167,6 +1170,9 @@ def test_postselection_differentiation( if diff_method in ["adjoint", "spsa", "hadamard"]: pytest.skip("Diff method does not support postselection.") + if dev.name == "reference.qubit": + pytest.skip("reference.qubit does not support postselection.") + @qml.qnode( dev, diff_method=diff_method, From fa97f16ecce36bdf2cba3ede228f6ecb49eb449f Mon Sep 17 00:00:00 2001 From: Guillermo Alonso-Linaje <65235481+KetpuntoG@users.noreply.github.com> Date: Tue, 17 Sep 2024 10:52:31 -0400 Subject: [PATCH 3/4] Deprecation QubitStateVector (#6172) [sc-72046] --- doc/development/deprecations.rst | 6 ++++++ doc/releases/changelog-dev.md | 4 ++++ pennylane/ops/qubit/state_preparation.py | 16 ++++++++++++++-- tests/drawer/test_drawable_layers.py | 2 +- tests/ops/functions/conftest.py | 2 +- tests/ops/qubit/test_state_prep.py | 6 ++++++ 6 files changed, 32 insertions(+), 4 deletions(-) diff --git a/doc/development/deprecations.rst b/doc/development/deprecations.rst index c91979963a9..22b478ab3a5 100644 --- a/doc/development/deprecations.rst +++ b/doc/development/deprecations.rst @@ -39,6 +39,12 @@ Pending deprecations - Deprecated in v0.37 - Will be removed in v0.39 +* The ``QubitStateVector`` template is deprecated. + Instead, use ``StatePrep``. + + - Deprecated in v0.39 + - Will be removed in v0.40 + New operator arithmetic deprecations ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 91194ac1e19..3f3a3c04326 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -90,6 +90,10 @@

Deprecations 👋

+* The ``QubitStateVector`` template is deprecated. + Instead, use ``StatePrep``. + [(#6172)](https://github.com/PennyLaneAI/pennylane/pull/6172) + * `Device`, `QubitDevice`, and `QutritDevice` will no longer be accessible via top-level import in v0.40. They will still be accessible as `qml.devices.LegacyDevice`, `qml.devices.QubitDevice`, and `qml.devices.QutritDevice` respectively. diff --git a/pennylane/ops/qubit/state_preparation.py b/pennylane/ops/qubit/state_preparation.py index 7ef84be70fd..6bfb4179e00 100644 --- a/pennylane/ops/qubit/state_preparation.py +++ b/pennylane/ops/qubit/state_preparation.py @@ -15,6 +15,8 @@ This submodule contains the discrete-variable quantum operations concerned with preparing a certain state on the device. """ +import warnings + # pylint:disable=too-many-branches,abstract-method,arguments-differ,protected-access,no-member from typing import Optional @@ -442,9 +444,19 @@ def _preprocess(state, wires, pad_with, normalize, validate_norm): return state -# pylint: disable=missing-class-docstring class QubitStateVector(StatePrep): - pass # QSV is still available + r""" + ``QubitStateVector`` is deprecated and will be removed in version 0.40. Instead, please use ``StatePrep``. + """ + + # pylint: disable=too-many-arguments + def __init__(self, state, wires, pad_with=None, normalize=False, validate_norm=True): + warnings.warn( + "QubitStateVector is deprecated and will be removed in version 0.40. " + "Instead, please use StatePrep.", + qml.PennyLaneDeprecationWarning, + ) + super().__init__(state, wires, pad_with, normalize, validate_norm) class QubitDensityMatrix(Operation): diff --git a/tests/drawer/test_drawable_layers.py b/tests/drawer/test_drawable_layers.py index 7e1bdb1d0bd..729e0f74609 100644 --- a/tests/drawer/test_drawable_layers.py +++ b/tests/drawer/test_drawable_layers.py @@ -196,7 +196,7 @@ def test_mid_measure_custom_wires(self): m1 = qml.measurements.MeasurementValue([mp1], lambda v: v) def teleport(state): - qml.QubitStateVector(state, wires=["A"]) + qml.StatePrep(state, wires=["A"]) qml.Hadamard(wires="a") qml.CNOT(wires=["a", "B"]) qml.CNOT(wires=["A", "a"]) diff --git a/tests/ops/functions/conftest.py b/tests/ops/functions/conftest.py index 92863eb7ab1..692745b48b6 100644 --- a/tests/ops/functions/conftest.py +++ b/tests/ops/functions/conftest.py @@ -38,7 +38,6 @@ qml.sum(qml.X(0), qml.X(0), qml.Z(0), qml.Z(0)), qml.BasisState([1], wires=[0]), qml.ControlledQubitUnitary(np.eye(2), control_wires=1, wires=0), - qml.QubitStateVector([0, 1], wires=0), qml.QubitChannel([np.array([[1, 0], [0, 0.8]]), np.array([[0, 0.6], [0, 0]])], wires=0), qml.MultiControlledX(wires=[0, 1]), qml.Projector([1], 0), # the state-vector version is already tested @@ -137,6 +136,7 @@ PowOpObs, PowOperation, PowObs, + qml.QubitStateVector, } """Types that should not have actual instances created.""" diff --git a/tests/ops/qubit/test_state_prep.py b/tests/ops/qubit/test_state_prep.py index 342aaff5df0..e6da832a8eb 100644 --- a/tests/ops/qubit/test_state_prep.py +++ b/tests/ops/qubit/test_state_prep.py @@ -36,6 +36,12 @@ def test_adjoint_error_exception(op): op.adjoint() +def test_QubitStateVector_is_deprecated(): + """Test that QubitStateVector is deprecated.""" + with pytest.warns(qml.PennyLaneDeprecationWarning, match="QubitStateVector is deprecated"): + _ = qml.QubitStateVector([1, 0, 0, 0], wires=[0, 1]) + + @pytest.mark.parametrize( "op, mat, base", [ From 7e21a762b65ed4f21f7e048ed199deac6a057646 Mon Sep 17 00:00:00 2001 From: Pietropaolo Frisoni Date: Tue, 17 Sep 2024 11:43:21 -0400 Subject: [PATCH 4/4] PennyLane is compatible with JAX 0.4.28 (#6255) **Context:** As part of the effort to make PL compatible with Numpy 2.0 (see #6061), we need to upgrade JAX to 0.4.26+ since such a version introduced the support for Numpy 2.0. We opted for JAX 0.4.28 since it is the same version used by Catalyst. **Description of the Change:** As above. **Benefits:** PL is compatible with Numpy 2.0 and Jax 0.4.28. **Possible Drawbacks:** - From JAX 0.4.27, in `jax.jit`, passing invalid static_argnums or static_argnames now leads to an error rather than a warning. In PL, this breaks every test where we set `shots` in the `QNode` call with `static_argnames=["shots"]`. At this stage, we decided to mark such tests with `pytest.xfail` to allow the upgrade. **Related GitHub Issues:** None. **Related Shortcut Stories**: [sc-61389] --------- Co-authored-by: dwierichs --- .github/workflows/install_deps/action.yml | 4 ++-- doc/releases/changelog-dev.md | 3 +++ .../devices/default_qubit/test_default_qubit.py | 2 ++ .../qutrit_mixed/test_qutrit_mixed_measure.py | 2 +- tests/devices/test_default_qutrit_mixed.py | 2 +- tests/interfaces/test_jax_jit_qnode.py | 17 +++++++++++++++++ .../test_optimization_utils.py | 13 +++++++++++-- 7 files changed, 37 insertions(+), 6 deletions(-) diff --git a/.github/workflows/install_deps/action.yml b/.github/workflows/install_deps/action.yml index 99b77dc8157..e1dc636bb84 100644 --- a/.github/workflows/install_deps/action.yml +++ b/.github/workflows/install_deps/action.yml @@ -15,7 +15,7 @@ inputs: jax_version: description: The version of JAX to install for any job that requires JAX required: false - default: '0.4.23' + default: '0.4.28' install_tensorflow: description: Indicate if TensorFlow should be installed or not required: false @@ -86,7 +86,7 @@ runs: if: inputs.install_jax == 'true' env: JAX_VERSION: ${{ inputs.jax_version != '' && format('=={0}', inputs.jax_version) || '' }} - run: pip install "jax${{ env.JAX_VERSION}}" "jaxlib${{ env.JAX_VERSION }}" scipy~=1.12.0 + run: pip install "jax${{ env.JAX_VERSION}}" "jaxlib${{ env.JAX_VERSION }}" - name: Install additional PIP packages shell: bash diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 3f3a3c04326..b03ff7b4f72 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -10,6 +10,9 @@ [(#6061)](https://github.com/PennyLaneAI/pennylane/pull/6061) [(#6258)](https://github.com/PennyLaneAI/pennylane/pull/6258) +* PennyLane is now compatible with Jax 0.4.28. + [(#6255)](https://github.com/PennyLaneAI/pennylane/pull/6255) + * `qml.qchem.excitations` now optionally returns fermionic operators. [(#6171)](https://github.com/PennyLaneAI/pennylane/pull/6171) diff --git a/tests/devices/default_qubit/test_default_qubit.py b/tests/devices/default_qubit/test_default_qubit.py index d3049d90eae..6820b0afdcb 100644 --- a/tests/devices/default_qubit/test_default_qubit.py +++ b/tests/devices/default_qubit/test_default_qubit.py @@ -1864,6 +1864,7 @@ def circ_expected(): if use_jit: import jax + pytest.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28") circ_postselect = jax.jit(circ_postselect, static_argnames=["shots"]) res = circ_postselect(param, shots=shots) @@ -2051,6 +2052,7 @@ def circ(): if use_jit: import jax + pytest.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28") circ = jax.jit(circ, static_argnames=["shots"]) res = circ(shots=shots) diff --git a/tests/devices/qutrit_mixed/test_qutrit_mixed_measure.py b/tests/devices/qutrit_mixed/test_qutrit_mixed_measure.py index 8a8de382f57..56d61fa339b 100644 --- a/tests/devices/qutrit_mixed/test_qutrit_mixed_measure.py +++ b/tests/devices/qutrit_mixed/test_qutrit_mixed_measure.py @@ -499,7 +499,7 @@ def test_jax_backprop(self, use_jit): x = jax.numpy.array(self.x, dtype=jax.numpy.float64) coeffs = (5.2, 6.7) - f = jax.jit(self.f, static_argnums=(1, 2, 3, 4)) if use_jit else self.f + f = jax.jit(self.f, static_argnums=(1, 2, 3)) if use_jit else self.f out = f(x, coeffs) expected_out = self.expected(x, coeffs) diff --git a/tests/devices/test_default_qutrit_mixed.py b/tests/devices/test_default_qutrit_mixed.py index 5178e1c800a..13f3d744bb1 100644 --- a/tests/devices/test_default_qutrit_mixed.py +++ b/tests/devices/test_default_qutrit_mixed.py @@ -823,7 +823,7 @@ def test_jax_backprop(self, use_jit): x = jax.numpy.array(self.x, dtype=jax.numpy.float64) coeffs = (5.2, 6.7) - f = jax.jit(self.f, static_argnums=(1, 2, 3, 4)) if use_jit else self.f + f = jax.jit(self.f, static_argnums=(1, 2, 3)) if use_jit else self.f out = f(x, coeffs) expected_out = self.expected(x, coeffs) diff --git a/tests/interfaces/test_jax_jit_qnode.py b/tests/interfaces/test_jax_jit_qnode.py index cafa9c47fa1..99ac281ef8f 100644 --- a/tests/interfaces/test_jax_jit_qnode.py +++ b/tests/interfaces/test_jax_jit_qnode.py @@ -813,6 +813,7 @@ def circuit(a, b): res = circuit(a, b, shots=100) # pylint: disable=unexpected-keyword-arg assert res.shape == (100, 2) # pylint:disable=comparison-with-callable + @pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28") def test_gradient_integration(self, interface): """Test that temporarily setting the shots works for gradient computations""" @@ -912,6 +913,7 @@ def circuit(x): class TestQubitIntegration: """Tests that ensure various qubit circuits integrate correctly""" + @pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28") def test_sampling(self, dev, diff_method, grad_on_execution, device_vjp, interface): """Test sampling works as expected""" if grad_on_execution: @@ -941,6 +943,7 @@ def circuit(): assert isinstance(res[1], jax.Array) assert res[1].shape == (10,) + @pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28") def test_counts(self, dev, diff_method, grad_on_execution, device_vjp, interface): """Test counts works as expected""" if grad_on_execution: @@ -2041,6 +2044,7 @@ def circ(p, U): class TestReturn: """Class to test the shape of the Grad/Jacobian with different return types.""" + @pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28") def test_grad_single_measurement_param( self, dev, diff_method, grad_on_execution, device_vjp, jacobian, shots, interface ): @@ -2073,6 +2077,7 @@ def circuit(a): assert isinstance(grad, jax.numpy.ndarray) assert grad.shape == () + @pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28") def test_grad_single_measurement_multiple_param( self, dev, diff_method, grad_on_execution, device_vjp, jacobian, shots, interface ): @@ -2110,6 +2115,7 @@ def circuit(a, b): assert grad[0].shape == () assert grad[1].shape == () + @pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28") def test_grad_single_measurement_multiple_param_array( self, dev, diff_method, grad_on_execution, device_vjp, jacobian, shots, interface ): @@ -2142,6 +2148,7 @@ def circuit(a): assert isinstance(grad, jax.numpy.ndarray) assert grad.shape == (2,) + @pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28") def test_jacobian_single_measurement_param_probs( self, dev, diff_method, grad_on_execution, device_vjp, jacobian, shots, interface ): @@ -2175,6 +2182,7 @@ def circuit(a): assert isinstance(jac, jax.numpy.ndarray) assert jac.shape == (4,) + @pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28") def test_jacobian_single_measurement_probs_multiple_param( self, dev, diff_method, grad_on_execution, device_vjp, jacobian, shots, interface ): @@ -2214,6 +2222,7 @@ def circuit(a, b): assert isinstance(jac[1], jax.numpy.ndarray) assert jac[1].shape == (4,) + @pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28") def test_jacobian_single_measurement_probs_multiple_param_single_array( self, dev, diff_method, grad_on_execution, device_vjp, jacobian, shots, interface ): @@ -2246,6 +2255,7 @@ def circuit(a): assert isinstance(jac, jax.numpy.ndarray) assert jac.shape == (4, 2) + @pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28") def test_jacobian_expval_expval_multiple_params( self, dev, diff_method, grad_on_execution, device_vjp, jacobian, shots, interface ): @@ -2295,6 +2305,7 @@ def circuit(x, y): assert isinstance(jac[1][1], jax.numpy.ndarray) assert jac[1][1].shape == () + @pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28") def test_jacobian_expval_expval_multiple_params_array( self, dev, diff_method, grad_on_execution, device_vjp, jacobian, shots, interface ): @@ -2333,6 +2344,7 @@ def circuit(a): assert isinstance(jac[1], jax.numpy.ndarray) assert jac[1].shape == (2,) + @pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28") def test_jacobian_var_var_multiple_params( self, dev, diff_method, grad_on_execution, device_vjp, jacobian, shots, interface ): @@ -2385,6 +2397,7 @@ def circuit(x, y): assert isinstance(jac[1][1], jax.numpy.ndarray) assert jac[1][1].shape == () + @pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28") def test_jacobian_var_var_multiple_params_array( self, dev, diff_method, grad_on_execution, device_vjp, jacobian, shots, interface ): @@ -2425,6 +2438,7 @@ def circuit(a): assert isinstance(jac[1], jax.numpy.ndarray) assert jac[1].shape == (2,) + @pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28") def test_jacobian_multiple_measurement_single_param( self, dev, diff_method, grad_on_execution, device_vjp, jacobian, shots, interface ): @@ -2463,6 +2477,7 @@ def circuit(a): assert isinstance(jac[1], jax.numpy.ndarray) assert jac[1].shape == (4,) + @pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28") def test_jacobian_multiple_measurement_multiple_param( self, dev, diff_method, grad_on_execution, device_vjp, jacobian, shots, interface ): @@ -2510,6 +2525,7 @@ def circuit(a, b): assert isinstance(jac[1][1], jax.numpy.ndarray) assert jac[1][1].shape == (4,) + @pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28") def test_jacobian_multiple_measurement_multiple_param_array( self, dev, diff_method, grad_on_execution, device_vjp, jacobian, shots, interface ): @@ -2871,6 +2887,7 @@ def circuit(x): assert hess[1].shape == (2, 2, 2) +@pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28") @pytest.mark.parametrize("hessian", hessian_fn) @pytest.mark.parametrize("diff_method", ["parameter-shift", "hadamard"]) def test_jax_device_hessian_shots(hessian, diff_method): diff --git a/tests/transforms/test_optimization/test_optimization_utils.py b/tests/transforms/test_optimization/test_optimization_utils.py index ff31cd999c6..19145b7ce53 100644 --- a/tests/transforms/test_optimization/test_optimization_utils.py +++ b/tests/transforms/test_optimization/test_optimization_utils.py @@ -238,8 +238,17 @@ def test_jacobian_jax(self, use_jit): special_angles = np.array(list(product(special_points, repeat=6))).reshape((-1, 2, 3)) random_angles = np.random.random((1000, 2, 3)) # Need holomorphic derivatives and complex inputs because the output matrices are complex - all_angles = jax.numpy.concatenate([special_angles, random_angles], dtype=complex) - jac_fn = lambda fn: jax.vmap(jax.jacobian(fn, holomorphic=True)) + all_angles = jax.numpy.concatenate([special_angles, random_angles]) + + # We need to define the Jacobian function manually because fuse_rot_angles is not guaranteed to be holomorphic, + # and jax.jacobian requires real-valued outputs for non-holomorphic functions. + def jac_fn(fn): + real_fn = lambda arg: qml.math.real(fn(arg)) + imag_fn = lambda arg: qml.math.imag(fn(arg)) + real_jac_fn = jax.vmap(jax.jacobian(real_fn)) + imag_jac_fn = jax.vmap(jax.jacobian(imag_fn)) + return lambda arg: real_jac_fn(arg) + 1j * imag_jac_fn(arg) + jit_fn = jax.jit if use_jit else None self.run_jacobian_test(all_angles, jac_fn, is_batched=True, jit_fn=jit_fn)