Skip to content

How to take the first i-slices<traced> of a variable in a scan loop #23087

Answered by jakevdp
ziky168 asked this question in Q&A
Discussion options

You must be logged in to vote

Thanks - the error I get when running your code is this:

TypeError: Branch index must be scalar, got Traced<ShapedArray(int32[1]):JaxprTrace(level=1/0)> of shape (1,).

This indicates that you're passing a length-1 array in a place where a scalar is expected. I fixed this by changing this:

        return jax.lax.switch(i, cases)

to this:

        return jax.lax.switch(i[0], cases)

Running again, I get this error:

TypeError: Expected a callable value, got Traced<ShapedArray(float64[1]):JaxprTrace(level=1/0)>

This comes because you're passing an array to JIT rather than a function. To fix this I changed this:

        probs = jnp.exp(jax.jit(self.select_logprobs(i,logprobs),static_argnames=0)) 

Replies: 2 comments 4 replies

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
4 replies
@jakevdp
Comment options

@ziky168
Comment options

@jakevdp
Comment options

Answer selected by ziky168
@ziky168
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants