From d9eb1991407c38db51570aa336308ac542556c15 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com> Date: Thu, 8 Feb 2024 19:32:40 +0100 Subject: [PATCH] [common] Added new unfused softmax cuda kernel to support causal attention mask (#652) * Added new unfused softmax cuda kernel to support causal attention mask Signed-off-by: Oleg Goncharov * Added test suite for unfused causal softmax kernel Signed-off-by: Oleg Goncharov * Removed test cases with large matrices from the causal softmax test suite Signed-off-by: Oleg Goncharov * Cleaned up the code per lint Signed-off-by: Oleg Goncharov * Added a compute buffer to causal softmax testing suite to store intermediate results without casting Signed-off-by: Oleg Goncharov * Added more tests cases Signed-off-by: Oleg Goncharov * Relaxed absolute tolerance atol Signed-off-by: Oleg Goncharov * Relaxed absolute tolerance for BF16 Signed-off-by: Oleg Goncharov --------- Signed-off-by: Oleg Goncharov --- tests/cpp/operator/CMakeLists.txt | 1 + tests/cpp/operator/test_causal_softmax.cu | 276 +++++++ transformer_engine/common/CMakeLists.txt | 4 +- .../scaled_aligned_causal_masked_softmax.cu | 671 ++++++++++++++++++ .../fused_softmax/scaled_masked_softmax.cu | 71 +- .../scaled_upper_triang_masked_softmax.cu | 58 +- .../include/transformer_engine/softmax.h | 35 + transformer_engine/pytorch/csrc/extensions.h | 11 + .../pytorch/csrc/extensions/pybind.cpp | 6 + .../pytorch/csrc/extensions/softmax.cu | 78 +- transformer_engine/pytorch/softmax.py | 68 +- 11 files changed, 1255 insertions(+), 24 deletions(-) create mode 100644 tests/cpp/operator/test_causal_softmax.cu create mode 100644 transformer_engine/common/fused_softmax/scaled_aligned_causal_masked_softmax.cu diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index f2598d0200..0dd2a6d8e2 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -15,6 +15,7 @@ add_executable(test_operator test_layernorm.cu test_rmsnorm.cu test_multi_cast_transpose.cu + test_causal_softmax.cu ../test_common.cu) list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB}) diff --git a/tests/cpp/operator/test_causal_softmax.cu b/tests/cpp/operator/test_causal_softmax.cu new file mode 100644 index 0000000000..1d7d70e572 --- /dev/null +++ b/tests/cpp/operator/test_causal_softmax.cu @@ -0,0 +1,276 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include + +#include +#include "../test_common.h" + +using namespace transformer_engine; + +namespace { + +using compute_t = float; + +template +void compute_single_head_fwd( + Type *softmax_out, + const Type *data_in, + compute_t *buff, + const float scaling_factor, + const int rows, + const int cols) +{ + for (int i = 0; i < rows; ++i) { + size_t offset = i * cols; + + const int masked_elements = i + cols - rows + 1; + compute_t max_value = static_cast(-10'000.f); + for (int j = 0; j < masked_elements; ++j) { + compute_t tmp = scaling_factor * static_cast(data_in[offset + j]); + buff[offset + j] = tmp; + max_value = std::max(max_value, tmp); + } + + compute_t accumulator = static_cast(0.f); + for (int j = 0; j < masked_elements; ++j) { + compute_t tmp = std::exp(buff[offset + j] - max_value); + buff[offset + j] = tmp; + accumulator += tmp; + } + + for (int j = 0; j < cols; ++j) { + if (j < masked_elements) { + compute_t tmp = buff[offset + j] / accumulator; + softmax_out[offset + j] = static_cast(tmp); + } else { + softmax_out[offset + j] = static_cast(0.f); + } + } + } +} + +template +void compute_single_head_bwd( + Type *grad_out, + const Type *grad_in, + const Type *softmax_in, + compute_t *buff, + const float scaling_factor, + const int batches, + const int heads, + const int rows, + const int cols) +{ + for (int i = 0; i < rows; ++i) { + size_t offset = i * cols; + + const int masked_elements = i + cols - rows + 1; + compute_t accumulator = static_cast(0.f); + for (int j = 0; j < masked_elements; ++j) { + compute_t tmp = static_cast(softmax_in[offset + j]) + * static_cast(grad_in[offset + j]); + buff[offset + j] = tmp; + accumulator += tmp; + } + + for (int j = 0; j < cols; ++j) { + if (j < masked_elements) { + compute_t tmp = buff[offset + j] + - static_cast(softmax_in[offset + j]) * accumulator; + grad_out[offset + j] = static_cast(scaling_factor * tmp); + } else { + grad_out[offset + j] = static_cast(0.f); + } + } + } +} + +template +void compute_fwd_ref( + Type *softmax_out, + const Type *data_in, + compute_t *buff, + const float scaling_factor, + const int batches, + const int heads, + const int rows, + const int cols) +{ + size_t head_size = rows * cols; + size_t batch_size = heads * head_size; + + for (int b = 0; b < batches; ++b) { + for (int h = 0; h < heads; ++h) { + size_t offset = b * batch_size + h * head_size; + compute_single_head_fwd(softmax_out + offset, data_in + offset, + buff + offset, scaling_factor, rows, cols); + } + } +} + +template +void compute_bwd_ref( + Type *grad_out, + const Type *grad_in, + const Type *softmax_in, + compute_t *buff, + const float scaling_factor, + const int batches, + const int heads, + const int rows, + const int cols) +{ + size_t head_size = rows * cols; + size_t batch_size = heads * head_size; + + for (int b = 0; b < batches; ++b) { + for (int h = 0; h < heads; ++h) { + size_t offset = b * batch_size + h * head_size; + compute_single_head_bwd(grad_out + offset, grad_in + offset, softmax_in + offset, + buff + offset, scaling_factor, batches, heads, rows, cols); + } + } +} + + +// Query Sequence Length = rows +// Key Sequence Length = cols +template +void performTest( + const size_t batches, + const size_t heads, + const size_t rows, + const size_t cols, + float scaling_factor) +{ + using namespace test; + + DType itype = TypeInfo::dtype; + + Tensor data_in({ batches, heads, rows, cols }, itype); + Tensor softmax_out({ batches, heads, rows, cols }, itype); + Tensor softmax_in({ batches, heads, rows, cols }, itype); + Tensor grads_in({ batches, heads, rows, cols }, itype); + Tensor grads_out({ batches, heads, rows, cols }, itype); + + const size_t elements_total = batches * heads * rows * cols; + std::unique_ptr softmax_out_ref = std::make_unique(elements_total); + std::unique_ptr grads_out_ref = std::make_unique(elements_total); + std::unique_ptr compute_buffer = std::make_unique(elements_total); + + fillUniform(&data_in); + fillUniform(&softmax_in); + fillUniform(&grads_in); + + nvte_scaled_aligned_causal_masked_softmax_forward( + data_in.data(), softmax_out.data(), scaling_factor, 0); + nvte_scaled_aligned_causal_masked_softmax_backward( + grads_in.data(), softmax_in.data(), grads_out.data(), scaling_factor, 0); + + + // Reference implementations + compute_fwd_ref(softmax_out_ref.get(), data_in.cpu_dptr(), + compute_buffer.get(), scaling_factor, batches, heads, rows, cols); + compute_bwd_ref(grads_out_ref.get(), grads_in.cpu_dptr(), softmax_in.cpu_dptr(), + compute_buffer.get(), scaling_factor, batches, heads, rows, cols); + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + auto [atol, rtol] = getTolerances(itype); + if(itype == DType::kBFloat16) { + atol = 1e-3; + } + compareResults("softmax_fwd", softmax_out, softmax_out_ref.get(), atol, rtol); + compareResults("softmax_bwd", grads_out, grads_out_ref.get(), atol, rtol); +} + +// [Batches, Attention Heads, Query Sequence Length, Key Sequence Length, Scaling Factor] +std::vector> test_cases = { + { 1, 1, 1, 16, -1.0f}, + { 1, 2, 17, 32, 0.8f}, + { 2, 1, 37, 112, 1.0f}, + { 2, 4, 127, 128, -0.2f}, + { 8, 6, 128, 256, 1.3f}, + { 1, 4, 270, 256, 0.8f}, + { 2, 2, 512, 512, -1.5f}, + { 1, 2, 819, 1024, 2.1f}, + { 1, 2, 281, 1024, 0.2f}, + { 1, 2, 277, 1024, -2.1f}, + { 1, 2, 127, 1024, 1.1f}, + { 2, 2, 107, 2048, 0.4f}, + { 2, 1, 103, 2048, -3.0f}, + { 2, 2, 101, 2048, 2.6f}, + { 1, 1, 1024, 4096, 0.6f}, + { 1, 2, 61, 4096, 0.6f}, + { 1, 2, 59, 4096, -4.9f}, + { 1, 2, 53, 4096, 3.5f}, + { 1, 1, 37, 8192, 0.7f}, + { 1, 1, 31, 8192, -5.8f}, + { 1, 1, 29, 8192, 4.4f}, + { 1, 1, 23, 12288, 0.8f}, + { 1, 1, 19, 12288, -6.7f}, + { 1, 1, 17, 12288, 3.3f}, + { 1, 1, 13, 16384, 0.9f}, + { 1, 1, 11, 16384, -7.6f}, + { 1, 1, 7, 16384, 6.2f}}; + +} // namespace + +class CausalSoftmaxTestSuite + : public ::testing::TestWithParam>> {}; + +TEST_P(CausalSoftmaxTestSuite, TestCausalSoftmax) { + using namespace transformer_engine; + using namespace test; + + const DType input_type = std::get<0>(GetParam()); + const auto size = std::get<1>(GetParam()); + + const size_t batches = std::get<0>(size); + const size_t heads = std::get<1>(size); + const size_t query_seq_len = std::get<2>(size); + const size_t key_seq_len = std::get<3>(size); + const float scaling_factor = std::get<4>(size); + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, + performTest(batches, heads, query_seq_len, key_seq_len, scaling_factor); + ); +} + + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + CausalSoftmaxTestSuite, + ::testing::Combine( + ::testing::Values(DType::kFloat16, DType::kBFloat16), + ::testing::ValuesIn(test_cases)), + [](const testing::TestParamInfo& info) { + const auto size = std::get<1>(info.param); + const size_t batches = std::get<0>(size); + const size_t heads = std::get<1>(size); + const size_t query_seq_len = std::get<2>(size); + const size_t key_seq_len = std::get<3>(size); + + std::string scaling_factor = std::to_string(std::get<4>(size)); + for (char& c : scaling_factor) { + if (c == '-') { c = 'N'; } + if (c == '.') { c = 'p'; } + } + + std::string name = test::typeName(std::get<0>(info.param)) + "X" + + std::to_string(batches) + "X" + + std::to_string(heads) + "X" + + std::to_string(query_seq_len) + "X" + + std::to_string(key_seq_len) + "X" + + scaling_factor; + return name; + }); diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 4997762e03..2be5a6cc3f 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -33,8 +33,7 @@ list(APPEND transformer_engine_SOURCES util/system.cpp fused_softmax/scaled_masked_softmax.cu fused_softmax/scaled_upper_triang_masked_softmax.cu - fused_softmax/scaled_masked_softmax.cu - fused_softmax/scaled_upper_triang_masked_softmax.cu + fused_softmax/scaled_aligned_causal_masked_softmax.cu fused_rope/fused_rope.cu) add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) target_include_directories(transformer_engine PUBLIC @@ -87,6 +86,7 @@ target_include_directories(transformer_engine PRIVATE # Compiler options set_source_files_properties(fused_softmax/scaled_masked_softmax.cu fused_softmax/scaled_upper_triang_masked_softmax.cu + fused_softmax/scaled_aligned_causal_masked_softmax.cu PROPERTIES COMPILE_OPTIONS "--use_fast_math") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") diff --git a/transformer_engine/common/fused_softmax/scaled_aligned_causal_masked_softmax.cu b/transformer_engine/common/fused_softmax/scaled_aligned_causal_masked_softmax.cu new file mode 100644 index 0000000000..648f9ab6b1 --- /dev/null +++ b/transformer_engine/common/fused_softmax/scaled_aligned_causal_masked_softmax.cu @@ -0,0 +1,671 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include "../common.h" +#include "../utils.cuh" +#include "../util/logging.h" + + +namespace transformer_engine { + +template +__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); + +template<> +__device__ __inline__ void copy_vector(bf16 *dst, const bf16 *src) { + *dst = *src; +} + +template<> +__device__ __inline__ void copy_vector(bf16 *dst, const bf16 *src) { + *((uint64_t*) dst) = *((uint64_t*) src); // NOLINT(*) +} + +template<> +__device__ __inline__ void copy_vector(fp16 *dst, const fp16 *src) { + *dst = *src; +} + +template<> +__device__ __inline__ void copy_vector(fp16 *dst, const fp16 *src) { + *((uint64_t*) dst) = *((uint64_t*) src); // NOLINT(*) +} + +template<> +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { + *dst = *src; +} + +template<> +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { + *((uint32_t*) dst) = *((uint32_t*) src); // NOLINT(*) +} + +template +__device__ __inline__ void copy_zero_vector(Datatype *dst); + +template <> +__device__ __inline__ void copy_zero_vector(bf16 *dst) { + *dst = 0.0f; +} + +template <> +__device__ __inline__ void copy_zero_vector(bf16 *dst) { + *((float2*) dst) = make_float2(0.0f, 0.0f); // NOLINT(*) +} + +template <> +__device__ __inline__ void copy_zero_vector(fp16 *dst) { + *dst = 0.0f; +} + +template <> +__device__ __inline__ void copy_zero_vector(fp16 *dst) { + *((float2*) dst) = make_float2(0.0f, 0.0f); // NOLINT(*) +} + + +template +struct Add { + __device__ __forceinline__ T operator()(T a, T b) const { + return a + b; + } +}; + +template +struct Max { + __device__ __forceinline__ T operator()(T a, T b) const { + return a < b ? b : a; + } +}; + +template +__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, + unsigned int mask = 0xffffffff) { +#if CUDA_VERSION >= 9000 + return __shfl_xor_sync(mask, value, laneMask, width); +#else + return __shfl_xor(value, laneMask, width); +#endif +} + +template class ReduceOp> +__device__ __forceinline__ void warp_reduce(acc_t* sum) { + ReduceOp r; + #pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + #pragma unroll + for (int i = 0; i < WARP_ROWS; ++i) { + acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); + sum[i] = r(sum[i], b); + } + } +} + +/* + * Extended softmax (from native aten pytorch) with the following additional features + * 1) input scaling + * 2) implicit causal masking + * + * works for all cases: + * k > q + * k < q + * k = q + * + * where: + * microbatches = batches * attn_heads * query_seq_len + * rows = query_seq_len + * cols = key_seq_len + */ +template +__global__ void scaled_aligned_causal_masked_softmax_warp_forward( + output_t *dst, + const input_t *src, + const acc_t scale, + const int microbatches, + const int rows, + const int cols +) { + // 1) WARP_WIDTH must match the value of warp_size + // 2) WARP_ROWS must match the value of rows_per_warp + // of the dispatch_scaled_aligned_causal_masked_softmax_forward method. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_WIDTH = (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two + : THREADS_PER_WARP; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_WIDTH; + constexpr int WARP_ROWS = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + const int global_row_idx = (blockIdx.x * blockDim.y + threadIdx.y) * WARP_ROWS; + const int col = threadIdx.x * ELEMENTS_PER_LDG_STG; + + const size_t thread_offset = global_row_idx * cols + col; + + src += thread_offset; + dst += thread_offset; + + // load data from global memory into registers WITH scaling + acc_t elements[WARP_ROWS][WARP_ITERATIONS]; + input_t temp_data[ELEMENTS_PER_LDG_STG]; + + #pragma unroll + for (int w = 0; w < WARP_ROWS; ++w) { + const int microbatch = global_row_idx + w; + const int i = microbatch % rows; // local row index of attention matrix + const int masked_elements = i + cols - rows + 1; + + if (microbatch >= microbatches) { + break; + } + + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + const int j = col + it * WARP_WIDTH; + const int itr_idx = w * cols + it * WARP_WIDTH; + + if (j < masked_elements) { + copy_vector(temp_data, src + itr_idx); + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (j + element < masked_elements) { + elements[w][it + element] = (acc_t)temp_data[element] * scale; + } else { + elements[w][it + element] = (acc_t)( -10'000 ); + } + } + } else { + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[w][it + element] = (acc_t)( -10'000 ); + } + } + } + } + + // compute max_value + acc_t max_value[WARP_ROWS]; + #pragma unroll + for (int w = 0; w < WARP_ROWS; ++w) { + max_value[w] = elements[w][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value[w] = + (max_value[w] > elements[w][it]) ? max_value[w] : elements[w][it]; + } + } + warp_reduce(max_value); + + acc_t sum[WARP_ROWS] { 0.0f }; + #pragma unroll + for (int w = 0; w < WARP_ROWS; ++w) { + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + elements[w][it] = expf((elements[w][it] - max_value[w])); + sum[w] += elements[w][it]; + } + } + warp_reduce(sum); + + output_t out[ELEMENTS_PER_LDG_STG] { 0.0f }; + // store result + #pragma unroll + for (int w = 0; w < WARP_ROWS; ++w) { + const int microbatch = global_row_idx + w; + const int i = microbatch % rows; + const int masked_elements = i + cols - rows + 1; + + // out of Attention matrix bounds (rows) + if (microbatch >= microbatches) { + break; + } + + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + const int j = col + it * WARP_WIDTH; // index of the first column + const int itr_idx = w * cols + it * WARP_WIDTH; + + if (j < masked_elements) { + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (j + element < masked_elements) { + out[element] = elements[w][it + element] / sum[w]; + } else { + out[element] = (output_t)( 0.0f ); + } + } + copy_vector(dst + itr_idx, out); + } else if (j < cols) { + copy_zero_vector(dst + itr_idx); + } else { + break; + } + } + } +} + + +template +__global__ void scaled_aligned_causal_masked_softmax_warp_backward( + output_t *gradInput, + const input_t *grad, + const input_t *softmax_output, + const acc_t scale, + const int microbatches, + const int rows, + const int cols +) { + // 1) WARP_WIDTH must match the value of warp_size + // 2) WARP_ROWS must match the value of rows_per_warp + // of the dispatch_scaled_aligned_causal_masked_softmax_forward method. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_WIDTH = (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two + : THREADS_PER_WARP; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_WIDTH; + constexpr int WARP_ROWS = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + const int global_row_idx = (blockIdx.x * blockDim.y + threadIdx.y) * WARP_ROWS; + const int col = threadIdx.x * ELEMENTS_PER_LDG_STG; + + const size_t thread_offset = global_row_idx * cols + col; + + grad += thread_offset; + softmax_output += thread_offset; + gradInput += thread_offset; + + // load data from global memory into registers + acc_t grad_reg[WARP_ROWS][WARP_ITERATIONS] { 0.0f }; + acc_t softmax_output_reg[WARP_ROWS][WARP_ITERATIONS] { 0.0f }; + input_t temp_grad[ELEMENTS_PER_LDG_STG]; + input_t temp_output[ELEMENTS_PER_LDG_STG]; + + #pragma unroll + for (int w = 0; w < WARP_ROWS; ++w) { + const int microbatch = global_row_idx + w; + const int i = microbatch % rows; // local row index of attention matrix + const int masked_elements = i + cols - rows + 1; + + if (microbatch >= microbatches) { + break; + } + + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + const int j = col + it * WARP_WIDTH; // index of the first column + const int itr_idx = w * cols + it * WARP_WIDTH; + + if (j < masked_elements) { + copy_vector(temp_grad, grad + itr_idx); + copy_vector(temp_output, softmax_output + itr_idx); + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (j + element < masked_elements) { + softmax_output_reg[w][it + element] = (acc_t)temp_output[element]; + grad_reg[w][it + element] = + (acc_t)temp_grad[element] * softmax_output_reg[w][it + element]; + } + } + } + } + } + + acc_t sum[WARP_ROWS]; + #pragma unroll + for (int w = 0; w < WARP_ROWS; ++w) { + sum[w] = grad_reg[w][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + sum[w] += grad_reg[w][it]; + } + } + + warp_reduce(sum); + + // store result + #pragma unroll + for (int w = 0; w < WARP_ROWS; ++w) { + const int microbatch = global_row_idx + w; + if (microbatch >= microbatches) { + break; + } + + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + const int j = col + it * WARP_WIDTH; // index of the first column + const int itr_idx = w * cols + it * WARP_WIDTH; + + if (j < cols) { + output_t out[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = (output_t)(scale * (grad_reg[w][it + element] - + softmax_output_reg[w][it + element] * sum[w])); + } + copy_vector(gradInput + itr_idx, out); + } + } + } +} + +template +void call_kernel_scaled_aligned_causal_masked_softmax_forward( + dim3 grid_size, + dim3 block_size, + const int shmem_size, + cudaStream_t stream, + output_t *dst, + const input_t *src, + const acc_t scale, + const int microbatches, + const int query_seq_len, + const int key_seq_len +) { + scaled_aligned_causal_masked_softmax_warp_forward + <<>>( + dst, src, scale, microbatches, query_seq_len, key_seq_len); +} + +template +void call_kernel_scaled_aligned_causal_masked_softmax_backward( + dim3 grid_size, + dim3 block_size, + const int shmem_size, + cudaStream_t stream, + output_t *gradInput, + const input_t *grad, + const input_t *output, + const acc_t scale, + const int microbatches, + const int query_seq_len, + const int key_seq_len +) { + scaled_aligned_causal_masked_softmax_warp_backward + <<>>( + gradInput, grad, output, scale, microbatches, query_seq_len, key_seq_len); +} + +template +struct FunctionWrapper { + using ForwardType = std::function< + void( + dim3 grid_size, + dim3 block_size, + const int shmem_size, + cudaStream_t stream, + output_t *dst, + const input_t *src, + const acc_t scale, + const int microbatches, + const int query_seq_len, + const int key_seq_len + ) + >; + using BackwardType = std::function< + void( + dim3 grid_size, + dim3 block_size, + const int shmem_size, + cudaStream_t stream, + output_t *gradInput, + const input_t *grad, + const input_t *output, + const acc_t scale, + const int microbatches, + const int query_seq_len, + const int key_seq_len + ) + >; +}; + + +constexpr int MIN_SUPPORTED_POWER = 4; +constexpr int MAX_SUPPORTED_POWER = 14; +constexpr int MIN_POWER = MIN_SUPPORTED_POWER - 1; +constexpr int MAX_POWER = MAX_SUPPORTED_POWER + 1; + +// Recursively instantiate the function for the limit of "log2_elements", +// i.e. "MAX_POWER" defined above. +template +struct CompileTimeLoopForward { + using ForwardFuncType = typename FunctionWrapper::ForwardType; + static void populate(std::array* arr) { + CompileTimeLoopForward::populate(arr); + (*arr)[log2_elements] = &call_kernel_scaled_aligned_causal_masked_softmax_forward< + output_t, input_t, acc_t, log2_elements>; + } +}; + +template +struct CompileTimeLoopForward { + using ForwardFuncType = typename FunctionWrapper::ForwardType; + static void populate(std::array* arr) { + (*arr)[MIN_POWER] = nullptr; + } +}; + +template +struct CompileTimeLoopBackward { + using BackwardFuncType = typename FunctionWrapper::BackwardType; + static void populate(std::array* arr) { + CompileTimeLoopBackward::populate(arr); + (*arr)[log2_elements] = &call_kernel_scaled_aligned_causal_masked_softmax_backward< + output_t, input_t, acc_t, log2_elements>; + } +}; + +template +struct CompileTimeLoopBackward { + using BackwardFuncType = typename FunctionWrapper::BackwardType; + static void populate(std::array* arr) { + (*arr)[MIN_POWER] = nullptr; + } +}; + +template +void dispatch_scaled_aligned_causal_masked_softmax_forward( + output_t *dst, + const input_t *src, + const input_t scale, + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads, + cudaStream_t stream +) { + NVTE_CHECK(key_seq_len >= 0 && key_seq_len <= 16384, "Unsupported shape."); + + if (key_seq_len == 0) { + return; + } + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + + // This value must match the WARP_WIDTH constexpr + // value computed inside scaled_aligned_causal_masked_softmax_warp_forward. + int warp_width = (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two + : THREADS_PER_WARP; + + // This value must match the WARP_ROWS constexpr + // value computed inside scaled_aligned_causal_masked_softmax_warp_forward. + int microbatches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = threads_per_block / warp_width; + int microbatches_per_block = warps_per_block * microbatches_per_warp; + int microbatches = batches * attn_heads * query_seq_len; + int blocks = DIVUP(microbatches, microbatches_per_block); + + dim3 block_size(warp_width, warps_per_block); + dim3 grid_size(blocks); + + // create an array of pointers to functions + using ForwardFuncType = typename FunctionWrapper::ForwardType; + static std::array forwardFunctionArray; + static bool is_initialized = false; + if (!is_initialized) { + CompileTimeLoopForward::populate( + &forwardFunctionArray); + is_initialized = true; + } + // Call the corresponding kernel + forwardFunctionArray[log2_elements](grid_size, block_size, 0, stream, dst, src, scale, + microbatches, query_seq_len, key_seq_len); +} + +template +void dispatch_scaled_aligned_causal_masked_softmax_backward( + output_t *grad_input, + const input_t *grad, + const input_t *output, + const acc_t scale, + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads, + cudaStream_t stream +) { + NVTE_CHECK(key_seq_len >= 0 && key_seq_len <= 16384, "Unsupported shape."); + + if (key_seq_len == 0) { + return; + } + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + + // This value must match the WARP_WIDTH constexpr + // value computed inside scaled_aligned_causal_masked_softmax_warp_forward. + int warp_width = (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two : THREADS_PER_WARP; + + // This value must match the WARP_ROWS constexpr + // value computed inside scaled_aligned_causal_masked_softmax_warp_forward. + int microbatches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = threads_per_block / warp_width; + int microbatches_per_block = warps_per_block * microbatches_per_warp; + int microbatches = batches * attn_heads * query_seq_len; + int blocks = DIVUP(microbatches, microbatches_per_block); + + dim3 block_size(warp_width, warps_per_block); + dim3 grid_size(blocks); + + // create an array of pointers to functions + using BackwardFuncType = typename FunctionWrapper::BackwardType; + static std::array backwardFunctionArray; + static bool is_initialized = false; + if (!is_initialized) { + CompileTimeLoopBackward::populate( + &backwardFunctionArray); + is_initialized = true; + } + // Call the corresponding kernel + backwardFunctionArray[log2_elements](grid_size, block_size, 0, stream, grad_input, grad, + output, scale, microbatches, query_seq_len, key_seq_len); +} + + +void scaled_aligned_causal_masked_softmax_forward( + const Tensor &input, + Tensor *softmax_results, + float scale_factor, + cudaStream_t stream) { + + const int batches = input.data.shape[0]; + const int attn_heads = input.data.shape[1]; + const int query_seq_len = input.data.shape[2]; + const int key_seq_len = input.data.shape[3]; + + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(input.data.dtype, softmax_type, + dispatch_scaled_aligned_causal_masked_softmax_forward( + reinterpret_cast(softmax_results->data.dptr), + reinterpret_cast(input.data.dptr), + scale_factor, + query_seq_len, + key_seq_len, + batches, + attn_heads, + stream);); +} + +void scaled_aligned_causal_masked_softmax_backward( + Tensor output_grads, + const Tensor incoming_grads, + const Tensor softmax_results, + float scale_factor, + cudaStream_t stream) { + + // output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] + const int batches = output_grads.data.shape[0]; + const int attn_heads = output_grads.data.shape[1]; + const int query_seq_len = output_grads.data.shape[2]; + const int key_seq_len = output_grads.data.shape[3]; + + // Softmax Grad + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(output_grads.data.dtype, softmax_type, + dispatch_scaled_aligned_causal_masked_softmax_backward( + reinterpret_cast(output_grads.data.dptr), + reinterpret_cast(incoming_grads.data.dptr), + reinterpret_cast(softmax_results.data.dptr), + scale_factor, + query_seq_len, + key_seq_len, + batches, + attn_heads, + stream);); +} +} // end namespace transformer_engine + + +void nvte_scaled_aligned_causal_masked_softmax_forward( + const NVTETensor input, + NVTETensor softmax_results, + float scale_factor, + cudaStream_t stream +) { + NVTE_API_CALL(nvte_scaled_aligned_causal_masked_softmax_forward); + using namespace transformer_engine; + scaled_aligned_causal_masked_softmax_forward( + *reinterpret_cast(input), + reinterpret_cast(softmax_results), + scale_factor, + stream); +} + + +void nvte_scaled_aligned_causal_masked_softmax_backward( + const NVTETensor incoming_grads, + const NVTETensor softmax_results, + NVTETensor output_grads, + float scale_factor, + cudaStream_t stream +) { + NVTE_API_CALL(nvte_scaled_aligned_causal_masked_softmax_backward); + using namespace transformer_engine; + scaled_aligned_causal_masked_softmax_backward( + *reinterpret_cast(output_grads), + *reinterpret_cast(incoming_grads), + *reinterpret_cast(softmax_results), + scale_factor, + stream); +} diff --git a/transformer_engine/common/fused_softmax/scaled_masked_softmax.cu b/transformer_engine/common/fused_softmax/scaled_masked_softmax.cu index 53582f28aa..7a7194878e 100644 --- a/transformer_engine/common/fused_softmax/scaled_masked_softmax.cu +++ b/transformer_engine/common/fused_softmax/scaled_masked_softmax.cu @@ -466,7 +466,7 @@ void dispatch_scaled_softmax_forward( int batches, int attn_heads, cudaStream_t stream) { - NVTE_CHECK(key_seq_len >= 0 && key_seq_len <= 4096, "Unsupported shape."); + NVTE_CHECK(key_seq_len >= 0 && key_seq_len <= 16384, "Unsupported shape."); if (key_seq_len == 0) { return; } else { @@ -597,6 +597,22 @@ void dispatch_scaled_softmax_forward( batch_count, key_seq_len); break; + case 13: // 8192 + scaled_softmax_warp_forward + <<>>(dst, + src, + scale, + batch_count, + key_seq_len); + break; + case 14: // 16384 + scaled_softmax_warp_forward + <<>>(dst, + src, + scale, + batch_count, + key_seq_len); + break; default: break; } @@ -615,7 +631,7 @@ void dispatch_scaled_masked_softmax_forward( int attn_heads, int pad_batches, cudaStream_t stream) { - NVTE_CHECK(key_seq_len >= 0 && key_seq_len <= 4096, "Unsupported shape."); + NVTE_CHECK(key_seq_len >= 0 && key_seq_len <= 16384, "Unsupported shape."); if (key_seq_len == 0) { return; } else { @@ -772,6 +788,26 @@ void dispatch_scaled_masked_softmax_forward( key_seq_len, pad_batches); break; + case 13: // 8192 + scaled_masked_softmax_warp_forward + <<>>(dst, + src, + mask, + scale, + batch_count, + key_seq_len, + pad_batches); + break; + case 14: // 16384 + scaled_masked_softmax_warp_forward + <<>>(dst, + src, + mask, + scale, + batch_count, + key_seq_len, + pad_batches); + break; default: break; } @@ -789,7 +825,7 @@ void dispatch_scaled_masked_softmax_backward( int batches, int attn_heads, cudaStream_t stream) { - NVTE_CHECK(key_seq_len >= 0 && key_seq_len <= 4096, "Unsupported shape."); + NVTE_CHECK(key_seq_len >= 0 && key_seq_len <= 16384, "Unsupported shape."); if (key_seq_len == 0) { return; } else { @@ -833,7 +869,6 @@ void dispatch_scaled_masked_softmax_backward( batch_count, key_seq_len); break; - break; case 2: // 4 scaled_masked_softmax_warp_backward <<>>(grad_input, @@ -843,7 +878,6 @@ void dispatch_scaled_masked_softmax_backward( batch_count, key_seq_len); break; - break; case 3: // 8 scaled_masked_softmax_warp_backward <<>>(grad_input, @@ -853,7 +887,6 @@ void dispatch_scaled_masked_softmax_backward( batch_count, key_seq_len); break; - break; case 4: // 16 scaled_masked_softmax_warp_backward <<>>(grad_input, @@ -863,7 +896,6 @@ void dispatch_scaled_masked_softmax_backward( batch_count, key_seq_len); break; - break; case 5: // 32 scaled_masked_softmax_warp_backward <<>>(grad_input, @@ -873,7 +905,6 @@ void dispatch_scaled_masked_softmax_backward( batch_count, key_seq_len); break; - break; case 6: // 64 scaled_masked_softmax_warp_backward <<>>(grad_input, @@ -883,7 +914,6 @@ void dispatch_scaled_masked_softmax_backward( batch_count, key_seq_len); break; - break; case 7: // 128 scaled_masked_softmax_warp_backward <<>>(grad_input, @@ -893,7 +923,6 @@ void dispatch_scaled_masked_softmax_backward( batch_count, key_seq_len); break; - break; case 8: // 256 scaled_masked_softmax_warp_backward <<>>(grad_input, @@ -903,7 +932,6 @@ void dispatch_scaled_masked_softmax_backward( batch_count, key_seq_len); break; - break; case 9: // 512 scaled_masked_softmax_warp_backward <<>>(grad_input, @@ -913,7 +941,6 @@ void dispatch_scaled_masked_softmax_backward( batch_count, key_seq_len); break; - break; case 10: // 1024 scaled_masked_softmax_warp_backward <<>>(grad_input, @@ -923,7 +950,6 @@ void dispatch_scaled_masked_softmax_backward( batch_count, key_seq_len); break; - break; case 11: // 2048 scaled_masked_softmax_warp_backward <<>>(grad_input, @@ -933,7 +959,6 @@ void dispatch_scaled_masked_softmax_backward( batch_count, key_seq_len); break; - break; case 12: // 4096 scaled_masked_softmax_warp_backward <<>>(grad_input, @@ -943,8 +968,24 @@ void dispatch_scaled_masked_softmax_backward( batch_count, key_seq_len); break; + case 13: // 8192 + scaled_masked_softmax_warp_backward + <<>>(grad_input, + grad, + output, + scale, + batch_count, + key_seq_len); + break; + case 14: // 16384 + scaled_masked_softmax_warp_backward + <<>>(grad_input, + grad, + output, + scale, + batch_count, + key_seq_len); break; - default: break; } diff --git a/transformer_engine/common/fused_softmax/scaled_upper_triang_masked_softmax.cu b/transformer_engine/common/fused_softmax/scaled_upper_triang_masked_softmax.cu index ad54013bb5..7b1ae14b1d 100644 --- a/transformer_engine/common/fused_softmax/scaled_upper_triang_masked_softmax.cu +++ b/transformer_engine/common/fused_softmax/scaled_upper_triang_masked_softmax.cu @@ -368,7 +368,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward( int softmax_elements_stride, int attn_batches, cudaStream_t stream) { - NVTE_CHECK(softmax_elements >= 0 && softmax_elements <= 2048, "Unsupported shape."); + NVTE_CHECK(softmax_elements >= 0 && softmax_elements <= 16384, "Unsupported shape."); if (softmax_elements == 0) { return; } else { @@ -506,6 +506,33 @@ void dispatch_scaled_upper_triang_masked_softmax_forward( softmax_elements_stride, softmax_elements); break; + case 12: // 4096 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, + src, + scale, + batch_count, + softmax_elements_stride, + softmax_elements); + break; + case 13: // 8192 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, + src, + scale, + batch_count, + softmax_elements_stride, + softmax_elements); + break; + case 14: // 16384 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, + src, + scale, + batch_count, + softmax_elements_stride, + softmax_elements); + break; default: break; } @@ -522,7 +549,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward( int softmax_elements_stride, int attn_batches, cudaStream_t stream) { - NVTE_CHECK(softmax_elements >= 0 && softmax_elements <= 2048, "Unsupported shape."); + NVTE_CHECK(softmax_elements >= 0 && softmax_elements <= 16384, "Unsupported shape."); if (softmax_elements == 0) { return; } else { @@ -660,6 +687,33 @@ void dispatch_scaled_upper_triang_masked_softmax_backward( softmax_elements_stride, softmax_elements); break; + case 12: // 4096 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, + grad, output, + scale, + batch_count, + softmax_elements_stride, + softmax_elements); + break; + case 13: // 8192 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, + grad, output, + scale, + batch_count, + softmax_elements_stride, + softmax_elements); + break; + case 14: // 16384 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, + grad, output, + scale, + batch_count, + softmax_elements_stride, + softmax_elements); + break; default: break; } diff --git a/transformer_engine/common/include/transformer_engine/softmax.h b/transformer_engine/common/include/transformer_engine/softmax.h index 08abad76b4..50f0a006ee 100644 --- a/transformer_engine/common/include/transformer_engine/softmax.h +++ b/transformer_engine/common/include/transformer_engine/softmax.h @@ -125,6 +125,41 @@ void nvte_scaled_upper_triang_masked_softmax_backward( ); +/*! \brief Compute scaled softmax activation using an implicit 2D mask aligned to the bottom right corner of the input matrix. + * + * \param[in] input Input tensor for softmax. + * \param[out] softmax_results Output tensor. + * \param[in] scale_factor Scalar for the input tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_scaled_aligned_causal_masked_softmax_forward( + const NVTETensor input, + NVTETensor softmax_results, + float scale_factor, + cudaStream_t stream +); + + +/*! \brief Compute the backward pass of the scaled softmax activation using an implicit 2D mask aligned to the bottom right corner of the input matrix. + * + * - `incoming_grads` is the input tensor containing the gradients received from the following layer. + * - `softmax_results` is the output tensor of the corresponding forward softmax operation. + * - `output_grads` is the output tensor containing the computed gradients. + * + * \param[in] incoming_grads Input gradient tensor for backward. + * \param[in] softmax_results Output tensor of softmax forward. + * \param[out] output_grads Output tensor. + * \param[in] scale_factor Scalar for the output tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_scaled_aligned_causal_masked_softmax_backward( + const NVTETensor incoming_grads, + const NVTETensor softmax_results, + NVTETensor output_grads, + float scale_factor, + cudaStream_t stream +); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 7233614a55..98b1db35dc 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -521,6 +521,17 @@ at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_, float scale_factor ); + +at::Tensor scaled_aligned_causal_masked_softmax_forward(at::Tensor input, + float scale_factor +); + + +at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grads_, + at::Tensor softmax_results_, + float scale_factor +); + /*************************************************************************************************** * Rotary positional embedding **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index c0c6cb4e2b..937ee11eed 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -23,6 +23,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("scaled_upper_triang_masked_softmax_backward", &scaled_upper_triang_masked_softmax_backward, "Scaled Upper-Triangular Masked Softmax BWD"); + m.def("scaled_aligned_causal_masked_softmax_forward", + &scaled_aligned_causal_masked_softmax_forward, + "Scaled Bottom-Right Corner Aligned Masked Softmax FWD"); + m.def("scaled_aligned_causal_masked_softmax_backward", + &scaled_aligned_causal_masked_softmax_backward, + "Scaled Bottom-Right Corner Aligned Masked Softmax BWD"); // Other granular functions m.def("layernorm_fwd_fp8", &layernorm_fwd_fp8, "LN FWD FP8"); diff --git a/transformer_engine/pytorch/csrc/extensions/softmax.cu b/transformer_engine/pytorch/csrc/extensions/softmax.cu index 4830ca8e13..6bae5f6b46 100644 --- a/transformer_engine/pytorch/csrc/extensions/softmax.cu +++ b/transformer_engine/pytorch/csrc/extensions/softmax.cu @@ -20,7 +20,7 @@ at::Tensor scaled_softmax_forward(at::Tensor input, const int query_seq_len = input.size(2); const int key_seq_len = input.size(3); - AT_ASSERTM(key_seq_len <= 4096, "Key sequence length must be 4096 or less"); + AT_ASSERTM(key_seq_len <= 16384, "Key sequence length must be 16384 or less"); AT_ASSERTM(key_seq_len % 8 == 0, "Key sequence length must be divisible by 8"); AT_ASSERTM(query_seq_len > 1, "Query sequence length must be greater than 1"); @@ -92,7 +92,7 @@ at::Tensor scaled_masked_softmax_forward(at::Tensor input, const int query_seq_len = input.size(2); const int key_seq_len = input.size(3); - AT_ASSERTM(key_seq_len <= 4096, "Key sequence length must be 4096 or less"); + AT_ASSERTM(key_seq_len <= 16384, "Key sequence length must be 16384 or less"); AT_ASSERTM(key_seq_len % 8 == 0, "Key sequence length must be divisible by 8"); AT_ASSERTM(query_seq_len > 1, "Query sequence length must be greater than 1"); TORCH_CHECK(pad_batches == 1 || pad_batches == batches); @@ -160,7 +160,7 @@ at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input, const int attn_batches = input.size(0); const int seq_len = input.size(1); - AT_ASSERTM(seq_len <= 2048, "Sequence length must be 2048 or less"); + AT_ASSERTM(seq_len <= 16384, "Sequence length must be 16384 or less"); // Output auto act_options = input.options().requires_grad(false); @@ -212,3 +212,75 @@ at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_, return output_grads; } + + +at::Tensor scaled_aligned_causal_masked_softmax_forward( + at::Tensor input, + float scale_factor +) { + using namespace transformer_engine; + AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); + AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || + (input.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + const int batches = input.size(0); + const int attn_heads = input.size(1); + const int query_seq_len = input.size(2); + const int key_seq_len = input.size(3); + + AT_ASSERTM(key_seq_len <= 16384, "Key sequence length must be 16384 or less"); + AT_ASSERTM(key_seq_len % 8 == 0, "Key sequence length must be divisible by 8"); + AT_ASSERTM(query_seq_len >= 1, "Query sequence length must be greater or equal to 1"); + + // Output + auto act_options = input.options().requires_grad(false); + auto softmax_results = + torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); + + auto input_cu = makeTransformerEngineTensor(input); + auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); + + nvte_scaled_aligned_causal_masked_softmax_forward( + input_cu.data(), + softmax_results_cu.data(), + scale_factor, + at::cuda::getCurrentCUDAStream()); + + return softmax_results; +} + + +at::Tensor scaled_aligned_causal_masked_softmax_backward( + at::Tensor output_grad_, + at::Tensor softmax_results_, + float scale_factor +) { + using namespace transformer_engine; + + auto output_grads = output_grad_.contiguous(); + auto softmax_results = softmax_results_.contiguous(); + + AT_ASSERTM(output_grads.dim() == 4, "expected 4D tensor"); + AT_ASSERTM(softmax_results.dim() == 4, "expected 4D tensor"); + + AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || + (output_grads.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || + (softmax_results.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + auto output_grads_cu = makeTransformerEngineTensor(output_grads); + auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); + + // Produce gradients in place. + nvte_scaled_aligned_causal_masked_softmax_backward( + output_grads_cu.data(), + softmax_results_cu.data(), + output_grads_cu.data(), + scale_factor, + at::cuda::getCurrentCUDAStream()); + + return output_grads; +} diff --git a/transformer_engine/pytorch/softmax.py b/transformer_engine/pytorch/softmax.py index 8d0c95e29b..e25c7b3268 100644 --- a/transformer_engine/pytorch/softmax.py +++ b/transformer_engine/pytorch/softmax.py @@ -113,6 +113,70 @@ def triangular_mask(): return out +class ScaledAlignedCausalMaskedSoftmax(torch.autograd.Function): + """ + Fused operation which performs following three operations in sequence + 1. Scale the tensor. + 2. Apply causal mask aligned to the bottom right corner of the input matrix + 3. Perform softmax. + """ + + @staticmethod + def forward(ctx, inputs: torch.Tensor, scale: float) -> torch.Tensor: + """ScaledAlignedCausalMaskedSoftmax fwd""" + scale_t = torch.tensor([scale]) + softmax_results = tex.scaled_aligned_causal_masked_softmax_forward( + inputs, scale_t[0] + ) + ctx.save_for_backward(softmax_results, scale_t) + return softmax_results + + @staticmethod + def backward( + ctx, output_grads: torch.Tensor + ) -> Tuple[Union[torch.Tensor, None], ...]: + """ScaledAlignedCausalMaskedSoftmax bwd""" + softmax_results, scale_t = ctx.saved_tensors + input_grads = tex.scaled_aligned_causal_masked_softmax_backward( + output_grads, softmax_results, scale_t[0] + ) + + return input_grads, None + + @staticmethod + @fp32_compute + def symbolic(g: torch.Graph, inputs: torch._C.Value, scale: float) -> torch._C.Value: + """ScaledAlignedCausalMaskedSoftmax symbolic method""" + def triangular_mask(): + dtype = _type_utils.JitScalarType.INT64 + ones = torch.onnx.symbolic_opset9.ones_like(g, inputs, dtype) + k = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) + + # rectangular causal mask aligned to the bottom right corner of Attention matrix + rows = inputs.size(dim=-2) + cols = inputs.size(dim=-1) + diag_shift = cols - rows + 1 + + mask = g.op("Trilu", ones, k, upper_i=diag_shift) + mask = g.op("Cast", mask, to_i=_C_onnx.TensorProtoDataType.BOOL) + return mask + + # Captures the logic of function scaled_aligned_masked_softmax_warp_forward + mask = triangular_mask() + one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) + inv_mask = g.op("Sub", one, mask) + + neg_tenK = g.op("Constant", value_t=torch.tensor(-10000., dtype=torch.float16)) + softmax_mask = g.op("Mul", mask, neg_tenK) + + scale_input = g.op("Constant", value_t=torch.tensor(scale, dtype=torch.float16)) + scaled = g.op("Mul", inputs, scale_input) + masked_scaled = g.op("Mul", inv_mask, scaled) + masked = g.op("Add", masked_scaled, softmax_mask) + out = g.op("Softmax", masked) + return out + + class ScaledMaskedSoftmax(torch.autograd.Function): """ Fused operation which performs following three operations in sequence @@ -272,13 +336,13 @@ def is_kernel_available(self, mask: torch.Tensor, b: int, np: int, sq: int, sk: if ( # pylint: disable=too-many-boolean-expressions self.scaled_masked_softmax_fusion # user wants to fuse and self.input_in_float16 # input must be fp16 - and 16 < sk <= 4096 # sk must be 16 ~ 2048 + and 16 <= sk <= 16384 # sk must be 16 ~ 16384 and sk % 8 == 0 # sk must be divisor of 8 and sq % 4 == 0 # sq must be divisor of 4 and attn_batches % 4 == 0 # np * b must be divisor of 4 and self.attn_mask_type != "arbitrary" # Custom masks not supported ): - if 0 <= sk <= 4096: + if 0 <= sk <= 16384: batch_per_block = self.get_batch_per_block(int(sk)) if self.attn_mask_type == "causal":