Skip to content

Commit

Permalink
Update gradients module to stop mutating operators in-place (#4220)
Browse files Browse the repository at this point in the history
* Adding changes for shift rules

* Testing changes

* Fixed op indices

* Fixed indexing

* Updated `bind_new_parameters`

* Updated `tape.get_operation`

* Updated tests

* Test updates; multishifting works

* Fixing interface

* Updated shifting; added dispatch for templates

* Updated shifting function

* Fixed index error

* Removed commented code

* [skip ci] Reverted changes to `bind_new_parameters

* Update to remove copying

* Update changelog

* Apply suggestions from code review

Co-authored-by: Matthew Silverman <matthews@xanadu.ai>

* Removed unused import

* Roll back suggested change

* Remove state vector support from `math/quantum.py` (#4322)

* Remove statevector support for qinfo functions

* remove unused funcs

* Fix some tests

* pylint

* more pylint

* Remove unnecessary log

* Changelog and deprecations entry

* trigger ci

* Enable CI and pre-commit hook to lint tests (#4335)

* rename all legacy test files to match pylint pattern

* pylint all tests in CI and pre-commit

* lint legacy/qnn/conftest

* remove the custom pylint test handling

* run black before pylint

* changelog

---------

Co-authored-by: David Wierichs <david.wierichs@xanadu.ai>

* Update broadcasting transforms to use `bind_new_parameters` (#4288)

* use bind_new_parameters

* pylint

* remove batching in measurements and add tests

* More coverage issues

* Add uncopied tests

* Helper function

* pylint

* Make tape._ops public

* use private methods

* pylint

* Add more docs to split_operations

* pylint

* Deprecate X and P (#4330)

* update in docs

* update changelog

* Test warning is raised

* update default gaussian device

* update tests

* update more tests

* fix legacy tests

* Specify v0.33 removal in docstring

* Render X and P in docs

* move deprecation warning in docstring

* fix sphinx linking

* add warning box

* Support wire labels in `qinfo` transforms (#4331)

* Support wire labels in qinfo transforms

* changelog

* pylint and update test

* add bugfix entry

* Deprecations for 0.32 from me! (#4316)

* deprecate the old return system

* deprecate the mode kwarg for QNode

* changelog

* PR feedback

* update notice to avoid wrongly suggesting action needed

* update docstrings, docs and warnings

* add link to qnode returns doc

* change the mode warning depending on return system active

* also add disclaimer to docstring

---------

Co-authored-by: Matthew Silverman <matthews@xanadu.ai>
Co-authored-by: Edward Jiang <34989448+eddddddy@users.noreply.github.com>
Co-authored-by: David Wierichs <david.wierichs@xanadu.ai>
Co-authored-by: lillian542 <38584660+lillian542@users.noreply.github.com>
  • Loading branch information
5 people committed Jul 11, 2023
1 parent f4537a9 commit 58c01e5
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 34 deletions.
5 changes: 5 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@

<h3>Improvements 🛠</h3>

* The `qml.gradients` module no longer mutates operators in-place for any gradient transforms.
Instead, operators that need to be mutated are copied with new parameters.
[(#4220)](https://github.com/PennyLaneAI/pennylane/pull/4220)

* `PauliWord` sparse matrices are much faster, which directly improves `PauliSentence`.
[(#4272)](https://github.com/PennyLaneAI/pennylane/pull/4272)

Expand Down Expand Up @@ -81,5 +85,6 @@ This release contains contributions from (in alphabetical order):

Edward Jiang,
Christina Lee,
Mudit Pandey,
Borja Requena,
Matthew Silverman
76 changes: 47 additions & 29 deletions pennylane/gradients/general_shift_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@

import numpy as np
import pennylane as qml
from pennylane.measurements import MeasurementProcess
from pennylane.ops.functions import bind_new_parameters
from pennylane.tape import QuantumScript


def process_shifts(rule, tol=1e-10, batch_duplicates=True):
Expand Down Expand Up @@ -378,6 +381,44 @@ def generate_multi_shift_rule(frequencies, shifts=None, orders=None):
return _combine_shift_rules(rules)


def _copy_and_shift_params(tape, indices, shifts, multipliers, cast=False):
"""Create a copy of a tape and of parameters, and set the new tape to the parameters
rescaled and shifted as indicated by ``indices``, ``multipliers`` and ``shifts``."""
all_ops = tape.circuit

for idx, shift, multiplier in zip(indices, shifts, multipliers):
_, op_idx, p_idx = tape.get_operation(idx)
op = (
all_ops[op_idx].obs
if isinstance(all_ops[op_idx], MeasurementProcess)
else all_ops[op_idx]
)

# Shift copied parameter
new_params = list(op.data)
new_params[p_idx] = new_params[p_idx] * qml.math.convert_like(multiplier, new_params[p_idx])
new_params[p_idx] = new_params[p_idx] + qml.math.convert_like(shift, new_params[p_idx])
if cast:
dtype = getattr(new_params[p_idx], "dtype", float)
new_params[p_idx] = qml.math.cast(new_params[p_idx], dtype)

# Create operator with shifted parameter and put into shifted tape
shifted_op = bind_new_parameters(op, new_params)
if op_idx < len(tape.operations):
all_ops[op_idx] = shifted_op
else:
mp = all_ops[op_idx].__class__
all_ops[op_idx] = mp(obs=shifted_op)

# pylint: disable=protected-access
prep = all_ops[: len(tape._prep)]
ops = all_ops[len(tape._prep) : len(tape.operations)]
meas = all_ops[len(tape.operations) :]
shifted_tape = QuantumScript(ops=ops, measurements=meas, prep=prep, shots=tape.shots)

return shifted_tape


def generate_shifted_tapes(tape, index, shifts, multipliers=None, broadcast=False):
r"""Generate a list of tapes or a single broadcasted tape, where one marked
trainable parameter has been shifted by the provided shift values.
Expand All @@ -403,27 +444,14 @@ def generate_shifted_tapes(tape, index, shifts, multipliers=None, broadcast=Fals
the ``batch_size`` of the returned tape matches the length of ``shifts``.
"""

def _copy_and_shift_params(tape, params, idx, shift, mult):
"""Create a copy of a tape and of parameters, and set the new tape to the parameters
rescaled and shifted as indicated by ``idx``, ``mult`` and ``shift``."""
new_params = params.copy()
new_params[idx] = new_params[idx] * qml.math.convert_like(
mult, new_params[idx]
) + qml.math.convert_like(shift, new_params[idx])

shifted_tape = tape.copy(copy_operations=True)
shifted_tape.set_parameters(new_params)
return shifted_tape

params = list(tape.get_parameters())
if multipliers is None:
multipliers = np.ones_like(shifts)

if broadcast:
return (_copy_and_shift_params(tape, params, index, shifts, multipliers),)
return (_copy_and_shift_params(tape, [index], [shifts], [multipliers]),)

return tuple(
_copy_and_shift_params(tape, params, index, shift, multiplier)
_copy_and_shift_params(tape, [index], [shift], [multiplier])
for shift, multiplier in zip(shifts, multipliers)
)

Expand All @@ -450,22 +478,12 @@ def generate_multishifted_tapes(tape, indices, shifts, multipliers=None):
of tapes will match the summed lengths of all inner sequences in ``shifts``
and ``multipliers`` (if provided).
"""
params = list(tape.get_parameters())
if multipliers is None:
multipliers = np.ones_like(shifts)

tapes = []

for _shifts, _multipliers in zip(shifts, multipliers):
new_params = params.copy()
shifted_tape = tape.copy(copy_operations=True)
for idx, shift, multiplier in zip(indices, _shifts, _multipliers):
dtype = getattr(new_params[idx], "dtype", float)
new_params[idx] = new_params[idx] * qml.math.convert_like(multiplier, new_params[idx])
new_params[idx] = new_params[idx] + qml.math.convert_like(shift, new_params[idx])
new_params[idx] = qml.math.cast(new_params[idx], dtype)

shifted_tape.set_parameters(new_params)
tapes.append(shifted_tape)
tapes = [
_copy_and_shift_params(tape, indices, _shifts, _multipliers, cast=True)
for _shifts, _multipliers in zip(shifts, multipliers)
]

return tapes
4 changes: 3 additions & 1 deletion pennylane/tape/qscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,10 +428,12 @@ def _update_par_info(self):
{"op": op, "op_idx": idx, "p_idx": i} for i, d in enumerate(op.data)
)

n_ops = len(self.operations)
for idx, m in enumerate(self.measurements):
if m.obs is not None:
self._par_info.extend(
{"op": m.obs, "op_idx": idx, "p_idx": i} for i, d in enumerate(m.obs.data)
{"op": m.obs, "op_idx": idx + n_ops, "p_idx": i}
for i, d in enumerate(m.obs.data)
)

def _update_trainable_params(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/legacy/test_legacy_qscript_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def test_update_par_info_update_trainable_params(self):
assert p_i[4] == {"op": ops[2], "op_idx": 2, "p_idx": 0}
assert p_i[5] == {"op": ops[3], "op_idx": 3, "p_idx": 0}
assert p_i[6] == {"op": ops[3], "op_idx": 3, "p_idx": 1}
assert p_i[7] == {"op": m[0].obs, "op_idx": 0, "p_idx": 0}
assert p_i[7] == {"op": m[0].obs, "op_idx": 4, "p_idx": 0}

assert qs._trainable_params == list(range(8))

Expand Down Expand Up @@ -224,7 +224,7 @@ def test_get_operation(self):
assert op_6 == ops[4] and op_id_6 == 4 and p_id_6 == 1

_, obs_id_0, p_id_0 = qs.get_operation(7)
assert obs_id_0 == 0 and p_id_0 == 0
assert obs_id_0 == 5 and p_id_0 == 0

def test_update_observables(self):
"""This method needs to be more thoroughly tested, and probably even reconsidered in
Expand Down
4 changes: 2 additions & 2 deletions tests/tape/test_qscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def test_update_par_info_update_trainable_params(self):
assert p_i[4] == {"op": ops[2], "op_idx": 2, "p_idx": 0}
assert p_i[5] == {"op": ops[3], "op_idx": 3, "p_idx": 0}
assert p_i[6] == {"op": ops[3], "op_idx": 3, "p_idx": 1}
assert p_i[7] == {"op": m[0].obs, "op_idx": 0, "p_idx": 0}
assert p_i[7] == {"op": m[0].obs, "op_idx": 4, "p_idx": 0}

assert qs._trainable_params == list(range(8))

Expand Down Expand Up @@ -222,7 +222,7 @@ def test_get_operation(self):
assert op_6 == ops[4] and op_id_6 == 4 and p_id_6 == 1

_, obs_id_0, p_id_0 = qs.get_operation(7)
assert obs_id_0 == 0 and p_id_0 == 0
assert obs_id_0 == 5 and p_id_0 == 0

def test_update_observables(self):
"""This method needs to be more thoroughly tested, and probably even reconsidered in
Expand Down

0 comments on commit 58c01e5

Please sign in to comment.