Replies: 3 comments
-
Hi @nikitakit, good to have you here :)
|
Beta Was this translation helpful? Give feedback.
0 replies
-
(I'll convert this to a discussion, if we decide to add any new features we can file an issue afterwards.) |
Beta Was this translation helpful? Give feedback.
0 replies
-
You could sync the paramaters across devices and make the first device the leading one:
|
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Currently
optimizer.replicate()
will replicate a model to all devices on the current host, but flax doesn't provide any means to replicate model parameters in a multi-host setting.For multi-host training to work, parameters need to be initialized identically across all hosts. This requires discipline from the programmer to use the same random seed on all hosts, and to avoid using unseeded randomness (like numpy randomness) when initializing parameters.
It would be helpful if there were a standard command that takes a copy of the model from one host and replicates it to all others.
(For anyone wondering why I ran into this: I'm loading pre-trained BERT checkpoints saved with tensorflow or pytorch, except that I also need to add a task-specific head on top. Since most of the parameter loading happens outside of flax and deals with numpy arrays, it felt natural to just add code in the same place that calls
numpy.random()
to initialize the classifier head)Beta Was this translation helpful? Give feedback.
All reactions