From 58c01e5c43909fc6c546b65278096e00b9c3dcc4 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Tue, 11 Jul 2023 09:43:51 -0400 Subject: [PATCH] Update gradients module to stop mutating operators in-place (#4220) * Adding changes for shift rules * Testing changes * Fixed op indices * Fixed indexing * Updated `bind_new_parameters` * Updated `tape.get_operation` * Updated tests * Test updates; multishifting works * Fixing interface * Updated shifting; added dispatch for templates * Updated shifting function * Fixed index error * Removed commented code * [skip ci] Reverted changes to `bind_new_parameters * Update to remove copying * Update changelog * Apply suggestions from code review Co-authored-by: Matthew Silverman * Removed unused import * Roll back suggested change * Remove state vector support from `math/quantum.py` (#4322) * Remove statevector support for qinfo functions * remove unused funcs * Fix some tests * pylint * more pylint * Remove unnecessary log * Changelog and deprecations entry * trigger ci * Enable CI and pre-commit hook to lint tests (#4335) * rename all legacy test files to match pylint pattern * pylint all tests in CI and pre-commit * lint legacy/qnn/conftest * remove the custom pylint test handling * run black before pylint * changelog --------- Co-authored-by: David Wierichs * Update broadcasting transforms to use `bind_new_parameters` (#4288) * use bind_new_parameters * pylint * remove batching in measurements and add tests * More coverage issues * Add uncopied tests * Helper function * pylint * Make tape._ops public * use private methods * pylint * Add more docs to split_operations * pylint * Deprecate X and P (#4330) * update in docs * update changelog * Test warning is raised * update default gaussian device * update tests * update more tests * fix legacy tests * Specify v0.33 removal in docstring * Render X and P in docs * move deprecation warning in docstring * fix sphinx linking * add warning box * Support wire labels in `qinfo` transforms (#4331) * Support wire labels in qinfo transforms * changelog * pylint and update test * add bugfix entry * Deprecations for 0.32 from me! (#4316) * deprecate the old return system * deprecate the mode kwarg for QNode * changelog * PR feedback * update notice to avoid wrongly suggesting action needed * update docstrings, docs and warnings * add link to qnode returns doc * change the mode warning depending on return system active * also add disclaimer to docstring --------- Co-authored-by: Matthew Silverman Co-authored-by: Edward Jiang <34989448+eddddddy@users.noreply.github.com> Co-authored-by: David Wierichs Co-authored-by: lillian542 <38584660+lillian542@users.noreply.github.com> --- doc/releases/changelog-dev.md | 5 ++ pennylane/gradients/general_shift_rules.py | 76 +++++++++++++--------- pennylane/tape/qscript.py | 4 +- tests/legacy/test_legacy_qscript_old.py | 4 +- tests/tape/test_qscript.py | 4 +- 5 files changed, 59 insertions(+), 34 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 24c3bdb1032..f0e4f66368a 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -6,6 +6,10 @@

Improvements 🛠

+* The `qml.gradients` module no longer mutates operators in-place for any gradient transforms. + Instead, operators that need to be mutated are copied with new parameters. + [(#4220)](https://github.com/PennyLaneAI/pennylane/pull/4220) + * `PauliWord` sparse matrices are much faster, which directly improves `PauliSentence`. [(#4272)](https://github.com/PennyLaneAI/pennylane/pull/4272) @@ -81,5 +85,6 @@ This release contains contributions from (in alphabetical order): Edward Jiang, Christina Lee, +Mudit Pandey, Borja Requena, Matthew Silverman diff --git a/pennylane/gradients/general_shift_rules.py b/pennylane/gradients/general_shift_rules.py index 4e73824a973..efb5c6a65e0 100644 --- a/pennylane/gradients/general_shift_rules.py +++ b/pennylane/gradients/general_shift_rules.py @@ -20,6 +20,9 @@ import numpy as np import pennylane as qml +from pennylane.measurements import MeasurementProcess +from pennylane.ops.functions import bind_new_parameters +from pennylane.tape import QuantumScript def process_shifts(rule, tol=1e-10, batch_duplicates=True): @@ -378,6 +381,44 @@ def generate_multi_shift_rule(frequencies, shifts=None, orders=None): return _combine_shift_rules(rules) +def _copy_and_shift_params(tape, indices, shifts, multipliers, cast=False): + """Create a copy of a tape and of parameters, and set the new tape to the parameters + rescaled and shifted as indicated by ``indices``, ``multipliers`` and ``shifts``.""" + all_ops = tape.circuit + + for idx, shift, multiplier in zip(indices, shifts, multipliers): + _, op_idx, p_idx = tape.get_operation(idx) + op = ( + all_ops[op_idx].obs + if isinstance(all_ops[op_idx], MeasurementProcess) + else all_ops[op_idx] + ) + + # Shift copied parameter + new_params = list(op.data) + new_params[p_idx] = new_params[p_idx] * qml.math.convert_like(multiplier, new_params[p_idx]) + new_params[p_idx] = new_params[p_idx] + qml.math.convert_like(shift, new_params[p_idx]) + if cast: + dtype = getattr(new_params[p_idx], "dtype", float) + new_params[p_idx] = qml.math.cast(new_params[p_idx], dtype) + + # Create operator with shifted parameter and put into shifted tape + shifted_op = bind_new_parameters(op, new_params) + if op_idx < len(tape.operations): + all_ops[op_idx] = shifted_op + else: + mp = all_ops[op_idx].__class__ + all_ops[op_idx] = mp(obs=shifted_op) + + # pylint: disable=protected-access + prep = all_ops[: len(tape._prep)] + ops = all_ops[len(tape._prep) : len(tape.operations)] + meas = all_ops[len(tape.operations) :] + shifted_tape = QuantumScript(ops=ops, measurements=meas, prep=prep, shots=tape.shots) + + return shifted_tape + + def generate_shifted_tapes(tape, index, shifts, multipliers=None, broadcast=False): r"""Generate a list of tapes or a single broadcasted tape, where one marked trainable parameter has been shifted by the provided shift values. @@ -403,27 +444,14 @@ def generate_shifted_tapes(tape, index, shifts, multipliers=None, broadcast=Fals the ``batch_size`` of the returned tape matches the length of ``shifts``. """ - def _copy_and_shift_params(tape, params, idx, shift, mult): - """Create a copy of a tape and of parameters, and set the new tape to the parameters - rescaled and shifted as indicated by ``idx``, ``mult`` and ``shift``.""" - new_params = params.copy() - new_params[idx] = new_params[idx] * qml.math.convert_like( - mult, new_params[idx] - ) + qml.math.convert_like(shift, new_params[idx]) - - shifted_tape = tape.copy(copy_operations=True) - shifted_tape.set_parameters(new_params) - return shifted_tape - - params = list(tape.get_parameters()) if multipliers is None: multipliers = np.ones_like(shifts) if broadcast: - return (_copy_and_shift_params(tape, params, index, shifts, multipliers),) + return (_copy_and_shift_params(tape, [index], [shifts], [multipliers]),) return tuple( - _copy_and_shift_params(tape, params, index, shift, multiplier) + _copy_and_shift_params(tape, [index], [shift], [multiplier]) for shift, multiplier in zip(shifts, multipliers) ) @@ -450,22 +478,12 @@ def generate_multishifted_tapes(tape, indices, shifts, multipliers=None): of tapes will match the summed lengths of all inner sequences in ``shifts`` and ``multipliers`` (if provided). """ - params = list(tape.get_parameters()) if multipliers is None: multipliers = np.ones_like(shifts) - tapes = [] - - for _shifts, _multipliers in zip(shifts, multipliers): - new_params = params.copy() - shifted_tape = tape.copy(copy_operations=True) - for idx, shift, multiplier in zip(indices, _shifts, _multipliers): - dtype = getattr(new_params[idx], "dtype", float) - new_params[idx] = new_params[idx] * qml.math.convert_like(multiplier, new_params[idx]) - new_params[idx] = new_params[idx] + qml.math.convert_like(shift, new_params[idx]) - new_params[idx] = qml.math.cast(new_params[idx], dtype) - - shifted_tape.set_parameters(new_params) - tapes.append(shifted_tape) + tapes = [ + _copy_and_shift_params(tape, indices, _shifts, _multipliers, cast=True) + for _shifts, _multipliers in zip(shifts, multipliers) + ] return tapes diff --git a/pennylane/tape/qscript.py b/pennylane/tape/qscript.py index 5bdea78390c..25173f61713 100644 --- a/pennylane/tape/qscript.py +++ b/pennylane/tape/qscript.py @@ -428,10 +428,12 @@ def _update_par_info(self): {"op": op, "op_idx": idx, "p_idx": i} for i, d in enumerate(op.data) ) + n_ops = len(self.operations) for idx, m in enumerate(self.measurements): if m.obs is not None: self._par_info.extend( - {"op": m.obs, "op_idx": idx, "p_idx": i} for i, d in enumerate(m.obs.data) + {"op": m.obs, "op_idx": idx + n_ops, "p_idx": i} + for i, d in enumerate(m.obs.data) ) def _update_trainable_params(self): diff --git a/tests/legacy/test_legacy_qscript_old.py b/tests/legacy/test_legacy_qscript_old.py index 73fd31e7b5d..096be217fe8 100644 --- a/tests/legacy/test_legacy_qscript_old.py +++ b/tests/legacy/test_legacy_qscript_old.py @@ -186,7 +186,7 @@ def test_update_par_info_update_trainable_params(self): assert p_i[4] == {"op": ops[2], "op_idx": 2, "p_idx": 0} assert p_i[5] == {"op": ops[3], "op_idx": 3, "p_idx": 0} assert p_i[6] == {"op": ops[3], "op_idx": 3, "p_idx": 1} - assert p_i[7] == {"op": m[0].obs, "op_idx": 0, "p_idx": 0} + assert p_i[7] == {"op": m[0].obs, "op_idx": 4, "p_idx": 0} assert qs._trainable_params == list(range(8)) @@ -224,7 +224,7 @@ def test_get_operation(self): assert op_6 == ops[4] and op_id_6 == 4 and p_id_6 == 1 _, obs_id_0, p_id_0 = qs.get_operation(7) - assert obs_id_0 == 0 and p_id_0 == 0 + assert obs_id_0 == 5 and p_id_0 == 0 def test_update_observables(self): """This method needs to be more thoroughly tested, and probably even reconsidered in diff --git a/tests/tape/test_qscript.py b/tests/tape/test_qscript.py index ca5eaa19bb4..4fa8548763c 100644 --- a/tests/tape/test_qscript.py +++ b/tests/tape/test_qscript.py @@ -183,7 +183,7 @@ def test_update_par_info_update_trainable_params(self): assert p_i[4] == {"op": ops[2], "op_idx": 2, "p_idx": 0} assert p_i[5] == {"op": ops[3], "op_idx": 3, "p_idx": 0} assert p_i[6] == {"op": ops[3], "op_idx": 3, "p_idx": 1} - assert p_i[7] == {"op": m[0].obs, "op_idx": 0, "p_idx": 0} + assert p_i[7] == {"op": m[0].obs, "op_idx": 4, "p_idx": 0} assert qs._trainable_params == list(range(8)) @@ -222,7 +222,7 @@ def test_get_operation(self): assert op_6 == ops[4] and op_id_6 == 4 and p_id_6 == 1 _, obs_id_0, p_id_0 = qs.get_operation(7) - assert obs_id_0 == 0 and p_id_0 == 0 + assert obs_id_0 == 5 and p_id_0 == 0 def test_update_observables(self): """This method needs to be more thoroughly tested, and probably even reconsidered in