diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 6881aa71e71..6256d2116af 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -73,6 +73,10 @@
Improvements ðŸ›
+* Gradient transforms may now be applied to batched/broadcasted QNodes, as long as the
+ broadcasting is in non-trainable parameters.
+ [(#5452)](https://github.com/PennyLaneAI/pennylane/pull/5452)
+
* Improve the performance of computing the matrix of `qml.QFT`
[(#5351)](https://github.com/PennyLaneAI/pennylane/pull/5351)
diff --git a/pennylane/gradients/finite_difference.py b/pennylane/gradients/finite_difference.py
index 2b33831b8df..615c92ccac0 100644
--- a/pennylane/gradients/finite_difference.py
+++ b/pennylane/gradients/finite_difference.py
@@ -34,7 +34,7 @@
from .general_shift_rules import generate_shifted_tapes
from .gradient_transform import (
_all_zero_grad,
- assert_no_tape_batching,
+ assert_no_trainable_tape_batching,
choose_trainable_params,
find_and_validate_gradient_methods,
_no_trainable_grad,
@@ -369,7 +369,7 @@ def finite_diff(
"""
transform_name = "finite difference"
- assert_no_tape_batching(tape, transform_name)
+ assert_no_trainable_tape_batching(tape, transform_name)
if any(qml.math.get_dtype_name(p) == "float32" for p in tape.get_parameters()):
warn(
diff --git a/pennylane/gradients/gradient_transform.py b/pennylane/gradients/gradient_transform.py
index cfb50b3f2c4..691f28dccd0 100644
--- a/pennylane/gradients/gradient_transform.py
+++ b/pennylane/gradients/gradient_transform.py
@@ -99,18 +99,24 @@ def assert_no_variance(measurements, transform_name):
)
-def assert_no_tape_batching(tape, transform_name):
+def assert_no_trainable_tape_batching(tape, transform_name):
"""Check whether a tape is broadcasted and raise an error if this is the case.
Args:
tape (`~.QuantumScript`): measurements to analyze
transform_name (str): Name of the gradient transform that queries the tape
"""
- if tape.batch_size is not None:
- raise NotImplementedError(
- f"Computing the gradient of broadcasted tapes with the {transform_name} "
- "gradient transform is currently not supported. See #4462 for details."
- )
+ if tape.batch_size is None:
+ return
+
+ # Iterate over trainable parameters and check the affiliated operations for batching
+ for idx in range(len(tape.trainable_params)):
+ if tape.get_operation(idx)[0].batch_size is not None:
+ raise NotImplementedError(
+ "Computing the gradient of broadcasted tapes with respect to the broadcasted "
+ f"parameters using the {transform_name} gradient transform is currently not "
+ "supported. See #4462 for details."
+ )
def choose_trainable_params(tape, argnum=None):
diff --git a/pennylane/gradients/hadamard_gradient.py b/pennylane/gradients/hadamard_gradient.py
index dea644caf71..229347636e6 100644
--- a/pennylane/gradients/hadamard_gradient.py
+++ b/pennylane/gradients/hadamard_gradient.py
@@ -28,7 +28,7 @@
from .gradient_transform import (
_all_zero_grad,
assert_no_state_returns,
- assert_no_tape_batching,
+ assert_no_trainable_tape_batching,
assert_no_variance,
choose_trainable_params,
find_and_validate_gradient_methods,
@@ -234,7 +234,7 @@ def hadamard_grad(
transform_name = "Hadamard test"
assert_no_state_returns(tape.measurements, transform_name)
assert_no_variance(tape.measurements, transform_name)
- assert_no_tape_batching(tape, transform_name)
+ assert_no_trainable_tape_batching(tape, transform_name)
if len(tape.measurements) > 1 and tape.shots.has_partitioned_shots:
raise NotImplementedError(
"hadamard gradient does not support multiple measurements with partitioned shots."
diff --git a/pennylane/gradients/parameter_shift.py b/pennylane/gradients/parameter_shift.py
index 739c92cdc61..1db25ba0d11 100644
--- a/pennylane/gradients/parameter_shift.py
+++ b/pennylane/gradients/parameter_shift.py
@@ -37,7 +37,7 @@
from .gradient_transform import (
_all_zero_grad,
assert_no_state_returns,
- assert_no_tape_batching,
+ assert_no_trainable_tape_batching,
assert_multimeasure_not_broadcasted,
choose_trainable_params,
find_and_validate_gradient_methods,
@@ -727,7 +727,6 @@ def var_param_shift(tape, argnum, shifts=None, gradient_recipes=None, f0=None, b
pdA2_fn = None
if non_involutory_indices:
-
new_measurements = list(tape.measurements)
for i in non_involutory_indices:
# We need to calculate d/dp; to do so, we replace the
@@ -1078,7 +1077,7 @@ def param_shift(
transform_name = "parameter-shift rule"
assert_no_state_returns(tape.measurements, transform_name)
assert_multimeasure_not_broadcasted(tape.measurements, broadcast)
- assert_no_tape_batching(tape, transform_name)
+ assert_no_trainable_tape_batching(tape, transform_name)
if argnum is None and not tape.trainable_params:
return _no_trainable_grad(tape)
diff --git a/pennylane/gradients/pulse_gradient.py b/pennylane/gradients/pulse_gradient.py
index ee3f32546ce..ed52cd3503b 100644
--- a/pennylane/gradients/pulse_gradient.py
+++ b/pennylane/gradients/pulse_gradient.py
@@ -29,7 +29,7 @@
from .gradient_transform import (
_all_zero_grad,
assert_no_state_returns,
- assert_no_tape_batching,
+ assert_no_trainable_tape_batching,
assert_no_variance,
choose_trainable_params,
find_and_validate_gradient_methods,
@@ -608,7 +608,7 @@ def ansatz(params):
_assert_has_jax(transform_name)
assert_no_state_returns(tape.measurements, transform_name)
assert_no_variance(tape.measurements, transform_name)
- assert_no_tape_batching(tape, transform_name)
+ assert_no_trainable_tape_batching(tape, transform_name)
if num_split_times < 1:
raise ValueError(
diff --git a/pennylane/gradients/pulse_gradient_odegen.py b/pennylane/gradients/pulse_gradient_odegen.py
index 0c2384ca3ca..da2dbf0586a 100644
--- a/pennylane/gradients/pulse_gradient_odegen.py
+++ b/pennylane/gradients/pulse_gradient_odegen.py
@@ -30,7 +30,7 @@
from .gradient_transform import (
_all_zero_grad,
assert_no_state_returns,
- assert_no_tape_batching,
+ assert_no_trainable_tape_batching,
assert_no_variance,
choose_trainable_params,
find_and_validate_gradient_methods,
@@ -681,7 +681,7 @@ def circuit(params):
_assert_has_jax(transform_name)
assert_no_state_returns(tape.measurements, transform_name)
assert_no_variance(tape.measurements, transform_name)
- assert_no_tape_batching(tape, transform_name)
+ assert_no_trainable_tape_batching(tape, transform_name)
if argnum is None and not tape.trainable_params:
return _no_trainable_grad(tape)
diff --git a/pennylane/gradients/spsa_gradient.py b/pennylane/gradients/spsa_gradient.py
index c7d25ccaff7..d5b84629874 100644
--- a/pennylane/gradients/spsa_gradient.py
+++ b/pennylane/gradients/spsa_gradient.py
@@ -29,7 +29,7 @@
from .finite_difference import _processing_fn, finite_diff_coeffs
from .gradient_transform import (
_all_zero_grad,
- assert_no_tape_batching,
+ assert_no_trainable_tape_batching,
choose_trainable_params,
find_and_validate_gradient_methods,
_no_trainable_grad,
@@ -292,7 +292,7 @@ def spsa_grad(
"""
transform_name = "SPSA"
- assert_no_tape_batching(tape, transform_name)
+ assert_no_trainable_tape_batching(tape, transform_name)
if argnum is None and not tape.trainable_params:
return _no_trainable_grad(tape)
diff --git a/pennylane/transforms/core/transform.py b/pennylane/transforms/core/transform.py
index 35c19f96f70..00a9f91503c 100644
--- a/pennylane/transforms/core/transform.py
+++ b/pennylane/transforms/core/transform.py
@@ -180,7 +180,7 @@ def qnode_circuit(a):
"The expand transform must have the same signature as the transform"
)
- # 3: CHeck the classical co-transform
+ # 3: Check the classical co-transform
if classical_cotransform is not None and not callable(classical_cotransform):
raise TransformError("The classical co-transform must be a valid Python function.")
diff --git a/tests/gradients/core/test_hadamard_gradient.py b/tests/gradients/core/test_hadamard_gradient.py
index e04c6abead2..72e4c41eabc 100644
--- a/tests/gradients/core/test_hadamard_gradient.py
+++ b/tests/gradients/core/test_hadamard_gradient.py
@@ -66,13 +66,35 @@ def cost6(x):
class TestHadamardGrad:
"""Unit tests for the hadamard_grad function"""
- def test_batched_tape_raises(self):
- """Test that an error is raised for a broadcasted/batched tape."""
+ def test_trainable_batched_tape_raises(self):
+ """Test that an error is raised for a broadcasted/batched tape if the broadcasted
+ parameter is differentiated."""
tape = qml.tape.QuantumScript([qml.RX([0.4, 0.2], 0)], [qml.expval(qml.PauliZ(0))])
- _match = "Computing the gradient of broadcasted tapes with the Hadamard test gradient"
+ _match = r"Computing the gradient of broadcasted tapes .* using the Hadamard test gradient"
with pytest.raises(NotImplementedError, match=_match):
qml.gradients.hadamard_grad(tape)
+ def test_nontrainable_batched_tape(self):
+ """Test that no error is raised for a broadcasted/batched tape if the broadcasted
+ parameter is not differentiated, and that the results correspond to the stacked
+ results of the single-tape derivatives."""
+ dev = qml.device("default.qubit")
+ x = [0.4, 0.2]
+ tape = qml.tape.QuantumScript(
+ [qml.RY(0.6, 0), qml.RX(x, 0)], [qml.expval(qml.PauliZ(0))], trainable_params=[0]
+ )
+ batched_tapes, batched_fn = qml.gradients.hadamard_grad(tape)
+ batched_grad = batched_fn(dev.execute(batched_tapes))
+ separate_tapes = [
+ qml.tape.QuantumScript(
+ [qml.RY(0.6, 0), qml.RX(_x, 0)], [qml.expval(qml.PauliZ(0))], trainable_params=[0]
+ )
+ for _x in x
+ ]
+ separate_tapes_and_fns = [qml.gradients.hadamard_grad(t) for t in separate_tapes]
+ separate_grad = [_fn(dev.execute(_tapes)) for _tapes, _fn in separate_tapes_and_fns]
+ assert np.allclose(batched_grad, separate_grad)
+
def test_tape_with_partitioned_shots_multiple_measurements_raises(self):
"""Test that an error is raised with multiple measurements and partitioned shots."""
tape = qml.tape.QuantumScript(
diff --git a/tests/gradients/core/test_pulse_gradient.py b/tests/gradients/core/test_pulse_gradient.py
index d3661e9cdf1..b95a36d67f9 100644
--- a/tests/gradients/core/test_pulse_gradient.py
+++ b/tests/gradients/core/test_pulse_gradient.py
@@ -752,13 +752,42 @@ def test_raises_for_less_than_one_sample(self, num_split_times):
with pytest.raises(ValueError, match="Expected a positive number of samples"):
stoch_pulse_grad(tape, num_split_times=num_split_times)
- def test_batched_tape_raises(self):
- """Test that an error is raised for a broadcasted/batched tape."""
+ def test_trainable_batched_tape_raises(self):
+ """Test that an error is raised for a broadcasted/batched tape if the broadcasted
+ parameter is differentiated."""
tape = qml.tape.QuantumScript([qml.RX([0.4, 0.2], 0)], [qml.expval(qml.PauliZ(0))])
- _match = "Computing the gradient of broadcasted tapes with the stochastic pulse"
+ _match = r"Computing the gradient of broadcasted tapes .* using the stochastic pulse"
with pytest.raises(NotImplementedError, match=_match):
stoch_pulse_grad(tape)
+ def test_nontrainable_batched_tape(self):
+ """Test that no error is raised for a broadcasted/batched tape if the broadcasted
+ parameter is not differentiated, and that the results correspond to the stacked
+ results of the single-tape derivatives."""
+ import jax.numpy as jnp
+
+ dev = qml.device("default.qubit")
+ x = [0.4, 0.2]
+ params = [jnp.array(0.14)]
+ ham_single_q_const = qml.pulse.constant * qml.PauliY(0)
+ op = qml.evolve(ham_single_q_const)(params, 0.1)
+ tape = qml.tape.QuantumScript(
+ [qml.RX(x, 0), op], [qml.expval(qml.PauliZ(0))], trainable_params=[1]
+ )
+ batched_tapes, batched_fn = stoch_pulse_grad(tape, argnum=0, num_split_times=1)
+ batched_grad = batched_fn(dev.execute(batched_tapes))
+ separate_tapes = [
+ qml.tape.QuantumScript(
+ [qml.RX(_x, 0), op], [qml.expval(qml.PauliZ(0))], trainable_params=[1]
+ )
+ for _x in x
+ ]
+ separate_tapes_and_fns = [
+ stoch_pulse_grad(t, argnum=0, num_split_times=1) for t in separate_tapes
+ ]
+ separate_grad = [_fn(dev.execute(_tapes)) for _tapes, _fn in separate_tapes_and_fns]
+ assert np.allclose(batched_grad, separate_grad)
+
@pytest.mark.parametrize("num_meas", [0, 1, 2])
def test_warning_no_trainable_params(self, num_meas):
"""Test that an empty gradient is returned when there are no trainable parameters."""
diff --git a/tests/gradients/core/test_pulse_odegen.py b/tests/gradients/core/test_pulse_odegen.py
index f13fddefff4..faac48f9e02 100644
--- a/tests/gradients/core/test_pulse_odegen.py
+++ b/tests/gradients/core/test_pulse_odegen.py
@@ -824,13 +824,40 @@ def test_raises_with_invalid_op(self):
with pytest.raises(ValueError, match=_match):
pulse_odegen(tape)
- def test_batched_tape_raises(self):
- """Test that an error is raised for a broadcasted/batched tape."""
+ def test_trainable_batched_tape_raises(self):
+ """Test that an error is raised for a broadcasted/batched tape if the broadcasted
+ parameter is differentiated."""
tape = qml.tape.QuantumScript([qml.RX([0.4, 0.2], 0)], [qml.expval(qml.PauliZ(0))])
- _match = "Computing the gradient of broadcasted tapes with the pulse generator"
+ _match = r"Computing the gradient of broadcasted tapes .* using the pulse generator"
with pytest.raises(NotImplementedError, match=_match):
pulse_odegen(tape)
+ def test_nontrainable_batched_tape(self):
+ """Test that no error is raised for a broadcasted/batched tape if the broadcasted
+ parameter is not differentiated, and that the results correspond to the stacked
+ results of the single-tape derivatives."""
+ import jax.numpy as jnp
+
+ dev = qml.device("default.qubit")
+ x = [0.4, 0.2]
+ params = [jnp.array(0.14)]
+ ham_single_q_const = qml.pulse.constant * qml.PauliY(0)
+ op = qml.evolve(ham_single_q_const)(params, 0.1)
+ tape = qml.tape.QuantumScript(
+ [qml.RX(x, 0), op], [qml.expval(qml.PauliZ(0))], trainable_params=[1]
+ )
+ batched_tapes, batched_fn = pulse_odegen(tape, argnum=0)
+ batched_grad = batched_fn(dev.execute(batched_tapes))
+ separate_tapes = [
+ qml.tape.QuantumScript(
+ [qml.RX(_x, 0), op], [qml.expval(qml.PauliZ(0))], trainable_params=[1]
+ )
+ for _x in x
+ ]
+ separate_tapes_and_fns = [pulse_odegen(t, argnum=0) for t in separate_tapes]
+ separate_grad = [_fn(dev.execute(_tapes)) for _tapes, _fn in separate_tapes_and_fns]
+ assert np.allclose(batched_grad, separate_grad)
+
def test_no_trainable_params_tape(self):
"""Test that the correct ouput and warning is generated in the absence of any trainable
parameters"""
diff --git a/tests/gradients/finite_diff/test_finite_difference.py b/tests/gradients/finite_diff/test_finite_difference.py
index 25eca0648ef..43dd856170d 100644
--- a/tests/gradients/finite_diff/test_finite_difference.py
+++ b/tests/gradients/finite_diff/test_finite_difference.py
@@ -99,13 +99,35 @@ def test_correct_second_derivative_center_order4(self):
class TestFiniteDiff:
"""Tests for the finite difference gradient transform"""
- def test_batched_tape_raises(self):
- """Test that an error is raised for a broadcasted/batched tape."""
+ def test_trainable_batched_tape_raises(self):
+ """Test that an error is raised for a broadcasted/batched tape if the broadcasted
+ parameter is differentiated."""
tape = qml.tape.QuantumScript([qml.RX([0.4, 0.2], 0)], [qml.expval(qml.PauliZ(0))])
- _match = "Computing the gradient of broadcasted tapes with the finite difference"
+ _match = r"Computing the gradient of broadcasted tapes .* using the finite difference"
with pytest.raises(NotImplementedError, match=_match):
finite_diff(tape)
+ def test_nontrainable_batched_tape(self):
+ """Test that no error is raised for a broadcasted/batched tape if the broadcasted
+ parameter is not differentiated, and that the results correspond to the stacked
+ results of the single-tape derivatives."""
+ dev = qml.device("default.qubit")
+ x = [0.4, 0.2]
+ tape = qml.tape.QuantumScript(
+ [qml.RY(0.6, 0), qml.RX(x, 0)], [qml.expval(qml.PauliZ(0))], trainable_params=[0]
+ )
+ batched_tapes, batched_fn = finite_diff(tape)
+ batched_grad = batched_fn(dev.execute(batched_tapes))
+ separate_tapes = [
+ qml.tape.QuantumScript(
+ [qml.RY(0.6, 0), qml.RX(_x, 0)], [qml.expval(qml.PauliZ(0))], trainable_params=[0]
+ )
+ for _x in x
+ ]
+ separate_tapes_and_fns = [finite_diff(t) for t in separate_tapes]
+ separate_grad = [_fn(dev.execute(_tapes)) for _tapes, _fn in separate_tapes_and_fns]
+ assert np.allclose(batched_grad, separate_grad)
+
def test_non_differentiable_error(self):
"""Test error raised if attempting to differentiate with
respect to a non-differentiable argument"""
diff --git a/tests/gradients/finite_diff/test_spsa_gradient.py b/tests/gradients/finite_diff/test_spsa_gradient.py
index 5f0950b7667..91ea1c89842 100644
--- a/tests/gradients/finite_diff/test_spsa_gradient.py
+++ b/tests/gradients/finite_diff/test_spsa_gradient.py
@@ -170,13 +170,35 @@ def circuit(param):
with pytest.raises(ValueError, match=expected_message):
qml.grad(circuit)(np.array(1.0))
- def test_batched_tape_raises(self):
- """Test that an error is raised for a broadcasted/batched tape."""
+ def test_trainable_batched_tape_raises(self):
+ """Test that an error is raised for a broadcasted/batched tape if the broadcasted
+ parameter is differentiated."""
tape = qml.tape.QuantumScript([qml.RX([0.4, 0.2], 0)], [qml.expval(qml.PauliZ(0))])
- _match = "Computing the gradient of broadcasted tapes with the SPSA gradient transform"
+ _match = r"Computing the gradient of broadcasted tapes .* using the SPSA gradient transform"
with pytest.raises(NotImplementedError, match=_match):
spsa_grad(tape)
+ def test_nontrainable_batched_tape(self):
+ """Test that no error is raised for a broadcasted/batched tape if the broadcasted
+ parameter is not differentiated, and that the results correspond to the stacked
+ results of the single-tape derivatives."""
+ dev = qml.device("default.qubit")
+ x = [0.4, 0.2]
+ tape = qml.tape.QuantumScript(
+ [qml.RY(0.6, 0), qml.RX(x, 0)], [qml.expval(qml.PauliZ(0))], trainable_params=[0]
+ )
+ batched_tapes, batched_fn = spsa_grad(tape)
+ batched_grad = batched_fn(dev.execute(batched_tapes))
+ separate_tapes = [
+ qml.tape.QuantumScript(
+ [qml.RY(0.6, 0), qml.RX(_x, 0)], [qml.expval(qml.PauliZ(0))], trainable_params=[0]
+ )
+ for _x in x
+ ]
+ separate_tapes_and_fns = [spsa_grad(t) for t in separate_tapes]
+ separate_grad = [_fn(dev.execute(_tapes)) for _tapes, _fn in separate_tapes_and_fns]
+ assert np.allclose(batched_grad, separate_grad)
+
def test_non_differentiable_error(self):
"""Test error raised if attempting to differentiate with
respect to a non-differentiable argument"""
diff --git a/tests/gradients/parameter_shift/test_parameter_shift.py b/tests/gradients/parameter_shift/test_parameter_shift.py
index 847f95961ec..8a7e9bc675c 100644
--- a/tests/gradients/parameter_shift/test_parameter_shift.py
+++ b/tests/gradients/parameter_shift/test_parameter_shift.py
@@ -830,24 +830,25 @@ class TestParamShiftRaisesWithBroadcasted:
"""Test that an error is raised with broadcasted tapes."""
def test_batched_tape_raises(self):
- """Test that an error is raised for a broadcasted/batched tape."""
+ """Test that an error is raised for a broadcasted/batched tape if the broadcasted
+ parameter is differentiated."""
tape = qml.tape.QuantumScript([qml.RX([0.4, 0.2], 0)], [qml.expval(qml.PauliZ(0))])
- _match = "Computing the gradient of broadcasted tapes with the parameter-shift rule"
+ _match = r"Computing the gradient of broadcasted tapes .* using the parameter-shift rule"
with pytest.raises(NotImplementedError, match=_match):
qml.gradients.param_shift(tape)
-# Revert the following skip once broadcasted tapes are fully supported with gradient transforms.
-# See #4462 for details.
-@pytest.mark.skip(reason="Applying gradient transforms to broadcasted tapes is disallowed")
class TestParamShiftWithBroadcasted:
"""Tests for the `param_shift` transform on already broadcasted tapes.
The tests for `param_shift` using broadcasting itself can be found
further below."""
+ # Revert the following skip once broadcasted tapes are fully supported with gradient transforms.
+ # See #4462 for details.
+ @pytest.mark.skip(reason="Applying gradient transforms to broadcasted tapes is disallowed")
@pytest.mark.parametrize("dim", [1, 3])
@pytest.mark.parametrize("pos", [0, 1])
- def test_with_single_parameter_broadcasted(self, dim, pos):
+ def test_with_single_trainable_parameter_broadcasted(self, dim, pos):
"""Test that the parameter-shift transform works with a tape that has
one of its parameters broadcasted already."""
x = np.array([0.23, 9.1, 2.3])
@@ -875,6 +876,9 @@ def test_with_single_parameter_broadcasted(self, dim, pos):
assert res[0].shape == (dim,)
assert res[1].shape == (dim,)
+ # Revert the following skip once broadcasted tapes are fully supported with gradient transforms.
+ # See #4462 for details.
+ @pytest.mark.skip(reason="Applying gradient transforms to broadcasted tapes is disallowed")
@pytest.mark.parametrize("argnum", [(0, 2), (0, 1), (1,), (2,)])
@pytest.mark.parametrize("dim", [1, 3])
def test_with_multiple_parameters_broadcasted(self, dim, argnum):
@@ -902,6 +906,33 @@ def test_with_multiple_parameters_broadcasted(self, dim, argnum):
assert res[0].shape == res[1].shape == res[2].shape == (dim,)
+ @pytest.mark.parametrize("dim", [1, 3])
+ @pytest.mark.parametrize("pos", [0, 1])
+ def test_with_single_nontrainable_parameter_broadcasted(self, dim, pos):
+ """Test that the parameter-shift transform works with a tape that has
+ one of its nontrainable parameters broadcasted."""
+ x = np.array([0.23, 9.1, 2.3])
+ x = x[:dim]
+ y = -0.654
+ if pos == 1:
+ x, y = y, x
+
+ with qml.queuing.AnnotatedQueue() as q:
+ qml.RX(x, wires=[0])
+ qml.RY(y, wires=[0]) # does not have any impact on the expval
+ qml.expval(qml.PauliZ(0))
+
+ tape = qml.tape.QuantumScript.from_queue(q)
+ tape.trainable_params = [1 - pos]
+ assert tape.batch_size == dim
+ tapes, fn = qml.gradients.param_shift(tape, argnum=[0])
+ assert len(tapes) == 2
+ assert np.allclose([t.batch_size for t in tapes], dim)
+
+ dev = qml.device("default.qubit", wires=2)
+ res = fn(dev.execute(tapes))
+ assert res.shape == (dim,)
+
class TestParamShiftUsingBroadcasting:
"""Tests for the `param_shift` function using broadcasting.