-
Hello, When I tried setting: static_argnums=(0,), the error is now: I feel like I'm missing something any ideas? Here is a google colab notebook with working example (using python for-loop) and error when I try fori_loop: |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 5 replies
-
The issue is that the third argument to import jax.numpy as jnp
i = jnp.arange(10)[:, None]
j = jnp.arange(5)
x = jnp.ones((10, 10))
x = x.at[i, i + j].set(2) It is compatible with jit, and unlike |
Beta Was this translation helpful? Give feedback.
The issue is that the third argument to
jnp.eye
must be static, and the arguments produced byfori_loop
are by design non-static (i.e. traced). I would suggest finding another way to express your computation, for example via broadcasted indices. For example, this produces the same output:It is compatible with jit, and unlike
fori_loop
will have good performance on accelerators.