Skip to content

Commit

Permalink
Debug CUDA graph support with operation-based API
Browse files Browse the repository at this point in the history
Signed-off-by: Tim Moon <tmoon@nvidia.com>
  • Loading branch information
timmoon10 committed Aug 16, 2024
1 parent 941364d commit d771ca5
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 27 deletions.
27 changes: 24 additions & 3 deletions tests/pytorch/test_cuda_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import is_bf16_compatible
import transformer_engine.pytorch.ops as te_ops


# Only run FP8 tests on H100.
Expand Down Expand Up @@ -48,7 +49,15 @@ class ModelConfig:

model_configs = {"small": ModelConfig(2, 32, 64, 2, 32)}

modules = ["transformer", "layernorm_mlp", "layernorm_linear", "linear", "mha", "dpa"]
modules = [
"transformer",
"layernorm_mlp",
"layernorm_linear",
"linear",
"mha",
"dpa",
"linear_op",
]

all_boolean = [True, False]

Expand Down Expand Up @@ -171,7 +180,10 @@ def _test_cuda_graphs(
"""Helper function for CUDA graph test."""
reset_rng_states()
FP8GlobalStateManager.reset()

dpa = module == "dpa"
if module == "linear_op":
fp8_weight_caching = False

with fp8_model_init(enabled=fp8_params):
# Create modules.
Expand Down Expand Up @@ -209,18 +221,27 @@ def _test_cuda_graphs(
)
for _ in range(num_layers)
]
elif dpa:
elif module == "dpa":
assert config.hidden_size % config.num_heads == 0, "Err."
assert num_layers == 1, "Err."
modules = [
DotProductAttention(config.num_heads, config.kv_channels, attention_dropout=0.0)
for _ in range(num_layers)
]
else:
elif module == "linear":
modules = [
Linear(config.hidden_size, config.hidden_size, device="cuda", params_dtype=dtype)
for _ in range(num_layers)
]
elif module == "linear_op":
modules = [
te_ops.Sequential(
te_ops.Linear(config.hidden_size, config.hidden_size, dtype=dtype),
)
for _ in range(num_layers)
]
else:
raise ValueError(f"Unknown module type ({module})")

# Initialize gradient buffers.
for module in modules:
Expand Down
47 changes: 38 additions & 9 deletions transformer_engine/pytorch/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# See LICENSE for license information.

"""Functions for CUDA Graphs support in FP8"""
from collections.abc import Iterable
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union

import torch
Expand All @@ -19,7 +20,6 @@
from .distributed import get_all_rng_states, graph_safe_rng_available
from .module.base import TransformerEngineBaseModule


__all__ = ["make_graphed_callables"]


Expand Down Expand Up @@ -483,27 +483,56 @@ def new_fwd(*user_args, **user_kwargs):
return tuple(ret)


def save_fp8_tensors(modules, amax_history_len):
def save_fp8_tensors(
modules: Iterable[torch.nn.Module],
fp8_recipe: DelayedScaling,
) -> Any:
"""
Returns the FP8 tensors for all modules
with adjusted amax history sizes.
"""
saved_fp8_meta_tensors = []
from .ops import Sequential, FusibleOperation # Avoid circular import
fp8_tensors = []
for module in modules:
for m in module.modules():
module_tensors = None
if isinstance(m, TransformerEngineBaseModule):
if m.primary_weights_in_fp8:
m.adjust_amax_history_length(amax_history_len)
saved_fp8_meta_tensors.append(m.get_fp8_meta_tensors())
return saved_fp8_meta_tensors
m.adjust_amax_history_length(fp8_recipe.amax_history_len)
module_tensors = m.get_fp8_meta_tensors()
elif isinstance(m, FusibleOperation):
if m.is_fused_op:
module_tensors = save_fp8_tensors(m.basic_ops, fp8_recipe)
else:
m.pre_forward(
fp8_enabled=True,
fp8_recipe=fp8_recipe,
)
module_tensors = m._save_fp8_metas()
elif isinstance(m, Sequential):
module_tensors = save_fp8_tensors(m, fp8_recipe)
fp8_tensors.append(module_tensors)
return fp8_tensors


def restore_fp8_tensors(modules, fp8_tensors):
def restore_fp8_tensors(
modules: Iterable[torch.nn.Module],
fp8_tensors: Any,
) -> None:
"""Restore FP8 tensors."""
from .ops import Sequential, FusibleOperation # Avoid circular import
for module in modules:
for m in module.modules():
module_tensors = fp8_tensors.pop(0)
if isinstance(m, TransformerEngineBaseModule):
m.reset_fp8_meta_tensors(fp8_tensors.pop(0))
m.reset_fp8_meta_tensors(module_tensors)
elif isinstance(m, FusibleOperation):
if m.is_fused_op:
restore_fp8_tensors(m.basic_ops, module_tensors)
else:
m._load_fp8_metas(module_tensors)
elif isinstance(m, Sequential):
restore_fp8_tensors(m, module_tensors)
assert len(fp8_tensors) == 0, "TE internal error."


Expand Down Expand Up @@ -573,7 +602,7 @@ def make_graphed_callables(
modules = (modules,)

# Store FP8 tensors to reset later.
saved_fp8_tensors = save_fp8_tensors(modules, fp8_recipe.amax_history_len)
saved_fp8_tensors = save_fp8_tensors(modules, fp8_recipe=fp8_recipe)

# FP8 wrapper.
def wrap_autocast(block):
Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/pytorch/ops/basic/basic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,8 @@ def reset_parameters(self) -> None:
weight = torch.nn.Parameter(weight)
self.weight = weight

def pre_forward(self) -> None:
super().pre_forward()
def pre_forward(self, *args, **kwargs) -> None:
super().pre_forward(*args, **kwargs)
if self.weight.device.type == "meta":
self.reset_parameters()

Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/pytorch/ops/basic/bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ def reset_parameters(self) -> None:
bias = torch.nn.Parameter(bias)
self.bias = bias

def pre_forward(self) -> None:
super().pre_forward()
def pre_forward(self, *args, **kwargs) -> None:
super().pre_forward(*args, **kwargs)
if self.bias.device.type == "meta":
self.reset_parameters()

Expand Down
105 changes: 94 additions & 11 deletions transformer_engine/pytorch/ops/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import (
DelayedScaling,
FP8GlobalStateManager,
get_default_fp8_recipe,
)
Expand Down Expand Up @@ -232,25 +233,39 @@ def _make_meta(
)

@classmethod
def _maybe_update_fp8_meta(cls, fp8_meta: Optional[dict[str, Any]]) -> None:
def _maybe_update_fp8_meta(
cls,
fp8_meta: Optional[dict[str, Any]],
*,
fp8_recipe: Optional[DelayedScaling] = None,
) -> None:
if fp8_meta is None:
return

# Update FP8 recipe and communication group
recipe = FP8GlobalStateManager.get_fp8_recipe()
fp8_meta["recipe"] = recipe
# Update FP8 recipe
if fp8_recipe is None:
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
if fp8_recipe is None:
fp8_recipe = get_default_fp8_recipe()
fp8_meta["recipe"] = fp8_recipe

# Update FP8 communication group
fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()

# Adjust amax history length if needed
amax_history_len = recipe.amax_history_len
amax_history_len = fp8_recipe.amax_history_len
for is_forward in (True, False):
key = FP8GlobalStateManager.get_meta_tensor_key(forward=is_forward)
if key not in fp8_meta:
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=is_forward)
if fp8_meta_key not in fp8_meta:
continue
meta = fp8_meta[key]
meta = fp8_meta[fp8_meta_key]
curr_len = meta.amax_history.size(0)

# Nothing to be done if amax history is already correct
if curr_len == amax_history_len:
continue

# Reallocate amax history
with torch.no_grad():
if curr_len > amax_history_len:
meta.amax_history = meta.amax_history[:amax_history_len].clone()
Expand All @@ -260,6 +275,21 @@ def _maybe_update_fp8_meta(cls, fp8_meta: Optional[dict[str, Any]]) -> None:
pad=(0, 0, 0, amax_history_len - curr_len),
)

# Update global buffers for amax reductions
buffer_info_key = FP8GlobalStateManager.get_buffer_info()
if buffer_info_key in fp8_meta:
fwd_pos, fwd_key, bwd_pos, bwd_key = fp8_meta[buffer_info_key]
for pos, buffer_key in zip((fwd_pos, bwd_pos), (fwd_key, bwd_key)):
assert (
buffer_key in FP8GlobalStateManager.global_amax_history_buffer
), "TE internal error during amax history change."
FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = fp8_meta[
fp8_meta_key
].amax_history[0]
FP8GlobalStateManager.global_amax_history_buffer[buffer_key][pos] = (
fp8_meta[fp8_meta_key].amax_history
)

def get_fp8_meta(self, mode: str) -> Optional[dict[str, Any]]:
"""FP8 metadata
Expand All @@ -273,11 +303,64 @@ def get_fp8_meta(self, mode: str) -> Optional[dict[str, Any]]:
self._fp8_metas = self._make_fp8_metas()
return self._fp8_metas[mode]

def pre_forward(self) -> None:
@torch.no_grad()
def _save_fp8_metas(self) -> Optional[dict[str, Any]]:
"""Create copies of tensors in FP8 metadata
Tensor copies can be loaded with _load_fp8_metas.
"""
if self._fp8_metas is None:
return None
out = {}
for mode, fp8_meta in self._fp8_metas.items():
if fp8_meta is None:
continue
out[mode] = {}
for is_forward in (True, False):
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=is_forward)
if fp8_meta_key not in fp8_meta:
continue
out[mode][fp8_meta_key] = (
fp8_meta[fp8_meta_key].scale.clone(),
fp8_meta[fp8_meta_key].scale_inv.clone(),
fp8_meta[fp8_meta_key].amax_history.clone(),
)
return out

@torch.no_grad()
def _load_fp8_metas(self, fp8_metas: Optional[dict[str, Any]]) -> None:
"""Update FP8 metadata with saved tensor copies
Tensor copies should be generated with _save_fp8_metas.
"""
assert (self._fp8_metas is None) == (fp8_metas is None), \
"Saved FP8 metadata does not match operation's FP8 metadata"
if fp8_metas is None:
return
for mode, fp8_meta in fp8_metas.items():
assert mode in self._fp8_metas, \
f"Found an unexpected key ({mode=}) in saved FP8 metadata"
for fp8_meta_key, tensors in fp8_meta.items():
assert fp8_meta_key in self._fp8_metas[mode], \
f"Found an unexpected key ({mode=}, {fp8_meta_key=}) in saved FP8 metadata"
scale, scale_inv, amax_history = tensors
self._fp8_metas[mode][fp8_meta_key].scale.copy_(scale)
self._fp8_metas[mode][fp8_meta_key].scale_inv.copy_(scale_inv)
self._fp8_metas[mode][fp8_meta_key].amax_history.copy_(amax_history)

def pre_forward(
self,
*,
fp8_enabled: Optional[bool] = None,
fp8_recipe: Optional[DelayedScaling] = None,
) -> None:
"""Preprocessing before forward pass"""

# Initialize FP8 metadata if needed
fp8_enabled = FP8GlobalStateManager.is_fp8_enabled()
if fp8_enabled is None:
fp8_enabled = FP8GlobalStateManager.is_fp8_enabled()
if fp8_enabled:

# Construct FP8 metadata if needed
Expand All @@ -286,7 +369,7 @@ def pre_forward(self) -> None:

# Make sure FP8 metadata matches FP8 autocast context
for fp8_meta in self._fp8_metas.values():
self._maybe_update_fp8_meta(fp8_meta)
self._maybe_update_fp8_meta(fp8_meta, fp8_recipe=fp8_recipe)

# Register FP8 metadata for amax and scale update
if not FP8GlobalStateManager.fp8_graph_capturing():
Expand Down

0 comments on commit d771ca5

Please sign in to comment.