Replies: 1 comment
-
I guess what I'm really asking is what's the canonic way to create an array that is replicated across multiple processes. Turns out I can actually populate such array with different values in different processes successfully. Consider the following code:
This code outputs the following:
So even though |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Let's say I have a distributed training setup that combines pjit-based model and data parallelism and can encompass multiple processes. I'm struggling to figure out what's the recommended way of loading a batch of data in this setup.
For simplicity, let's assume that I have 4 devices, each managed by a separate process, and my batch is sharded into 2 chunks, so that process 0 & 1 should process chunk #0, and process 2 & 3 should process chunk #1.
How should I load the data in this case? Should each process just load the data it's going to need (meaning that 0 and 1 will each load chunk #0, while 2 & 3 will load chunk #1), and then call
make_array_from_single_device_arrays
? If assume not, because it goes against the whole idea of batch being replicated over some devices: like what if #0 and #1 will load different data into it?Perhaps I should pick a single process within each replica to load the data? Or load it all from a single process and then replicate send it to the whole mesh? If so, I'm not quite sure how to correctly achieve that. Like, can I call
make_array_from_single_device_arrays
on a subset of mesh devices and not cause synchronization locks?Any suggestions or pointers to code that does something similar are be much appreciated!
Beta Was this translation helpful? Give feedback.
All reactions