Skip to content

Commit

Permalink
Flash attention v2 forward (#10484)
Browse files Browse the repository at this point in the history
集成了flash attn v2 forward算子

---------

Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
  • Loading branch information
cccddd77 and oneflow-ci-bot committed Apr 18, 2024
1 parent 3d7b87d commit 44ad994
Show file tree
Hide file tree
Showing 11 changed files with 953 additions and 0 deletions.
5 changes: 5 additions & 0 deletions cmake/oneflow.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,11 @@ if(BUILD_CPP_API)
checkdirandappendslash(DIR ${TRT_FLASH_ATTENTION_LIBRARY_DIR} OUTPUT
TRT_FLASH_ATTENTION_LIBRARY_DIR_APPENDED)
list(APPEND LIBONEFLOW_THIRD_PARTY_DIRS ${TRT_FLASH_ATTENTION_LIBRARY_DIR_APPENDED})
if(CUDA_VERSION VERSION_GREATER_EQUAL "11.7")
checkdirandappendslash(DIR ${FLASH_ATTENTION_LIBRARY_DIR} OUTPUT
FLASH_ATTENTION_LIBRARY_DIR_APPENDED)
list(APPEND LIBONEFLOW_THIRD_PARTY_DIRS ${FLASH_ATTENTION_LIBRARY_DIR_APPENDED})
endif()
if(WITH_CUTLASS)
checkdirandappendslash(DIR ${CUTLASS_LIBRARY_DIR} OUTPUT CUTLASS_LIBRARY_DIR_APPENDED)
list(APPEND LIBONEFLOW_THIRD_PARTY_DIRS ${CUTLASS_LIBRARY_DIR_APPENDED})
Expand Down
8 changes: 8 additions & 0 deletions cmake/third_party.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ if(BUILD_CUDA)
include(nccl)
include(cutlass)
include(trt_flash_attention)
if(CUDA_VERSION VERSION_GREATER_EQUAL "11.7")
include(flash_attention)
endif()

list(APPEND oneflow_third_party_libs ${NCCL_LIBRARIES})
list(APPEND oneflow_third_party_libs ${CUDNN_LIBRARIES})
Expand All @@ -164,6 +167,11 @@ if(BUILD_CUDA)
list(APPEND oneflow_third_party_dependencies trt_flash_attention)
list(APPEND oneflow_third_party_libs ${TRT_FLASH_ATTENTION_LIBRARIES})
list(APPEND ONEFLOW_THIRD_PARTY_INCLUDE_DIRS ${TRT_FLASH_ATTENTION_INCLUDE_DIR})
if(CUDA_VERSION VERSION_GREATER_EQUAL "11.7")
list(APPEND oneflow_third_party_dependencies flash_attention)
list(APPEND oneflow_third_party_libs ${FLASH_ATTENTION_LIBRARIES})
list(APPEND ONEFLOW_THIRD_PARTY_INCLUDE_DIRS ${FLASH_ATTENTION_INCLUDE_DIR})
endif()
endif()

if(BUILD_RDMA)
Expand Down
39 changes: 39 additions & 0 deletions cmake/third_party/flash_attention.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
include(ExternalProject)

find_package(Threads)

# NOTE: A git version of 1.6.5 or later is required if this download method is used.
find_package(Git QUIET REQUIRED)

set(FLASH_ATTENTION_PROJECT flash_attention)

set(FLASH_ATTENTION_URL https://github.com/Oneflow-Inc/flash-attention-v2.git)
set(FLASH_ATTENTION_TAG eed2e82b880e06237af3e50ceac4cf6728b15645)

set(FLASH_ATTENTION_INSTALL_DIR ${THIRD_PARTY_DIR}/flash_attention)
set(FLASH_ATTENTION_INCLUDE_DIR ${FLASH_ATTENTION_INSTALL_DIR}/include CACHE PATH "" FORCE)
set(FLASH_ATTENTION_LIBRARY_DIR ${FLASH_ATTENTION_INSTALL_DIR}/lib CACHE PATH "" FORCE)
set(FLASH_ATTENTION_LIBRARIES ${FLASH_ATTENTION_LIBRARY_DIR}/libflash_attention.so)

if(THIRD_PARTY)
ExternalProject_Add(
${FLASH_ATTENTION_PROJECT}
PREFIX flash_attention
GIT_REPOSITORY ${FLASH_ATTENTION_URL}
GIT_TAG ${FLASH_ATTENTION_TAG}
UPDATE_COMMAND ""
BUILD_BYPRODUCTS ${FLASH_ATTENTION_LIBRARIES}
CMAKE_ARGS -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE}
-DCMAKE_CXX_FLAGS:STRING=${CMAKE_CXX_FLAGS}
-DCMAKE_CXX_FLAGS_DEBUG:STRING=${CMAKE_CXX_FLAGS_DEBUG}
-DCMAKE_CXX_FLAGS_RELEASE:STRING=${CMAKE_CXX_FLAGS_RELEASE}
-DCMAKE_CUDA_ARCHITECTURES:STRING=${CMAKE_CUDA_ARCHITECTURES}
CMAKE_CACHE_ARGS
-DCMAKE_CUDA_COMPILER:STRING=${CUDAToolkit_NVCC_EXECUTABLE}
-DCMAKE_C_COMPILER_LAUNCHER:STRING=${CMAKE_C_COMPILER_LAUNCHER}
-DCMAKE_CXX_COMPILER_LAUNCHER:STRING=${CMAKE_CXX_COMPILER_LAUNCHER}
-DCMAKE_INSTALL_PREFIX:PATH=${FLASH_ATTENTION_INSTALL_DIR}
-DCMAKE_INSTALL_LIBDIR:PATH=${FLASH_ATTENTION_LIBRARY_DIR}
-DCMAKE_INSTALL_MESSAGE:STRING=${CMAKE_INSTALL_MESSAGE}
-DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE})
endif(THIRD_PARTY)
4 changes: 4 additions & 0 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2684,6 +2684,10 @@
signature: "TensorTuple (Tensor x, Tensor bias, Tensor mask, *, Float fill_value=0.0, Float scale=1.0, Float p=0.5, Bool training=True, Generator generator=None) => FusedBiasAddScaleMaskSoftmaxDropout"
bind_python: True

- name: "scaled_dot_product_attention"
signature: "Tensor (Tensor query, Tensor key, Tensor value, Tensor attn_mask=None, Float dropout_p=0.0, Bool is_causal=False, Float scale=None, Int64 seed=0) => ScaledDotProductFlashAttention"
bind_python: True

- name: "fused_multi_head_attention_inference"
signature: "Tensor (Tensor query, Tensor key, Tensor value, Int64 num_heads, Bool causal=False, Int64 query_hidden_slice_start=0, Int64 query_hidden_slice_end=-1, Int64 key_hidden_slice_start=0, Int64 key_hidden_slice_end=-1, Int64 value_hidden_slice_start=0, Int64 value_hidden_slice_end=-1, Tensor attn_bias=None, Int64 causal_diagonal_offset=0) => FusedMultiHeadAttentionInference"
bind_python: True
Expand Down
119 changes: 119 additions & 0 deletions oneflow/core/functional/impl/nn_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ limitations under the License.
#include "oneflow/user/kernels/dropout_kernel.h"
#include "oneflow/user/kernels/distributions/common.h"
#include "oneflow/user/kernels/random_seed_util.h"
#include "oneflow/user/kernels/scaled_dot_product_attention_kernel.h"

#include "oneflow/core/common/container_util.h"
#include "fmt/core.h"
Expand Down Expand Up @@ -5420,6 +5421,123 @@ class NonContiguousBinaryOpGradFunctor {
std::shared_ptr<OpExpr> op_;
};

namespace {

template<int alignment_size>
Maybe<one::Tensor> pad_last_dim(const std::shared_ptr<one::Tensor>& input) {
auto num_dims = input->shape()->NumAxes();
auto last_dim_size = input->shape()->At(num_dims - 1);
if (last_dim_size % alignment_size == 0) { return input; }
auto pad_count = alignment_size - (last_dim_size % alignment_size);

return JUST(functional::Pad(input, {0, pad_count}, "constant", Scalar(0)));
;
}

} // namespace

class ScaledDotProductFlashAttentionFunctor {
public:
ScaledDotProductFlashAttentionFunctor() {
#if CUDA_VERSION >= 11070
op_ = CHECK_JUST(one::OpBuilder("scaled_dot_product_flash_attention")
.Input("query")
.Input("key")
.Input("value")
.Output("out")
.Output("softmax_lse")
.Output("rng_state")
.Build());
#endif // CUDA_VERSION >= 11070
}

Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& query,
const std::shared_ptr<one::Tensor>& key,
const std::shared_ptr<one::Tensor>& value,
const Optional<one::Tensor>& attn_mask, const float& dropout_p,
const bool& is_causal, const Optional<float>& scale,
const int64_t& seed = 0) const {
#if CUDA_VERSION >= 11070
const auto og_size = query->shape()->At(3);
const auto batch_size = query->shape()->At(0);
const auto seqlen_q = query->shape()->At(2);
const auto num_heads = query->shape()->At(1);
const auto num_heads_k = key->shape()->At(1);
const auto max_seqlen_batch_k = key->shape()->At(2);
const auto max_seqlen_batch_v = value->shape()->At(2);

CHECK_EQ_OR_RETURN(batch_size, key->shape()->At(0))
<< " key has different batch size from query.";
CHECK_EQ_OR_RETURN(batch_size, value->shape()->At(0))
<< " value has different batch size from query.";
CHECK_EQ_OR_RETURN(num_heads_k, value->shape()->At(1))
<< " value has different num_heads from key.";
CHECK_EQ_OR_RETURN(max_seqlen_batch_k, max_seqlen_batch_v)
<< "value has different seqlen from key.";
CHECK_EQ_OR_RETURN(og_size, key->shape()->At(3)) << " key has different head dims from query.";
CHECK_EQ_OR_RETURN(og_size, value->shape()->At(3))
<< " value has different head dims from query.";

// Query (Batch x Num_heads x Q_seq_len x Dim_per_head)
// Key (Batch x Num_heads x KV_seq_len x Dim_per_head)
// Value (Batch x Num_heads x KV_seq_len x Dim_per_head)
std::shared_ptr<Tensor> q_padded, k_padded, v_padded;
bool padded = og_size % 8;
if (padded) {
q_padded = JUST(pad_last_dim<8>(query));
k_padded = JUST(pad_last_dim<8>(key));
v_padded = JUST(pad_last_dim<8>(value));
} else {
q_padded = query;
k_padded = key;
v_padded = value;
}

auto q_ = JUST(functional::Transpose(q_padded, {0, 2, 1, 3}));
auto k_ = JUST(functional::Transpose(k_padded, {0, 2, 1, 3}));
auto v_ = JUST(functional::Transpose(v_padded, {0, 2, 1, 3}));
// Query -> Query(Batch x Q_seq_len x Num_heads x Dim_per_head)
// Key -> Key (Batch x KV_seq_len x Num_heads x Dim_per_head)
// Value -> Value(Batch x KV_seq_len x Num_heads x Dim_per_head)

const auto& scale_ =
scale.has_value() ? scale : (1.0f / std::sqrt(static_cast<float>(query->shape()->At(3))));

auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("p_dropout", "softmax_scale", "is_causal",
"window_size_left", "window_size_right", "seed");
attrs.SetAllAttrs(dropout_p, scale_, is_causal, -1, -1, seed);

auto gen = JUST(one::DefaultAutoGenerator());
gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), query));
const auto& state = std::make_shared<ScaledDotProductFlashAttentionKernelState>(gen);
OpExprInterpContext ctx(attrs, state);

std::shared_ptr<one::Tensor> output_ =
JUST(OpInterpUtil::Dispatch<one::Tensor>(*op_, {q_, k_, v_}, ctx));

auto output_padded = JUST(functional::Transpose(output_, {0, 2, 1, 3}));

std::shared_ptr<Tensor> output;
if (padded) {
output =
JUST(functional::Slice(output_padded, {0, 0, 0, 0},
{batch_size, num_heads, seqlen_q, og_size}, {1, 1, 1, 1}, false));
} else {
output = output_padded;
}

return output;
#endif // CUDA_VERSION >= 11070

UNIMPLEMENTED_THEN_RETURN() << "only support CUDA_VERSION >= 11070.";
}

private:
#if CUDA_VERSION >= 11070
std::shared_ptr<OpExpr> op_;
#endif // CUDA_VERSION >= 11070
};

} // namespace impl

ONEFLOW_FUNCTION_LIBRARY(m) {
Expand Down Expand Up @@ -5557,6 +5675,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::NonContiguousBinaryOpGradFunctor>("NonContiguousBinaryOpGrad");
m.add_functor<impl::MultiTensorYoloV5WeightUpdateFunctor>("MultiTensorYoloV5WeightUpdate");
m.add_functor<impl::FusedClipGradFunctor>("FusedClipGrad");
m.add_functor<impl::ScaledDotProductFlashAttentionFunctor>("ScaledDotProductFlashAttention");
}

} // namespace functional
Expand Down
26 changes: 26 additions & 0 deletions oneflow/ir/include/OneFlow/OneFlowUserOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2877,6 +2877,32 @@ def OneFlow_FusedCrossFeatureInteractionV2GradOp : OneFlow_BaseOp<"fused_cross_f
let has_data_type_infer_fn = 1;
}

def OneFlow_ScaledDotProductFlashAttentionOp : OneFlow_BaseOp<"scaled_dot_product_flash_attention", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$query,
OneFlow_Tensor:$key,
OneFlow_Tensor:$value,
Optional<OneFlow_Tensor>:$alibi_slopes_
);
let output = (outs
OneFlow_Tensor:$out,
OneFlow_Tensor:$softmax_lse,
OneFlow_Tensor:$rng_state
);
let attrs = (ins
DefaultValuedAttr<F32Attr, "0.">:$p_dropout,
DefaultValuedAttr<F32Attr, "0.">:$softmax_scale,
DefaultValuedAttr<BoolAttr, "false">:$is_causal,
SI32Attr:$window_size_left,
SI32Attr:$window_size_right,
DefaultValuedAttr<SI64Attr, "0">:$seed
);
let has_logical_tensor_desc_infer_fn = 1;
let has_physical_tensor_desc_infer_fn = 1;
let has_get_sbp_fn = 1;
let has_data_type_infer_fn = 1;
}

def OneFlow_FusedMultiHeadAttentionInferenceOp : OneFlow_BaseOp<"fused_multi_head_attention_inference", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$query,
Expand Down
Loading

0 comments on commit 44ad994

Please sign in to comment.