Skip to content

Commit

Permalink
MNN:Sync: Fix bug for llama2/llama3 attention fuse, refract llm usage
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaying committed Jun 15, 2024
1 parent 226f1bc commit 65ec0ea
Show file tree
Hide file tree
Showing 110 changed files with 12,586 additions and 2,772 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ option(MNN_BUILD_LLM "Build llm library based MNN." OFF)
option(MNN_BUILD_DIFFUSION "Build diffusion demo based MNN." OFF)
option(MNN_INTERNAL "Build with MNN internal features, such as model authentication, metrics logging" OFF)
option(MNN_JNI "Build MNN Jni for java to use" OFF)
option(MNN_SUPPORT_BF16 "Enable MNN's bf16 op" OFF)
option(MNN_LOW_MEMORY "Build MNN support low memory for weight quant model." OFF)

IF (OHOS)
include($ENV{NODE_PATH}/@ali/tcpkg/tcpkg.cmake)
Expand Down
6 changes: 3 additions & 3 deletions docs/compile/engine.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
## Linux/MacOS
- 环境要求
- cmake >= 3.10
- gcc >= 4.9
- gcc >= 4.9 或者使用 clang
- 相关编译选项
- `MNN_ONEDNN` 是否使用oneDNN库来加速卷积运算
- `MNN_AVX512` 是否使用AVX512指令,需要gcc9以上版本编译
- `MNN_OPENCL` 是否使用OpenCL后端,针对GPU设备
- `MNN_METAL` 是否使用Metal后端,针对MacOS/iOSGPU设备
- `MNN_VULKAN` 是否使用Vulkan后端,针对GPU设备
- `MNN_CUDA` 是否使用CUDA后端,针对Nivida GPU设备
- `MNN_TENSORRT` 是否使用TensorRT后端,针对Nivida GPU设备
- 其他编译选项可自行查看 CMakeLists.txt
- 具体步骤
1. 准备工作 (可选,修改 MNN Schema 后需要)
```bash
Expand Down
24 changes: 23 additions & 1 deletion docs/compile/tools.md → docs/compile/other.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# 工具模块编译
# 其他模块编译

## 模型转换工具
- 相关编译选项
Expand Down Expand Up @@ -31,6 +31,28 @@
- `runTrainDemo.out` 运行训练框架demo的入口程序
- `transformer` 训练模型转换器,将推理用的MNN模型转换为执行训练的MNN模型
- `extractForInfer` 从执行训练的MNN模型中提取参数,对应更新推理用的MNN模型
## 生成式模型
- 相关编译选项
- `MNN_BUILD_DIFFUSION` 是否编译扩散模型推理示例
- `MNN_BUILD_LLM` 是否编译大语言模型推理引擎
- `MNN_SUPPORT_TRANSFORMER_FUSE` 是否支持`transformer`相关的融合算子,主要加速transformer模型
- 编译命令
- 编译扩散模型推理示例
```bash
mkdir build && cd build
cmake .. -DMNN_BUILD_OPENCV=ON -DMNN_IMGCODECS=ON -DMNN_BUILD_DIFFUSION=ON -DMNN_SUPPORT_TRANSFORMER_FUSE=ON
make -j4
```
- 编译大语言模型推理引擎
```bash
mkdir build && cd build
cmake .. -DMNN_BUILD_LLM=ON -DMNN_SUPPORT_TRANSFORMER_FUSE=ON
make -j4
```
- 编译产物
- `libllm.so` 大语言模型推理库
- `llm_demo` 大语言模型推理示例程序
- `diffusion_demo` 扩散模型示例程序
## 测试工具
- 相关编译选项
- `MNN_BUILD_TOOL` 是否编译测试工具
Expand Down
10 changes: 9 additions & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

compile/cmake
compile/engine
compile/tools
compile/other
compile/pymnn

.. toctree::
Expand Down Expand Up @@ -62,6 +62,14 @@
train/finetune
train/distl

.. toctree::
:maxdepth: 1
:caption: 生成式模型
:name: transformers

transformers/diffusion
transformers/llm

.. toctree::
:maxdepth: 1
:caption: 测试工具
Expand Down
3 changes: 3 additions & 0 deletions docs/transformers/diffusion.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# 扩散模型

TODO
198 changes: 198 additions & 0 deletions docs/transformers/llm.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
# 大语言模型

基于MNN开发的LLM推理引擎,支持目前主流的开源LLM模型。该功能分为2部分:
- 模型导出:将torch模型导出为onnx,然后转换为mnn模型;导出tokenizer文件,embedding等文件;
- 模型推理:支持导出的模型推理,支持LLM模型的文本生成;

## 模型导出

`llm_export`是一个llm模型导出工具,能够将llm模型导出为onnx和mnn模型。

### 用法
1. 将需要导出的LLM项目clone到本地,如:Qwen2-0.5B-Instruct
```sh
git clone https://www.modelscope.cn/qwen/Qwen2-0.5B-Instruct.git
```
3. 执行`llm_export.py`导出模型
```sh
cd ./transformers/llm/export
# 导出模型,tokenizer和embedding,并导出对应的mnn模型
python llm_export.py \
--type Qwen2-0_5B-Instruct \
--path /path/to/Qwen2-0.5B-Instruct \
--export \
--export_token \
--export_embed --embed_bin \
--export_mnn
```
4. 导出产物
导出产物为:
1. `embeddings_bf16.bin`: 模型的embedding权重二进制文件,推理时使用;
2. `llm_config.json`: 模型的配置信息,推理时使用;
3. `llm.onnx`: 模型的onnx文件,推理时不使用;
4. `tokenizer.txt`: 模型的tokenzier文件,推理时使用;
5. `llm.mnn`: 模型的mnn文件,推理时使用;
6. `llm.mnn.weight`: 模型的mnn权重,推理时使用;
目录结构如下所示:
```
.
├── onnx
| ├── embeddings_bf16.bin
| ├── llm_config.json
| ├── llm.onnx
| └── tokenizer.txt
└── mnn
├── llm.mnn
└── llm.mnn.weight
```

### 功能
- 支持将模型完整导出为一个onnx模型,使用`--export`
- 支持将模型分段导出为多个模型,使用`--export_split`
- 支持导出模型的词表到一个文本文件,每行代表一个token;其中token使用base64编码;使用`--export_verbose`
- 支持导出模型的Embedding层为一个onnx模型,使用`--export_embed`,同时支持bf16格式,使用`--embed_bf16`
- 支持分层导出模型的block,使用`--export_blocks`导出全部层;使用`--export_block $id`导出指定层
- 支持导出模型的lm_head层为一个onnx模型,使用`--export_lm`
- 支持导出多模态模型的visual模型为一个onnx模型,使用`--export_visual`
- 支持对模型进行对话测试,使用`--test $query`会返回llm的回复内容
- 支持在导出onnx模型后使用onnxruntime对结果一致性进行校验,使用`--export_test`
- 支持将tokenizer导出为文本文件,使用`--export_token`
- 支持将导出的onnx模型转换为mnn模型,默认转换为非对称4bit量化,使用`--export_mnn`
- 指定导出路径使用`--onnx_path``--mnn_path`
- 默认会使用onnx-slim对onnx模型进行优化,跳过该步骤使用`--skip_slim`
- 支持合并lora权重后导出,指定lora权重的目录使用`--lora_path`

### 参数
```
usage: llm_export.py [-h] --path PATH
[--type {chatglm-6b,chatglm2-6b,chatglm3-6b,codegeex2-6b,Qwen-7B-Chat,Qwen-1_8B-Chat,Qwen-1_8B,Qwen-VL-Chat,Qwen1_5-0_5B-Chat,Qwen1_5-1_8B-Chat,Qwen1_5-4B-Chat,Qwen1_5-7B-Chat,Qwen2-1_5B-Instruct,Baichuan2-7B-Chat,Llama-2-7b-chat-ms,Llama-3-8B-Instruct,internlm-chat-7b,TinyLlama-1_1B-Chat,Yi-6B-Chat,deepseek-llm-7b-chat,phi-2,bge-large-zh,lora}]
[--lora_path LORA_PATH] [--onnx_path ONNX_PATH] [--mnn_path MNN_PATH] [--export_mnn] [--export_verbose] [--export_test] [--test TEST] [--export] [--export_split] [--export_token]
[--export_embed] [--export_visual] [--export_lm] [--export_block EXPORT_BLOCK] [--export_blocks] [--embed_bin] [--embed_bf16] [--skip_slim]
llm_exporter
options:
-h, --help show this help message and exit
--path PATH path(`str` or `os.PathLike`):
Can be either:
- A string, the *model id* of a pretrained model like `THUDM/chatglm-6b`. [TODO]
- A path to a *directory* clone from repo like `../chatglm-6b`.
--type {chatglm-6b,chatglm2-6b,chatglm3-6b,codegeex2-6b,Qwen-7B-Chat,Qwen-1_8B-Chat,Qwen-1_8B,Qwen-VL-Chat,Qwen1_5-0_5B-Chat,Qwen1_5-1_8B-Chat,Qwen1_5-4B-Chat,Qwen1_5-7B-Chat,Qwen2-1_5B-Instruct,Baichuan2-7B-Chat,Llama-2-7b-chat-ms,Llama-3-8B-Instruct,internlm-chat-7b,TinyLlama-1_1B-Chat,Yi-6B-Chat,deepseek-llm-7b-chat,phi-2,bge-large-zh,lora}
type(`str`, *optional*):
The pretrain llm model type.
--lora_path LORA_PATH
lora path, defaut is `None` mean not apply lora.
--onnx_path ONNX_PATH
export onnx model path, defaut is `./onnx`.
--mnn_path MNN_PATH export mnn model path, defaut is `./mnn`.
--export_mnn Whether or not to export mnn model after onnx.
--export_verbose Whether or not to export onnx with verbose.
--export_test Whether or not to export onnx with test using onnxruntime.
--test TEST test model inference with query `TEST`.
--export export model to an `onnx` model.
--export_split export model split to some `onnx` models:
- embedding model.
- block models.
- lm_head model.
--export_token export llm tokenizer to a txt file.
--export_embed export llm embedding to an `onnx` model.
--export_visual export llm visual model to an `onnx` model.
--export_lm export llm lm_head to an `onnx` model.
--export_block EXPORT_BLOCK
export llm block [id] to an `onnx` model.
--export_blocks export llm all blocks to `onnx` models.
--embed_bin export embedding weight as bin file with dtype `bfloat16`
--embed_bf16 using `bfloat16` replace `float32` in embedding.
--skip_slim Whether or not to skip onnx-slim.
```

## 模型推理

### 编译

[从源码编译](../compile/tools.html#id4)

### 使用
#### 运行时配置

##### 运行时文件
将导出产物中用于模型推理的部分置于同一个文件夹下,添加一个配置文件`config.json`来描述模型名称与推理参数,目录如下:
```
.
└── model_dir
├── config.json
├── embeddings_bf16.bin
├── llm_config.json
├── llm.mnn
├── llm.mnn.weight
└── tokenizer.txt
```

##### 配置项
配置文件支持以下配置:
- 模型文件信息
- base_dir: 模型文件加载的文件夹目录,默认为config.json的所在目录,或模型所在目录;
- llm_config: `llm_config.json`的实际名称路径为`base_dir + llm_config`,默认为`base_dir + 'config.json'`
- llm_model: `llm.mnn`的实际名称路径为`base_dir + llm_model`,默认为`base_dir + 'llm.mnn'`
- llm_weight: `llm.mnn.weight`的实际名称路径为`base_dir + llm_weight`,默认为`base_dir + 'llm.mnn.weight'`
- block_model: 分段模型时`block_{idx}.mnn`的实际路径为`base_dir + block_model`,默认为`base_dir + 'block_{idx}.mnn'`
- lm_model: 分段模型时`lm.mnn`的实际路径为`base_dir + lm_model`,默认为`base_dir + 'lm.mnn'`
- embedding_model: 当embedding使用模型时,embedding的实际路径为`base_dir + embedding_model`,默认为`base_dir + 'embedding.mnn'`
- embedding_file: 当embedding使用二进制时,embedding的实际路径为`base_dir + embedding_file`,默认为`base_dir + 'embeddings_bf16.bin'`
- tokenizer_file: `tokenizer.txt`的实际名称路径为`base_dir + tokenizer_file`,默认为`base_dir + 'tokenizer.txt'`
- visual_model: 当使用VL模型时,visual_model的实际路径为`base_dir + visual_model`,默认为`base_dir + 'visual.mnn'`
- 推理配置
- max_new_tokens: 生成时最大token数,默认为`512`
- 硬件配置
- backend_type: 推理使用硬件后端类型,默认为:`"cpu"`
- thread_num: 推理使用硬件线程数,默认为:`4`
- precision: 推理使用精度策略,默认为:`"low"`,尽量使用`fp16`
- memory: 推理使用内存策略,默认为:`"low"`,开启运行时量化

##### 配置文件示例
- `config.json`
```json
{
"llm_model": "qwen2-1.5b-int4.mnn",
"llm_weight": "qwen2-1.5b-int4.mnn.weight",

"backend_type": "cpu",
"thread_num": 4,
"precision": "low",
"memory": "low"
}
```
- `llm_config.json`
```json
{
"hidden_size": 1536,
"layer_nums": 28,
"attention_mask": "float",
"key_value_shape": [
2,
1,
0,
2,
128
],
"prompt_template": "<|im_start|>user\n%s<|im_end|>\n<|im_start|>assistant\n",
"is_visual": false,
"is_single": true
}
```

#### 推理用法
`llm_demo`的用法如下:
```
# 使用config.json
## 交互式聊天
./llm_demo model_dir/config.json
## 针对prompt中的每行进行回复
./llm_demo model_dir/config.json prompt.txt

# 不使用config.json, 使用默认配置
## 交互式聊天
./llm_demo model_dir/llm.mnn
## 针对prompt中的每行进行回复
./llm_demo model_dir/llm.mnn prompt.txt
```
32 changes: 2 additions & 30 deletions express/Executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,38 +243,10 @@ void Executor::RuntimeManager::destroy(RuntimeManager* rtmgr) {
}

void Executor::RuntimeManager::setMode(Interpreter::SessionMode mode) {
if (mode == Interpreter::Session_Input_Inside || mode == Interpreter::Session_Input_User) {
mInside->modes.inputMode = mode;
} else if (mode == Interpreter::Session_Output_User || mode == Interpreter::Session_Output_Inside) {
mInside->modes.outputMode = mode;
} else if (mode == Interpreter::Session_Backend_Auto || mode == Interpreter::Session_Backend_Fix) {
mInside->modes.backendMode = mode;
} else if (mode == Interpreter::Session_Debug || mode == Interpreter::Session_Release) {
mInside->modes.callBackMode = mode;
} else if (mode == Interpreter::Session_Resize_Direct || mode == Interpreter::Session_Resize_Defer) {
mInside->modes.resizeMode = mode;
} else if(mode == Interpreter::Session_Memory_Collect || mode == Interpreter::Session_Memory_Cache) {
mInside->modes.memoryUsageMode = mode;
} else if(mode == Interpreter::Session_Codegen_Disable || mode == Interpreter::Session_Codegen_Enable) {
mInside->modes.codegenMode = mode;
}
mInside->modes.setMode(mode);
}
void Executor::RuntimeManager::setHint(Interpreter::HintMode mode, int value) {
switch (mode) {
case Interpreter::MAX_TUNING_NUMBER:
mInside->modes.maxTuningNumber = value;
break;
case Interpreter::STRICT_CHECK_MODEL:
mInside->checkNetBuffer = value > 0;
break;
case Interpreter::MEM_ALLOCATOR_TYPE:
mInside->modes.memoryAllocatorType = value;
break;
case Interpreter::WINOGRAD_MEMORY_LEVEL:
mInside->modes.winogradMemoryUsed = value;
default:
break;
}
mInside->modes.setHint(mode, value);
}
bool Executor::RuntimeManager::getInfo(Interpreter::SessionInfoCode code, void* ptr) {
// Only support get memory
Expand Down
5 changes: 3 additions & 2 deletions express/Expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ VARP Variable::create(EXPRP expr, int index) {
res.fix(VARP::CONSTANT);
return res;
}
// CONTENT Mode
// CONTENT Mode, Use Geometry Computer to Decompress Expr
do {
if (!(executor->getLazyMode() & Executor::LAZY_CONTENT)) {
break;
Expand All @@ -398,7 +398,8 @@ VARP Variable::create(EXPRP expr, int index) {
outputTensors[i] = expr->mInside->mOutputTensors[i];
}
auto bn = executor->getAttr()->constantBackend;
GeometryComputer::Context context(bn);
// TODO: Support set mask
GeometryComputer::Context context(Interpreter::GeometryComputeMask::GEOMETRCOMPUTEMASK_ALL, bn);
auto geo = GeometryComputer::search(expr->get()->type(), Runtime::Compiler_Loop);
CommandBuffer cmd;
res = geo->onCompute(expr->get(), inputTensors, outputTensors, context, cmd);
Expand Down
1 change: 0 additions & 1 deletion express/RuntimeAttr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ struct RuntimeAttr {
// Use for static module to compute flops
float mFlops;
std::string mExternalFile;
bool checkNetBuffer = true;
};
struct ExecutorAttr {
std::shared_ptr<Backend> constantBackend;
Expand Down
2 changes: 1 addition & 1 deletion express/module/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ static Module* loadInternal(const std::vector<std::string>& inputs, const std::v
}
bool checkMNNBuffer = true;
if (nullptr != _rtMgr) {
checkMNNBuffer = _rtMgr->getInside()->checkNetBuffer;
checkMNNBuffer = _rtMgr->getInside()->modes.checkNetBuffer;
}
if (checkMNNBuffer) {
flatbuffers::Verifier verify(buffer, length);
Expand Down
21 changes: 21 additions & 0 deletions include/MNN/Interpreter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,28 @@ class MNN_PUBLIC Interpreter {
MEM_ALLOCATOR_TYPE = 2,
// Winograd unit candidates count, default 3. if set 0, will use less unit candidates for less memory at the expense of performance.
WINOGRAD_MEMORY_LEVEL = 3,

// Geometry Compute option, default is 0xFFFF
GEOMETRY_COMPUTE_MASK = 4,
};

enum GeometryComputeMask {
// Support Region Fuse
GEOMETRCOMPUTEMASK_FUSEREGION = 1 << 0,

// Support Region Fuse to input with multi-region, eg: pad + concat
GEOMETRCOMPUTEMASK_FUSEREGION_MULTI = 1 << 1,

// Use loop instead of raster + compute if possible
GEOMETRCOMPUTEMASK_USELOOP = 1 << 2,

// Support Geometry Cache, if shape changed, will try recompute, and then run compute if failed
GEOMETRCOMPUTEMASK_OPENCACHE = 1 << 3,

// Full option open mask, for example, if want to close useloop, can set mask as (GEOMETRCOMPUTEMASK_ALL - GEOMETRCOMPUTEMASK_USELOOP)
GEOMETRCOMPUTEMASK_ALL = 0xFFFF,
};

/**
* @brief The API shoud be called before create session.
* @param mode Hint type
Expand Down
Loading

0 comments on commit 65ec0ea

Please sign in to comment.