From 3bf2964e7038c8caed17ab722e2c45dc3bcf46e9 Mon Sep 17 00:00:00 2001 From: David Wierichs Date: Fri, 16 Jun 2023 15:49:27 +0200 Subject: [PATCH] Raise a warning when applying a pulse gradient transform to a QNode directly (#4241) * introduce warning and tests * changelog, recommendation * docstrings * fix test * switch to raising an error instead * Apply suggestions from code review Co-authored-by: Tom Bromley <49409390+trbromley@users.noreply.github.com> * fix tests --------- Co-authored-by: Korbinian Kottmann Co-authored-by: Tom Bromley <49409390+trbromley@users.noreply.github.com> --- doc/releases/changelog-dev.md | 4 ++ .../gradients/pulse_generator_gradient.py | 22 +++++++- pennylane/gradients/pulse_gradient.py | 31 +++++++++-- .../core/test_pulse_generator_gradient.py | 24 +++++++-- tests/gradients/core/test_pulse_gradient.py | 52 ++++++++++++++++++- 5 files changed, 122 insertions(+), 11 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 4181efd47cb..078379ace37 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -63,6 +63,10 @@

Improvements 🛠

+* The pulse differentiation methods, `pulse_generator` and `stoch_pulse_grad` now raise an error when they + are applied to a `QNode` directly. Instead, use differentiation via a JAX entry point (`jax.grad`, `jax.jacobian`, ...). + [(4241)](https://github.com/PennyLaneAI/pennylane/pull/4241) + * `pulse.ParametrizedEvolution` now raises an error if the number of input parameters does not match the number of parametrized coefficients in the `ParametrizedHamiltonian` that generates it. An exception is made for `HardwareHamiltonian`s which are not checked. diff --git a/pennylane/gradients/pulse_generator_gradient.py b/pennylane/gradients/pulse_generator_gradient.py index 0a5868a2d5b..144a49d241f 100644 --- a/pennylane/gradients/pulse_generator_gradient.py +++ b/pennylane/gradients/pulse_generator_gradient.py @@ -25,7 +25,7 @@ from pennylane.measurements import Shots from .parameter_shift import _make_zero_rep -from .pulse_gradient import _assert_has_jax +from .pulse_gradient import _assert_has_jax, raise_pulse_diff_on_qnode from .gradient_transform import ( _all_zero_grad, assert_active_return, @@ -455,7 +455,14 @@ def _pulse_generator(tape, argnum=None, shots=None, atol=1e-7): This function requires the JAX interface and does not work with other autodiff interfaces commonly encountered with PennyLane. - In addition, this transform is only JIT-compatible with pulses that only have scalar parameters. + In addition, this transform is only JIT-compatible with pulses that only have scalar + parameters. + + .. warning:: + + This transform may not be applied directly to QNodes. Use JAX entrypoints + (``jax.grad``, ``jax.jacobian``, ...) instead or apply the transform on the tape + level. Also see the examples below. **Example** @@ -702,3 +709,14 @@ def expand_invalid_trainable_pulse_generator(x, *args, **kwargs): pulse_generator = gradient_transform( _pulse_generator, expand_fn=expand_invalid_trainable_pulse_generator ) + + +@pulse_generator.custom_qnode_wrapper +def pulse_generator_qnode_wrapper(self, qnode, targs, tkwargs): + """A custom QNode wrapper for the gradient transform :func:`~.pulse_generator`. + It raises an error, so that applying ``pulse_generator`` to a ``QNode`` directly + is not supported. + """ + # pylint:disable=unused-argument + transform_name = "pulse generator parameter-shift" + raise_pulse_diff_on_qnode(transform_name) diff --git a/pennylane/gradients/pulse_gradient.py b/pennylane/gradients/pulse_gradient.py index c2911522b9e..c89d710b8c0 100644 --- a/pennylane/gradients/pulse_gradient.py +++ b/pennylane/gradients/pulse_gradient.py @@ -56,6 +56,19 @@ def _assert_has_jax(transform_name): ) +def raise_pulse_diff_on_qnode(transform_name): + """Raises an error as the gradient transform with the provided name does + not support direct application to QNodes. + """ + msg = ( + f"Applying the {transform_name} gradient transform to a QNode directly is currently " + "not supported. Please use differentiation via a JAX entry point " + "(jax.grad, jax.jacobian, ...) instead.", + UserWarning, + ) + raise NotImplementedError(msg) + + def _split_evol_ops(op, ob, tau): r"""Randomly split a ``ParametrizedEvolution`` with respect to time into two operations and insert a Pauli rotation using a given Pauli word and rotation angles :math:`\pm\pi/2`. @@ -313,10 +326,11 @@ def _stoch_pulse_grad( rules when used with simple pulses (see details and examples below), potentially leading to imprecise results and/or unnecessarily large computational efforts. - .. note:: + .. warning:: - Currently this function only supports pulses for which each *parametrized* term is a - simple Pauli word. More general Hamiltonian terms are not supported yet. + This transform may not be applied directly to QNodes. Use JAX entrypoints + (``jax.grad``, ``jax.jacobian``, ...) instead or apply the transform on the tape level. + Also see the examples below. **Examples** @@ -682,3 +696,14 @@ def expand_invalid_trainable_stoch_pulse_grad(x, *args, **kwargs): stoch_pulse_grad = gradient_transform( _stoch_pulse_grad, expand_fn=expand_invalid_trainable_stoch_pulse_grad ) + + +@stoch_pulse_grad.custom_qnode_wrapper +def stoch_pulse_grad_qnode_wrapper(self, qnode, targs, tkwargs): + """A custom QNode wrapper for the gradient transform :func:`~.stoch_pulse_grad`. + It raises an error, so that applying ``pulse_generator`` to a ``QNode`` directly + is not supported. + """ + # pylint:disable=unused-argument + transform_name = "stochastic pulse parameter-shift" + raise_pulse_diff_on_qnode(transform_name) diff --git a/tests/gradients/core/test_pulse_generator_gradient.py b/tests/gradients/core/test_pulse_generator_gradient.py index 0011d4d8b1d..3692cca3b4c 100644 --- a/tests/gradients/core/test_pulse_generator_gradient.py +++ b/tests/gradients/core/test_pulse_generator_gradient.py @@ -1121,6 +1121,24 @@ def circuit(par): class TestPulseGeneratorQNode: """Test that pulse_generator integrates correctly with QNodes.""" + def test_raises_for_application_to_qnodes(self): + """Test that an error is raised when applying ``stoch_pulse_grad`` + to a QNode directly.""" + + dev = qml.device("default.qubit.jax", wires=1) + ham_single_q_const = qml.pulse.constant * qml.PauliY(0) + + @qml.qnode(dev, interface="jax") + def circuit(params): + qml.evolve(ham_single_q_const)([params], 0.2) + return qml.expval(qml.PauliZ(0)) + + _match = "pulse generator parameter-shift gradient transform to a QNode directly" + with pytest.raises(NotImplementedError, match=_match): + pulse_generator(circuit) + + # TODO: include the following tests when #4225 is resolved. + @pytest.mark.skip("Applying this gradient transform to QNodes directly is not supported.") def test_qnode_expval_single_par(self): """Test that a simple qnode that returns an expectation value can be differentiated with pulse_generator.""" @@ -1146,8 +1164,7 @@ def circuit(params): assert jnp.allclose(grad, exp_grad) assert tracker.totals["executions"] == 2 # two shifted tapes - # Applying QNode-level gradient transforms with non-scalar parameters is not supported yet - @pytest.mark.xfail + @pytest.mark.skip("Applying this gradient transform to QNodes directly is not supported.") def test_qnode_expval_probs_single_par(self): """Test that a simple qnode that returns an expectation value can be differentiated with pulse_generator.""" @@ -1179,8 +1196,7 @@ def circuit(params): for j, e in zip(jac, exp_jac): assert qml.math.allclose(j, e) - # Applying QNode-level gradient transforms with non-scalar parameters is not supported yet - @pytest.mark.xfail + @pytest.mark.skip("Applying this gradient transform to QNodes directly is not supported.") def test_qnode_probs_expval_multi_par(self): """Test that a simple qnode that returns probabilities can be differentiated with pulse_generator.""" diff --git a/tests/gradients/core/test_pulse_gradient.py b/tests/gradients/core/test_pulse_gradient.py index e9a89bac0cc..38347846fce 100644 --- a/tests/gradients/core/test_pulse_gradient.py +++ b/tests/gradients/core/test_pulse_gradient.py @@ -803,8 +803,56 @@ def test_shots_attribute(self, shots): @pytest.mark.jax -class TestStochPulseGradQNodeIntegration: - """Test that stoch_pulse_grad integrates correctly with QNodes.""" +class TestStochPulseGradQNode: + """Test that pulse_generator integrates correctly with QNodes.""" + + def test_raises_for_application_to_qnodes(self): + """Test that an error is raised when applying ``stoch_pulse_grad`` + to a QNode directly.""" + dev = qml.device("default.qubit.jax", wires=1) + ham_single_q_const = qml.pulse.constant * qml.PauliY(0) + + @qml.qnode(dev, interface="jax") + def circuit(params): + qml.evolve(ham_single_q_const)([params], 0.2) + return qml.expval(qml.PauliZ(0)) + + _match = "stochastic pulse parameter-shift gradient transform to a QNode directly" + with pytest.raises(NotImplementedError, match=_match): + stoch_pulse_grad(circuit, num_split_times=2) + + # TODO: include the following tests when #4225 is resolved. + @pytest.mark.skip("Applying this gradient transform to QNodes directly is not supported.") + def test_qnode_expval_single_par(self): + """Test that a simple qnode that returns an expectation value + can be differentiated with pulse_generator.""" + import jax + import jax.numpy as jnp + + jax.config.update("jax_enable_x64", True) + dev = qml.device("default.qubit.jax", wires=1) + T = 0.2 + ham_single_q_const = qml.pulse.constant * qml.PauliY(0) + + @qml.qnode(dev, interface="jax") + def circuit(params): + qml.evolve(ham_single_q_const)([params], T) + return qml.expval(qml.PauliZ(0)) + + params = jnp.array(0.4) + with qml.Tracker(dev) as tracker: + _match = "stochastic pulse parameter-shift .* scalar pulse parameters." + grad = stoch_pulse_grad(circuit, num_split_times=2)(params) + + p = params * T + exp_grad = -2 * jnp.sin(2 * p) * T + assert jnp.allclose(grad, exp_grad) + assert tracker.totals["executions"] == 4 # two shifted tapes, two splitting times + + +@pytest.mark.jax +class TestStochPulseGradIntegration: + """Test that stoch_pulse_grad integrates correctly with QNodes and ML interfaces.""" @pytest.mark.parametrize("shots, tol", [(None, 1e-4), (100, 0.1), ([100, 99], 0.1)]) @pytest.mark.parametrize("num_split_times", [1, 2])