Skip to content

Possibly a bug in jax.numpy.ravel_multi_index #14733

Answered by jakevdp
PgLoLo asked this question in Q&A
Discussion options

You must be logged in to vote

The error looks like this:

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(int32[])>with<BatchTrace(level=1/0)> with
  val = Array([0, 1, 2], dtype=int32)
  batch_dim = 0
The error occurred because ravel_multi_index was jit-compiled with mode='raise'. Use mode='wrap' or mode='clip' instead.
This BatchTracer with object id 140570390572544 was created on line:
  <ipython-input-2-5c78fe181e6a>:6 (<lambda>)

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

The issue is that the default mode='raise' is not compatible with JAX transforms like jit or vmap, and as mentioned by the error you sho…

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by PgLoLo
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants