Replies: 1 comment 1 reply
-
When you wrap your code in JIT, the compiler will make decisions about the most effective layout for your particular sequence of computations. For that reason, the best advice would be to use whatever layout makes the logic of your algorithm the most clear. |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi!
I am trying to understand the optimal way of defining tensors, in particular the order of their dimensions when we have batches.
I remember when running code on the CPU, the optimal thing was to have the N batches as the last dimension and use tensor shapes of the form (dims..., N), since the CPU operates on the batches one by one, and it's better if it can load in and cache the data it needs to operate on while working on a specific batch instance. I then learned that the opposite is true on the GPU while trying to optimize my code, where shapes of the form (N, dims...) are preferred. I guess this is because the GPU loads in the data for all N batches at once, so there it's worth to have them in consecutive memory? Still, this makes sense to me.
What doesn't make sense to me is that then why in an RNN, the shapes we use are (T, N, dims), when the time dimension is accessed step by step? Doesn't the same idea apply that the batches N should be the first dimension and the shapes should be (N, T, dims) or something?
Thanks for any help!
EDIT: I am even more confused now, since as far as I can tell jax by default uses a row-major format. So isn't have N as the first dimension the opposite of the optimal?
Beta Was this translation helpful? Give feedback.
All reactions