Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Activation operations #1164

Merged
merged 17 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 160 additions & 0 deletions tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1362,6 +1362,166 @@ def test_make_extra_output(
torch.testing.assert_close(y2_test, y2_ref, rtol=0, atol=0)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)

@pytest.mark.parametrize("activation", ("relu", "gelu", "geglu", "reglu", "swiglu"))
@pytest.mark.parametrize("out_shape", ((37,), (2, 13), (4, 1, 16)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("fp8_input", (False, True))
@pytest.mark.parametrize("fp8_output", (False, True))
def test_activation(
self,
*,
activation: str,
out_shape: Iterable[int],
dtype: torch.dtype,
device: torch.device = "cuda",
fp8_input: bool,
fp8_output: bool,
) -> None:
"""Activation functions"""

# Tensor dimensions
in_shape = list(out_shape)
if activation in ("geglu", "reglu", "swiglu"):
in_shape[-1] *= 2

# Skip invalid configurations
if fp8_input or fp8_output:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")

# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8_input,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)

# Plain PyTorch implementation
y_ref: torch.Tensor
if activation == "gelu":
y_ref = torch.nn.functional.gelu(x_ref, approximate="tanh")
elif activation == "relu":
y_ref = torch.nn.functional.relu(x_ref)
elif activation == "geglu":
x1, x2 = x_ref.chunk(2, dim=-1)
y_ref = torch.nn.functional.gelu(x1, approximate="tanh") * x2
elif activation == "reglu":
x1, x2 = x_ref.chunk(2, dim=-1)
y_ref = torch.nn.functional.relu(x1) * x2
elif activation == "swiglu":
x1, x2 = x_ref.chunk(2, dim=-1)
y_ref = torch.nn.functional.silu(x1) * x2
else:
raise ValueError(f"Unexpected activation function ({activation})")
y_ref.backward(dy_ref)

# Implementation with fusible operation
make_op = dict(
gelu=te_ops.GELU,
relu=te_ops.ReLU,
geglu=te_ops.GEGLU,
reglu=te_ops.ReGLU,
swiglu=te_ops.SwiGLU,
)[activation]
forward = te_ops.Sequential(
make_op(),
te_ops.Quantize(forward=fp8_output, backward=False),
)
with te.fp8_autocast(enabled=fp8_output):
y_test = forward(x_test)
y_test.backward(dy_test)

# Expected numerical error
tols = dtype_tols(dtype)
if fp8_output:
tols = dtype_tols(tex.DType.kFloat8E4M3)

# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)

@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("fp8_output", (False, True))
@pytest.mark.parametrize("fp8_grad_input", (False, True))
def test_swiglu(
self,
*,
out_shape: Iterable[int] = (16, 16),
dtype: torch.dtype,
device: torch.device = "cuda",
fp8_output: bool,
fp8_grad_input: bool,
):

# Tensor dimensions
in_shape = list(out_shape)
in_shape[-1] *= 2

# Skip invalid configurations
fp8 = fp8_output or fp8_grad_input
if fp8:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")

# FP8 recipe
fp8_recipe = None
if fp8_grad_input:
fp8_recipe = transformer_engine.common.recipe.DelayedScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)

# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)

# Plain PyTorch implementation
x1, x2 = x_ref.chunk(2, dim=-1)
y_ref = torch.nn.functional.silu(x1) * x2
y_ref.backward(dy_ref)

# Implementation with fusible operation
forward = te_ops.Sequential(
te_ops.Quantize(forward=False, backward=fp8_grad_input),
te_ops.SwiGLU(),
te_ops.Quantize(forward=fp8_output, backward=False),
)
with te.fp8_autocast(enabled=fp8, fp8_recipe=fp8_recipe):
y_test = forward(x_test)
y_test.backward(dy_test)

# Expected numerical error
tols = dtype_tols(dtype)
if fp8:
tols = dtype_tols(tex.DType.kFloat8E4M3)

# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)


class TestFusedOps:
"""Tests for fused operations"""
Expand Down
39 changes: 39 additions & 0 deletions transformer_engine/pytorch/cpp_extensions/transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"fp8_cast_transpose_fused",
"fp8_cast_transpose_bgrad_fused",
"fp8_cast_transpose_bgrad_dgelu_fused",
"fp8_dswiglu_cast_transpose_fused",
"fp8_multi_cast_transpose_fused",
"fp8_transpose_bgrad_fused",
]
Expand Down Expand Up @@ -168,6 +169,44 @@ def fp8_cast_transpose_bgrad_dgelu_fused(
)


def fp8_dswiglu_cast_transpose_fused(
grad_output: torch.Tensor,
inp: torch.Tensor,
*,
grad_input: torch.Tensor,
grad_input_transpose: torch.Tensor,
otype: tex.DType,
fp8_meta: Optional[tex.FP8TensorMeta] = None,
fp8_meta_index: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None] = None,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
) -> None:
"""Fused SwiGLU backward + FP8 cast + FP8 transpose"""

# Get FP8 scaling factors
fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales(
scale=scale,
amax=amax,
scale_inv=scale_inv,
fp8_meta=fp8_meta,
fp8_meta_index=fp8_meta_index,
)

# Launch kernel
return tex.fused_dswiglu_cast_transpose(
grad_output,
inp,
grad_input,
grad_input_transpose,
fp8_scales["scale"],
fp8_scales["amax"],
fp8_scales["scale_inv"],
otype,
**fp8_scales_offsets,
)


def fp8_multi_cast_transpose_fused(
input_list: List[torch.Tensor],
fp8_meta_tensor: tex.FP8TensorMeta,
Expand Down
6 changes: 6 additions & 0 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,12 @@ std::vector<at::Tensor> fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output,
int scale_offset = 0, int amax_offset = 0,
int scale_inv_offset = 0);

void fused_dswiglu_cast_transpose(at::Tensor grad_output, at::Tensor input, at::Tensor grad_input,
at::Tensor grad_input_transpose, at::Tensor scale,
at::Tensor amax, at::Tensor scale_inv,
transformer_engine::DType otype, int scale_offset = 0,
int amax_offset = 0, int scale_inv_offset = 0);

void fused_multi_cast_transpose(std::vector<at::Tensor> input_list,
std::vector<at::Tensor> scale_list,
std::vector<at::Tensor> cast_output_list,
Expand Down
6 changes: 6 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("grad_output"), py::arg("gelu_input"), py::arg("scale"), py::arg("amax"),
py::arg("scale_inv"), py::arg("otype"), py::arg("scale_offset") = 0,
py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0);
m.def("fused_dswiglu_cast_transpose", &fused_dswiglu_cast_transpose,
"Fused SwiGLU backward + FP8 cast + FP8 transpose",
py::call_guard<py::gil_scoped_release>(), py::arg("grad_output"), py::arg("input"),
py::arg("grad_input"), py::arg("grad_input_transpose"), py::arg("scale"), py::arg("amax"),
py::arg("scale_inv"), py::arg("otype"), py::arg("scale_offset") = 0,
py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0);
m.def("fused_multi_cast_transpose", &fused_multi_cast_transpose,
"Fused Multi-tensor Cast + Transpose", py::call_guard<py::gil_scoped_release>());
m.def("fused_multi_cast_transpose_alloc", &fused_multi_cast_transpose_alloc,
Expand Down
69 changes: 69 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,75 @@ std::vector<at::Tensor> fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output,
return {grad_bias, dgelu, dgelu_transpose};
}

void fused_dswiglu_cast_transpose(at::Tensor grad_output, at::Tensor input, at::Tensor grad_input,
at::Tensor grad_input_transpose, at::Tensor scale,
at::Tensor amax, at::Tensor scale_inv,
transformer_engine::DType otype, int scale_offset,
int amax_offset, int scale_inv_offset) {
using namespace transformer_engine;

// Tensor dimensions
auto outer_dim = [](const at::Tensor& tensor) -> size_t {
return tensor.numel() / tensor.size(-1);
};
const auto M = outer_dim(grad_output);
const auto N = static_cast<size_t>(grad_output.size(-1));

// Check tensor dims
NVTE_CHECK(grad_output.dim() == 2, "Expected grad output tensor to have 2 dims, but found ",
grad_output.dim());
NVTE_CHECK(input.dim() == 2, "Expected input tensor to have 2 dims, but found ", input.dim());
NVTE_CHECK(outer_dim(input) == M, "Expected input tensor to have outer dimension of ", M,
", but found ", outer_dim(input));
NVTE_CHECK(input.size(-1) == 2 * N, "Expected input tensor to have inner dimension of ", 2 * N,
", but found ", input.size(-1));
NVTE_CHECK(grad_input.dim() == 2, "Expected grad input tensor to have 2 dims, but found ",
grad_input.dim());
NVTE_CHECK(outer_dim(grad_input) == M, "Expected grad input tensor to have outer dimension of ",
M, ", but found ", outer_dim(grad_input));
NVTE_CHECK(grad_input.size(-1) == 2 * N, "Expected grad input tensor to have inner dimension of ",
2 * N, ", but found ", grad_input.size(-1));
NVTE_CHECK(grad_input_transpose.dim() == 2,
"Expected grad input transpose tensor to have 2 dims, but found ",
grad_input_transpose.dim());
NVTE_CHECK(grad_input_transpose.size(0) == 2 * N,
"Expected grad input tensor to have outer dimension of ", 2 * N, ", but found ",
grad_input_transpose.size(0));
NVTE_CHECK(grad_input_transpose.size(1) == M,
"Expected grad input tensor to have outer dimension of ", M, ", but found ",
grad_input_transpose.size(1));

// Check tensor format
NVTE_CHECK(grad_output.is_contiguous(), "Expected grad output tensor to be contiguous");
NVTE_CHECK(input.is_contiguous(), "Expected input tensor to be contiguous");
NVTE_CHECK(grad_input.is_contiguous(), "Expected grad input tensor to be contiguous");
NVTE_CHECK(grad_input_transpose.is_contiguous(),
"Expected grad input transpose tensor to be contiguous");
NVTE_CHECK(grad_output.scalar_type() == input.scalar_type(),
"Expected grad output tensor and input tensor to have same dtype");
NVTE_CHECK(grad_input.scalar_type() == at::ScalarType::Byte,
"Expected grad input tensor to be uint8 buffer");
NVTE_CHECK(grad_input_transpose.scalar_type() == at::ScalarType::Byte,
"Expected grad input transpose tensor to be uint8 buffer");

// Get pointers for FP8 scale, amax, scale-inverse
void* scale_dptr = getDataPtr(scale, scale_offset);
void* amax_dptr = getDataPtr(amax, amax_offset);
void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset);

// Construct Transformer Engine tensors
auto dy_cu = makeTransformerEngineTensor(grad_output);
auto x_cu = makeTransformerEngineTensor(input);
auto dx_cu = makeTransformerEngineTensor(grad_input.data_ptr(), {M, 2 * N}, otype, amax_dptr,
scale_dptr, scale_inv_dptr);
auto dx_t_cu = makeTransformerEngineTensor(grad_input_transpose.data_ptr(), {2 * N, M}, otype,
amax_dptr, scale_dptr, scale_inv_dptr);

// Launch kernel
nvte_dswiglu_cast_transpose(dy_cu.data(), x_cu.data(), dx_cu.data(), dx_t_cu.data(),
at::cuda::getCurrentCUDAStream());
}

void fused_multi_cast_transpose_base(std::vector<at::Tensor> input_list,
std::vector<void*> scale_dptr_list,
std::vector<at::Tensor> cast_output_list,
Expand Down
1 change: 1 addition & 0 deletions transformer_engine/pytorch/ops/basic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

"""Single tensor operations supported by the operation fuser."""

from .activation import GELU, ReLU, GEGLU, ReGLU, SwiGLU
from .add_in_place import AddInPlace
from .all_gather import AllGather
from .all_reduce import AllReduce
Expand Down
Loading
Loading