Skip to content

Commit

Permalink
LLM:Optimize: Reduce unusable malloc for llm forward, fix bug for fix…
Browse files Browse the repository at this point in the history
…resizecache for output mutable
  • Loading branch information
xiaying committed Aug 27, 2024
1 parent 9f38e39 commit d9f041a
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 35 deletions.
8 changes: 8 additions & 0 deletions source/core/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -933,6 +933,14 @@ ErrorCode Pipeline::fixResizeCache() {
break;
}
}
if (mOutputStatic) {
for (auto t : cmd.workOutputs) {
if (TensorUtils::getDescribe(t)->usage != Tensor::InsideDescribe::NORMAL) {
cmd.group = 0;
break;
}
}
}
}
if (1 == cmd.group) {
fixNumber++;
Expand Down
1 change: 1 addition & 0 deletions transformers/llm/engine/include/llm/llm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class MNN_PUBLIC Llm {
int64_t prefill_us_ = 0;
int64_t decode_us_ = 0;
bool is_single_ = true;
bool attention_fused_ = true;
protected:
std::shared_ptr<LlmConfig> config_;
std::shared_ptr<Tokenizer> tokenizer_;
Expand Down
58 changes: 24 additions & 34 deletions transformers/llm/engine/src/llm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ void Llm::load() {
// init module status
key_value_shape_ = config_->key_value_shape();
is_single_ = config_->is_single();
attention_fused_ = config_->attention_fused();
{
std::ifstream embedding_bin(config_->embedding_file());
embedding_bin.close();
Expand All @@ -144,24 +145,18 @@ void Llm::load() {
std::string model_path = config_->llm_model();
MNN_PRINT("load %s ... ", model_path.c_str());
runtime_manager_->setExternalFile(config_->llm_weight());
modules_[0].reset(Module::load(
{"input_ids", "attention_mask", "position_ids", "past_key_values"},
{"logits", "presents"}, model_path.c_str(), runtime_manager_, &module_config));
if (attention_fused_) {
modules_[0].reset(Module::load(
{"input_ids", "attention_mask", "position_ids"},
{"logits"}, model_path.c_str(), runtime_manager_, &module_config));
} else {
modules_[0].reset(Module::load(
{"input_ids", "attention_mask", "position_ids", "past_key_values"},
{"logits", "presents"}, model_path.c_str(), runtime_manager_, &module_config));
}
MNN_PRINT("Done!\n");
} else {
// load split models
modules_.resize(layer_nums + 2);
// load lm model
modules_[layer_nums].reset(Module::load({}, {}, config_->lm_model().c_str(), runtime_manager_, &module_config));
// load block models
for (int i = 0; i < layer_nums; i++) {
std::string model_path = config_->block_model(i);
MNN_PRINT("load %s ... ", model_path.c_str());
modules_[i].reset(Module::load(
{"inputs_embeds", "attention_mask", "position_ids", "past_key_values"},
{"hidden_states", "presents"}, model_path.c_str(), runtime_manager_, &module_config));
MNN_PRINT("Done!\n");
}
MNN_ERROR("Split version is depercerate\n");
}
decode_modules_.resize(modules_.size());
for (int v=0; v<modules_.size(); ++v) {
Expand Down Expand Up @@ -238,30 +233,23 @@ VARP Llm::forward(const std::vector<int>& input_ids) {
auto position_ids = gen_position_ids(seq_len);
VARP logits;
if (is_single_) {
// single model
std::vector<MNN::Express::VARP> outputs;
auto hidden_states = embedding(input_ids);
auto outputs = current_modules_.back()->onForward({hidden_states, attention_mask, position_ids, past_key_values_[0]});
if (attention_fused_) {
outputs = current_modules_.back()->onForward({hidden_states, attention_mask, position_ids});
} else {
outputs = current_modules_.back()->onForward({hidden_states, attention_mask, position_ids, past_key_values_[0]});
}
if (outputs.empty()) {
return nullptr;
}
logits = outputs[0];
past_key_values_[0] = outputs[1];
} else {
// split block models
int layer_nums = config_->layer_nums();
auto hidden_states = embedding(input_ids);
ExecutorScope::Current()->gc(Executor::FULL);
for (int i = 0; i < layer_nums; i++) {
AUTOTIME;
auto outputs = current_modules_[i]->onForward({hidden_states, attention_mask, position_ids, past_key_values_[i]});
hidden_states = outputs[0];
past_key_values_[i] = outputs[1];
}
{
AUTOTIME;
auto outputs = current_modules_[layer_nums]->onForward({hidden_states});
logits = outputs[0];
if (!attention_fused_) {
past_key_values_[0] = outputs[1];
}
} else {
MNN_ERROR("Split models is depercarate\n");
return nullptr;
}
all_seq_len_ += seq_len;
gen_seq_len_++;
Expand Down Expand Up @@ -390,6 +378,7 @@ std::vector<int> Llm::generate(const std::vector<int>& input_ids, int max_new_to
// decode
current_modules_ = decode_modules_;
while (gen_seq_len_ < max_new_tokens) {
logits = nullptr;
logits = forward({token});
if (logits.get() == nullptr) {
return {};
Expand Down Expand Up @@ -429,6 +418,7 @@ std::string Llm::generate(const std::vector<int>& input_ids, std::ostream* os, c
while (gen_seq_len_ < config_->max_new_tokens()) {
st = std::chrono::system_clock::now();
history_ids_.push_back(token);
logits = nullptr;
logits = forward({token});
if (nullptr == logits.get()) {
return "";
Expand Down
5 changes: 4 additions & 1 deletion transformers/llm/engine/src/llmconfig.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,9 @@ class LlmConfig {
std::string attention_mask() const {
return llm_config_.value("attention_mask", "int");
}
bool attention_fused() const {
return llm_config_.value("attention_fused", true);
}

std::string chat_template() const {
return llm_config_.value("chat_template", "");
Expand All @@ -290,4 +293,4 @@ class LlmConfig {
// llm model config end >
};
} // Transformer
} // MNN
} // MNN

0 comments on commit d9f041a

Please sign in to comment.