Skip to content

Commit

Permalink
Store module extra state in tensor
Browse files Browse the repository at this point in the history
Signed-off-by: Tim Moon <tmoon@nvidia.com>
  • Loading branch information
timmoon10 committed Nov 14, 2024
1 parent 28aa41a commit 8cf1133
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 25 deletions.
91 changes: 67 additions & 24 deletions transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,20 +588,50 @@ def reset(key):

def get_extra_state(self) -> torch.Tensor:
"""Save before checkpointing."""
state = None

# This implementation is working around a few issues:
#
# (1) PyTorch's "extra state" infrastructure might be able to
# support any picklable type, but they make no guarantees.
# We have experienced problems (e.g. in ONNX export) with
# non-tensor extra state.
# (2) PyTorch's checkpointing infrastructure does not remap
# devices for "extra state" like it does for "state dict".
# Thus, we want to avoid putting extra state on the GPU
# since it may be loaded on the wrong device.
# (3) The extra state consists of many small tensors. If we
# want to copy them all to CPU, then we need to avoid the
# overhead of many GPU-CPU memory transfers.
#
# See: https://github.com/NVIDIA/TransformerEngine/pull/351
# See: https://github.com/NVIDIA/TransformerEngine/pull/363

def to_cpu(src: torch.Tensor) -> torch.Tensor:
"""Helper function to make CPU copy of tensor
Memory transfer is asynchronous w.r.t. host, so GPU should
be synchronized before using result.
"""
dst = torch.empty_like(src, device="cpu")
dst.copy_(src, non_blocking=True)
return dst

# Store FP8 state if needed
state = None
fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration

if fp8_checkpoint:

# Copy tensors to CPU and store
state = {}
state["scale_fwd"] = self.fp8_meta["scaling_fwd"].scale
state["scale_inv_fwd"] = self.fp8_meta["scaling_fwd"].scale_inv
state["amax_history_fwd"] = self.fp8_meta["scaling_fwd"].amax_history
state["scale_bwd"] = self.fp8_meta["scaling_bwd"].scale
state["scale_inv_bwd"] = self.fp8_meta["scaling_bwd"].scale_inv
state["amax_history_bwd"] = self.fp8_meta["scaling_bwd"].amax_history

# Store other pickelable values.
state["scale_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].scale)
state["amax_history_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].amax_history)
state["scale_inv_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].scale_inv)
state["scale_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].scale)
state["amax_history_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].amax_history)
state["scale_inv_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].scale_inv)

# Store other pickelable values
extra = {}
for k, v in self.fp8_meta.items():
if k != "buffer_index_and_autocast_key" and isinstance(
Expand All @@ -610,22 +640,23 @@ def get_extra_state(self) -> torch.Tensor:
extra[k] = v
state["extra_fp8_variables"] = extra

if is_in_onnx_export_mode():
state_serialized = torch.frombuffer(pickle.dumps(state), dtype=torch.uint8)
else:
state_serialized = io.BytesIO()
torch.save(state, state_serialized)

# Serialize state into byte tensor
torch.cuda.synchronize()
state_serialized = bytearray(pickle.dumps(state))
state_serialized = torch.frombuffer(state_serialized, dtype=torch.uint8)
return state_serialized

def set_extra_state(self, state: torch.Tensor) -> None:
"""Load previous state."""
if state is None:
return

# Load state
if isinstance(state, torch.Tensor):
# Default format: byte tensor with pickled data
state = pickle.loads(state.detach().cpu().numpy().tobytes())
elif isinstance(state, io.BytesIO):
# Deprecated format with io.BytesIO
state.seek(0)
state = torch.load(state, map_location="cuda")
else:
Expand All @@ -634,20 +665,32 @@ def set_extra_state(self, state: torch.Tensor) -> None:
if state is None:
return

# Load extra items.
# Load extra items
self.fp8_meta.update(state["extra_fp8_variables"])
self.fp8_meta["recipe"].amax_history_len = state["amax_history_fwd"].shape[0]
if "global_fp8_buffer_pos_fwd_recompute" in self.fp8_meta:
del self.fp8_meta["global_fp8_buffer_pos_fwd_recompute"]

# Initialize before loading.
# Initialize before loading
self.init_fp8_meta_tensors()
self.fp8_meta["scaling_fwd"].scale.copy_(state["scale_fwd"])
self.fp8_meta["scaling_fwd"].amax_history.copy_(state["amax_history_fwd"])
self.fp8_meta["scaling_bwd"].scale.copy_(state["scale_bwd"])
self.fp8_meta["scaling_bwd"].amax_history.copy_(state["amax_history_bwd"])
self.fp8_meta["scaling_fwd"].scale_inv.copy_(state["scale_inv_fwd"])
self.fp8_meta["scaling_bwd"].scale_inv.copy_(state["scale_inv_bwd"])

def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None:
"""Helper function to copy tensor from CPU
Memory transfer is asynchronous w.r.t. host, so GPU should
be synchronized before using result.
"""
dst.copy_(src, non_blocking=True)

# Load tensors
copy_tensor(state["scale_fwd"], self.fp8_meta["scaling_fwd"].scale)
copy_tensor(state["amax_history_fwd"], self.fp8_meta["scaling_fwd"].amax_history)
copy_tensor(state["scale_inv_fwd"], self.fp8_meta["scaling_fwd"].scale_inv)
copy_tensor(state["scale_bwd"], self.fp8_meta["scaling_bwd"].scale)
copy_tensor(state["amax_history_bwd"], self.fp8_meta["scaling_bwd"].amax_history)
copy_tensor(state["scale_inv_bwd"], self.fp8_meta["scaling_bwd"].scale_inv)
torch.cuda.synchronize()

def set_activation_dtype(self, inp: torch.Tensor) -> None:
"""Get activation data type for AMP."""
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/ops/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ def get_extra_state(self) -> torch.Tensor:
#
# (1) PyTorch's "extra state" infrastructure might be able to
# support any picklable type, but they make no guarantees.
# It seems that ONNX export experiences issues with
# We have experienced problems (e.g. in ONNX export) with
# non-tensor extra state.
# (2) PyTorch's checkpointing infrastructure does not remap
# devices for "extra state" like it does for "state dict".
Expand Down

0 comments on commit 8cf1133

Please sign in to comment.