-
Is there anything equivalent to My use case is that I would like to selectively access a different index of a nested pytree inside a static argument of my pmapped function, depending on the corresponding device index. |
Beta Was this translation helpful? Give feedback.
Answered by
minqi
Apr 15, 2023
Replies: 1 comment
-
Answering my own question: Just realized that's what |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
minqi
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Answering my own question: Just realized that's what
jax.lax.axis_index
is for!