Skip to content

Commit

Permalink
Fixed offloading for PyT version/ Added Attention activation offloadi…
Browse files Browse the repository at this point in the history
…ng support/ Native FP8 support (#632)

* Fixed offloading for PyT version/ Added Attention activation offloading support/ Native FP8 support

Signed-off-by: Selvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Removed activation offloading for fused attention

Signed-off-by: Selvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Fixed the illegal memory access issue for activation offloading of attention

Signed-off-by: Selvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Removed the version guard

Signed-off-by: Selvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Pipeline failures fix

Signed-off-by: Selvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Fixed lint erros

Signed-off-by: Selvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Lint error fix

Signed-off-by: Selvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

---------

Signed-off-by: Selvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>
Co-authored-by: Selvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>
  • Loading branch information
sanandaraj5597 and Selvaraj Anandaraj authored Jan 30, 2024
1 parent 4077ccc commit 44574de
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 18 deletions.
24 changes: 24 additions & 0 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1752,6 +1752,14 @@ def forward(
deterministic=self.deterministic
)
else:

from .cpu_offload import CPUOffloadEnabled
if CPUOffloadEnabled:
tensor_list = [query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv]
for tensor in tensor_list:
if tensor is not None:
tensor.activation_offloading = True

with self.attention_dropout_ctx():
fa_optional_forward_kwargs = {}
if _flash_attn_2_3_plus:
Expand Down Expand Up @@ -1938,6 +1946,15 @@ def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seql
attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
rng_gen)

from .cpu_offload import CPUOffloadEnabled
if CPUOffloadEnabled:
tensor_list = [q, k, v, out, cu_seqlens_q, cu_seqlens_kv]
qkv_layout = 'sbhd_sbhd_sbhd'
for tensor in tensor_list:
if tensor is not None:
tensor.activation_offloading = True


ctx.save_for_backward(q, k, v, out, cu_seqlens_q, cu_seqlens_kv)
ctx.aux_ctx_tensors = aux_ctx_tensors
ctx.max_seqlen_q = max_seqlen_q
Expand Down Expand Up @@ -2818,6 +2835,13 @@ def forward(
assert (not context_parallel), \
"Context parallelism is only implemented with Flash Attention and Fused Attention!"

from .cpu_offload import CPUOffloadEnabled
if CPUOffloadEnabled:
warnings.warn(
"Attention activation Offloading is only implemented"
"with Flash Attention and Fused Attention!"
)

if _NVTE_DEBUG:
print("[DotProductAttention]: using unfused DPA")
if use_unfused_attention:
Expand Down
46 changes: 31 additions & 15 deletions transformer_engine/pytorch/cpu_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def groupid_reset(self):
# the tensor back to gpu and deletes the cpu tensor.
# These will increment whenever `group_commit()` is invoked
self.current_group, self.tensor_count_current_group = (0, 0)
self.torch_tensor_count = 0
self.tensor_tag_to_state = {}

def on_group_commit_forward(self):
Expand Down Expand Up @@ -310,24 +311,35 @@ def get_tensor_buf_for_offloaded_tensor(self, tensor, tensor_tag):


def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any:
# obtain a unique tensor tag
tensor_tag = (self.current_group, self.tensor_count_current_group)
self.tensor_count_current_group += 1
assert tensor_tag not in self.tensor_tag_to_state

if (self.current_group < self.num_offload_group
and self.tensor_need_offloading_checker(tensor)):
# first copy the tensor to tensorbuf, so that the original tensor will not be deleted
tensor_buf = self.get_tensor_buf_for_offloaded_tensor(tensor, tensor_tag)
tensor_buf.copy_(tensor)
if hasattr(tensor,"weight_offloading"):
tensor_buf.weight_offloading = True
if hasattr(tensor,"activation_offloading"):
tensor_buf.activation_offloading = True
# Here we just save it, and at commit, bulk_offload_group will handle it
self.tensor_tag_to_state[tensor_tag] = tensor_buf
torch_stray_tensor = isinstance(tensor,(torch._subclasses.fake_tensor.FakeTensor,
torch._subclasses.functional_tensor.FunctionalTensor))

if not torch_stray_tensor:
# obtain a unique tensor tag
tensor_tag = (self.current_group, self.tensor_count_current_group)
self.tensor_count_current_group += 1
assert tensor_tag not in self.tensor_tag_to_state

if (self.current_group < self.num_offload_group
and self.tensor_need_offloading_checker(tensor)):
# first copy the tensor to tensorbuf,
# so that the original tensor will not be deleted
tensor_buf = self.get_tensor_buf_for_offloaded_tensor(tensor, tensor_tag)
tensor_buf.copy_(tensor)
if hasattr(tensor,"weight_offloading"):
tensor_buf.weight_offloading = True
if hasattr(tensor,"activation_offloading"):
tensor_buf.activation_offloading = True
# Here we just save it, and at commit, bulk_offload_group will handle it
self.tensor_tag_to_state[tensor_tag] = tensor_buf
else:
self.tensor_tag_to_state[tensor_tag] = tensor
else:
tensor_tag = (-1,self.torch_tensor_count)
self.torch_tensor_count += 1
self.tensor_tag_to_state[tensor_tag] = tensor

return tensor_tag

def tensor_pop(self, tensor_tag, **kwargs):
Expand All @@ -350,6 +362,10 @@ def bulk_offload_group(self, group_to_offload):

# if offload, return the reference to cpu copy
if self.tensor_need_offloading_checker(tensor_on_device):
if hasattr(tensor_on_device,"weight_offloading"):
delattr(tensor_on_device,"weight_offloading")
if hasattr(tensor_on_device,"activation_offloading"):
delattr(tensor_on_device,"activation_offloading")
state = SynchronizedGroupOffloadHandler.offload(tensor_on_device)
self.tensor_tag_to_state[tensor_tag] = state

Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def forward(
if cpu_offloading:
if fuse_wgrad_accumulation:
weight.main_grad.weight_offloading = True
if fp8:
if fp8 and weight_t_fp8 is not None:
weight_t_fp8.weight_offloading = True
ln_weight.weight_offloading = True
weight.weight_offloading = True
Expand Down
3 changes: 2 additions & 1 deletion transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,8 +424,9 @@ def forward(
if fuse_wgrad_accumulation:
fc1_weight.main_grad.weight_offloading = True
fc2_weight.main_grad.weight_offloading = True
if fp8:
if fp8 and fc1_weight_t_fp8 is not None:
fc1_weight_t_fp8.weight_offloading = True
if fp8 and fc2_weight_t_fp8 is not None:
fc2_weight_t_fp8.weight_offloading = True
ln_weight.weight_offloading = True
fc1_weight.weight_offloading = True
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def forward(
if cpu_offloading:
if fuse_wgrad_accumulation:
weight.main_grad.weight_offloading = True
if fp8:
if fp8 and weight_t_fp8 is not None:
weight_t_fp8.weight_offloading = True
weight.weight_offloading = True

Expand Down

0 comments on commit 44574de

Please sign in to comment.