Skip to content

Commit

Permalink
Using pytest fixture (hoping this is thread-safe)
Browse files Browse the repository at this point in the history
  • Loading branch information
PietropaoloFrisoni committed Sep 18, 2024
1 parent 37b37c2 commit 0527c11
Showing 1 changed file with 61 additions and 0 deletions.
61 changes: 61 additions & 0 deletions tests/interfaces/test_jax_jit_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,17 @@
H_FOR_SPSA = 0.05


@pytest.fixture
def manage_jax_precision():
"""Fixture to manage JAX precision for tests that require single precision mode."""

original_value = jax.config.read("jax_enable_x64")
jax.config.update("jax_enable_x64", False)
yield

jax.config.update("jax_enable_x64", original_value)


@pytest.mark.parametrize(
"interface,dev,diff_method,grad_on_execution,device_vjp",
interface_and_qubit_device_and_diff_method,
Expand Down Expand Up @@ -3034,3 +3045,53 @@ def circuit(a, b):
else:
assert np.allclose(jac[0], expected[0], atol=tol)
assert np.allclose(jac[1], expected[1], atol=tol)


@pytest.mark.usefixtures("manage_jax_precision")
class TestSinglePrecision:
"""Tests for compatibility with single precision mode."""

# pylint: disable=import-outside-toplevel
def test_type_conversion_fallback(self):
"""Test that if the type isn't int, float, or complex, we still have a fallback."""
from pennylane.workflow.interfaces.jax_jit import _jax_dtype

assert _jax_dtype(bool) == jax.numpy.dtype(bool)

@pytest.mark.parametrize("diff_method", ("adjoint", "parameter-shift"))
def test_float32_return(self, diff_method):
"""Test that jax jit works when float64 mode is disabled."""

@jax.jit
@qml.qnode(qml.device("default.qubit"), diff_method=diff_method)
def circuit(x):
qml.RX(x, wires=0)
return qml.expval(qml.PauliZ(0))

grad = jax.grad(circuit)(jax.numpy.array(0.1))
assert qml.math.allclose(grad, -np.sin(0.1))

@pytest.mark.parametrize("diff_method", ("adjoint", "finite-diff"))
def test_complex64_return(self, diff_method):
"""Test that jax jit works with differentiating the state."""
tol = 2e-2 if diff_method == "finite-diff" else 1e-6

@jax.jit
@qml.qnode(qml.device("default.qubit", wires=1), diff_method=diff_method)
def circuit(x):
qml.RX(x, wires=0)
return qml.state()

j = jax.jacobian(circuit, holomorphic=True)(jax.numpy.array(0.1 + 0j))
assert qml.math.allclose(j, [-np.sin(0.05) / 2, -np.cos(0.05) / 2 * 1j], atol=tol)

def test_int32_return(self):
"""Test that jax jit forward execution works with samples and int32"""

@jax.jit
@qml.qnode(qml.device("default.qubit", shots=10), diff_method=qml.gradients.param_shift)
def circuit(x):
qml.RX(x, wires=0)
return qml.sample(wires=0)

_ = circuit(jax.numpy.array(0.1))

0 comments on commit 0527c11

Please sign in to comment.