Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added jax support to private function _qsp_to_qsvt() which handles convention changes. #5853

Merged
merged 12 commits into from
Jun 21, 2024
5 changes: 4 additions & 1 deletion doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@

* `qml.transforms.split_non_commuting` can now handle circuits containing measurements of multi-term observables.
[(#5729)](https://github.com/PennyLaneAI/pennylane/pull/5729)
[(#5853)](https://github.com/PennyLaneAI/pennylane/pull/5838)
[(#5838)](https://github.com/PennyLaneAI/pennylane/pull/5838)
[(#5828)](https://github.com/PennyLaneAI/pennylane/pull/5828)
[(#5869)](https://github.com/PennyLaneAI/pennylane/pull/5869)

Expand Down Expand Up @@ -367,6 +367,9 @@
to be simulated on the `default.qutrit.mixed` device.
[(#5784)](https://github.com/PennyLaneAI/pennylane/pull/5784)

* `qml.qsvt()` now supports jax arrays with angle conversions.
[(#5853)](https://github.com/PennyLaneAI/pennylane/pull/5853)

<h3>Breaking changes 💔</h3>

* Passing `shots` as a keyword argument to a `QNode` initialization now raises an error, instead of ignoring the input.
Expand Down
1 change: 1 addition & 0 deletions pennylane/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
scatter,
scatter_element_add,
set_index,
add_index,
stack,
tensordot,
unwrap,
Expand Down
41 changes: 41 additions & 0 deletions pennylane/math/multi_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1065,3 +1065,44 @@ def set_index(array, idx, val, like=None):

array[idx] = val
return array


@multi_dispatch(tensor_list=[1])
def add_index(array, idx, val, like=None):
"""Add the value at a specified index in an array.
Calls ``array[idx]+=val`` and returns the updated array unless JAX.

If `idx` and `val` are TensorLike, then they must be the same length

Args:
array (tensor_like): array to be modified (unless JAX)
idx (int, tensor_like): index to modify
val (Union[int, float, complex, tensor_like]): value to update with

Returns:
a new copy of the array with the specified index updated by ``val``.

Whether the original array is modified is interface-dependent.

.. note:: TensorFlow EagerTensor does not support item assignment
"""
if like == "jax":
from jax import numpy as jnp

# ensure array is jax array (interface may be jax because of idx or val and not array)
jax_array = jnp.array(array)
Jaybsoni marked this conversation as resolved.
Show resolved Hide resolved

return jax_array.at[idx].add(val)

try:
if len(idx) and len(idx) == len(val):
idx_array = idx
val_array = val

except TypeError:
idx_array = [idx]
val_array = [val]

for id, v in zip(idx_array, val_array):
array[id] += v
return array
13 changes: 10 additions & 3 deletions pennylane/templates/subroutines/qsvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,9 +455,16 @@ def compute_matrix(*args, **kwargs):

def _qsp_to_qsvt(angles):
r"""Converts qsp angles to qsvt angles."""
num_angles = len(angles)
new_angles = qml.math.array(copy.copy(angles))
new_angles[0] += 3 * np.pi / 4
new_angles[-1] -= np.pi / 4
indicies = qml.math.convert_like(np.arange(num_angles), angles)

new_angles[1:-1] += np.pi / 2
update_vals = np.zeros(num_angles)
Jaybsoni marked this conversation as resolved.
Show resolved Hide resolved
update_vals[0] = 3 * np.pi / 4
update_vals[1:-1] = np.pi / 2
update_vals[-1] = -np.pi / 4

update_vals = qml.math.convert_like(update_vals, angles)

new_angles = qml.math.add_index(new_angles, indicies, update_vals)
Jaybsoni marked this conversation as resolved.
Show resolved Hide resolved
return new_angles
85 changes: 85 additions & 0 deletions tests/math/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2953,3 +2953,88 @@ def jitted_function(y):

assert qml.math.allclose(array2, jnp.array([[7, 2, 3, 4]]))
assert isinstance(array2, jnp.ndarray)


class TestAddIndex:
"""Test the add_index method."""

@pytest.mark.parametrize(
"array", [qml.numpy.ones((2, 2)), torch.ones((2, 2)), jnp.ones((2, 2))]
)
def test_add_index_jax_2d_array(self, array):
"""Test that an array can be created that is a copy of the
original array, with the value at the specified index updated"""

array2 = qml.math.add_index(array, (1, 1), 3)
assert qml.math.allclose(array2, np.array([[1, 1], [1, 4]]))
# since idx and val have no interface, we expect the returned array type to match initial type
assert isinstance(array2, type(array))

@pytest.mark.parametrize("array", [qml.numpy.ones((4)), torch.ones((4)), jnp.ones((4))])
def test_add_index_jax_1d_array(self, array):
"""Test that an array can be created that is a copy of the
original array, with the value at the specified index updated"""

array2 = qml.math.add_index(array, 3, 3)
assert qml.math.allclose(array2, np.array([[1, 1, 1, 4]]))
# since idx and val have no interface, we expect the returned array type to match initial type
assert isinstance(array2, type(array))

@pytest.mark.parametrize(
"array",
[jnp.array([[1, 2], [3, 4]]), onp.array([[1, 2], [3, 4]]), np.array([[1, 2], [3, 4]])],
)
def test_add_index_with_val_tracer(self, array):
"""Test that for both jax and numpy arrays, if the val to add is a tracer,
the add_index function succeeds and returns an updated jax array"""
from jax.interpreters.partial_eval import DynamicJaxprTracer

@jax.jit
def jitted_function(x):
assert isinstance(x, DynamicJaxprTracer)
return qml.math.add_index(array, (0, 0), x)

val = jnp.array(7)
array2 = jitted_function(val)

assert qml.math.allclose(array2, jnp.array([[8, 2], [3, 4]]))
assert isinstance(array2, jnp.ndarray)

@pytest.mark.parametrize(
"array",
[jnp.array([[1, 2], [3, 4]]), onp.array([[1, 2], [3, 4]]), np.array([[1, 2], [3, 4]])],
)
def test_add_index_with_idx_tracer_2D_array(self, array):
"""Test that for both jax and numpy 2d arrays, if the idx to add is a tracer,
the add_index function succeeds and returns an updated jax array"""
from jax.interpreters.partial_eval import DynamicJaxprTracer

@jax.jit
def jitted_function(y):
assert isinstance(y, DynamicJaxprTracer)
return qml.math.add_index(array, (1 + y, y), 7)

val = jnp.array(0)
array2 = jitted_function(val)

assert qml.math.allclose(array2, jnp.array([[1, 2], [10, 4]]))
assert isinstance(array2, jnp.ndarray)

@pytest.mark.parametrize(
"array", [jnp.array([1, 2, 3, 4]), onp.array([1, 2, 3, 4]), np.array([1, 2, 3, 4])]
)
def test_add_index_with_idx_tracer_1D_array(self, array):
"""Test that for both jax and numpy 1d arrays, if the idx to add is a tracer,
the add_index function succeeds and returns an updated jax array"""
from jax.interpreters.partial_eval import DynamicJaxprTracer

@jax.jit
def jitted_function(y):
assert isinstance(y, DynamicJaxprTracer)
return qml.math.add_index(array, y, 7)

val = jnp.array(0)
array2 = jitted_function(val)

assert qml.math.allclose(array2, jnp.array([[8, 2, 3, 4]]))
assert isinstance(array2, jnp.ndarray)
27 changes: 27 additions & 0 deletions tests/templates/test_subroutines/test_qsvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,3 +578,30 @@ def circuit(A, phis):

for idx, result in enumerate(manual_phi_results):
assert np.isclose(result, np.real(phi_grad_results[idx]), atol=1e-6)


phase_angle_data = (
(
[0, 0, 0],
[3 * np.pi / 4, np.pi / 2, -np.pi / 4],
),
(
[1.0, 2.0, 3.0, 4.0],
[1.0 + 3 * np.pi / 4, 2.0 + np.pi / 2, 3.0 + np.pi / 2, 4.0 - np.pi / 4],
),
)


@pytest.mark.jax
@pytest.mark.parametrize("initial_angles, expected_angles", phase_angle_data)
def test_private_qsp_to_qsvt_jax(initial_angles, expected_angles):
"""Test that the _qsp_to_qsvt function is jax compatible"""
import jax.numpy as jnp

from pennylane.templates.subroutines.qsvt import _qsp_to_qsvt

initial_angles = jnp.array(initial_angles)
expected_angles = jnp.array(expected_angles)

computed_angles = _qsp_to_qsvt(initial_angles)
jnp.allclose(computed_angles, expected_angles)
Loading