Skip to content

Error: Non-hashable static arguments for fori_loop #5859

Answered by jakevdp
sokrypton asked this question in Q&A
Discussion options

You must be logged in to vote

The issue is that the third argument to jnp.eye must be static, and the arguments produced by fori_loop are by design non-static (i.e. traced). I would suggest finding another way to express your computation, for example via broadcasted indices. For example, this produces the same output:

import jax.numpy as jnp
i = jnp.arange(10)[:, None]
j = jnp.arange(5)
x = jnp.ones((10, 10))
x = x.at[i, i + j].set(2)

It is compatible with jit, and unlike fori_loop will have good performance on accelerators.

Replies: 1 comment 5 replies

Comment options

You must be logged in to vote
5 replies
@sokrypton
Comment options

@jakevdp
Comment options

@sokrypton
Comment options

@jakevdp
Comment options

@sokrypton
Comment options

Answer selected by sokrypton
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants