From 69253e5c223648eb86d80abd1b17ab0ee923db5a Mon Sep 17 00:00:00 2001 From: Ning Wang Date: Sun, 18 Feb 2024 06:16:40 +0000 Subject: [PATCH 01/17] reshard and cleanup --- composer/trainer/trainer.py | 13 ++++ tests/trainer/test_fsdp.py | 134 +++++++++++++++++++++++++++++++++++- 2 files changed, 145 insertions(+), 2 deletions(-) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 01cd0fcc9b..e4f2423ca8 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -30,6 +30,8 @@ import torch.utils.data from torch._dynamo import OptimizedModule from torch.cuda.amp.grad_scaler import GradScaler, _refresh_per_optimizer_state +from torch.distributed.fsdp import FullyShardedDataParallel +from torch.distributed.fsdp._runtime_utils import _post_backward_final_callback from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler from torch.nn.parallel import DistributedDataParallel from torch.optim.lr_scheduler import LRScheduler @@ -232,6 +234,16 @@ def _is_cuda_oom(e: RuntimeError): return False +def _fsdp_reshard_and_cleanup(model: torch.nn.Module): + for name, module in model.named_modules(): + if isinstance(module, FullyShardedDataParallel): + if module.check_is_root(): + try: + _post_backward_final_callback(module, module) + except Exception as e: + log.warning(f'Failed to reshard fsdp after oom, error: {e}') + + def _adjust_device_train_microbatch_size(state: State): """Adjust device_train_microbatch_size if we encounter OOM. @@ -259,6 +271,7 @@ def _adjust_device_train_microbatch_size(state: State): optimizer.zero_grad(set_to_none=True) if state.scaler is not None: state.scaler._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) + _fsdp_reshard_and_cleanup() torch.cuda.empty_cache() diff --git a/tests/trainer/test_fsdp.py b/tests/trainer/test_fsdp.py index c6f5258c49..2c42c247d8 100644 --- a/tests/trainer/test_fsdp.py +++ b/tests/trainer/test_fsdp.py @@ -7,10 +7,10 @@ import torch from packaging import version from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import CheckpointWrapper -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, Dataset from composer.models import ComposerClassifier, ComposerModel -from composer.trainer.trainer import Trainer +from composer.trainer.trainer import Trainer, _fsdp_reshard_and_cleanup from composer.utils import dist from tests.common import (EmbeddedWeightTiedModel, RandomClassificationDataset, SimpleModel, SimpleWeightTiedModel, world_size) @@ -191,6 +191,80 @@ def test_fsdp_prefetch_limit(forward_prefetch_limit: int, backward_prefetch_limi trainer.fit() +class SimpleDataset(Dataset): + + def __init__(self, size: int = 256, batch_size: int = 32, feature_size: int = 1, num_classes: int = 2): + self.size = size + #self.batch_size = batch_size + self.feature_size = feature_size + self.num_classes = num_classes + self.x = None + self.y = None + + def __len__(self): + return self.size + + def __getitem__(self, index: int): + # Note: lazily generate data so it runs after Composer seeds everything, giving the same + # dataset across multiple calls when using the same seed. + if self.x is None: + self.x = torch.randn(self.size, self.feature_size) + if self.y is None: + self.y = torch.randint(0, self.num_classes, size=(self.size,), dtype=torch.long) + return self.x[index] + + +class SimpleMLPForTestingOOM(ComposerModel): + + def __init__(self, num_features: int = 128, device: str = 'cuda'): + super().__init__() + self.device = device + self.fc1 = torch.nn.Linear(num_features, num_features, device=device, bias=False) + self.fc2 = torch.nn.Linear(num_features, num_features, device=device, bias=False) + self.fc3 = torch.nn.Linear(num_features, num_features, device=device, bias=False) + self.rank = dist.get_global_rank() + self.iter = 0 + + def forward(self, x): + x = self.fc1(x) + if self.rank == 0 and x.shape[0] >= 64: + raise RuntimeError('CUDA out of memory') + x = self.fc2(x) + x = self.fc3(x) + self.iter += 1 + return x + + def loss(self, outputs, batch): + return torch.sum(outputs) + + +@pytest.mark.gpu +@world_size(2) +def test_fsdp_auto_microbatch(world_size: int): + model = SimpleMLPForTestingOOM() + model.fc1._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues] + model.fc2._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues] + dataset = SimpleDataset(size=256, feature_size=128) + dataloader = DataLoader(dataset, sampler=dist.get_sampler(dataset), batch_size=64) + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + + trainer = Trainer( + model=model, + optimizers=optimizer, + train_dataloader=dataloader, + fsdp_config={ + 'forward_prefetch_limit': 1, + 'backward_prefetch_limit': 1, + }, + max_duration='3ba', + device_train_microbatch_size='auto', + dist_timeout=20, + ) + + trainer.fit() + assert False + + @pytest.mark.gpu @world_size(2) @pytest.mark.filterwarnings('ignore:Instantiating FSDP with custom process groups.*:UserWarning') @@ -272,3 +346,59 @@ def test_fsdp_act_ckpt_offload( assert isinstance(trainer.state.model.fc1._fsdp_wrapped_module, OffloadWrapper) else: assert not isinstance(trainer.state.model.fc1._fsdp_wrapped_module, CheckpointWrapper) + + +class SimpleMLPForTestingOOM(ComposerModel): + + def __init__(self, num_features: int = 128, device: str = 'cuda'): + super().__init__() + self.device = device + self.fc1 = torch.nn.Linear(num_features, num_features, device=device, bias=False) + self.fc2 = torch.nn.Linear(num_features, num_features, device=device, bias=False) + self.fc3 = torch.nn.Linear(num_features, num_features, device=device, bias=False) + self.rank = dist.get_global_rank() + + def oom_hook(*args): + raise RuntimeError('CUDA out of memory.') + self.fc2.register_full_backward_hook(oom_hook) + + + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + x = self.fc3(x) + return x + + def loss(self, outputs, batch): + return torch.sum(outputs) + +@pytest.mark.gpu +@world_size(2) +def test_fsdp_reshard_after_oom(world_size: int): + model = SimpleMLPForTestingOOM() + + trainer = Trainer( + model=model, + fsdp_config={ + }, + max_duration='3ba', + dist_timeout=20, + ) + fsdp_model = trainer.state.model + + x = torch.rand([2, 128]) + output = fsdp_model(x) + with pytest.raises(Exception): + # Backward triggers the fake OOM exception, + # which prevents fsdp reshard and cleanup + torch.sum(output).backward() + + fc2_flat_param = fsdp_model.fc2._flat_param + + # without cleanup, model.fc2.flat_params is still in unshard state + # the full param is not freed + assert fc2_flat_param.data_ptr() != fc2_flat_param._local_shard.data_ptr() + assert fc2_flat_param._full_param_padded.numel() > 0 + + _fsdp_reshard_and_cleanup(fsdp_model) + assert fc2_flat_param.data_ptr() == fc2_flat_param._local_shard.data_ptr() From cafa4889b8a4e9d078584028635b7e8dcd87aeab Mon Sep 17 00:00:00 2001 From: Ning Wang Date: Sun, 18 Feb 2024 06:16:57 +0000 Subject: [PATCH 02/17] format --- tests/trainer/test_fsdp.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/trainer/test_fsdp.py b/tests/trainer/test_fsdp.py index 2c42c247d8..b6eaaba8ac 100644 --- a/tests/trainer/test_fsdp.py +++ b/tests/trainer/test_fsdp.py @@ -360,8 +360,8 @@ def __init__(self, num_features: int = 128, device: str = 'cuda'): def oom_hook(*args): raise RuntimeError('CUDA out of memory.') + self.fc2.register_full_backward_hook(oom_hook) - def forward(self, x): x = self.fc1(x) @@ -372,6 +372,7 @@ def forward(self, x): def loss(self, outputs, batch): return torch.sum(outputs) + @pytest.mark.gpu @world_size(2) def test_fsdp_reshard_after_oom(world_size: int): @@ -379,8 +380,7 @@ def test_fsdp_reshard_after_oom(world_size: int): trainer = Trainer( model=model, - fsdp_config={ - }, + fsdp_config={}, max_duration='3ba', dist_timeout=20, ) @@ -389,10 +389,10 @@ def test_fsdp_reshard_after_oom(world_size: int): x = torch.rand([2, 128]) output = fsdp_model(x) with pytest.raises(Exception): - # Backward triggers the fake OOM exception, + # Backward triggers the fake OOM exception, # which prevents fsdp reshard and cleanup torch.sum(output).backward() - + fc2_flat_param = fsdp_model.fc2._flat_param # without cleanup, model.fc2.flat_params is still in unshard state @@ -400,5 +400,5 @@ def test_fsdp_reshard_after_oom(world_size: int): assert fc2_flat_param.data_ptr() != fc2_flat_param._local_shard.data_ptr() assert fc2_flat_param._full_param_padded.numel() > 0 - _fsdp_reshard_and_cleanup(fsdp_model) + _fsdp_reshard_and_cleanup(fsdp_model) assert fc2_flat_param.data_ptr() == fc2_flat_param._local_shard.data_ptr() From a84a22c0d867d69fc9473919a976cd18b07cd83a Mon Sep 17 00:00:00 2001 From: Ning Wang Date: Sun, 18 Feb 2024 06:31:03 +0000 Subject: [PATCH 03/17] fix --- composer/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index e4f2423ca8..b80a7d8604 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -271,7 +271,7 @@ def _adjust_device_train_microbatch_size(state: State): optimizer.zero_grad(set_to_none=True) if state.scaler is not None: state.scaler._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) - _fsdp_reshard_and_cleanup() + _fsdp_reshard_and_cleanup(state.model) torch.cuda.empty_cache() From 8ad8d305b6ddb59d83982be20ed7d4586a9b0557 Mon Sep 17 00:00:00 2001 From: Ning Wang Date: Sun, 18 Feb 2024 07:53:29 +0000 Subject: [PATCH 04/17] cleanup unit test --- tests/trainer/test_fsdp.py | 76 +------------------------------------- 1 file changed, 1 insertion(+), 75 deletions(-) diff --git a/tests/trainer/test_fsdp.py b/tests/trainer/test_fsdp.py index b6eaaba8ac..f3dfe68cdb 100644 --- a/tests/trainer/test_fsdp.py +++ b/tests/trainer/test_fsdp.py @@ -7,7 +7,7 @@ import torch from packaging import version from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import CheckpointWrapper -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import DataLoader from composer.models import ComposerClassifier, ComposerModel from composer.trainer.trainer import Trainer, _fsdp_reshard_and_cleanup @@ -191,80 +191,6 @@ def test_fsdp_prefetch_limit(forward_prefetch_limit: int, backward_prefetch_limi trainer.fit() -class SimpleDataset(Dataset): - - def __init__(self, size: int = 256, batch_size: int = 32, feature_size: int = 1, num_classes: int = 2): - self.size = size - #self.batch_size = batch_size - self.feature_size = feature_size - self.num_classes = num_classes - self.x = None - self.y = None - - def __len__(self): - return self.size - - def __getitem__(self, index: int): - # Note: lazily generate data so it runs after Composer seeds everything, giving the same - # dataset across multiple calls when using the same seed. - if self.x is None: - self.x = torch.randn(self.size, self.feature_size) - if self.y is None: - self.y = torch.randint(0, self.num_classes, size=(self.size,), dtype=torch.long) - return self.x[index] - - -class SimpleMLPForTestingOOM(ComposerModel): - - def __init__(self, num_features: int = 128, device: str = 'cuda'): - super().__init__() - self.device = device - self.fc1 = torch.nn.Linear(num_features, num_features, device=device, bias=False) - self.fc2 = torch.nn.Linear(num_features, num_features, device=device, bias=False) - self.fc3 = torch.nn.Linear(num_features, num_features, device=device, bias=False) - self.rank = dist.get_global_rank() - self.iter = 0 - - def forward(self, x): - x = self.fc1(x) - if self.rank == 0 and x.shape[0] >= 64: - raise RuntimeError('CUDA out of memory') - x = self.fc2(x) - x = self.fc3(x) - self.iter += 1 - return x - - def loss(self, outputs, batch): - return torch.sum(outputs) - - -@pytest.mark.gpu -@world_size(2) -def test_fsdp_auto_microbatch(world_size: int): - model = SimpleMLPForTestingOOM() - model.fc1._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues] - model.fc2._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues] - dataset = SimpleDataset(size=256, feature_size=128) - dataloader = DataLoader(dataset, sampler=dist.get_sampler(dataset), batch_size=64) - optimizer = torch.optim.SGD(model.parameters(), lr=0.01) - - trainer = Trainer( - model=model, - optimizers=optimizer, - train_dataloader=dataloader, - fsdp_config={ - 'forward_prefetch_limit': 1, - 'backward_prefetch_limit': 1, - }, - max_duration='3ba', - device_train_microbatch_size='auto', - dist_timeout=20, - ) - - trainer.fit() - assert False - - @pytest.mark.gpu @world_size(2) @pytest.mark.filterwarnings('ignore:Instantiating FSDP with custom process groups.*:UserWarning') From defce2e5648aef73246ef222addd56a7c48c31c2 Mon Sep 17 00:00:00 2001 From: Ning Wang Date: Tue, 20 Feb 2024 21:12:09 +0000 Subject: [PATCH 05/17] comments --- composer/trainer/trainer.py | 17 +++++++++++++---- tests/trainer/test_fsdp.py | 37 +++++++++---------------------------- 2 files changed, 22 insertions(+), 32 deletions(-) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index b80a7d8604..13d9984d29 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -235,13 +235,22 @@ def _is_cuda_oom(e: RuntimeError): def _fsdp_reshard_and_cleanup(model: torch.nn.Module): + """ + This function manually reshards and cleans-up FSDP model by calling + _post_backward_final_callback on FSDP root module. In normal case, _post_backward_final_callback + is registered as a backward callback. But when exception happens, like OOM, that callback + is skipped. There will be memory leak if we don't call that callback. + Find more information here: https://github.com/mosaicml/composer/pull/3030 + """ for name, module in model.named_modules(): if isinstance(module, FullyShardedDataParallel): if module.check_is_root(): - try: - _post_backward_final_callback(module, module) - except Exception as e: - log.warning(f'Failed to reshard fsdp after oom, error: {e}') + """ + Only call _post_backward_final_callback on root module. It will + traverse all the FSDP sub-modules and do the reshard and cleanup + on all sub-modules + """ + _post_backward_final_callback(module, module) def _adjust_device_train_microbatch_size(state: State): diff --git a/tests/trainer/test_fsdp.py b/tests/trainer/test_fsdp.py index f3dfe68cdb..2cf7226926 100644 --- a/tests/trainer/test_fsdp.py +++ b/tests/trainer/test_fsdp.py @@ -222,10 +222,11 @@ def __init__(self, num_features: int = 128, device: str = 'cuda'): super().__init__() self.fc1 = torch.nn.Linear(num_features, num_features, device=device, bias=False) self.fc2 = torch.nn.Linear(num_features, num_features, device=device, bias=False) + self.relu = torch.nn.ReLU() def forward(self, x): x = self.fc1(x) - x = torch.nn.ReLU(x) + x = self.relu(x) x = self.fc2(x) return x @@ -274,35 +275,15 @@ def test_fsdp_act_ckpt_offload( assert not isinstance(trainer.state.model.fc1._fsdp_wrapped_module, CheckpointWrapper) -class SimpleMLPForTestingOOM(ComposerModel): - - def __init__(self, num_features: int = 128, device: str = 'cuda'): - super().__init__() - self.device = device - self.fc1 = torch.nn.Linear(num_features, num_features, device=device, bias=False) - self.fc2 = torch.nn.Linear(num_features, num_features, device=device, bias=False) - self.fc3 = torch.nn.Linear(num_features, num_features, device=device, bias=False) - self.rank = dist.get_global_rank() - - def oom_hook(*args): - raise RuntimeError('CUDA out of memory.') - - self.fc2.register_full_backward_hook(oom_hook) - - def forward(self, x): - x = self.fc1(x) - x = self.fc2(x) - x = self.fc3(x) - return x - - def loss(self, outputs, batch): - return torch.sum(outputs) - - @pytest.mark.gpu @world_size(2) def test_fsdp_reshard_after_oom(world_size: int): - model = SimpleMLPForTestingOOM() + model = SimpleMLP(num_features=128) + + def oom_hook(*args): + raise RuntimeError('CUDA out of memory.') + + model.fc2.register_full_backward_hook(oom_hook) trainer = Trainer( model=model, @@ -321,7 +302,7 @@ def test_fsdp_reshard_after_oom(world_size: int): fc2_flat_param = fsdp_model.fc2._flat_param - # without cleanup, model.fc2.flat_params is still in unshard state + # Without cleanup, model.fc2.flat_params is still in unshard state # the full param is not freed assert fc2_flat_param.data_ptr() != fc2_flat_param._local_shard.data_ptr() assert fc2_flat_param._full_param_padded.numel() > 0 From e863adb565522508e4f4f90d189fe770b088c1bf Mon Sep 17 00:00:00 2001 From: Ning Wang Date: Tue, 20 Feb 2024 21:17:47 +0000 Subject: [PATCH 06/17] more test --- tests/trainer/test_fsdp.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/trainer/test_fsdp.py b/tests/trainer/test_fsdp.py index 2cf7226926..ac8b01587c 100644 --- a/tests/trainer/test_fsdp.py +++ b/tests/trainer/test_fsdp.py @@ -309,3 +309,4 @@ def oom_hook(*args): _fsdp_reshard_and_cleanup(fsdp_model) assert fc2_flat_param.data_ptr() == fc2_flat_param._local_shard.data_ptr() + assert fc2_flat_param._full_param_padded._typed_storage()._size() == 0 From 23871797e2fab2e0c51f381cd433fe5395b7aa28 Mon Sep 17 00:00:00 2001 From: Ning Wang Date: Tue, 20 Feb 2024 22:49:33 +0000 Subject: [PATCH 07/17] fix the warning --- tests/trainer/test_fsdp.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/trainer/test_fsdp.py b/tests/trainer/test_fsdp.py index ac8b01587c..fb4e4ffb3f 100644 --- a/tests/trainer/test_fsdp.py +++ b/tests/trainer/test_fsdp.py @@ -279,6 +279,7 @@ def test_fsdp_act_ckpt_offload( @world_size(2) def test_fsdp_reshard_after_oom(world_size: int): model = SimpleMLP(num_features=128) + #model.relu._fsdp_wrap = False def oom_hook(*args): raise RuntimeError('CUDA out of memory.') From 15e9eee209196b92ae0c367f4f4ac3628e83b0f3 Mon Sep 17 00:00:00 2001 From: Ning Wang Date: Wed, 21 Feb 2024 20:11:37 +0000 Subject: [PATCH 08/17] add numerical correctness test --- composer/trainer/trainer.py | 2 + tests/trainer/test_fsdp.py | 80 ++++++++++++++++++++++++++++++++++++- 2 files changed, 81 insertions(+), 1 deletion(-) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 13d9984d29..6e6219baf2 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -2501,6 +2501,8 @@ def _train_microbatch(self, use_grad_scaling: bool, current_batch_size: int, with _get_precision_context(self.state.precision, self.state.precision_config, self.state.deepspeed_enabled): + print( + f'bigning debug batch shape: {self.state.batch[0].shape}, label shape: {self.state.batch[1].shape}') self.state.outputs = self.state.model(self.state.batch) self.engine.run_event(Event.AFTER_FORWARD) diff --git a/tests/trainer/test_fsdp.py b/tests/trainer/test_fsdp.py index fb4e4ffb3f..b832f196ee 100644 --- a/tests/trainer/test_fsdp.py +++ b/tests/trainer/test_fsdp.py @@ -1,12 +1,15 @@ # Copyright 2022 MosaicML Composer authors # SPDX-License-Identifier: Apache-2.0 +import copy +import pathlib from unittest.mock import MagicMock import pytest import torch from packaging import version from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import CheckpointWrapper +from torch.distributed.fsdp import FullyShardedDataParallel from torch.utils.data import DataLoader from composer.models import ComposerClassifier, ComposerModel @@ -14,6 +17,8 @@ from composer.utils import dist from tests.common import (EmbeddedWeightTiedModel, RandomClassificationDataset, SimpleModel, SimpleWeightTiedModel, world_size) +from tests.trainer.test_fsdp_checkpoint import (_compare_model_params_between_state_dicts, + _compare_optims_between_state_dicts) _INIT_DEVICES = ['cpu', 'meta', 'mixed', 'cuda'] _MIXED_PRECISION_TYPES = ['FULL', 'DEFAULT', 'PURE'] @@ -218,7 +223,7 @@ def test_fsdp_process_group(world_size: int): class SimpleMLP(ComposerModel): - def __init__(self, num_features: int = 128, device: str = 'cuda'): + def __init__(self, num_features: int = 2, device: str = 'cuda'): super().__init__() self.fc1 = torch.nn.Linear(num_features, num_features, device=device, bias=False) self.fc2 = torch.nn.Linear(num_features, num_features, device=device, bias=False) @@ -279,6 +284,7 @@ def test_fsdp_act_ckpt_offload( @world_size(2) def test_fsdp_reshard_after_oom(world_size: int): model = SimpleMLP(num_features=128) + #model.relu._fsdp_wrap = False def oom_hook(*args): @@ -311,3 +317,75 @@ def oom_hook(*args): _fsdp_reshard_and_cleanup(fsdp_model) assert fc2_flat_param.data_ptr() == fc2_flat_param._local_shard.data_ptr() assert fc2_flat_param._full_param_padded._typed_storage()._size() == 0 + + +@pytest.mark.gpu +@world_size(2) +def test_fsdp_same_state_after_oom_reshard(world_size: int, tmp_path: pathlib.Path): + """ + Test the numerical correctness after we continue to train with + smaller batch size after OOM. + """ + model = SimpleMLP() + model.fc1._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues] + model.fc2._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues] + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + + trainer = Trainer( + model=model, + fsdp_config={}, + dist_timeout=20, + optimizers=optimizer, + seed=1, + ) + fsdp_model = trainer.state.model + + state_dict = fsdp_model.state_dict() + + oom_model = SimpleMLP() + oom_model.fc1._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues] + oom_model.fc2._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues] + oom_model_optimizer = torch.optim.SGD(oom_model.parameters(), lr=0.1) + + def oom_hook(module, grad_input, grad_ouput): + if grad_ouput[0].shape[0] >= 4: + raise RuntimeError('CUDA out of memory.') + + oom_handle = oom_model.fc2.register_full_backward_hook(oom_hook) + oom_trainer = Trainer( + model=oom_model, + fsdp_config={}, + dist_timeout=20, + optimizers=oom_model_optimizer, + seed=1, + ) + + fsdp_oom_model = oom_trainer.state.model + fsdp_oom_model.load_state_dict(state_dict) + + x = torch.rand([4, 2]) + + # Run fwd + bwd + optimizer on normal model + output_0 = fsdp_model(x) + torch.sum(output_0).backward() + optimizer.step() + + # Run fwd + bwd + optimizer on OOM model + output = fsdp_oom_model(x) + with pytest.raises(Exception): + torch.sum(output).backward() + # Cleanup after OOM + _fsdp_reshard_and_cleanup(fsdp_oom_model) + oom_model_optimizer.zero_grad(set_to_none=True) + + oom_handle.remove() + output = fsdp_oom_model(x) + torch.sum(output).backward() + oom_model_optimizer.step() + + # Run another fwd on both model and check + # if output is the same + output_1 = fsdp_model(x) + output_2 = fsdp_oom_model(x) + + assert torch.equal(output_1, output_2) From 8db614b8bf750a8789e830181d737c9e1c4bbb86 Mon Sep 17 00:00:00 2001 From: bigning Date: Wed, 21 Feb 2024 13:09:27 -0800 Subject: [PATCH 09/17] Apply suggestions from code review Co-authored-by: Mihir Patel --- composer/trainer/trainer.py | 20 +++++++------------- tests/trainer/test_fsdp.py | 8 +------- 2 files changed, 8 insertions(+), 20 deletions(-) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index f8e83b7533..42c0cf72e2 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -235,21 +235,17 @@ def _is_cuda_oom(e: RuntimeError): def _fsdp_reshard_and_cleanup(model: torch.nn.Module): - """ - This function manually reshards and cleans-up FSDP model by calling - _post_backward_final_callback on FSDP root module. In normal case, _post_backward_final_callback - is registered as a backward callback. But when exception happens, like OOM, that callback - is skipped. There will be memory leak if we don't call that callback. - Find more information here: https://github.com/mosaicml/composer/pull/3030 + """Manually reshard and clean up FSDP model. + + When an exception like OOM happens, _post_backward_final_callback, which + is registered as a backward callback, will not run. We manually call it to cleanup + loose memory. """ for name, module in model.named_modules(): if isinstance(module, FullyShardedDataParallel): if module.check_is_root(): - """ - Only call _post_backward_final_callback on root module. It will - traverse all the FSDP sub-modules and do the reshard and cleanup - on all sub-modules - """ + # Only call _post_backward_final_callback on root module. It will + # traverse and reshard all FSDP sub-modules _post_backward_final_callback(module, module) @@ -2504,8 +2500,6 @@ def _train_microbatch(self, use_grad_scaling: bool, current_batch_size: int, with _get_precision_context(self.state.precision, self.state.precision_config, self.state.deepspeed_enabled): - print( - f'bigning debug batch shape: {self.state.batch[0].shape}, label shape: {self.state.batch[1].shape}') self.state.outputs = self.state.model(self.state.batch) self.engine.run_event(Event.AFTER_FORWARD) diff --git a/tests/trainer/test_fsdp.py b/tests/trainer/test_fsdp.py index 2335dc0628..4cb1969d6b 100644 --- a/tests/trainer/test_fsdp.py +++ b/tests/trainer/test_fsdp.py @@ -295,8 +295,6 @@ def test_fsdp_act_ckpt_offload( def test_fsdp_reshard_after_oom(world_size: int): model = SimpleMLP(num_features=128) - #model.relu._fsdp_wrap = False - def oom_hook(*args): raise RuntimeError('CUDA out of memory.') @@ -306,7 +304,6 @@ def oom_hook(*args): model=model, fsdp_config={}, max_duration='3ba', - dist_timeout=20, ) fsdp_model = trainer.state.model @@ -332,10 +329,7 @@ def oom_hook(*args): @pytest.mark.gpu @world_size(2) def test_fsdp_same_state_after_oom_reshard(world_size: int, tmp_path: pathlib.Path): - """ - Test the numerical correctness after we continue to train with - smaller batch size after OOM. - """ + #Test numerical correctness after continuing to train with smaller batch size after OOM. model = SimpleMLP() model.fc1._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues] model.fc2._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues] From d72fe1905cc23a759ed798e8bc4f43d07405ce84 Mon Sep 17 00:00:00 2001 From: Ning Wang Date: Wed, 21 Feb 2024 21:19:04 +0000 Subject: [PATCH 10/17] lint --- tests/trainer/test_fsdp.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/trainer/test_fsdp.py b/tests/trainer/test_fsdp.py index 4cb1969d6b..dd2d0bbc0f 100644 --- a/tests/trainer/test_fsdp.py +++ b/tests/trainer/test_fsdp.py @@ -1,8 +1,6 @@ # Copyright 2022 MosaicML Composer authors # SPDX-License-Identifier: Apache-2.0 -import copy -import pathlib from unittest.mock import MagicMock import pytest @@ -17,8 +15,6 @@ from composer.utils import dist from tests.common import (EmbeddedWeightTiedModel, RandomClassificationDataset, SimpleModel, SimpleWeightTiedModel, world_size) -from tests.trainer.test_fsdp_checkpoint import (_compare_model_params_between_state_dicts, - _compare_optims_between_state_dicts) _INIT_DEVICES = ['cpu', 'meta', 'mixed', 'cuda'] _MIXED_PRECISION_TYPES = ['FULL', 'DEFAULT', 'PURE'] @@ -328,7 +324,7 @@ def oom_hook(*args): @pytest.mark.gpu @world_size(2) -def test_fsdp_same_state_after_oom_reshard(world_size: int, tmp_path: pathlib.Path): +def test_fsdp_same_state_after_oom_reshard(world_size: int): #Test numerical correctness after continuing to train with smaller batch size after OOM. model = SimpleMLP() model.fc1._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues] From 25f3d992d55a521e3907da81b0ad0590f0515faa Mon Sep 17 00:00:00 2001 From: Ning Wang Date: Wed, 21 Feb 2024 22:42:32 +0000 Subject: [PATCH 11/17] fix test warnning --- tests/trainer/test_fsdp.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/trainer/test_fsdp.py b/tests/trainer/test_fsdp.py index dd2d0bbc0f..2e88d9a044 100644 --- a/tests/trainer/test_fsdp.py +++ b/tests/trainer/test_fsdp.py @@ -290,6 +290,7 @@ def test_fsdp_act_ckpt_offload( @world_size(2) def test_fsdp_reshard_after_oom(world_size: int): model = SimpleMLP(num_features=128) + model.relu._fsdp_wrap = False def oom_hook(*args): raise RuntimeError('CUDA out of memory.') @@ -329,6 +330,7 @@ def test_fsdp_same_state_after_oom_reshard(world_size: int): model = SimpleMLP() model.fc1._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues] model.fc2._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues] + model.relu._fsdp_wrap = False optimizer = torch.optim.SGD(model.parameters(), lr=0.1) trainer = Trainer( @@ -345,6 +347,7 @@ def test_fsdp_same_state_after_oom_reshard(world_size: int): oom_model = SimpleMLP() oom_model.fc1._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues] oom_model.fc2._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues] + oom_model.relu._fsdp_wrap = False oom_model_optimizer = torch.optim.SGD(oom_model.parameters(), lr=0.1) def oom_hook(module, grad_input, grad_ouput): From 921f92d8284940fb7574253d15b4e79ed79d7168 Mon Sep 17 00:00:00 2001 From: Ning Wang Date: Wed, 21 Feb 2024 22:45:36 +0000 Subject: [PATCH 12/17] revert irrelevant change --- tests/trainer/test_fsdp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/test_fsdp.py b/tests/trainer/test_fsdp.py index 2e88d9a044..df76289a4f 100644 --- a/tests/trainer/test_fsdp.py +++ b/tests/trainer/test_fsdp.py @@ -229,7 +229,7 @@ def test_wrong_size_device_mesh_error(world_size: int): class SimpleMLP(ComposerModel): - def __init__(self, num_features: int = 2, device: str = 'cuda'): + def __init__(self, num_features: int = 128, device: str = 'cuda'): super().__init__() self.fc1 = torch.nn.Linear(num_features, num_features, device=device, bias=False) self.fc2 = torch.nn.Linear(num_features, num_features, device=device, bias=False) From 174dc8356a5d54924c69c8eb4ec24efa180ba30e Mon Sep 17 00:00:00 2001 From: Ning Wang Date: Wed, 21 Feb 2024 23:37:04 +0000 Subject: [PATCH 13/17] lint --- composer/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 42c0cf72e2..20ad57cf1f 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -236,7 +236,7 @@ def _is_cuda_oom(e: RuntimeError): def _fsdp_reshard_and_cleanup(model: torch.nn.Module): """Manually reshard and clean up FSDP model. - + When an exception like OOM happens, _post_backward_final_callback, which is registered as a backward callback, will not run. We manually call it to cleanup loose memory. From 3c254cd3c441e15e6bda8ed020e60ab2f7f4e5bb Mon Sep 17 00:00:00 2001 From: Ning Wang Date: Wed, 21 Feb 2024 23:47:08 +0000 Subject: [PATCH 14/17] fix test --- tests/trainer/test_fsdp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/trainer/test_fsdp.py b/tests/trainer/test_fsdp.py index df76289a4f..5176fcc4f2 100644 --- a/tests/trainer/test_fsdp.py +++ b/tests/trainer/test_fsdp.py @@ -327,7 +327,7 @@ def oom_hook(*args): @world_size(2) def test_fsdp_same_state_after_oom_reshard(world_size: int): #Test numerical correctness after continuing to train with smaller batch size after OOM. - model = SimpleMLP() + model = SimpleMLP(num_features=2) model.fc1._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues] model.fc2._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues] model.relu._fsdp_wrap = False @@ -344,7 +344,7 @@ def test_fsdp_same_state_after_oom_reshard(world_size: int): state_dict = fsdp_model.state_dict() - oom_model = SimpleMLP() + oom_model = SimpleMLP(num_features=2) oom_model.fc1._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues] oom_model.fc2._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues] oom_model.relu._fsdp_wrap = False From 76b92c0f588bdf39cb10e54a85eb4510c2c0970d Mon Sep 17 00:00:00 2001 From: bigning Date: Wed, 21 Feb 2024 16:48:21 -0800 Subject: [PATCH 15/17] Update tests/trainer/test_fsdp.py Co-authored-by: Mihir Patel --- tests/trainer/test_fsdp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/test_fsdp.py b/tests/trainer/test_fsdp.py index 5176fcc4f2..73842dd981 100644 --- a/tests/trainer/test_fsdp.py +++ b/tests/trainer/test_fsdp.py @@ -326,7 +326,7 @@ def oom_hook(*args): @pytest.mark.gpu @world_size(2) def test_fsdp_same_state_after_oom_reshard(world_size: int): - #Test numerical correctness after continuing to train with smaller batch size after OOM. + # Test numerical correctness after continuing to train with smaller batch size after OOM. model = SimpleMLP(num_features=2) model.fc1._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues] model.fc2._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues] From a7aa3dc35f7a1dce3cca1666c2542db5faf6b2b8 Mon Sep 17 00:00:00 2001 From: Ning Wang Date: Thu, 22 Feb 2024 00:59:51 +0000 Subject: [PATCH 16/17] fix lint --- composer/utils/object_store/oci_object_store.py | 3 ++- tests/trainer/test_fsdp.py | 7 +++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/composer/utils/object_store/oci_object_store.py b/composer/utils/object_store/oci_object_store.py index 72898464cc..cf4b78bdbd 100644 --- a/composer/utils/object_store/oci_object_store.py +++ b/composer/utils/object_store/oci_object_store.py @@ -152,7 +152,8 @@ def download_object( object_size = 0 try: head_object_response = self.client.head_object(self.namespace, self.bucket, object_name) - object_size = int(head_object_response.headers['content-length']) # pyright: ignore[reportOptionalMemberAccess] + object_size = int( + head_object_response.headers['content-length']) # pyright: ignore[reportOptionalMemberAccess] except Exception as e: _reraise_oci_errors(self.get_uri(object_name), e) diff --git a/tests/trainer/test_fsdp.py b/tests/trainer/test_fsdp.py index 5176fcc4f2..475fe3a614 100644 --- a/tests/trainer/test_fsdp.py +++ b/tests/trainer/test_fsdp.py @@ -7,7 +7,6 @@ import torch from packaging import version from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import CheckpointWrapper -from torch.distributed.fsdp import FullyShardedDataParallel from torch.utils.data import DataLoader from composer.models import ComposerClassifier, ComposerModel @@ -290,7 +289,7 @@ def test_fsdp_act_ckpt_offload( @world_size(2) def test_fsdp_reshard_after_oom(world_size: int): model = SimpleMLP(num_features=128) - model.relu._fsdp_wrap = False + model.relu._fsdp_wrap = False # pyright: ignore[reportGeneralTypeIssues] def oom_hook(*args): raise RuntimeError('CUDA out of memory.') @@ -330,7 +329,7 @@ def test_fsdp_same_state_after_oom_reshard(world_size: int): model = SimpleMLP(num_features=2) model.fc1._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues] model.fc2._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues] - model.relu._fsdp_wrap = False + model.relu._fsdp_wrap = False # pyright: ignore[reportGeneralTypeIssues] optimizer = torch.optim.SGD(model.parameters(), lr=0.1) trainer = Trainer( @@ -347,7 +346,7 @@ def test_fsdp_same_state_after_oom_reshard(world_size: int): oom_model = SimpleMLP(num_features=2) oom_model.fc1._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues] oom_model.fc2._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues] - oom_model.relu._fsdp_wrap = False + oom_model.relu._fsdp_wrap = False # pyright: ignore[reportGeneralTypeIssues] oom_model_optimizer = torch.optim.SGD(oom_model.parameters(), lr=0.1) def oom_hook(module, grad_input, grad_ouput): From 059005238a870d0e2830c3a7b1da439a7176ee14 Mon Sep 17 00:00:00 2001 From: Ning Wang Date: Thu, 22 Feb 2024 01:11:14 +0000 Subject: [PATCH 17/17] lint --- composer/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 20ad57cf1f..7411dc4393 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -241,7 +241,7 @@ def _fsdp_reshard_and_cleanup(model: torch.nn.Module): is registered as a backward callback, will not run. We manually call it to cleanup loose memory. """ - for name, module in model.named_modules(): + for __, module in model.named_modules(): if isinstance(module, FullyShardedDataParallel): if module.check_is_root(): # Only call _post_backward_final_callback on root module. It will