Skip to content

Commit

Permalink
[PyTorch] Fix pipeline parallel execution by using cloned scale inver…
Browse files Browse the repository at this point in the history
…se tensors (#659)

Use cloned scale_inv for fp8 cast

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
  • Loading branch information
ksivaman authored Feb 8, 2024
1 parent d9eb199 commit 91d52ac
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,9 +484,9 @@ def backward(
)
clear_tensor_data(ln_out_total_t, grad_output_t)
else:
ln_out_total_c = tex.cast_from_fp8(
ln_out_total_c = torch.ops.tex_ts.cast_from_fp8_ts(
ln_out_total,
ctx.fp8_meta["scaling_fwd"],
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
TE_DType[ctx.activation_dtype],
Expand Down
8 changes: 4 additions & 4 deletions transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,9 +676,9 @@ def backward(
clear_tensor_data(fc1_out)
else:
if fc2_weight.requires_grad:
gelu_out_c = tex.cast_from_fp8(
gelu_out_c = torch.ops.tex_ts.cast_from_fp8_ts(
gelu_out,
ctx.fp8_meta["scaling_fwd"],
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM2_INPUT,
fp8_dtype_forward,
TE_DType[ctx.activation_dtype],
Expand Down Expand Up @@ -875,9 +875,9 @@ def backward(
)
clear_tensor_data(ln_out_total_t, dgelu_t)
else:
ln_out_total_c = tex.cast_from_fp8(
ln_out_total_c = torch.ops.tex_ts.cast_from_fp8_ts(
ln_out_total,
ctx.fp8_meta["scaling_fwd"],
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
TE_DType[ctx.activation_dtype],
Expand Down

0 comments on commit 91d52ac

Please sign in to comment.