diff --git a/source/core/Pipeline.cpp b/source/core/Pipeline.cpp index 553b964e2..30266df05 100644 --- a/source/core/Pipeline.cpp +++ b/source/core/Pipeline.cpp @@ -933,6 +933,14 @@ ErrorCode Pipeline::fixResizeCache() { break; } } + if (mOutputStatic) { + for (auto t : cmd.workOutputs) { + if (TensorUtils::getDescribe(t)->usage != Tensor::InsideDescribe::NORMAL) { + cmd.group = 0; + break; + } + } + } } if (1 == cmd.group) { fixNumber++; diff --git a/transformers/llm/engine/include/llm/llm.hpp b/transformers/llm/engine/include/llm/llm.hpp index a4b868592..72fcaa742 100644 --- a/transformers/llm/engine/include/llm/llm.hpp +++ b/transformers/llm/engine/include/llm/llm.hpp @@ -86,6 +86,7 @@ class MNN_PUBLIC Llm { int64_t prefill_us_ = 0; int64_t decode_us_ = 0; bool is_single_ = true; + bool attention_fused_ = true; protected: std::shared_ptr config_; std::shared_ptr tokenizer_; diff --git a/transformers/llm/engine/src/llm.cpp b/transformers/llm/engine/src/llm.cpp index efa9a5d23..e01350eb9 100644 --- a/transformers/llm/engine/src/llm.cpp +++ b/transformers/llm/engine/src/llm.cpp @@ -119,6 +119,7 @@ void Llm::load() { // init module status key_value_shape_ = config_->key_value_shape(); is_single_ = config_->is_single(); + attention_fused_ = config_->attention_fused(); { std::ifstream embedding_bin(config_->embedding_file()); embedding_bin.close(); @@ -144,24 +145,18 @@ void Llm::load() { std::string model_path = config_->llm_model(); MNN_PRINT("load %s ... ", model_path.c_str()); runtime_manager_->setExternalFile(config_->llm_weight()); - modules_[0].reset(Module::load( - {"input_ids", "attention_mask", "position_ids", "past_key_values"}, - {"logits", "presents"}, model_path.c_str(), runtime_manager_, &module_config)); + if (attention_fused_) { + modules_[0].reset(Module::load( + {"input_ids", "attention_mask", "position_ids"}, + {"logits"}, model_path.c_str(), runtime_manager_, &module_config)); + } else { + modules_[0].reset(Module::load( + {"input_ids", "attention_mask", "position_ids", "past_key_values"}, + {"logits", "presents"}, model_path.c_str(), runtime_manager_, &module_config)); + } MNN_PRINT("Done!\n"); } else { - // load split models - modules_.resize(layer_nums + 2); - // load lm model - modules_[layer_nums].reset(Module::load({}, {}, config_->lm_model().c_str(), runtime_manager_, &module_config)); - // load block models - for (int i = 0; i < layer_nums; i++) { - std::string model_path = config_->block_model(i); - MNN_PRINT("load %s ... ", model_path.c_str()); - modules_[i].reset(Module::load( - {"inputs_embeds", "attention_mask", "position_ids", "past_key_values"}, - {"hidden_states", "presents"}, model_path.c_str(), runtime_manager_, &module_config)); - MNN_PRINT("Done!\n"); - } + MNN_ERROR("Split version is depercerate\n"); } decode_modules_.resize(modules_.size()); for (int v=0; v& input_ids) { auto position_ids = gen_position_ids(seq_len); VARP logits; if (is_single_) { - // single model + std::vector outputs; auto hidden_states = embedding(input_ids); - auto outputs = current_modules_.back()->onForward({hidden_states, attention_mask, position_ids, past_key_values_[0]}); + if (attention_fused_) { + outputs = current_modules_.back()->onForward({hidden_states, attention_mask, position_ids}); + } else { + outputs = current_modules_.back()->onForward({hidden_states, attention_mask, position_ids, past_key_values_[0]}); + } if (outputs.empty()) { return nullptr; } logits = outputs[0]; - past_key_values_[0] = outputs[1]; - } else { - // split block models - int layer_nums = config_->layer_nums(); - auto hidden_states = embedding(input_ids); - ExecutorScope::Current()->gc(Executor::FULL); - for (int i = 0; i < layer_nums; i++) { - AUTOTIME; - auto outputs = current_modules_[i]->onForward({hidden_states, attention_mask, position_ids, past_key_values_[i]}); - hidden_states = outputs[0]; - past_key_values_[i] = outputs[1]; - } - { - AUTOTIME; - auto outputs = current_modules_[layer_nums]->onForward({hidden_states}); - logits = outputs[0]; + if (!attention_fused_) { + past_key_values_[0] = outputs[1]; } + } else { + MNN_ERROR("Split models is depercarate\n"); + return nullptr; } all_seq_len_ += seq_len; gen_seq_len_++; @@ -390,6 +378,7 @@ std::vector Llm::generate(const std::vector& input_ids, int max_new_to // decode current_modules_ = decode_modules_; while (gen_seq_len_ < max_new_tokens) { + logits = nullptr; logits = forward({token}); if (logits.get() == nullptr) { return {}; @@ -429,6 +418,7 @@ std::string Llm::generate(const std::vector& input_ids, std::ostream* os, c while (gen_seq_len_ < config_->max_new_tokens()) { st = std::chrono::system_clock::now(); history_ids_.push_back(token); + logits = nullptr; logits = forward({token}); if (nullptr == logits.get()) { return ""; diff --git a/transformers/llm/engine/src/llmconfig.hpp b/transformers/llm/engine/src/llmconfig.hpp index b09ab6177..57cc924a8 100644 --- a/transformers/llm/engine/src/llmconfig.hpp +++ b/transformers/llm/engine/src/llmconfig.hpp @@ -279,6 +279,9 @@ class LlmConfig { std::string attention_mask() const { return llm_config_.value("attention_mask", "int"); } + bool attention_fused() const { + return llm_config_.value("attention_fused", true); + } std::string chat_template() const { return llm_config_.value("chat_template", ""); @@ -290,4 +293,4 @@ class LlmConfig { // llm model config end > }; } // Transformer -} // MNN \ No newline at end of file +} // MNN