Skip to content

Commit

Permalink
Merge branch 'main' into vllm_0.6.0_integration_test
Browse files Browse the repository at this point in the history
  • Loading branch information
oyilmaz-nvidia authored Oct 1, 2024
2 parents 7303c70 + 32503fd commit 700fa9a
Show file tree
Hide file tree
Showing 21 changed files with 467 additions and 55 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/_test_template.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ jobs:
(
set -e
docker run --rm --device=/dev/nvidia0 --gpus all --shm-size=64g --env TRANSFORMERS_OFFLINE=0 --env HYDRA_FULL_ERROR=1 --volume /mnt/datadrive/TestData:/home/TestData nemoci.azurecr.io/nemo_container_${{ github.run_id }} bash -c '${{ inputs.SCRIPT }}'
docker run --rm --runtime=nvidia --gpus all --shm-size=64g --env TRANSFORMERS_OFFLINE=0 --env HYDRA_FULL_ERROR=1 --volume /mnt/datadrive/TestData:/home/TestData nemoci.azurecr.io/nemo_container_${{ github.run_id }} bash -c '${{ inputs.SCRIPT }}'
) 2> >(tee err.log)
EXIT_CODE=$?
Expand All @@ -73,4 +73,4 @@ jobs:
- name: after_script
if: always() && inputs.AFTER_SCRIPT != ':'
run: |
docker run --rm --device=/dev/nvidia0 --gpus all --shm-size=64g --env TRANSFORMERS_OFFLINE=0 --env HYDRA_FULL_ERROR=1 --volume /mnt/datadrive/TestData:/home/TestData nemoci.azurecr.io/nemo_container_${{ github.run_id }} bash -c '${{ inputs.AFTER_SCRIPT }}'
docker run --rm --runtime=nvidia --gpus all --shm-size=64g --env TRANSFORMERS_OFFLINE=0 --env HYDRA_FULL_ERROR=1 --volume /mnt/datadrive/TestData:/home/TestData nemoci.azurecr.io/nemo_container_${{ github.run_id }} bash -c '${{ inputs.AFTER_SCRIPT }}'
17 changes: 12 additions & 5 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ jobs:
# Basic Import Checks
python -c "import nemo.collections.asr as nemo_asr"
python -c "import nemo.collections.nlp as nemo_nlp"
python -c "import nemo.collections.nlp as nemo_nlp; nemo_nlp.modules.get_tokenizer_list()"
python -c "import nemo.collections.tts as nemo_tts"
python setup.py style
Expand Down Expand Up @@ -5245,13 +5246,19 @@ jobs:
NVTE_FUSED_ATTN=0 NVTE_FLASH_ATTN=0 python tests/collections/llm/megatron_t5_pretraining.py \
--devices=2 \
--max-steps=3 \
--experiment-dir=tests/collections/llm/t5_pretrain_results \
--experiment-dir=tests/collections/llm/t5_pretrain_results/${{ github.run_id }} \
--data-path=/home/TestData/nlp/megatron_t5/data/pile_val_small_bert_tokenizer_text_document \
--index-mapping-dir=tests/collections/llm/t5_index_mappings
--index-mapping-dir=tests/collections/llm/t5_index_mappings/${{ github.run_id }}
NVTE_FUSED_ATTN=0 NVTE_FLASH_ATTN=0 python tests/collections/llm/megatron_t5_pretraining.py \
--devices=2 \
--max-steps=6 \
--experiment-dir=tests/collections/llm/t5_pretrain_results/${{ github.run_id }} \
--data-path=/home/TestData/nlp/megatron_t5/data/pile_val_small_bert_tokenizer_text_document \
--index-mapping-dir=tests/collections/llm/t5_index_mappings/${{ github.run_id }}
AFTER_SCRIPT: |
rm -rf tests/collections/llm/t5_pretrain_results
rm -rf tests/collections/llm/t5_index_mappings
rm -rf tests/collections/llm/t5_pretrain_results/${{ github.run_id }}
rm -rf tests/collections/llm/t5_index_mappings/${{ github.run_id }}
Nemo_CICD_Test:
needs:
Expand Down Expand Up @@ -5402,7 +5409,7 @@ jobs:
echo "FAILED=$FAILED" >> $GITHUB_OUTPUT
# Mark as successful if no job was cancelled:
SUCCESS=${{ !contains(needs.*.result, 'cancelled') }}
SUCCESS=${{ !contains(needs.*.result, 'cancelled') && !contains(needs.*.result, 'skipped') }}
echo "SUCCESS=$SUCCESS" >> $GITHUB_OUTPUT
# This should depend on all the tests so we block/unblock based on all tests passing
Expand Down
25 changes: 21 additions & 4 deletions examples/asr/transcribe_speech_parallel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -82,6 +82,7 @@
from nemo.collections.asr.data.audio_to_text_dataset import ASRPredictionWriter
from nemo.collections.asr.metrics.wer import word_error_rate
from nemo.collections.asr.models import ASRModel, EncDecHybridRNNTCTCModel
from nemo.collections.asr.models.aed_multitask_models import EncDecMultiTaskModel
from nemo.collections.asr.models.configs.asr_models_config import ASRDatasetConfig
from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig
from nemo.collections.asr.parts.submodules.rnnt_greedy_decoding import GreedyBatchedRNNTInferConfig
Expand All @@ -93,7 +94,9 @@
@dataclass
class ParallelTranscriptionConfig:
model: Optional[str] = None # name
predict_ds: ASRDatasetConfig = ASRDatasetConfig(return_sample_id=True, num_workers=4)
predict_ds: ASRDatasetConfig = ASRDatasetConfig(
return_sample_id=True, num_workers=4, min_duration=0, max_duration=40
)
output_path: str = MISSING

# when return_predictions is enabled, the prediction call would keep all the predictions in memory and return them when prediction is done
Expand Down Expand Up @@ -157,10 +160,24 @@ def main(cfg: ParallelTranscriptionConfig):
if isinstance(model, EncDecHybridRNNTCTCModel) and cfg.decoder_type is not None:
model.change_decoding_strategy(decoder_type=cfg.decoder_type)

trainer = ptl.Trainer(**cfg.trainer)

cfg.predict_ds.return_sample_id = True
cfg.predict_ds = match_train_config(predict_ds=cfg.predict_ds, train_ds=model.cfg.train_ds)

if isinstance(model, EncDecMultiTaskModel):
cfg.trainer.use_distributed_sampler = False
OmegaConf.set_struct(cfg.predict_ds, False)
cfg.predict_ds.use_lhotse = True
cfg.predict_ds.lang_field = "target_lang"
OmegaConf.set_struct(cfg.predict_ds, True)

trainer = ptl.Trainer(**cfg.trainer)

if isinstance(model, EncDecMultiTaskModel):
OmegaConf.set_struct(cfg.predict_ds, False)
cfg.predict_ds.global_rank = trainer.global_rank
cfg.predict_ds.world_size = trainer.world_size
OmegaConf.set_struct(cfg.predict_ds, True)

data_loader = model._setup_dataloader_from_config(cfg.predict_ds)

os.makedirs(cfg.output_path, exist_ok=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@
from importlib.metadata import version
from typing import Tuple

import packaging
import torch
import torch.nn.functional as F
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.parallel_state import get_tensor_model_parallel_group
from megatron.core.transformer import TransformerConfig
from pkg_resources import packaging
from torch import Tensor
from torch.nn.modules.loss import _Loss

Expand Down
30 changes: 22 additions & 8 deletions nemo/collections/asr/data/audio_to_text_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,16 +859,30 @@ def write_on_batch_end(
batch_idx: int,
dataloader_idx: int,
):
import lhotse

for sample_id, transcribed_text in prediction:
item = {}
sample = self.dataset.get_manifest_sample(sample_id)
item["audio_filepath"] = sample.audio_file
item["offset"] = sample.offset
item["duration"] = sample.duration
item["text"] = sample.text_raw
item["pred_text"] = transcribed_text
self.outf.write(json.dumps(item) + "\n")
self.samples_num += 1
if isinstance(sample_id, lhotse.cut.Cut):
sample = sample_id
if isinstance(sample, lhotse.cut.MixedCut):
sample = sample.first_non_padding_cut
item["audio_filepath"] = sample.recording.sources[0].source
item["offset"] = sample.start
item["duration"] = sample.duration
item["text"] = sample.supervisions[0].text
item["pred_text"] = transcribed_text
self.outf.write(json.dumps(item) + "\n")
self.samples_num += 1
else:
sample = self.dataset.get_manifest_sample(sample_id)
item["audio_filepath"] = sample.audio_file
item["offset"] = sample.offset
item["duration"] = sample.duration
item["text"] = sample.text_raw
item["pred_text"] = transcribed_text
self.outf.write(json.dumps(item) + "\n")
self.samples_num += 1
return

def close_output_file(self):
Expand Down
4 changes: 3 additions & 1 deletion nemo/collections/asr/data/audio_to_text_lhotse_prompted.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Callable, Union
from typing import Callable, Optional, Union

import torch.utils.data
from lhotse import CutSet
Expand All @@ -32,6 +32,7 @@ class PromptedAudioToTextMiniBatch:
prompt_lens: torch.Tensor
prompted_transcript: torch.Tensor
prompted_transcript_lens: torch.Tensor
cuts: Optional[CutSet] = None

def get_decoder_inputs_outputs(self) -> tuple[torch.Tensor, torch.Tensor]:
"""
Expand Down Expand Up @@ -98,6 +99,7 @@ def __getitem__(self, cuts: CutSet) -> PromptedAudioToTextMiniBatch:
prompt_lens=prompt_lens,
prompted_transcript=prompts_with_answers,
prompted_transcript_lens=prompts_with_answers_lens,
cuts=cuts.drop_in_memory_data(),
)

def _collate_tokens(self, tokens: list[Union[list[int], torch.Tensor]]) -> tuple[torch.Tensor, torch.Tensor]:
Expand Down
9 changes: 5 additions & 4 deletions nemo/collections/asr/models/aed_multitask_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,10 +504,12 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]):
"Multi-task model only supports dataloading with Lhotse. "
"Please set config.{train,validation,test}_ds.use_lhotse=True"
)
global_rank = config.get("global_rank", self.global_rank)
world_size = config.get("world_size", self.world_size)
return get_lhotse_dataloader_from_config(
config,
global_rank=self.global_rank,
world_size=self.world_size,
global_rank=global_rank,
world_size=world_size,
dataset=PromptedAudioToTextLhotseDataset(
tokenizer=self.tokenizer,
prompt_format_fn=get_prompt_format_fn(self.prompt_format),
Expand Down Expand Up @@ -1042,8 +1044,7 @@ def predict_step(
decoder_input_ids=batch.prompt,
return_hypotheses=False,
)[0]

return text
return list(zip(batch.cuts, text))

@property
def adapter_module_names(self) -> List[str]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
from typing import Any, Optional

import numpy as np
import packaging
import torch
import torch.nn.functional as F
from einops import rearrange, reduce, repeat
from omegaconf import DictConfig, ListConfig, OmegaConf
from pkg_resources import packaging
from pytorch_lightning.trainer.trainer import Trainer
from transformers import CLIPVisionModel, SiglipVisionModel

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from importlib.metadata import version
from typing import Any, Callable, Optional

import packaging
import torch
from pkg_resources import packaging

from nemo.collections.nlp.modules.common.megatron.utils import ApexGuardDefaults
from nemo.collections.nlp.parts import utils_funcs
Expand Down
14 changes: 12 additions & 2 deletions nemo/collections/nlp/modules/common/lm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,16 @@ def get_lm_model(
pretrain_model_name = ''
if cfg.get('language_model') and cfg.language_model.get('pretrained_model_name', ''):
pretrain_model_name = cfg.language_model.get('pretrained_model_name', '')

from nemo.collections.nlp.models.language_modeling.megatron_bert_model import MegatronBertModel

def get_megatron_pretrained_bert_models() -> List[str]:

all_pretrained_megatron_bert_models = [
model.pretrained_model_name for model in MegatronBertModel.list_available_models()
]
return all_pretrained_megatron_bert_models

all_pretrained_megatron_bert_models = get_megatron_pretrained_bert_models()
if (
cfg.tokenizer is not None
Expand Down Expand Up @@ -175,13 +185,13 @@ def get_transformer(
config_dict={
'_target_': 'transformers.BertConfig',
'hidden_size': 1536
})
})
Args:
library (str, optional): Can be 'nemo', 'huggingface', or 'megatron'. Defaults to 'nemo'.
model_name (Optional[str], optional): Named model architecture from the chosen library. Defaults to None.
pretrained (bool, optional): Use True to get pretrained weights.
pretrained (bool, optional): Use True to get pretrained weights.
False will use the same architecture but with randomly initialized weights.
Defaults to False.
config_dict (Optional[dict], optional): Use for custom configuration of transformer. Defaults to None.
Expand Down
5 changes: 1 addition & 4 deletions nemo/collections/nlp/modules/common/megatron/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo.collections.nlp.modules.common.megatron.megatron_utils import (
get_megatron_checkpoint,
get_megatron_lm_models_list,
)
from .megatron_utils import get_megatron_checkpoint, get_megatron_lm_models_list
48 changes: 36 additions & 12 deletions nemo/collections/nlp/modules/common/megatron/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,9 @@

try:
# Flash Attention Triton
import pkg_resources
from flash_attn.flash_attn_triton import flash_attn_func as flash_attn_func_triton

except (ImportError, ModuleNotFoundError, pkg_resources.DistributionNotFound):
except (ImportError, ModuleNotFoundError):

flash_attn_func_triton = None

Expand Down Expand Up @@ -202,7 +201,12 @@ def __init__(
else:
assert attention_type == AttnType.cross_attn
self.query = tensor_parallel.ColumnParallelLinear(
hidden_size, projection_size, config=config, gather_output=False, init_method=init_method, bias=bias,
hidden_size,
projection_size,
config=config,
gather_output=False,
init_method=init_method,
bias=bias,
)

self.key_value = tensor_parallel.ColumnParallelLinear(
Expand Down Expand Up @@ -336,7 +340,7 @@ def _transpose_last_dim(self, mixed_layer, num_splits, num_splits_first):
"""[s, b, num_splits * np * hn]
-->(view) [s, b, num_splits, np, hn]
-->(tranpose) [s, b, np, num_splits, hn]
-->(view) [s, b, np * num_splits * hn] """
-->(view) [s, b, np * num_splits * hn]"""

intermediate_shape = input_shape[:-1] + (
num_splits,
Expand All @@ -350,7 +354,7 @@ def _transpose_last_dim(self, mixed_layer, num_splits, num_splits_first):
"""[s, b, np * hn * num_splits]
-->(view) [s, b, np, hn, num_splits]
-->(tranpose) [s, b, np, num_splits, hn]
-->(view) [s, b, np * num_splits * hn] """
-->(view) [s, b, np * num_splits * hn]"""

intermediate_shape = input_shape[:-1] + (
self.num_attention_heads_per_partition,
Expand Down Expand Up @@ -535,7 +539,10 @@ def forward(
)
v = _cast_if_autocast_enabled(rearrange(value_layer, 'sk b np hn -> b sk np hn'))
context_layer = flash_attn_with_kvcache(
q=q, k_cache=k, v_cache=v, causal=self.attn_mask_type == AttnMaskType.causal,
q=q,
k_cache=k,
v_cache=v,
causal=self.attn_mask_type == AttnMaskType.causal,
)
context_layer = rearrange(context_layer, 'b sq np hn -> sq b (np hn)')

Expand Down Expand Up @@ -742,9 +749,9 @@ def forward(


class CoreAttention(MegatronModule):
""" Region where selective activation recomputation is applied.
See Figure 3. in Reducing Activation Recomputation in Large Transformer Models
https://arxiv.org/pdf/2205.05198.pdf for more details.
"""Region where selective activation recomputation is applied.
See Figure 3. in Reducing Activation Recomputation in Large Transformer Models
https://arxiv.org/pdf/2205.05198.pdf for more details.
"""

Expand Down Expand Up @@ -994,10 +1001,21 @@ def flash_attention(self, query_layer, key_layer, value_layer, attention_mask, a

if attention_bias is not None:
return self.flash_attention_triton(
query_layer, key_layer, value_layer, attention_mask, attention_bias, is_causal,
query_layer,
key_layer,
value_layer,
attention_mask,
attention_bias,
is_causal,
)
else:
return self.flash_attention_cuda(query_layer, key_layer, value_layer, attention_mask, is_causal,)
return self.flash_attention_cuda(
query_layer,
key_layer,
value_layer,
attention_mask,
is_causal,
)

def flash_attention_cuda(self, query_layer, key_layer, value_layer, attention_mask, is_causal):
batch_size, seqlen, nheads, _ = query_layer.shape
Expand Down Expand Up @@ -1071,7 +1089,13 @@ def flash_attention_triton(self, query_layer, key_layer, value_layer, attention_
if attention_bias.shape[3] == attention_mask_kv.shape[3]:
attention_bias = attention_bias.masked_fill(~attention_mask_kv, torch.finfo(query_layer.dtype).min)

context_layer = flash_attn_func_triton(query_layer, key_layer, value_layer, attention_bias, is_causal,)
context_layer = flash_attn_func_triton(
query_layer,
key_layer,
value_layer,
attention_bias,
is_causal,
)

# [b, sq, np, hn] -> [b, np, sq, hn]
context_layer = context_layer.permute(0, 2, 1, 3)
Expand Down
Loading

0 comments on commit 700fa9a

Please sign in to comment.