Skip to content

Commit

Permalink
[host_callback] Fix type promotion error
Browse files Browse the repository at this point in the history
Fix a type error that arises when we try to run the host callback tests with JAX_HOST_CALLBACK_LEGACY=False (in the process of deprecating jax.experimental.host_callback).

PiperOrigin-RevId: 671825020
  • Loading branch information
gnecula authored and jax authors committed Sep 6, 2024
1 parent 878b6b5 commit fc6b22e
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tests/host_callback_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2240,7 +2240,7 @@ def f_outside(arg):
def test_call_cond(self):
def f_outside(args):
x, y = args
return x * y
return x * y.astype(np.float32)

def loop(x, use_outside=True):
def body(i, acc):
Expand All @@ -2253,8 +2253,8 @@ def body(i, acc):

return lax.fori_loop(0, 18, body, x)

res_inside = loop(1.2, use_outside=False)
self.assertAllClose(res_inside, jax.jit(loop)(1.2))
res_inside = loop(np.float32(1.2), use_outside=False)
self.assertAllClose(res_inside, jax.jit(loop)(np.float32(1.2)))

def test_call_jit_scan_call(self):
def f_outside(x):
Expand Down

0 comments on commit fc6b22e

Please sign in to comment.