-
Notifications
You must be signed in to change notification settings - Fork 53
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
277 additions
and
89 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,8 @@ namespace infini | |
Concat, | ||
Div, | ||
Mul, | ||
Reshape, | ||
MatMul, | ||
Relu, | ||
Sub, | ||
Transpose, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
#pragma once | ||
#include "core/operator.h" | ||
|
||
namespace infini | ||
{ | ||
/** | ||
* @brief Matrix multiplication. | ||
* | ||
*/ | ||
class MatmulObj : public OperatorObj | ||
{ | ||
private: | ||
// InfiniTensor assumes a row-major tensor layout. `transA`=false means | ||
// default dims, true means A should be transposed before matmul. This is in | ||
// oppsite to the column-major BLAS. | ||
bool transA, transB; | ||
|
||
// Auxiliary attributes which are not a part of operator attributes. | ||
int m, n, k; | ||
|
||
public: | ||
/** | ||
* @brief Matmul operator with batch broadcast and tensor transpose | ||
* supports. Only one tensor with singe batch can be broadcasted due to the | ||
* BLAS interface restriction. Tranpose indicates whether the last two | ||
* dimensions should be transposed before Matmul and does not affect other | ||
* leading dimensions. | ||
* | ||
* Matmul show how operators are defined in InfiniTensor. The constructor of | ||
* an operator can create output tensors for the operator or not, which | ||
* depends on `graph`. | ||
* | ||
* @param graph The computation graph that this operator belongs to. | ||
* @param A The input tensor. | ||
* @param B The input tensor. | ||
* @param C C is the output of Matmul. If outputs are going to be created in | ||
* the constructor, C should be an empty Ref. | ||
* @param transA If matrix A should be transposed when computing. | ||
* @param transB If matrix B should be transposed when computing. | ||
*/ | ||
MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, | ||
bool transA = false, bool transB = false); | ||
OP_CLONE(MatmulObj); | ||
|
||
std::string toString() const override; | ||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override; | ||
|
||
int numInputs() const override { return inputs.size(); } | ||
int numOutputs() const override { return 1; } | ||
|
||
bool getTransA() const { return transA; } | ||
bool getTransB() const { return transB; } | ||
void setTransA(bool transA) { this->transA = transA; } | ||
void setTransB(bool transB) { this->transB = transB; } | ||
int getM() const { return m; } | ||
int getN() const { return n; } | ||
int getK() const { return k; } | ||
}; | ||
|
||
} // namespace infini |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
#include "operators/matmul.h" | ||
#include "utils/operator_utils.h" | ||
#include <numeric> | ||
|
||
namespace infini | ||
{ | ||
|
||
MatmulObj::MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, bool transA, | ||
bool transB) | ||
: OperatorObj(OpType::MatMul, TensorVec{A, B}, {C}), | ||
transA(transA), transB(transB) | ||
{ | ||
IT_ASSERT(checkValid(graph)); | ||
} | ||
|
||
string MatmulObj::toString() const | ||
{ | ||
std::ostringstream os; | ||
os << "Matmul([" << (transA ? "A^T" : "A") << "," << (transB ? "B^T" : "B]") | ||
<< ",A=" << inputs[0]->getGuid() | ||
<< ",B=" << inputs[1]->getGuid() << ",C=" << outputs[0]->getGuid() | ||
<< ",mnk=[" << m << "," << n << "," << k << "])"; | ||
return os.str(); | ||
} | ||
|
||
optional<vector<Shape>> MatmulObj::inferShape(const TensorVec &inputs) | ||
{ | ||
// =================================== 作业 =================================== | ||
// TODO:返回经过 matmul 操作后的 shape | ||
// REF: https://github.com/onnx/onnx/blob/main/docs/Operators.md#gemm | ||
// =================================== 作业 =================================== | ||
} | ||
|
||
} // namespace infini |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
#include "core/graph.h" | ||
#include "core/kernel.h" | ||
#include "core/runtime.h" | ||
#include "operators/matmul.h" | ||
#include "operators/transpose.h" | ||
|
||
#include "test.h" | ||
|
||
namespace infini | ||
{ | ||
TEST(Graph, Optimize) | ||
{ | ||
Runtime runtime = NativeCpuRuntimeObj::getInstance(); | ||
Graph g = make_ref<GraphObj>(runtime); | ||
Tensor i1 = g->addTensor({2, 3, 4, 5}, DataType::UInt32); | ||
Tensor i2 = g->addTensor({2, 3, 4, 5}, DataType::UInt32); | ||
Tensor t1 = g->addTensor({2, 3, 5, 4}, DataType::UInt32); | ||
Tensor t2 = g->addTensor({2, 3, 4, 5}, DataType::UInt32); | ||
Tensor t3 = g->addTensor({2, 3, 5, 4}, DataType::UInt32); | ||
Tensor o = g->addTensor({2, 3, 4, 4}, DataType::UInt32); | ||
g->addOpWithOutputs<TransposeObj>(i1, t1, Shape{0, 1, 3, 2}); | ||
g->addOpWithOutputs<TransposeObj>(t1, t2, Shape{0, 1, 3, 2}); | ||
g->addOpWithOutputs<TransposeObj>(i2, t3, Shape{0, 1, 3, 2}); | ||
g->addOpWithOutputs<MatmulObj>(t2, t3, o); | ||
// 优化前 | ||
g->print(); | ||
g->optimize(); | ||
// 优化后 | ||
g->print(); | ||
EXPECT_EQ(g->getOperators().size(), 1); | ||
EXPECT_EQ(g->getTensors().size(), 3); | ||
EXPECT_EQ(g->getOperators()[0]->getOpType().underlying(), 8); | ||
auto op = as<MatmulObj>(g->getOperators()[0]); | ||
EXPECT_EQ(op->getInputs(0)->getGuid(), 2); | ||
EXPECT_EQ(op->getInputs(1)->getGuid(), 3); | ||
EXPECT_EQ(op->getOutputs()[0], o); | ||
EXPECT_EQ(op->getTransA(), false); | ||
EXPECT_EQ(op->getTransB(), true); | ||
} | ||
} |
Oops, something went wrong.