From d771ca545298c954b6d17b39352cc48996c0d935 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Thu, 15 Aug 2024 17:22:05 -0700 Subject: [PATCH] Debug CUDA graph support with operation-based API Signed-off-by: Tim Moon --- tests/pytorch/test_cuda_graphs.py | 27 ++++- transformer_engine/pytorch/graph.py | 47 ++++++-- .../pytorch/ops/basic/basic_linear.py | 4 +- transformer_engine/pytorch/ops/basic/bias.py | 4 +- transformer_engine/pytorch/ops/op.py | 105 ++++++++++++++++-- 5 files changed, 160 insertions(+), 27 deletions(-) diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index 60a5a1ea99..97f18b037b 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -21,6 +21,7 @@ ) from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.utils import is_bf16_compatible +import transformer_engine.pytorch.ops as te_ops # Only run FP8 tests on H100. @@ -48,7 +49,15 @@ class ModelConfig: model_configs = {"small": ModelConfig(2, 32, 64, 2, 32)} -modules = ["transformer", "layernorm_mlp", "layernorm_linear", "linear", "mha", "dpa"] +modules = [ + "transformer", + "layernorm_mlp", + "layernorm_linear", + "linear", + "mha", + "dpa", + "linear_op", +] all_boolean = [True, False] @@ -171,7 +180,10 @@ def _test_cuda_graphs( """Helper function for CUDA graph test.""" reset_rng_states() FP8GlobalStateManager.reset() + dpa = module == "dpa" + if module == "linear_op": + fp8_weight_caching = False with fp8_model_init(enabled=fp8_params): # Create modules. @@ -209,18 +221,27 @@ def _test_cuda_graphs( ) for _ in range(num_layers) ] - elif dpa: + elif module == "dpa": assert config.hidden_size % config.num_heads == 0, "Err." assert num_layers == 1, "Err." modules = [ DotProductAttention(config.num_heads, config.kv_channels, attention_dropout=0.0) for _ in range(num_layers) ] - else: + elif module == "linear": modules = [ Linear(config.hidden_size, config.hidden_size, device="cuda", params_dtype=dtype) for _ in range(num_layers) ] + elif module == "linear_op": + modules = [ + te_ops.Sequential( + te_ops.Linear(config.hidden_size, config.hidden_size, dtype=dtype), + ) + for _ in range(num_layers) + ] + else: + raise ValueError(f"Unknown module type ({module})") # Initialize gradient buffers. for module in modules: diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index e2642bc360..b8b383ad6e 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Functions for CUDA Graphs support in FP8""" +from collections.abc import Iterable from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union import torch @@ -19,7 +20,6 @@ from .distributed import get_all_rng_states, graph_safe_rng_available from .module.base import TransformerEngineBaseModule - __all__ = ["make_graphed_callables"] @@ -483,27 +483,56 @@ def new_fwd(*user_args, **user_kwargs): return tuple(ret) -def save_fp8_tensors(modules, amax_history_len): +def save_fp8_tensors( + modules: Iterable[torch.nn.Module], + fp8_recipe: DelayedScaling, +) -> Any: """ Returns the FP8 tensors for all modules with adjusted amax history sizes. """ - saved_fp8_meta_tensors = [] + from .ops import Sequential, FusibleOperation # Avoid circular import + fp8_tensors = [] for module in modules: for m in module.modules(): + module_tensors = None if isinstance(m, TransformerEngineBaseModule): if m.primary_weights_in_fp8: - m.adjust_amax_history_length(amax_history_len) - saved_fp8_meta_tensors.append(m.get_fp8_meta_tensors()) - return saved_fp8_meta_tensors + m.adjust_amax_history_length(fp8_recipe.amax_history_len) + module_tensors = m.get_fp8_meta_tensors() + elif isinstance(m, FusibleOperation): + if m.is_fused_op: + module_tensors = save_fp8_tensors(m.basic_ops, fp8_recipe) + else: + m.pre_forward( + fp8_enabled=True, + fp8_recipe=fp8_recipe, + ) + module_tensors = m._save_fp8_metas() + elif isinstance(m, Sequential): + module_tensors = save_fp8_tensors(m, fp8_recipe) + fp8_tensors.append(module_tensors) + return fp8_tensors -def restore_fp8_tensors(modules, fp8_tensors): +def restore_fp8_tensors( + modules: Iterable[torch.nn.Module], + fp8_tensors: Any, +) -> None: """Restore FP8 tensors.""" + from .ops import Sequential, FusibleOperation # Avoid circular import for module in modules: for m in module.modules(): + module_tensors = fp8_tensors.pop(0) if isinstance(m, TransformerEngineBaseModule): - m.reset_fp8_meta_tensors(fp8_tensors.pop(0)) + m.reset_fp8_meta_tensors(module_tensors) + elif isinstance(m, FusibleOperation): + if m.is_fused_op: + restore_fp8_tensors(m.basic_ops, module_tensors) + else: + m._load_fp8_metas(module_tensors) + elif isinstance(m, Sequential): + restore_fp8_tensors(m, module_tensors) assert len(fp8_tensors) == 0, "TE internal error." @@ -573,7 +602,7 @@ def make_graphed_callables( modules = (modules,) # Store FP8 tensors to reset later. - saved_fp8_tensors = save_fp8_tensors(modules, fp8_recipe.amax_history_len) + saved_fp8_tensors = save_fp8_tensors(modules, fp8_recipe=fp8_recipe) # FP8 wrapper. def wrap_autocast(block): diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 826807d1c0..3c9f3b3bc8 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -300,8 +300,8 @@ def reset_parameters(self) -> None: weight = torch.nn.Parameter(weight) self.weight = weight - def pre_forward(self) -> None: - super().pre_forward() + def pre_forward(self, *args, **kwargs) -> None: + super().pre_forward(*args, **kwargs) if self.weight.device.type == "meta": self.reset_parameters() diff --git a/transformer_engine/pytorch/ops/basic/bias.py b/transformer_engine/pytorch/ops/basic/bias.py index b8e8cc5e56..7688aa2ea1 100644 --- a/transformer_engine/pytorch/ops/basic/bias.py +++ b/transformer_engine/pytorch/ops/basic/bias.py @@ -113,8 +113,8 @@ def reset_parameters(self) -> None: bias = torch.nn.Parameter(bias) self.bias = bias - def pre_forward(self) -> None: - super().pre_forward() + def pre_forward(self, *args, **kwargs) -> None: + super().pre_forward(*args, **kwargs) if self.bias.device.type == "meta": self.reset_parameters() diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 47c6567056..d1f5f2c719 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -14,6 +14,7 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.fp8 import ( + DelayedScaling, FP8GlobalStateManager, get_default_fp8_recipe, ) @@ -232,25 +233,39 @@ def _make_meta( ) @classmethod - def _maybe_update_fp8_meta(cls, fp8_meta: Optional[dict[str, Any]]) -> None: + def _maybe_update_fp8_meta( + cls, + fp8_meta: Optional[dict[str, Any]], + *, + fp8_recipe: Optional[DelayedScaling] = None, + ) -> None: if fp8_meta is None: return - # Update FP8 recipe and communication group - recipe = FP8GlobalStateManager.get_fp8_recipe() - fp8_meta["recipe"] = recipe + # Update FP8 recipe + if fp8_recipe is None: + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_recipe is None: + fp8_recipe = get_default_fp8_recipe() + fp8_meta["recipe"] = fp8_recipe + + # Update FP8 communication group fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() # Adjust amax history length if needed - amax_history_len = recipe.amax_history_len + amax_history_len = fp8_recipe.amax_history_len for is_forward in (True, False): - key = FP8GlobalStateManager.get_meta_tensor_key(forward=is_forward) - if key not in fp8_meta: + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=is_forward) + if fp8_meta_key not in fp8_meta: continue - meta = fp8_meta[key] + meta = fp8_meta[fp8_meta_key] curr_len = meta.amax_history.size(0) + + # Nothing to be done if amax history is already correct if curr_len == amax_history_len: continue + + # Reallocate amax history with torch.no_grad(): if curr_len > amax_history_len: meta.amax_history = meta.amax_history[:amax_history_len].clone() @@ -260,6 +275,21 @@ def _maybe_update_fp8_meta(cls, fp8_meta: Optional[dict[str, Any]]) -> None: pad=(0, 0, 0, amax_history_len - curr_len), ) + # Update global buffers for amax reductions + buffer_info_key = FP8GlobalStateManager.get_buffer_info() + if buffer_info_key in fp8_meta: + fwd_pos, fwd_key, bwd_pos, bwd_key = fp8_meta[buffer_info_key] + for pos, buffer_key in zip((fwd_pos, bwd_pos), (fwd_key, bwd_key)): + assert ( + buffer_key in FP8GlobalStateManager.global_amax_history_buffer + ), "TE internal error during amax history change." + FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = fp8_meta[ + fp8_meta_key + ].amax_history[0] + FP8GlobalStateManager.global_amax_history_buffer[buffer_key][pos] = ( + fp8_meta[fp8_meta_key].amax_history + ) + def get_fp8_meta(self, mode: str) -> Optional[dict[str, Any]]: """FP8 metadata @@ -273,11 +303,64 @@ def get_fp8_meta(self, mode: str) -> Optional[dict[str, Any]]: self._fp8_metas = self._make_fp8_metas() return self._fp8_metas[mode] - def pre_forward(self) -> None: + @torch.no_grad() + def _save_fp8_metas(self) -> Optional[dict[str, Any]]: + """Create copies of tensors in FP8 metadata + + Tensor copies can be loaded with _load_fp8_metas. + + """ + if self._fp8_metas is None: + return None + out = {} + for mode, fp8_meta in self._fp8_metas.items(): + if fp8_meta is None: + continue + out[mode] = {} + for is_forward in (True, False): + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=is_forward) + if fp8_meta_key not in fp8_meta: + continue + out[mode][fp8_meta_key] = ( + fp8_meta[fp8_meta_key].scale.clone(), + fp8_meta[fp8_meta_key].scale_inv.clone(), + fp8_meta[fp8_meta_key].amax_history.clone(), + ) + return out + + @torch.no_grad() + def _load_fp8_metas(self, fp8_metas: Optional[dict[str, Any]]) -> None: + """Update FP8 metadata with saved tensor copies + + Tensor copies should be generated with _save_fp8_metas. + + """ + assert (self._fp8_metas is None) == (fp8_metas is None), \ + "Saved FP8 metadata does not match operation's FP8 metadata" + if fp8_metas is None: + return + for mode, fp8_meta in fp8_metas.items(): + assert mode in self._fp8_metas, \ + f"Found an unexpected key ({mode=}) in saved FP8 metadata" + for fp8_meta_key, tensors in fp8_meta.items(): + assert fp8_meta_key in self._fp8_metas[mode], \ + f"Found an unexpected key ({mode=}, {fp8_meta_key=}) in saved FP8 metadata" + scale, scale_inv, amax_history = tensors + self._fp8_metas[mode][fp8_meta_key].scale.copy_(scale) + self._fp8_metas[mode][fp8_meta_key].scale_inv.copy_(scale_inv) + self._fp8_metas[mode][fp8_meta_key].amax_history.copy_(amax_history) + + def pre_forward( + self, + *, + fp8_enabled: Optional[bool] = None, + fp8_recipe: Optional[DelayedScaling] = None, + ) -> None: """Preprocessing before forward pass""" # Initialize FP8 metadata if needed - fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() + if fp8_enabled is None: + fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() if fp8_enabled: # Construct FP8 metadata if needed @@ -286,7 +369,7 @@ def pre_forward(self) -> None: # Make sure FP8 metadata matches FP8 autocast context for fp8_meta in self._fp8_metas.values(): - self._maybe_update_fp8_meta(fp8_meta) + self._maybe_update_fp8_meta(fp8_meta, fp8_recipe=fp8_recipe) # Register FP8 metadata for amax and scale update if not FP8GlobalStateManager.fp8_graph_capturing():