Skip to content

Commit

Permalink
[PyTorch] Implement Fp8 padding and unpadding module (#1129)
Browse files Browse the repository at this point in the history
* [TE/PyTorch][MoE] Add FP8 padding and unpadding module 

 1. Add multi-tensor padding kernel for FP8 with padding size = 16.
 2. Add FP8Padding and Fp8Unpadding module
 3. Add Padded GroupedLinear unit tests

---------

Signed-off-by: beinggod <zhangruibin@01.ai>
Co-authored-by: Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
  • Loading branch information
BeingGod and phu0ngng authored Sep 5, 2024
1 parent 454e389 commit 215db88
Show file tree
Hide file tree
Showing 16 changed files with 995 additions and 1 deletion.
1 change: 1 addition & 0 deletions tests/cpp/operator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ add_executable(test_operator
test_layernorm.cu
test_rmsnorm.cu
test_multi_cast_transpose.cu
test_multi_padding.cu
test_causal_softmax.cu
../test_common.cu)

Expand Down
169 changes: 169 additions & 0 deletions tests/cpp/operator/test_multi_padding.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
#include <random>
#include <vector>
#include <cstdio>

#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>

#include <transformer_engine/padding.h>
#include "../test_common.h"

using namespace transformer_engine;

namespace {

template <typename InputType, typename OutputType>
void compute_ref(const std::vector<std::vector<InputType>>& input_list,
std::vector<std::vector<OutputType>>& output_list,
const std::vector<size_t>& height_list,
const std::vector<size_t>& width_list,
const std::vector<int>& padded_height_list) {
using compute_t = float;
for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) {
const auto& input = input_list[tensor_id];
auto& output = output_list[tensor_id];
const size_t height = height_list[tensor_id];
const size_t width = width_list[tensor_id];
const size_t padded_height = padded_height_list[tensor_id];

for (size_t i = 0; i < padded_height; ++i) {
if (i < height) {
for (size_t j = 0; j < width; ++j) {
const compute_t x = static_cast<compute_t>(input[i * width + j]);
const OutputType y = static_cast<OutputType>(x);
output[i * width + j] = y;
}
} else {
for (size_t j = 0; j < width; ++j) {
output[i * width + j] = static_cast<OutputType>(0.f);
}
}
}
}
}

template <typename InputType, typename OutputType>
void performTest() {
using namespace test;

const DType itype = TypeInfo<InputType>::dtype;
const DType otype = TypeInfo<OutputType>::dtype;
const std::vector<std::pair<size_t, size_t>> tensor_dims = {{1,1},
{1,768},
{768,1},
{768,768},
{43,43},
{43,256},
{256,43},
{256,256}};
const size_t num_tensors = tensor_dims.size();
constexpr int align = 16;

// Buffers for Transformer Engine implementation
std::vector<Tensor> input_list, output_list, output_t_list;

// Buffers for reference implementation
std::vector<std::vector<InputType>> ref_input_list;
std::vector<std::vector<OutputType>> ref_output_list;
std::vector<size_t> ref_height_list(num_tensors), ref_width_list(num_tensors);
std::vector<int> ref_padded_height_list(num_tensors);

// Initialize buffers
for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) {
const size_t height = tensor_dims[tensor_id].first;
const size_t width = tensor_dims[tensor_id].second;
const size_t padded_height = (height + align - 1) / align * align;
input_list.emplace_back(Tensor({ height, width }, itype));
output_list.emplace_back(Tensor({ padded_height, width }, otype));

auto& input = input_list.back();
auto& output = output_list.back();
fillUniform(&input);
setRandomScale(&output);

ref_input_list.emplace_back(height*width);
ref_output_list.emplace_back(padded_height*width);

std::copy(input.cpu_dptr<InputType>(),
input.cpu_dptr<InputType>() + height * width,
ref_input_list.back().begin());
ref_height_list[tensor_id] = height;
ref_width_list[tensor_id] = width;
ref_padded_height_list[tensor_id] = padded_height;
}

// Transformer Engine implementation
auto make_nvte_vector = [](std::vector<Tensor>& tensor_list)
-> std::vector<NVTETensor> {
std::vector<NVTETensor> nvte_tensor_list;
for (auto& tensor : tensor_list) {
nvte_tensor_list.emplace_back(tensor.data());
}
return nvte_tensor_list;
};
nvte_multi_padding(num_tensors,
make_nvte_vector(input_list).data(),
make_nvte_vector(output_list).data(),
ref_padded_height_list.data(),
0);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);

// Reference implementation
compute_ref<InputType, OutputType>(ref_input_list,
ref_output_list,
ref_height_list,
ref_width_list,
ref_padded_height_list);

// Check correctness
for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) {
auto [atol, rtol] = getTolerances(otype);
compareResults("output",
output_list[tensor_id],
ref_output_list[tensor_id].data(),
atol, rtol);
}
}

} // namespace

class MultiPaddingTestSuite
: public ::testing::TestWithParam<
transformer_engine::DType> {};

TEST_P(MultiPaddingTestSuite, TestMultiPaddingTranspose) {
using namespace transformer_engine;
using namespace test;

const DType input_type = GetParam();
const DType output_type = input_type;

TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<InputType, OutputType>();
);
);
}


INSTANTIATE_TEST_SUITE_P(
OperatorTest,
MultiPaddingTestSuite,
::testing::ValuesIn(test::all_fp_types),
[](const testing::TestParamInfo<MultiPaddingTestSuite::ParamType>& info) {
std::string name = test::typeName(info.param);
return name;
});
189 changes: 189 additions & 0 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Dict, List, Optional
import pytest
import copy
import random

import torch
import torch.nn as nn
Expand All @@ -30,6 +31,8 @@
TransformerLayer,
LayerNorm,
InferenceParams,
Fp8Padding,
Fp8Unpadding,
)
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import fp8_gemm, fp8_grouped_gemm, gemm, grouped_gemm
Expand Down Expand Up @@ -354,6 +357,40 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
return (input > 0) * input * input


class TorchGroupedLinearWithPadding(nn.Module):

def __init__(
self, num_gemms, in_features, out_features, bias, params_dtype, parallel_mode, fp8
) -> None:
super().__init__()

self.padding = Fp8Padding(num_gemms)
self.linear_fn = GroupedLinear(
num_gemms,
in_features,
out_features,
bias=bias,
params_dtype=params_dtype,
parallel_mode=parallel_mode,
device="cuda",
)
self.unpadding = Fp8Unpadding(num_gemms)

self.fp8 = fp8

def forward(self, inp: torch.Tensor, m_splits: List[int]) -> torch.Tensor:
if self.fp8:
orig_m_splits = m_splits
inp, m_splits = self.padding(inp, m_splits)

out = self.linear_fn(inp, m_splits)

if self.fp8:
out = self.unpadding(out, orig_m_splits)

return out


_supported_act = {
"geglu": nn.GELU(approximate="tanh"),
"gelu": nn.GELU(approximate="tanh"),
Expand Down Expand Up @@ -1328,6 +1365,158 @@ def test_grouped_linear_accuracy_parallel_mode(parallel_mode):
)


def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False):

def _pad_tensor_for_fp8(hidden_states, tokens_per_expert):
"""Padding tensor shapes to multiples of 16."""
padded_tokens_per_expert = [
(num_tokens + 15) // 16 * 16 for num_tokens in tokens_per_expert
]
hidden_states = torch.split(hidden_states, tokens_per_expert)
padded_hidden_states = []
for hidden_state, actual_num_tokens, padded_num_tokens in zip(
hidden_states, tokens_per_expert, padded_tokens_per_expert
):
padded_hidden_states.append(hidden_state)
if padded_num_tokens > actual_num_tokens:
pad_tensor = torch.zeros(
padded_num_tokens - actual_num_tokens,
hidden_state.shape[1],
dtype=hidden_state.dtype,
device=hidden_state.device,
)
padded_hidden_states.append(pad_tensor)
padded_hidden_states = torch.cat(padded_hidden_states, dim=0)
return padded_hidden_states, padded_tokens_per_expert

def _unpad_tensor_for_fp8(padded_hidden_states, actual_tokens_per_expert, tokens_per_expert):
inputmats = torch.split(
padded_hidden_states.view(-1, padded_hidden_states.shape[-1]), tokens_per_expert
)
hidden_states = torch.cat(
[
grad_output_mat[: actual_tokens_per_expert[i]]
for i, grad_output_mat in enumerate(inputmats)
],
dim=0,
)

return hidden_states

def _generate_random_numbers(n, total_sum):
if n <= 0:
return []

# reset seed
random.seed(seed)

breaks = sorted(random.sample(range(1, total_sum), n - 1))
random_numbers = (
[breaks[0]]
+ [breaks[i] - breaks[i - 1] for i in range(1, n - 1)]
+ [total_sum - breaks[-1]]
)

return random_numbers

reset_rng_states()
if fp8:
FP8GlobalStateManager.reset()

inp_hidden_states = torch.randn(
(config.seq_len * bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
inp_hidden_states.retain_grad()

m_splits = _generate_random_numbers(num_gemms, config.seq_len * bs)

with fp8_autocast(enabled=fp8):
if isinstance(block, TorchGroupedLinearWithPadding):
out = block(inp_hidden_states, m_splits)
else:
if fp8:
padded_inp_hidden_states, padding_m_splits = _pad_tensor_for_fp8(
inp_hidden_states, m_splits
)
padded_inp_hidden_states = block(padded_inp_hidden_states, padding_m_splits)
out = _unpad_tensor_for_fp8(padded_inp_hidden_states, m_splits, padding_m_splits)
else:
out = block(inp_hidden_states, m_splits)

loss = out.sum()
loss.backward()

torch.cuda.synchronize()
outputs = [out, inp_hidden_states.grad]
for p in block.parameters():
if p.requires_grad:
outputs.append(p.grad)
return outputs


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("num_gemms", [3, 6])
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("fp8", [True])
@pytest.mark.parametrize("fp8_model_params", all_boolean)
def test_padding_grouped_linear_accuracy(
dtype, num_gemms, bs, model, fp8, fp8_model_params, parallel_mode=None
):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)

config = model_configs[model]
if config.seq_len % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")

with fp8_model_init(enabled=fp8 and fp8_model_params):
grouped_linear = TorchGroupedLinearWithPadding(
num_gemms,
config.hidden_size,
4 * config.hidden_size,
bias=False,
params_dtype=dtype,
parallel_mode=parallel_mode,
fp8=fp8,
).eval()

with fp8_model_init(enabled=fp8 and fp8_model_params):
ref_grouped_linear = GroupedLinear(
num_gemms,
config.hidden_size,
4 * config.hidden_size,
bias=False,
params_dtype=dtype,
parallel_mode=parallel_mode,
device="cuda",
).eval()

# Share params
with torch.no_grad():
inner_grouped_linear = grouped_linear.linear_fn
for i in range(num_gemms):
setattr(
ref_grouped_linear,
f"weight{i}",
Parameter(getattr(inner_grouped_linear, f"weight{i}").clone()),
)

outputs = _test_padding_grouped_linear_accuracy(
grouped_linear, num_gemms, bs, dtype, config, fp8
)
outputs_ref = _test_padding_grouped_linear_accuracy(
ref_grouped_linear, num_gemms, bs, dtype, config, fp8
)

# Shoule be bit-wise match
for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)


def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph):
reset_rng_states()

Expand Down
Loading

0 comments on commit 215db88

Please sign in to comment.