Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

类ResNet50模型,Vulkan Backend的推理偏差误差比较大,是符合预期的么? #2433

Closed
ldfandian opened this issue Jun 15, 2023 · 25 comments

Comments

@ldfandian
Copy link

ldfandian commented Jun 15, 2023

平台(如果交叉编译请再附上交叉编译目标平台):

Platform(Include target platform as well if cross-compiling):

Android / 小米12s Ultra (高通骁龙8+ Gen1)

Github版本:

Github Version:

Provide date (or better yet, git revision from the comment section of the zip. Obtainable using 7z l PATH/TO/ZIP and search for Comment in the output) if downloading source as zip,otherwise provide the first commit id from the output of git commit

编译方式:

Compiling Method

请在这里粘贴cmake参数或使用的cmake脚本路径以及完整输出
Paste cmake arguments or path of the build script used here as well as the full log of the cmake proess here or pastebin

编译日志:

Build Log:

粘贴在这里
Paste log here or pastebin

模型是类ResNet-50的Clip模型,先转成Onnx模型,再转化成了Mnn模型。
BTW,我这个机型上的libOpenCL.so好像三方App用不了,这里的CL数据应该是MNN自带实现(猜测)里的~
最后的结果是,只有Vulkan Backend的推理偏差误差比较大,但OpenCL/NNAPI/Auto/CPU都是比较贴合Onnx原模型的(完全相同的代码和数据,就只有一行backend选择不同),这是符合预期的么?

另外,我发现选择4线程执行CPU的时候,比OpenCL/Vulkan都快,这有点猛啊,按说Clip这个模型不小呢,怎么会CPU反而快?

#resultOnnx = dims=[1, 1024], output[0]=[0.004379074, 0.05519787, -0.012666782, 0.110597305, -0.045369, 0.052906048, 0.06956524, 0.1803551, ...]
#  Onnx-MnnCL = 0.022962905: [2.0631822E-4, 2.1298602E-4, -7.469002E-4, 4.6975166E-4, 4.868023E-4, 1.5204772E-4, 0.0010392666, 7.560998E-4, ...]
#  Onnx-MnnVulkan = 6.1780834: [-0.002903183, 0.041785393, 0.16604416, 0.26709145, -0.28926548, 0.14683916, 0.010391653, 0.25097278, ...]
#  Onnx-MnnNNAPI = 0.022984184: [2.0515174E-4, 2.1444634E-4, -7.5005554E-4, 4.7250837E-4, 4.892647E-4, 1.5044957E-4, 0.0010413975, 7.5531006E-4, ...]
#  Onnx-MnnAuto = 0.022962905: [2.0631822E-4, 2.1298602E-4, -7.469002E-4, 4.6975166E-4, 4.868023E-4, 1.5204772E-4, 0.0010392666, 7.560998E-4, ...]
D  #  Onnx-MnnCPU = 0.022984184: [2.0515174E-4, 2.1444634E-4, -7.5005554E-4, 4.7250837E-4, 4.892647E-4, 1.5044957E-4, 0.0010413975, 7.5531006E-4, ...]
@jxt1234
Copy link
Collaborator

jxt1234 commented Jun 16, 2023

用 backendTest.out 测试下,也可以试下 vulkan 的 buffer 版本 (-DMNN_VULKAN_IMAGE=OFF)

@jxt1234
Copy link
Collaborator

jxt1234 commented Jun 16, 2023

模型方便上传一下么?

@ldfandian
Copy link
Author

Here you go: 链接: https://pan.baidu.com/s/1RovXbErY2GLU7b0uouMUhQ?pwd=jhhb 提取码: jhhb

@ldfandian
Copy link
Author

ldfandian commented Jun 17, 2023

用 backendTest.out 测试下,也可以试下 vulkan 的 buffer 版本 (-DMNN_VULKAN_IMAGE=OFF)

backendTest.out also tells error.

> ../testCommon.sh ./backendTest.out ./test-clip.mnn 3 0.1 1
==== executing: adb push ./test-clip.mnn /data/local/tmp/MNN/test-clip.mnn
./test-clip.mnn: 1 file pushed, 0 skipped. 34.9 MB/s (76687308 bytes in 2.096s)
==== executing: cd /data/local/tmp/MNN&&export LD_LIBRARY_PATH=.:&& ./backendTest.out test-clip.mnn 3 0.1 1    
Test forward type: 3
Tolerance Rate: 0.100000
Open Model test-clip.mnn
The device support i8sdot:1, support fp16:1, support i8mm: 1
Input: 224,224,3,1
precision=1 in main, 271 
modeNum=1 in main, 276 
stopOp.c_str()=s  in main, 281 
Correct ! Run second pass
Correct !
> ../testCommon.sh ./backendTest.out ./test-clip.mnn 7 0.1 1
==== executing: adb push ./test-clip.mnn /data/local/tmp/MNN/test-clip.mnn
./test-clip.mnn: 1 file pushed, 0 skipped. 35.4 MB/s (76687308 bytes in 2.064s)
==== executing: cd /data/local/tmp/MNN&&export LD_LIBRARY_PATH=.:&& ./backendTest.out test-clip.mnn 7 0.1 1    
Test forward type: 7
Tolerance Rate: 0.100000
Open Model test-clip.mnn
The device support i8sdot:1, support fp16:1, support i8mm: 1
Input: 224,224,3,1
precision=1 in main, 271 
modeNum=1 in main, 276 
stopOp.c_str()=s  in main, 281 
16: 1.252289 != 3.000625
onnx::Reshape_637 - 0 is error
> ../testCommon.sh ./backendTest.out ./test-clip.mnn 7 0.1 2
==== executing: adb push ./test-clip.mnn /data/local/tmp/MNN/test-clip.mnn
./test-clip.mnn: 1 file pushed, 0 skipped. 36.1 MB/s (76687308 bytes in 2.028s)
==== executing: cd /data/local/tmp/MNN&&export LD_LIBRARY_PATH=.:&& ./backendTest.out test-clip.mnn 7 0.1 2    
Test forward type: 7
Tolerance Rate: 0.100000
Open Model test-clip.mnn
The device support i8sdot:1, support fp16:1, support i8mm: 1
Input: 224,224,3,1
precision=2 in main, 271 
modeNum=1 in main, 276 
stopOp.c_str()=s  in main, 281 
16: 1.283203 != 3.000625
onnx::Reshape_637 - 0 is error
> ../testCommon.sh ./backendTest.out ./test-clip.mnn 7 0.1 3
==== executing: adb push ./test-clip.mnn /data/local/tmp/MNN/test-clip.mnn
./test-clip.mnn: 1 file pushed, 0 skipped. 35.2 MB/s (76687308 bytes in 2.075s)
==== executing: cd /data/local/tmp/MNN&&export LD_LIBRARY_PATH=.:&& ./backendTest.out test-clip.mnn 7 0.1 3    
Test forward type: 7
Tolerance Rate: 0.100000
Open Model test-clip.mnn
The device support i8sdot:1, support fp16:1, support i8mm: 1
Input: 224,224,3,1
precision=3 in main, 271 
modeNum=1 in main, 276 
stopOp.c_str()=s  in main, 281 
16: 1.283203 != 3.000625
onnx::Reshape_637 - 0 is error
> ../testCommon.sh ./backendTest.out ./test-clip.mnn 3 0.1 2
==== executing: adb push ./test-clip.mnn /data/local/tmp/MNN/test-clip.mnn
./test-clip.mnn: 1 file pushed, 0 skipped. 35.3 MB/s (76687308 bytes in 2.074s)
==== executing: cd /data/local/tmp/MNN&&export LD_LIBRARY_PATH=.:&& ./backendTest.out test-clip.mnn 3 0.1 2    
Test forward type: 3
Tolerance Rate: 0.100000
Open Model test-clip.mnn
The device support i8sdot:1, support fp16:1, support i8mm: 1
Input: 224,224,3,1
precision=2 in main, 271 
modeNum=1 in main, 276 
stopOp.c_str()=s  in main, 281 
Correct ! Run second pass
Correct !

@ldfandian
Copy link
Author

另外,类似这种性能指标,是不是也就没有什么可以优化的了?
(怎么看出来模型是全部跑在OpenCL/Vulkan GPU上了,还是fall back一部分计算到了CPU了呢?)

thor:/data/local/tmp $ LD_LIBRARY_PATH=/data/local/tmp  ./timeProfile.out benchmark_models/modified-resnet50.pt.fp16.mnn 100 
... ...
The device support i8sdot:1, support fp16:1, support i8mm: 1
... ...
Sort by time cost !
Node Type   	Avg(ms)   	%         	Called times 	Flops Rate 	
Reduction   	0.021170  	0.031292  	1.000000     	0.000032   	
Softmax     	0.061860  	0.091436  	1.000000     	0.001246   	
While       	0.285560  	0.422090  	6.000000     	0.164346   	
Raster      	0.353309  	0.522231  	14.000000    	0.022340   	
Pooling     	0.393220  	0.581223  	8.000000     	0.048469   	
BinaryOp    	1.786501  	2.640648  	18.000000    	0.089185   	
Convolution 	64.752686 	95.711708 	57.000000    	99.675491  	
total time : 67.653885 ms, total mflops : 6120.963867 
main, 147, cost time: 6780.418457 ms
thor:/data/local/tmp $ 


thor:/data/local/tmp $ LD_LIBRARY_PATH=/data/local/tmp  ./MNNV2Basic.out  benchmark_models/modified-resnet50.pt.fp16.mnn 100 0 3 4 2                                                                                              
Use extra forward type: 3
... ...
Session Info: memory use 52.734776 MB, flops is 6121.452637 M, backendType is 3
===========> Session Resize Done.
===========> Session Start running...
Input size:150528
	**Tensor shape**: 1, 3, 224, 224, 
output: unnorm_image_features
precision:2, memory: 0, Run 100 time:
... ...


@jxt1234
Copy link
Collaborator

jxt1234 commented Jun 20, 2023

Here you go: 链接: https://pan.baidu.com/s/1RovXbErY2GLU7b0uouMUhQ?pwd=jhhb 提取码: jhhb

呃,这里不能用百度网盘,有别的方式上传么,比如 google 云盘,或者发我邮箱之类

@jxt1234
Copy link
Collaborator

jxt1234 commented Jun 20, 2023

另外,类似这种性能指标,是不是也就没有什么可以优化的了? (怎么看出来模型是全部跑在OpenCL/Vulkan GPU上了,还是fall back一部分计算到了CPU了呢?)

thor:/data/local/tmp $ LD_LIBRARY_PATH=/data/local/tmp  ./timeProfile.out benchmark_models/modified-resnet50.pt.fp16.mnn 100 
... ...
The device support i8sdot:1, support fp16:1, support i8mm: 1
... ...
Sort by time cost !
Node Type   	Avg(ms)   	%         	Called times 	Flops Rate 	
Reduction   	0.021170  	0.031292  	1.000000     	0.000032   	
Softmax     	0.061860  	0.091436  	1.000000     	0.001246   	
While       	0.285560  	0.422090  	6.000000     	0.164346   	
Raster      	0.353309  	0.522231  	14.000000    	0.022340   	
Pooling     	0.393220  	0.581223  	8.000000     	0.048469   	
BinaryOp    	1.786501  	2.640648  	18.000000    	0.089185   	
Convolution 	64.752686 	95.711708 	57.000000    	99.675491  	
total time : 67.653885 ms, total mflops : 6120.963867 
main, 147, cost time: 6780.418457 ms
thor:/data/local/tmp $ 


thor:/data/local/tmp $ LD_LIBRARY_PATH=/data/local/tmp  ./MNNV2Basic.out  benchmark_models/modified-resnet50.pt.fp16.mnn 100 0 3 4 2                                                                                              
Use extra forward type: 3
... ...
Session Info: memory use 52.734776 MB, flops is 6121.452637 M, backendType is 3
===========> Session Resize Done.
===========> Session Start running...
Input size:150528
	**Tensor shape**: 1, 3, 224, 224, 
output: unnorm_image_features
precision:2, memory: 0, Run 100 time:
... ...

用 MNNV2Basic.out 会打印各算子时间,GPU 的耗时都会很少(只包含了 CPU 发命令的时间),看哪些算子耗时较多可估算出哪个算子跑在了 cpu 上

@ldfandian
Copy link
Author

另外,类似这种性能指标,是不是也就没有什么可以优化的了? (怎么看出来模型是全部跑在OpenCL/Vulkan GPU上了,还是fall back一部分计算到了CPU了呢?)

thor:/data/local/tmp $ LD_LIBRARY_PATH=/data/local/tmp  ./timeProfile.out benchmark_models/modified-resnet50.pt.fp16.mnn 100 
... ...
The device support i8sdot:1, support fp16:1, support i8mm: 1
... ...
Sort by time cost !
Node Type   	Avg(ms)   	%         	Called times 	Flops Rate 	
Reduction   	0.021170  	0.031292  	1.000000     	0.000032   	
Softmax     	0.061860  	0.091436  	1.000000     	0.001246   	
While       	0.285560  	0.422090  	6.000000     	0.164346   	
Raster      	0.353309  	0.522231  	14.000000    	0.022340   	
Pooling     	0.393220  	0.581223  	8.000000     	0.048469   	
BinaryOp    	1.786501  	2.640648  	18.000000    	0.089185   	
Convolution 	64.752686 	95.711708 	57.000000    	99.675491  	
total time : 67.653885 ms, total mflops : 6120.963867 
main, 147, cost time: 6780.418457 ms
thor:/data/local/tmp $ 


thor:/data/local/tmp $ LD_LIBRARY_PATH=/data/local/tmp  ./MNNV2Basic.out  benchmark_models/modified-resnet50.pt.fp16.mnn 100 0 3 4 2                                                                                              
Use extra forward type: 3
... ...
Session Info: memory use 52.734776 MB, flops is 6121.452637 M, backendType is 3
===========> Session Resize Done.
===========> Session Start running...
Input size:150528
	**Tensor shape**: 1, 3, 224, 224, 
output: unnorm_image_features
precision:2, memory: 0, Run 100 time:
... ...

用 MNNV2Basic.out 会打印各算子时间,GPU 的耗时都会很少(只包含了 CPU 发命令的时间),看哪些算子耗时较多可估算出哪个算子跑在了 cpu 上

明白了,还可以这样用呢?天才啊~ 我还以为是完全不准把 ./MNNV2Basic.out 用到GPU模型呢。。。原来还可以用这个来做排除法。。。

@ldfandian
Copy link
Author

Here you go: 链接: https://pan.baidu.com/s/1RovXbErY2GLU7b0uouMUhQ?pwd=jhhb 提取码: jhhb

呃,这里不能用百度网盘,有别的方式上传么,比如 google 云盘,或者发我邮箱之类

please go here: https://drive.google.com/drive/folders/1gyJfSZ2sHM69vjBCok-PA_CEDEVUa5dm?usp=sharing

@ldfandian
Copy link
Author

另外,类似这种性能指标,是不是也就没有什么可以优化的了? (怎么看出来模型是全部跑在OpenCL/Vulkan GPU上了,还是fall back一部分计算到了CPU了呢?)

thor:/data/local/tmp $ LD_LIBRARY_PATH=/data/local/tmp  ./timeProfile.out benchmark_models/modified-resnet50.pt.fp16.mnn 100 
... ...
The device support i8sdot:1, support fp16:1, support i8mm: 1
... ...
Sort by time cost !
Node Type   	Avg(ms)   	%         	Called times 	Flops Rate 	
Reduction   	0.021170  	0.031292  	1.000000     	0.000032   	
Softmax     	0.061860  	0.091436  	1.000000     	0.001246   	
While       	0.285560  	0.422090  	6.000000     	0.164346   	
Raster      	0.353309  	0.522231  	14.000000    	0.022340   	
Pooling     	0.393220  	0.581223  	8.000000     	0.048469   	
BinaryOp    	1.786501  	2.640648  	18.000000    	0.089185   	
Convolution 	64.752686 	95.711708 	57.000000    	99.675491  	
total time : 67.653885 ms, total mflops : 6120.963867 
main, 147, cost time: 6780.418457 ms
thor:/data/local/tmp $ 


thor:/data/local/tmp $ LD_LIBRARY_PATH=/data/local/tmp  ./MNNV2Basic.out  benchmark_models/modified-resnet50.pt.fp16.mnn 100 0 3 4 2                                                                                              
Use extra forward type: 3
... ...
Session Info: memory use 52.734776 MB, flops is 6121.452637 M, backendType is 3
===========> Session Resize Done.
===========> Session Start running...
Input size:150528
	**Tensor shape**: 1, 3, 224, 224, 
output: unnorm_image_features
precision:2, memory: 0, Run 100 time:
... ...

用 MNNV2Basic.out 会打印各算子时间,GPU 的耗时都会很少(只包含了 CPU 发命令的时间),看哪些算子耗时较多可估算出哪个算子跑在了 cpu 上

明白了,还可以这样用呢?天才啊~ 我还以为是完全不准把 ./MNNV2Basic.out 用到GPU模型呢。。。原来还可以用这个来做排除法。。。

看起来OpenCL和CPU都基本没有用GPU,只有Vulkan用了GPU~ 这么看起来,OpenCL没什么优势啊~

> ls -l /tmp/run-*.log
-rw-r--r--  1 dianfan  wheel  14140  6 21 20:21 /tmp/run-opencl.log
-rw-r--r--  1 dianfan  wheel  14138  6 21 20:20 /tmp/run-vulkan.log
> grep " FlopsRate: " /tmp/run-*.log | grep -v " FlopsRate: 0." | grep -v " average cost 0.0"
/tmp/run-opencl.log:                                          input.43 	[Convolution] run 100 average cost 0.245180 ms, 0.301 %, FlopsRate: 1.889 %
/tmp/run-opencl.log:                             359__matmul_converted 	[Convolution] run 100 average cost 51.993515 ms, 63.858 %, FlopsRate: 10.278 %

@jxt1234
Copy link
Collaborator

jxt1234 commented Jun 22, 2023

opencl 可以把 numberThread 设成 65 试下,用 buffer 内存格式

@ldfandian
Copy link
Author

opencl 可以把 numberThread 设成 65 试下,用 buffer 内存格式

两个设置都试了下,没啥优化效果。

@qiankl
Copy link

qiankl commented Jul 7, 2023

请问有进展吗,我这也出现类似的问题

@jxt1234
Copy link
Collaborator

jxt1234 commented Jul 26, 2023

已经定位,等待同步。Vulkan Image 模式下 convert 缓存 key 错误问题。相关代码可自行修改:

diff --git a/source/backend/vulkan/image/backend/VulkanBackend.cpp b/source/backend/vulkan/image/backend/VulkanBackend.cpp
index 8e56dd22e..d05d47cc0 100644
--- a/source/backend/vulkan/image/backend/VulkanBackend.cpp
+++ b/source/backend/vulkan/image/backend/VulkanBackend.cpp
@@ -284,7 +284,7 @@ void VulkanBackend::onCopyBuffer(const Tensor* srcTensor, const Tensor* dstTenso
tempTensor->buffer().host = (uint8_t*)mHostBuffer->map();
MNNCPUCopyBuffer(srcTensor, tempTensor.get());
mHostBuffer->unmap();

  •    auto key    = std::make_tuple(dstTensor, true, format);
    
  •    auto key    = std::make_tuple(TensorUtils::getDescribe(dstTensor), true, format);
       auto iter   = mConverters.find(key);
       if (iter == mConverters.end()) {
           if (mConverters.size() > MNN_VULKAN_MAX_CACHE_CONVSIZE) {
    

@@ -315,7 +315,7 @@ void VulkanBackend::onCopyBuffer(const Tensor* srcTensor, const Tensor* dstTenso
_finish();
_allocHostBuffer(size);
auto format = TensorUtils::getDescribe(srcTensor)->dimensionFormat;

  •    auto key    = std::make_tuple(srcTensor, false, format);
    
  •    auto key    = std::make_tuple(TensorUtils::getDescribe(srcTensor), false, format);
    
       auto iter = mConverters.find(key);
       if (iter == mConverters.end()) {
    

diff --git a/source/backend/vulkan/image/backend/VulkanBackend.hpp b/source/backend/vulkan/image/backend/VulkanBackend.hpp
index 337dd128d..7784593bb 100644
--- a/source/backend/vulkan/image/backend/VulkanBackend.hpp
+++ b/source/backend/vulkan/image/backend/VulkanBackend.hpp
@@ -13,6 +13,7 @@
#include "MNN_generated.h"
#include "VulkanRuntime.hpp"
#include "VulkanTensor.hpp"
+#include "core/TensorUtils.hpp"

namespace MNN {
class VulkanImageConverter;
@@ -92,7 +93,7 @@ private:
mutable std::shared_ptr mFence;

  • mutable std::map<std::tuple<const Tensor*, bool, MNN_DATA_FORMAT>,
  • mutable std::map<std::tuple<const Tensor::InsideDescribe::NativeInsideDescribe*, bool, MNN_DATA_FORMAT>,
    std::pair<std::shared_ptr, std::shared_ptrVulkanCommandPool::Buffer>>
    mConverters;

@jxt1234
Copy link
Collaborator

jxt1234 commented Jul 31, 2023

2.6.2 修正

@cmelody999
Copy link

cmelody999 commented Aug 9, 2023

你好,使用2.6.2版本,Vulkan Image模式(-DMNN_VULKAN_IMAGE=ON)的误差还是很大:
img_v2_fd2dd052-f234-4c00-b54e-005c67db62fg

Vulkan Buffer模式(-DMNN_VULKAN_IMAGE=OFF)的误差很小,但是性能很差:
image

@ldfandian
Copy link
Author

@jxt1234 here is our update~ please help take a look~

@qiankl
Copy link

qiankl commented Aug 10, 2023

你好,使用2.6.2版本,Vulkan Image模式(-DMNN_VULKAN_IMAGE=ON)的误差还是很大: img_v2_fd2dd052-f234-4c00-b54e-005c67db62fg

Vulkan Buffer模式(-DMNN_VULKAN_IMAGE=OFF)的误差很小,但是性能很差: image

我这边mnn2.6.2也是一样的情况,问题没有解决,可以重开一下issue吗 @jxt1234

@jxt1234
Copy link
Collaborator

jxt1234 commented Aug 15, 2023

确实是,疏忽了。

@jxt1234
Copy link
Collaborator

jxt1234 commented Aug 16, 2023

diff --git a/source/backend/vulkan/image/execution/VulkanImageConverter.cpp b/source/backend/vulkan/image/execution/VulkanImageConverte
r.cpp
index 7bd65523..b1c2bee9 100644
--- a/source/backend/vulkan/image/execution/VulkanImageConverter.cpp
+++ b/source/backend/vulkan/image/execution/VulkanImageConverter.cpp
@@ -126,10 +126,6 @@ void VulkanImageConverter::encodeTensorToBuffer(const Tensor* srcTensor, VkBuffe
const VulkanCommandPool::Buffer* cmdBuffer) {
auto sourceFormat = TensorUtils::getDescribe(srcTensor)->dimensionFormat;
auto destFormat = destBufferFormat;

  • if (sourceFormat == MNN_DATA_FORMAT_NC4HW4 && 1 >= srcTensor->width() && 1 >= srcTensor->height() &&
  •    srcTensor->channel() % 4 == 0) {
    
  •    destFormat = MNN_DATA_FORMAT_NC4HW4;
    
  • }

这个优化项导致的,可以先关掉

jxt1234 pushed a commit that referenced this issue Aug 16, 2023
@jxt1234
Copy link
Collaborator

jxt1234 commented Aug 16, 2023

#2542
已经修正

@cmelody999
Copy link

我试一下

@cmelody999
Copy link

#2542 Vulkan Image模式(-DMNN_VULKAN_IMAGE=ON)的误差很小了,但是性能很差了:
image

@cmelody999
Copy link

cmelody999 commented Aug 18, 2023

@jxt1234 可以帮忙看一下吗

@jxt1234
Copy link
Collaborator

jxt1234 commented Sep 4, 2023

运行错误问题已经修正

@jxt1234 jxt1234 closed this as completed Sep 4, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants