Skip to content

how to dynamic_slice with DeviceArray #650

Answered by jheek
L-Hugh asked this question in Q&A
Discussion options

You must be logged in to vote

The issue here is that b affects the shape of the result returned by dynamic_slice. You can do this by declaring b as a static argument:

def test(a, b):
    return jax.lax.dynamic_slice(a, (0,), (b,))
a = jnp.array([1,2,3])
b = jnp.array(1)
jax.jit(test, static_argnums=(1,))(a, b)

Note however that test will be re-compiled if you pass in a different value for b

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@L-Hugh
Comment options

Answer selected by L-Hugh
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants