Data Parallel - Managing State Variables #3264
Unanswered
peterdavidfagan
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi Flax Community,
I am getting started with transforming my flax training pipeline into a format that can run across multiple devices. I've been reading the following guide to accomplish this. I had some clarifying questions that I was hoping to ask (and post here in case the answer to these are also useful to other users).
Managing State Variables
For layers that maintain variables such as
nn.BatchNorm
(API docs), is it sufficient to pass the logical axis name to theaxis_name
parameter for state variables to be correctly tracked when training across multiple devices? The API docs for BatchNorm reference pmap but don't mention partitioning with jit. Does this parameter also apply to partitioning with jit as outlined in the guide? I presume yes, but I haven't delved into the codebase for jit yet to verify.Beta Was this translation helpful? Give feedback.
All reactions