Skip to content

Commit

Permalink
fix tensorstore import
Browse files Browse the repository at this point in the history
Signed-off-by: jiemingz <jiemingz@nvidia.com>
  • Loading branch information
jiemingz committed Mar 22, 2024
1 parent 45adc1d commit 1ee03a3
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
try:
from megatron.core import dist_checkpointing, parallel_state
from megatron.core.dist_checkpointing.dict_utils import dict_list_map_outplace
from megatron.core.dist_checkpointing.strategies import tensorstore
from megatron.core.dist_checkpointing.optimizer import (
get_param_id_to_sharded_param_map,
make_sharded_optimizer_tensor,
Expand Down Expand Up @@ -441,7 +442,7 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
# after dist_checkpointing.load, sharded tensors will be replaced with tensors
checkpoint['state_dict'] = sharded_state_dict
checkpoint['optimizer_states'] = [self.optimizer_sharded_state_dict()]
strategy = dist_checkpointing.strategies.tensorstore.TensorStoreLoadShardedStrategy(
strategy = tensorstore.TensorStoreLoadShardedStrategy(
load_directly_on_device=True
)
checkpoint = dist_checkpointing.load(
Expand Down

0 comments on commit 1ee03a3

Please sign in to comment.