diff --git a/examples/flava/native/train.py b/examples/flava/native/train.py index 6e2afa10..88b64b6b 100644 --- a/examples/flava/native/train.py +++ b/examples/flava/native/train.py @@ -117,7 +117,11 @@ def __init__(self, config: DictConfig): else torch.float16 ) - self.scaler = ShardedGradScaler() if config.training.enable_amp else None + self.scaler = ( + ShardedGradScaler() + if config.training.enable_amp and self.half_dtype == torch.float16 + else None + ) def log( self, @@ -144,25 +148,6 @@ def create_model(self) -> torch.nn.Module: f"size: {get_model_size_gb(model):.3} GB" ) - if self.config.training.activation_checkpointing: - check_fn = lambda submodule: isinstance(submodule, TransformerEncoderLayer) - checkpoint_impl = CheckpointImpl.REENTRANT - - # DDP gradient hooks have compatibility issues with REENTRANT autograd - if strategy == "ddp": - checkpoint_impl = CheckpointImpl.NO_REENTRANT - - checkpoint_wrapper_fn = partial( - checkpoint_wrapper, - offload_to_cpu=False, - checkpoint_impl=checkpoint_impl, - ) - apply_activation_checkpointing( - model, - checkpoint_wrapper_fn=checkpoint_wrapper_fn, - check_fn=check_fn, - ) - if strategy == "ddp": # TODO do we have to do this in FSDP too? see https://github.com/pytorch/pytorch/issues/75478 model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) @@ -206,6 +191,20 @@ def create_model(self) -> torch.nn.Module: print0(f"after FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB") + if self.config.training.activation_checkpointing: + check_fn = lambda submodule: isinstance(submodule, TransformerEncoderLayer) + checkpoint_impl = CheckpointImpl.NO_REENTRANT + + checkpoint_wrapper_fn = partial( + checkpoint_wrapper, + checkpoint_impl=checkpoint_impl, + ) + apply_activation_checkpointing( + model, + checkpoint_wrapper_fn=checkpoint_wrapper_fn, + check_fn=check_fn, + ) + else: raise ValueError(f"unknown strategy: {strategy}")