Skip to content

Commit

Permalink
Merge branch 'main' into debug-mcore-norm
Browse files Browse the repository at this point in the history
  • Loading branch information
timmoon10 authored Nov 15, 2024
2 parents 49a8f3c + 20b0473 commit 1014133
Show file tree
Hide file tree
Showing 14 changed files with 705 additions and 97 deletions.
162 changes: 161 additions & 1 deletion tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def test_fp8_scale_update(
)

# Check that scaling factors match expected
w_amax_ref = max(w_vals[: step + 2])
w_amax_ref = max(w_vals[: step + 1])
x_amax_ref = max(x_vals[: step + 1])
dy_amax_ref = max(dy_vals[: step + 1])
w_scale_ref = (fp8_format.value.max_fwd / w_amax_ref) / (2**margin)
Expand Down Expand Up @@ -1362,6 +1362,166 @@ def test_make_extra_output(
torch.testing.assert_close(y2_test, y2_ref, rtol=0, atol=0)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)

@pytest.mark.parametrize("activation", ("relu", "gelu", "geglu", "reglu", "swiglu"))
@pytest.mark.parametrize("out_shape", ((37,), (2, 13), (4, 1, 16)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("fp8_input", (False, True))
@pytest.mark.parametrize("fp8_output", (False, True))
def test_activation(
self,
*,
activation: str,
out_shape: Iterable[int],
dtype: torch.dtype,
device: torch.device = "cuda",
fp8_input: bool,
fp8_output: bool,
) -> None:
"""Activation functions"""

# Tensor dimensions
in_shape = list(out_shape)
if activation in ("geglu", "reglu", "swiglu"):
in_shape[-1] *= 2

# Skip invalid configurations
if fp8_input or fp8_output:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")

# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8_input,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)

# Plain PyTorch implementation
y_ref: torch.Tensor
if activation == "gelu":
y_ref = torch.nn.functional.gelu(x_ref, approximate="tanh")
elif activation == "relu":
y_ref = torch.nn.functional.relu(x_ref)
elif activation == "geglu":
x1, x2 = x_ref.chunk(2, dim=-1)
y_ref = torch.nn.functional.gelu(x1, approximate="tanh") * x2
elif activation == "reglu":
x1, x2 = x_ref.chunk(2, dim=-1)
y_ref = torch.nn.functional.relu(x1) * x2
elif activation == "swiglu":
x1, x2 = x_ref.chunk(2, dim=-1)
y_ref = torch.nn.functional.silu(x1) * x2
else:
raise ValueError(f"Unexpected activation function ({activation})")
y_ref.backward(dy_ref)

# Implementation with fusible operation
make_op = dict(
gelu=te_ops.GELU,
relu=te_ops.ReLU,
geglu=te_ops.GEGLU,
reglu=te_ops.ReGLU,
swiglu=te_ops.SwiGLU,
)[activation]
forward = te_ops.Sequential(
make_op(),
te_ops.Quantize(forward=fp8_output, backward=False),
)
with te.fp8_autocast(enabled=fp8_output):
y_test = forward(x_test)
y_test.backward(dy_test)

# Expected numerical error
tols = dtype_tols(dtype)
if fp8_output:
tols = dtype_tols(tex.DType.kFloat8E4M3)

# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)

@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("fp8_output", (False, True))
@pytest.mark.parametrize("fp8_grad_input", (False, True))
def test_swiglu(
self,
*,
out_shape: Iterable[int] = (16, 16),
dtype: torch.dtype,
device: torch.device = "cuda",
fp8_output: bool,
fp8_grad_input: bool,
):

# Tensor dimensions
in_shape = list(out_shape)
in_shape[-1] *= 2

# Skip invalid configurations
fp8 = fp8_output or fp8_grad_input
if fp8:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")

# FP8 recipe
fp8_recipe = None
if fp8_grad_input:
fp8_recipe = transformer_engine.common.recipe.DelayedScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)

# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)

# Plain PyTorch implementation
x1, x2 = x_ref.chunk(2, dim=-1)
y_ref = torch.nn.functional.silu(x1) * x2
y_ref.backward(dy_ref)

# Implementation with fusible operation
forward = te_ops.Sequential(
te_ops.Quantize(forward=False, backward=fp8_grad_input),
te_ops.SwiGLU(),
te_ops.Quantize(forward=fp8_output, backward=False),
)
with te.fp8_autocast(enabled=fp8, fp8_recipe=fp8_recipe):
y_test = forward(x_test)
y_test.backward(dy_test)

# Expected numerical error
tols = dtype_tols(dtype)
if fp8:
tols = dtype_tols(tex.DType.kFloat8E4M3)

# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)


class TestFusedOps:
"""Tests for fused operations"""
Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/jax/csrc/extensions/activation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,8 @@ Error_Type DActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type act
auto *output = output_buf->untyped_data();

auto act_input_dims = act_input_buf.dimensions();
auto m = product(act_input_dims, 0, act_input_dims.size() - 2);
auto n = act_input_dims.back();
auto m = static_cast<size_t>(product(act_input_dims, 0, act_input_dims.size() - 2));
auto n = static_cast<size_t>(act_input_dims.back());
auto act_len = act_input_dims.end()[-2];

auto input_shape = std::vector<size_t>{m, n};
Expand Down
35 changes: 18 additions & 17 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2528,12 +2528,13 @@ def backward(ctx, dout):
recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a]
batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)

(q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded) = ctx.saved_tensors[:6]
(fp8_fwd_scales, fp8_fwd_scale_invs) = ctx.saved_tensors[6:8]
cu_seqlens_q_per_step = ctx.saved_tensors[8 : 8 + cp_size]
cu_seqlens_kv_per_step = ctx.saved_tensors[8 + cp_size : 8 + cp_size * 2]
rng_states = ctx.saved_tensors[8 + cp_size * 2 : 8 + cp_size * 3]
attn_biases = ctx.saved_tensors[8 + cp_size * 3 : 8 + cp_size * 4]
(*saved_tensors,) = ctx.saved_tensors
(q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded) = saved_tensors[:6]
(fp8_fwd_scales, fp8_fwd_scale_invs) = saved_tensors[6:8]
cu_seqlens_q_per_step = saved_tensors[8 : 8 + cp_size]
cu_seqlens_kv_per_step = saved_tensors[8 + cp_size : 8 + cp_size * 2]
rng_states = saved_tensors[8 + cp_size * 2 : 8 + cp_size * 3]
attn_biases = saved_tensors[8 + cp_size * 3 : 8 + cp_size * 4]

causal = "causal" in ctx.attn_mask_type
padding = "padding" in ctx.attn_mask_type
Expand Down Expand Up @@ -3577,11 +3578,12 @@ def backward(ctx, dout):
cp_size = get_distributed_world_size(ctx.cp_group)
rank = get_distributed_rank(ctx.cp_group)

(q, k, v, cu_seqlens_q, cu_seqlens_q_padded) = ctx.saved_tensors[:5]
cu_seqlens_kv_per_step = ctx.saved_tensors[5:7]
out_per_step = ctx.saved_tensors[7:9]
softmax_lse_per_step = ctx.saved_tensors[9:11]
rng_states = ctx.saved_tensors[11:13]
(*saved_tensors,) = ctx.saved_tensors
(q, k, v, cu_seqlens_q, cu_seqlens_q_padded) = saved_tensors[:5]
cu_seqlens_kv_per_step = saved_tensors[5:7]
out_per_step = saved_tensors[7:9]
softmax_lse_per_step = saved_tensors[9:11]
rng_states = saved_tensors[11:13]
kv_seq_range_per_step = ctx.kv_seq_range_per_step
window_size_per_step = ctx.window_size_per_step

Expand Down Expand Up @@ -4056,12 +4058,11 @@ def backward(ctx, dout):
# pylint: disable=missing-function-docstring
cp_size = get_distributed_world_size(ctx.cp_group)

q, k, v, out = ctx.saved_tensors[:4]
cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded = ctx.saved_tensors[
4:8
]
fp8_fwd_scales, fp8_fwd_scale_invs = ctx.saved_tensors[8:10]
aux_ctx_tensors = ctx.saved_tensors[10:]
(*saved_tensors,) = ctx.saved_tensors
q, k, v, out = saved_tensors[:4]
cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded = saved_tensors[4:8]
fp8_fwd_scales, fp8_fwd_scale_invs = saved_tensors[8:10]
aux_ctx_tensors = saved_tensors[10:]

qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format
causal = "causal" in ctx.attn_mask_type
Expand Down
39 changes: 39 additions & 0 deletions transformer_engine/pytorch/cpp_extensions/transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"fp8_cast_transpose_fused",
"fp8_cast_transpose_bgrad_fused",
"fp8_cast_transpose_bgrad_dgelu_fused",
"fp8_dswiglu_cast_transpose_fused",
"fp8_multi_cast_transpose_fused",
"fp8_transpose_bgrad_fused",
]
Expand Down Expand Up @@ -168,6 +169,44 @@ def fp8_cast_transpose_bgrad_dgelu_fused(
)


def fp8_dswiglu_cast_transpose_fused(
grad_output: torch.Tensor,
inp: torch.Tensor,
*,
grad_input: torch.Tensor,
grad_input_transpose: torch.Tensor,
otype: tex.DType,
fp8_meta: Optional[tex.FP8TensorMeta] = None,
fp8_meta_index: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None] = None,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
) -> None:
"""Fused SwiGLU backward + FP8 cast + FP8 transpose"""

# Get FP8 scaling factors
fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales(
scale=scale,
amax=amax,
scale_inv=scale_inv,
fp8_meta=fp8_meta,
fp8_meta_index=fp8_meta_index,
)

# Launch kernel
return tex.fused_dswiglu_cast_transpose(
grad_output,
inp,
grad_input,
grad_input_transpose,
fp8_scales["scale"],
fp8_scales["amax"],
fp8_scales["scale_inv"],
otype,
**fp8_scales_offsets,
)


def fp8_multi_cast_transpose_fused(
input_list: List[torch.Tensor],
fp8_meta_tensor: tex.FP8TensorMeta,
Expand Down
6 changes: 6 additions & 0 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,12 @@ std::vector<at::Tensor> fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output,
int scale_offset = 0, int amax_offset = 0,
int scale_inv_offset = 0);

void fused_dswiglu_cast_transpose(at::Tensor grad_output, at::Tensor input, at::Tensor grad_input,
at::Tensor grad_input_transpose, at::Tensor scale,
at::Tensor amax, at::Tensor scale_inv,
transformer_engine::DType otype, int scale_offset = 0,
int amax_offset = 0, int scale_inv_offset = 0);

void fused_multi_cast_transpose(std::vector<at::Tensor> input_list,
std::vector<at::Tensor> scale_list,
std::vector<at::Tensor> cast_output_list,
Expand Down
6 changes: 6 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("grad_output"), py::arg("gelu_input"), py::arg("scale"), py::arg("amax"),
py::arg("scale_inv"), py::arg("otype"), py::arg("scale_offset") = 0,
py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0);
m.def("fused_dswiglu_cast_transpose", &fused_dswiglu_cast_transpose,
"Fused SwiGLU backward + FP8 cast + FP8 transpose",
py::call_guard<py::gil_scoped_release>(), py::arg("grad_output"), py::arg("input"),
py::arg("grad_input"), py::arg("grad_input_transpose"), py::arg("scale"), py::arg("amax"),
py::arg("scale_inv"), py::arg("otype"), py::arg("scale_offset") = 0,
py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0);
m.def("fused_multi_cast_transpose", &fused_multi_cast_transpose,
"Fused Multi-tensor Cast + Transpose", py::call_guard<py::gil_scoped_release>());
m.def("fused_multi_cast_transpose_alloc", &fused_multi_cast_transpose_alloc,
Expand Down
69 changes: 69 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,75 @@ std::vector<at::Tensor> fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output,
return {grad_bias, dgelu, dgelu_transpose};
}

void fused_dswiglu_cast_transpose(at::Tensor grad_output, at::Tensor input, at::Tensor grad_input,
at::Tensor grad_input_transpose, at::Tensor scale,
at::Tensor amax, at::Tensor scale_inv,
transformer_engine::DType otype, int scale_offset,
int amax_offset, int scale_inv_offset) {
using namespace transformer_engine;

// Tensor dimensions
auto outer_dim = [](const at::Tensor& tensor) -> size_t {
return tensor.numel() / tensor.size(-1);
};
const auto M = outer_dim(grad_output);
const auto N = static_cast<size_t>(grad_output.size(-1));

// Check tensor dims
NVTE_CHECK(grad_output.dim() == 2, "Expected grad output tensor to have 2 dims, but found ",
grad_output.dim());
NVTE_CHECK(input.dim() == 2, "Expected input tensor to have 2 dims, but found ", input.dim());
NVTE_CHECK(outer_dim(input) == M, "Expected input tensor to have outer dimension of ", M,
", but found ", outer_dim(input));
NVTE_CHECK(input.size(-1) == 2 * N, "Expected input tensor to have inner dimension of ", 2 * N,
", but found ", input.size(-1));
NVTE_CHECK(grad_input.dim() == 2, "Expected grad input tensor to have 2 dims, but found ",
grad_input.dim());
NVTE_CHECK(outer_dim(grad_input) == M, "Expected grad input tensor to have outer dimension of ",
M, ", but found ", outer_dim(grad_input));
NVTE_CHECK(grad_input.size(-1) == 2 * N, "Expected grad input tensor to have inner dimension of ",
2 * N, ", but found ", grad_input.size(-1));
NVTE_CHECK(grad_input_transpose.dim() == 2,
"Expected grad input transpose tensor to have 2 dims, but found ",
grad_input_transpose.dim());
NVTE_CHECK(grad_input_transpose.size(0) == 2 * N,
"Expected grad input tensor to have outer dimension of ", 2 * N, ", but found ",
grad_input_transpose.size(0));
NVTE_CHECK(grad_input_transpose.size(1) == M,
"Expected grad input tensor to have outer dimension of ", M, ", but found ",
grad_input_transpose.size(1));

// Check tensor format
NVTE_CHECK(grad_output.is_contiguous(), "Expected grad output tensor to be contiguous");
NVTE_CHECK(input.is_contiguous(), "Expected input tensor to be contiguous");
NVTE_CHECK(grad_input.is_contiguous(), "Expected grad input tensor to be contiguous");
NVTE_CHECK(grad_input_transpose.is_contiguous(),
"Expected grad input transpose tensor to be contiguous");
NVTE_CHECK(grad_output.scalar_type() == input.scalar_type(),
"Expected grad output tensor and input tensor to have same dtype");
NVTE_CHECK(grad_input.scalar_type() == at::ScalarType::Byte,
"Expected grad input tensor to be uint8 buffer");
NVTE_CHECK(grad_input_transpose.scalar_type() == at::ScalarType::Byte,
"Expected grad input transpose tensor to be uint8 buffer");

// Get pointers for FP8 scale, amax, scale-inverse
void* scale_dptr = getDataPtr(scale, scale_offset);
void* amax_dptr = getDataPtr(amax, amax_offset);
void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset);

// Construct Transformer Engine tensors
auto dy_cu = makeTransformerEngineTensor(grad_output);
auto x_cu = makeTransformerEngineTensor(input);
auto dx_cu = makeTransformerEngineTensor(grad_input.data_ptr(), {M, 2 * N}, otype, amax_dptr,
scale_dptr, scale_inv_dptr);
auto dx_t_cu = makeTransformerEngineTensor(grad_input_transpose.data_ptr(), {2 * N, M}, otype,
amax_dptr, scale_dptr, scale_inv_dptr);

// Launch kernel
nvte_dswiglu_cast_transpose(dy_cu.data(), x_cu.data(), dx_cu.data(), dx_t_cu.data(),
at::cuda::getCurrentCUDAStream());
}

void fused_multi_cast_transpose_base(std::vector<at::Tensor> input_list,
std::vector<void*> scale_dptr_list,
std::vector<at::Tensor> cast_output_list,
Expand Down
Loading

0 comments on commit 1014133

Please sign in to comment.