diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 16bd128734..e3abfa00fc 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -21,6 +21,7 @@ # Register custom op symbolic ONNX functions from .te_onnx_extensions import ( onnx_cast_to_fp8, + onnx_cast_to_fp8_noalloc, onnx_cast_from_fp8, onnx_fp8_gelu, onnx_fp8_relu, diff --git a/transformer_engine/pytorch/cpp_extensions/cast.py b/transformer_engine/pytorch/cpp_extensions/cast.py index 9e9e0384c2..3c80beff97 100644 --- a/transformer_engine/pytorch/cpp_extensions/cast.py +++ b/transformer_engine/pytorch/cpp_extensions/cast.py @@ -22,12 +22,13 @@ def cast_to_fp8( """Cast input to FP8""" if out is not None: - tex.cast_to_fp8_noalloc( + torch.ops.tex_ts.cast_to_fp8_noalloc_ts( inp, - fp8_meta_tensor.scale[fp8_tensor], + fp8_meta_tensor.scale, out, - fp8_meta_tensor.amax_history[0][fp8_tensor], - fp8_meta_tensor.scale_inv[fp8_tensor], + fp8_meta_tensor.amax_history, + fp8_meta_tensor.scale_inv, + fp8_tensor, otype ) return None diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cu b/transformer_engine/pytorch/csrc/extensions/cast.cu index 860b97a56a..80975069de 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cu +++ b/transformer_engine/pytorch/csrc/extensions/cast.cu @@ -6,6 +6,7 @@ #include "extensions.h" + at::Tensor cast_to_fp8(const at::Tensor &input, const at::Tensor &scale, at::Tensor amax, @@ -31,11 +32,11 @@ at::Tensor cast_to_fp8(const at::Tensor &input, void cast_to_fp8_noalloc(const at::Tensor &input, - const at::Tensor &scale, - at::Tensor output, - at::Tensor amax, - at::Tensor scale_inv, - transformer_engine::DType otype + const at::Tensor &scale, + at::Tensor output, + at::Tensor amax, + at::Tensor scale_inv, + transformer_engine::DType otype ) { using namespace transformer_engine; size_t N = static_cast(input.size(0)); diff --git a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp index 7d362d0709..a9659e1b7a 100755 --- a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp +++ b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp @@ -7,6 +7,7 @@ #include #include "extensions.h" + namespace { transformer_engine::DType reverse_map_dtype(int64_t dtype) { if (dtype >= 0 && dtype < static_cast(transformer_engine::DType::kNumTypes)) { @@ -20,8 +21,8 @@ namespace { at::Tensor cast_to_fp8_ts(const at::Tensor &input, const at::Tensor &scale, - const at::Tensor &amax, - const at::Tensor &scale_inv, + at::Tensor amax, + at::Tensor scale_inv, int64_t fp8_tensor, int64_t otype) { transformer_engine::DType otype_arg = reverse_map_dtype(otype); @@ -33,6 +34,25 @@ at::Tensor cast_to_fp8_ts(const at::Tensor &input, return output; } + +at::Tensor cast_to_fp8_noalloc_ts(const at::Tensor &input, + const at::Tensor &scale, + at::Tensor output, + at::Tensor amax, + at::Tensor scale_inv, + int64_t fp8_tensor, + int64_t otype) { + transformer_engine::DType otype_arg = reverse_map_dtype(otype); + cast_to_fp8_noalloc(input, + scale[fp8_tensor], + output, + amax[0][fp8_tensor], + scale_inv[fp8_tensor], + otype_arg); + return output; +} + + at::Tensor cast_from_fp8_ts(const at::Tensor &input, const at::Tensor &scale_inv, int64_t fp8_tensor, @@ -47,6 +67,7 @@ at::Tensor cast_from_fp8_ts(const at::Tensor &input, return output; } + at::Tensor gelu_ts(at::Tensor input, at::Tensor scale, at::Tensor amax, @@ -82,6 +103,7 @@ at::Tensor gelu_ts(at::Tensor input, return output; } + at::Tensor relu_ts(at::Tensor input, at::Tensor scale, at::Tensor amax, @@ -117,6 +139,7 @@ at::Tensor relu_ts(at::Tensor input, return output; } + at::Tensor reglu_ts(at::Tensor input, at::Tensor scale, at::Tensor amax, @@ -152,6 +175,7 @@ at::Tensor reglu_ts(at::Tensor input, return output; } + at::Tensor geglu_ts(at::Tensor input, at::Tensor scale, at::Tensor amax, @@ -187,6 +211,7 @@ at::Tensor geglu_ts(at::Tensor input, return output; } + at::Tensor swiglu_ts(at::Tensor input, at::Tensor scale, at::Tensor amax, @@ -222,6 +247,7 @@ at::Tensor swiglu_ts(at::Tensor input, return output; } + at::Tensor te_gemm_ts(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, @@ -286,6 +312,7 @@ at::Tensor te_gemm_ts(at::Tensor A, return D; } + at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, @@ -312,6 +339,7 @@ at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input, return output; } + at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, @@ -328,6 +356,7 @@ at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input, return output; } + at::Tensor rmsnorm_fwd_fp8_inf_ts(const at::Tensor &input, const at::Tensor &weight, double eps, @@ -352,6 +381,7 @@ at::Tensor rmsnorm_fwd_fp8_inf_ts(const at::Tensor &input, return output; } + at::Tensor rmsnorm_fwd_inf_ts(const at::Tensor &input, const at::Tensor &weight, double eps, @@ -366,8 +396,10 @@ at::Tensor rmsnorm_fwd_inf_ts(const at::Tensor &input, return output; } + TORCH_LIBRARY(tex_ts, m) { m.def("cast_to_fp8_ts", &cast_to_fp8_ts); + m.def("cast_to_fp8_noalloc_ts", &cast_to_fp8_noalloc_ts); m.def("cast_from_fp8_ts", &cast_from_fp8_ts); m.def("gelu_ts", &gelu_ts); m.def("relu_ts", &relu_ts); diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 6836ef6d22..687b1ef88a 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -36,6 +36,8 @@ allreduce, reduce_scatter_along_first_dim, gather_along_first_dim, + is_fp8_activation_recompute_enabled, + in_fp8_activation_recompute_phase, ) from ..constants import GemmParallelModes, dist_group_type, TE_DType from ..jit import no_torch_dynamo @@ -173,7 +175,9 @@ def forward( fp8_meta=fp8_meta, fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, ) - if is_grad_enabled: + if (is_grad_enabled + or (is_fp8_activation_recompute_enabled() + and not in_fp8_activation_recompute_phase())): tex.fp8_cast_transpose_fused( weight, fp8_meta["scaling_fwd"], @@ -183,11 +187,12 @@ def forward( transpose_out=weight_t_fp8._data, ) else: - weight_fp8._data = tex.cast_to_fp8( + tex.cast_to_fp8( weight, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, + out=weight_fp8._data, ) weight_t_fp8 = None diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 3a0e5cb559..8b92728506 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -41,6 +41,8 @@ allreduce, reduce_scatter_along_first_dim, gather_along_first_dim, + is_fp8_activation_recompute_enabled, + in_fp8_activation_recompute_phase, ) from .. import cpp_extensions as tex @@ -219,7 +221,9 @@ def forward( fp8_meta=fp8_meta, fp8_meta_index=tex.FP8FwdTensors.GEMM2_WEIGHT, ) - if is_grad_enabled: + if (is_grad_enabled + or (is_fp8_activation_recompute_enabled() + and not in_fp8_activation_recompute_phase())): # Fused cast-transpose kernels tex.fp8_cast_transpose_fused( fc1_weight, @@ -238,18 +242,20 @@ def forward( transpose_out=fc2_weight_t_fp8._data, ) else: - fc1_weight_fp8._data = tex.cast_to_fp8( + tex.cast_to_fp8( fc1_weight, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, + out=fc1_weight_fp8._data, ) fc1_weight_t_fp8 = None - fc2_weight_fp8._data = tex.cast_to_fp8( + tex.cast_to_fp8( fc2_weight, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM2_WEIGHT, fp8_dtype_forward, + out=fc2_weight_fp8._data, ) fc2_weight_t_fp8 = None diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index f2c955bfc0..6c4f13c685 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -34,6 +34,8 @@ allreduce, reduce_scatter_along_first_dim, gather_along_first_dim, + is_fp8_activation_recompute_enabled, + in_fp8_activation_recompute_phase, ) from ..cpp_extensions import ( fp8_gemm, @@ -155,7 +157,9 @@ def forward( fp8_meta=fp8_meta, fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, ) - if is_grad_enabled: + if (is_grad_enabled + or (is_fp8_activation_recompute_enabled() + and not in_fp8_activation_recompute_phase())): fp8_cast_transpose_fused( weight, fp8_meta["scaling_fwd"], @@ -165,11 +169,12 @@ def forward( transpose_out=weight_t_fp8._data, ) else: - weight_fp8._data = cast_to_fp8( + cast_to_fp8( weight, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, + out=weight_fp8._data, ) weight_t_fp8 = None diff --git a/transformer_engine/pytorch/te_onnx_extensions.py b/transformer_engine/pytorch/te_onnx_extensions.py index 81c9d71c74..67ff4ce161 100755 --- a/transformer_engine/pytorch/te_onnx_extensions.py +++ b/transformer_engine/pytorch/te_onnx_extensions.py @@ -130,6 +130,13 @@ def onnx_cast_to_fp8(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): return quantize(g, inputs, scale_inv, fp8_tensor) +@symbolic_helper.parse_args("v", "v", "v", "v", "fs", "i", "i") +def onnx_cast_to_fp8_noalloc(g, inputs, scale, output, amax, scale_inv, fp8_tensor, otype): + """ONNX graph for cast_to_fp8_noalloc""" + # pylint: disable=unused-argument + return quantize(g, inputs, scale_inv, fp8_tensor) + + @symbolic_helper.parse_args("v", "fs", "i", "i", "i") def onnx_cast_from_fp8(g, inputs, scale_inv, fp8_tensor, itype, otype): """ONNX graph for cast_from_fp8""" @@ -393,10 +400,11 @@ def onnx_rmsnorm_fwd(g, inputs, weight, eps, zero_centered_gamma): result = g.op("Mul", weight, normalized_input) result = g.op("Cast", result, to_i=get_TensorProtoDataType(inputs)) - return result + register_custom_op_symbolic('tex_ts::cast_to_fp8_ts', onnx_cast_to_fp8, VER) +register_custom_op_symbolic('tex_ts::cast_to_fp8_noalloc_ts', onnx_cast_to_fp8_noalloc, VER) register_custom_op_symbolic('tex_ts::cast_from_fp8_ts', onnx_cast_from_fp8, VER) register_custom_op_symbolic('tex_ts::gelu_ts', onnx_fp8_gelu, VER) register_custom_op_symbolic('tex_ts::relu_ts', onnx_fp8_relu, VER)