Skip to content

Commit

Permalink
Merge branch 'main' into rewang/revert-deterministic-tests
Browse files Browse the repository at this point in the history
  • Loading branch information
phu0ngng authored Oct 31, 2024
2 parents 19afa3e + 23caab3 commit af4983a
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 88 deletions.
5 changes: 1 addition & 4 deletions transformer_engine/jax/cpp_extensions/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def lowering(

wkspace_aval = ctx.avals_out[-1]

if is_ffi_enabled():
if is_ffi_enabled() and bool(os.getenv("NVTE_JAX_FUSED_ATTN_WITH_FFI", "0")):
name = "te_fused_attn_forward_ffi"
out = ffi.ffi_lowering(name)(
ctx,
Expand All @@ -401,14 +401,11 @@ def lowering(
bias_heads=bias_heads,
head_dim=head_dim,
max_segments_per_seq=config.max_segments_per_seq,
wkspace_size=wkspace_aval.size,
scaling_factor=float(config.scaling_factor),
dropout_probability=float(config.dropout_probability),
bias_type=int(config.attn_bias_type),
mask_type=int(config.attn_mask_type),
qkv_layout=int(config.qkv_layout),
dtype=int(jax_dtype_to_te_dtype(q_aval.dtype)),
wkspace_dtype=int(jax_dtype_to_te_dtype(wkspace_aval.dtype)),
is_training=config.is_training,
deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
window_size_left=config.window_size[0],
Expand Down
7 changes: 3 additions & 4 deletions transformer_engine/jax/csrc/extensions/activation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type outp
auto *output = output_buf->untyped_data();

auto input_dims = input_buf.dimensions();
auto m = std::accumulate(input_dims.begin(), input_dims.end() - 2, 1, std::multiplies<>());
auto m = product(input_dims, 0, input_dims.size() - 2);
auto n = input_dims.back();
auto act_len = input_dims.end()[-2];
auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
Expand Down Expand Up @@ -175,7 +175,7 @@ Error_Type ActLuFP8FFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type a
}

auto input_dims = input_buf.dimensions();
auto m = std::accumulate(input_dims.begin(), input_dims.end() - 2, 1, std::multiplies<>());
auto m = product(input_dims, 0, input_dims.size() - 2);
auto n = input_dims.back();
auto act_len = input_dims.end()[-2];
auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
Expand Down Expand Up @@ -264,8 +264,7 @@ Error_Type DActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type act
auto *output = output_buf->untyped_data();

auto act_input_dims = act_input_buf.dimensions();
auto m =
std::accumulate(act_input_dims.begin(), act_input_dims.end() - 2, 1, std::multiplies<>());
auto m = product(act_input_dims, 0, act_input_dims.size() - 2);
auto n = act_input_dims.back();
auto act_len = act_input_dims.end()[-2];

Expand Down
97 changes: 48 additions & 49 deletions transformer_engine/jax/csrc/extensions/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,36 +329,55 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
descriptor.deterministic, descriptor.window_size_left, descriptor.window_size_right);
}

Error_Type FusedAttnForwardFFI(
cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf, Buffer_Type v_buf,
Buffer_Type bias_buf, Buffer_Type q_cu_seqlens_buf, Buffer_Type kv_cu_seqlens_buf,
Buffer_Type q_seq_offsets_buf, Buffer_Type k_seq_offsets_buf, Buffer_Type seed_buf,
Result_Type output_buf, Result_Type softmax_aux_buf, Result_Type rng_state_buf,
Result_Type workspace_buf, int64_t input_batch_, int64_t bias_batch_, int64_t q_max_seqlen_,
int64_t kv_max_seqlen_, int64_t attn_heads_, int64_t num_gqa_groups_, int64_t bias_heads_,
int64_t head_dim_, int64_t max_segments_per_seq_, int64_t wkspace_size_, double scaling_factor_,
double dropout_probability_, int64_t bias_type_, int64_t mask_type_, int64_t qkv_layout_,
int64_t dtype_, int64_t wkspace_dtype_, bool is_training, bool deterministic,
int64_t window_size_left, int64_t window_size_right) {
NVTE_QKV_Layout qkv_layout = static_cast<NVTE_QKV_Layout>(qkv_layout_);
Error_Type FusedAttnForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf,
Buffer_Type v_buf, Buffer_Type bias_buf,
Buffer_Type q_cu_seqlens_buf, Buffer_Type kv_cu_seqlens_buf,
Buffer_Type q_seq_offsets_buf, Buffer_Type k_seq_offsets_buf,
Buffer_Type seed_buf, Result_Type output_buf,
Result_Type softmax_aux_buf, Result_Type rng_state_buf,
Result_Type workspace_buf, Dictionary attrs) {
/* Descriptor data type conversion */
size_t input_batch = get_attr_value<int64_t>(attrs, "input_batch");
size_t bias_batch = get_attr_value<int64_t>(attrs, "bias_batch");
size_t q_max_seqlen = get_attr_value<int64_t>(attrs, "q_max_seqlen");
size_t kv_max_seqlen = get_attr_value<int64_t>(attrs, "kv_max_seqlen");
size_t attn_heads = get_attr_value<int64_t>(attrs, "attn_heads");
size_t num_gqa_groups = get_attr_value<int64_t>(attrs, "num_gqa_groups");
size_t bias_heads = get_attr_value<int64_t>(attrs, "bias_heads");
size_t head_dim = get_attr_value<int64_t>(attrs, "head_dim");
size_t max_segments_per_seq = get_attr_value<int64_t>(attrs, "max_segments_per_seq");
auto window_size_left = get_attr_value<int64_t>(attrs, "window_size_left");
auto window_size_right = get_attr_value<int64_t>(attrs, "window_size_right");

float scaling_factor = get_attr_value<double>(attrs, "scaling_factor");
float dropout_probability = get_attr_value<double>(attrs, "dropout_probability");

NVTE_Bias_Type bias_type =
static_cast<NVTE_Bias_Type>(get_attr_value<int64_t>(attrs, "bias_type"));
NVTE_Mask_Type mask_type =
static_cast<NVTE_Mask_Type>(get_attr_value<int64_t>(attrs, "mask_type"));
NVTE_QKV_Layout qkv_layout =
static_cast<NVTE_QKV_Layout>(get_attr_value<int64_t>(attrs, "qkv_layout"));

bool is_training = get_attr_value<bool>(attrs, "is_training");
bool deterministic = get_attr_value<bool>(attrs, "deterministic");

auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;

size_t wkspace_size = product(workspace_buf->dimensions());
DType dtype = convert_ffi_datatype_to_te_dtype(q_buf.element_type());
DType wkspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());

FusedAttnForwardImpl(
stream, q_buf.untyped_data(), k_buf.untyped_data(), v_buf.untyped_data(),
bias_buf.untyped_data(), q_cu_seqlens_buf.untyped_data(), kv_cu_seqlens_buf.untyped_data(),
is_ragged ? q_seq_offsets_buf.untyped_data() : nullptr,
is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr, seed_buf.untyped_data(),
output_buf->untyped_data(), softmax_aux_buf->untyped_data(), rng_state_buf->untyped_data(),
workspace_buf->untyped_data(), static_cast<size_t>(input_batch_),
static_cast<size_t>(bias_batch_), static_cast<size_t>(q_max_seqlen_),
static_cast<size_t>(kv_max_seqlen_), static_cast<size_t>(attn_heads_),
static_cast<size_t>(num_gqa_groups_), static_cast<size_t>(bias_heads_),
static_cast<size_t>(head_dim_), static_cast<size_t>(max_segments_per_seq_),
static_cast<size_t>(wkspace_size_), static_cast<float>(scaling_factor_),
static_cast<float>(dropout_probability_), static_cast<NVTE_Bias_Type>(bias_type_),
static_cast<NVTE_Mask_Type>(mask_type_), static_cast<NVTE_QKV_Layout>(qkv_layout_),
static_cast<DType>(dtype_), static_cast<DType>(wkspace_dtype_), is_training, deterministic,
window_size_left, window_size_right);
workspace_buf->untyped_data(), input_batch, bias_batch, q_max_seqlen, kv_max_seqlen,
attn_heads, num_gqa_groups, bias_heads, head_dim, max_segments_per_seq, wkspace_size,
scaling_factor, dropout_probability, bias_type, mask_type, qkv_layout, dtype, wkspace_dtype,
is_training, deterministic, window_size_left, window_size_right);

return ffi_with_cuda_error_check();
}
Expand All @@ -379,27 +398,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnForwardHandler, FusedAttnForwardFFI,
.Ret<Buffer_Type>() // softmax_aux
.Ret<Buffer_Type>() // rng_state
.Ret<Buffer_Type>() // workspace
.Attr<int64_t>("input_batch")
.Attr<int64_t>("bias_batch")
.Attr<int64_t>("q_max_seqlen")
.Attr<int64_t>("kv_max_seqlen")
.Attr<int64_t>("attn_heads")
.Attr<int64_t>("num_gqa_groups")
.Attr<int64_t>("bias_heads")
.Attr<int64_t>("head_dim")
.Attr<int64_t>("max_segments_per_seq")
.Attr<int64_t>("wkspace_size")
.Attr<double>("scaling_factor")
.Attr<double>("dropout_probability")
.Attr<int64_t>("bias_type")
.Attr<int64_t>("mask_type")
.Attr<int64_t>("qkv_layout")
.Attr<int64_t>("dtype")
.Attr<int64_t>("wkspace_dtype")
.Attr<bool>("is_training")
.Attr<bool>("deterministic")
.Attr<int64_t>("window_size_left")
.Attr<int64_t>("window_size_right"),
.Attrs(),
FFI_CudaGraph_Traits);

pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
Expand Down Expand Up @@ -608,7 +607,7 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto dqkv = buffers[12];
auto dqkv_tensor = TensorWrapper(dqkv, qkv_shape, dtype);
if (is_ragged) {
cudaMemsetAsync(dqkv, 0, product(qkv_shape) * typeToSize(dtype), stream);
cudaMemsetAsync(dqkv, 0, transformer_engine::product(qkv_shape) * typeToSize(dtype), stream);
}
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
Expand All @@ -630,8 +629,8 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto dkv = buffers[13];
auto dkv_tensor = TensorWrapper(dkv, kv_shape, dtype);
if (is_ragged) {
cudaMemsetAsync(dq, 0, product(q_shape) * typeToSize(dtype), stream);
cudaMemsetAsync(dkv, 0, product(kv_shape) * typeToSize(dtype), stream);
cudaMemsetAsync(dq, 0, transformer_engine::product(q_shape) * typeToSize(dtype), stream);
cudaMemsetAsync(dkv, 0, transformer_engine::product(kv_shape) * typeToSize(dtype), stream);
}
nvte_fused_attn_bwd_kvpacked(
q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
Expand Down Expand Up @@ -659,9 +658,9 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto dv = buffers[14];
auto dv_tensor = TensorWrapper(dv, v_shape, dtype);
if (is_ragged) {
cudaMemsetAsync(dq, 0, product(q_shape) * typeToSize(dtype), stream);
cudaMemsetAsync(dk, 0, product(k_shape) * typeToSize(dtype), stream);
cudaMemsetAsync(dv, 0, product(v_shape) * typeToSize(dtype), stream);
cudaMemsetAsync(dq, 0, transformer_engine::product(q_shape) * typeToSize(dtype), stream);
cudaMemsetAsync(dk, 0, transformer_engine::product(k_shape) * typeToSize(dtype), stream);
cudaMemsetAsync(dv, 0, transformer_engine::product(v_shape) * typeToSize(dtype), stream);
}
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(),
Expand Down
2 changes: 0 additions & 2 deletions transformer_engine/jax/csrc/extensions/ffi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@

#include <iostream>

#include "common/util/logging.h"

namespace transformer_engine {
namespace jax {

Expand Down
57 changes: 56 additions & 1 deletion transformer_engine/jax/csrc/extensions/ffi.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

#include <numeric>

#include "common/util/logging.h"

namespace transformer_engine {
namespace jax {

Expand All @@ -17,10 +19,63 @@ using Result_Type = xla::ffi::Result<xla::ffi::AnyBuffer>;
using Error_Type = xla::ffi::Error;
using FFI = xla::ffi::Ffi;
using FFI_Stream_Type = xla::ffi::PlatformStream<cudaStream_t>;
using Dictionary = xla::ffi::Dictionary;
constexpr auto FFI_CudaGraph_Traits = {xla::ffi::Traits::kCmdBufferCompatible};

DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type);
DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType& type);

Error_Type ffi_with_cuda_error_check();

// source_location is not available in C++17, so we implement it ourselves
#if defined(__GNUC__) || defined(__clang__)
#define CURRENT_FILE __builtin_FILE()
#define CURRENT_LINE __builtin_LINE()
#define CURRENT_FUNCTION __builtin_FUNCTION()
#else
#define CURRENT_FILE __FILE__
#define CURRENT_LINE __LINE__
#define CURRENT_FUNCTION __func__
#endif

class source_location {
public:
static source_location current(const char* file = CURRENT_FILE, int line = CURRENT_LINE,
const char* function = CURRENT_FUNCTION) {
return source_location(file, line, function);
}

constexpr const char* file_name() const { return file_; }
constexpr int line() const { return line_; }
constexpr const char* function_name() const { return function_; }

private:
constexpr source_location(const char* file, int line, const char* function)
: file_(file), line_(line), function_(function) {}

const char* file_;
int line_;
const char* function_;
};

template <typename T>
T get_attr_value(Dictionary& attrs, std::string attr_name,
const source_location& loc = source_location::current()) {
auto attr = attrs.get<T>(attr_name);
if (attr.has_error()) {
NVTE_ERROR("Failure in getting attribute value of '", attr_name, "'\n",
"Called from: ", loc.file_name(), ":", loc.line(), "\n",
"In function: ", loc.function_name(), "\n",
"Please ensure the attribute name and datatype match between C++ and Python APIs.");
}
return attr.value();
}

inline size_t product(const xla::ffi::Span<const int64_t>& data, size_t start_idx = 0,
size_t end_idx = 0) {
end_idx = (end_idx == 0) ? data.size() : end_idx;
return std::accumulate(data.begin() + start_idx, data.begin() + end_idx, size_t(1),
std::multiplies<size_t>());
}

} // namespace jax
} // namespace transformer_engine
28 changes: 8 additions & 20 deletions transformer_engine/jax/csrc/extensions/normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,19 +264,13 @@ Error_Type LayerNormForwardFP8FFI(cudaStream_t stream, Buffer_Type x_buf, Buffer
NVTE_CHECK(amax_out == amax,
"amax not bound to amax_out in TE/JAX LayerNormForwardFP8 primitive");

auto x_dims = x_buf.dimensions();
auto gamma_dims = gamma_buf.dimensions();
auto x_size = std::accumulate(x_dims.begin(), x_dims.end(), 1, std::multiplies<>());
auto gamma_size = std::accumulate(gamma_dims.begin(), gamma_dims.end(), 1, std::multiplies<>());
auto x_size = product(x_buf.dimensions());
auto gamma_size = product(gamma_buf.dimensions());
auto hidden_size = gamma_size;
auto batch_size = x_size / gamma_size;

auto wkspace_dims = wkspace_buf->dimensions();
auto barrier_dims = barrier_buf->dimensions();
auto wkspace_size =
std::accumulate(wkspace_dims.begin(), wkspace_dims.end(), 1, std::multiplies<>());
auto barrier_size =
std::accumulate(barrier_dims.begin(), barrier_dims.end(), 1, std::multiplies<>());
auto wkspace_size = product(wkspace_buf->dimensions());
auto barrier_size = product(barrier_buf->dimensions());

float eps = static_cast<float>(eps_);
int sm_margin = static_cast<int>(sm_margin_);
Expand Down Expand Up @@ -408,19 +402,13 @@ Error_Type LayerNormBackwardFFI(cudaStream_t stream, Buffer_Type dz_buf, Buffer_
auto *dgamma_part = dgamma_part_buf->untyped_data();
auto *dbeta_part = dbeta_part_buf->untyped_data();

auto x_dims = x_buf.dimensions();
auto gamma_dims = gamma_buf.dimensions();
auto x_size = std::accumulate(x_dims.begin(), x_dims.end(), 1, std::multiplies<>());
auto gamma_size = std::accumulate(gamma_dims.begin(), gamma_dims.end(), 1, std::multiplies<>());
auto x_size = product(x_buf.dimensions());
auto gamma_size = product(gamma_buf.dimensions());
auto hidden_size = gamma_size;
auto batch_size = x_size / gamma_size;

auto wkspace_dims = wkspace_buf->dimensions();
auto barrier_dims = barrier_buf->dimensions();
auto wkspace_size =
std::accumulate(wkspace_dims.begin(), wkspace_dims.end(), 1, std::multiplies<>());
auto barrier_size =
std::accumulate(barrier_dims.begin(), barrier_dims.end(), 1, std::multiplies<>());
auto wkspace_size = product(wkspace_buf->dimensions());
auto barrier_size = product(barrier_buf->dimensions());

auto dgamma_part_dims = dgamma_part_buf->dimensions();
auto dbeta_part_dims = dbeta_part_buf->dimensions();
Expand Down
13 changes: 5 additions & 8 deletions transformer_engine/jax/csrc/extensions/transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,9 @@ Error_Type TransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type

auto input_dims = input_buf.dimensions();
if (transpose_axis < 0) transpose_axis += input_dims.size();
auto m = std::accumulate(input_dims.begin(), input_dims.begin() + transpose_axis, 1,
std::multiplies<>());
auto n = std::accumulate(input_dims.begin() + transpose_axis, input_dims.end(), 1,
std::multiplies<>());
auto m = product(input_dims, 0, transpose_axis);
auto n = product(input_dims, transpose_axis, input_dims.size());

auto input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{n, m};

Expand Down Expand Up @@ -124,10 +123,8 @@ Error_Type CastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T

auto input_dims = input_buf.dimensions();
if (transpose_axis < 0) transpose_axis += input_dims.size();
auto m = std::accumulate(input_dims.begin(), input_dims.begin() + transpose_axis, 1,
std::multiplies<>());
auto n = std::accumulate(input_dims.begin() + transpose_axis, input_dims.end(), 1,
std::multiplies<>());
auto m = product(input_dims, 0, transpose_axis);
auto n = product(input_dims, transpose_axis, input_dims.size());
auto input_shape = std::vector<size_t>{m, n};
auto input_trans_shape = std::vector<size_t>{n, m};

Expand Down

0 comments on commit af4983a

Please sign in to comment.