diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 4181efd47cb..078379ace37 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -63,6 +63,10 @@
Improvements ðŸ›
+* The pulse differentiation methods, `pulse_generator` and `stoch_pulse_grad` now raise an error when they
+ are applied to a `QNode` directly. Instead, use differentiation via a JAX entry point (`jax.grad`, `jax.jacobian`, ...).
+ [(4241)](https://github.com/PennyLaneAI/pennylane/pull/4241)
+
* `pulse.ParametrizedEvolution` now raises an error if the number of input parameters does not match the number
of parametrized coefficients in the `ParametrizedHamiltonian` that generates it. An exception is made for
`HardwareHamiltonian`s which are not checked.
diff --git a/pennylane/gradients/pulse_generator_gradient.py b/pennylane/gradients/pulse_generator_gradient.py
index 0a5868a2d5b..144a49d241f 100644
--- a/pennylane/gradients/pulse_generator_gradient.py
+++ b/pennylane/gradients/pulse_generator_gradient.py
@@ -25,7 +25,7 @@
from pennylane.measurements import Shots
from .parameter_shift import _make_zero_rep
-from .pulse_gradient import _assert_has_jax
+from .pulse_gradient import _assert_has_jax, raise_pulse_diff_on_qnode
from .gradient_transform import (
_all_zero_grad,
assert_active_return,
@@ -455,7 +455,14 @@ def _pulse_generator(tape, argnum=None, shots=None, atol=1e-7):
This function requires the JAX interface and does not work with other autodiff interfaces
commonly encountered with PennyLane.
- In addition, this transform is only JIT-compatible with pulses that only have scalar parameters.
+ In addition, this transform is only JIT-compatible with pulses that only have scalar
+ parameters.
+
+ .. warning::
+
+ This transform may not be applied directly to QNodes. Use JAX entrypoints
+ (``jax.grad``, ``jax.jacobian``, ...) instead or apply the transform on the tape
+ level. Also see the examples below.
**Example**
@@ -702,3 +709,14 @@ def expand_invalid_trainable_pulse_generator(x, *args, **kwargs):
pulse_generator = gradient_transform(
_pulse_generator, expand_fn=expand_invalid_trainable_pulse_generator
)
+
+
+@pulse_generator.custom_qnode_wrapper
+def pulse_generator_qnode_wrapper(self, qnode, targs, tkwargs):
+ """A custom QNode wrapper for the gradient transform :func:`~.pulse_generator`.
+ It raises an error, so that applying ``pulse_generator`` to a ``QNode`` directly
+ is not supported.
+ """
+ # pylint:disable=unused-argument
+ transform_name = "pulse generator parameter-shift"
+ raise_pulse_diff_on_qnode(transform_name)
diff --git a/pennylane/gradients/pulse_gradient.py b/pennylane/gradients/pulse_gradient.py
index c2911522b9e..c89d710b8c0 100644
--- a/pennylane/gradients/pulse_gradient.py
+++ b/pennylane/gradients/pulse_gradient.py
@@ -56,6 +56,19 @@ def _assert_has_jax(transform_name):
)
+def raise_pulse_diff_on_qnode(transform_name):
+ """Raises an error as the gradient transform with the provided name does
+ not support direct application to QNodes.
+ """
+ msg = (
+ f"Applying the {transform_name} gradient transform to a QNode directly is currently "
+ "not supported. Please use differentiation via a JAX entry point "
+ "(jax.grad, jax.jacobian, ...) instead.",
+ UserWarning,
+ )
+ raise NotImplementedError(msg)
+
+
def _split_evol_ops(op, ob, tau):
r"""Randomly split a ``ParametrizedEvolution`` with respect to time into two operations and
insert a Pauli rotation using a given Pauli word and rotation angles :math:`\pm\pi/2`.
@@ -313,10 +326,11 @@ def _stoch_pulse_grad(
rules when used with simple pulses (see details and examples below), potentially leading
to imprecise results and/or unnecessarily large computational efforts.
- .. note::
+ .. warning::
- Currently this function only supports pulses for which each *parametrized* term is a
- simple Pauli word. More general Hamiltonian terms are not supported yet.
+ This transform may not be applied directly to QNodes. Use JAX entrypoints
+ (``jax.grad``, ``jax.jacobian``, ...) instead or apply the transform on the tape level.
+ Also see the examples below.
**Examples**
@@ -682,3 +696,14 @@ def expand_invalid_trainable_stoch_pulse_grad(x, *args, **kwargs):
stoch_pulse_grad = gradient_transform(
_stoch_pulse_grad, expand_fn=expand_invalid_trainable_stoch_pulse_grad
)
+
+
+@stoch_pulse_grad.custom_qnode_wrapper
+def stoch_pulse_grad_qnode_wrapper(self, qnode, targs, tkwargs):
+ """A custom QNode wrapper for the gradient transform :func:`~.stoch_pulse_grad`.
+ It raises an error, so that applying ``pulse_generator`` to a ``QNode`` directly
+ is not supported.
+ """
+ # pylint:disable=unused-argument
+ transform_name = "stochastic pulse parameter-shift"
+ raise_pulse_diff_on_qnode(transform_name)
diff --git a/tests/gradients/core/test_pulse_generator_gradient.py b/tests/gradients/core/test_pulse_generator_gradient.py
index 0011d4d8b1d..3692cca3b4c 100644
--- a/tests/gradients/core/test_pulse_generator_gradient.py
+++ b/tests/gradients/core/test_pulse_generator_gradient.py
@@ -1121,6 +1121,24 @@ def circuit(par):
class TestPulseGeneratorQNode:
"""Test that pulse_generator integrates correctly with QNodes."""
+ def test_raises_for_application_to_qnodes(self):
+ """Test that an error is raised when applying ``stoch_pulse_grad``
+ to a QNode directly."""
+
+ dev = qml.device("default.qubit.jax", wires=1)
+ ham_single_q_const = qml.pulse.constant * qml.PauliY(0)
+
+ @qml.qnode(dev, interface="jax")
+ def circuit(params):
+ qml.evolve(ham_single_q_const)([params], 0.2)
+ return qml.expval(qml.PauliZ(0))
+
+ _match = "pulse generator parameter-shift gradient transform to a QNode directly"
+ with pytest.raises(NotImplementedError, match=_match):
+ pulse_generator(circuit)
+
+ # TODO: include the following tests when #4225 is resolved.
+ @pytest.mark.skip("Applying this gradient transform to QNodes directly is not supported.")
def test_qnode_expval_single_par(self):
"""Test that a simple qnode that returns an expectation value
can be differentiated with pulse_generator."""
@@ -1146,8 +1164,7 @@ def circuit(params):
assert jnp.allclose(grad, exp_grad)
assert tracker.totals["executions"] == 2 # two shifted tapes
- # Applying QNode-level gradient transforms with non-scalar parameters is not supported yet
- @pytest.mark.xfail
+ @pytest.mark.skip("Applying this gradient transform to QNodes directly is not supported.")
def test_qnode_expval_probs_single_par(self):
"""Test that a simple qnode that returns an expectation value
can be differentiated with pulse_generator."""
@@ -1179,8 +1196,7 @@ def circuit(params):
for j, e in zip(jac, exp_jac):
assert qml.math.allclose(j, e)
- # Applying QNode-level gradient transforms with non-scalar parameters is not supported yet
- @pytest.mark.xfail
+ @pytest.mark.skip("Applying this gradient transform to QNodes directly is not supported.")
def test_qnode_probs_expval_multi_par(self):
"""Test that a simple qnode that returns probabilities
can be differentiated with pulse_generator."""
diff --git a/tests/gradients/core/test_pulse_gradient.py b/tests/gradients/core/test_pulse_gradient.py
index e9a89bac0cc..38347846fce 100644
--- a/tests/gradients/core/test_pulse_gradient.py
+++ b/tests/gradients/core/test_pulse_gradient.py
@@ -803,8 +803,56 @@ def test_shots_attribute(self, shots):
@pytest.mark.jax
-class TestStochPulseGradQNodeIntegration:
- """Test that stoch_pulse_grad integrates correctly with QNodes."""
+class TestStochPulseGradQNode:
+ """Test that pulse_generator integrates correctly with QNodes."""
+
+ def test_raises_for_application_to_qnodes(self):
+ """Test that an error is raised when applying ``stoch_pulse_grad``
+ to a QNode directly."""
+ dev = qml.device("default.qubit.jax", wires=1)
+ ham_single_q_const = qml.pulse.constant * qml.PauliY(0)
+
+ @qml.qnode(dev, interface="jax")
+ def circuit(params):
+ qml.evolve(ham_single_q_const)([params], 0.2)
+ return qml.expval(qml.PauliZ(0))
+
+ _match = "stochastic pulse parameter-shift gradient transform to a QNode directly"
+ with pytest.raises(NotImplementedError, match=_match):
+ stoch_pulse_grad(circuit, num_split_times=2)
+
+ # TODO: include the following tests when #4225 is resolved.
+ @pytest.mark.skip("Applying this gradient transform to QNodes directly is not supported.")
+ def test_qnode_expval_single_par(self):
+ """Test that a simple qnode that returns an expectation value
+ can be differentiated with pulse_generator."""
+ import jax
+ import jax.numpy as jnp
+
+ jax.config.update("jax_enable_x64", True)
+ dev = qml.device("default.qubit.jax", wires=1)
+ T = 0.2
+ ham_single_q_const = qml.pulse.constant * qml.PauliY(0)
+
+ @qml.qnode(dev, interface="jax")
+ def circuit(params):
+ qml.evolve(ham_single_q_const)([params], T)
+ return qml.expval(qml.PauliZ(0))
+
+ params = jnp.array(0.4)
+ with qml.Tracker(dev) as tracker:
+ _match = "stochastic pulse parameter-shift .* scalar pulse parameters."
+ grad = stoch_pulse_grad(circuit, num_split_times=2)(params)
+
+ p = params * T
+ exp_grad = -2 * jnp.sin(2 * p) * T
+ assert jnp.allclose(grad, exp_grad)
+ assert tracker.totals["executions"] == 4 # two shifted tapes, two splitting times
+
+
+@pytest.mark.jax
+class TestStochPulseGradIntegration:
+ """Test that stoch_pulse_grad integrates correctly with QNodes and ML interfaces."""
@pytest.mark.parametrize("shots, tol", [(None, 1e-4), (100, 0.1), ([100, 99], 0.1)])
@pytest.mark.parametrize("num_split_times", [1, 2])