From 44ad9947152865f873df524abdcaa93f851e09ac Mon Sep 17 00:00:00 2001 From: chen de <72677659+cccddd77@users.noreply.github.com> Date: Thu, 18 Apr 2024 15:30:36 +0800 Subject: [PATCH] Flash attention v2 forward (#10484) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 集成了flash attn v2 forward算子 --------- Co-authored-by: oneflow-ci-bot --- cmake/oneflow.cmake | 5 + cmake/third_party.cmake | 8 + cmake/third_party/flash_attention.cmake | 39 +++ oneflow/core/functional/functional_api.yaml | 4 + oneflow/core/functional/impl/nn_functor.cpp | 119 ++++++++ oneflow/ir/include/OneFlow/OneFlowUserOps.td | 26 ++ .../scaled_dot_product_attention_kernel.cu | 258 ++++++++++++++++++ .../scaled_dot_product_attention_kernel.h | 39 +++ .../scaled_dot_product_attention_util.h | 232 ++++++++++++++++ .../scaled_dot_product_flash_attention_op.cpp | 118 ++++++++ .../test_scaled_dot_product_attention.py | 105 +++++++ 11 files changed, 953 insertions(+) create mode 100644 cmake/third_party/flash_attention.cmake create mode 100644 oneflow/user/kernels/scaled_dot_product_attention_kernel.cu create mode 100644 oneflow/user/kernels/scaled_dot_product_attention_kernel.h create mode 100644 oneflow/user/kernels/scaled_dot_product_attention_util.h create mode 100644 oneflow/user/ops/scaled_dot_product_flash_attention_op.cpp create mode 100644 python/oneflow/test/modules/test_scaled_dot_product_attention.py diff --git a/cmake/oneflow.cmake b/cmake/oneflow.cmake index 1beb41b1776..b37535367e1 100644 --- a/cmake/oneflow.cmake +++ b/cmake/oneflow.cmake @@ -632,6 +632,11 @@ if(BUILD_CPP_API) checkdirandappendslash(DIR ${TRT_FLASH_ATTENTION_LIBRARY_DIR} OUTPUT TRT_FLASH_ATTENTION_LIBRARY_DIR_APPENDED) list(APPEND LIBONEFLOW_THIRD_PARTY_DIRS ${TRT_FLASH_ATTENTION_LIBRARY_DIR_APPENDED}) + if(CUDA_VERSION VERSION_GREATER_EQUAL "11.7") + checkdirandappendslash(DIR ${FLASH_ATTENTION_LIBRARY_DIR} OUTPUT + FLASH_ATTENTION_LIBRARY_DIR_APPENDED) + list(APPEND LIBONEFLOW_THIRD_PARTY_DIRS ${FLASH_ATTENTION_LIBRARY_DIR_APPENDED}) + endif() if(WITH_CUTLASS) checkdirandappendslash(DIR ${CUTLASS_LIBRARY_DIR} OUTPUT CUTLASS_LIBRARY_DIR_APPENDED) list(APPEND LIBONEFLOW_THIRD_PARTY_DIRS ${CUTLASS_LIBRARY_DIR_APPENDED}) diff --git a/cmake/third_party.cmake b/cmake/third_party.cmake index c7ac2893e6e..43d731af056 100644 --- a/cmake/third_party.cmake +++ b/cmake/third_party.cmake @@ -145,6 +145,9 @@ if(BUILD_CUDA) include(nccl) include(cutlass) include(trt_flash_attention) + if(CUDA_VERSION VERSION_GREATER_EQUAL "11.7") + include(flash_attention) + endif() list(APPEND oneflow_third_party_libs ${NCCL_LIBRARIES}) list(APPEND oneflow_third_party_libs ${CUDNN_LIBRARIES}) @@ -164,6 +167,11 @@ if(BUILD_CUDA) list(APPEND oneflow_third_party_dependencies trt_flash_attention) list(APPEND oneflow_third_party_libs ${TRT_FLASH_ATTENTION_LIBRARIES}) list(APPEND ONEFLOW_THIRD_PARTY_INCLUDE_DIRS ${TRT_FLASH_ATTENTION_INCLUDE_DIR}) + if(CUDA_VERSION VERSION_GREATER_EQUAL "11.7") + list(APPEND oneflow_third_party_dependencies flash_attention) + list(APPEND oneflow_third_party_libs ${FLASH_ATTENTION_LIBRARIES}) + list(APPEND ONEFLOW_THIRD_PARTY_INCLUDE_DIRS ${FLASH_ATTENTION_INCLUDE_DIR}) + endif() endif() if(BUILD_RDMA) diff --git a/cmake/third_party/flash_attention.cmake b/cmake/third_party/flash_attention.cmake new file mode 100644 index 00000000000..6958afadaef --- /dev/null +++ b/cmake/third_party/flash_attention.cmake @@ -0,0 +1,39 @@ +include(ExternalProject) + +find_package(Threads) + +# NOTE: A git version of 1.6.5 or later is required if this download method is used. +find_package(Git QUIET REQUIRED) + +set(FLASH_ATTENTION_PROJECT flash_attention) + +set(FLASH_ATTENTION_URL https://github.com/Oneflow-Inc/flash-attention-v2.git) +set(FLASH_ATTENTION_TAG eed2e82b880e06237af3e50ceac4cf6728b15645) + +set(FLASH_ATTENTION_INSTALL_DIR ${THIRD_PARTY_DIR}/flash_attention) +set(FLASH_ATTENTION_INCLUDE_DIR ${FLASH_ATTENTION_INSTALL_DIR}/include CACHE PATH "" FORCE) +set(FLASH_ATTENTION_LIBRARY_DIR ${FLASH_ATTENTION_INSTALL_DIR}/lib CACHE PATH "" FORCE) +set(FLASH_ATTENTION_LIBRARIES ${FLASH_ATTENTION_LIBRARY_DIR}/libflash_attention.so) + +if(THIRD_PARTY) + ExternalProject_Add( + ${FLASH_ATTENTION_PROJECT} + PREFIX flash_attention + GIT_REPOSITORY ${FLASH_ATTENTION_URL} + GIT_TAG ${FLASH_ATTENTION_TAG} + UPDATE_COMMAND "" + BUILD_BYPRODUCTS ${FLASH_ATTENTION_LIBRARIES} + CMAKE_ARGS -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE} + -DCMAKE_CXX_FLAGS:STRING=${CMAKE_CXX_FLAGS} + -DCMAKE_CXX_FLAGS_DEBUG:STRING=${CMAKE_CXX_FLAGS_DEBUG} + -DCMAKE_CXX_FLAGS_RELEASE:STRING=${CMAKE_CXX_FLAGS_RELEASE} + -DCMAKE_CUDA_ARCHITECTURES:STRING=${CMAKE_CUDA_ARCHITECTURES} + CMAKE_CACHE_ARGS + -DCMAKE_CUDA_COMPILER:STRING=${CUDAToolkit_NVCC_EXECUTABLE} + -DCMAKE_C_COMPILER_LAUNCHER:STRING=${CMAKE_C_COMPILER_LAUNCHER} + -DCMAKE_CXX_COMPILER_LAUNCHER:STRING=${CMAKE_CXX_COMPILER_LAUNCHER} + -DCMAKE_INSTALL_PREFIX:PATH=${FLASH_ATTENTION_INSTALL_DIR} + -DCMAKE_INSTALL_LIBDIR:PATH=${FLASH_ATTENTION_LIBRARY_DIR} + -DCMAKE_INSTALL_MESSAGE:STRING=${CMAKE_INSTALL_MESSAGE} + -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE}) +endif(THIRD_PARTY) diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index 3323db0e93c..5829165e708 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -2684,6 +2684,10 @@ signature: "TensorTuple (Tensor x, Tensor bias, Tensor mask, *, Float fill_value=0.0, Float scale=1.0, Float p=0.5, Bool training=True, Generator generator=None) => FusedBiasAddScaleMaskSoftmaxDropout" bind_python: True +- name: "scaled_dot_product_attention" + signature: "Tensor (Tensor query, Tensor key, Tensor value, Tensor attn_mask=None, Float dropout_p=0.0, Bool is_causal=False, Float scale=None, Int64 seed=0) => ScaledDotProductFlashAttention" + bind_python: True + - name: "fused_multi_head_attention_inference" signature: "Tensor (Tensor query, Tensor key, Tensor value, Int64 num_heads, Bool causal=False, Int64 query_hidden_slice_start=0, Int64 query_hidden_slice_end=-1, Int64 key_hidden_slice_start=0, Int64 key_hidden_slice_end=-1, Int64 value_hidden_slice_start=0, Int64 value_hidden_slice_end=-1, Tensor attn_bias=None, Int64 causal_diagonal_offset=0) => FusedMultiHeadAttentionInference" bind_python: True diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp index ef4f2f92070..648e0832e98 100644 --- a/oneflow/core/functional/impl/nn_functor.cpp +++ b/oneflow/core/functional/impl/nn_functor.cpp @@ -28,6 +28,7 @@ limitations under the License. #include "oneflow/user/kernels/dropout_kernel.h" #include "oneflow/user/kernels/distributions/common.h" #include "oneflow/user/kernels/random_seed_util.h" +#include "oneflow/user/kernels/scaled_dot_product_attention_kernel.h" #include "oneflow/core/common/container_util.h" #include "fmt/core.h" @@ -5420,6 +5421,123 @@ class NonContiguousBinaryOpGradFunctor { std::shared_ptr op_; }; +namespace { + +template +Maybe pad_last_dim(const std::shared_ptr& input) { + auto num_dims = input->shape()->NumAxes(); + auto last_dim_size = input->shape()->At(num_dims - 1); + if (last_dim_size % alignment_size == 0) { return input; } + auto pad_count = alignment_size - (last_dim_size % alignment_size); + + return JUST(functional::Pad(input, {0, pad_count}, "constant", Scalar(0))); + ; +} + +} // namespace + +class ScaledDotProductFlashAttentionFunctor { + public: + ScaledDotProductFlashAttentionFunctor() { +#if CUDA_VERSION >= 11070 + op_ = CHECK_JUST(one::OpBuilder("scaled_dot_product_flash_attention") + .Input("query") + .Input("key") + .Input("value") + .Output("out") + .Output("softmax_lse") + .Output("rng_state") + .Build()); +#endif // CUDA_VERSION >= 11070 + } + + Maybe operator()(const std::shared_ptr& query, + const std::shared_ptr& key, + const std::shared_ptr& value, + const Optional& attn_mask, const float& dropout_p, + const bool& is_causal, const Optional& scale, + const int64_t& seed = 0) const { +#if CUDA_VERSION >= 11070 + const auto og_size = query->shape()->At(3); + const auto batch_size = query->shape()->At(0); + const auto seqlen_q = query->shape()->At(2); + const auto num_heads = query->shape()->At(1); + const auto num_heads_k = key->shape()->At(1); + const auto max_seqlen_batch_k = key->shape()->At(2); + const auto max_seqlen_batch_v = value->shape()->At(2); + + CHECK_EQ_OR_RETURN(batch_size, key->shape()->At(0)) + << " key has different batch size from query."; + CHECK_EQ_OR_RETURN(batch_size, value->shape()->At(0)) + << " value has different batch size from query."; + CHECK_EQ_OR_RETURN(num_heads_k, value->shape()->At(1)) + << " value has different num_heads from key."; + CHECK_EQ_OR_RETURN(max_seqlen_batch_k, max_seqlen_batch_v) + << "value has different seqlen from key."; + CHECK_EQ_OR_RETURN(og_size, key->shape()->At(3)) << " key has different head dims from query."; + CHECK_EQ_OR_RETURN(og_size, value->shape()->At(3)) + << " value has different head dims from query."; + + // Query (Batch x Num_heads x Q_seq_len x Dim_per_head) + // Key (Batch x Num_heads x KV_seq_len x Dim_per_head) + // Value (Batch x Num_heads x KV_seq_len x Dim_per_head) + std::shared_ptr q_padded, k_padded, v_padded; + bool padded = og_size % 8; + if (padded) { + q_padded = JUST(pad_last_dim<8>(query)); + k_padded = JUST(pad_last_dim<8>(key)); + v_padded = JUST(pad_last_dim<8>(value)); + } else { + q_padded = query; + k_padded = key; + v_padded = value; + } + + auto q_ = JUST(functional::Transpose(q_padded, {0, 2, 1, 3})); + auto k_ = JUST(functional::Transpose(k_padded, {0, 2, 1, 3})); + auto v_ = JUST(functional::Transpose(v_padded, {0, 2, 1, 3})); + // Query -> Query(Batch x Q_seq_len x Num_heads x Dim_per_head) + // Key -> Key (Batch x KV_seq_len x Num_heads x Dim_per_head) + // Value -> Value(Batch x KV_seq_len x Num_heads x Dim_per_head) + + const auto& scale_ = + scale.has_value() ? scale : (1.0f / std::sqrt(static_cast(query->shape()->At(3)))); + + auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("p_dropout", "softmax_scale", "is_causal", + "window_size_left", "window_size_right", "seed"); + attrs.SetAllAttrs(dropout_p, scale_, is_causal, -1, -1, seed); + + auto gen = JUST(one::DefaultAutoGenerator()); + gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), query)); + const auto& state = std::make_shared(gen); + OpExprInterpContext ctx(attrs, state); + + std::shared_ptr output_ = + JUST(OpInterpUtil::Dispatch(*op_, {q_, k_, v_}, ctx)); + + auto output_padded = JUST(functional::Transpose(output_, {0, 2, 1, 3})); + + std::shared_ptr output; + if (padded) { + output = + JUST(functional::Slice(output_padded, {0, 0, 0, 0}, + {batch_size, num_heads, seqlen_q, og_size}, {1, 1, 1, 1}, false)); + } else { + output = output_padded; + } + + return output; +#endif // CUDA_VERSION >= 11070 + + UNIMPLEMENTED_THEN_RETURN() << "only support CUDA_VERSION >= 11070."; + } + + private: +#if CUDA_VERSION >= 11070 + std::shared_ptr op_; +#endif // CUDA_VERSION >= 11070 +}; + } // namespace impl ONEFLOW_FUNCTION_LIBRARY(m) { @@ -5557,6 +5675,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("NonContiguousBinaryOpGrad"); m.add_functor("MultiTensorYoloV5WeightUpdate"); m.add_functor("FusedClipGrad"); + m.add_functor("ScaledDotProductFlashAttention"); } } // namespace functional diff --git a/oneflow/ir/include/OneFlow/OneFlowUserOps.td b/oneflow/ir/include/OneFlow/OneFlowUserOps.td index 05f5ea56bc3..eb7c6da6e58 100644 --- a/oneflow/ir/include/OneFlow/OneFlowUserOps.td +++ b/oneflow/ir/include/OneFlow/OneFlowUserOps.td @@ -2877,6 +2877,32 @@ def OneFlow_FusedCrossFeatureInteractionV2GradOp : OneFlow_BaseOp<"fused_cross_f let has_data_type_infer_fn = 1; } +def OneFlow_ScaledDotProductFlashAttentionOp : OneFlow_BaseOp<"scaled_dot_product_flash_attention", [NoMemoryEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$query, + OneFlow_Tensor:$key, + OneFlow_Tensor:$value, + Optional:$alibi_slopes_ + ); + let output = (outs + OneFlow_Tensor:$out, + OneFlow_Tensor:$softmax_lse, + OneFlow_Tensor:$rng_state + ); + let attrs = (ins + DefaultValuedAttr:$p_dropout, + DefaultValuedAttr:$softmax_scale, + DefaultValuedAttr:$is_causal, + SI32Attr:$window_size_left, + SI32Attr:$window_size_right, + DefaultValuedAttr:$seed + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + def OneFlow_FusedMultiHeadAttentionInferenceOp : OneFlow_BaseOp<"fused_multi_head_attention_inference", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$query, diff --git a/oneflow/user/kernels/scaled_dot_product_attention_kernel.cu b/oneflow/user/kernels/scaled_dot_product_attention_kernel.cu new file mode 100644 index 00000000000..e032bbd0150 --- /dev/null +++ b/oneflow/user/kernels/scaled_dot_product_attention_kernel.cu @@ -0,0 +1,258 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include +#include "oneflow/core/common/container_util.h" +#include "oneflow/core/common/data_type.h" +#include "oneflow/core/common/data_type.pb.h" +#include "oneflow/core/common/just.h" +#include "oneflow/core/common/maybe.h" +#include "oneflow/core/common/shape_view.h" +#include "oneflow/core/common/throw.h" +#include "oneflow/core/common/util.h" +#include "oneflow/core/framework/op_kernel.h" +#include "oneflow/core/framework/user_op_tensor.h" + +#if CUDA_VERSION >= 11070 + +#ifdef WITH_CUTLASS + +#include "oneflow/core/framework/framework.h" +#include "oneflow/core/ep/cuda/cuda_stream.h" +#include "oneflow/core/cuda/elementwise.cuh" +#include "oneflow/core/ep/include/primitive/permute.h" +#include "cutlass/arch/mma.h" +#include "cutlass/gemm/warp/mma.h" +#include "oneflow/core/kernel/cuda_graph_support.h" +#include "oneflow/user/kernels/random_seed_util.h" +#include "oneflow/user/kernels/scaled_dot_product_attention_kernel.h" +// from flash_attention +#include "oneflow/user/kernels/scaled_dot_product_attention_util.h" + +namespace oneflow { + +namespace user_op { + +namespace { + +static size_t InferTmpBufferSizeForFlashAttentionKernel(InferContext* ctx) { + const float p_dropout = ctx->Attr("p_dropout"); + const auto& q_shape = ctx->InputTensorDesc("query", 0).shape(); + const auto& k_shape = ctx->InputTensorDesc("key", 0).shape(); + const int batch_size = q_shape.At(0); + const int seqlen_q = q_shape.At(1); + const int num_heads = q_shape.At(2); + const int head_size_og = q_shape.At(3); + const int seqlen_k = k_shape.At(1); + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size = round_multiple(head_size_og, 8); + const int head_size_rounded = round_multiple(head_size, 32); + + int dev; + { + cudaError_t err = cudaGetDevice(&dev); + if (err != cudaSuccess) { return err; } + } + int sm_count; + { + cudaError_t err = cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev); + if (err != cudaSuccess) { return err; } + } + const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); + const int num_n_blocks = (seqlen_k + block_n - 1) / block_n; + const int num_m_blocks = (seqlen_q + 64 - 1) / 64; + size_t buffer_size = 0; + // for splitKV and splitKV is not implemented for dropout. + if (p_dropout == 0.0f) { + int num_splits = + num_splits_heuristic(batch_size * num_heads * num_m_blocks, sm_count, num_n_blocks, 128); + buffer_size += GetCudaAlignedSize(num_splits * batch_size * num_heads * seqlen_q + * GetSizeOfDataType(DataType::kFloat)); + buffer_size += GetCudaAlignedSize(num_splits * batch_size * num_heads * seqlen_q + * head_size_rounded * GetSizeOfDataType(DataType::kFloat)); + } + return buffer_size; +} + +class ScaledDotProductFlashAttentionKernel final : public user_op::OpKernel, + public user_op::CudaGraphSupport { + public: + ScaledDotProductFlashAttentionKernel() = default; + ~ScaledDotProductFlashAttentionKernel() override = default; + + std::shared_ptr CreateOpKernelState( + user_op::KernelInitContext* ctx) const override { + const auto& generator = CHECK_JUST(one::MakeGenerator(DeviceType::kCUDA)); + generator->set_current_seed( + CHECK_JUST(GetOpKernelRandomSeedInCurrentRank(ctx, ctx->Attr("seed")))); + return std::make_shared(generator); + } + + private: + using user_op::OpKernel::Compute; + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, + const user_op::OpKernelCache*) const override { + const Tensor* query = ctx->Tensor4ArgNameAndIndex("query", 0); + const Tensor* key = ctx->Tensor4ArgNameAndIndex("key", 0); + const Tensor* value = ctx->Tensor4ArgNameAndIndex("value", 0); + const Tensor* alibi_slopes_ = nullptr; + if (ctx->has_input("alibi_slopes_", 0)) { + // default to null, it will never get input for current flash-attn version. + alibi_slopes_ = ctx->Tensor4ArgNameAndIndex("alibi_slopes_", 0); + CHECK(!alibi_slopes_) << "alibi_slopes should not have value"; + } + + const float p_dropout = ctx->Attr("p_dropout"); + const float softmax_scale = ctx->Attr("softmax_scale"); + bool is_causal = ctx->Attr("is_causal"); + int window_size_left = ctx->Attr("window_size_left"); + int window_size_right = ctx->Attr("window_size_right"); + uint64_t seed = ctx->Attr("seed"); + + Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); + Tensor* softmax_lse = ctx->Tensor4ArgNameAndIndex("softmax_lse", 0); + Tensor* rng_state = ctx->Tensor4ArgNameAndIndex("rng_state", 0); + Tensor* tmp = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); + void* tmp_ptr = tmp->mut_dptr(); + + auto* cuda_device = dynamic_cast(ctx->stream()->device()); + auto dprops = cuda_device->properties(); + auto* cuda_stream = ctx->stream()->As(); + + const int arch = cuda_stream->cuda_arch() / 10; + const bool is_supported_arch = (arch == 80 || arch == 86 || arch == 89 || arch == 90); + CHECK(is_supported_arch) << "only supports CUDA Arch 80, 86, 89 and 90."; + + const DataType data_type = query->data_type(); + const bool is_supported_dtype = + (data_type == DataType::kFloat16 || data_type == DataType::kBFloat16); + CHECK(is_supported_dtype); + CHECK_EQ(key->data_type(), data_type); + CHECK_EQ(value->data_type(), data_type); + CHECK_EQ(out->data_type(), data_type); + + CHECK_EQ(softmax_lse->data_type(), DataType::kFloat); + + // check contiguous last dimension. + CHECK_EQ(CHECK_JUST(VectorAt(query->stride(), 3)), 1); + CHECK_EQ(CHECK_JUST(VectorAt(key->stride(), 3)), 1); + CHECK_EQ(CHECK_JUST(VectorAt(value->stride(), 3)), 1); + + const int batch_size = query->shape_view().At(0); + const int seqlen_q = query->shape_view().At(1); + const int num_heads = query->shape_view().At(2); + const int head_size_og = query->shape_view().At(3); + const int seqlen_k = key->shape_view().At(1); + const int num_heads_k = key->shape_view().At(2); + + // check tensor shape. + CHECK_EQ(query->shape_view().At(0), batch_size); + CHECK_EQ(query->shape_view().At(1), seqlen_q); + CHECK_EQ(query->shape_view().At(2), num_heads); + CHECK_EQ(query->shape_view().At(3), head_size_og); + CHECK_EQ(key->shape_view().At(0), batch_size); + CHECK_EQ(key->shape_view().At(1), seqlen_k); + CHECK_EQ(key->shape_view().At(2), num_heads_k); + CHECK_EQ(key->shape_view().At(3), head_size_og); + CHECK_EQ(value->shape_view().At(0), batch_size); + CHECK_EQ(value->shape_view().At(1), seqlen_k); + CHECK_EQ(value->shape_view().At(2), num_heads_k); + CHECK_EQ(value->shape_view().At(3), head_size_og); + CHECK_EQ(out->shape_view().At(0), batch_size); + CHECK_EQ(out->shape_view().At(1), seqlen_q); + CHECK_EQ(out->shape_view().At(2), num_heads); + CHECK_EQ(out->shape_view().At(3), head_size_og); + CHECK_EQ(softmax_lse->shape_view().At(0), batch_size); + CHECK_EQ(softmax_lse->shape_view().At(1), num_heads); + CHECK_EQ(softmax_lse->shape_view().At(2), seqlen_q); + + CHECK_GT(batch_size, 0); // batch size must be postive + CHECK_LE(head_size_og, 256); // only support head dimensions at most 256 + CHECK(num_heads % num_heads_k + == 0); // Number of heads in key/value must devide number of heads in query + + if (window_size_left >= seqlen_k) { window_size_left = -1; } + if (window_size_right >= seqlen_k) { window_size_right = -1; } + + // causal=true is the same as causal=false in this case + if (seqlen_q == 1 && !alibi_slopes_) { is_causal = false; } + if (is_causal) { window_size_right = 0; } + + const int seqlenq_ngroups_swapped = 0; + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size = round_multiple(head_size_og, 8); + const int head_size_rounded = round_multiple(head_size, 32); + const int seqlen_q_rounded = round_multiple(seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + + Flash_fwd_params params; + set_params_fprop(params, batch_size, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, head_size, head_size_rounded, query, key, value, out, + /*cu_seqlens_q_d=*/nullptr, + /*cu_seqlens_k_d=*/nullptr, + /*seqused_k=*/nullptr, + /*return_softmax=*/nullptr, softmax_lse->mut_dptr(), p_dropout, softmax_scale, + window_size_left, window_size_right); + + int64_t counter_offset = params.b * params.h * 32; + params.rng_state = rng_state->mut_dptr(); + + set_params_splitkv(params, batch_size, num_heads, head_size, seqlen_k, seqlen_q, + head_size_rounded, p_dropout, /*num_splits*/ 0, dprops, tmp_ptr); + + if (p_dropout > 0.0f) { + // todo gennerator. + auto* flash_attention_kernel_state = + dynamic_cast(state); + CHECK_NOTNULL(flash_attention_kernel_state); + const auto& generator = flash_attention_kernel_state->generator(); + CHECK_NOTNULL(generator); + const auto device_index = cuda_device->device_index(); + std::shared_ptr cuda_generator = + CHECK_JUST(generator->Get(device_index)); + params.philox_args = + at::PhiloxCudaState(seed, cuda_generator->get_philox_offset(counter_offset)); + } + + set_params_alibi(params, alibi_slopes_, batch_size, num_heads); + + if (seqlen_k > 0) { run_mha_fwd(params, cuda_stream->cuda_stream()); } + } + + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +#define REGISTER_SCALED_DOT_PRODUCT_FLASH_ATTENTION_KERNEL(dtype) \ + REGISTER_USER_KERNEL("scaled_dot_product_flash_attention") \ + .SetCreateFn() \ + .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ + && (user_op::HobDataType("out", 0) == dtype)) \ + .SetInferTmpSizeFn(InferTmpBufferSizeForFlashAttentionKernel); + +REGISTER_SCALED_DOT_PRODUCT_FLASH_ATTENTION_KERNEL(DataType::kFloat16) +REGISTER_SCALED_DOT_PRODUCT_FLASH_ATTENTION_KERNEL(DataType::kBFloat16) + +} // namespace + +} // namespace user_op + +} // namespace oneflow + +#endif // WITH_CUTLASS + +#endif // CUDA_VERSION >= 11070 diff --git a/oneflow/user/kernels/scaled_dot_product_attention_kernel.h b/oneflow/user/kernels/scaled_dot_product_attention_kernel.h new file mode 100644 index 00000000000..20486bd4a40 --- /dev/null +++ b/oneflow/user/kernels/scaled_dot_product_attention_kernel.h @@ -0,0 +1,39 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#ifndef ONEFLOW_USER_KERNELS_FLASH_ATTENTION_KERNEL_H_ +#define ONEFLOW_USER_KERNELS_FLASH_ATTENTION_KERNEL_H_ + +#include "oneflow/user/kernels/random_mask_generator.h" +#include "oneflow/core/framework/framework.h" + +namespace oneflow { + +class ScaledDotProductFlashAttentionKernelState : public user_op::OpKernelState { + public: + explicit ScaledDotProductFlashAttentionKernelState( + const std::shared_ptr& generator) + : generator_(generator) {} + + const std::shared_ptr& generator() const { return generator_; } + + private: + std::shared_ptr generator_; +}; + +} // namespace oneflow + +#endif // ONEFLOW_USER_KERNELS_FLASH_ATTENTION_KERNEL_H_ diff --git a/oneflow/user/kernels/scaled_dot_product_attention_util.h b/oneflow/user/kernels/scaled_dot_product_attention_util.h new file mode 100644 index 00000000000..29d93c01da5 --- /dev/null +++ b/oneflow/user/kernels/scaled_dot_product_attention_util.h @@ -0,0 +1,232 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#ifndef ONEFLOW_USER_KERNELS_FLASH_ATTENTION_UTIL_H_ +#define ONEFLOW_USER_KERNELS_FLASH_ATTENTION_UTIL_H_ + +#include "oneflow/core/framework/user_op_tensor.h" +#include "oneflow/core/common/util.h" +#include "flash.h" +#include "static_switch.h" + +namespace oneflow { + +namespace user_op { + +namespace { + +void set_params_fprop(Flash_fwd_params& params, + // sizes + const size_t b, const size_t seqlen_q, const size_t seqlen_k, + const size_t seqlen_q_rounded, const size_t seqlen_k_rounded, const size_t h, + const size_t h_k, const size_t d, const size_t d_rounded, + // device pointers + const Tensor* q, const Tensor* k, const Tensor* v, Tensor* out, + void* cu_seqlens_q_d, void* cu_seqlens_k_d, void* seqused_k, void* p_d, + void* softmax_lse_d, float p_dropout, float softmax_scale, + int window_size_left, int window_size_right, + bool seqlenq_ngroups_swapped = false) { + // Reset the parameters + std::memset(¶ms, 0, sizeof(params)); + + params.is_bf16 = q->data_type() == DataType::kBFloat16; + + // Set the pointers and strides. + params.q_ptr = const_cast(q->dptr()); + params.k_ptr = const_cast(k->dptr()); + params.v_ptr = const_cast(v->dptr()); + // All stride are in elements, not bytes. + params.q_row_stride = CHECK_JUST(VectorAt(q->stride(), 1)); + params.k_row_stride = CHECK_JUST(VectorAt(k->stride(), 1)); + params.v_row_stride = CHECK_JUST(VectorAt(v->stride(), 1)); + params.q_head_stride = CHECK_JUST(VectorAt(q->stride(), 2)); + params.k_head_stride = CHECK_JUST(VectorAt(k->stride(), 2)); + params.v_head_stride = CHECK_JUST(VectorAt(v->stride(), 2)); + params.o_ptr = out->mut_dptr(); + params.o_row_stride = CHECK_JUST(VectorAt(out->stride(), 1)); + params.o_head_stride = CHECK_JUST(VectorAt(out->stride(), 2)); + + if (cu_seqlens_q_d == nullptr) { + params.q_batch_stride = CHECK_JUST(VectorAt(q->stride(), 0)); + params.k_batch_stride = CHECK_JUST(VectorAt(k->stride(), 0)); + params.v_batch_stride = CHECK_JUST(VectorAt(v->stride(), 0)); + params.o_batch_stride = CHECK_JUST(VectorAt(out->stride(), 0)); + if (seqlenq_ngroups_swapped) { + params.q_batch_stride *= seqlen_q; + params.o_batch_stride *= seqlen_q; + } + } + + params.cu_seqlens_q = static_cast(cu_seqlens_q_d); + params.cu_seqlens_k = static_cast(cu_seqlens_k_d); + params.seqused_k = static_cast(seqused_k); + + // P = softmax(QK^T) + params.p_ptr = p_d; + + // Softmax sum + params.softmax_lse_ptr = softmax_lse_d; + + // Set the dimensions. + params.b = b; + params.h = h; + params.h_k = h_k; + params.h_h_k_ratio = h / h_k; + params.seqlen_q = seqlen_q; + params.seqlen_k = seqlen_k; + params.seqlen_q_rounded = seqlen_q_rounded; + params.seqlen_k_rounded = seqlen_k_rounded; + params.d = d; + params.d_rounded = d_rounded; + + // Set the different scale values. + params.scale_softmax = softmax_scale; + params.scale_softmax_log2 = softmax_scale * M_LOG2E; + + // Set this to probability of keeping an element to simplify things. + params.p_dropout = 1.f - p_dropout; + // Convert p from float to int so we don't have to convert the random uint to float to compare. + // [Minor] We want to round down since when we do the comparison we use <= instead of < + // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0)); + // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0)); + params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); + params.rp_dropout = 1.f / params.p_dropout; + params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax; + CHECK_LT(p_dropout, 1.f); +#ifdef FLASHATTENTION_DISABLE_DROPOUT + TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout."); +#endif + + // Causal is the special case where window_size_right == 0 and window_size_left < 0. + // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. + params.is_causal = window_size_left < 0 && window_size_right == 0; + + if (window_size_left < 0 && window_size_right >= 0) { window_size_left = seqlen_k; } + if (window_size_left >= 0 && window_size_right < 0) { window_size_right = seqlen_k; } + params.window_size_left = window_size_left; + params.window_size_right = window_size_right; + +#ifdef FLASHATTENTION_DISABLE_LOCAL + TORCH_CHECK(params.is_causal || (window_size_left < 0 && window_size_right < 0), + "This flash attention build does not support local attention."); +#endif + + params.is_seqlens_k_cumulative = true; + +#ifdef FLASHATTENTION_DISABLE_UNEVEN_K + TORCH_CHECK(d == d_rounded, + "This flash attention build does not support headdim not being a multiple of 32."); +#endif +} + +void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream, bool force_split_kernel = false) { + FP16_SWITCH(!params.is_bf16, [&] { + HEADDIM_SWITCH(params.d, [&] { + if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0 + run_mha_fwd_(params, stream); + } else { + run_mha_fwd_splitkv_dispatch(params, stream); + } + }); + }); +} + +// Find the number of splits that maximizes the occupancy. For example, if we have +// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is +// better than having 3 splits (efficiency = 0.67). However, we also don't want too many +// splits as that would incur more HBM reads/writes. +// So we find the best efficiency, then find the smallest number of splits that gets 85% +// of the best efficiency. +inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, + int max_splits) { + // If we have enough to almost fill the SMs, then just use 1 split + if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; } + max_splits = std::min({max_splits, num_SMs, num_n_blocks}); + float max_efficiency = 0.f; + std::vector efficiency; + efficiency.reserve(max_splits); + auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; + // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits, + // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks + // (i.e. it's 11 splits anyway). + // So we check if the number of blocks per split is the same as the previous num_splits. + auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) { + return num_splits == 1 + || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1); + }; + for (int num_splits = 1; num_splits <= max_splits; num_splits++) { + if (!is_split_eligible(num_splits)) { + efficiency.push_back(0.f); + } else { + float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs; + float eff = n_waves / ceil(n_waves); + // printf("num_splits = %d, eff = %f\n", num_splits, eff); + if (eff > max_efficiency) { max_efficiency = eff; } + efficiency.push_back(eff); + } + } + for (int num_splits = 1; num_splits <= max_splits; num_splits++) { + if (!is_split_eligible(num_splits)) { continue; } + if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) { + // printf("num_splits chosen = %d\n", num_splits); + return num_splits; + } + } + return 1; +} + +void set_params_splitkv(Flash_fwd_params& params, const int batch_size, const int num_heads, + const int head_size, const int max_seqlen_k, const int max_seqlen_q, + const int head_size_rounded, const float p_dropout, const int num_splits, + cudaDeviceProp& dprops, void* tmp_ptr) { + // This needs to match with run_mha_fwd_splitkv_dispatch + const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); + const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n; + // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel. + // In any case we don't expect seqlen_q to be larger than 64 for inference. + const int num_m_blocks = (max_seqlen_q + 64 - 1) / 64; + params.num_splits = num_splits; + if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout + if (num_splits < 1) { + params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, + dprops.multiProcessorCount, num_n_blocks, 128); + } + if (params.num_splits > 1) { + size_t softmax_lse_accum_size = + params.num_splits * batch_size * num_heads * max_seqlen_q * sizeof(float); + params.softmax_lseaccum_ptr = tmp_ptr; + params.oaccum_ptr = + reinterpret_cast(tmp_ptr) + GetCudaAlignedSize(softmax_lse_accum_size); + } + CHECK_LE(params.num_splits, 128); + } +} + +void set_params_alibi(Flash_fwd_params& params, const Tensor* alibi_slopes_, int batch_size, + int num_heads) { + // TODO(ChenDe): Need Support Alibi params. + // default to null + CHECK(!alibi_slopes_) << "alibi_slopes should be null."; + params.alibi_slopes_ptr = nullptr; +} + +} // namespace + +} // namespace user_op + +} // namespace oneflow + +#endif // ONEFLOW_USER_KERNELS_FLASH_ATTENTION_UTIL_H_ diff --git a/oneflow/user/ops/scaled_dot_product_flash_attention_op.cpp b/oneflow/user/ops/scaled_dot_product_flash_attention_op.cpp new file mode 100644 index 00000000000..7e0ebc0102a --- /dev/null +++ b/oneflow/user/ops/scaled_dot_product_flash_attention_op.cpp @@ -0,0 +1,118 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "oneflow/core/common/data_type.pb.h" +#include "oneflow/core/common/just.h" +#include "oneflow/core/common/maybe.h" +#include "oneflow/core/common/shape.h" +#include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" + +namespace oneflow { + +Maybe ScaledDotProductFlashAttentionOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& q_shape = ctx->InputShape("query", 0); + const Shape& k_shape = ctx->InputShape("key", 0); + const Shape& v_shape = ctx->InputShape("value", 0); + + auto batch_size = q_shape.At(0); + auto seqlen_q = q_shape.At(1); + auto num_heads = q_shape.At(2); + auto head_size_og = q_shape.At(3); + auto seqlen_k = k_shape.At(1); + auto num_heads_k = k_shape.At(2); + + // check input tensor shape. + CHECK_EQ_OR_RETURN(batch_size, k_shape.At(0)) << "query has different batch size from key."; + CHECK_EQ_OR_RETURN(batch_size, v_shape.At(0)) << "query has different batch size from value."; + + CHECK_EQ_OR_RETURN(seqlen_k, v_shape.At(1)) << "key has different seqlen from value."; + CHECK_EQ_OR_RETURN(num_heads_k, v_shape.At(2)) << "key has different num_heads from value."; + + CHECK_EQ_OR_RETURN(head_size_og, k_shape.At(3)) << "query has different head_size from key"; + CHECK_EQ_OR_RETURN(head_size_og, v_shape.At(3)) << "query has different head_size from value"; + + // batch size must be positive. + CHECK_GT_OR_RETURN(batch_size, 0) << "batch size must be positive"; + + // only support head dimensions at most 256. + CHECK_LE_OR_RETURN(head_size_og, 256) << "only support head dimensions at most 256"; + + // number of heads in key/value must devide number of heads in query. + CHECK_EQ_OR_RETURN(num_heads % num_heads_k, 0) + << "number of heads in key/value must devide number of heads in query."; + + ctx->SetOutputShape("out", 0, Shape({batch_size, seqlen_q, num_heads, head_size_og})); + // save for backward + ctx->SetOutputShape("softmax_lse", 0, Shape({batch_size, num_heads, seqlen_q})); + // save seed and offset for backward. + ctx->SetOutputShape("rng_state", 0, Shape({2})); + + return Maybe::Ok(); +} + +Maybe ScaledDotProductFlashAttentionOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return ScaledDotProductFlashAttentionOp::InferLogicalTensorDesc(ctx); +} + +Maybe ScaledDotProductFlashAttentionOp::GetSbp(user_op::SbpContext* ctx) { + auto parallel_num = ctx->parallel_num(); + const Shape& q_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("query", 0).shape(); + const Shape& k_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("key", 0).shape(); + auto num_heads = q_shape.At(2); + auto num_heads_k = k_shape.At(2); + bool can_spilt_num_heads = + num_heads == num_heads_k || (!(num_heads % parallel_num) && !(num_heads_k % parallel_num)); + if (can_spilt_num_heads) { + // prior to split on num_heads. + ctx->NewBuilder() + .Split(user_op::OpArg("query", 0), 2) + .Split(user_op::OpArg("key", 0), 2) + .Split(user_op::OpArg("value", 0), 2) + .Split(user_op::OpArg("out", 0), 2) + .Split(user_op::OpArg("softmax", 0), 1) + .Broadcast(user_op::OpArg("rng_state", 0)) + .Build(); + } else { + // otherwise split on batch_size. + ctx->NewBuilder() + .Split(user_op::OpArg("query", 0), 0) + .Split(user_op::OpArg("key", 0), 0) + .Split(user_op::OpArg("value", 0), 0) + .Split(user_op::OpArg("out", 0), 0) + .Split(user_op::OpArg("softmax", 0), 0) + .Broadcast(user_op::OpArg("rng_state", 0)) + .Build(); + } + return Maybe::Ok(); +} + +Maybe ScaledDotProductFlashAttentionOp::InferDataType(user_op::InferContext* ctx) { + auto q_datatype = ctx->InputDType("query", 0); + auto k_datatype = ctx->InputDType("key", 0); + auto v_datatype = ctx->InputDType("value", 0); + + CHECK_EQ_OR_RETURN(q_datatype, k_datatype) << "query has different data type from key."; + CHECK_EQ_OR_RETURN(q_datatype, v_datatype) << "query has different data type from value."; + + ctx->SetOutputDType("out", 0, q_datatype); + ctx->SetOutputDType("softmax_lse", 0, DataType::kFloat); + ctx->SetOutputDType("rng_state", 0, DataType::kUInt64); + + return Maybe::Ok(); +} + +} // namespace oneflow diff --git a/python/oneflow/test/modules/test_scaled_dot_product_attention.py b/python/oneflow/test/modules/test_scaled_dot_product_attention.py new file mode 100644 index 00000000000..51863b79643 --- /dev/null +++ b/python/oneflow/test/modules/test_scaled_dot_product_attention.py @@ -0,0 +1,105 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import unittest +from collections import OrderedDict +import numpy as np +from oneflow.test_utils.test_util import GenArgList +import math +import os + +import oneflow as flow + + +def _scaled_dot_product_attention( + query, key, value, +): + # input dims will equal 3 or 4. + if key.ndim == 4: + key = key.permute(0, 1, 3, 2) + elif key.ndim == 3: + key = key.permute(0, 2, 1) + scores = flow.matmul(query, key) / math.sqrt(query.shape[-1]) + attn = flow.softmax(scores, dim=-1) + out = flow.matmul(attn, value) + return out + + +def _test_scaled_dot_product_attention( + test_case, batch_size, num_head_pair, seq_len_pair, head_size, dtype, +): + num_heads = num_head_pair[0] + num_heads_k = num_head_pair[1] + seq_len_q = seq_len_pair[0] + seq_len_kv = seq_len_pair[1] + query = flow.randn( + (batch_size, num_heads, seq_len_q, head_size), device="cuda", dtype=flow.float, + ).to(dtype) + key = flow.randn( + (batch_size, num_heads_k, seq_len_kv, head_size), + device="cuda", + dtype=flow.float, + ).to(dtype) + value = flow.randn( + (batch_size, num_heads_k, seq_len_kv, head_size), + device="cuda", + dtype=flow.float, + ).to(dtype) + + fused_out = ( + flow._C.scaled_dot_product_attention(query=query, key=key, value=value,) + .cpu() + .numpy() + ) + if num_heads == num_heads_k: + ref_out = _scaled_dot_product_attention(query, key, value,).cpu().numpy() + else: # For GQA + ref_out = flow.empty(query.shape, device="cuda", dtype=dtype) + stride = num_heads / num_heads_k + for i in range(0, num_heads): + j = int(i / stride) + ref_out[:, i, :, :] = _scaled_dot_product_attention( + query[:, i, :, :], key[:, j, :, :], value[:, j, :, :] + ) + + if dtype == flow.float16: + test_case.assertTrue(np.allclose(ref_out, fused_out, atol=1e-2, rtol=1e-2)) + elif dtype == flow.bfloat16: + test_case.assertTrue(np.allclose(ref_out, fused_out, atol=1e-1, rtol=1e-1)) + else: + test_case.assertTrue(np.allclose(ref_out, fused_out, atol=1e-3, rtol=1e-3)) + + +@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") +@flow.unittest.skip_unless_1n1d() +class TestScaledDotProductAttention(flow.unittest.TestCase): + def test_scaled_dot_product_attention(test_case): + args_dict = OrderedDict() + args_dict["test_fun"] = [_test_scaled_dot_product_attention] + args_dict["batchsize"] = [1, 2, 4] + args_dict["num_head_pair"] = [[16, 16], [16, 8]] + args_dict["seqlen_pair"] = [[4096, 4096], [4096, 77], [1024, 1024], [1024, 77]] + args_dict["head_size"] = [40, 80, 160, 41] + args_dict["dtype"] = [flow.float16, flow.bfloat16] + + if flow._oneflow_internal.flags.with_cuda(): + if flow._oneflow_internal.flags.cuda_version() >= 11070: + if flow.cuda.get_device_capability()[0] >= 8: + for arg in GenArgList(args_dict): + arg[0](test_case, *arg[1:]) + + +if __name__ == "__main__": + unittest.main()