Skip to content

Commit

Permalink
nemo1 pass metadata
Browse files Browse the repository at this point in the history
Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>
  • Loading branch information
jiemingz committed Sep 21, 2024
1 parent e68f981 commit e576a12
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1919,7 +1919,7 @@ def on_validation_model_zero_grad(self) -> None:
if not self.validation_param_sync_overlap:
super().on_validation_model_zero_grad()

def sharded_state_dict(self, prefix: str = '') -> Dict[str, Any]:
def sharded_state_dict(self, prefix: str = '', metadata: Optional[Dict] = None) -> Dict[str, Any]:
"""
Creates the sharded state dict which is used by dist_checkpoint to save the sharded tensors to disk.
When given the sharded_stated_dict, dist_checkpoint.load will load the tensors corresponding to
Expand All @@ -1934,10 +1934,10 @@ def sharded_state_dict(self, prefix: str = '') -> Dict[str, Any]:
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
# virtual pipline rank must be set so that GPTModel returns the correct sharded state dict
parallel_state.set_virtual_pipeline_model_parallel_rank(index)
module_sharded_state_dict = module.sharded_state_dict(prefix=module_prefix)
module_sharded_state_dict = module.sharded_state_dict(prefix=module_prefix, metadata=metadata)
sharded_state_dict[f'model_{index}'] = module_sharded_state_dict
else:
module_sharded_state_dict = module.sharded_state_dict(prefix=module_prefix)
module_sharded_state_dict = module.sharded_state_dict(prefix=module_prefix, metadata=metadata)
sharded_state_dict.update(module_sharded_state_dict)

# reset vp rank
Expand Down
7 changes: 5 additions & 2 deletions nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Callable, Dict, Generator, Iterator, List, Literal, Mapping, Optional, Sized, Union
from dataclasses import asdict

import pytorch_lightning as pl
import torch
Expand Down Expand Up @@ -104,6 +105,7 @@
optim_state_to_sharding_state,
)
from megatron.core.dist_checkpointing.strategies import tensorstore
from megatron.core.dist_checkpointing.core import maybe_load_config
from megatron.core.tensor_parallel.layers import param_is_not_tensor_parallel_duplicate
from megatron.core.transformer.module import Float16Module as MCoreFloat16Module
from megatron.core.transformer.transformer_layer import TransformerLayer as MCoreTransformerLayer
Expand Down Expand Up @@ -1338,9 +1340,10 @@ def dummy():
)

checkpoint = {}
sharded_state_dict = instance.sharded_state_dict()
metadata = asdict(maybe_load_config(tmp_model_weights_dir))
sharded_state_dict = instance.sharded_state_dict(metadata=metadata)
checkpoint['state_dict'] = sharded_state_dict

checkpoint_io = DistributedCheckpointIO.from_config(conf)
checkpoint = checkpoint_io.load_checkpoint(
tmp_model_weights_dir,
Expand Down

0 comments on commit e576a12

Please sign in to comment.