Skip to content

Commit

Permalink
Merge branch 'newptl_fix_validation_in_spellmapper' of github.com:ben…
Browse files Browse the repository at this point in the history
…e-ges/NeMo into newptl_fix_validation_in_spellmapper
  • Loading branch information
bene-ges committed Dec 4, 2023
2 parents 6507eba + 51ce14a commit f617f1f
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 19 deletions.
7 changes: 2 additions & 5 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@ pipeline {
timeout(time: 8, unit: 'HOURS')
disableConcurrentBuilds(abortPrevious: true)
}
environment {
NVTE_APPLY_QK_LAYER_SCALING = 1
}

stages {

Expand Down Expand Up @@ -75,8 +72,8 @@ pipeline {
steps {
sh 'git clone https://github.com/NVIDIA/Megatron-LM.git && \
cd Megatron-LM && \
git checkout e122536b7645edcb7ebf099b5c92a443f7dbf8e7 && \
pip install -e .'
git checkout 973330e9c3681604703bf1eb6b5a265d1b9b9b38 && \
pip install .'
}
}

Expand Down
3 changes: 3 additions & 0 deletions examples/nlp/language_modeling/tuning/megatron_gpt_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False):
if cfg.model.get('seq_len_interpolation_factor', None) is not None:
gpt_cfg.seq_len_interpolation_factor = cfg.model.seq_len_interpolation_factor

if cfg.model.get('rotary_base', None) is not None:
gpt_cfg.rotary_base = cfg.model.rotary_base

sft_cls = MegatronGPTSFTModel
gpt_cfg.target = f"{sft_cls.__module__}.{sft_cls.__name__}"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ def model_provider_func(self, pre_process, post_process):
position_embedding_type=self.cfg.get('position_embedding_type', 'learned_absolute'),
rotary_percent=self.cfg.get('rotary_percentage', 1.0),
seq_len_interpolation_factor=self.cfg.get('seq_len_interpolation_factor', None),
rotary_base=self.cfg.get('rotary_base', 10000),
)
else:
assert self.cfg.get('num_query_groups', None) is None or self.cfg.get(
Expand Down Expand Up @@ -1544,6 +1545,19 @@ def build_transformer_config(self) -> TransformerConfig:

attention_softmax_in_fp32 = False # not currently used in NeMo unless apply_query_key_layer_scaling is True
apply_query_key_layer_scaling = self.cfg.get('apply_query_key_layer_scaling', False)

fp16_enabled = self.trainer.precision in [16, '16', '16-mixed']
if apply_query_key_layer_scaling:
if fp16_enabled:
os.environ["NVTE_APPLY_QK_LAYER_SCALING"] = "1"
else:
logging.warning(
"apply_query_key_layer_scaling is only enabled when using FP16, setting it to False "
"and setting NVTE_APPLY_QK_LAYER_SCALING=0"
)
os.environ["NVTE_APPLY_QK_LAYER_SCALING"] = "0"
apply_query_key_layer_scaling = False

if apply_query_key_layer_scaling:
attention_softmax_in_fp32 = True

Expand All @@ -1570,6 +1584,7 @@ def build_transformer_config(self) -> TransformerConfig:

# any configs that are not in the nemo model config will be added here
config_mapping = {
'apply_query_key_layer_scaling': apply_query_key_layer_scaling,
'apply_residual_connection_post_layernorm': False, # we don't use this in NeMo
'layernorm_zero_centered_gamma': layernorm_zero_centered_gamma,
'add_bias_linear': add_bias_linear,
Expand Down
8 changes: 2 additions & 6 deletions nemo/collections/nlp/modules/common/text_generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,12 +507,8 @@ def synced_generate(

if compute_logprob:
precision = model._trainer.precision
if precision in [16, "16"]:
dtype = torch.float16
elif precision in ['bf16', 'bf16-mixed']:
dtype = torch.bfloat16
else:
dtype = torch.float32
dtype = torch.float32

output_logits = torch.empty(
tokens.size(0), context_length - 1, dtype=dtype, device=torch.device("cuda")
)
Expand Down
36 changes: 28 additions & 8 deletions scripts/checkpoint_averaging/distributed_checkpoint_averaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@
import logging
import os
import shutil

import numpy as np
import tensorstore # need to import it for bf16 support
import zarr

logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -84,6 +85,7 @@ def main():
n = len(checkpoint_paths)
# initialize dict, will be used to store the weights that need to be averaged
avg_weights = {}
chunk_info = {}

logging.info(f"Averaging {n} checkpoints ... {'at steps:' + str(args.steps) if args.steps is not None else ''}")

Expand Down Expand Up @@ -114,21 +116,22 @@ def main():

if item not in avg_weights:
logging.info(f"Initialized average weights dict with: {item}")
avg_weights[item] = zarr.open(os.path.join(full_path, item), mode='r')
array = zarr.open(os.path.join(full_path, item), mode='r')
avg_weights[item] = array[:]
chunk_info[item] = array.chunks
else:
logging.info(f"Updated average weights dict with weight: {item}")
array_z = zarr.open(os.path.join(full_path, item), mode='r')
sum_array = avg_weights[item][:] + array_z[:]
avg_weights[item] = zarr.array(sum_array, chunks=array_z.chunks, dtype=array_z.dtype)
sum_array = avg_weights[item] + array_z[:]
avg_weights[item] = sum_array

for k in avg_weights:
logging.info(f"Average weights dict key : {k}, dtype : {avg_weights[k].dtype}, shape : {avg_weights[k].shape}")
if str(avg_weights[k].dtype).startswith("int"):
raise ValueError("Int type not supported")
else:
array_z = avg_weights[k][:]
array_z = array_z / n
avg_weights[k] = zarr.array(array_z, chunks=avg_weights[k].chunks, dtype=avg_weights[k].dtype)
array_z = avg_weights[k] / n
avg_weights[k] = array_z

# Save model
if args.steps is None:
Expand All @@ -140,7 +143,24 @@ def main():
# save avg_weights
for k in avg_weights:
logging.info(f"Saving {k} to {ckpt_name}")
zarr.save(os.path.join(ckpt_name, k), avg_weights[k])
input_arr = avg_weights[k]
chunks = chunk_info[k]
# create the zarr array
output_array = zarr.create(
input_arr.shape,
dtype=input_arr.dtype,
store=os.path.join(ckpt_name, k),
chunks=chunks,
compressor=None,
fill_value=None,
write_empty_chunks=True,
)
if input_arr.dtype == np.dtype('bfloat16'):
arr = output_array
arr._dtype = input_arr.dtype
zarray = arr.store['.zarray']
arr.store['.zarray'] = zarray.replace(b'<V2', b'bfloat16')
output_array[:] = input_arr

# copy other files
for item in copy_items:
Expand Down
2 changes: 2 additions & 0 deletions scripts/nlp_language_modeling/convert_hf_llama_to_nemo.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ def load_config(args, llama_config):
nemo_config['seq_len_interpolation_factor'] = llama_config['rope_scaling']['factor']
else:
raise ValueError("Only linear rope scaling type is supported now")
if llama_config['rope_theta'] is not None:
nemo_config['rotary_base'] = llama_config['rope_theta']

base = 128
while llama_config['vocab_size'] % base != 0:
Expand Down

0 comments on commit f617f1f

Please sign in to comment.