Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix auto-microbatch] FSDP reshard and cleanup after OOM to fix the cuda memory leak #3030

Merged
merged 22 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
import torch.utils.data
from torch._dynamo import OptimizedModule
from torch.cuda.amp.grad_scaler import GradScaler, _refresh_per_optimizer_state
from torch.distributed.fsdp import FullyShardedDataParallel
from torch.distributed.fsdp._runtime_utils import _post_backward_final_callback
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from torch.nn.parallel import DistributedDataParallel
from torch.optim.lr_scheduler import LRScheduler
Expand Down Expand Up @@ -232,6 +234,25 @@ def _is_cuda_oom(e: RuntimeError):
return False


def _fsdp_reshard_and_cleanup(model: torch.nn.Module):
"""
This function manually reshards and cleans-up FSDP model by calling
_post_backward_final_callback on FSDP root module. In normal case, _post_backward_final_callback
is registered as a backward callback. But when exception happens, like OOM, that callback
is skipped. There will be memory leak if we don't call that callback.
Find more information here: https://github.com/mosaicml/composer/pull/3030
"""
bigning marked this conversation as resolved.
Show resolved Hide resolved
for name, module in model.named_modules():
if isinstance(module, FullyShardedDataParallel):
if module.check_is_root():
bigning marked this conversation as resolved.
Show resolved Hide resolved
"""
Only call _post_backward_final_callback on root module. It will
traverse all the FSDP sub-modules and do the reshard and cleanup
on all sub-modules
"""
bigning marked this conversation as resolved.
Show resolved Hide resolved
_post_backward_final_callback(module, module)


def _adjust_device_train_microbatch_size(state: State):
"""Adjust device_train_microbatch_size if we encounter OOM.

Expand Down Expand Up @@ -259,6 +280,7 @@ def _adjust_device_train_microbatch_size(state: State):
optimizer.zero_grad(set_to_none=True)
if state.scaler is not None:
state.scaler._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
_fsdp_reshard_and_cleanup(state.model)
torch.cuda.empty_cache()


Expand Down Expand Up @@ -2479,6 +2501,8 @@ def _train_microbatch(self, use_grad_scaling: bool, current_batch_size: int,

with _get_precision_context(self.state.precision, self.state.precision_config,
self.state.deepspeed_enabled):
print(
f'bigning debug batch shape: {self.state.batch[0].shape}, label shape: {self.state.batch[1].shape}')
bigning marked this conversation as resolved.
Show resolved Hide resolved
self.state.outputs = self.state.model(self.state.batch)

self.engine.run_event(Event.AFTER_FORWARD)
Expand Down
123 changes: 120 additions & 3 deletions tests/trainer/test_fsdp.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

import copy
import pathlib
from unittest.mock import MagicMock

import pytest
import torch
from packaging import version
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import CheckpointWrapper
from torch.distributed.fsdp import FullyShardedDataParallel
from torch.utils.data import DataLoader

from composer.models import ComposerClassifier, ComposerModel
from composer.trainer.trainer import Trainer
from composer.trainer.trainer import Trainer, _fsdp_reshard_and_cleanup
from composer.utils import dist
from tests.common import (EmbeddedWeightTiedModel, RandomClassificationDataset, SimpleModel, SimpleWeightTiedModel,
world_size)
from tests.trainer.test_fsdp_checkpoint import (_compare_model_params_between_state_dicts,
_compare_optims_between_state_dicts)

_INIT_DEVICES = ['cpu', 'meta', 'mixed', 'cuda']
_MIXED_PRECISION_TYPES = ['FULL', 'DEFAULT', 'PURE']
Expand Down Expand Up @@ -218,14 +223,15 @@ def test_fsdp_process_group(world_size: int):

class SimpleMLP(ComposerModel):

def __init__(self, num_features: int = 128, device: str = 'cuda'):
def __init__(self, num_features: int = 2, device: str = 'cuda'):
super().__init__()
self.fc1 = torch.nn.Linear(num_features, num_features, device=device, bias=False)
self.fc2 = torch.nn.Linear(num_features, num_features, device=device, bias=False)
self.relu = torch.nn.ReLU()

def forward(self, x):
x = self.fc1(x)
x = torch.nn.ReLU(x)
x = self.relu(x)
x = self.fc2(x)
return x

Expand Down Expand Up @@ -272,3 +278,114 @@ def test_fsdp_act_ckpt_offload(
assert isinstance(trainer.state.model.fc1._fsdp_wrapped_module, OffloadWrapper)
else:
assert not isinstance(trainer.state.model.fc1._fsdp_wrapped_module, CheckpointWrapper)


@pytest.mark.gpu
@world_size(2)
def test_fsdp_reshard_after_oom(world_size: int):
model = SimpleMLP(num_features=128)

#model.relu._fsdp_wrap = False
bigning marked this conversation as resolved.
Show resolved Hide resolved

def oom_hook(*args):
raise RuntimeError('CUDA out of memory.')

model.fc2.register_full_backward_hook(oom_hook)

trainer = Trainer(
model=model,
fsdp_config={},
max_duration='3ba',
dist_timeout=20,
bigning marked this conversation as resolved.
Show resolved Hide resolved
)
fsdp_model = trainer.state.model

x = torch.rand([2, 128])
output = fsdp_model(x)
with pytest.raises(Exception):
# Backward triggers the fake OOM exception,
# which prevents fsdp reshard and cleanup
torch.sum(output).backward()

fc2_flat_param = fsdp_model.fc2._flat_param

# Without cleanup, model.fc2.flat_params is still in unshard state
# the full param is not freed
assert fc2_flat_param.data_ptr() != fc2_flat_param._local_shard.data_ptr()
bigning marked this conversation as resolved.
Show resolved Hide resolved
assert fc2_flat_param._full_param_padded.numel() > 0

_fsdp_reshard_and_cleanup(fsdp_model)
assert fc2_flat_param.data_ptr() == fc2_flat_param._local_shard.data_ptr()
assert fc2_flat_param._full_param_padded._typed_storage()._size() == 0


@pytest.mark.gpu
@world_size(2)
def test_fsdp_same_state_after_oom_reshard(world_size: int, tmp_path: pathlib.Path):
"""
Test the numerical correctness after we continue to train with
smaller batch size after OOM.
"""
bigning marked this conversation as resolved.
Show resolved Hide resolved
model = SimpleMLP()
model.fc1._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues]
model.fc2._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues]
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

trainer = Trainer(
model=model,
fsdp_config={},
dist_timeout=20,
optimizers=optimizer,
seed=1,
)
fsdp_model = trainer.state.model

state_dict = fsdp_model.state_dict()

oom_model = SimpleMLP()
oom_model.fc1._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues]
oom_model.fc2._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues]
oom_model_optimizer = torch.optim.SGD(oom_model.parameters(), lr=0.1)

def oom_hook(module, grad_input, grad_ouput):
if grad_ouput[0].shape[0] >= 4:
raise RuntimeError('CUDA out of memory.')

oom_handle = oom_model.fc2.register_full_backward_hook(oom_hook)
oom_trainer = Trainer(
model=oom_model,
fsdp_config={},
dist_timeout=20,
optimizers=oom_model_optimizer,
seed=1,
)

fsdp_oom_model = oom_trainer.state.model
fsdp_oom_model.load_state_dict(state_dict)

x = torch.rand([4, 2])

# Run fwd + bwd + optimizer on normal model
output_0 = fsdp_model(x)
torch.sum(output_0).backward()
optimizer.step()

# Run fwd + bwd + optimizer on OOM model
output = fsdp_oom_model(x)
with pytest.raises(Exception):
torch.sum(output).backward()
# Cleanup after OOM
_fsdp_reshard_and_cleanup(fsdp_oom_model)
oom_model_optimizer.zero_grad(set_to_none=True)

oom_handle.remove()
output = fsdp_oom_model(x)
torch.sum(output).backward()
oom_model_optimizer.step()

# Run another fwd on both model and check
# if output is the same
output_1 = fsdp_model(x)
output_2 = fsdp_oom_model(x)

assert torch.equal(output_1, output_2)
Loading