Skip to content

Commit

Permalink
Allow gradient transforms applied to broadcasted tapes if the broadca…
Browse files Browse the repository at this point in the history
…sted parameter(s) are not trainable (#5452)

**Context:**
In #4462 we discussed the differentiation of broadcasted tapes/QNodes
with gradient transforms. Due to some technical debt in the code for
JVPs and due to unclear desired behaviour, applying gradient transforms
to such QNodes was disallowed in #4480.
However, this also disallows applying gradient transforms to broadcasted
tapes for which the broadcasted parameter(s) is/are not trainable, even
though the problems discussed in #4462 do not apply in this case.

**Description of the Change:**
This PR lifts the restriction of applying gradient transforms to
broadcasted tapes for the case that no trainable parameters are
broadcasted.
The problem described in [this forum
post](https://discuss.pennylane.ai/t/qml-training-using-shot-based-simulator/4206)
can be resolved by this, for example.

**Benefits:**
Unlocks gradient transforms for broadcasted tapes/QNodes as long as the
broadcasting is in non-trainable parameters (which is quite common for
QML applications: You want to broadcast across data points rather than
trainable parameters.).

**Possible Drawbacks:**
Usage of `tape.trainable_params` although we don't want to rely on this
property too much in the future.
  • Loading branch information
dwierichs committed Apr 17, 2024
1 parent 6e1f810 commit 9a03cce
Show file tree
Hide file tree
Showing 15 changed files with 203 additions and 41 deletions.
4 changes: 4 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@

<h3>Improvements 🛠</h3>

* Gradient transforms may now be applied to batched/broadcasted QNodes, as long as the
broadcasting is in non-trainable parameters.
[(#5452)](https://github.com/PennyLaneAI/pennylane/pull/5452)

* Improve the performance of computing the matrix of `qml.QFT`
[(#5351)](https://github.com/PennyLaneAI/pennylane/pull/5351)

Expand Down
4 changes: 2 additions & 2 deletions pennylane/gradients/finite_difference.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from .general_shift_rules import generate_shifted_tapes
from .gradient_transform import (
_all_zero_grad,
assert_no_tape_batching,
assert_no_trainable_tape_batching,
choose_trainable_params,
find_and_validate_gradient_methods,
_no_trainable_grad,
Expand Down Expand Up @@ -369,7 +369,7 @@ def finite_diff(
"""

transform_name = "finite difference"
assert_no_tape_batching(tape, transform_name)
assert_no_trainable_tape_batching(tape, transform_name)

if any(qml.math.get_dtype_name(p) == "float32" for p in tape.get_parameters()):
warn(
Expand Down
18 changes: 12 additions & 6 deletions pennylane/gradients/gradient_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,18 +99,24 @@ def assert_no_variance(measurements, transform_name):
)


def assert_no_tape_batching(tape, transform_name):
def assert_no_trainable_tape_batching(tape, transform_name):
"""Check whether a tape is broadcasted and raise an error if this is the case.
Args:
tape (`~.QuantumScript`): measurements to analyze
transform_name (str): Name of the gradient transform that queries the tape
"""
if tape.batch_size is not None:
raise NotImplementedError(
f"Computing the gradient of broadcasted tapes with the {transform_name} "
"gradient transform is currently not supported. See #4462 for details."
)
if tape.batch_size is None:
return

# Iterate over trainable parameters and check the affiliated operations for batching
for idx in range(len(tape.trainable_params)):
if tape.get_operation(idx)[0].batch_size is not None:
raise NotImplementedError(
"Computing the gradient of broadcasted tapes with respect to the broadcasted "
f"parameters using the {transform_name} gradient transform is currently not "
"supported. See #4462 for details."
)


def choose_trainable_params(tape, argnum=None):
Expand Down
4 changes: 2 additions & 2 deletions pennylane/gradients/hadamard_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from .gradient_transform import (
_all_zero_grad,
assert_no_state_returns,
assert_no_tape_batching,
assert_no_trainable_tape_batching,
assert_no_variance,
choose_trainable_params,
find_and_validate_gradient_methods,
Expand Down Expand Up @@ -234,7 +234,7 @@ def hadamard_grad(
transform_name = "Hadamard test"
assert_no_state_returns(tape.measurements, transform_name)
assert_no_variance(tape.measurements, transform_name)
assert_no_tape_batching(tape, transform_name)
assert_no_trainable_tape_batching(tape, transform_name)
if len(tape.measurements) > 1 and tape.shots.has_partitioned_shots:
raise NotImplementedError(
"hadamard gradient does not support multiple measurements with partitioned shots."
Expand Down
5 changes: 2 additions & 3 deletions pennylane/gradients/parameter_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from .gradient_transform import (
_all_zero_grad,
assert_no_state_returns,
assert_no_tape_batching,
assert_no_trainable_tape_batching,
assert_multimeasure_not_broadcasted,
choose_trainable_params,
find_and_validate_gradient_methods,
Expand Down Expand Up @@ -727,7 +727,6 @@ def var_param_shift(tape, argnum, shifts=None, gradient_recipes=None, f0=None, b

pdA2_fn = None
if non_involutory_indices:

new_measurements = list(tape.measurements)
for i in non_involutory_indices:
# We need to calculate d<A^2>/dp; to do so, we replace the
Expand Down Expand Up @@ -1078,7 +1077,7 @@ def param_shift(
transform_name = "parameter-shift rule"
assert_no_state_returns(tape.measurements, transform_name)
assert_multimeasure_not_broadcasted(tape.measurements, broadcast)
assert_no_tape_batching(tape, transform_name)
assert_no_trainable_tape_batching(tape, transform_name)

if argnum is None and not tape.trainable_params:
return _no_trainable_grad(tape)
Expand Down
4 changes: 2 additions & 2 deletions pennylane/gradients/pulse_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from .gradient_transform import (
_all_zero_grad,
assert_no_state_returns,
assert_no_tape_batching,
assert_no_trainable_tape_batching,
assert_no_variance,
choose_trainable_params,
find_and_validate_gradient_methods,
Expand Down Expand Up @@ -608,7 +608,7 @@ def ansatz(params):
_assert_has_jax(transform_name)
assert_no_state_returns(tape.measurements, transform_name)
assert_no_variance(tape.measurements, transform_name)
assert_no_tape_batching(tape, transform_name)
assert_no_trainable_tape_batching(tape, transform_name)

if num_split_times < 1:
raise ValueError(
Expand Down
4 changes: 2 additions & 2 deletions pennylane/gradients/pulse_gradient_odegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from .gradient_transform import (
_all_zero_grad,
assert_no_state_returns,
assert_no_tape_batching,
assert_no_trainable_tape_batching,
assert_no_variance,
choose_trainable_params,
find_and_validate_gradient_methods,
Expand Down Expand Up @@ -681,7 +681,7 @@ def circuit(params):
_assert_has_jax(transform_name)
assert_no_state_returns(tape.measurements, transform_name)
assert_no_variance(tape.measurements, transform_name)
assert_no_tape_batching(tape, transform_name)
assert_no_trainable_tape_batching(tape, transform_name)

if argnum is None and not tape.trainable_params:
return _no_trainable_grad(tape)
Expand Down
4 changes: 2 additions & 2 deletions pennylane/gradients/spsa_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from .finite_difference import _processing_fn, finite_diff_coeffs
from .gradient_transform import (
_all_zero_grad,
assert_no_tape_batching,
assert_no_trainable_tape_batching,
choose_trainable_params,
find_and_validate_gradient_methods,
_no_trainable_grad,
Expand Down Expand Up @@ -292,7 +292,7 @@ def spsa_grad(
"""

transform_name = "SPSA"
assert_no_tape_batching(tape, transform_name)
assert_no_trainable_tape_batching(tape, transform_name)

if argnum is None and not tape.trainable_params:
return _no_trainable_grad(tape)
Expand Down
2 changes: 1 addition & 1 deletion pennylane/transforms/core/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def qnode_circuit(a):
"The expand transform must have the same signature as the transform"
)

# 3: CHeck the classical co-transform
# 3: Check the classical co-transform
if classical_cotransform is not None and not callable(classical_cotransform):
raise TransformError("The classical co-transform must be a valid Python function.")

Expand Down
28 changes: 25 additions & 3 deletions tests/gradients/core/test_hadamard_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,35 @@ def cost6(x):
class TestHadamardGrad:
"""Unit tests for the hadamard_grad function"""

def test_batched_tape_raises(self):
"""Test that an error is raised for a broadcasted/batched tape."""
def test_trainable_batched_tape_raises(self):
"""Test that an error is raised for a broadcasted/batched tape if the broadcasted
parameter is differentiated."""
tape = qml.tape.QuantumScript([qml.RX([0.4, 0.2], 0)], [qml.expval(qml.PauliZ(0))])
_match = "Computing the gradient of broadcasted tapes with the Hadamard test gradient"
_match = r"Computing the gradient of broadcasted tapes .* using the Hadamard test gradient"
with pytest.raises(NotImplementedError, match=_match):
qml.gradients.hadamard_grad(tape)

def test_nontrainable_batched_tape(self):
"""Test that no error is raised for a broadcasted/batched tape if the broadcasted
parameter is not differentiated, and that the results correspond to the stacked
results of the single-tape derivatives."""
dev = qml.device("default.qubit")
x = [0.4, 0.2]
tape = qml.tape.QuantumScript(
[qml.RY(0.6, 0), qml.RX(x, 0)], [qml.expval(qml.PauliZ(0))], trainable_params=[0]
)
batched_tapes, batched_fn = qml.gradients.hadamard_grad(tape)
batched_grad = batched_fn(dev.execute(batched_tapes))
separate_tapes = [
qml.tape.QuantumScript(
[qml.RY(0.6, 0), qml.RX(_x, 0)], [qml.expval(qml.PauliZ(0))], trainable_params=[0]
)
for _x in x
]
separate_tapes_and_fns = [qml.gradients.hadamard_grad(t) for t in separate_tapes]
separate_grad = [_fn(dev.execute(_tapes)) for _tapes, _fn in separate_tapes_and_fns]
assert np.allclose(batched_grad, separate_grad)

def test_tape_with_partitioned_shots_multiple_measurements_raises(self):
"""Test that an error is raised with multiple measurements and partitioned shots."""
tape = qml.tape.QuantumScript(
Expand Down
35 changes: 32 additions & 3 deletions tests/gradients/core/test_pulse_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,13 +752,42 @@ def test_raises_for_less_than_one_sample(self, num_split_times):
with pytest.raises(ValueError, match="Expected a positive number of samples"):
stoch_pulse_grad(tape, num_split_times=num_split_times)

def test_batched_tape_raises(self):
"""Test that an error is raised for a broadcasted/batched tape."""
def test_trainable_batched_tape_raises(self):
"""Test that an error is raised for a broadcasted/batched tape if the broadcasted
parameter is differentiated."""
tape = qml.tape.QuantumScript([qml.RX([0.4, 0.2], 0)], [qml.expval(qml.PauliZ(0))])
_match = "Computing the gradient of broadcasted tapes with the stochastic pulse"
_match = r"Computing the gradient of broadcasted tapes .* using the stochastic pulse"
with pytest.raises(NotImplementedError, match=_match):
stoch_pulse_grad(tape)

def test_nontrainable_batched_tape(self):
"""Test that no error is raised for a broadcasted/batched tape if the broadcasted
parameter is not differentiated, and that the results correspond to the stacked
results of the single-tape derivatives."""
import jax.numpy as jnp

dev = qml.device("default.qubit")
x = [0.4, 0.2]
params = [jnp.array(0.14)]
ham_single_q_const = qml.pulse.constant * qml.PauliY(0)
op = qml.evolve(ham_single_q_const)(params, 0.1)
tape = qml.tape.QuantumScript(
[qml.RX(x, 0), op], [qml.expval(qml.PauliZ(0))], trainable_params=[1]
)
batched_tapes, batched_fn = stoch_pulse_grad(tape, argnum=0, num_split_times=1)
batched_grad = batched_fn(dev.execute(batched_tapes))
separate_tapes = [
qml.tape.QuantumScript(
[qml.RX(_x, 0), op], [qml.expval(qml.PauliZ(0))], trainable_params=[1]
)
for _x in x
]
separate_tapes_and_fns = [
stoch_pulse_grad(t, argnum=0, num_split_times=1) for t in separate_tapes
]
separate_grad = [_fn(dev.execute(_tapes)) for _tapes, _fn in separate_tapes_and_fns]
assert np.allclose(batched_grad, separate_grad)

@pytest.mark.parametrize("num_meas", [0, 1, 2])
def test_warning_no_trainable_params(self, num_meas):
"""Test that an empty gradient is returned when there are no trainable parameters."""
Expand Down
33 changes: 30 additions & 3 deletions tests/gradients/core/test_pulse_odegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,13 +824,40 @@ def test_raises_with_invalid_op(self):
with pytest.raises(ValueError, match=_match):
pulse_odegen(tape)

def test_batched_tape_raises(self):
"""Test that an error is raised for a broadcasted/batched tape."""
def test_trainable_batched_tape_raises(self):
"""Test that an error is raised for a broadcasted/batched tape if the broadcasted
parameter is differentiated."""
tape = qml.tape.QuantumScript([qml.RX([0.4, 0.2], 0)], [qml.expval(qml.PauliZ(0))])
_match = "Computing the gradient of broadcasted tapes with the pulse generator"
_match = r"Computing the gradient of broadcasted tapes .* using the pulse generator"
with pytest.raises(NotImplementedError, match=_match):
pulse_odegen(tape)

def test_nontrainable_batched_tape(self):
"""Test that no error is raised for a broadcasted/batched tape if the broadcasted
parameter is not differentiated, and that the results correspond to the stacked
results of the single-tape derivatives."""
import jax.numpy as jnp

dev = qml.device("default.qubit")
x = [0.4, 0.2]
params = [jnp.array(0.14)]
ham_single_q_const = qml.pulse.constant * qml.PauliY(0)
op = qml.evolve(ham_single_q_const)(params, 0.1)
tape = qml.tape.QuantumScript(
[qml.RX(x, 0), op], [qml.expval(qml.PauliZ(0))], trainable_params=[1]
)
batched_tapes, batched_fn = pulse_odegen(tape, argnum=0)
batched_grad = batched_fn(dev.execute(batched_tapes))
separate_tapes = [
qml.tape.QuantumScript(
[qml.RX(_x, 0), op], [qml.expval(qml.PauliZ(0))], trainable_params=[1]
)
for _x in x
]
separate_tapes_and_fns = [pulse_odegen(t, argnum=0) for t in separate_tapes]
separate_grad = [_fn(dev.execute(_tapes)) for _tapes, _fn in separate_tapes_and_fns]
assert np.allclose(batched_grad, separate_grad)

def test_no_trainable_params_tape(self):
"""Test that the correct ouput and warning is generated in the absence of any trainable
parameters"""
Expand Down
28 changes: 25 additions & 3 deletions tests/gradients/finite_diff/test_finite_difference.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,35 @@ def test_correct_second_derivative_center_order4(self):
class TestFiniteDiff:
"""Tests for the finite difference gradient transform"""

def test_batched_tape_raises(self):
"""Test that an error is raised for a broadcasted/batched tape."""
def test_trainable_batched_tape_raises(self):
"""Test that an error is raised for a broadcasted/batched tape if the broadcasted
parameter is differentiated."""
tape = qml.tape.QuantumScript([qml.RX([0.4, 0.2], 0)], [qml.expval(qml.PauliZ(0))])
_match = "Computing the gradient of broadcasted tapes with the finite difference"
_match = r"Computing the gradient of broadcasted tapes .* using the finite difference"
with pytest.raises(NotImplementedError, match=_match):
finite_diff(tape)

def test_nontrainable_batched_tape(self):
"""Test that no error is raised for a broadcasted/batched tape if the broadcasted
parameter is not differentiated, and that the results correspond to the stacked
results of the single-tape derivatives."""
dev = qml.device("default.qubit")
x = [0.4, 0.2]
tape = qml.tape.QuantumScript(
[qml.RY(0.6, 0), qml.RX(x, 0)], [qml.expval(qml.PauliZ(0))], trainable_params=[0]
)
batched_tapes, batched_fn = finite_diff(tape)
batched_grad = batched_fn(dev.execute(batched_tapes))
separate_tapes = [
qml.tape.QuantumScript(
[qml.RY(0.6, 0), qml.RX(_x, 0)], [qml.expval(qml.PauliZ(0))], trainable_params=[0]
)
for _x in x
]
separate_tapes_and_fns = [finite_diff(t) for t in separate_tapes]
separate_grad = [_fn(dev.execute(_tapes)) for _tapes, _fn in separate_tapes_and_fns]
assert np.allclose(batched_grad, separate_grad)

def test_non_differentiable_error(self):
"""Test error raised if attempting to differentiate with
respect to a non-differentiable argument"""
Expand Down
28 changes: 25 additions & 3 deletions tests/gradients/finite_diff/test_spsa_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,13 +170,35 @@ def circuit(param):
with pytest.raises(ValueError, match=expected_message):
qml.grad(circuit)(np.array(1.0))

def test_batched_tape_raises(self):
"""Test that an error is raised for a broadcasted/batched tape."""
def test_trainable_batched_tape_raises(self):
"""Test that an error is raised for a broadcasted/batched tape if the broadcasted
parameter is differentiated."""
tape = qml.tape.QuantumScript([qml.RX([0.4, 0.2], 0)], [qml.expval(qml.PauliZ(0))])
_match = "Computing the gradient of broadcasted tapes with the SPSA gradient transform"
_match = r"Computing the gradient of broadcasted tapes .* using the SPSA gradient transform"
with pytest.raises(NotImplementedError, match=_match):
spsa_grad(tape)

def test_nontrainable_batched_tape(self):
"""Test that no error is raised for a broadcasted/batched tape if the broadcasted
parameter is not differentiated, and that the results correspond to the stacked
results of the single-tape derivatives."""
dev = qml.device("default.qubit")
x = [0.4, 0.2]
tape = qml.tape.QuantumScript(
[qml.RY(0.6, 0), qml.RX(x, 0)], [qml.expval(qml.PauliZ(0))], trainable_params=[0]
)
batched_tapes, batched_fn = spsa_grad(tape)
batched_grad = batched_fn(dev.execute(batched_tapes))
separate_tapes = [
qml.tape.QuantumScript(
[qml.RY(0.6, 0), qml.RX(_x, 0)], [qml.expval(qml.PauliZ(0))], trainable_params=[0]
)
for _x in x
]
separate_tapes_and_fns = [spsa_grad(t) for t in separate_tapes]
separate_grad = [_fn(dev.execute(_tapes)) for _tapes, _fn in separate_tapes_and_fns]
assert np.allclose(batched_grad, separate_grad)

def test_non_differentiable_error(self):
"""Test error raised if attempting to differentiate with
respect to a non-differentiable argument"""
Expand Down
Loading

0 comments on commit 9a03cce

Please sign in to comment.