Skip to content

Commit

Permalink
Avoid out of GPU memory issue when loading checkpoints (#1532)
Browse files Browse the repository at this point in the history
* Avoid out of GPU memory issue when loading checkpoints

* Add test
  • Loading branch information
Haichao-Zhang authored Sep 5, 2023
1 parent 36947f5 commit 2485b77
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 7 deletions.
13 changes: 6 additions & 7 deletions alf/utils/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,9 @@ def extract_sub_state_dict_from_checkpoint(checkpoint_prefix, checkpoint_path):
by ALF, e.g. "/path_to_experiment/train/algorithm/ckpt-100".
"""

map_location = None
if not torch.cuda.is_available():
map_location = torch.device('cpu')

# use ``cpu`` as the map location to avoid GPU RAM surge when loading a
# model checkpoint.
map_location = torch.device('cpu')
checkpoint = torch.load(checkpoint_path, map_location=map_location)

if checkpoint_prefix != '':
Expand Down Expand Up @@ -249,9 +248,9 @@ def _merge_checkpoint(merged, new):
"Checkpoint '%s' does not exist. Train from scratch." % f_path)
return self._global_step

map_location = None
if not torch.cuda.is_available():
map_location = torch.device('cpu')
# use ``cpu`` as the map location to avoid GPU RAM surge when loading a
# model checkpoint.
map_location = torch.device('cpu')

checkpoint = torch.load(f_path, map_location=map_location)
if including_optimizer:
Expand Down
31 changes: 31 additions & 0 deletions alf/utils/checkpoint_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from alf.networks.preprocessors import EmbeddingPreprocessor
from alf.tensor_specs import TensorSpec
from alf.utils import common
import unittest


class Net(nn.Module):
Expand Down Expand Up @@ -627,5 +628,35 @@ def test_checkpoint_structure(self):
self.assertEqual(model_structure, expected_model_structure)


class TestCheckpointMapLocation(alf.test.TestCase):
@unittest.skipIf(not torch.cuda.is_available(), "gpu is unavailable")
def test_map_location(self):
# test that that loading checkpoint file using cpu as the map_location
# works in the subsequent ``load_state_dict`` both for cpu and gpu models
model_on_cpu = Net()
model_on_gpu = Net().cuda()

with tempfile.TemporaryDirectory() as ckpt_dir:
ckpt_path = ckpt_dir + '/test.ckpt'
torch.save(model_on_cpu.state_dict(), ckpt_path)

# load checkpoint file on cpu
device = torch.device('cpu')
state_dict = torch.load(ckpt_path, map_location=device)

# test load state-dict for cpu model
model_on_cpu.load_state_dict(state_dict)

# test load state-dict for cpu model
model_on_gpu.load_state_dict(state_dict)

for p1, p2 in zip(model_on_cpu.parameters(),
model_on_gpu.parameters()):
self.assertTrue('cpu' == str(p1.device))
self.assertTrue(p2.is_cuda)

self.assertTensorClose(p1, p2.cpu())


if __name__ == '__main__':
alf.test.main()

0 comments on commit 2485b77

Please sign in to comment.