Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added jax support to private function _qsp_to_qsvt() which handles convention changes. #5853

Merged
merged 12 commits into from
Jun 21, 2024
5 changes: 4 additions & 1 deletion doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@

* `qml.transforms.split_non_commuting` can now handle circuits containing measurements of multi-term observables.
[(#5729)](https://github.com/PennyLaneAI/pennylane/pull/5729)
[(#5853)](https://github.com/PennyLaneAI/pennylane/pull/5838)
[(#5838)](https://github.com/PennyLaneAI/pennylane/pull/5838)
[(#5828)](https://github.com/PennyLaneAI/pennylane/pull/5828)
[(#5869)](https://github.com/PennyLaneAI/pennylane/pull/5869)

Expand Down Expand Up @@ -373,6 +373,9 @@
to be simulated on the `default.qutrit.mixed` device.
[(#5784)](https://github.com/PennyLaneAI/pennylane/pull/5784)

* `qml.qsvt()` now supports jax arrays with angle conversions.
[(#5853)](https://github.com/PennyLaneAI/pennylane/pull/5853)

<h3>Breaking changes 💔</h3>

* Passing `shots` as a keyword argument to a `QNode` initialization now raises an error, instead of ignoring the input.
Expand Down
13 changes: 8 additions & 5 deletions pennylane/templates/subroutines/qsvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,9 +455,12 @@ def compute_matrix(*args, **kwargs):

def _qsp_to_qsvt(angles):
r"""Converts qsp angles to qsvt angles."""
new_angles = qml.math.array(copy.copy(angles))
new_angles[0] += 3 * np.pi / 4
new_angles[-1] -= np.pi / 4
num_angles = len(angles)
update_vals = np.empty(num_angles)

new_angles[1:-1] += np.pi / 2
return new_angles
update_vals[0] = 3 * np.pi / 4
update_vals[1:-1] = np.pi / 2
update_vals[-1] = -np.pi / 4
update_vals = qml.math.convert_like(update_vals, angles)

return angles + update_vals
27 changes: 27 additions & 0 deletions tests/templates/test_subroutines/test_qsvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,3 +578,30 @@ def circuit(A, phis):

for idx, result in enumerate(manual_phi_results):
assert np.isclose(result, np.real(phi_grad_results[idx]), atol=1e-6)


phase_angle_data = (
(
[0, 0, 0],
[3 * np.pi / 4, np.pi / 2, -np.pi / 4],
),
(
[1.0, 2.0, 3.0, 4.0],
[1.0 + 3 * np.pi / 4, 2.0 + np.pi / 2, 3.0 + np.pi / 2, 4.0 - np.pi / 4],
),
)


@pytest.mark.jax
@pytest.mark.parametrize("initial_angles, expected_angles", phase_angle_data)
def test_private_qsp_to_qsvt_jax(initial_angles, expected_angles):
"""Test that the _qsp_to_qsvt function is jax compatible"""
import jax.numpy as jnp

from pennylane.templates.subroutines.qsvt import _qsp_to_qsvt

initial_angles = jnp.array(initial_angles)
expected_angles = jnp.array(expected_angles)

computed_angles = _qsp_to_qsvt(initial_angles)
jnp.allclose(computed_angles, expected_angles)
Loading