diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 6a4c8475c4..910b85b26e 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -208,17 +208,11 @@ std::vector fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output, int scale_offset = 0, int amax_offset = 0, int scale_inv_offset = 0); -void fused_dswiglu_cast_transpose(at::Tensor grad_output, - at::Tensor input, - at::Tensor grad_input, - at::Tensor grad_input_transpose, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - transformer_engine::DType otype, - int scale_offset = 0, - int amax_offset = 0, - int scale_inv_offset = 0); +void fused_dswiglu_cast_transpose(at::Tensor grad_output, at::Tensor input, at::Tensor grad_input, + at::Tensor grad_input_transpose, at::Tensor scale, + at::Tensor amax, at::Tensor scale_inv, + transformer_engine::DType otype, int scale_offset = 0, + int amax_offset = 0, int scale_inv_offset = 0); void fused_multi_cast_transpose(std::vector input_list, std::vector scale_list, diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 2aaca7e68f..4c93526461 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -90,11 +90,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); m.def("fused_dswiglu_cast_transpose", &fused_dswiglu_cast_transpose, "Fused SwiGLU backward + FP8 cast + FP8 transpose", - py::call_guard(), - py::arg("grad_output"), py::arg("input"), - py::arg("grad_input"), py::arg("grad_input_transpose"), - py::arg("scale"), py::arg("amax"), py::arg("scale_inv"), - py::arg("otype"), py::arg("scale_offset") = 0, + py::call_guard(), py::arg("grad_output"), py::arg("input"), + py::arg("grad_input"), py::arg("grad_input_transpose"), py::arg("scale"), py::arg("amax"), + py::arg("scale_inv"), py::arg("otype"), py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); m.def("fused_multi_cast_transpose", &fused_multi_cast_transpose, "Fused Multi-tensor Cast + Transpose", py::call_guard()); diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cu b/transformer_engine/pytorch/csrc/extensions/transpose.cu index fb0d105345..f373cdf83a 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cu +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cu @@ -196,17 +196,11 @@ std::vector fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output, return {grad_bias, dgelu, dgelu_transpose}; } -void fused_dswiglu_cast_transpose(at::Tensor grad_output, - at::Tensor input, - at::Tensor grad_input, - at::Tensor grad_input_transpose, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - transformer_engine::DType otype, - int scale_offset, - int amax_offset, - int scale_inv_offset) { +void fused_dswiglu_cast_transpose(at::Tensor grad_output, at::Tensor input, at::Tensor grad_input, + at::Tensor grad_input_transpose, at::Tensor scale, + at::Tensor amax, at::Tensor scale_inv, + transformer_engine::DType otype, int scale_offset, + int amax_offset, int scale_inv_offset) { using namespace transformer_engine; // Tensor dimensions @@ -217,33 +211,28 @@ void fused_dswiglu_cast_transpose(at::Tensor grad_output, const auto N = static_cast(grad_output.size(-1)); // Check tensor dims - NVTE_CHECK(grad_output.dim() == 2, - "Expected grad output tensor to have 2 dims, but found ", grad_output.dim()); - NVTE_CHECK(input.dim() == 2, - "Expected input tensor to have 2 dims, but found ", input.dim()); - NVTE_CHECK(outer_dim(input) == M, - "Expected input tensor to have outer dimension of ", - M, ", but found ", outer_dim(input)); - NVTE_CHECK(input.size(-1) == 2*N, - "Expected input tensor to have inner dimension of ", - 2*N, ", but found ", input.size(-1)); - NVTE_CHECK(grad_input.dim() == 2, - "Expected grad input tensor to have 2 dims, but found ", grad_input.dim()); - NVTE_CHECK(outer_dim(grad_input) == M, - "Expected grad input tensor to have outer dimension of ", + NVTE_CHECK(grad_output.dim() == 2, "Expected grad output tensor to have 2 dims, but found ", + grad_output.dim()); + NVTE_CHECK(input.dim() == 2, "Expected input tensor to have 2 dims, but found ", input.dim()); + NVTE_CHECK(outer_dim(input) == M, "Expected input tensor to have outer dimension of ", M, + ", but found ", outer_dim(input)); + NVTE_CHECK(input.size(-1) == 2 * N, "Expected input tensor to have inner dimension of ", 2 * N, + ", but found ", input.size(-1)); + NVTE_CHECK(grad_input.dim() == 2, "Expected grad input tensor to have 2 dims, but found ", + grad_input.dim()); + NVTE_CHECK(outer_dim(grad_input) == M, "Expected grad input tensor to have outer dimension of ", M, ", but found ", outer_dim(grad_input)); - NVTE_CHECK(grad_input.size(-1) == 2*N, - "Expected grad input tensor to have inner dimension of ", - 2*N, ", but found ", grad_input.size(-1)); + NVTE_CHECK(grad_input.size(-1) == 2 * N, "Expected grad input tensor to have inner dimension of ", + 2 * N, ", but found ", grad_input.size(-1)); NVTE_CHECK(grad_input_transpose.dim() == 2, "Expected grad input transpose tensor to have 2 dims, but found ", grad_input_transpose.dim()); - NVTE_CHECK(grad_input_transpose.size(0) == 2*N, - "Expected grad input tensor to have outer dimension of ", - 2*N, ", but found ", grad_input_transpose.size(0)); + NVTE_CHECK(grad_input_transpose.size(0) == 2 * N, + "Expected grad input tensor to have outer dimension of ", 2 * N, ", but found ", + grad_input_transpose.size(0)); NVTE_CHECK(grad_input_transpose.size(1) == M, - "Expected grad input tensor to have outer dimension of ", - M, ", but found ", grad_input_transpose.size(1)); + "Expected grad input tensor to have outer dimension of ", M, ", but found ", + grad_input_transpose.size(1)); // Check tensor format NVTE_CHECK(grad_output.is_contiguous(), "Expected grad output tensor to be contiguous"); @@ -266,10 +255,10 @@ void fused_dswiglu_cast_transpose(at::Tensor grad_output, // Construct Transformer Engine tensors auto dy_cu = makeTransformerEngineTensor(grad_output); auto x_cu = makeTransformerEngineTensor(input); - auto dx_cu = makeTransformerEngineTensor(grad_input.data_ptr(), {M, 2*N}, otype, amax_dptr, + auto dx_cu = makeTransformerEngineTensor(grad_input.data_ptr(), {M, 2 * N}, otype, amax_dptr, scale_dptr, scale_inv_dptr); - auto dx_t_cu = makeTransformerEngineTensor(grad_input_transpose.data_ptr(), {2*N, M}, - otype, amax_dptr, scale_dptr, scale_inv_dptr); + auto dx_t_cu = makeTransformerEngineTensor(grad_input_transpose.data_ptr(), {2 * N, M}, otype, + amax_dptr, scale_dptr, scale_inv_dptr); // Launch kernel nvte_dswiglu_cast_transpose(dy_cu.data(), x_cu.data(), dx_cu.data(), dx_t_cu.data(), diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index a933619c34..a2e5a24a85 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -97,11 +97,7 @@ def op_forward( output_fp8_meta = None output_dtype = TE_DType[dtype] output_fp8_scale_inv = None - if ( - fp8_enabled - and next_op is not None - and next_op.num_fp8_scales("input") > 0 - ): + if fp8_enabled and next_op is not None and next_op.num_fp8_scales("input") > 0: with_fp8_output = True fp8_meta = next_op.get_fp8_meta("input") fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True)