Skip to content

Commit

Permalink
Experiment (removing tests for single precision in file to see if it …
Browse files Browse the repository at this point in the history
…still fails)
  • Loading branch information
PietropaoloFrisoni committed Sep 18, 2024
1 parent 02fb066 commit 37b37c2
Showing 1 changed file with 0 additions and 69 deletions.
69 changes: 0 additions & 69 deletions tests/interfaces/test_jax_jit_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -3034,72 +3034,3 @@ def circuit(a, b):
else:
assert np.allclose(jac[0], expected[0], atol=tol)
assert np.allclose(jac[1], expected[1], atol=tol)


class TestSinglePrecision:
"""Tests for compatibility with single precision mode."""

# pylint: disable=import-outside-toplevel
def test_type_conversion_fallback(self):
"""Test that if the type isn't int, float, or complex, we still have a fallback."""
from pennylane.workflow.interfaces.jax_jit import _jax_dtype

assert _jax_dtype(bool) == jax.numpy.dtype(bool)

@pytest.mark.parametrize("diff_method", ("adjoint", "parameter-shift"))
def test_float32_return(self, diff_method):
"""Test that jax jit works when float64 mode is disabled."""
jax.config.update("jax_enable_x64", False)

try:

@jax.jit
@qml.qnode(qml.device("default.qubit"), diff_method=diff_method)
def circuit(x):
qml.RX(x, wires=0)
return qml.expval(qml.PauliZ(0))

grad = jax.grad(circuit)(jax.numpy.array(0.1))
assert qml.math.allclose(grad, -np.sin(0.1))
finally:
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)

@pytest.mark.parametrize("diff_method", ("adjoint", "finite-diff"))
def test_complex64_return(self, diff_method):
"""Test that jax jit works with differentiating the state."""
jax.config.update("jax_enable_x64", False)

try:
tol = 2e-2 if diff_method == "finite-diff" else 1e-6

@jax.jit
@qml.qnode(qml.device("default.qubit", wires=1), diff_method=diff_method)
def circuit(x):
qml.RX(x, wires=0)
return qml.state()

j = jax.jacobian(circuit, holomorphic=True)(jax.numpy.array(0.1 + 0j))
assert qml.math.allclose(j, [-np.sin(0.05) / 2, -np.cos(0.05) / 2 * 1j], atol=tol)

finally:
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)

def test_int32_return(self):
"""Test that jax jit forward execution works with samples and int32"""

jax.config.update("jax_enable_x64", False)

try:

@jax.jit
@qml.qnode(qml.device("default.qubit", shots=10), diff_method=qml.gradients.param_shift)
def circuit(x):
qml.RX(x, wires=0)
return qml.sample(wires=0)

_ = circuit(jax.numpy.array(0.1))
finally:
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)

0 comments on commit 37b37c2

Please sign in to comment.