Skip to content

Tricky NonConcreteBooleanIndexError #22639

Answered by jakevdp
harryjulian asked this question in Q&A
Discussion options

You must be logged in to vote

The problem is in this line:

indices = jnp.arange(0, N_CARD_TYPES, dtype=DTYPE)[agent.hand.under > 0]

since agent.hand.under is traced, there's no way to know statically how many of its values are larger than zero, and so indices would have to have a dynamic size, and dynamically-sized arrays are not supported in JAX transforms like jit (see JAX sharp bits: Dynamic Shapes).

You could probably work around this by using jax.random.randint rather than jax.random.choice, and avoiding constructing the explicit indices array.

Replies: 1 comment 5 replies

Comment options

You must be logged in to vote
5 replies
@harryjulian
Comment options

@jakevdp
Comment options

@harryjulian
Comment options

@jakevdp
Comment options

Answer selected by harryjulian
@harryjulian
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants