Skip to content

Commit

Permalink
PennyLane is compatible with JAX 0.4.28 (#6255)
Browse files Browse the repository at this point in the history
**Context:** As part of the effort to make PL compatible with Numpy 2.0
(see #6061), we need to upgrade JAX to 0.4.26+ since such a version
introduced the support for Numpy 2.0. We opted for JAX 0.4.28 since it
is the same version used by Catalyst.

**Description of the Change:** As above.

**Benefits:** PL is compatible with Numpy 2.0 and Jax 0.4.28.

**Possible Drawbacks:** 

- From JAX 0.4.27, in `jax.jit`, passing invalid static_argnums or
static_argnames now leads to an error rather than a warning. In PL, this
breaks every test where we set `shots` in the `QNode` call with
`static_argnames=["shots"]`. At this stage, we decided to mark such
tests with `pytest.xfail` to allow the upgrade.

**Related GitHub Issues:** None.

**Related Shortcut Stories**: [sc-61389]

---------

Co-authored-by: dwierichs <david.wierichs@xanadu.ai>
  • Loading branch information
2 people authored and mudit2812 committed Sep 18, 2024
1 parent f5d1807 commit 02f7efa
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 6 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/install_deps/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ inputs:
jax_version:
description: The version of JAX to install for any job that requires JAX
required: false
default: '0.4.23'
default: '0.4.28'
install_tensorflow:
description: Indicate if TensorFlow should be installed or not
required: false
Expand Down Expand Up @@ -86,7 +86,7 @@ runs:
if: inputs.install_jax == 'true'
env:
JAX_VERSION: ${{ inputs.jax_version != '' && format('=={0}', inputs.jax_version) || '' }}
run: pip install "jax${{ env.JAX_VERSION}}" "jaxlib${{ env.JAX_VERSION }}" scipy~=1.12.0
run: pip install "jax${{ env.JAX_VERSION}}" "jaxlib${{ env.JAX_VERSION }}"

- name: Install additional PIP packages
shell: bash
Expand Down
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
[(#6061)](https://github.com/PennyLaneAI/pennylane/pull/6061)
[(#6258)](https://github.com/PennyLaneAI/pennylane/pull/6258)

* PennyLane is now compatible with Jax 0.4.28.
[(#6255)](https://github.com/PennyLaneAI/pennylane/pull/6255)

* `qml.qchem.excitations` now optionally returns fermionic operators.
[(#6171)](https://github.com/PennyLaneAI/pennylane/pull/6171)

Expand Down
2 changes: 2 additions & 0 deletions tests/devices/default_qubit/test_default_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1864,6 +1864,7 @@ def circ_expected():
if use_jit:
import jax

pytest.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28")
circ_postselect = jax.jit(circ_postselect, static_argnames=["shots"])

res = circ_postselect(param, shots=shots)
Expand Down Expand Up @@ -2051,6 +2052,7 @@ def circ():
if use_jit:
import jax

pytest.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28")
circ = jax.jit(circ, static_argnames=["shots"])

res = circ(shots=shots)
Expand Down
2 changes: 1 addition & 1 deletion tests/devices/qutrit_mixed/test_qutrit_mixed_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ def test_jax_backprop(self, use_jit):

x = jax.numpy.array(self.x, dtype=jax.numpy.float64)
coeffs = (5.2, 6.7)
f = jax.jit(self.f, static_argnums=(1, 2, 3, 4)) if use_jit else self.f
f = jax.jit(self.f, static_argnums=(1, 2, 3)) if use_jit else self.f

out = f(x, coeffs)
expected_out = self.expected(x, coeffs)
Expand Down
2 changes: 1 addition & 1 deletion tests/devices/test_default_qutrit_mixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,7 +823,7 @@ def test_jax_backprop(self, use_jit):

x = jax.numpy.array(self.x, dtype=jax.numpy.float64)
coeffs = (5.2, 6.7)
f = jax.jit(self.f, static_argnums=(1, 2, 3, 4)) if use_jit else self.f
f = jax.jit(self.f, static_argnums=(1, 2, 3)) if use_jit else self.f

out = f(x, coeffs)
expected_out = self.expected(x, coeffs)
Expand Down
17 changes: 17 additions & 0 deletions tests/interfaces/test_jax_jit_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,7 @@ def circuit(a, b):
res = circuit(a, b, shots=100) # pylint: disable=unexpected-keyword-arg
assert res.shape == (100, 2) # pylint:disable=comparison-with-callable

@pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28")
def test_gradient_integration(self, interface):
"""Test that temporarily setting the shots works
for gradient computations"""
Expand Down Expand Up @@ -912,6 +913,7 @@ def circuit(x):
class TestQubitIntegration:
"""Tests that ensure various qubit circuits integrate correctly"""

@pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28")
def test_sampling(self, dev, diff_method, grad_on_execution, device_vjp, interface):
"""Test sampling works as expected"""
if grad_on_execution:
Expand Down Expand Up @@ -941,6 +943,7 @@ def circuit():
assert isinstance(res[1], jax.Array)
assert res[1].shape == (10,)

@pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28")
def test_counts(self, dev, diff_method, grad_on_execution, device_vjp, interface):
"""Test counts works as expected"""
if grad_on_execution:
Expand Down Expand Up @@ -2041,6 +2044,7 @@ def circ(p, U):
class TestReturn:
"""Class to test the shape of the Grad/Jacobian with different return types."""

@pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28")
def test_grad_single_measurement_param(
self, dev, diff_method, grad_on_execution, device_vjp, jacobian, shots, interface
):
Expand Down Expand Up @@ -2073,6 +2077,7 @@ def circuit(a):
assert isinstance(grad, jax.numpy.ndarray)
assert grad.shape == ()

@pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28")
def test_grad_single_measurement_multiple_param(
self, dev, diff_method, grad_on_execution, device_vjp, jacobian, shots, interface
):
Expand Down Expand Up @@ -2110,6 +2115,7 @@ def circuit(a, b):
assert grad[0].shape == ()
assert grad[1].shape == ()

@pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28")
def test_grad_single_measurement_multiple_param_array(
self, dev, diff_method, grad_on_execution, device_vjp, jacobian, shots, interface
):
Expand Down Expand Up @@ -2142,6 +2148,7 @@ def circuit(a):
assert isinstance(grad, jax.numpy.ndarray)
assert grad.shape == (2,)

@pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28")
def test_jacobian_single_measurement_param_probs(
self, dev, diff_method, grad_on_execution, device_vjp, jacobian, shots, interface
):
Expand Down Expand Up @@ -2175,6 +2182,7 @@ def circuit(a):
assert isinstance(jac, jax.numpy.ndarray)
assert jac.shape == (4,)

@pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28")
def test_jacobian_single_measurement_probs_multiple_param(
self, dev, diff_method, grad_on_execution, device_vjp, jacobian, shots, interface
):
Expand Down Expand Up @@ -2214,6 +2222,7 @@ def circuit(a, b):
assert isinstance(jac[1], jax.numpy.ndarray)
assert jac[1].shape == (4,)

@pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28")
def test_jacobian_single_measurement_probs_multiple_param_single_array(
self, dev, diff_method, grad_on_execution, device_vjp, jacobian, shots, interface
):
Expand Down Expand Up @@ -2246,6 +2255,7 @@ def circuit(a):
assert isinstance(jac, jax.numpy.ndarray)
assert jac.shape == (4, 2)

@pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28")
def test_jacobian_expval_expval_multiple_params(
self, dev, diff_method, grad_on_execution, device_vjp, jacobian, shots, interface
):
Expand Down Expand Up @@ -2295,6 +2305,7 @@ def circuit(x, y):
assert isinstance(jac[1][1], jax.numpy.ndarray)
assert jac[1][1].shape == ()

@pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28")
def test_jacobian_expval_expval_multiple_params_array(
self, dev, diff_method, grad_on_execution, device_vjp, jacobian, shots, interface
):
Expand Down Expand Up @@ -2333,6 +2344,7 @@ def circuit(a):
assert isinstance(jac[1], jax.numpy.ndarray)
assert jac[1].shape == (2,)

@pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28")
def test_jacobian_var_var_multiple_params(
self, dev, diff_method, grad_on_execution, device_vjp, jacobian, shots, interface
):
Expand Down Expand Up @@ -2385,6 +2397,7 @@ def circuit(x, y):
assert isinstance(jac[1][1], jax.numpy.ndarray)
assert jac[1][1].shape == ()

@pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28")
def test_jacobian_var_var_multiple_params_array(
self, dev, diff_method, grad_on_execution, device_vjp, jacobian, shots, interface
):
Expand Down Expand Up @@ -2425,6 +2438,7 @@ def circuit(a):
assert isinstance(jac[1], jax.numpy.ndarray)
assert jac[1].shape == (2,)

@pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28")
def test_jacobian_multiple_measurement_single_param(
self, dev, diff_method, grad_on_execution, device_vjp, jacobian, shots, interface
):
Expand Down Expand Up @@ -2463,6 +2477,7 @@ def circuit(a):
assert isinstance(jac[1], jax.numpy.ndarray)
assert jac[1].shape == (4,)

@pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28")
def test_jacobian_multiple_measurement_multiple_param(
self, dev, diff_method, grad_on_execution, device_vjp, jacobian, shots, interface
):
Expand Down Expand Up @@ -2510,6 +2525,7 @@ def circuit(a, b):
assert isinstance(jac[1][1], jax.numpy.ndarray)
assert jac[1][1].shape == (4,)

@pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28")
def test_jacobian_multiple_measurement_multiple_param_array(
self, dev, diff_method, grad_on_execution, device_vjp, jacobian, shots, interface
):
Expand Down Expand Up @@ -2871,6 +2887,7 @@ def circuit(x):
assert hess[1].shape == (2, 2, 2)


@pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28")
@pytest.mark.parametrize("hessian", hessian_fn)
@pytest.mark.parametrize("diff_method", ["parameter-shift", "hadamard"])
def test_jax_device_hessian_shots(hessian, diff_method):
Expand Down
13 changes: 11 additions & 2 deletions tests/transforms/test_optimization/test_optimization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,17 @@ def test_jacobian_jax(self, use_jit):
special_angles = np.array(list(product(special_points, repeat=6))).reshape((-1, 2, 3))
random_angles = np.random.random((1000, 2, 3))
# Need holomorphic derivatives and complex inputs because the output matrices are complex
all_angles = jax.numpy.concatenate([special_angles, random_angles], dtype=complex)
jac_fn = lambda fn: jax.vmap(jax.jacobian(fn, holomorphic=True))
all_angles = jax.numpy.concatenate([special_angles, random_angles])

# We need to define the Jacobian function manually because fuse_rot_angles is not guaranteed to be holomorphic,
# and jax.jacobian requires real-valued outputs for non-holomorphic functions.
def jac_fn(fn):
real_fn = lambda arg: qml.math.real(fn(arg))
imag_fn = lambda arg: qml.math.imag(fn(arg))
real_jac_fn = jax.vmap(jax.jacobian(real_fn))
imag_jac_fn = jax.vmap(jax.jacobian(imag_fn))
return lambda arg: real_jac_fn(arg) + 1j * imag_jac_fn(arg)

jit_fn = jax.jit if use_jit else None
self.run_jacobian_test(all_angles, jac_fn, is_batched=True, jit_fn=jit_fn)

Expand Down

0 comments on commit 02f7efa

Please sign in to comment.