Skip to content

Commit

Permalink
添加图优化作业
Browse files Browse the repository at this point in the history
  • Loading branch information
bitzyz committed Jul 31, 2024
1 parent 154146b commit c80e1d9
Show file tree
Hide file tree
Showing 11 changed files with 311 additions and 89 deletions.
34 changes: 34 additions & 0 deletions docs/训练营作业介绍.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,37 @@ Shape infer_broadcast(const Shape &A, const Shape &B) {
}
````
## 作业七:矩阵乘形状推导
难度:⭐⭐⭐
对应测例:``test_matmul``
需要实现的代码块位置:`src/operators/matmul.cc`
````c++
optional<vector<Shape>> MatmulObj::inferShape(const TensorVec &inputs)
{
// =================================== 作业 ===================================
// TODO:返回经过 matmul 操作后的 shape
// REF: https://github.com/onnx/onnx/blob/main/docs/Operators.md#gemm
// =================================== 作业 ===================================
}
````

## 作业八:简单图优化规则实现

难度:⭐⭐⭐⭐

对应测例:``test_graph``

需要实现的代码块位置:`src/core/graph.cc`

````c++
void GraphObj::optimize()
{
// =================================== 作业 ===================================
// TODO: 设计一个算法来实现指定的图优化规则
// =================================== 作业 ===================================
}
````
7 changes: 0 additions & 7 deletions include/core/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,6 @@ namespace infini
void addConnection(Tensor tensor, Operator op);
void replaceConnection(Tensor oldInput, Tensor newInput, Operator op);

Operator cloneOperator(Operator op, TensorVec inputs, TensorVec outputs)
{
auto opClone = op->clone(inputs, outputs);
addOperatorAndConnect(opClone);
return opClone;
}

const TensorVec &getTensors() const { return tensors; }
const OpVec &getOperators() const { return ops; }
Tensor getTensor(int) const;
Expand Down
2 changes: 2 additions & 0 deletions include/core/op_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ namespace infini
Concat,
Div,
Mul,
Reshape,
MatMul,
Relu,
Sub,
Transpose,
Expand Down
149 changes: 77 additions & 72 deletions include/core/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,86 +3,91 @@
#include "core/op_type.h"
#include "core/tensor.h"

namespace infini {
using KernelAttrs = std::tuple<Device, OpType::underlying_t>;
namespace infini
{
using KernelAttrs = std::tuple<Device, OpType::underlying_t>;

class GraphObj;
class OperatorObj : public Object {
friend class GraphObj;
class GraphObj;
class OperatorObj : public Object
{
friend class GraphObj;

protected:
OpType type;
TensorVec inputs;
TensorVec outputs;
vector<WRef<OperatorObj>> predecessors;
vector<WRef<OperatorObj>> successors;
protected:
OpType type;
TensorVec inputs;
TensorVec outputs;
vector<WRef<OperatorObj>> predecessors;
vector<WRef<OperatorObj>> successors;

public:
OperatorObj(OpType opType, TensorVec inputs, TensorVec outputs);
virtual optional<vector<Shape>> inferShape(const TensorVec &inputs) = 0;
virtual vector<DataType> inferDataType(const TensorVec &inputs) const;
/**
* @brief Constructs outputs (if requried) and check whether the operator is
* valid.
*
* @param graph If graph is not nullptr, outputs should be created in this
* function.
*/
bool checkValid(GraphObj *graph);
public:
OperatorObj(OpType opType, TensorVec inputs, TensorVec outputs);
virtual optional<vector<Shape>> inferShape(const TensorVec &inputs) = 0;
virtual vector<DataType> inferDataType(const TensorVec &inputs) const;
/**
* @brief Constructs outputs (if requried) and check whether the operator is
* valid.
*
* @param graph If graph is not nullptr, outputs should be created in this
* function.
*/
bool checkValid(GraphObj *graph);

public: // getter and setter
const TensorVec &getInputs() const { return inputs; }
const TensorVec &getOutputs() const { return outputs; }
Tensor getInputs(size_t i) const { return inputs.at(i); }
Tensor getOutput() const {
IT_ASSERT(outputs.size() == 1, "Unimplemented");
return outputs[0];
}
Tensor getOutput(size_t i) const {
IT_ASSERT(i < outputs.size(), "Index exceeded");
return outputs.at(i);
}
OpVec getPredecessors() const { return wrefs_to_refs(predecessors); }
OpVec getSuccessors() const { return wrefs_to_refs(successors); }
OpType getOpType() const { return type; }
// HACK: set correct data type
DataType getDType() const { return getInputs(0)->getDType(); }
DataType getOutDType() const { return getOutput()->getDType(); }
virtual int numInputs() const = 0;
virtual int numOutputs() const = 0;
public: // getter and setter
const TensorVec &getInputs() const { return inputs; }
const TensorVec &getOutputs() const { return outputs; }
Tensor getInputs(size_t i) const { return inputs.at(i); }
Tensor getOutput() const
{
IT_ASSERT(outputs.size() == 1, "Unimplemented");
return outputs[0];
}
Tensor getOutput(size_t i) const
{
IT_ASSERT(i < outputs.size(), "Index exceeded");
return outputs.at(i);
}
OpVec getPredecessors() const { return wrefs_to_refs(predecessors); }
OpVec getSuccessors() const { return wrefs_to_refs(successors); }
OpType getOpType() const { return type; }
// HACK: set correct data type
DataType getDType() const { return getInputs(0)->getDType(); }
DataType getOutDType() const { return getOutput()->getDType(); }
virtual int numInputs() const = 0;
virtual int numOutputs() const = 0;

/**
* @brief Clone this operator and replace its inputs and outputs.
*
* @param newInputs
* @param newOutputs
* @return Operator
*/
virtual Operator clone(const TensorVec &newInputs,
const TensorVec &newOutputs) const = 0;
/**
* @brief Clone this operator and replace its inputs and outputs.
*
* @param newInputs
* @param newOutputs
* @return Operator
*/
virtual Operator clone(const TensorVec &newInputs,
const TensorVec &newOutputs) const = 0;

protected:
optional<vector<Shape>> inferShape();
vector<DataType> inferDataType() const;
protected:
optional<vector<Shape>> inferShape();
vector<DataType> inferDataType() const;

private:
void addPredecessors(const Operator &op) { predecessors.emplace_back(op); }
void addSuccessors(const Operator &op) { successors.emplace_back(op); }
void removePredecessors(const Operator &op);
void removeSuccessors(const Operator &op);
void replaceInput(Tensor t1, Tensor t2);
};
private:
void addPredecessors(const Operator &op) { predecessors.emplace_back(op); }
void addSuccessors(const Operator &op) { successors.emplace_back(op); }
void removePredecessors(const Operator &op);
void removeSuccessors(const Operator &op);
void replaceInput(Tensor t1, Tensor t2);
};

#define OP_CLONE(OpObj) \
virtual Operator clone(const TensorVec &newInputs, \
const TensorVec &newOutputs) const override { \
auto op = infini::make_ref<OpObj>(*this); \
op->inputs = newInputs; \
op->outputs = newOutputs; \
op->predecessors.clear(); \
op->successors.clear(); \
IT_ASSERT(op->checkValid(nullptr)); \
return op; \
#define OP_CLONE(OpObj) \
virtual Operator clone(const TensorVec &newInputs, \
const TensorVec &newOutputs) const override \
{ \
auto op = infini::make_ref<OpObj>(*this); \
op->inputs = newInputs; \
op->outputs = newOutputs; \
op->predecessors.clear(); \
op->successors.clear(); \
IT_ASSERT(op->checkValid(nullptr)); \
return op; \
}

} // namespace infini
4 changes: 2 additions & 2 deletions include/core/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
#include <cstring>
#include <fstream>

namespace infini {
namespace infini
{
class GraphObj;
using ShapeElem = int;
using Shape = vector<ShapeElem>;
Expand Down Expand Up @@ -158,7 +159,6 @@ namespace infini {
++itr;
}
}

};

} // namespace infini
60 changes: 60 additions & 0 deletions include/operators/matmul.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#pragma once
#include "core/operator.h"

namespace infini
{
/**
* @brief Matrix multiplication.
*
*/
class MatmulObj : public OperatorObj
{
private:
// InfiniTensor assumes a row-major tensor layout. `transA`=false means
// default dims, true means A should be transposed before matmul. This is in
// oppsite to the column-major BLAS.
bool transA, transB;

// Auxiliary attributes which are not a part of operator attributes.
int m, n, k;

public:
/**
* @brief Matmul operator with batch broadcast and tensor transpose
* supports. Only one tensor with singe batch can be broadcasted due to the
* BLAS interface restriction. Tranpose indicates whether the last two
* dimensions should be transposed before Matmul and does not affect other
* leading dimensions.
*
* Matmul show how operators are defined in InfiniTensor. The constructor of
* an operator can create output tensors for the operator or not, which
* depends on `graph`.
*
* @param graph The computation graph that this operator belongs to.
* @param A The input tensor.
* @param B The input tensor.
* @param C C is the output of Matmul. If outputs are going to be created in
* the constructor, C should be an empty Ref.
* @param transA If matrix A should be transposed when computing.
* @param transB If matrix B should be transposed when computing.
*/
MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C,
bool transA = false, bool transB = false);
OP_CLONE(MatmulObj);

std::string toString() const override;
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;

int numInputs() const override { return inputs.size(); }
int numOutputs() const override { return 1; }

bool getTransA() const { return transA; }
bool getTransB() const { return transB; }
void setTransA(bool transA) { this->transA = transA; }
void setTransB(bool transB) { this->transB = transB; }
int getM() const { return m; }
int getN() const { return n; }
int getK() const { return k; }
};

} // namespace infini
11 changes: 3 additions & 8 deletions src/core/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,9 @@ namespace infini

void GraphObj::optimize()
{
for (auto &op : ops)
{
switch (op->getOpType().underlying())
{
default:
break;
}
}
// =================================== 作业 ===================================
// TODO: 设计一个算法来实现指定的图优化规则
// =================================== 作业 ===================================
}

Tensor GraphObj::getTensor(int fuid) const
Expand Down
2 changes: 2 additions & 0 deletions src/core/op_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ namespace infini
CASE(Div);
CASE(Cast);
CASE(Clip);
CASE(Reshape);
CASE(Relu);
CASE(Transpose);
CASE(Concat);
CASE(MatMul);

default:
return "Unknown";
Expand Down
34 changes: 34 additions & 0 deletions src/operators/matmul.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#include "operators/matmul.h"
#include "utils/operator_utils.h"
#include <numeric>

namespace infini
{

MatmulObj::MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, bool transA,
bool transB)
: OperatorObj(OpType::MatMul, TensorVec{A, B}, {C}),
transA(transA), transB(transB)
{
IT_ASSERT(checkValid(graph));
}

string MatmulObj::toString() const
{
std::ostringstream os;
os << "Matmul([" << (transA ? "A^T" : "A") << "," << (transB ? "B^T" : "B]")
<< ",A=" << inputs[0]->getGuid()
<< ",B=" << inputs[1]->getGuid() << ",C=" << outputs[0]->getGuid()
<< ",mnk=[" << m << "," << n << "," << k << "])";
return os.str();
}

optional<vector<Shape>> MatmulObj::inferShape(const TensorVec &inputs)
{
// =================================== 作业 ===================================
// TODO:返回经过 matmul 操作后的 shape
// REF: https://github.com/onnx/onnx/blob/main/docs/Operators.md#gemm
// =================================== 作业 ===================================
}

} // namespace infini
Loading

0 comments on commit c80e1d9

Please sign in to comment.