Skip to content

Commit

Permalink
Bugfix - conversions in _copy_and_shift_params (#4477)
Browse files Browse the repository at this point in the history
* Comment qml.math.cast_like.

* Cast to float64 _copy_and_shift_params.

* convert/cast new_params[p_idx]

* Cast new_params to float64.

* Add test for qml.gradients.param_shift when inputting integer parameters.

* Fix import_should_record_backprop.

* Fix cast and update _copy_and_shift_params.

* Unwrap before casting.

* Skip autograd astype.

* Revert general_shift_rules.py to master.

* Skip conversion for integral types.

* Test more types in test_integer_parameters.

* Update changelog.

* Update tests/gradients/parameter_shift/test_parameter_shift.py

Co-authored-by: Christina Lee <christina@xanadu.ai>

* Update pennylane/gradients/general_shift_rules.py

Co-authored-by: Christina Lee <christina@xanadu.ai>

* Reformat

* Fix TF test.

* Revert to numbers.Integral solution.

* no cover import_should_record_backprop

---------

Co-authored-by: Christina Lee <christina@xanadu.ai>
  • Loading branch information
vincentmr and albi3ro committed Aug 17, 2023
1 parent 7e0e5af commit bd701b7
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 20 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,9 @@ array([False, False])

<h3>Bug fixes 🐛</h3>

* `_copy_and_shift_params` does not cast or convert integral types, just relying on `+` and `*`'s casting rules in this case.
[(#4477)](https://github.com/PennyLaneAI/pennylane/pull/4477)

* `qml.Projector` is pickle-able again.
[(#4452)](https://github.com/PennyLaneAI/pennylane/pull/4452)

Expand Down
10 changes: 6 additions & 4 deletions pennylane/gradients/general_shift_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
shifted parameters."""
import functools
import itertools
import numbers
import warnings

import numpy as np
Expand Down Expand Up @@ -396,10 +397,11 @@ def _copy_and_shift_params(tape, indices, shifts, multipliers, cast=False):

# Shift copied parameter
new_params = list(op.data)
multiplier = qml.math.convert_like(multiplier, new_params[p_idx])
multiplier = qml.math.cast_like(multiplier, new_params[p_idx])
shift = qml.math.convert_like(shift, new_params[p_idx])
shift = qml.math.cast_like(shift, new_params[p_idx])
if not isinstance(new_params[p_idx], numbers.Integral):
multiplier = qml.math.convert_like(multiplier, new_params[p_idx])
multiplier = qml.math.cast_like(multiplier, new_params[p_idx])
shift = qml.math.convert_like(shift, new_params[p_idx])
shift = qml.math.cast_like(shift, new_params[p_idx])
new_params[p_idx] = new_params[p_idx] * multiplier
new_params[p_idx] = new_params[p_idx] + shift
if cast:
Expand Down
34 changes: 18 additions & 16 deletions pennylane/math/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,22 @@ def function(x):
return False


def import_should_record_backprop(): # pragma: no cover
"""Return should_record_backprop or an equivalent function from TensorFlow."""
import tensorflow.python as tfpy

if hasattr(tfpy.eager.tape, "should_record_backprop"):
from tensorflow.python.eager.tape import should_record_backprop
elif hasattr(tfpy.eager.tape, "should_record"):
from tensorflow.python.eager.tape import should_record as should_record_backprop
elif hasattr(tfpy.eager.record, "should_record_backprop"):
from tensorflow.python.eager.record import should_record_backprop
else:
raise ImportError("Cannot import should_record_backprop from TensorFlow.")

return should_record_backprop


def requires_grad(tensor, interface=None):
"""Returns True if the tensor is considered trainable.
Expand Down Expand Up @@ -454,14 +470,7 @@ def requires_grad(tensor, interface=None):
if interface == "tensorflow":
import tensorflow as tf

try:
try:
from tensorflow.python.eager.record import should_record_backprop
except ImportError: # pragma: no cover
from tensorflow.python.eager.tape import should_record_backprop
except ImportError: # pragma: no cover
from tensorflow.python.eager.tape import should_record as should_record_backprop

should_record_backprop = import_should_record_backprop()
return should_record_backprop([tf.convert_to_tensor(tensor)])

if interface == "autograd":
Expand Down Expand Up @@ -510,14 +519,7 @@ def in_backprop(tensor, interface=None):
if interface == "tensorflow":
import tensorflow as tf

try:
try:
from tensorflow.python.eager.record import should_record_backprop
except ImportError: # pragma: no cover
from tensorflow.python.eager.tape import should_record_backprop
except ImportError: # pragma: no cover
from tensorflow.python.eager.tape import should_record as should_record_backprop

should_record_backprop = import_should_record_backprop()
return should_record_backprop([tf.convert_to_tensor(tensor)])

if interface == "autograd":
Expand Down
18 changes: 18 additions & 0 deletions tests/gradients/parameter_shift/test_parameter_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -1567,6 +1567,24 @@ def test_single_expectation_value(self, tol):
assert np.allclose(res[0], expected[0], atol=tol, rtol=0)
assert np.allclose(res[1], expected[1], atol=tol, rtol=0)

@pytest.mark.parametrize(
"par", [0, 1, 2, 3, np.int8(1), np.int16(1), np.int32(1), np.int64(1)]
) # integers, zero
def test_integer_parameters(self, tol, par):
"""Test that the gradient of the RY gate matches the exact analytic formula."""
dev = qml.device("default.qubit", wires=2)

tape = qml.tape.QuantumScript([qml.RY(par, wires=[0])], [qml.expval(qml.PauliX(0))])
tape.trainable_params = {0}

# gradients
exact = np.cos(par)
gtapes, fn = qml.gradients.param_shift(tape)
grad_PS = fn(qml.execute(gtapes, dev, gradient_fn=None))

# different methods must agree
assert np.allclose(grad_PS, exact, atol=tol, rtol=0)

def test_multiple_expectation_values(self, tol):
"""Tests correct output shape and evaluation for a tape
with multiple expval outputs"""
Expand Down

0 comments on commit bd701b7

Please sign in to comment.