diff --git a/alf/utils/checkpoint_utils.py b/alf/utils/checkpoint_utils.py index 5e9874ef7..42f549847 100644 --- a/alf/utils/checkpoint_utils.py +++ b/alf/utils/checkpoint_utils.py @@ -61,10 +61,9 @@ def extract_sub_state_dict_from_checkpoint(checkpoint_prefix, checkpoint_path): by ALF, e.g. "/path_to_experiment/train/algorithm/ckpt-100". """ - map_location = None - if not torch.cuda.is_available(): - map_location = torch.device('cpu') - + # use ``cpu`` as the map location to avoid GPU RAM surge when loading a + # model checkpoint. + map_location = torch.device('cpu') checkpoint = torch.load(checkpoint_path, map_location=map_location) if checkpoint_prefix != '': @@ -249,9 +248,9 @@ def _merge_checkpoint(merged, new): "Checkpoint '%s' does not exist. Train from scratch." % f_path) return self._global_step - map_location = None - if not torch.cuda.is_available(): - map_location = torch.device('cpu') + # use ``cpu`` as the map location to avoid GPU RAM surge when loading a + # model checkpoint. + map_location = torch.device('cpu') checkpoint = torch.load(f_path, map_location=map_location) if including_optimizer: diff --git a/alf/utils/checkpoint_utils_test.py b/alf/utils/checkpoint_utils_test.py index f8ca3155e..fa2fc17a5 100644 --- a/alf/utils/checkpoint_utils_test.py +++ b/alf/utils/checkpoint_utils_test.py @@ -36,6 +36,7 @@ from alf.networks.preprocessors import EmbeddingPreprocessor from alf.tensor_specs import TensorSpec from alf.utils import common +import unittest class Net(nn.Module): @@ -627,5 +628,35 @@ def test_checkpoint_structure(self): self.assertEqual(model_structure, expected_model_structure) +class TestCheckpointMapLocation(alf.test.TestCase): + @unittest.skipIf(not torch.cuda.is_available(), "gpu is unavailable") + def test_map_location(self): + # test that that loading checkpoint file using cpu as the map_location + # works in the subsequent ``load_state_dict`` both for cpu and gpu models + model_on_cpu = Net() + model_on_gpu = Net().cuda() + + with tempfile.TemporaryDirectory() as ckpt_dir: + ckpt_path = ckpt_dir + '/test.ckpt' + torch.save(model_on_cpu.state_dict(), ckpt_path) + + # load checkpoint file on cpu + device = torch.device('cpu') + state_dict = torch.load(ckpt_path, map_location=device) + + # test load state-dict for cpu model + model_on_cpu.load_state_dict(state_dict) + + # test load state-dict for cpu model + model_on_gpu.load_state_dict(state_dict) + + for p1, p2 in zip(model_on_cpu.parameters(), + model_on_gpu.parameters()): + self.assertTrue('cpu' == str(p1.device)) + self.assertTrue(p2.is_cuda) + + self.assertTensorClose(p1, p2.cpu()) + + if __name__ == '__main__': alf.test.main()