diff --git a/examples/nlp/gpt/train_gpt_sft.py b/examples/nlp/gpt/train_gpt_sft.py index f52445637..dd4c48ba4 100644 --- a/examples/nlp/gpt/train_gpt_sft.py +++ b/examples/nlp/gpt/train_gpt_sft.py @@ -102,6 +102,9 @@ def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False): if cfg.model.get("seq_len_interpolation_factor", None) is not None: gpt_cfg.seq_len_interpolation_factor = cfg.model.seq_len_interpolation_factor + if cfg.model.get("dist_ckpt_load_strictness", None) is not None: + gpt_cfg.dist_ckpt_load_strictness = cfg.model.dist_ckpt_load_strictness + gpt_cfg.inference = cfg.model.get("inference", {}) # This is needed when modifying a hparam file directly to load `.ckpt` files. diff --git a/examples/nlp/gpt/train_steerlm2.py b/examples/nlp/gpt/train_steerlm2.py index 305588c46..62a01b3f4 100644 --- a/examples/nlp/gpt/train_steerlm2.py +++ b/examples/nlp/gpt/train_steerlm2.py @@ -140,6 +140,9 @@ def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False): if cfg.model.get("seq_len_interpolation_factor", None) is not None: gpt_cfg.seq_len_interpolation_factor = cfg.model.seq_len_interpolation_factor + if cfg.model.get("dist_ckpt_load_strictness", None) is not None: + gpt_cfg.dist_ckpt_load_strictness = cfg.model.dist_ckpt_load_strictness + gpt_cfg.inference = cfg.model.get("inference", {}) # This is needed when modifying a hparam file directly to load `.ckpt` files. diff --git a/nemo_aligner/utils/utils.py b/nemo_aligner/utils/utils.py index d67d2e952..7701b5e78 100644 --- a/nemo_aligner/utils/utils.py +++ b/nemo_aligner/utils/utils.py @@ -123,7 +123,7 @@ def load_checkpoint_model_config(restore_path): with tempfile.TemporaryDirectory() as tmpdir: # Extracts only model config - members = NLPSaveRestoreConnector._filtered_tar_info(restore_path, filter_fn=lambda name: '.yaml' in name) + members = NLPSaveRestoreConnector._filtered_tar_info(restore_path, filter_fn=lambda name: ".yaml" in name) NLPSaveRestoreConnector._unpack_nemo_file(restore_path, tmpdir, members=members) cfg = OmegaConf.load(os.path.join(tmpdir, config_name_in_ckpt))