From 1a77f9f8f744d43393437489cc04f1a3c312c111 Mon Sep 17 00:00:00 2001 From: Dong Hyuk Chang Date: Mon, 23 Sep 2024 11:49:23 -0400 Subject: [PATCH] Update dist_ckpt_load_strictness for sft and steerlm2 (#293) Signed-off-by: Dong Hyuk Chang Co-authored-by: Dong Hyuk Chang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- examples/nlp/gpt/train_gpt_sft.py | 3 +++ examples/nlp/gpt/train_steerlm2.py | 3 +++ nemo_aligner/utils/utils.py | 2 +- 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/examples/nlp/gpt/train_gpt_sft.py b/examples/nlp/gpt/train_gpt_sft.py index f52445637..dd4c48ba4 100644 --- a/examples/nlp/gpt/train_gpt_sft.py +++ b/examples/nlp/gpt/train_gpt_sft.py @@ -102,6 +102,9 @@ def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False): if cfg.model.get("seq_len_interpolation_factor", None) is not None: gpt_cfg.seq_len_interpolation_factor = cfg.model.seq_len_interpolation_factor + if cfg.model.get("dist_ckpt_load_strictness", None) is not None: + gpt_cfg.dist_ckpt_load_strictness = cfg.model.dist_ckpt_load_strictness + gpt_cfg.inference = cfg.model.get("inference", {}) # This is needed when modifying a hparam file directly to load `.ckpt` files. diff --git a/examples/nlp/gpt/train_steerlm2.py b/examples/nlp/gpt/train_steerlm2.py index 305588c46..62a01b3f4 100644 --- a/examples/nlp/gpt/train_steerlm2.py +++ b/examples/nlp/gpt/train_steerlm2.py @@ -140,6 +140,9 @@ def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False): if cfg.model.get("seq_len_interpolation_factor", None) is not None: gpt_cfg.seq_len_interpolation_factor = cfg.model.seq_len_interpolation_factor + if cfg.model.get("dist_ckpt_load_strictness", None) is not None: + gpt_cfg.dist_ckpt_load_strictness = cfg.model.dist_ckpt_load_strictness + gpt_cfg.inference = cfg.model.get("inference", {}) # This is needed when modifying a hparam file directly to load `.ckpt` files. diff --git a/nemo_aligner/utils/utils.py b/nemo_aligner/utils/utils.py index d67d2e952..7701b5e78 100644 --- a/nemo_aligner/utils/utils.py +++ b/nemo_aligner/utils/utils.py @@ -123,7 +123,7 @@ def load_checkpoint_model_config(restore_path): with tempfile.TemporaryDirectory() as tmpdir: # Extracts only model config - members = NLPSaveRestoreConnector._filtered_tar_info(restore_path, filter_fn=lambda name: '.yaml' in name) + members = NLPSaveRestoreConnector._filtered_tar_info(restore_path, filter_fn=lambda name: ".yaml" in name) NLPSaveRestoreConnector._unpack_nemo_file(restore_path, tmpdir, members=members) cfg = OmegaConf.load(os.path.join(tmpdir, config_name_in_ckpt))