Skip to content

Commit

Permalink
[MNN:Sync] Sync Internal Gitlab
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaying committed Jul 18, 2023
1 parent c293f9e commit ac5b331
Show file tree
Hide file tree
Showing 90 changed files with 3,512 additions and 523 deletions.
7 changes: 7 additions & 0 deletions codegen/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
option(MNN_CODEGEN_OPENCL "Build OpenCL op fuse." OFF)
option(MNN_CODEGEN_METAL "Build Metal op fuse." OFF)
option(MNN_CODEGEN_CUDA "Build Cuda op fuse." OFF)

file(GLOB MNN_FUSE_SRCS "${CMAKE_CURRENT_LIST_DIR}/*.*")

Expand All @@ -15,6 +16,12 @@ if(MNN_CODEGEN_METAL)
list(APPEND MNN_FUSE_SRCS ${METAL_SRCS})
endif()

if(MNN_CODEGEN_CUDA)
add_definitions(-DMNN_CODEGEN_CUDA)
file(GLOB CUDA_SRCS "${CMAKE_CURRENT_LIST_DIR}/cuda/*.*")
list(APPEND MNN_FUSE_SRCS ${CUDA_SRCS})
endif()

add_library(MNNFuse OBJECT ${MNN_FUSE_SRCS})
# set_property(TARGET MNNFuse PROPERTY CXX_STANDARD 14)
list(APPEND MNN_OBJECTS_TO_LINK $<TARGET_OBJECTS:MNNFuse>)
199 changes: 155 additions & 44 deletions codegen/OpFuse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@
#include "SourceModule.hpp"
#include "opencl/OpenCLTarget.hpp"
#include "metal/MetalTarget.hpp"
#ifdef MNN_CODEGEN_CUDA
#include "cuda/CUDATarget.hpp"
#endif
#include <queue>
#include <unordered_map>
#include "core/OpCommonUtils.hpp"

namespace MNN {
static void dumpOp(const Op* op) {
Expand All @@ -37,7 +41,7 @@ static void dumpTensor(const Tensor* t) {
MNN_PRINT("\t%p [", t);
for (int d : t->shape())
MNN_PRINT("%d,", d);
MNN_PRINT("],\n");
MNN_PRINT("], format:%d\n", TensorUtils::getDescribe(t)->dimensionFormat);
auto des = TensorUtils::getDescribe(t);
if (des->memoryType == Tensor::InsideDescribe::MEMORY_VIRTUAL) {
MNN_PRINT("Regions:");
Expand All @@ -59,40 +63,99 @@ static void dumpCmd(const Command* cmd) {
}

// is legal fused type
bool isLegal(const Command* cmd) {
bool isLegal(Command* cmd, MNNForwardType forwardType) {
auto type = cmd->op->type();
bool elemWise = type == OpType_BinaryOp
|| type == OpType_UnaryOp
|| type == OpType_ReLU
|| type == OpType_ReLU6
|| type == OpType_Eltwise;
if (elemWise) {
for (auto t : cmd->inputs) {
if (t->width() * UP_DIV(t->channel(), 4) > 16384) {
return false;
}
auto des = TensorUtils::getDescribe(t)->regions;
for(auto region : des)
{
auto tensor = region.origin;
if (tensor->width() * UP_DIV(tensor->channel(), 4) > 16384) {
if(forwardType == MNN_FORWARD_OPENCL) {
for (auto t : cmd->inputs) {
if (t->width() * UP_DIV(t->channel(), 4) > 16384) {
return false;
}
auto des = TensorUtils::getDescribe(t)->regions;
for(auto region : des)
{
auto tensor = region.origin;
if (tensor->width() * UP_DIV(tensor->channel(), 4) > 16384) {
return false;
}
}
}
}
if(TensorUtils::getDescribe(cmd->outputs[0])->dimensionFormat == MNN_DATA_FORMAT_NC4HW4) {
cmd->canVectorize = true;
} else {
int count = 1;
for(int i = 0; i < cmd->outputs[0]->dimensions(); i++) {
count *= cmd->outputs[0]->length(i);
}
if(count % 4 == 0) {
cmd->canVectorize = true;
} else {
cmd->canVectorize = false;
}
}
return true;
}
#ifdef fuse_raster
if (type == OpType_Raster) {
auto outputFormat = TensorUtils::getDescribe(cmd->outputs[0])->dimensionFormat;
bool legalFormat = outputFormat != MNN_DATA_FORMAT_NC4HW4;
if (TensorUtils::getDescribe(cmd->inputs[0])->regions.size() > 1) return false;
for (auto reg : TensorUtils::getDescribe(cmd->inputs[0])->regions) {
legalFormat &= TensorUtils::getDescribe(reg.origin)->dimensionFormat == outputFormat;

if (forwardType == MNN_FORWARD_CUDA && type == OpType_Raster) {
// Fuse NC4HW4 -> NCHW/HHWC
OpCommonUtils::TensorConvertParameter singleConvert;
auto input = cmd->outputs[0];
OpCommonUtils::rasterInputReset(cmd->inputs, cmd->outputs[0]);
singleConvert.type = 0;
auto des = TensorUtils::getDescribe(input);
if(des->regions.size() == 1) {
OpCommonUtils::turnRegion2Convert(des->regions[0], cmd->outputs[0], singleConvert);
if (singleConvert.type > 0){
auto realInput = TensorUtils::getDescribe(input)->regions[0].origin;
auto sourceFormat = TensorUtils::getDescribe(realInput)->dimensionFormat;
if (MNN_DATA_FORMAT_NC4HW4 == sourceFormat) { // NC4HW4 -> NCHW/NHWC is Supported!
if(singleConvert.type == 1) { // output NCHW
if(singleConvert.batch != cmd->outputs[0]->length(0)) {
return false;
}
if(cmd->outputs[0]->dimensions() < 3 || singleConvert.channel != cmd->outputs[0]->length(1)) {
return false;
}
int area = 1;
for(int i = 2; i < cmd->outputs[0]->dimensions(); i++) {
area *= cmd->outputs[0]->length(i);
}
if(singleConvert.area != area) {
return false;
}
return true;
}
if(singleConvert.type == 2) { // output NHWC
if(singleConvert.batch != cmd->outputs[0]->length(0)) {
return false;
}
int dims = cmd->outputs[0]->dimensions();
if(dims < 3 || singleConvert.channel != cmd->outputs[0]->length(dims-1)) {
return false;
}
int area = 1;
for(int i = 1; i < dims-1; i++) {
area *= cmd->outputs[0]->length(i);
}
if(singleConvert.area != area) {
return false;
}
if(singleConvert.channel % 4 == 0) {
cmd->canVectorize = true;
}
return true;
}
return false;
}
}
}
return legalFormat;
}
#endif
return false;
}

Expand All @@ -109,14 +172,17 @@ Node* LCA(Node* x, Node* y) {
}
return x;
}
bool allPathLegal(Node* s, Node* t) {
bool allPathLegal(Node* s, Node* t, MNNForwardType type) {
bool legal = true;
std::queue<Node*> q;
q.push(s);
while (!q.empty()) {
auto node = q.front();
q.pop();
legal &= isLegal(node->cmd);
legal &= isLegal(node->cmd, type);
if(!legal) {
return false;
}
for (auto succ : node->succ) {
if (succ != t) {
q.push(succ);
Expand All @@ -125,37 +191,58 @@ bool allPathLegal(Node* s, Node* t) {
}
return legal;
}
std::vector<Node*> fuseNode(Node* root, std::vector<Node*>& edges) {
std::vector<Node*> fuseNode(Node* root, std::vector<Node*>& edges, MNNForwardType type) {
std::vector<Node*> fuseSet;
std::queue<Node*> q;
q.push(root);
int rasterCount = 0;
bool insert = false;
while (!q.empty()) {
insert = false;
auto node = q.front();
fuseSet.insert(fuseSet.begin(), node);
if(node->cmd->op->type() == OpType_Raster) {
// Current only fuse single raster
rasterCount++;
if(rasterCount < 2) {
fuseSet.insert(fuseSet.begin(), node);
insert = true;
}
} else {
fuseSet.insert(fuseSet.begin(), node);
insert = true;
}

q.pop();
for (auto child : node->domainateSucc) {
if (isLegal(child->cmd) && allPathLegal(child, root)) {
q.push(child);
} else {
edges.push_back(child);
if(insert) {
for (auto child : node->domainateSucc) {
if (isLegal(child->cmd, type) && allPathLegal(child, root, type)) {
q.push(child);
} else {
edges.push_back(child);
}
}
}
}
return fuseSet;
}

bool codegen(std::vector<Schedule::OpCacheInfo>& infos, std::vector<std::vector<Node*>>& fuseSets, MNNForwardType type) {
bool codegen(std::vector<Schedule::OpCacheInfo>& infos, std::vector<std::vector<Node*>>& fuseSets, MNNForwardType type, BackendConfig::PrecisionMode precision) {
// generate Kernel
std::unique_ptr<Target> target;
switch (type) {
#ifdef MNN_CODEGEN_OPENCL
case MNN_FORWARD_OPENCL:
target.reset(new OpenCLTarget);
target.reset(new OpenCLTarget(precision));
break;
#endif
#ifdef MNN_CODEGEN_METAL
case MNN_FORWARD_METAL:
target.reset(new MetalTarget);
target.reset(new MetalTarget(precision));
break;
#endif
#ifdef MNN_CODEGEN_CUDA
case MNN_FORWARD_CUDA:
target.reset(new CUDATarget(precision));
break;
#endif
default:
Expand All @@ -166,13 +253,23 @@ bool codegen(std::vector<Schedule::OpCacheInfo>& infos, std::vector<std::vector<
MNN_PRINT(">>>>>>>>>>>>> fuseSets.size = %lu\n", fuseSets.size());
}
#endif
std::map<std::string, int> mapKernelSources;
for (int i = 0; i < fuseSets.size(); i++) {
auto& compSet = fuseSets[i];
/*
for (auto comp : compSet) {
dumpCmd(comp->cmd);
}
*/
bool fuseKernelVectorize = true;
for (auto& node : compSet) {
auto cmd = node->cmd;
if(!cmd->canVectorize) {
fuseKernelVectorize = false;
break;
}
}
target->setFuseKernelVectorize(fuseKernelVectorize);
SourceModule fuseModule(target.get());
InOutTensors tensors = fuseModule.buildKernel(compSet, i);
auto inputs = tensors.first;
Expand All @@ -181,13 +278,21 @@ bool codegen(std::vector<Schedule::OpCacheInfo>& infos, std::vector<std::vector<
SharedPtr<Command> cmdPlugin;
{
auto sourceCode = fuseModule.codegen();
if(mapKernelSources.find(sourceCode) == mapKernelSources.end()) {
int kernelCount = mapKernelSources.size();
mapKernelSources.insert(std::pair<std::string, int>(sourceCode, kernelCount));
}
std::string kernelName = "kernel_" + std::to_string(mapKernelSources[sourceCode]);
sourceCode.insert(fuseModule.strIndexForKernelNum(), kernelName);

std::unique_ptr<OpT> fuseOp(new OpT);
fuseOp->type = OpType_Extra;
fuseOp->name = fuseModule.opName();
ExtraT* extra_param = new ExtraT;
extra_param->type = fuseModule.kernelName();
extra_param->type = kernelName;
extra_param->info.resize(sourceCode.size() + 1);
memcpy(extra_param->info.data(), sourceCode.data(), sourceCode.size() + 1);
extra_param->vector = fuseKernelVectorize;
fuseOp->main.type = OpParameter_Extra;
fuseOp->main.value = extra_param;
flatbuffers::FlatBufferBuilder builder;
Expand Down Expand Up @@ -218,10 +323,20 @@ bool codegen(std::vector<Schedule::OpCacheInfo>& infos, std::vector<std::vector<
}
}
}
// Clear useless cacheBuffer
for (auto& info : infos) {
for (auto iter = info.cacheBuffer.command.begin(); iter != info.cacheBuffer.command.end();) {
if (iter->get()->op == nullptr) {
iter = info.cacheBuffer.command.erase(iter);
} else {
++iter;
}
}
}
return true;
}

bool opFuse(std::vector<Schedule::OpCacheInfo>& infos, MNNForwardType type) {
bool opFuse(std::vector<Schedule::OpCacheInfo>& infos, MNNForwardType type, BackendConfig::PrecisionMode precision) {
std::unordered_map<const Tensor*, Node*> outputTensor;
// build graph
std::vector<std::unique_ptr<Node>> graph;
Expand All @@ -246,13 +361,7 @@ bool opFuse(std::vector<Schedule::OpCacheInfo>& infos, MNNForwardType type) {
node->cmd = iter.get();
node->topoIndex = i;
for (auto input : iter->inputs) {
if (!TensorUtils::getDescribe(input)->regions.empty()) {
for (auto& region : TensorUtils::getDescribe(input)->regions) {
insertEdge(region.origin, node.get());
}
} else {
insertEdge(input, node.get());
}
insertEdge(input, node.get());
}
for (auto output : iter->outputs) {
outputTensor[output] = node.get();
Expand Down Expand Up @@ -289,8 +398,8 @@ bool opFuse(std::vector<Schedule::OpCacheInfo>& infos, MNNForwardType type) {
continue;
}
std::vector<Node*> childs;
if (isLegal(root->cmd)) {
auto fuseSet = fuseNode(root, childs);
if (isLegal(root->cmd, type)) {
auto fuseSet = fuseNode(root, childs, type);
if (fuseSet.size() > 1) {
fuseSets.emplace_back(std::move(fuseSet));
}
Expand All @@ -301,7 +410,9 @@ bool opFuse(std::vector<Schedule::OpCacheInfo>& infos, MNNForwardType type) {
postDominateNodeQueue.push(child);
}
}

#if 0
MNN_PRINT("fuse total number: %lu \n", fuseSets.size());
for (auto compSet : fuseSets) {
MNN_PRINT("set size: %lu \n", compSet.size());
if (true) {
Expand All @@ -313,7 +424,7 @@ bool opFuse(std::vector<Schedule::OpCacheInfo>& infos, MNNForwardType type) {
}
}
#endif
return codegen(infos, fuseSets, type);
return codegen(infos, fuseSets, type, precision);
}
} // namespace MNN

4 changes: 2 additions & 2 deletions codegen/OpFuse.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
//

#include "geometry/GeometryComputerUtils.hpp"

#include <map>
namespace MNN {
bool opFuse(std::vector<Schedule::OpCacheInfo>& infos, MNNForwardType type);
bool opFuse(std::vector<Schedule::OpCacheInfo>& infos, MNNForwardType type, BackendConfig::PrecisionMode precision);
} // namespace MNN

Loading

0 comments on commit ac5b331

Please sign in to comment.