diff --git "a/docs/\350\256\255\347\273\203\350\220\245\344\275\234\344\270\232\344\273\213\347\273\215.md" "b/docs/\350\256\255\347\273\203\350\220\245\344\275\234\344\270\232\344\273\213\347\273\215.md" index 5b27825..80b08e5 100644 --- "a/docs/\350\256\255\347\273\203\350\220\245\344\275\234\344\270\232\344\273\213\347\273\215.md" +++ "b/docs/\350\256\255\347\273\203\350\220\245\344\275\234\344\270\232\344\273\213\347\273\215.md" @@ -11,6 +11,8 @@ 7. test_nativecpu_concat:依赖作业一、作业五 8. test_nativecpu_elementwise:依赖作业一、作业六 9. test_nativecpu_transpose:依赖作业一、作业二 +10. test_matmul:依赖作业六 +11. test_graph:无依赖 # 作业题目 @@ -189,7 +191,7 @@ optional> ConcatObj::inferShape(const TensorVec &inputs) { 难度:⭐⭐⭐ -对应测例:``test_element_wise``,``test_nativecpu_elementwise`` +对应测例:``test_element_wise``,``test_nativecpu_elementwise``,``test_matmul`` 需要实现的代码块位置:`src/utils/operator_utils.cc` @@ -220,6 +222,7 @@ optional> MatmulObj::inferShape(const TensorVec &inputs) // TODO:返回经过 matmul 操作后的 shape // REF: https://github.com/onnx/onnx/blob/main/docs/Operators.md#gemm // =================================== 作业 =================================== + return {{}}; } ```` @@ -236,6 +239,9 @@ void GraphObj::optimize() { // =================================== 作业 =================================== // TODO: 设计一个算法来实现指定的图优化规则 + // 图优化规则如下: + // 1. 去除冗余的算子(例如,两个相邻的算子都是 transpose 算子,且做的是相反的操作,可以将其全部删除) + // 2. 合并算子(例如,矩阵乘算子中含有属性transA、transB,如果其输入存在transpose,且对最后两个维度做交换,就可以将transpose融入到矩阵乘算子的属性中去) // =================================== 作业 =================================== } ```` \ No newline at end of file diff --git a/include/core/graph.h b/include/core/graph.h index bf8ce5c..9326eaf 100644 --- a/include/core/graph.h +++ b/include/core/graph.h @@ -37,10 +37,6 @@ namespace infini tensors.erase(it); } - void deleteConnection(Tensor tensor, Operator op); - void addConnection(Tensor tensor, Operator op); - void replaceConnection(Tensor oldInput, Tensor newInput, Operator op); - const TensorVec &getTensors() const { return tensors; } const OpVec &getOperators() const { return ops; } Tensor getTensor(int) const; diff --git a/include/core/op_type.h b/include/core/op_type.h index f1a5115..db67f33 100644 --- a/include/core/op_type.h +++ b/include/core/op_type.h @@ -19,7 +19,6 @@ namespace infini Concat, Div, Mul, - Reshape, MatMul, Relu, Sub, diff --git a/src/core/graph.cc b/src/core/graph.cc index 6903595..3a90637 100644 --- a/src/core/graph.cc +++ b/src/core/graph.cc @@ -102,6 +102,9 @@ namespace infini { // =================================== 作业 =================================== // TODO: 设计一个算法来实现指定的图优化规则 + // 图优化规则如下: + // 1. 去除冗余的算子(例如,两个相邻的算子都是 transpose 算子,且做的是相反的操作,可以将其全部删除) + // 2. 合并算子(例如,矩阵乘算子中含有属性transA、transB,如果其输入存在transpose,且对最后两个维度做交换,就可以将transpose融入到矩阵乘算子的属性中去) // =================================== 作业 =================================== } @@ -175,43 +178,6 @@ namespace infini return tensors; } - void GraphObj::deleteConnection(Tensor tensor, Operator op) - { - // if op is target - IT_ASSERT(std::find(tensor->getTargets().begin(), - tensor->getTargets().end(), - op) != tensor->getTargets().end()); - tensor->removeTarget(op); - if (tensor->getSource()) - { - tensor->getSource()->removeSuccessors(op); - op->removePredecessors(tensor->getSource()); - } - } - - // add op as a target - void GraphObj::addConnection(Tensor tensor, Operator op) - { - tensor->addTarget(op); - if (tensor->getSource()) - { - tensor->getSource()->addSuccessors(op); - op->addPredecessors(tensor->getSource()); - } - } - - void GraphObj::replaceConnection(Tensor oldTensor, Tensor newTensor, - Operator op) - { - // op is a target of old tensor - IT_ASSERT(std::find(oldTensor->getTargets().begin(), - oldTensor->getTargets().end(), - op) != oldTensor->getTargets().end()); - addConnection(newTensor, op); - deleteConnection(oldTensor, op); - op->replaceInput(oldTensor, newTensor); - } - // tensor's "source" and "target" must be in "ops". // tensor has no "source" and no "target" must not exist. // "inputs" or "outputs" of operators must be in "tensors" diff --git a/src/core/op_type.cc b/src/core/op_type.cc index 7699a90..b2a721a 100644 --- a/src/core/op_type.cc +++ b/src/core/op_type.cc @@ -17,7 +17,6 @@ namespace infini CASE(Div); CASE(Cast); CASE(Clip); - CASE(Reshape); CASE(Relu); CASE(Transpose); CASE(Concat); diff --git a/src/operators/matmul.cc b/src/operators/matmul.cc index 139fffb..0ce94a1 100644 --- a/src/operators/matmul.cc +++ b/src/operators/matmul.cc @@ -1,6 +1,4 @@ #include "operators/matmul.h" -#include "utils/operator_utils.h" -#include namespace infini { @@ -29,6 +27,7 @@ namespace infini // TODO:返回经过 matmul 操作后的 shape // REF: https://github.com/onnx/onnx/blob/main/docs/Operators.md#gemm // =================================== 作业 =================================== + return {{}}; } } // namespace infini \ No newline at end of file