Skip to content

Commit

Permalink
Set TP overlap flag in ModelParallelConfig, fix TP overlap for LoRA (N…
Browse files Browse the repository at this point in the history
…VIDIA#8839)

* Set TP overlap flag in ModelParallelConfig, fix TP overlap for LoRA

Signed-off-by: Jaemin Choi <jaeminc@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Jaemin Choi <jaeminc@nvidia.com>
Co-authored-by: Jaemin Choi <jaeminc@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Eric Harper <complex451@gmail.com>
  • Loading branch information
4 people authored Apr 6, 2024
1 parent 91349ab commit 35e400f
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1155,6 +1155,7 @@ def build_model_parallel_config(self) -> ModelParallelConfig:
"no_sync_func": None, # set dynamically during training
"grad_sync_func": None, # set dynamically during training
"param_sync_func": None, # set dynamically during training
"tp_comm_overlap": self.cfg.get('ub_tp_comm_overlap', False),
}

# instantitate ModelParallelConfig from this dict
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,14 @@ def mcore_register_adapters(self):
[LoraKQVAdapterConfig._target_, LoraDenseAttentionAdapterConfig._target_, InfusedAdapterConfig._target_]
)
self.linear_qkv.return_layernorm_output = True # need layernorm output for lora mlp
if self.config.sequence_parallel and hasattr(self.linear_qkv, "return_layernorm_output_gathered"):
if (
self.config.sequence_parallel
and hasattr(self.linear_qkv, "return_layernorm_output_gathered")
and not self.config.tp_comm_overlap
):
# for LoRA SP, TE v1.5 can return layernorm output gathered so there is no need
# to perform the redundant gather in the adapter module.
# to perform the redundant gather in the adapter module, unless TP communication
# overlap is used.
self.linear_qkv.return_layernorm_output_gathered = True

def get_query_key_value_tensors(self, hidden_states, key_value_states=None):
Expand Down Expand Up @@ -249,9 +254,14 @@ def mcore_register_adapters(self):
[LoraHto4HAdapterConfig._target_, Lora4HtoHAdapterConfig._target_, MLPInfusedAdapterConfig._target_]
) # only self attn (packed qkv) for now
self.linear_fc1.return_layernorm_output = True # need layernorm output for lora mlp
if self.config.sequence_parallel and hasattr(self.linear_fc1, "return_layernorm_output_gathered"):
if (
self.config.sequence_parallel
and hasattr(self.linear_fc1, "return_layernorm_output_gathered")
and not self.config.tp_comm_overlap
):
# for LoRA SP, TE v1.5 can return layernorm output gathered so there is no need
# to perform the redundant gather in the adapter module.
# to perform the redundant gather in the adapter module, unless TP communication
# overlap is used.
self.linear_fc1.return_layernorm_output_gathered = True

def forward(self, hidden_states):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,9 +243,10 @@ def __init__(
from pkg_resources import packaging

te_version = packaging.version.Version(version("transformer-engine"))
if te_version >= packaging.version.Version("1.5.0dev"):
if te_version >= packaging.version.Version("1.5.0dev") and not model_parallel_config.tp_comm_overlap:
# TE 1.5 introduces the option `return_layernorm_output_gathered`, so the all gather
# in the forward method is not needed, so set self._sequence_parallel to False
# unless TP communication overlap is used
self._sequence_parallel = False

def _get_init_fn(self, init_method: str):
Expand Down

0 comments on commit 35e400f

Please sign in to comment.