Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PennyLane is compatible with JAX 0.4.28 #6255

Merged
merged 76 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
c471f5b
bring changes from prep_for_np2 branch
EmilianoG-byte Aug 1, 2024
c3e011e
change from rc to released version
EmilianoG-byte Aug 2, 2024
6ffb5e5
Merge branch 'master' into compatible-np-2.0
EmilianoG-byte Aug 2, 2024
c8b7d91
add scipy<= 1.13 on ci requirements to become np 2.0 compatible
EmilianoG-byte Aug 6, 2024
e54eeef
pin autograd to major 1.7.0 version
EmilianoG-byte Aug 27, 2024
0c0d5a6
Merge branch 'master' into compatible-np-2.0
EmilianoG-byte Aug 27, 2024
b3ab335
set numpy print options to legacy for scalars and change np.NaN to np…
EmilianoG-byte Aug 27, 2024
ef47609
Revert "set numpy print options to legacy for scalars and change np.N…
EmilianoG-byte Aug 27, 2024
9a88660
change NaN for nan
EmilianoG-byte Aug 27, 2024
f523d32
increase version of scipy used with jax
EmilianoG-byte Aug 27, 2024
236b3f5
change to use context manager for legacy and check against older vers…
EmilianoG-byte Aug 27, 2024
d94d30c
add legacy print option context manager to torch and prepselprep tests
EmilianoG-byte Aug 27, 2024
2d8db14
add legacy printing solution to data-tests
EmilianoG-byte Aug 27, 2024
f9322a7
change regex for probabilities not adding 1 from numpy
EmilianoG-byte Aug 27, 2024
50711df
use legacy context manager on default qubit legacy
EmilianoG-byte Aug 27, 2024
e0bcd0a
Merge branch 'master' into compatible-np-2.0
EmilianoG-byte Aug 28, 2024
1793a03
check numpy version to check dtype in test_apply_global_phase
EmilianoG-byte Aug 28, 2024
d55a614
move context manager to only affect asserts related to representation
EmilianoG-byte Aug 29, 2024
e728c25
ping numpy in requirements
EmilianoG-byte Aug 29, 2024
f1ff64e
Resolving CodeFactor issues
PietropaoloFrisoni Sep 6, 2024
fdca824
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Sep 6, 2024
0f04e84
Using isort:skip to resolve conflict (between CodeFactor and isort) a…
PietropaoloFrisoni Sep 6, 2024
32815d4
Replacing isort:skip with isort:skip_file
PietropaoloFrisoni Sep 6, 2024
c9c148b
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Sep 6, 2024
b229b92
Unpinning numpy in requirements-ci.txt
PietropaoloFrisoni Sep 6, 2024
18702ff
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Sep 9, 2024
61a2d42
Running torch tests with numpy 2.0 and without it
PietropaoloFrisoni Sep 9, 2024
68fbdf7
Re-pinning numpy in reqs-CI and running torch tests with numpy 2 as t…
PietropaoloFrisoni Sep 9, 2024
928fdb5
Remove pin from scipy and numpy. Running torch tests only with numpy-…
PietropaoloFrisoni Sep 9, 2024
e4abd05
Running numpy-2 tests for autograd
PietropaoloFrisoni Sep 9, 2024
0b34222
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Sep 9, 2024
6f77e21
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Sep 9, 2024
61d09d2
Correct typo in `actionyml` (numpy 2.0 was not installed as originall…
PietropaoloFrisoni Sep 10, 2024
3820765
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Sep 10, 2024
a0d37e4
Testing `torch` with 2 versions (fake test) with separate artifacts
PietropaoloFrisoni Sep 10, 2024
6a9c708
Weak pin on scipy and removed pin from autograd
PietropaoloFrisoni Sep 11, 2024
3ffcc33
Weak pin to numpy 1.x in requirements-CI (numpy 2.0 is installed sepa…
PietropaoloFrisoni Sep 11, 2024
6f5c228
Extending double numpy tests to all interfaces
PietropaoloFrisoni Sep 11, 2024
515a910
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Sep 11, 2024
0ea7f12
E.C.
PietropaoloFrisoni Sep 11, 2024
30115a6
Upgrade to `jax 0.4.26`
PietropaoloFrisoni Sep 11, 2024
50e3021
Update torch in GPU tests to 2.3.0 and increasing tolerance for faili…
PietropaoloFrisoni Sep 11, 2024
3fda59e
Removing torch upgrade and increasing tolerance even more in test
PietropaoloFrisoni Sep 11, 2024
9d0178e
Bumping to 0.4.28 (expecting several failures)
PietropaoloFrisoni Sep 12, 2024
3d06af0
Removing tests for Numpy 2.0 in some jobs (those with tensorflow and …
PietropaoloFrisoni Sep 12, 2024
3086d98
Merge branch 'master' into compatible-np-2.0
PietropaoloFrisoni Sep 12, 2024
6ce96cd
Removing empty spaces (causing issues with CI)
PietropaoloFrisoni Sep 12, 2024
0f6ec70
changing number of runners
PietropaoloFrisoni Sep 12, 2024
0649eb5
Merge branch 'compatible-np-2.0' of https://github.com/PennyLaneAI/pe…
PietropaoloFrisoni Sep 12, 2024
55cf72e
Fixing some tests
PietropaoloFrisoni Sep 12, 2024
eca800f
Running full test CI suite
PietropaoloFrisoni Sep 12, 2024
815bd44
Using pytest.xfail instead of pytest.mark.xfail
PietropaoloFrisoni Sep 12, 2024
8bdd625
Increasing tolerance in test
PietropaoloFrisoni Sep 12, 2024
81eb977
Marking tests with xfail
PietropaoloFrisoni Sep 12, 2024
0905f7a
Forgot one xfail
PietropaoloFrisoni Sep 13, 2024
331ad27
Increasing numberical tolerance in test
PietropaoloFrisoni Sep 13, 2024
77ee483
Increasing tolerance
PietropaoloFrisoni Sep 13, 2024
c6b3ed0
Testing with version 0.4.31
PietropaoloFrisoni Sep 13, 2024
26735e9
changing back to version 0.4.28 [ci skip]
PietropaoloFrisoni Sep 13, 2024
9e48ecc
Increasing core runners
PietropaoloFrisoni Sep 13, 2024
7964bc9
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Sep 13, 2024
8b09bc3
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Sep 15, 2024
3604c45
Removing double Numpy tests
PietropaoloFrisoni Sep 15, 2024
3f9567e
Updating changelog
PietropaoloFrisoni Sep 15, 2024
1b27526
fix JAX test of `fuse_rot_angles` in `tests/transforms/test_optimizat…
dwierichs Sep 16, 2024
5aa49f8
Removing pin in `requirements-ci.txt` (not useful since the pin <=2.0…
PietropaoloFrisoni Sep 16, 2024
591671d
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Sep 16, 2024
c481f4a
Merge branch 'compatible-np-2.0' of https://github.com/PennyLaneAI/pe…
PietropaoloFrisoni Sep 16, 2024
efff146
Removing double NumPy testing
PietropaoloFrisoni Sep 16, 2024
6312af6
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Sep 16, 2024
0b480d9
Updating changelog
PietropaoloFrisoni Sep 16, 2024
090504f
[ci skip]
PietropaoloFrisoni Sep 16, 2024
a6e7e13
Triggering CI
PietropaoloFrisoni Sep 16, 2024
1e6eddc
Added comment with clarification
PietropaoloFrisoni Sep 16, 2024
bd44fe6
Updating changelog
PietropaoloFrisoni Sep 17, 2024
5eed998
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Sep 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/install_deps/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ inputs:
jax_version:
description: The version of JAX to install for any job that requires JAX
required: false
default: '0.4.23'
default: '0.4.28'
install_tensorflow:
description: Indicate if TensorFlow should be installed or not
required: false
Expand Down Expand Up @@ -86,7 +86,7 @@ runs:
if: inputs.install_jax == 'true'
env:
JAX_VERSION: ${{ inputs.jax_version != '' && format('=={0}', inputs.jax_version) || '' }}
run: pip install "jax${{ env.JAX_VERSION}}" "jaxlib${{ env.JAX_VERSION }}" scipy~=1.12.0
run: pip install "jax${{ env.JAX_VERSION}}" "jaxlib${{ env.JAX_VERSION }}"

- name: Install additional PIP packages
shell: bash
Expand Down
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
[(#6061)](https://github.com/PennyLaneAI/pennylane/pull/6061)
[(#6258)](https://github.com/PennyLaneAI/pennylane/pull/6258)

* PennyLane is now compatible with Jax 0.4.28.
[(#6255)](https://github.com/PennyLaneAI/pennylane/pull/6255)
dwierichs marked this conversation as resolved.
Show resolved Hide resolved

* `qml.qchem.excitations` now optionally returns fermionic operators.
[(#6171)](https://github.com/PennyLaneAI/pennylane/pull/6171)

Expand Down
2 changes: 2 additions & 0 deletions tests/devices/default_qubit/test_default_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1864,6 +1864,7 @@ def circ_expected():
if use_jit:
import jax

pytest.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28")
circ_postselect = jax.jit(circ_postselect, static_argnames=["shots"])

res = circ_postselect(param, shots=shots)
Expand Down Expand Up @@ -2051,6 +2052,7 @@ def circ():
if use_jit:
import jax

pytest.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28")
circ = jax.jit(circ, static_argnames=["shots"])

res = circ(shots=shots)
Expand Down
2 changes: 1 addition & 1 deletion tests/devices/qutrit_mixed/test_qutrit_mixed_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ def test_jax_backprop(self, use_jit):

x = jax.numpy.array(self.x, dtype=jax.numpy.float64)
coeffs = (5.2, 6.7)
f = jax.jit(self.f, static_argnums=(1, 2, 3, 4)) if use_jit else self.f
f = jax.jit(self.f, static_argnums=(1, 2, 3)) if use_jit else self.f

out = f(x, coeffs)
expected_out = self.expected(x, coeffs)
Expand Down
2 changes: 1 addition & 1 deletion tests/devices/test_default_qutrit_mixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,7 +823,7 @@ def test_jax_backprop(self, use_jit):

x = jax.numpy.array(self.x, dtype=jax.numpy.float64)
coeffs = (5.2, 6.7)
f = jax.jit(self.f, static_argnums=(1, 2, 3, 4)) if use_jit else self.f
f = jax.jit(self.f, static_argnums=(1, 2, 3)) if use_jit else self.f

out = f(x, coeffs)
expected_out = self.expected(x, coeffs)
Expand Down
17 changes: 17 additions & 0 deletions tests/interfaces/test_jax_jit_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,7 @@ def circuit(a, b):
res = circuit(a, b, shots=100) # pylint: disable=unexpected-keyword-arg
assert res.shape == (100, 2) # pylint:disable=comparison-with-callable

@pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28")
def test_gradient_integration(self, interface):
"""Test that temporarily setting the shots works
for gradient computations"""
Expand Down Expand Up @@ -912,6 +913,7 @@ def circuit(x):
class TestQubitIntegration:
"""Tests that ensure various qubit circuits integrate correctly"""

@pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28")
def test_sampling(self, dev, diff_method, grad_on_execution, device_vjp, interface):
"""Test sampling works as expected"""
if grad_on_execution:
Expand Down Expand Up @@ -941,6 +943,7 @@ def circuit():
assert isinstance(res[1], jax.Array)
assert res[1].shape == (10,)

@pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28")
def test_counts(self, dev, diff_method, grad_on_execution, device_vjp, interface):
"""Test counts works as expected"""
if grad_on_execution:
Expand Down Expand Up @@ -2041,6 +2044,7 @@ def circ(p, U):
class TestReturn:
"""Class to test the shape of the Grad/Jacobian with different return types."""

@pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28")
def test_grad_single_measurement_param(
self, dev, diff_method, grad_on_execution, device_vjp, jacobian, shots, interface
):
Expand Down Expand Up @@ -2073,6 +2077,7 @@ def circuit(a):
assert isinstance(grad, jax.numpy.ndarray)
assert grad.shape == ()

@pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28")
def test_grad_single_measurement_multiple_param(
self, dev, diff_method, grad_on_execution, device_vjp, jacobian, shots, interface
):
Expand Down Expand Up @@ -2110,6 +2115,7 @@ def circuit(a, b):
assert grad[0].shape == ()
assert grad[1].shape == ()

@pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28")
def test_grad_single_measurement_multiple_param_array(
self, dev, diff_method, grad_on_execution, device_vjp, jacobian, shots, interface
):
Expand Down Expand Up @@ -2142,6 +2148,7 @@ def circuit(a):
assert isinstance(grad, jax.numpy.ndarray)
assert grad.shape == (2,)

@pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28")
def test_jacobian_single_measurement_param_probs(
self, dev, diff_method, grad_on_execution, device_vjp, jacobian, shots, interface
):
Expand Down Expand Up @@ -2175,6 +2182,7 @@ def circuit(a):
assert isinstance(jac, jax.numpy.ndarray)
assert jac.shape == (4,)

@pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28")
def test_jacobian_single_measurement_probs_multiple_param(
self, dev, diff_method, grad_on_execution, device_vjp, jacobian, shots, interface
):
Expand Down Expand Up @@ -2214,6 +2222,7 @@ def circuit(a, b):
assert isinstance(jac[1], jax.numpy.ndarray)
assert jac[1].shape == (4,)

@pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28")
def test_jacobian_single_measurement_probs_multiple_param_single_array(
self, dev, diff_method, grad_on_execution, device_vjp, jacobian, shots, interface
):
Expand Down Expand Up @@ -2246,6 +2255,7 @@ def circuit(a):
assert isinstance(jac, jax.numpy.ndarray)
assert jac.shape == (4, 2)

@pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28")
def test_jacobian_expval_expval_multiple_params(
self, dev, diff_method, grad_on_execution, device_vjp, jacobian, shots, interface
):
Expand Down Expand Up @@ -2295,6 +2305,7 @@ def circuit(x, y):
assert isinstance(jac[1][1], jax.numpy.ndarray)
assert jac[1][1].shape == ()

@pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28")
def test_jacobian_expval_expval_multiple_params_array(
self, dev, diff_method, grad_on_execution, device_vjp, jacobian, shots, interface
):
Expand Down Expand Up @@ -2333,6 +2344,7 @@ def circuit(a):
assert isinstance(jac[1], jax.numpy.ndarray)
assert jac[1].shape == (2,)

@pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28")
def test_jacobian_var_var_multiple_params(
self, dev, diff_method, grad_on_execution, device_vjp, jacobian, shots, interface
):
Expand Down Expand Up @@ -2385,6 +2397,7 @@ def circuit(x, y):
assert isinstance(jac[1][1], jax.numpy.ndarray)
assert jac[1][1].shape == ()

@pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28")
def test_jacobian_var_var_multiple_params_array(
self, dev, diff_method, grad_on_execution, device_vjp, jacobian, shots, interface
):
Expand Down Expand Up @@ -2425,6 +2438,7 @@ def circuit(a):
assert isinstance(jac[1], jax.numpy.ndarray)
assert jac[1].shape == (2,)

@pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28")
def test_jacobian_multiple_measurement_single_param(
self, dev, diff_method, grad_on_execution, device_vjp, jacobian, shots, interface
):
Expand Down Expand Up @@ -2463,6 +2477,7 @@ def circuit(a):
assert isinstance(jac[1], jax.numpy.ndarray)
assert jac[1].shape == (4,)

@pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28")
def test_jacobian_multiple_measurement_multiple_param(
self, dev, diff_method, grad_on_execution, device_vjp, jacobian, shots, interface
):
Expand Down Expand Up @@ -2510,6 +2525,7 @@ def circuit(a, b):
assert isinstance(jac[1][1], jax.numpy.ndarray)
assert jac[1][1].shape == (4,)

@pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28")
def test_jacobian_multiple_measurement_multiple_param_array(
self, dev, diff_method, grad_on_execution, device_vjp, jacobian, shots, interface
):
Expand Down Expand Up @@ -2871,6 +2887,7 @@ def circuit(x):
assert hess[1].shape == (2, 2, 2)


@pytest.mark.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28")
@pytest.mark.parametrize("hessian", hessian_fn)
@pytest.mark.parametrize("diff_method", ["parameter-shift", "hadamard"])
def test_jax_device_hessian_shots(hessian, diff_method):
Expand Down
13 changes: 11 additions & 2 deletions tests/transforms/test_optimization/test_optimization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,17 @@ def test_jacobian_jax(self, use_jit):
special_angles = np.array(list(product(special_points, repeat=6))).reshape((-1, 2, 3))
random_angles = np.random.random((1000, 2, 3))
# Need holomorphic derivatives and complex inputs because the output matrices are complex
all_angles = jax.numpy.concatenate([special_angles, random_angles], dtype=complex)
jac_fn = lambda fn: jax.vmap(jax.jacobian(fn, holomorphic=True))
all_angles = jax.numpy.concatenate([special_angles, random_angles])

# We need to define the Jacobian function manually because fuse_rot_angles is not guaranteed to be holomorphic,
# and jax.jacobian requires real-valued outputs for non-holomorphic functions.
def jac_fn(fn):
real_fn = lambda arg: qml.math.real(fn(arg))
imag_fn = lambda arg: qml.math.imag(fn(arg))
real_jac_fn = jax.vmap(jax.jacobian(real_fn))
imag_jac_fn = jax.vmap(jax.jacobian(imag_fn))
return lambda arg: real_jac_fn(arg) + 1j * imag_jac_fn(arg)

jit_fn = jax.jit if use_jit else None
self.run_jacobian_test(all_angles, jac_fn, is_batched=True, jit_fn=jit_fn)

Expand Down
Loading