From 0527c1144c9f724016205247560124174184f0ed Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Wed, 18 Sep 2024 12:48:28 -0400 Subject: [PATCH] Using pytest fixture (hoping this is thread-safe) --- tests/interfaces/test_jax_jit_qnode.py | 61 ++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/tests/interfaces/test_jax_jit_qnode.py b/tests/interfaces/test_jax_jit_qnode.py index d0712eb96de..92e60f32111 100644 --- a/tests/interfaces/test_jax_jit_qnode.py +++ b/tests/interfaces/test_jax_jit_qnode.py @@ -57,6 +57,17 @@ H_FOR_SPSA = 0.05 +@pytest.fixture +def manage_jax_precision(): + """Fixture to manage JAX precision for tests that require single precision mode.""" + + original_value = jax.config.read("jax_enable_x64") + jax.config.update("jax_enable_x64", False) + yield + + jax.config.update("jax_enable_x64", original_value) + + @pytest.mark.parametrize( "interface,dev,diff_method,grad_on_execution,device_vjp", interface_and_qubit_device_and_diff_method, @@ -3034,3 +3045,53 @@ def circuit(a, b): else: assert np.allclose(jac[0], expected[0], atol=tol) assert np.allclose(jac[1], expected[1], atol=tol) + + +@pytest.mark.usefixtures("manage_jax_precision") +class TestSinglePrecision: + """Tests for compatibility with single precision mode.""" + + # pylint: disable=import-outside-toplevel + def test_type_conversion_fallback(self): + """Test that if the type isn't int, float, or complex, we still have a fallback.""" + from pennylane.workflow.interfaces.jax_jit import _jax_dtype + + assert _jax_dtype(bool) == jax.numpy.dtype(bool) + + @pytest.mark.parametrize("diff_method", ("adjoint", "parameter-shift")) + def test_float32_return(self, diff_method): + """Test that jax jit works when float64 mode is disabled.""" + + @jax.jit + @qml.qnode(qml.device("default.qubit"), diff_method=diff_method) + def circuit(x): + qml.RX(x, wires=0) + return qml.expval(qml.PauliZ(0)) + + grad = jax.grad(circuit)(jax.numpy.array(0.1)) + assert qml.math.allclose(grad, -np.sin(0.1)) + + @pytest.mark.parametrize("diff_method", ("adjoint", "finite-diff")) + def test_complex64_return(self, diff_method): + """Test that jax jit works with differentiating the state.""" + tol = 2e-2 if diff_method == "finite-diff" else 1e-6 + + @jax.jit + @qml.qnode(qml.device("default.qubit", wires=1), diff_method=diff_method) + def circuit(x): + qml.RX(x, wires=0) + return qml.state() + + j = jax.jacobian(circuit, holomorphic=True)(jax.numpy.array(0.1 + 0j)) + assert qml.math.allclose(j, [-np.sin(0.05) / 2, -np.cos(0.05) / 2 * 1j], atol=tol) + + def test_int32_return(self): + """Test that jax jit forward execution works with samples and int32""" + + @jax.jit + @qml.qnode(qml.device("default.qubit", shots=10), diff_method=qml.gradients.param_shift) + def circuit(x): + qml.RX(x, wires=0) + return qml.sample(wires=0) + + _ = circuit(jax.numpy.array(0.1))