-
Hi all, Is there a way to have dynamic shapes, slices and other objects with dynamic properties in Jax for trivial cases. I would appreciate it if anyone could guide to the right solution. Thanks! The minimal (non)working example is below. It appears to me that this case is trivial, but very surprising that it doesn't work in jax. The index def cond_fn(acc):
i, _ = acc
return i < 5
def body_fn(acc: tuple[int, jax.typing.ArrayLike]):
i, var = acc
update = jnp.ones([2, i + 1]) * i # << Breaks here, and even jnp.repeat is required to be static...
var_new = lax.dynamic_update_slice(var, update, [i, 0])
return i+1, var_new
lax.while_loop(cond_fn, body_fn, (0, arr)) PS: it seems that even |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 5 replies
-
No, dynamic shapes are not yet supported, even trivial ones that you mention (in practice, this is not so trivial to implement!) You linked to #14634, I think that's probably the place to follow for updates. Regarding import jax
@jax.jit
def get_slice(x, i): # note: i is dynamic
# return jax.lax.slice(x, [i], [i + 3]) # this would fail
return jax.lax.dynamic_slice(x, [i], [3]) # this succeeds
get_slice(jax.numpy.arange(5), 2) |
Beta Was this translation helpful? Give feedback.
No, dynamic shapes are not yet supported, even trivial ones that you mention (in practice, this is not so trivial to implement!) You linked to #14634, I think that's probably the place to follow for updates.
Regarding
lax.dynamic_update_slice
, it does support dynamic indices, but not dynamic array shapes. This is in contrast tolax.slice
, which does not support dynamic indices at all, e.g.