Skip to content

Commit

Permalink
Recomputation fixes with native fp8 (#646)
Browse files Browse the repository at this point in the history
* fixes for recomputation

Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>

* lint

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix onnx export [wip]

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* register op; fixes

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

---------

Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: Jimmy Zhang <jiemingz@nvidia.com>
Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
  • Loading branch information
3 people authored Feb 3, 2024
1 parent 2aee059 commit 5b155fb
Show file tree
Hide file tree
Showing 8 changed files with 78 additions and 19 deletions.
1 change: 1 addition & 0 deletions transformer_engine/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
# Register custom op symbolic ONNX functions
from .te_onnx_extensions import (
onnx_cast_to_fp8,
onnx_cast_to_fp8_noalloc,
onnx_cast_from_fp8,
onnx_fp8_gelu,
onnx_fp8_relu,
Expand Down
9 changes: 5 additions & 4 deletions transformer_engine/pytorch/cpp_extensions/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@ def cast_to_fp8(
"""Cast input to FP8"""

if out is not None:
tex.cast_to_fp8_noalloc(
torch.ops.tex_ts.cast_to_fp8_noalloc_ts(
inp,
fp8_meta_tensor.scale[fp8_tensor],
fp8_meta_tensor.scale,
out,
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_tensor,
otype
)
return None
Expand Down
11 changes: 6 additions & 5 deletions transformer_engine/pytorch/csrc/extensions/cast.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "extensions.h"


at::Tensor cast_to_fp8(const at::Tensor &input,
const at::Tensor &scale,
at::Tensor amax,
Expand All @@ -31,11 +32,11 @@ at::Tensor cast_to_fp8(const at::Tensor &input,


void cast_to_fp8_noalloc(const at::Tensor &input,
const at::Tensor &scale,
at::Tensor output,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
const at::Tensor &scale,
at::Tensor output,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(0));
Expand Down
36 changes: 34 additions & 2 deletions transformer_engine/pytorch/csrc/ts_fp8_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <torch/script.h>
#include "extensions.h"


namespace {
transformer_engine::DType reverse_map_dtype(int64_t dtype) {
if (dtype >= 0 && dtype < static_cast<int64_t>(transformer_engine::DType::kNumTypes)) {
Expand All @@ -20,8 +21,8 @@ namespace {

at::Tensor cast_to_fp8_ts(const at::Tensor &input,
const at::Tensor &scale,
const at::Tensor &amax,
const at::Tensor &scale_inv,
at::Tensor amax,
at::Tensor scale_inv,
int64_t fp8_tensor,
int64_t otype) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype);
Expand All @@ -33,6 +34,25 @@ at::Tensor cast_to_fp8_ts(const at::Tensor &input,
return output;
}


at::Tensor cast_to_fp8_noalloc_ts(const at::Tensor &input,
const at::Tensor &scale,
at::Tensor output,
at::Tensor amax,
at::Tensor scale_inv,
int64_t fp8_tensor,
int64_t otype) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype);
cast_to_fp8_noalloc(input,
scale[fp8_tensor],
output,
amax[0][fp8_tensor],
scale_inv[fp8_tensor],
otype_arg);
return output;
}


at::Tensor cast_from_fp8_ts(const at::Tensor &input,
const at::Tensor &scale_inv,
int64_t fp8_tensor,
Expand All @@ -47,6 +67,7 @@ at::Tensor cast_from_fp8_ts(const at::Tensor &input,
return output;
}


at::Tensor gelu_ts(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
Expand Down Expand Up @@ -82,6 +103,7 @@ at::Tensor gelu_ts(at::Tensor input,
return output;
}


at::Tensor relu_ts(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
Expand Down Expand Up @@ -117,6 +139,7 @@ at::Tensor relu_ts(at::Tensor input,
return output;
}


at::Tensor reglu_ts(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
Expand Down Expand Up @@ -152,6 +175,7 @@ at::Tensor reglu_ts(at::Tensor input,
return output;
}


at::Tensor geglu_ts(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
Expand Down Expand Up @@ -187,6 +211,7 @@ at::Tensor geglu_ts(at::Tensor input,
return output;
}


at::Tensor swiglu_ts(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
Expand Down Expand Up @@ -222,6 +247,7 @@ at::Tensor swiglu_ts(at::Tensor input,
return output;
}


at::Tensor te_gemm_ts(at::Tensor A,
at::Tensor A_scale_inverse,
int64_t A_fp8_tensor,
Expand Down Expand Up @@ -286,6 +312,7 @@ at::Tensor te_gemm_ts(at::Tensor A,
return D;
}


at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
Expand All @@ -312,6 +339,7 @@ at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input,
return output;
}


at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
Expand All @@ -328,6 +356,7 @@ at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input,
return output;
}


at::Tensor rmsnorm_fwd_fp8_inf_ts(const at::Tensor &input,
const at::Tensor &weight,
double eps,
Expand All @@ -352,6 +381,7 @@ at::Tensor rmsnorm_fwd_fp8_inf_ts(const at::Tensor &input,
return output;
}


at::Tensor rmsnorm_fwd_inf_ts(const at::Tensor &input,
const at::Tensor &weight,
double eps,
Expand All @@ -366,8 +396,10 @@ at::Tensor rmsnorm_fwd_inf_ts(const at::Tensor &input,
return output;
}


TORCH_LIBRARY(tex_ts, m) {
m.def("cast_to_fp8_ts", &cast_to_fp8_ts);
m.def("cast_to_fp8_noalloc_ts", &cast_to_fp8_noalloc_ts);
m.def("cast_from_fp8_ts", &cast_from_fp8_ts);
m.def("gelu_ts", &gelu_ts);
m.def("relu_ts", &relu_ts);
Expand Down
9 changes: 7 additions & 2 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
allreduce,
reduce_scatter_along_first_dim,
gather_along_first_dim,
is_fp8_activation_recompute_enabled,
in_fp8_activation_recompute_phase,
)
from ..constants import GemmParallelModes, dist_group_type, TE_DType
from ..jit import no_torch_dynamo
Expand Down Expand Up @@ -173,7 +175,9 @@ def forward(
fp8_meta=fp8_meta,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
)
if is_grad_enabled:
if (is_grad_enabled
or (is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase())):
tex.fp8_cast_transpose_fused(
weight,
fp8_meta["scaling_fwd"],
Expand All @@ -183,11 +187,12 @@ def forward(
transpose_out=weight_t_fp8._data,
)
else:
weight_fp8._data = tex.cast_to_fp8(
tex.cast_to_fp8(
weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
out=weight_fp8._data,
)
weight_t_fp8 = None

Expand Down
12 changes: 9 additions & 3 deletions transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
allreduce,
reduce_scatter_along_first_dim,
gather_along_first_dim,
is_fp8_activation_recompute_enabled,
in_fp8_activation_recompute_phase,
)

from .. import cpp_extensions as tex
Expand Down Expand Up @@ -219,7 +221,9 @@ def forward(
fp8_meta=fp8_meta,
fp8_meta_index=tex.FP8FwdTensors.GEMM2_WEIGHT,
)
if is_grad_enabled:
if (is_grad_enabled
or (is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase())):
# Fused cast-transpose kernels
tex.fp8_cast_transpose_fused(
fc1_weight,
Expand All @@ -238,18 +242,20 @@ def forward(
transpose_out=fc2_weight_t_fp8._data,
)
else:
fc1_weight_fp8._data = tex.cast_to_fp8(
tex.cast_to_fp8(
fc1_weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
out=fc1_weight_fp8._data,
)
fc1_weight_t_fp8 = None
fc2_weight_fp8._data = tex.cast_to_fp8(
tex.cast_to_fp8(
fc2_weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM2_WEIGHT,
fp8_dtype_forward,
out=fc2_weight_fp8._data,
)
fc2_weight_t_fp8 = None

Expand Down
9 changes: 7 additions & 2 deletions transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
allreduce,
reduce_scatter_along_first_dim,
gather_along_first_dim,
is_fp8_activation_recompute_enabled,
in_fp8_activation_recompute_phase,
)
from ..cpp_extensions import (
fp8_gemm,
Expand Down Expand Up @@ -155,7 +157,9 @@ def forward(
fp8_meta=fp8_meta,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
)
if is_grad_enabled:
if (is_grad_enabled
or (is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase())):
fp8_cast_transpose_fused(
weight,
fp8_meta["scaling_fwd"],
Expand All @@ -165,11 +169,12 @@ def forward(
transpose_out=weight_t_fp8._data,
)
else:
weight_fp8._data = cast_to_fp8(
cast_to_fp8(
weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
out=weight_fp8._data,
)
weight_t_fp8 = None

Expand Down
10 changes: 9 additions & 1 deletion transformer_engine/pytorch/te_onnx_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,13 @@ def onnx_cast_to_fp8(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
return quantize(g, inputs, scale_inv, fp8_tensor)


@symbolic_helper.parse_args("v", "v", "v", "v", "fs", "i", "i")
def onnx_cast_to_fp8_noalloc(g, inputs, scale, output, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for cast_to_fp8_noalloc"""
# pylint: disable=unused-argument
return quantize(g, inputs, scale_inv, fp8_tensor)


@symbolic_helper.parse_args("v", "fs", "i", "i", "i")
def onnx_cast_from_fp8(g, inputs, scale_inv, fp8_tensor, itype, otype):
"""ONNX graph for cast_from_fp8"""
Expand Down Expand Up @@ -393,10 +400,11 @@ def onnx_rmsnorm_fwd(g, inputs, weight, eps, zero_centered_gamma):
result = g.op("Mul", weight, normalized_input)
result = g.op("Cast", result, to_i=get_TensorProtoDataType(inputs))


return result


register_custom_op_symbolic('tex_ts::cast_to_fp8_ts', onnx_cast_to_fp8, VER)
register_custom_op_symbolic('tex_ts::cast_to_fp8_noalloc_ts', onnx_cast_to_fp8_noalloc, VER)
register_custom_op_symbolic('tex_ts::cast_from_fp8_ts', onnx_cast_from_fp8, VER)
register_custom_op_symbolic('tex_ts::gelu_ts', onnx_fp8_gelu, VER)
register_custom_op_symbolic('tex_ts::relu_ts', onnx_fp8_relu, VER)
Expand Down

0 comments on commit 5b155fb

Please sign in to comment.