BatchNorm fixes for JAX and PyTorch workloads #798
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Fixes to BatchNorm behavior in JAX and PyTorch; mainly decouple update batch norm statistics from using the running statistics.
Changes for PyTorch from @adefazio's #783
From pull/783:
There are some subtle issues with how BatchNorm is handled in the PyTorch version of the code. Currently,
workload.model_fn
has anupdate_batch_norm
parameter, which in theory should allow the submission to control whether the batch-norm statistics are updated during a forward pass. The issues are the following:update_batch_norm_fn
function stores the old momentum parameter for each batchnorm layer in amomentum_backup
variable, so it can be restored later, before zeroing the parameter. However, if it is called withupdate_batch_norm=False
twice in a row, it overwrites themomentum_backup
with 0 on the second call, so momentum then remains zero for the remainder of training.0
indicates that the momentum buffer shouldn't be updated. This is the opposite of how EMA momentum is usually done (i.e. in Adam), where1
would indicate that it shouldn't be updated, and 0 means it's set to the latest value at every step. The custom BatchNorm modules used in the two librispeech workloads follows this second, more standard convention instead. However, theupdate_batch_norm_fn
sets the momentum to zero for all three layer types, resulting in incorrect behavior for the librispeech workloads.update_batch_norm_fn
sets the BN layers to eval mode. This doesn't make sense as it prevents the use-case where you use batch-computed statistics (train mode) without also updating the running statistics. The BN layers can bet set to eval mode separately by passing inForwardPassMode.EVAL
to the forward pass, so removing this.eval()
call doesn't prevent the submission from using eval mode during a forward pass.This PR changes switch the custom BN code to follow the BN convention so that momentum=0 doesn't update the running buffers. It also fixes the issues in the update_batch_norm_fn function mentioned above.