Skip to content

Commit

Permalink
Plan rank compiler (#10141)
Browse files Browse the repository at this point in the history
Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
  • Loading branch information
strint and oneflow-ci-bot authored Jun 4, 2023
1 parent 4ac3692 commit a9a339b
Show file tree
Hide file tree
Showing 15 changed files with 957 additions and 147 deletions.
11 changes: 11 additions & 0 deletions oneflow/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class Graph {
std::function<Maybe<void>(NodeType*)> NodeHandler) const;
void ReverseTopoForEachNode(std::function<void(NodeType*)> NodeHandler) const;
void ForEachEdge(std::function<void(EdgeType*)> EdgeHandler) const;
Maybe<void> MaybeForEachEdge(std::function<Maybe<void>(EdgeType*)> EdgeHandler) const;

void SortedTopoForEachNode(std::function<bool(const EdgeType* lhs, const EdgeType* rhs)> LessThan,
std::function<void(NodeType*)> NodeHandler) const;
Expand Down Expand Up @@ -292,6 +293,16 @@ void Graph<NodeType, EdgeType>::ForEachEdge(std::function<void(EdgeType*)> EdgeH
}
}

template<typename NodeType, typename EdgeType>
Maybe<void> Graph<NodeType, EdgeType>::MaybeForEachEdge(
std::function<Maybe<void>(EdgeType*)> EdgeHandler) const {
for (auto& x : edges_) {
if (x->src_node() == nullptr && x->dst_node() == nullptr) { continue; }
JUST(EdgeHandler(x.get()));
}
return Maybe<void>::Ok();
}

template<typename NodeType, typename EdgeType>
NodeType* Graph<NodeType, EdgeType>::SoleNode() const {
CHECK_EQ(nodes_.size(), 1);
Expand Down
59 changes: 50 additions & 9 deletions oneflow/core/graph/op_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,37 @@ limitations under the License.
#include "oneflow/core/job/job_builder.h"
#include "oneflow/core/job/local_sig_infer_hint.h"
#include "oneflow/core/job/lazy_mode.h"
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/persistence/tee_persistent_log_stream.h"
#include "oneflow/core/auto_parallel/algorithm_util.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/framework/sbp_infer_util.h"

namespace oneflow {

bool OpEdge::NeedBoxing() const {
if (src_node()->parallel_desc_sym() != dst_node()->parallel_desc_sym()) { return true; }
if (src_node()->parallel_desc().parallel_num() == 1) { return false; }
for (const auto& lbi : *lbis_) {
Shape src_reduced_hierarchy;
Shape dst_reduced_hierarchy;
NdSbp src_reduced_nd_sbp;
NdSbp dst_reduced_nd_sbp;

InOutParallelDimReduce(*src_node()->parallel_desc().hierarchy(),
*dst_node()->parallel_desc().hierarchy(), src_node()->NdSbp4Lbi(lbi),
dst_node()->NdSbp4Lbi(lbi), &src_reduced_hierarchy,
&dst_reduced_hierarchy, &src_reduced_nd_sbp, &dst_reduced_nd_sbp,
src_node()->LogicalBlobDesc4Lbi(lbi).shape());
if (src_reduced_hierarchy != dst_reduced_hierarchy
|| src_reduced_nd_sbp != dst_reduced_nd_sbp) {
// Not one to one
return true;
}
}
return false;
}

std::string OpEdge::VisualStr() const {
std::string str;
int32_t idx = 0;
Expand Down Expand Up @@ -54,12 +80,11 @@ const NdSbp& OpNode::NdSbp4Lbi(const LogicalBlobId& lbi) const {
return it->second;
}

OpNode::OpNode(const std::shared_ptr<const ParallelDesc>& parallel_desc,
const OperatorConf& op_conf)
OpNode::OpNode(Symbol<ParallelDesc> parallel_desc, const OperatorConf& op_conf)
: parallel_desc_(parallel_desc),
op_(CHECK_JUST(ConstructOp(op_conf, parallel_desc->device_type()))),
ibns_(op_->input_bns().begin(), op_->input_bns().end()) {
CHECK_JUST(op_->FillOpParallelDesc(parallel_desc));
CHECK_JUST(op_->FillOpParallelDesc(parallel_desc.shared_from_symbol()));
}

std::string OpNode::VisualStr() const {
Expand Down Expand Up @@ -194,16 +219,14 @@ void OpGraph::CheckIsDAG() const {

namespace {

std::function<std::shared_ptr<const ParallelDesc>(const std::string&)>
MakeGetterParallelDesc4OpName(const Job& job) {
std::function<Symbol<ParallelDesc>(const std::string&)> MakeGetterParallelDesc4OpName(
const Job& job) {
const Placement& placement = job.placement();
auto op_name2parallel_desc =
std::make_shared<HashMap<std::string, std::shared_ptr<const ParallelDesc>>>();
auto op_name2parallel_desc = std::make_shared<HashMap<std::string, Symbol<ParallelDesc>>>();
op_name2parallel_desc->reserve(job.net().op_size());
for (const auto& placement_group : placement.placement_group()) {
const ParallelConf& parallel_conf = placement_group.parallel_conf();
std::shared_ptr<const ParallelDesc> parallel_desc =
std::make_shared<const ParallelDesc>(parallel_conf);
Symbol<ParallelDesc> parallel_desc = SymbolOf(ParallelDesc(parallel_conf));
for (const std::string& op_name : placement_group.op_set().op_name()) {
CHECK(op_name2parallel_desc->emplace(op_name, parallel_desc).second)
<< "op_name: " << op_name;
Expand Down Expand Up @@ -566,6 +589,11 @@ Maybe<void> OpGraph::ForEachOpNode(const std::function<Maybe<void>(const OpNode&
return Maybe<void>::Ok();
}

std::function<bool(const OpNode* src, const OpNode* dst)> OpGraph::CreatePredicatorIsReachable()
const {
return MakePredicatorIsReachable();
}

// Print the graph with SBP in order
void OpGraph::PrintSBPGraphDebugInfo() const {
// test debug
Expand Down Expand Up @@ -622,4 +650,17 @@ void OpGraph::PrintSBPGraphDebugInfo() const {
}
}

OpGraphSingletonGuard::OpGraphSingletonGuard(const Job& job) {
// new Singleton<OpGraph> and set log configs.
Singleton<OpGraph>::New(job);
const JobDesc& job_desc = GlobalJobDesc();
if (Singleton<ResourceDesc, ForSession>::Get()->enable_debug_mode()) {
TeePersistentLogStream::Create(StrCat("optimized_job", job_desc.job_id()))->Write(job);
Singleton<OpGraph>::Get()->ToDotWithFilePath(
"optimized_dlnet_" + std::to_string(job_desc.job_id()) + "_op_graph.dot");
}
}

OpGraphSingletonGuard::~OpGraphSingletonGuard() { Singleton<OpGraph>::Delete(); }

} // namespace oneflow
16 changes: 13 additions & 3 deletions oneflow/core/graph/op_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ class OpGraph;
class OpNode final : public Node<OpNode, OpEdge> {
public:
OF_DISALLOW_COPY_AND_MOVE(OpNode);
explicit OpNode(const std::shared_ptr<const ParallelDesc>& parallel_desc,
const OperatorConf& op_conf);
explicit OpNode(Symbol<ParallelDesc> parallel_desc, const OperatorConf& op_conf);
~OpNode() = default;

// Getters
bool IsTimeShapeIdentity() const;
const Operator& op() const { return *op_; }
std::shared_ptr<const Operator> shared_op() const { return op_; }
const ParallelDesc& parallel_desc() const { return *parallel_desc_; }
Symbol<ParallelDesc> parallel_desc_sym() const { return parallel_desc_; }
const SbpSignature& sbp_signature() const { return *CHECK_JUST(op().sbp_signature()); }
const NdSbpSignature& nd_sbp_signature() const { return *CHECK_JUST(op().nd_sbp_signature()); }
const SbpParallel& SbpParallel4Lbi(const LogicalBlobId& lbi) const;
Expand All @@ -67,7 +67,7 @@ class OpNode final : public Node<OpNode, OpEdge> {
void InitLbi2SourceNode();
void InitLbi2NdSbp();

std::shared_ptr<const ParallelDesc> parallel_desc_;
Symbol<ParallelDesc> parallel_desc_;
std::shared_ptr<Operator> op_;
HashSet<std::string> ibns_;
HashMap<LogicalBlobId, OpNode*> lbi2source_node_;
Expand All @@ -88,6 +88,8 @@ class OpEdge final : public Edge<OpNode, OpEdge> {
const std::vector<LogicalBlobId>& lbis() const { return *lbis_; }
const HashMap<LogicalBlobId, std::string>& lbi2obn() const { return *lbi2obn_; }
const HashMap<LogicalBlobId, std::vector<std::string>>& lbi2ibns() const { return *lbi2ibns_; }

bool NeedBoxing() const;
std::string VisualStr() const override;

private:
Expand Down Expand Up @@ -130,6 +132,7 @@ class OpGraph final : public Graph<OpNode, OpEdge> {

Maybe<void> Init(const Job& job);

std::function<bool(const OpNode* src, const OpNode* dst)> CreatePredicatorIsReachable() const;
// Print the graph with SBP in order
void PrintSBPGraphDebugInfo() const;

Expand All @@ -155,6 +158,13 @@ class OpGraph final : public Graph<OpNode, OpEdge> {
HashMap<std::string, HashSet<std::string>> producer_op_name2ctrl_consumer_op_names_;
};

class OpGraphSingletonGuard {
public:
OF_DISALLOW_COPY_AND_MOVE(OpGraphSingletonGuard);
explicit OpGraphSingletonGuard(const Job& job);
~OpGraphSingletonGuard();
};

} // namespace oneflow

#endif // ONEFLOW_CORE_GRAPH_OP_GRAPH_H_
Loading

0 comments on commit a9a339b

Please sign in to comment.