Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MNN:Sync] Sync Internal 2.8.4 #2839

Merged
merged 2 commits into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
17 changes: 16 additions & 1 deletion docs/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ TensorArray 和控制流支持需要借助 MNN-Express ,
- 加载网络时,把需要获取的中间结果加到 output name 中


### GPU 后端无法使用
### OpenCL 或 Vulkan 后端无法使用
Linux系统上的简单解决方案:
cmake .. -DMNN_USE_SYSTEM_LIB=true -DMNN_SEP_BUILD=false

Expand All @@ -193,6 +193,21 @@ OpenCL / Vulkan 采用静态变量自注册的方式往 MNN 主库注册后端.
1. 设置 MNN_SEP_BUILD = OFF (cmake -DMNN_SEP_BUILD=OFF).  把 opencl / vulkan 后端统一编入 MNN 的 so.
1. 自己在使用的代码中加上 dlopen("libMNN_CL.so") . 参考 [https://github.com/alibaba/MNN/issues/105](https://github.com/alibaba/MNN/issues/105) .

#### Android App 上因权限问题打不开 OpenCL 库
由于Android新版本增强了权限控制,有可能遇到加载OpenCL库失败的问题,可以修改 AndroidManifest.xml 对应栏,加入OpenCL相关 so 的权限需求

```
<application>
...

<uses-native-library android:name="libOpenCL.so"
android:required="true"/>

...

</>
```

### 部分模型用 MNNV2Basic 运行出现段错误

- 模型不满足运行条件
Expand Down
12 changes: 6 additions & 6 deletions docs/tools/convert.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Usage:

--batch arg 如果模型时输入的batch是动态的,可以指定转换后的batch数

--keepInputFormat 是否保持原始模型的输入格式,默认为:
--keepInputFormat 是否保持原始模型的输入格式,默认为:

--optimizeLevel arg 图优化级别,默认为1:
- 0: 不执行图优化,仅针对原始模型是MNN的情况;
Expand All @@ -34,8 +34,6 @@ Usage:
--fp16 将conv/matmul/LSTM的float32参数保存为float16,
模型将减小一半,精度基本无损

--benchmarkModel 不保存模型中conv/matmul/BN等层的参数,仅用于benchmark测试

--bizCode arg MNN模型Flag, ex: MNN

--debug 使用debug模型显示更多转换信息
Expand All @@ -46,6 +44,8 @@ Usage:
仅优化模型大小,加载模型后会解码为float32,量化位宽可选2~8,
运行速度和float32模型一致。8bit时精度基本无损,模型大小减小4倍
default: 0,即不进行权值量化

--weightQuantAsymmetric 与weightQuantBits结合使用,决定是否用非对称量化,默认为`true`

--compressionParamsFile arg
使用MNN模型压缩工具箱生成的模型压缩信息文件
Expand Down Expand Up @@ -79,12 +79,12 @@ Usage:
可选值:{0, 1}, 默认为1, 会检测权重是否使用稀疏化加速

--saveExternalData 将权重,常量等数据存储在额外文件中,默认为`false`

```
**说明1: 选项benchmarkModel将模型中例如卷积的weight,BN的mean、var等参数移除,减小转换后模型文件大小,在运行时随机初始化参数,以方便测试模型的性能。**

**说明2: 选项weightQuantBits,使用方式为 --weightQuantBits numBits,numBits可选2~8,此功能仅对conv/matmul/LSTM的float32权值进行量化,仅优化模型大小,加载模型后会解码为float32,量化位宽可选2~8,运行速度和float32模型一致。经内部测试8bit时精度基本无损,模型大小减小4倍。default: 0,即不进行权值量化。**
**说明1: 选项weightQuantBits,使用方式为 --weightQuantBits numBits,numBits可选2~8,此功能仅对conv/matmul/LSTM的float32权值进行量化,仅优化模型大小,加载模型后会解码为float32,量化位宽可选2~8,运行速度和float32模型一致。经内部测试8bit时精度基本无损,模型大小减小4倍。default: 0,即不进行权值量化。**

**说明3:如果使用Android JNI的Java接口开发,因为接口中不提供`copyFromHost`功能,所以需要在转换模型时使用`keepInputFormat`**
**说明2:如果使用Interpreter-Session C++接口开发,因为NC4HW4便于与ImageProcess结合,可以考虑在转换模型时使用自动内存布局:`--keepInputFormat=0`**

## 其他模型转换到MNN
### TensorFlow to MNN
Expand Down
9 changes: 7 additions & 2 deletions docs/tools/test.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Model Version: < 2.0.0
`./MNNV2Basic.out model [runLoops runMask forwardType numberThread precision_memory inputSize]`
- `model:str` 模型文件路径
- `runLoops:int` 性能测试的循环次数,可选,默认为`1`
- `runMask:int` 是否输出推理中间结果,0为不输出,1为只输出每个算子的输出结果({op_name}.txt)2为输出每个算子的输入(Input_{op_name}.txt)和输出({op_name}.txt)结果; 默认输出当前目录的output目录下(使用工具之前要自己建好output目录)可选,默认为`0`
- `runMask:int` 是否输出推理中间结果,0为不输出,1为只输出每个算子的输出结果({op_name}.txt);2为输出每个算子的输入(Input_{op_name}.txt)和输出({op_name}.txt)结果; 默认输出当前目录的output目录下(使用工具之前要自己建好output目录); 16为开启自动选择后端;32为针对Winograd算法开启内存优化模式,开启后会降低模型(如果含有Winograd Convolution算子)运行时的内存但可能会导致算子的性能损失。可选,默认为`0`
- `forwardType:int` 执行推理的计算设备,有效值为:0(CPU)、1(Metal)、2(CUDA)、3(OpenCL)、6(OpenGL),7(Vulkan) ,9 (TensorRT),可选,默认为`0`
- `numberThread:int` 线程数仅对CPU有效,可选,默认为`4`
- `precision_memory:int` 测试精度与内存模式,precision_memory % 16 为精度,有效输入为:0(Normal), 1(High), 2(Low), 3(Low_BF16),可选,默认为`2` ; precision_memory / 16 为内存设置,默认为 0 (memory_normal) 。例如测试 memory 为 low (2) ,precision 为 1 (high) 时,设置 precision_memory = 9 (2 * 4 + 1)
Expand Down Expand Up @@ -79,6 +79,10 @@ Avg= 5.570600 ms, OpSum = 7.059200 ms min= 3.863000 ms, max= 11.596001 ms
### 默认输出
在当前目录 output 文件夹下,依次打印输出为 0.txt , 1.txt , 2.txt , etc

### 测试文件夹生成
- 若有原始的tf模型/Onnx模型,可以使用testMNNFromTf.py / testMNNFromOnnx.py / testMNNFromTflite.py 等脚本生成
- 若只有mnn模型,可以用 tools/script/make_test_for_mnn.py 脚本生成测试文件夹,使用方式:mkdir testdir && pythhon3 make_test_for_mnn.py XXX.mnn testdir

### runMask 参数说明
- 1 : 输出推理中间结果,每个算子的输入存到(Input_{op_name}.txt),输出存为({op_name}.txt), 默认输出当前目录的output目录下(使用工具之前要自己建好output目录),不支持与 2 / 4 叠加
- 2 : 打印推理中间结果的统计值(最大值/最小值/平均值),只支持浮点类型的统计,不支持与 1 / 4 叠加
Expand All @@ -88,6 +92,7 @@ Avg= 5.570600 ms, OpSum = 7.059200 ms min= 3.863000 ms, max= 11.596001 ms
- 32 : rearrange 设为 true ,降低模型加载后的内存大小,但会增加模型加载的初始化时间
- 64 : 创建模型后,clone 出一个新的模型运行,用于测试 clone 功能(主要用于多并发推理)的正确性
- 128 : 使用文件夹下面的 input.mnn 和 output.mnn 做为输入和对比输出,对于数据量较大的情况宜用此方案
- 512 : 开启使用Winograd算法计算卷积时的内存优化,开启后模型的运行时内存会降低,但可能导致性能损失。


### 示例
Expand All @@ -114,7 +119,7 @@ Avg= 9.946699 ms, min= 9.472000 ms, max= 10.227000 ms
`./SequenceModuleTest.out model [forwardType] [shapeMutable] dir1 dir2 ......`
- `model:str` 模型文件路径
- `forwardType:int` 执行推理的计算设备,有效值为:0(CPU)、1(Metal)、2(CUDA)、3(OpenCL)、6(OpenGL),7(Vulkan) ,9 (TensorRT)
- `shapeMutable:int` 输入形状是否可变
- `numberThread:int` 线程数或GPU模式
- `dir_n:str` 输入输出信息文件夹,可使用 testMNNFromOnnx.py 等脚本生成,参考模型转换的正确性校验部分
```bash
./SequenceModuleTest.out transformer.mnn 0 1 tr tr1 tr2 tr3 tr4 > error.txt
Expand Down
23 changes: 13 additions & 10 deletions express/Executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,6 @@ void Executor::setGlobalExecutorConfig(MNNForwardType type, const BackendConfig&
mAttr->firstType = std::make_pair(type, numberThread);
info.mode = Backend::Info::DIRECT;
info.numThread = numberThread;
if (MNN_FORWARD_METAL == type) {
// Close metal's defer encoder
info.numThread |= MNN_GPU_RECORD_OP;
}
info.user = (BackendConfig*)&config;
std::shared_ptr<Runtime> bn(creator->onCreate(info));
mRuntimes[mAttr->firstType] = bn;
Expand Down Expand Up @@ -264,6 +260,8 @@ void Executor::RuntimeManager::setHint(Interpreter::HintMode mode, int value) {
case Interpreter::MEM_ALLOCATOR_TYPE:
mInside->modes.memoryAllocatorType = value;
break;
case Interpreter::WINOGRAD_MEMORY_LEVEL:
mInside->modes.winogradMemoryUsed = value;
default:
break;
}
Expand Down Expand Up @@ -303,6 +301,7 @@ Executor::RuntimeManager::RuntimeManager() {
mInside->modes.outputMode = Interpreter::Session_Output_User;
}
Executor::RuntimeManager::~RuntimeManager() {
updateCache();
delete mInside;
}
Executor::RuntimeManager* Executor::RuntimeManager::createRuntimeManager(const ScheduleConfig &config) {
Expand Down Expand Up @@ -367,7 +366,7 @@ void Executor::RuntimeManager::setCache(std::string cacheName) {
MNN_ERROR("Empty cacheFile\n");
return;
}
std::unique_ptr<FileLoader> loader(new FileLoader(mInside->mCache->cacheFile.c_str()));
std::unique_ptr<FileLoader> loader(new FileLoader(mInside->mCache->cacheFile.c_str(), true));
if (!loader->valid()) {
MNN_ERROR("Load Cache file error.\n");
return;
Expand All @@ -394,16 +393,19 @@ void Executor::RuntimeManager::setCache(std::string cacheName) {
// Reset cache
loadCache(mInside->mInfo, nullptr, 0);
MNN_PRINT("Cache invalid, will be reset\n");
} else {
mInside->mCache->lastCacheSize = mInside->mCache->cacheBuffer.size() - mInside->mCache->cacheOffset;
}

mInside->mCache->lastCacheSize = mInside->mCache->cacheBuffer.size() - mInside->mCache->cacheOffset;
}

void Executor::RuntimeManager::setExternalFile(std::string fileName) {
mInside->mExternalFile = fileName;
}

void Executor::RuntimeManager::updateCache() {
if (nullptr == mInside->mCache) {
return;
}
std::lock_guard<std::mutex> _l(mLock);

// Backend_Auto and no Async work, then don't need updateCache
Expand Down Expand Up @@ -489,7 +491,10 @@ void Executor::_makeCache(const std::vector<EXPRP>& expr, bool forceCPU) {
if (dfsStack.empty()) {
return;
}
auto current = ExecutorScope::Current();
auto rt = current->getRuntime();
Schedule::ScheduleInfo scheduleInfo;
scheduleInfo.externalWeightPath = current->getAttr()->externalFile;
scheduleInfo.pipelineInfo.resize(1);
auto& pipeline = scheduleInfo.pipelineInfo[0].second;
std::vector<std::shared_ptr<BufferStorage>> opBuffers;
Expand Down Expand Up @@ -577,8 +582,6 @@ void Executor::_makeCache(const std::vector<EXPRP>& expr, bool forceCPU) {
}
pipeline.emplace_back(std::move(opInfo));
}
auto current = ExecutorScope::Current();
auto rt = current->getRuntime();
Session::ModeGroup group;
group.inputMode = Interpreter::Session_Input_User;
group.outputMode = Interpreter::Session_Output_User;
Expand All @@ -593,9 +596,9 @@ void Executor::_makeCache(const std::vector<EXPRP>& expr, bool forceCPU) {
cahce->mCacheBuffers = std::move(opBuffers);
// Don't report error when use expr dynamic compute, which will be called in model convert
scheduleInfo.pipelineInfo[0].first.reportError = false;
scheduleInfo.pipelineInfo[0].first.info.numThread = 1;
if (forceCPU) {
scheduleInfo.pipelineInfo[0].first.info.type = MNN_FORWARD_CPU;
scheduleInfo.pipelineInfo[0].first.info.numThread = 1;
} else {
scheduleInfo.pipelineInfo[0].first.info.type = current->getAttr()->firstType.first;
scheduleInfo.pipelineInfo[0].first.info.numThread = current->getAttr()->firstType.second;
Expand Down
31 changes: 20 additions & 11 deletions express/Expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ bool VARP::fix(VARP::InputType type) const {
VARP newVARP = Express::Variable::create(Express::Expr::create(tensor, true));
newVARP->expr().first->mType = type;
auto& pipelineInfo = inside->mCache->getSession()->getPipelineInfo(0);
if (TensorUtils::getDescribe(tensor)->getBackend() == pipelineInfo.first.cache.first.get()) {
if (TensorUtils::getDescribeOrigin(tensor)->getBackend() == pipelineInfo.first.cache.first.get()) {
newVARP->expr().first->inside()->mHoldBackend = pipelineInfo.first.cache.first;
} else if (TensorUtils::getDescribe(tensor)->getBackend() == pipelineInfo.first.cache.second.get()) {
} else if (TensorUtils::getDescribeOrigin(tensor)->getBackend() == pipelineInfo.first.cache.second.get()) {
newVARP->expr().first->inside()->mHoldBackend = pipelineInfo.first.cache.second;
}
Variable::replace(VARP(mContent), newVARP);
Expand Down Expand Up @@ -224,6 +224,16 @@ EXPRP Expr::create(const OpT* op, std::vector<VARP> inputs, int outputSize) {
return create(std::move(info), nullptr, VARP::INPUT);
}
if (OpType_Const == op->type || OpType_TrainableParam == op->type) {
if (!op->externalPath.empty()) {
flatbuffers::FlatBufferBuilder builder;
auto offset = Op::Pack(builder, op);
builder.Finish(offset);
std::shared_ptr<BufferStorage> extra(new BufferStorage);
extra->storage = builder.ReleaseRaw(extra->allocated_size, extra->offset);
auto resExpr = Expr::create(extra, std::move(inputs), outputSize);
resExpr->setName(op->name);
return resExpr;
}
Variable::Info info;
info.dim = op->main.AsBlob()->dims;
info.order = Utils::revertFormat(op->main.AsBlob()->dataFormat);
Expand Down Expand Up @@ -568,7 +578,7 @@ bool Variable::copyToDevicePtr(void* devicePtr, int memoryType) {
auto inside = mFrom->inside();
auto originTensor = inside->mOutputTensors[mFromIndex];

auto bn = TensorUtils::getDescribe(originTensor)->getBackend();
auto bn = TensorUtils::getDescribeOrigin(originTensor)->getBackend();
if(bn == nullptr) {
MNN_ERROR("Error: Varp copyToDevicePtr can't find backend\n");
return false;
Expand All @@ -577,7 +587,7 @@ bool Variable::copyToDevicePtr(void* devicePtr, int memoryType) {
MNN::Tensor tempTensor(originTensor->dimensions(), originTensor->getDimensionType());
tempTensor.setDevicePtr(devicePtr, memoryType);

TensorUtils::getDescribe(originTensor)->getBackend()->onCopyBuffer(originTensor, &tempTensor);
TensorUtils::getDescribeOrigin(originTensor)->getBackend()->onCopyBuffer(originTensor, &tempTensor);
// Sync the result
tempTensor.wait(Tensor::MAP_TENSOR_READ, true);
return true;
Expand Down Expand Up @@ -738,12 +748,11 @@ bool Variable::resize(INTS dims) {
info.syncSize();
Utils::copyInfoToTensor(mFrom->inside()->mOutputTensors[0], mFrom->inside()->mOutputInfos.data());
Utils::releaseMemoryForHostTensor(mFrom->inside()->mOutputTensors[0]);
if (0 >= info.size) {
return false;
}
bool res = Utils::allocMemoryForHostTensor(mFrom->inside()->mOutputTensors[0]);
if (!res) {
return false;
if (0 < info.size) {
bool res = Utils::allocMemoryForHostTensor(mFrom->inside()->mOutputTensors[0]);
if (!res) {
return false;
}
}

mFrom->mValid = true;
Expand Down Expand Up @@ -946,7 +955,7 @@ bool Expr::setInfoDirty() {
std::vector<VARP> Variable::load(const char* fileName) {
AutoStorage<uint8_t> buffer;
{
FileLoader loader(fileName);
FileLoader loader(fileName, true);
if (!loader.valid()) {
MNN_ERROR("Error for open %s\n", fileName);
return {};
Expand Down
24 changes: 0 additions & 24 deletions express/NeuralNetWorkOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1756,18 +1756,6 @@ VARP _Int8ToFloat(VARP x, VARP scale) {
auto xInfo = x->getInfo();
auto scaleInfo = scale->getInfo();
auto scalePtr = scale->readMap<float>();
if (nullptr == scalePtr || nullptr == xInfo || nullptr == scaleInfo) {
MNN_ERROR("Error for _Int8ToFloat because var not ready\n");
return nullptr;
}
if (xInfo->order != NC4HW4 || xInfo->type.code != halide_type_int) {
MNN_ERROR("Not Support Input for _Int8ToFloat because var not NC4HW4 or not int8\n");
return nullptr;
}
if ((scaleInfo->size != xInfo->dim[1]) && (scaleInfo->size != 1)) {
MNN_ERROR("_Int8ToFloat Scale's size not match input's channel\n");
return nullptr;
}
std::unique_ptr<OpT> op(new OpT);
op->type = OpType_Int8ToFloat;
op->main.type = OpParameter_QuantizedFloatParam;
Expand All @@ -1781,18 +1769,6 @@ VARP _Int8ToFloat(VARP x, VARP scale, int8_t zeroPoint) {
auto xInfo = x->getInfo();
auto scaleInfo = scale->getInfo();
auto scalePtr = scale->readMap<float>();
if (nullptr == scalePtr || nullptr == xInfo || nullptr == scaleInfo) {
MNN_ERROR("Error for _Int8ToFloat because var not ready\n");
return nullptr;
}
if (xInfo->order != NC4HW4 || xInfo->type.code != halide_type_int) {
MNN_ERROR("Not Support Input for _Int8ToFloat because var not NC4HW4 or not int8\n");
return nullptr;
}
if ((scaleInfo->size != xInfo->dim[1]) && (scaleInfo->size != 1)) {
MNN_ERROR("_Int8ToFloat Scale's size not match input's channel\n");
return nullptr;
}
std::unique_ptr<OpT> op(new OpT);
op->type = OpType_Int8ToFloat;
op->main.type = OpParameter_QuantizedFloatParam;
Expand Down
1 change: 1 addition & 0 deletions express/RuntimeAttr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ struct RuntimeAttr {
struct ExecutorAttr {
std::shared_ptr<Backend> constantBackend;
std::pair<MNNForwardType, int> firstType;
std::string externalFile;
};
};
};
Expand Down
Loading
Loading