Skip to content

Commit

Permalink
Merge branch 'r2.0.0' into jlasek/vllm_tokenizer_bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
oyilmaz-nvidia authored Oct 2, 2024
2 parents 7698b0f + dea8c32 commit 425c0a9
Show file tree
Hide file tree
Showing 13 changed files with 386 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@
from importlib.metadata import version
from typing import Tuple

import packaging
import torch
import torch.nn.functional as F
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.parallel_state import get_tensor_model_parallel_group
from megatron.core.transformer import TransformerConfig
from pkg_resources import packaging
from torch import Tensor
from torch.nn.modules.loss import _Loss

Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/llm/peft/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def forward(self, x):
linear_output, bias, layernorm_output = linear_output
x = layernorm_output

adapter_output = self.adapter(x)
adapter_output = self.adapter(x.contiguous())
return linear_output + adapter_output, bias


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
from typing import Any, Optional

import numpy as np
import packaging
import torch
import torch.nn.functional as F
from einops import rearrange, reduce, repeat
from omegaconf import DictConfig, ListConfig, OmegaConf
from pkg_resources import packaging
from pytorch_lightning.trainer.trainer import Trainer
from transformers import CLIPVisionModel, SiglipVisionModel

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from importlib.metadata import version
from typing import Any, Callable, Optional

import packaging
import torch
from pkg_resources import packaging

from nemo.collections.nlp.modules.common.megatron.utils import ApexGuardDefaults
from nemo.collections.nlp.parts import utils_funcs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.moe.moe_layer import MoELayer, MoESubmodules
from megatron.core.transformer.moe.shared_experts import SharedExpertMLP
from megatron.core.transformer.moe.moe_layer import MoELayer
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules

Expand Down Expand Up @@ -49,7 +48,6 @@ def get_gpt_layer_modelopt_spec(num_experts: int = None) -> ModuleSpec:
if not HAVE_MEGATRON_CORE:
raise IMPORT_ERROR

mlp = _get_mlp_module_spec(num_experts=num_experts)
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
Expand All @@ -67,7 +65,7 @@ def get_gpt_layer_modelopt_spec(num_experts: int = None) -> ModuleSpec:
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=TENorm,
mlp=mlp,
mlp=_get_mlp_module_spec(num_experts=num_experts),
mlp_bda=get_bias_dropout_add,
# Map TE-layernorm-fusion keys back
sharded_state_dict_keys_map={
Expand All @@ -79,7 +77,7 @@ def get_gpt_layer_modelopt_spec(num_experts: int = None) -> ModuleSpec:


# Helper function to get module spec for MLP/MoE
def _get_mlp_module_spec(num_experts: int = None) -> ModuleSpec:
def _get_mlp_module_spec(num_experts: int = None, moe_grouped_gemm: bool = False) -> ModuleSpec:
if num_experts is None:
# Dense MLP w/ or w/o TE modules.
return ModuleSpec(
Expand All @@ -93,18 +91,12 @@ def _get_mlp_module_spec(num_experts: int = None) -> ModuleSpec:
# Mixture of experts with modules in megatron core.
return ModuleSpec(
module=MoELayer,
submodules=MoESubmodules(
experts=MLPSubmodules(
submodules=(
MLPSubmodules(
linear_fc1=ColumnParallelLinear,
linear_fc2=RowParallelLinear,
),
shared_experts=ModuleSpec(
module=SharedExpertMLP,
params={"gate": False},
submodules=MLPSubmodules(
linear_fc1=ColumnParallelLinear,
linear_fc2=RowParallelLinear,
),
),
)
if not moe_grouped_gemm
else None
),
)
48 changes: 36 additions & 12 deletions nemo/collections/nlp/modules/common/megatron/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,9 @@

try:
# Flash Attention Triton
import pkg_resources
from flash_attn.flash_attn_triton import flash_attn_func as flash_attn_func_triton

except (ImportError, ModuleNotFoundError, pkg_resources.DistributionNotFound):
except (ImportError, ModuleNotFoundError):

flash_attn_func_triton = None

Expand Down Expand Up @@ -202,7 +201,12 @@ def __init__(
else:
assert attention_type == AttnType.cross_attn
self.query = tensor_parallel.ColumnParallelLinear(
hidden_size, projection_size, config=config, gather_output=False, init_method=init_method, bias=bias,
hidden_size,
projection_size,
config=config,
gather_output=False,
init_method=init_method,
bias=bias,
)

self.key_value = tensor_parallel.ColumnParallelLinear(
Expand Down Expand Up @@ -336,7 +340,7 @@ def _transpose_last_dim(self, mixed_layer, num_splits, num_splits_first):
"""[s, b, num_splits * np * hn]
-->(view) [s, b, num_splits, np, hn]
-->(tranpose) [s, b, np, num_splits, hn]
-->(view) [s, b, np * num_splits * hn] """
-->(view) [s, b, np * num_splits * hn]"""

intermediate_shape = input_shape[:-1] + (
num_splits,
Expand All @@ -350,7 +354,7 @@ def _transpose_last_dim(self, mixed_layer, num_splits, num_splits_first):
"""[s, b, np * hn * num_splits]
-->(view) [s, b, np, hn, num_splits]
-->(tranpose) [s, b, np, num_splits, hn]
-->(view) [s, b, np * num_splits * hn] """
-->(view) [s, b, np * num_splits * hn]"""

intermediate_shape = input_shape[:-1] + (
self.num_attention_heads_per_partition,
Expand Down Expand Up @@ -535,7 +539,10 @@ def forward(
)
v = _cast_if_autocast_enabled(rearrange(value_layer, 'sk b np hn -> b sk np hn'))
context_layer = flash_attn_with_kvcache(
q=q, k_cache=k, v_cache=v, causal=self.attn_mask_type == AttnMaskType.causal,
q=q,
k_cache=k,
v_cache=v,
causal=self.attn_mask_type == AttnMaskType.causal,
)
context_layer = rearrange(context_layer, 'b sq np hn -> sq b (np hn)')

Expand Down Expand Up @@ -742,9 +749,9 @@ def forward(


class CoreAttention(MegatronModule):
""" Region where selective activation recomputation is applied.
See Figure 3. in Reducing Activation Recomputation in Large Transformer Models
https://arxiv.org/pdf/2205.05198.pdf for more details.
"""Region where selective activation recomputation is applied.
See Figure 3. in Reducing Activation Recomputation in Large Transformer Models
https://arxiv.org/pdf/2205.05198.pdf for more details.
"""

Expand Down Expand Up @@ -994,10 +1001,21 @@ def flash_attention(self, query_layer, key_layer, value_layer, attention_mask, a

if attention_bias is not None:
return self.flash_attention_triton(
query_layer, key_layer, value_layer, attention_mask, attention_bias, is_causal,
query_layer,
key_layer,
value_layer,
attention_mask,
attention_bias,
is_causal,
)
else:
return self.flash_attention_cuda(query_layer, key_layer, value_layer, attention_mask, is_causal,)
return self.flash_attention_cuda(
query_layer,
key_layer,
value_layer,
attention_mask,
is_causal,
)

def flash_attention_cuda(self, query_layer, key_layer, value_layer, attention_mask, is_causal):
batch_size, seqlen, nheads, _ = query_layer.shape
Expand Down Expand Up @@ -1071,7 +1089,13 @@ def flash_attention_triton(self, query_layer, key_layer, value_layer, attention_
if attention_bias.shape[3] == attention_mask_kv.shape[3]:
attention_bias = attention_bias.masked_fill(~attention_mask_kv, torch.finfo(query_layer.dtype).min)

context_layer = flash_attn_func_triton(query_layer, key_layer, value_layer, attention_bias, is_causal,)
context_layer = flash_attn_func_triton(
query_layer,
key_layer,
value_layer,
attention_bias,
is_causal,
)

# [b, sq, np, hn] -> [b, np, sq, hn]
context_layer = context_layer.permute(0, 2, 1, 3)
Expand Down
29 changes: 13 additions & 16 deletions nemo/lightning/pytorch/callbacks/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,16 @@ def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str)
super().setup(trainer, pl_module, stage=stage)

trainer.strategy.trainer = trainer
self.wrapped_io = WrappedAdapterIO(trainer.strategy.checkpoint_io)
self.wrapped_io = WrappedAdapterIO(trainer.strategy.checkpoint_io, self)
trainer.strategy._checkpoint_io = self.wrapped_io
trainer.strategy._init_model_parallel = False
trainer.strategy._setup_optimizers = False

def apply_transform(self, trainer):
super().apply_transform(trainer)
self.trainable_params = set(
name for name, param in trainer.lightning_module.named_parameters() if param.requires_grad
)

adapter_sharded_state_dict = {}
if self.wrapped_io.adapter_ckpt_path is not None:
Expand Down Expand Up @@ -137,22 +140,8 @@ def apply_transform(self, trainer):
if trainer.state.fn == TrainerFn.FITTING:
trainer.strategy.load_optimizer_state_dict(adapter_state, selective_restore=True)

def on_save_checkpoint(
self, trainer: pl.Trainer, pl_module: pl.LightningModule, checkpoint: Dict[str, Any]
) -> None:
# Filter out non-trainable parameters
trainable_params = set(name for name, param in pl_module.named_parameters() if param.requires_grad)
filtered_state_dict = {}
for name, value in trainer.strategy.megatron_parallel.sharded_state_dict().items():
if name in trainable_params:
filtered_state_dict[name] = value
elif self.adapter_key_filter(name): # Include all adapter-related parameters
filtered_state_dict[name] = value

checkpoint['sharded_state_dict'] = filtered_state_dict

def adapter_key_filter(self, key: str) -> bool:
return ".adapter." in key or key.endswith(".adapters")
return key in self.trainable_params or ".adapter." in key or key.endswith(".adapters")


class AdapterWrapper(nn.Module):
Expand Down Expand Up @@ -269,13 +258,21 @@ def load_state_dict(self, state_dict, strict=True):


class WrappedAdapterIO(_WrappingCheckpointIO):
peft: Optional[PEFT] = None
model_ckpt_path: Optional[Path] = None
adapter_ckpt_path: Optional[Path] = None

def __init__(self, checkpoint_io: Optional["CheckpointIO"] = None, peft: Optional[PEFT] = None) -> None:
self.peft = peft
super().__init__(checkpoint_io)

@override
def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None:
assert self.checkpoint_io is not None

checkpoint['sharded_state_dict'] = dict(
filter(lambda item: self.peft.adapter_key_filter(item[0]), checkpoint['sharded_state_dict'].items())
)
self.checkpoint_io.save_checkpoint(checkpoint, path, storage_options=storage_options)

from nemo.utils.get_rank import is_global_rank_zero
Expand Down
8 changes: 2 additions & 6 deletions nemo/lightning/resume.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,9 @@ def _try_import_model(

if model is None:
raise ValueError("Model is needed to import checkpoint from HF or other non-NeMo checkpoint format.")
try:
if '://' in path:
new_path = model.import_ckpt(path)
except (ValueError, AttributeError):
# This is reached when the model connector does not exist for the particular path.
else:
new_path = path

if adapter_path:
Expand All @@ -143,9 +142,6 @@ def _resume_peft(self, adapter_meta_path, model):
metadata = json.load(f)

assert self.restore_config, "PEFT resume requires specifying restore_config"
assert (
"://" in self.restore_config.path
), "For now PEFT resume requires specifying the import path instead of local path"
base_model_path = self._try_import_model(model, self.restore_config.path)
if base_model_path != Path(metadata['model_ckpt_path']):
raise ValueError(
Expand Down
68 changes: 66 additions & 2 deletions nemo/lightning/run/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import copy
import logging
import os
from dataclasses import dataclass, field
from pathlib import Path
Expand All @@ -35,7 +34,7 @@

def _merge_callbacks(partial: run.Partial, callbacks: list[run.Config[Callback]]):
if hasattr(partial, "trainer"):
if hasattr(partial.trainer, "callbacks"):
if hasattr(partial.trainer, "callbacks") and partial.trainer.callbacks:
for callback in callbacks:
if callback not in partial.trainer.callbacks:
partial.trainer.callbacks.append(callback)
Expand Down Expand Up @@ -177,3 +176,68 @@ def setup(self, task: run.Partial | run.Script, executor: run.Executor):
logging.warning(
f"The {self.__class__.__name__} will have no effect as WANDB_API_KEY environment variable is not set."
)


@dataclass(kw_only=True)
class ConfigValidationPlugin(run.Plugin):
"""
A plugin for validating a NeMo task with its executor.
This plugin is used to ensure that the NeMo environment, task, and executor meet certain criteria.
The validation checks include preemption, checkpoint directory,
serialization, and Weights and Biases (wandb) integration.
Attributes:
validate_preemption (bool): Whether to validate the preemption callback. If set to True, the plugin will
assert that the task has a `PreemptionCallback`. Defaults to True.
validate_checkpoint_dir (bool): Whether to validate the checkpoint directory. If set to True and the executor
is a `SlurmExecutor`, the plugin will assert that the task's log directory exists in the mounts
specified in the `SlurmExecutor`. Defaults to True.
validate_serialization (bool): Whether to validate task serialization. If set to True, the plugin will
assert that the task can be successfully serialized and deserialized using NeMo-Run's
`ZlibJSONSerializer`. Defaults to True.
validate_wandb (bool): Whether to validate Weights and Biases integration. If set to True, the plugin will
assert that the executor's environment variables contain a `WANDB_API_KEY`
and that NeMo Logger's `wandb` is set. Defaults to False.
validate_nodes_and_devices (bool): Whether to validate the number of devices and nodes. If set to True, the plugin will assert that the task's
trainer is configured to use the same number of nodes and devices as the executor. Defaults to True.
"""

validate_preemption: bool = True
validate_checkpoint_dir: bool = True
validate_serialization: bool = True
validate_wandb: bool = False
validate_nodes_and_devices: bool = True

def setup(self, task: run.Partial | run.Script, executor: run.Executor):
assert isinstance(task, run.Partial)
logging.info(f"Validating {task.__fn_or_cls__.__qualname__} and {executor.__class__.__qualname__}.")
if self.validate_preemption:
logging.info("Validating preemption callback")
assert any(map(lambda callback: callback.__fn_or_cls__ == PreemptionCallback, task.trainer.callbacks))

if self.validate_checkpoint_dir:
if isinstance(executor, run.SlurmExecutor):
mounts = executor.container_mounts + ["/nemo_run"]
mounts = list(map(lambda m: m.split(":")[-1], mounts))
logging.info(f"Validating checkpoint dir {task.log.log_dir} exists in {mounts}")
assert task.log.log_dir
assert any(map(lambda mount: Path(mount) in Path(task.log.log_dir).parents, mounts))

if self.validate_serialization:
from nemo_run.core.serialization.zlib_json import ZlibJSONSerializer

logging.info("Validating serialization/de-serialization of task")
serializer = ZlibJSONSerializer()
assert serializer.deserialize(serializer.serialize(task)) == task

if self.validate_wandb:
logging.info("Validating that Weights and Biases is enabled for task")
assert "WANDB_API_KEY" in executor.env_vars.keys()
assert task.log.wandb

if self.validate_nodes_and_devices:
logging.info("Validating that nodes and devices match for task and executor")
if isinstance(executor, run.SlurmExecutor):
assert task.trainer.num_nodes == executor.nodes
assert task.trainer.devices == executor.nproc_per_node()
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ onnx>=1.7.0
python-dateutil
ruamel.yaml
scikit-learn
setuptools>=65.5.1
setuptools>=70.0.0
tensorboard
text-unidecode
torch
Expand Down
Loading

0 comments on commit 425c0a9

Please sign in to comment.