diff --git a/README.rst b/README.rst index d07b07434b20..fba4aaf04f09 100644 --- a/README.rst +++ b/README.rst @@ -38,6 +38,22 @@ **NVIDIA NeMo** =============== +Latest News +----------- + +- 2023/12/06 `New NVIDIA NeMo Framework Features and NVIDIA H200 `_ + +.. image:: https://github.com/sbhavani/TransformerEngine/blob/main/docs/examples/H200-NeMo-performance.png + :target: https://developer.nvidia.com/blog/new-nvidia-nemo-framework-features-and-nvidia-h200-supercharge-llm-training-performance-and-versatility + :alt: H200-NeMo-performance + :width: 600 + +NeMo Framework has been updated with state-of-the-art features, +such as FSDP, Mixture-of-Experts, and RLHF with TensorRT-LLM to provide speedups up to 4.2x for Llama-2 pre-training on H200. +**All of these features will be available in an upcoming release.** + + + Introduction ------------ diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index 6579e837b1a6..ccdd2e8725db 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -103,7 +103,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True): self.tokenizer = None with open_dict(cfg): - if cfg.get('precision', None) is None and trainer is not None: + if cfg.get('precision', None) is None: cfg.precision = trainer.precision super().__init__(cfg, trainer=trainer, no_lm_init=no_lm_init) @@ -773,7 +773,6 @@ def build_model_parallel_config(self) -> ModelParallelConfig: cfg = OmegaConf.to_container(self.cfg, resolve=True) # map precision related configs - precision = cfg.get('precision', 32) # PTL trainer precision megatron_amp_O2 = cfg.get('megatron_amp_O2', False) # dtype used in p2p communication @@ -791,7 +790,7 @@ def build_model_parallel_config(self) -> ModelParallelConfig: and not self.cfg.get('sequence_parallel', False), "pipeline_dtype": pipeline_dtype, "grad_scale_func": self.trainer.precision_plugin.scaler.scale - if self.torch_dtype == torch.float16 + if self.trainer.precision in ["16", "16-mixed"] else None, "enable_autocast": not megatron_amp_O2 and self.torch_dtype in [torch.bfloat16, torch.float16], "autocast_dtype": self.autocast_dtype, diff --git a/nemo/collections/nlp/parts/utils_funcs.py b/nemo/collections/nlp/parts/utils_funcs.py index 5185c6cf9b5a..2ec77faf91f5 100644 --- a/nemo/collections/nlp/parts/utils_funcs.py +++ b/nemo/collections/nlp/parts/utils_funcs.py @@ -12,7 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = ['list2str', 'tensor2list', 'plot_confusion_matrix', 'get_classification_report'] +__all__ = [ + 'torch_dtype_from_precision', + 'list2str', + 'tensor2list', + 'plot_confusion_matrix', + 'get_classification_report', +] import os import time diff --git a/scripts/nlp_language_modeling/convert_hf_llama_to_nemo.py b/scripts/nlp_language_modeling/convert_hf_llama_to_nemo.py index d1453aeee972..d6007aa771c0 100644 --- a/scripts/nlp_language_modeling/convert_hf_llama_to_nemo.py +++ b/scripts/nlp_language_modeling/convert_hf_llama_to_nemo.py @@ -17,8 +17,7 @@ Example to run this conversion script: python convert_hf_llama_to_nemo.py \ --in-file \ - --out-file \ - [--fast-swiglu\ + --out-file """ import os @@ -41,6 +40,7 @@ NLPSaveRestoreConnector, PipelineMixedPrecisionPlugin, ) +from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision from nemo.utils import logging @@ -50,7 +50,7 @@ def get_args(): "--in-file", type=str, default=None, required=True, help="Path to Huggingface LLaMA checkpoints", ) parser.add_argument("--out-file", type=str, default=None, required=True, help="Path to output .nemo file.") - parser.add_argument("--precision", type=str, default="32", help="Model precision") + parser.add_argument("--precision", type=str, default="16", help="Model precision") args = parser.parse_args() return args @@ -94,7 +94,7 @@ def load_model(cls, checkpoint, strict, **kwargs): return model -def load_config(args, llama_config): +def load_config(llama_config): nemo_config = OmegaConf.load( os.path.join(os.path.dirname(__file__), '../../examples/nlp/language_modeling/conf/megatron_llama_config.yaml') ).model @@ -138,7 +138,7 @@ def convert(args): for name, param in model.named_parameters(): print(f"- {name}") - nemo_config = load_config(args, hf_config) + nemo_config = load_config(hf_config) if args.precision in ["32", "16"]: precision = int(float(args.precision)) @@ -170,15 +170,6 @@ def convert(args): else: plugins.append(PipelineMixedPrecisionPlugin(precision=plugin_precision, device='cuda', scaler=scaler)) - if precision == 32: - dtype = torch.float32 - elif precision in [16, "16", "16-mixed"]: - dtype = torch.float16 - elif precision in ["bf16", "bf16-mixed"]: - dtype = torch.bfloat16 - else: - dtype = torch.float32 # fallback - nemo_config.precision = precision print(f"nemo_config: {nemo_config}") @@ -315,6 +306,7 @@ def convert(args): model._save_restore_connector = NLPSaveRestoreConnector() # cast to target precision and disable cpu init + dtype = torch_dtype_from_precision(precision) model = model.to(dtype=dtype) model.cfg.use_cpu_initialization = False diff --git a/scripts/nlp_language_modeling/convert_mpt_7b_hf_to_nemo.py b/scripts/nlp_language_modeling/convert_mpt_7b_hf_to_nemo.py index fd761b6b20c2..2261f70ea928 100644 --- a/scripts/nlp_language_modeling/convert_mpt_7b_hf_to_nemo.py +++ b/scripts/nlp_language_modeling/convert_mpt_7b_hf_to_nemo.py @@ -148,7 +148,6 @@ 'precision': 'bf16', 'logger': False, # logger provided by exp_manager 'enable_checkpointing': False, - 'replace_sampler_ddp': False, 'max_epochs': -1, # PTL default. In practice, max_steps will be reached first. 'max_steps': 100000, # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches 'log_every_n_steps': 10, diff --git a/tutorials/asr/ASR_with_NeMo.ipynb b/tutorials/asr/ASR_with_NeMo.ipynb index 74cd0f739e84..afda092b8ecc 100644 --- a/tutorials/asr/ASR_with_NeMo.ipynb +++ b/tutorials/asr/ASR_with_NeMo.ipynb @@ -267,7 +267,7 @@ "plt.title('Waveform of Audio Example')\n", "plt.ylabel('Amplitude')\n", "\n", - "_ = librosa.display.waveshow(audio)" + "_ = librosa.display.waveshow(audio, color='blue')" ], "execution_count": null, "outputs": [] @@ -330,7 +330,7 @@ }, "source": [ "# Plot the mel spectrogram of our sample\n", - "mel_spec = librosa.feature.melspectrogram(audio, sr=sample_rate)\n", + "mel_spec = librosa.feature.melspectrogram(y=audio, sr=sample_rate)\n", "mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)\n", "\n", "librosa.display.specshow(\n",