From a7c8acc0fdb6ed31bdbad10cf56dfa59aa402582 Mon Sep 17 00:00:00 2001 From: rvarm1 Date: Thu, 6 Jul 2023 17:12:36 -0700 Subject: [PATCH 1/2] Dont use scaler for bf16 Summary: ShardedGradScaler not needed for bf16 training. Differential Revision: D47218367 fbshipit-source-id: 143269cc8dc05805b3caebe73035ce0f39c7e8d4 --- examples/flava/native/train.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/flava/native/train.py b/examples/flava/native/train.py index 6e2afa10..6fe29e73 100644 --- a/examples/flava/native/train.py +++ b/examples/flava/native/train.py @@ -117,7 +117,11 @@ def __init__(self, config: DictConfig): else torch.float16 ) - self.scaler = ShardedGradScaler() if config.training.enable_amp else None + self.scaler = ( + ShardedGradScaler() + if config.training.enable_amp and self.half_dtype == torch.float16 + else None + ) def log( self, From aefeaeacef0d5b75520a11add0fe648ade4dba15 Mon Sep 17 00:00:00 2001 From: Rohan Varma Date: Thu, 6 Jul 2023 17:12:53 -0700 Subject: [PATCH 2/2] Fix AC order for FLAVA example Summary: There have been some fixes to the non-reentrant AC and FLAVA model now works with it. Differential Revision: D47218368 fbshipit-source-id: 448969eb42e6df224dcd7ea82b4fabe63dcecf96 --- examples/flava/native/train.py | 33 ++++++++++++++------------------- 1 file changed, 14 insertions(+), 19 deletions(-) diff --git a/examples/flava/native/train.py b/examples/flava/native/train.py index 6fe29e73..88b64b6b 100644 --- a/examples/flava/native/train.py +++ b/examples/flava/native/train.py @@ -148,25 +148,6 @@ def create_model(self) -> torch.nn.Module: f"size: {get_model_size_gb(model):.3} GB" ) - if self.config.training.activation_checkpointing: - check_fn = lambda submodule: isinstance(submodule, TransformerEncoderLayer) - checkpoint_impl = CheckpointImpl.REENTRANT - - # DDP gradient hooks have compatibility issues with REENTRANT autograd - if strategy == "ddp": - checkpoint_impl = CheckpointImpl.NO_REENTRANT - - checkpoint_wrapper_fn = partial( - checkpoint_wrapper, - offload_to_cpu=False, - checkpoint_impl=checkpoint_impl, - ) - apply_activation_checkpointing( - model, - checkpoint_wrapper_fn=checkpoint_wrapper_fn, - check_fn=check_fn, - ) - if strategy == "ddp": # TODO do we have to do this in FSDP too? see https://github.com/pytorch/pytorch/issues/75478 model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) @@ -210,6 +191,20 @@ def create_model(self) -> torch.nn.Module: print0(f"after FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB") + if self.config.training.activation_checkpointing: + check_fn = lambda submodule: isinstance(submodule, TransformerEncoderLayer) + checkpoint_impl = CheckpointImpl.NO_REENTRANT + + checkpoint_wrapper_fn = partial( + checkpoint_wrapper, + checkpoint_impl=checkpoint_impl, + ) + apply_activation_checkpointing( + model, + checkpoint_wrapper_fn=checkpoint_wrapper_fn, + check_fn=check_fn, + ) + else: raise ValueError(f"unknown strategy: {strategy}")