Skip to content

Commit

Permalink
Added jax support to private function _qsp_to_qsvt() which handles …
Browse files Browse the repository at this point in the history
…convention changes. (#5853)

**Context:**
Allow for `qml.qsvt()` to accept phase angles as jax arrays and perform
convention conversions.

---------

Co-authored-by: David Ittah <dime10@users.noreply.github.com>
  • Loading branch information
Jaybsoni and dime10 authored Jun 21, 2024
1 parent 96a3d56 commit e1bea42
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 6 deletions.
5 changes: 4 additions & 1 deletion doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@

* `qml.transforms.split_non_commuting` can now handle circuits containing measurements of multi-term observables.
[(#5729)](https://github.com/PennyLaneAI/pennylane/pull/5729)
[(#5853)](https://github.com/PennyLaneAI/pennylane/pull/5838)
[(#5838)](https://github.com/PennyLaneAI/pennylane/pull/5838)
[(#5828)](https://github.com/PennyLaneAI/pennylane/pull/5828)
[(#5869)](https://github.com/PennyLaneAI/pennylane/pull/5869)

Expand Down Expand Up @@ -373,6 +373,9 @@
to be simulated on the `default.qutrit.mixed` device.
[(#5784)](https://github.com/PennyLaneAI/pennylane/pull/5784)

* `qml.qsvt()` now supports jax arrays with angle conversions.
[(#5853)](https://github.com/PennyLaneAI/pennylane/pull/5853)

<h3>Breaking changes 💔</h3>

* Passing `shots` as a keyword argument to a `QNode` initialization now raises an error, instead of ignoring the input.
Expand Down
13 changes: 8 additions & 5 deletions pennylane/templates/subroutines/qsvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,9 +455,12 @@ def compute_matrix(*args, **kwargs):

def _qsp_to_qsvt(angles):
r"""Converts qsp angles to qsvt angles."""
new_angles = qml.math.array(copy.copy(angles))
new_angles[0] += 3 * np.pi / 4
new_angles[-1] -= np.pi / 4
num_angles = len(angles)
update_vals = np.empty(num_angles)

new_angles[1:-1] += np.pi / 2
return new_angles
update_vals[0] = 3 * np.pi / 4
update_vals[1:-1] = np.pi / 2
update_vals[-1] = -np.pi / 4
update_vals = qml.math.convert_like(update_vals, angles)

return angles + update_vals
27 changes: 27 additions & 0 deletions tests/templates/test_subroutines/test_qsvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,3 +578,30 @@ def circuit(A, phis):

for idx, result in enumerate(manual_phi_results):
assert np.isclose(result, np.real(phi_grad_results[idx]), atol=1e-6)


phase_angle_data = (
(
[0, 0, 0],
[3 * np.pi / 4, np.pi / 2, -np.pi / 4],
),
(
[1.0, 2.0, 3.0, 4.0],
[1.0 + 3 * np.pi / 4, 2.0 + np.pi / 2, 3.0 + np.pi / 2, 4.0 - np.pi / 4],
),
)


@pytest.mark.jax
@pytest.mark.parametrize("initial_angles, expected_angles", phase_angle_data)
def test_private_qsp_to_qsvt_jax(initial_angles, expected_angles):
"""Test that the _qsp_to_qsvt function is jax compatible"""
import jax.numpy as jnp

from pennylane.templates.subroutines.qsvt import _qsp_to_qsvt

initial_angles = jnp.array(initial_angles)
expected_angles = jnp.array(expected_angles)

computed_angles = _qsp_to_qsvt(initial_angles)
jnp.allclose(computed_angles, expected_angles)

0 comments on commit e1bea42

Please sign in to comment.