Skip to content

Commit

Permalink
Add support for fused dSwiGLU-cast-transpose
Browse files Browse the repository at this point in the history
Signed-off-by: Tim Moon <tmoon@nvidia.com>
  • Loading branch information
timmoon10 committed Sep 27, 2024
1 parent 6db1ddb commit ada6804
Show file tree
Hide file tree
Showing 6 changed files with 299 additions and 1 deletion.
71 changes: 71 additions & 0 deletions tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,6 +940,77 @@ def test_activation(
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)

@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("fp8_output", (False, True))
@pytest.mark.parametrize("fp8_grad_input", (False, True))
def test_swiglu(
self,
*,
out_shape: Iterable[int] = (16, 16),
dtype: torch.dtype,
device: torch.device = "cuda",
fp8_output: bool,
fp8_grad_input: bool,
):

# Tensor dimensions
in_shape = list(out_shape)
in_shape[-1] *= 2

# Skip invalid configurations
fp8 = fp8_output or fp8_grad_input
if fp8:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")

# FP8 recipe
fp8_recipe = None
if fp8_grad_input:
fp8_recipe = transformer_engine.common.recipe.DelayedScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)

# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)

# Plain PyTorch implementation
x1, x2 = x_ref.chunk(2, dim=-1)
y_ref = torch.nn.functional.silu(x1) * x2
y_ref.backward(dy_ref)

# Implementation with fusible operation
forward = te_ops.Sequential(
te_ops.Quantize(forward=False, backward=fp8_grad_input),
te_ops.SwiGLU(),
te_ops.Quantize(forward=fp8_output, backward=False),
)
with te.fp8_autocast(enabled=fp8, fp8_recipe=fp8_recipe):
y_test = forward(x_test)
y_test.backward(dy_test)

# Expected numerical error
tols = dtype_tols(dtype)
if fp8:
tols = dtype_tols(tex.DType.kFloat8E4M3)

# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)


class TestFusedOps:
"""Tests for fused operations"""
Expand Down
39 changes: 39 additions & 0 deletions transformer_engine/pytorch/cpp_extensions/transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"fp8_cast_transpose_fused",
"fp8_cast_transpose_bgrad_fused",
"fp8_cast_transpose_bgrad_dgelu_fused",
"fp8_dswiglu_cast_transpose_fused",
"fp8_multi_cast_transpose_fused",
"fp8_transpose_bgrad_fused",
]
Expand Down Expand Up @@ -168,6 +169,44 @@ def fp8_cast_transpose_bgrad_dgelu_fused(
)


def fp8_dswiglu_cast_transpose_fused(
grad_output: torch.Tensor,
inp: torch.Tensor,
*,
grad_input: torch.Tensor,
grad_input_transpose: torch.Tensor,
otype: tex.DType,
fp8_meta: Optional[tex.FP8TensorMeta] = None,
fp8_meta_index: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None] = None,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
) -> None:
"""Fused SwiGLU backward + FP8 cast + FP8 transpose"""

# Get FP8 scaling factors
fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales(
scale=scale,
amax=amax,
scale_inv=scale_inv,
fp8_meta=fp8_meta,
fp8_meta_index=fp8_meta_index,
)

# Launch kernel
return tex.fused_dswiglu_cast_transpose(
grad_output,
inp,
grad_input,
grad_input_transpose,
fp8_scales["scale"],
fp8_scales["amax"],
fp8_scales["scale_inv"],
otype,
**fp8_scales_offsets,
)


def fp8_multi_cast_transpose_fused(
input_list: List[torch.Tensor],
fp8_meta_tensor: tex.FP8TensorMeta,
Expand Down
12 changes: 12 additions & 0 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,18 @@ std::vector<at::Tensor> fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output,
int scale_offset = 0, int amax_offset = 0,
int scale_inv_offset = 0);

void fused_dswiglu_cast_transpose(at::Tensor grad_output,
at::Tensor input,
at::Tensor grad_input,
at::Tensor grad_input_transpose,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype,
int scale_offset = 0,
int amax_offset = 0,
int scale_inv_offset = 0);

void fused_multi_cast_transpose(std::vector<at::Tensor> input_list,
std::vector<at::Tensor> scale_list,
std::vector<at::Tensor> cast_output_list,
Expand Down
8 changes: 8 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("grad_output"), py::arg("gelu_input"), py::arg("scale"), py::arg("amax"),
py::arg("scale_inv"), py::arg("otype"), py::arg("scale_offset") = 0,
py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0);
m.def("fused_dswiglu_cast_transpose", &fused_dswiglu_cast_transpose,
"Fused SwiGLU backward + FP8 cast + FP8 transpose",
py::call_guard<py::gil_scoped_release>(),
py::arg("grad_output"), py::arg("input"),
py::arg("grad_input"), py::arg("grad_input_transpose"),
py::arg("scale"), py::arg("amax"), py::arg("scale_inv"),
py::arg("otype"), py::arg("scale_offset") = 0,
py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0);
m.def("fused_multi_cast_transpose", &fused_multi_cast_transpose,
"Fused Multi-tensor Cast + Transpose", py::call_guard<py::gil_scoped_release>());
m.def("fused_multi_cast_transpose_alloc", &fused_multi_cast_transpose_alloc,
Expand Down
80 changes: 80 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/transpose.cu
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,86 @@ std::vector<at::Tensor> fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output,
return {grad_bias, dgelu, dgelu_transpose};
}

void fused_dswiglu_cast_transpose(at::Tensor grad_output,
at::Tensor input,
at::Tensor grad_input,
at::Tensor grad_input_transpose,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype,
int scale_offset,
int amax_offset,
int scale_inv_offset) {
using namespace transformer_engine;

// Tensor dimensions
auto outer_dim = [](const at::Tensor& tensor) -> size_t {
return tensor.numel() / tensor.size(-1);
};
const auto M = outer_dim(grad_output);
const auto N = static_cast<size_t>(grad_output.size(-1));

// Check tensor dims
NVTE_CHECK(grad_output.dim() == 2,
"Expected grad output tensor to have 2 dims, but found ", grad_output.dim());
NVTE_CHECK(input.dim() == 2,
"Expected input tensor to have 2 dims, but found ", input.dim());
NVTE_CHECK(outer_dim(input) == M,
"Expected input tensor to have outer dimension of ",
M, ", but found ", outer_dim(input));
NVTE_CHECK(input.size(-1) == 2*N,
"Expected input tensor to have inner dimension of ",
2*N, ", but found ", input.size(-1));
NVTE_CHECK(grad_input.dim() == 2,
"Expected grad input tensor to have 2 dims, but found ", grad_input.dim());
NVTE_CHECK(outer_dim(grad_input) == M,
"Expected grad input tensor to have outer dimension of ",
M, ", but found ", outer_dim(grad_input));
NVTE_CHECK(grad_input.size(-1) == 2*N,
"Expected grad input tensor to have inner dimension of ",
2*N, ", but found ", grad_input.size(-1));
NVTE_CHECK(grad_input_transpose.dim() == 2,
"Expected grad input transpose tensor to have 2 dims, but found ",
grad_input_transpose.dim());
NVTE_CHECK(grad_input_transpose.size(0) == 2*N,
"Expected grad input tensor to have outer dimension of ",
2*N, ", but found ", grad_input_transpose.size(0));
NVTE_CHECK(grad_input_transpose.size(1) == M,
"Expected grad input tensor to have outer dimension of ",
M, ", but found ", grad_input_transpose.size(1));

// Check tensor format
NVTE_CHECK(grad_output.is_contiguous(), "Expected grad output tensor to be contiguous");
NVTE_CHECK(input.is_contiguous(), "Expected input tensor to be contiguous");
NVTE_CHECK(grad_input.is_contiguous(), "Expected grad input tensor to be contiguous");
NVTE_CHECK(grad_input_transpose.is_contiguous(),
"Expected grad input transpose tensor to be contiguous");
NVTE_CHECK(grad_output.scalar_type() == input.scalar_type(),
"Expected grad output tensor and input tensor to have same dtype");
NVTE_CHECK(grad_input.scalar_type() == at::ScalarType::Byte,
"Expected grad input tensor to be uint8 buffer");
NVTE_CHECK(grad_input_transpose.scalar_type() == at::ScalarType::Byte,
"Expected grad input transpose tensor to be uint8 buffer");

// Get pointers for FP8 scale, amax, scale-inverse
void* scale_dptr = getDataPtr(scale, scale_offset);
void* amax_dptr = getDataPtr(amax, amax_offset);
void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset);

// Construct Transformer Engine tensors
auto dy_cu = makeTransformerEngineTensor(grad_output);
auto x_cu = makeTransformerEngineTensor(input);
auto dx_cu = makeTransformerEngineTensor(grad_input.data_ptr(), {M, 2*N}, otype, amax_dptr,
scale_dptr, scale_inv_dptr);
auto dx_t_cu = makeTransformerEngineTensor(grad_input_transpose.data_ptr(), {2*N, M},
otype, amax_dptr, scale_dptr, scale_inv_dptr);

// Launch kernel
nvte_dswiglu_cast_transpose(dy_cu.data(), x_cu.data(), dx_cu.data(), dx_t_cu.data(),
at::cuda::getCurrentCUDAStream());
}

void fused_multi_cast_transpose_base(std::vector<at::Tensor> input_list,
std::vector<void*> scale_dptr_list,
std::vector<at::Tensor> cast_output_list,
Expand Down
90 changes: 89 additions & 1 deletion transformer_engine/pytorch/ops/basic/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
reglu as tex_reglu,
relu as tex_relu,
swiglu as tex_swiglu,
fp8_dswiglu_cast_transpose_fused,
)
from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype
from ...tensor import Float8Tensor, QuantizedTensor
Expand Down Expand Up @@ -91,12 +92,13 @@ def op_forward(
x = x.contiguous()

# Check if FP8 is enabled
fp8_enabled = FP8GlobalStateManager.is_fp8_enabled()
with_fp8_output = False
output_fp8_meta = None
output_dtype = TE_DType[dtype]
output_fp8_scale_inv = None
if (
FP8GlobalStateManager.is_fp8_enabled()
fp8_enabled
and next_op is not None
and next_op.num_fp8_scales("input") > 0
):
Expand Down Expand Up @@ -132,6 +134,7 @@ def op_forward(

# Save state for backward pass
ctx.save_for_backward(x)
ctx.fp8_enabled = fp8_enabled
ctx.prev_op = prev_op

return y
Expand Down Expand Up @@ -304,3 +307,88 @@ def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor:

def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
return transformer_engine_torch.dswiglu(*args, **kwargs)

def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, tuple[()]]:

# Saved tensors from forward pass
(x,) = ctx.saved_tensors

# Tensor attributes
dtype = x.dtype
device = x.device

# Check grad output tensor
dy = grad_output
if isinstance(dy, QuantizedTensor):
dy = dy.dequantize()
if not devices_match(dy.device, device) or dy.dtype != dtype:
dy = dy.to(device=device, dtype=dtype)
if not dy.is_contiguous():
dy = dy.contiguous()

# Check if FP8 is enabled
with_fp8_grad_input = False
grad_input_fp8_meta = None
grad_input_dtype = TE_DType[dtype]
grad_input_fp8_scale_inv = None
if (
ctx.fp8_enabled
and ctx.prev_op is not None
and ctx.prev_op.num_fp8_scales("grad_output") > 0
):
with_fp8_grad_input = True
fp8_meta = ctx.prev_op.get_fp8_meta("grad_output")
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=False)
grad_input_fp8_meta = fp8_meta[fp8_meta_key]
grad_input_dtype = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False)
grad_input_fp8_scale_inv = torch.empty([1], dtype=torch.float32, device=device)

# Launch kernel
if with_fp8_grad_input:
# Fused with FP8 cast-transpose
input_dims = x.size()
flat_input_dims = [x.numel() // input_dims[-1], input_dims[-1]]
flat_output_dims = [flat_input_dims[0], flat_input_dims[1] // 2]
dx = torch.empty(input_dims, dtype=torch.uint8, device=device)
dx_t = torch.empty(
(flat_input_dims[1], flat_input_dims[0]),
dtype=torch.uint8,
device=device,
)
fp8_dswiglu_cast_transpose_fused(
dy.reshape(flat_output_dims),
x.reshape(flat_input_dims),
grad_input=dx.reshape(flat_input_dims),
grad_input_transpose=dx_t,
otype=grad_input_dtype,
fp8_meta=grad_input_fp8_meta,
fp8_meta_index=0,
scale_inv=grad_input_fp8_scale_inv,
)
dx = Float8Tensor(
data=dx,
fp8_meta=grad_input_fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=0,
fp8_dtype=grad_input_dtype,
fp8_scale_inv=grad_input_fp8_scale_inv,
dtype=dtype,
)
dx._transpose = dx_t
dx._transpose_invalid = False
else:
# Standard impl
dx = self._activation_backward_impl(dy, x, TE_DType[dtype])
if dx.size() != x.size():
dx = dx.reshape(x.size())

# Note: This fails if op is preceeded by an identity op like Quantize(forward=False)
# # Clear input tensor if possible
# if ctx.prev_op is not None:
# clear_tensor_data(x)

return dx, ()

0 comments on commit ada6804

Please sign in to comment.