Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Sep 27, 2024
1 parent ada6804 commit 01327c4
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 57 deletions.
16 changes: 5 additions & 11 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,17 +208,11 @@ std::vector<at::Tensor> fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output,
int scale_offset = 0, int amax_offset = 0,
int scale_inv_offset = 0);

void fused_dswiglu_cast_transpose(at::Tensor grad_output,
at::Tensor input,
at::Tensor grad_input,
at::Tensor grad_input_transpose,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype,
int scale_offset = 0,
int amax_offset = 0,
int scale_inv_offset = 0);
void fused_dswiglu_cast_transpose(at::Tensor grad_output, at::Tensor input, at::Tensor grad_input,
at::Tensor grad_input_transpose, at::Tensor scale,
at::Tensor amax, at::Tensor scale_inv,
transformer_engine::DType otype, int scale_offset = 0,
int amax_offset = 0, int scale_inv_offset = 0);

void fused_multi_cast_transpose(std::vector<at::Tensor> input_list,
std::vector<at::Tensor> scale_list,
Expand Down
8 changes: 3 additions & 5 deletions transformer_engine/pytorch/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0);
m.def("fused_dswiglu_cast_transpose", &fused_dswiglu_cast_transpose,
"Fused SwiGLU backward + FP8 cast + FP8 transpose",
py::call_guard<py::gil_scoped_release>(),
py::arg("grad_output"), py::arg("input"),
py::arg("grad_input"), py::arg("grad_input_transpose"),
py::arg("scale"), py::arg("amax"), py::arg("scale_inv"),
py::arg("otype"), py::arg("scale_offset") = 0,
py::call_guard<py::gil_scoped_release>(), py::arg("grad_output"), py::arg("input"),
py::arg("grad_input"), py::arg("grad_input_transpose"), py::arg("scale"), py::arg("amax"),
py::arg("scale_inv"), py::arg("otype"), py::arg("scale_offset") = 0,
py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0);
m.def("fused_multi_cast_transpose", &fused_multi_cast_transpose,
"Fused Multi-tensor Cast + Transpose", py::call_guard<py::gil_scoped_release>());
Expand Down
61 changes: 25 additions & 36 deletions transformer_engine/pytorch/csrc/extensions/transpose.cu
Original file line number Diff line number Diff line change
Expand Up @@ -196,17 +196,11 @@ std::vector<at::Tensor> fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output,
return {grad_bias, dgelu, dgelu_transpose};
}

void fused_dswiglu_cast_transpose(at::Tensor grad_output,
at::Tensor input,
at::Tensor grad_input,
at::Tensor grad_input_transpose,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype,
int scale_offset,
int amax_offset,
int scale_inv_offset) {
void fused_dswiglu_cast_transpose(at::Tensor grad_output, at::Tensor input, at::Tensor grad_input,
at::Tensor grad_input_transpose, at::Tensor scale,
at::Tensor amax, at::Tensor scale_inv,
transformer_engine::DType otype, int scale_offset,
int amax_offset, int scale_inv_offset) {
using namespace transformer_engine;

// Tensor dimensions
Expand All @@ -217,33 +211,28 @@ void fused_dswiglu_cast_transpose(at::Tensor grad_output,
const auto N = static_cast<size_t>(grad_output.size(-1));

// Check tensor dims
NVTE_CHECK(grad_output.dim() == 2,
"Expected grad output tensor to have 2 dims, but found ", grad_output.dim());
NVTE_CHECK(input.dim() == 2,
"Expected input tensor to have 2 dims, but found ", input.dim());
NVTE_CHECK(outer_dim(input) == M,
"Expected input tensor to have outer dimension of ",
M, ", but found ", outer_dim(input));
NVTE_CHECK(input.size(-1) == 2*N,
"Expected input tensor to have inner dimension of ",
2*N, ", but found ", input.size(-1));
NVTE_CHECK(grad_input.dim() == 2,
"Expected grad input tensor to have 2 dims, but found ", grad_input.dim());
NVTE_CHECK(outer_dim(grad_input) == M,
"Expected grad input tensor to have outer dimension of ",
NVTE_CHECK(grad_output.dim() == 2, "Expected grad output tensor to have 2 dims, but found ",
grad_output.dim());
NVTE_CHECK(input.dim() == 2, "Expected input tensor to have 2 dims, but found ", input.dim());
NVTE_CHECK(outer_dim(input) == M, "Expected input tensor to have outer dimension of ", M,
", but found ", outer_dim(input));
NVTE_CHECK(input.size(-1) == 2 * N, "Expected input tensor to have inner dimension of ", 2 * N,
", but found ", input.size(-1));
NVTE_CHECK(grad_input.dim() == 2, "Expected grad input tensor to have 2 dims, but found ",
grad_input.dim());
NVTE_CHECK(outer_dim(grad_input) == M, "Expected grad input tensor to have outer dimension of ",
M, ", but found ", outer_dim(grad_input));
NVTE_CHECK(grad_input.size(-1) == 2*N,
"Expected grad input tensor to have inner dimension of ",
2*N, ", but found ", grad_input.size(-1));
NVTE_CHECK(grad_input.size(-1) == 2 * N, "Expected grad input tensor to have inner dimension of ",
2 * N, ", but found ", grad_input.size(-1));
NVTE_CHECK(grad_input_transpose.dim() == 2,
"Expected grad input transpose tensor to have 2 dims, but found ",
grad_input_transpose.dim());
NVTE_CHECK(grad_input_transpose.size(0) == 2*N,
"Expected grad input tensor to have outer dimension of ",
2*N, ", but found ", grad_input_transpose.size(0));
NVTE_CHECK(grad_input_transpose.size(0) == 2 * N,
"Expected grad input tensor to have outer dimension of ", 2 * N, ", but found ",
grad_input_transpose.size(0));
NVTE_CHECK(grad_input_transpose.size(1) == M,
"Expected grad input tensor to have outer dimension of ",
M, ", but found ", grad_input_transpose.size(1));
"Expected grad input tensor to have outer dimension of ", M, ", but found ",
grad_input_transpose.size(1));

// Check tensor format
NVTE_CHECK(grad_output.is_contiguous(), "Expected grad output tensor to be contiguous");
Expand All @@ -266,10 +255,10 @@ void fused_dswiglu_cast_transpose(at::Tensor grad_output,
// Construct Transformer Engine tensors
auto dy_cu = makeTransformerEngineTensor(grad_output);
auto x_cu = makeTransformerEngineTensor(input);
auto dx_cu = makeTransformerEngineTensor(grad_input.data_ptr(), {M, 2*N}, otype, amax_dptr,
auto dx_cu = makeTransformerEngineTensor(grad_input.data_ptr(), {M, 2 * N}, otype, amax_dptr,
scale_dptr, scale_inv_dptr);
auto dx_t_cu = makeTransformerEngineTensor(grad_input_transpose.data_ptr(), {2*N, M},
otype, amax_dptr, scale_dptr, scale_inv_dptr);
auto dx_t_cu = makeTransformerEngineTensor(grad_input_transpose.data_ptr(), {2 * N, M}, otype,
amax_dptr, scale_dptr, scale_inv_dptr);

// Launch kernel
nvte_dswiglu_cast_transpose(dy_cu.data(), x_cu.data(), dx_cu.data(), dx_t_cu.data(),
Expand Down
6 changes: 1 addition & 5 deletions transformer_engine/pytorch/ops/basic/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,7 @@ def op_forward(
output_fp8_meta = None
output_dtype = TE_DType[dtype]
output_fp8_scale_inv = None
if (
fp8_enabled
and next_op is not None
and next_op.num_fp8_scales("input") > 0
):
if fp8_enabled and next_op is not None and next_op.num_fp8_scales("input") > 0:
with_fp8_output = True
fp8_meta = next_op.get_fp8_meta("input")
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True)
Expand Down

0 comments on commit 01327c4

Please sign in to comment.