Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added jax support to private function _qsp_to_qsvt() which handles convention changes. #5853

Merged
merged 12 commits into from
Jun 21, 2024
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# Release 0.37.0-dev (development release)

<h3>New features since last release</h3>

Check warning on line 5 in doc/releases/changelog-dev.md

View workflow job for this annotation

GitHub Actions / sphinx

Duplicate explicit target name: "(#5853)".

* The `default.tensor` device now supports the `tn` method to simulate quantum circuits using exact tensor networks.
[(#5786)](https://github.com/PennyLaneAI/pennylane/pull/5786)
Expand Down Expand Up @@ -272,6 +272,9 @@
to be simulated on the `default.qutrit.mixed` device.
[(#5784)](https://github.com/PennyLaneAI/pennylane/pull/5784)

* `qml.qsvt()` now supports jax arrays with angle conversions.
[(#5853)](https://github.com/PennyLaneAI/pennylane/pull/5853)

<h3>Breaking changes 💔</h3>

* Passing `shots` as a keyword argument to a `QNode` initialization now raises an error, instead of ignoring the input.
Expand Down
14 changes: 11 additions & 3 deletions pennylane/templates/subroutines/qsvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,9 +455,17 @@ def compute_matrix(*args, **kwargs):

def _qsp_to_qsvt(angles):
r"""Converts qsp angles to qsvt angles."""
num_angles = len(angles)
new_angles = qml.math.array(copy.copy(angles))
new_angles[0] += 3 * np.pi / 4
new_angles[-1] -= np.pi / 4

new_angles[1:-1] += np.pi / 2
for index in range(num_angles):
if index == 0:
update_val = 3 * np.pi / 4
elif index == num_angles - 1:
update_val = -np.pi / 4
else:
update_val = np.pi / 2
Jaybsoni marked this conversation as resolved.
Show resolved Hide resolved

qml.math.set_index(new_angles, index, new_angles[index] + update_val)
Jaybsoni marked this conversation as resolved.
Show resolved Hide resolved

return new_angles
27 changes: 27 additions & 0 deletions tests/templates/test_subroutines/test_qsvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,3 +577,30 @@ def circuit(A, phis):

for idx, result in enumerate(manual_phi_results):
assert np.isclose(result, np.real(phi_grad_results[idx]), atol=1e-6)


phase_angle_data = (
(
[0, 0, 0],
[3 * np.pi / 4, np.pi / 2, -np.pi / 4],
),
(
[1.0, 2.0, 3.0, 4.0],
[1.0 + 3 * np.pi / 4, 2.0 + np.pi / 2, 3.0 + np.pi / 2, 4.0 - np.pi / 4],
),
)


@pytest.mark.jax
@pytest.mark.parametrize("initial_angles, expected_angles", phase_angle_data)
def test_private_qsp_to_qsvt_jax(initial_angles, expected_angles):
"""Test that the _qsp_to_qsvt function is jax compatible"""
import jax.numpy as jnp

from pennylane.templates.subroutines.qsvt import _qsp_to_qsvt

initial_angles = jnp.array(initial_angles)
expected_angles = jnp.array(expected_angles)

computed_angles = _qsp_to_qsvt(initial_angles)
jnp.allclose(computed_angles, expected_angles)
Loading