diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 6881aa71e71..6256d2116af 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -73,6 +73,10 @@

Improvements 🛠

+* Gradient transforms may now be applied to batched/broadcasted QNodes, as long as the + broadcasting is in non-trainable parameters. + [(#5452)](https://github.com/PennyLaneAI/pennylane/pull/5452) + * Improve the performance of computing the matrix of `qml.QFT` [(#5351)](https://github.com/PennyLaneAI/pennylane/pull/5351) diff --git a/pennylane/gradients/finite_difference.py b/pennylane/gradients/finite_difference.py index 2b33831b8df..615c92ccac0 100644 --- a/pennylane/gradients/finite_difference.py +++ b/pennylane/gradients/finite_difference.py @@ -34,7 +34,7 @@ from .general_shift_rules import generate_shifted_tapes from .gradient_transform import ( _all_zero_grad, - assert_no_tape_batching, + assert_no_trainable_tape_batching, choose_trainable_params, find_and_validate_gradient_methods, _no_trainable_grad, @@ -369,7 +369,7 @@ def finite_diff( """ transform_name = "finite difference" - assert_no_tape_batching(tape, transform_name) + assert_no_trainable_tape_batching(tape, transform_name) if any(qml.math.get_dtype_name(p) == "float32" for p in tape.get_parameters()): warn( diff --git a/pennylane/gradients/gradient_transform.py b/pennylane/gradients/gradient_transform.py index cfb50b3f2c4..691f28dccd0 100644 --- a/pennylane/gradients/gradient_transform.py +++ b/pennylane/gradients/gradient_transform.py @@ -99,18 +99,24 @@ def assert_no_variance(measurements, transform_name): ) -def assert_no_tape_batching(tape, transform_name): +def assert_no_trainable_tape_batching(tape, transform_name): """Check whether a tape is broadcasted and raise an error if this is the case. Args: tape (`~.QuantumScript`): measurements to analyze transform_name (str): Name of the gradient transform that queries the tape """ - if tape.batch_size is not None: - raise NotImplementedError( - f"Computing the gradient of broadcasted tapes with the {transform_name} " - "gradient transform is currently not supported. See #4462 for details." - ) + if tape.batch_size is None: + return + + # Iterate over trainable parameters and check the affiliated operations for batching + for idx in range(len(tape.trainable_params)): + if tape.get_operation(idx)[0].batch_size is not None: + raise NotImplementedError( + "Computing the gradient of broadcasted tapes with respect to the broadcasted " + f"parameters using the {transform_name} gradient transform is currently not " + "supported. See #4462 for details." + ) def choose_trainable_params(tape, argnum=None): diff --git a/pennylane/gradients/hadamard_gradient.py b/pennylane/gradients/hadamard_gradient.py index dea644caf71..229347636e6 100644 --- a/pennylane/gradients/hadamard_gradient.py +++ b/pennylane/gradients/hadamard_gradient.py @@ -28,7 +28,7 @@ from .gradient_transform import ( _all_zero_grad, assert_no_state_returns, - assert_no_tape_batching, + assert_no_trainable_tape_batching, assert_no_variance, choose_trainable_params, find_and_validate_gradient_methods, @@ -234,7 +234,7 @@ def hadamard_grad( transform_name = "Hadamard test" assert_no_state_returns(tape.measurements, transform_name) assert_no_variance(tape.measurements, transform_name) - assert_no_tape_batching(tape, transform_name) + assert_no_trainable_tape_batching(tape, transform_name) if len(tape.measurements) > 1 and tape.shots.has_partitioned_shots: raise NotImplementedError( "hadamard gradient does not support multiple measurements with partitioned shots." diff --git a/pennylane/gradients/parameter_shift.py b/pennylane/gradients/parameter_shift.py index 739c92cdc61..1db25ba0d11 100644 --- a/pennylane/gradients/parameter_shift.py +++ b/pennylane/gradients/parameter_shift.py @@ -37,7 +37,7 @@ from .gradient_transform import ( _all_zero_grad, assert_no_state_returns, - assert_no_tape_batching, + assert_no_trainable_tape_batching, assert_multimeasure_not_broadcasted, choose_trainable_params, find_and_validate_gradient_methods, @@ -727,7 +727,6 @@ def var_param_shift(tape, argnum, shifts=None, gradient_recipes=None, f0=None, b pdA2_fn = None if non_involutory_indices: - new_measurements = list(tape.measurements) for i in non_involutory_indices: # We need to calculate d/dp; to do so, we replace the @@ -1078,7 +1077,7 @@ def param_shift( transform_name = "parameter-shift rule" assert_no_state_returns(tape.measurements, transform_name) assert_multimeasure_not_broadcasted(tape.measurements, broadcast) - assert_no_tape_batching(tape, transform_name) + assert_no_trainable_tape_batching(tape, transform_name) if argnum is None and not tape.trainable_params: return _no_trainable_grad(tape) diff --git a/pennylane/gradients/pulse_gradient.py b/pennylane/gradients/pulse_gradient.py index ee3f32546ce..ed52cd3503b 100644 --- a/pennylane/gradients/pulse_gradient.py +++ b/pennylane/gradients/pulse_gradient.py @@ -29,7 +29,7 @@ from .gradient_transform import ( _all_zero_grad, assert_no_state_returns, - assert_no_tape_batching, + assert_no_trainable_tape_batching, assert_no_variance, choose_trainable_params, find_and_validate_gradient_methods, @@ -608,7 +608,7 @@ def ansatz(params): _assert_has_jax(transform_name) assert_no_state_returns(tape.measurements, transform_name) assert_no_variance(tape.measurements, transform_name) - assert_no_tape_batching(tape, transform_name) + assert_no_trainable_tape_batching(tape, transform_name) if num_split_times < 1: raise ValueError( diff --git a/pennylane/gradients/pulse_gradient_odegen.py b/pennylane/gradients/pulse_gradient_odegen.py index 0c2384ca3ca..da2dbf0586a 100644 --- a/pennylane/gradients/pulse_gradient_odegen.py +++ b/pennylane/gradients/pulse_gradient_odegen.py @@ -30,7 +30,7 @@ from .gradient_transform import ( _all_zero_grad, assert_no_state_returns, - assert_no_tape_batching, + assert_no_trainable_tape_batching, assert_no_variance, choose_trainable_params, find_and_validate_gradient_methods, @@ -681,7 +681,7 @@ def circuit(params): _assert_has_jax(transform_name) assert_no_state_returns(tape.measurements, transform_name) assert_no_variance(tape.measurements, transform_name) - assert_no_tape_batching(tape, transform_name) + assert_no_trainable_tape_batching(tape, transform_name) if argnum is None and not tape.trainable_params: return _no_trainable_grad(tape) diff --git a/pennylane/gradients/spsa_gradient.py b/pennylane/gradients/spsa_gradient.py index c7d25ccaff7..d5b84629874 100644 --- a/pennylane/gradients/spsa_gradient.py +++ b/pennylane/gradients/spsa_gradient.py @@ -29,7 +29,7 @@ from .finite_difference import _processing_fn, finite_diff_coeffs from .gradient_transform import ( _all_zero_grad, - assert_no_tape_batching, + assert_no_trainable_tape_batching, choose_trainable_params, find_and_validate_gradient_methods, _no_trainable_grad, @@ -292,7 +292,7 @@ def spsa_grad( """ transform_name = "SPSA" - assert_no_tape_batching(tape, transform_name) + assert_no_trainable_tape_batching(tape, transform_name) if argnum is None and not tape.trainable_params: return _no_trainable_grad(tape) diff --git a/pennylane/transforms/core/transform.py b/pennylane/transforms/core/transform.py index 35c19f96f70..00a9f91503c 100644 --- a/pennylane/transforms/core/transform.py +++ b/pennylane/transforms/core/transform.py @@ -180,7 +180,7 @@ def qnode_circuit(a): "The expand transform must have the same signature as the transform" ) - # 3: CHeck the classical co-transform + # 3: Check the classical co-transform if classical_cotransform is not None and not callable(classical_cotransform): raise TransformError("The classical co-transform must be a valid Python function.") diff --git a/tests/gradients/core/test_hadamard_gradient.py b/tests/gradients/core/test_hadamard_gradient.py index e04c6abead2..72e4c41eabc 100644 --- a/tests/gradients/core/test_hadamard_gradient.py +++ b/tests/gradients/core/test_hadamard_gradient.py @@ -66,13 +66,35 @@ def cost6(x): class TestHadamardGrad: """Unit tests for the hadamard_grad function""" - def test_batched_tape_raises(self): - """Test that an error is raised for a broadcasted/batched tape.""" + def test_trainable_batched_tape_raises(self): + """Test that an error is raised for a broadcasted/batched tape if the broadcasted + parameter is differentiated.""" tape = qml.tape.QuantumScript([qml.RX([0.4, 0.2], 0)], [qml.expval(qml.PauliZ(0))]) - _match = "Computing the gradient of broadcasted tapes with the Hadamard test gradient" + _match = r"Computing the gradient of broadcasted tapes .* using the Hadamard test gradient" with pytest.raises(NotImplementedError, match=_match): qml.gradients.hadamard_grad(tape) + def test_nontrainable_batched_tape(self): + """Test that no error is raised for a broadcasted/batched tape if the broadcasted + parameter is not differentiated, and that the results correspond to the stacked + results of the single-tape derivatives.""" + dev = qml.device("default.qubit") + x = [0.4, 0.2] + tape = qml.tape.QuantumScript( + [qml.RY(0.6, 0), qml.RX(x, 0)], [qml.expval(qml.PauliZ(0))], trainable_params=[0] + ) + batched_tapes, batched_fn = qml.gradients.hadamard_grad(tape) + batched_grad = batched_fn(dev.execute(batched_tapes)) + separate_tapes = [ + qml.tape.QuantumScript( + [qml.RY(0.6, 0), qml.RX(_x, 0)], [qml.expval(qml.PauliZ(0))], trainable_params=[0] + ) + for _x in x + ] + separate_tapes_and_fns = [qml.gradients.hadamard_grad(t) for t in separate_tapes] + separate_grad = [_fn(dev.execute(_tapes)) for _tapes, _fn in separate_tapes_and_fns] + assert np.allclose(batched_grad, separate_grad) + def test_tape_with_partitioned_shots_multiple_measurements_raises(self): """Test that an error is raised with multiple measurements and partitioned shots.""" tape = qml.tape.QuantumScript( diff --git a/tests/gradients/core/test_pulse_gradient.py b/tests/gradients/core/test_pulse_gradient.py index d3661e9cdf1..b95a36d67f9 100644 --- a/tests/gradients/core/test_pulse_gradient.py +++ b/tests/gradients/core/test_pulse_gradient.py @@ -752,13 +752,42 @@ def test_raises_for_less_than_one_sample(self, num_split_times): with pytest.raises(ValueError, match="Expected a positive number of samples"): stoch_pulse_grad(tape, num_split_times=num_split_times) - def test_batched_tape_raises(self): - """Test that an error is raised for a broadcasted/batched tape.""" + def test_trainable_batched_tape_raises(self): + """Test that an error is raised for a broadcasted/batched tape if the broadcasted + parameter is differentiated.""" tape = qml.tape.QuantumScript([qml.RX([0.4, 0.2], 0)], [qml.expval(qml.PauliZ(0))]) - _match = "Computing the gradient of broadcasted tapes with the stochastic pulse" + _match = r"Computing the gradient of broadcasted tapes .* using the stochastic pulse" with pytest.raises(NotImplementedError, match=_match): stoch_pulse_grad(tape) + def test_nontrainable_batched_tape(self): + """Test that no error is raised for a broadcasted/batched tape if the broadcasted + parameter is not differentiated, and that the results correspond to the stacked + results of the single-tape derivatives.""" + import jax.numpy as jnp + + dev = qml.device("default.qubit") + x = [0.4, 0.2] + params = [jnp.array(0.14)] + ham_single_q_const = qml.pulse.constant * qml.PauliY(0) + op = qml.evolve(ham_single_q_const)(params, 0.1) + tape = qml.tape.QuantumScript( + [qml.RX(x, 0), op], [qml.expval(qml.PauliZ(0))], trainable_params=[1] + ) + batched_tapes, batched_fn = stoch_pulse_grad(tape, argnum=0, num_split_times=1) + batched_grad = batched_fn(dev.execute(batched_tapes)) + separate_tapes = [ + qml.tape.QuantumScript( + [qml.RX(_x, 0), op], [qml.expval(qml.PauliZ(0))], trainable_params=[1] + ) + for _x in x + ] + separate_tapes_and_fns = [ + stoch_pulse_grad(t, argnum=0, num_split_times=1) for t in separate_tapes + ] + separate_grad = [_fn(dev.execute(_tapes)) for _tapes, _fn in separate_tapes_and_fns] + assert np.allclose(batched_grad, separate_grad) + @pytest.mark.parametrize("num_meas", [0, 1, 2]) def test_warning_no_trainable_params(self, num_meas): """Test that an empty gradient is returned when there are no trainable parameters.""" diff --git a/tests/gradients/core/test_pulse_odegen.py b/tests/gradients/core/test_pulse_odegen.py index f13fddefff4..faac48f9e02 100644 --- a/tests/gradients/core/test_pulse_odegen.py +++ b/tests/gradients/core/test_pulse_odegen.py @@ -824,13 +824,40 @@ def test_raises_with_invalid_op(self): with pytest.raises(ValueError, match=_match): pulse_odegen(tape) - def test_batched_tape_raises(self): - """Test that an error is raised for a broadcasted/batched tape.""" + def test_trainable_batched_tape_raises(self): + """Test that an error is raised for a broadcasted/batched tape if the broadcasted + parameter is differentiated.""" tape = qml.tape.QuantumScript([qml.RX([0.4, 0.2], 0)], [qml.expval(qml.PauliZ(0))]) - _match = "Computing the gradient of broadcasted tapes with the pulse generator" + _match = r"Computing the gradient of broadcasted tapes .* using the pulse generator" with pytest.raises(NotImplementedError, match=_match): pulse_odegen(tape) + def test_nontrainable_batched_tape(self): + """Test that no error is raised for a broadcasted/batched tape if the broadcasted + parameter is not differentiated, and that the results correspond to the stacked + results of the single-tape derivatives.""" + import jax.numpy as jnp + + dev = qml.device("default.qubit") + x = [0.4, 0.2] + params = [jnp.array(0.14)] + ham_single_q_const = qml.pulse.constant * qml.PauliY(0) + op = qml.evolve(ham_single_q_const)(params, 0.1) + tape = qml.tape.QuantumScript( + [qml.RX(x, 0), op], [qml.expval(qml.PauliZ(0))], trainable_params=[1] + ) + batched_tapes, batched_fn = pulse_odegen(tape, argnum=0) + batched_grad = batched_fn(dev.execute(batched_tapes)) + separate_tapes = [ + qml.tape.QuantumScript( + [qml.RX(_x, 0), op], [qml.expval(qml.PauliZ(0))], trainable_params=[1] + ) + for _x in x + ] + separate_tapes_and_fns = [pulse_odegen(t, argnum=0) for t in separate_tapes] + separate_grad = [_fn(dev.execute(_tapes)) for _tapes, _fn in separate_tapes_and_fns] + assert np.allclose(batched_grad, separate_grad) + def test_no_trainable_params_tape(self): """Test that the correct ouput and warning is generated in the absence of any trainable parameters""" diff --git a/tests/gradients/finite_diff/test_finite_difference.py b/tests/gradients/finite_diff/test_finite_difference.py index 25eca0648ef..43dd856170d 100644 --- a/tests/gradients/finite_diff/test_finite_difference.py +++ b/tests/gradients/finite_diff/test_finite_difference.py @@ -99,13 +99,35 @@ def test_correct_second_derivative_center_order4(self): class TestFiniteDiff: """Tests for the finite difference gradient transform""" - def test_batched_tape_raises(self): - """Test that an error is raised for a broadcasted/batched tape.""" + def test_trainable_batched_tape_raises(self): + """Test that an error is raised for a broadcasted/batched tape if the broadcasted + parameter is differentiated.""" tape = qml.tape.QuantumScript([qml.RX([0.4, 0.2], 0)], [qml.expval(qml.PauliZ(0))]) - _match = "Computing the gradient of broadcasted tapes with the finite difference" + _match = r"Computing the gradient of broadcasted tapes .* using the finite difference" with pytest.raises(NotImplementedError, match=_match): finite_diff(tape) + def test_nontrainable_batched_tape(self): + """Test that no error is raised for a broadcasted/batched tape if the broadcasted + parameter is not differentiated, and that the results correspond to the stacked + results of the single-tape derivatives.""" + dev = qml.device("default.qubit") + x = [0.4, 0.2] + tape = qml.tape.QuantumScript( + [qml.RY(0.6, 0), qml.RX(x, 0)], [qml.expval(qml.PauliZ(0))], trainable_params=[0] + ) + batched_tapes, batched_fn = finite_diff(tape) + batched_grad = batched_fn(dev.execute(batched_tapes)) + separate_tapes = [ + qml.tape.QuantumScript( + [qml.RY(0.6, 0), qml.RX(_x, 0)], [qml.expval(qml.PauliZ(0))], trainable_params=[0] + ) + for _x in x + ] + separate_tapes_and_fns = [finite_diff(t) for t in separate_tapes] + separate_grad = [_fn(dev.execute(_tapes)) for _tapes, _fn in separate_tapes_and_fns] + assert np.allclose(batched_grad, separate_grad) + def test_non_differentiable_error(self): """Test error raised if attempting to differentiate with respect to a non-differentiable argument""" diff --git a/tests/gradients/finite_diff/test_spsa_gradient.py b/tests/gradients/finite_diff/test_spsa_gradient.py index 5f0950b7667..91ea1c89842 100644 --- a/tests/gradients/finite_diff/test_spsa_gradient.py +++ b/tests/gradients/finite_diff/test_spsa_gradient.py @@ -170,13 +170,35 @@ def circuit(param): with pytest.raises(ValueError, match=expected_message): qml.grad(circuit)(np.array(1.0)) - def test_batched_tape_raises(self): - """Test that an error is raised for a broadcasted/batched tape.""" + def test_trainable_batched_tape_raises(self): + """Test that an error is raised for a broadcasted/batched tape if the broadcasted + parameter is differentiated.""" tape = qml.tape.QuantumScript([qml.RX([0.4, 0.2], 0)], [qml.expval(qml.PauliZ(0))]) - _match = "Computing the gradient of broadcasted tapes with the SPSA gradient transform" + _match = r"Computing the gradient of broadcasted tapes .* using the SPSA gradient transform" with pytest.raises(NotImplementedError, match=_match): spsa_grad(tape) + def test_nontrainable_batched_tape(self): + """Test that no error is raised for a broadcasted/batched tape if the broadcasted + parameter is not differentiated, and that the results correspond to the stacked + results of the single-tape derivatives.""" + dev = qml.device("default.qubit") + x = [0.4, 0.2] + tape = qml.tape.QuantumScript( + [qml.RY(0.6, 0), qml.RX(x, 0)], [qml.expval(qml.PauliZ(0))], trainable_params=[0] + ) + batched_tapes, batched_fn = spsa_grad(tape) + batched_grad = batched_fn(dev.execute(batched_tapes)) + separate_tapes = [ + qml.tape.QuantumScript( + [qml.RY(0.6, 0), qml.RX(_x, 0)], [qml.expval(qml.PauliZ(0))], trainable_params=[0] + ) + for _x in x + ] + separate_tapes_and_fns = [spsa_grad(t) for t in separate_tapes] + separate_grad = [_fn(dev.execute(_tapes)) for _tapes, _fn in separate_tapes_and_fns] + assert np.allclose(batched_grad, separate_grad) + def test_non_differentiable_error(self): """Test error raised if attempting to differentiate with respect to a non-differentiable argument""" diff --git a/tests/gradients/parameter_shift/test_parameter_shift.py b/tests/gradients/parameter_shift/test_parameter_shift.py index 847f95961ec..8a7e9bc675c 100644 --- a/tests/gradients/parameter_shift/test_parameter_shift.py +++ b/tests/gradients/parameter_shift/test_parameter_shift.py @@ -830,24 +830,25 @@ class TestParamShiftRaisesWithBroadcasted: """Test that an error is raised with broadcasted tapes.""" def test_batched_tape_raises(self): - """Test that an error is raised for a broadcasted/batched tape.""" + """Test that an error is raised for a broadcasted/batched tape if the broadcasted + parameter is differentiated.""" tape = qml.tape.QuantumScript([qml.RX([0.4, 0.2], 0)], [qml.expval(qml.PauliZ(0))]) - _match = "Computing the gradient of broadcasted tapes with the parameter-shift rule" + _match = r"Computing the gradient of broadcasted tapes .* using the parameter-shift rule" with pytest.raises(NotImplementedError, match=_match): qml.gradients.param_shift(tape) -# Revert the following skip once broadcasted tapes are fully supported with gradient transforms. -# See #4462 for details. -@pytest.mark.skip(reason="Applying gradient transforms to broadcasted tapes is disallowed") class TestParamShiftWithBroadcasted: """Tests for the `param_shift` transform on already broadcasted tapes. The tests for `param_shift` using broadcasting itself can be found further below.""" + # Revert the following skip once broadcasted tapes are fully supported with gradient transforms. + # See #4462 for details. + @pytest.mark.skip(reason="Applying gradient transforms to broadcasted tapes is disallowed") @pytest.mark.parametrize("dim", [1, 3]) @pytest.mark.parametrize("pos", [0, 1]) - def test_with_single_parameter_broadcasted(self, dim, pos): + def test_with_single_trainable_parameter_broadcasted(self, dim, pos): """Test that the parameter-shift transform works with a tape that has one of its parameters broadcasted already.""" x = np.array([0.23, 9.1, 2.3]) @@ -875,6 +876,9 @@ def test_with_single_parameter_broadcasted(self, dim, pos): assert res[0].shape == (dim,) assert res[1].shape == (dim,) + # Revert the following skip once broadcasted tapes are fully supported with gradient transforms. + # See #4462 for details. + @pytest.mark.skip(reason="Applying gradient transforms to broadcasted tapes is disallowed") @pytest.mark.parametrize("argnum", [(0, 2), (0, 1), (1,), (2,)]) @pytest.mark.parametrize("dim", [1, 3]) def test_with_multiple_parameters_broadcasted(self, dim, argnum): @@ -902,6 +906,33 @@ def test_with_multiple_parameters_broadcasted(self, dim, argnum): assert res[0].shape == res[1].shape == res[2].shape == (dim,) + @pytest.mark.parametrize("dim", [1, 3]) + @pytest.mark.parametrize("pos", [0, 1]) + def test_with_single_nontrainable_parameter_broadcasted(self, dim, pos): + """Test that the parameter-shift transform works with a tape that has + one of its nontrainable parameters broadcasted.""" + x = np.array([0.23, 9.1, 2.3]) + x = x[:dim] + y = -0.654 + if pos == 1: + x, y = y, x + + with qml.queuing.AnnotatedQueue() as q: + qml.RX(x, wires=[0]) + qml.RY(y, wires=[0]) # does not have any impact on the expval + qml.expval(qml.PauliZ(0)) + + tape = qml.tape.QuantumScript.from_queue(q) + tape.trainable_params = [1 - pos] + assert tape.batch_size == dim + tapes, fn = qml.gradients.param_shift(tape, argnum=[0]) + assert len(tapes) == 2 + assert np.allclose([t.batch_size for t in tapes], dim) + + dev = qml.device("default.qubit", wires=2) + res = fn(dev.execute(tapes)) + assert res.shape == (dim,) + class TestParamShiftUsingBroadcasting: """Tests for the `param_shift` function using broadcasting.