diff --git a/nemo/export/trt_llm/nemo/nemo_ckpt_convert.py b/nemo/export/trt_llm/nemo/nemo_ckpt_convert.py index 0d8392ef23b5..a43642c82898 100644 --- a/nemo/export/trt_llm/nemo/nemo_ckpt_convert.py +++ b/nemo/export/trt_llm/nemo/nemo_ckpt_convert.py @@ -347,8 +347,8 @@ def convert_nemo_model( pp_last_rank = parallel_state.get_pipeline_model_parallel_last_rank() pp_size = parallel_state.get_pipeline_model_parallel_world_size() pp_group = parallel_state.get_pipeline_model_parallel_group() - pp_is_last = parallel_state.is_pipeline_last_stage() - pp_is_first = parallel_state.is_pipeline_first_stage() + pp_is_last = parallel_state.is_pipeline_last_stage(ignore_virtual=True) + pp_is_first = parallel_state.is_pipeline_first_stage(ignore_virtual=True) vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size() if not vp_size: vp_size = 1