From a9a339be9cc18bca4753259f1aaad2c5c4f632cf Mon Sep 17 00:00:00 2001 From: Xiaoyu Xu Date: Sun, 4 Jun 2023 09:55:54 +0800 Subject: [PATCH] Plan rank compiler (#10141) Co-authored-by: oneflow-ci-bot --- oneflow/core/graph/graph.h | 11 + oneflow/core/graph/op_graph.cpp | 59 ++- oneflow/core/graph/op_graph.h | 16 +- oneflow/core/graph/task_graph.cpp | 633 ++++++++++++++++++++++++---- oneflow/core/graph/task_graph.h | 114 ++++- oneflow/core/graph/task_node.cpp | 7 + oneflow/core/graph/task_node.h | 5 + oneflow/core/job/compiler.cpp | 24 +- oneflow/core/job/parallel_desc.cpp | 9 + oneflow/core/job/parallel_desc.h | 2 + oneflow/core/job/plan_util.cpp | 60 ++- oneflow/core/job/plan_util.h | 10 +- oneflow/core/job/rank_compiler.cpp | 110 +++++ oneflow/core/job/rank_compiler.h | 43 ++ oneflow/core/operator/op_conf.proto | 1 + 15 files changed, 957 insertions(+), 147 deletions(-) create mode 100644 oneflow/core/job/rank_compiler.cpp create mode 100644 oneflow/core/job/rank_compiler.h diff --git a/oneflow/core/graph/graph.h b/oneflow/core/graph/graph.h index b9f62e01696..ef04abafc23 100644 --- a/oneflow/core/graph/graph.h +++ b/oneflow/core/graph/graph.h @@ -45,6 +45,7 @@ class Graph { std::function(NodeType*)> NodeHandler) const; void ReverseTopoForEachNode(std::function NodeHandler) const; void ForEachEdge(std::function EdgeHandler) const; + Maybe MaybeForEachEdge(std::function(EdgeType*)> EdgeHandler) const; void SortedTopoForEachNode(std::function LessThan, std::function NodeHandler) const; @@ -292,6 +293,16 @@ void Graph::ForEachEdge(std::function EdgeH } } +template +Maybe Graph::MaybeForEachEdge( + std::function(EdgeType*)> EdgeHandler) const { + for (auto& x : edges_) { + if (x->src_node() == nullptr && x->dst_node() == nullptr) { continue; } + JUST(EdgeHandler(x.get())); + } + return Maybe::Ok(); +} + template NodeType* Graph::SoleNode() const { CHECK_EQ(nodes_.size(), 1); diff --git a/oneflow/core/graph/op_graph.cpp b/oneflow/core/graph/op_graph.cpp index edbe3258ced..0f19fafa1b7 100644 --- a/oneflow/core/graph/op_graph.cpp +++ b/oneflow/core/graph/op_graph.cpp @@ -17,11 +17,37 @@ limitations under the License. #include "oneflow/core/job/job_builder.h" #include "oneflow/core/job/local_sig_infer_hint.h" #include "oneflow/core/job/lazy_mode.h" +#include "oneflow/core/common/container_util.h" +#include "oneflow/core/persistence/tee_persistent_log_stream.h" #include "oneflow/core/auto_parallel/algorithm_util.h" #include "oneflow/core/framework/nd_sbp.h" +#include "oneflow/core/framework/sbp_infer_util.h" namespace oneflow { +bool OpEdge::NeedBoxing() const { + if (src_node()->parallel_desc_sym() != dst_node()->parallel_desc_sym()) { return true; } + if (src_node()->parallel_desc().parallel_num() == 1) { return false; } + for (const auto& lbi : *lbis_) { + Shape src_reduced_hierarchy; + Shape dst_reduced_hierarchy; + NdSbp src_reduced_nd_sbp; + NdSbp dst_reduced_nd_sbp; + + InOutParallelDimReduce(*src_node()->parallel_desc().hierarchy(), + *dst_node()->parallel_desc().hierarchy(), src_node()->NdSbp4Lbi(lbi), + dst_node()->NdSbp4Lbi(lbi), &src_reduced_hierarchy, + &dst_reduced_hierarchy, &src_reduced_nd_sbp, &dst_reduced_nd_sbp, + src_node()->LogicalBlobDesc4Lbi(lbi).shape()); + if (src_reduced_hierarchy != dst_reduced_hierarchy + || src_reduced_nd_sbp != dst_reduced_nd_sbp) { + // Not one to one + return true; + } + } + return false; +} + std::string OpEdge::VisualStr() const { std::string str; int32_t idx = 0; @@ -54,12 +80,11 @@ const NdSbp& OpNode::NdSbp4Lbi(const LogicalBlobId& lbi) const { return it->second; } -OpNode::OpNode(const std::shared_ptr& parallel_desc, - const OperatorConf& op_conf) +OpNode::OpNode(Symbol parallel_desc, const OperatorConf& op_conf) : parallel_desc_(parallel_desc), op_(CHECK_JUST(ConstructOp(op_conf, parallel_desc->device_type()))), ibns_(op_->input_bns().begin(), op_->input_bns().end()) { - CHECK_JUST(op_->FillOpParallelDesc(parallel_desc)); + CHECK_JUST(op_->FillOpParallelDesc(parallel_desc.shared_from_symbol())); } std::string OpNode::VisualStr() const { @@ -194,16 +219,14 @@ void OpGraph::CheckIsDAG() const { namespace { -std::function(const std::string&)> -MakeGetterParallelDesc4OpName(const Job& job) { +std::function(const std::string&)> MakeGetterParallelDesc4OpName( + const Job& job) { const Placement& placement = job.placement(); - auto op_name2parallel_desc = - std::make_shared>>(); + auto op_name2parallel_desc = std::make_shared>>(); op_name2parallel_desc->reserve(job.net().op_size()); for (const auto& placement_group : placement.placement_group()) { const ParallelConf& parallel_conf = placement_group.parallel_conf(); - std::shared_ptr parallel_desc = - std::make_shared(parallel_conf); + Symbol parallel_desc = SymbolOf(ParallelDesc(parallel_conf)); for (const std::string& op_name : placement_group.op_set().op_name()) { CHECK(op_name2parallel_desc->emplace(op_name, parallel_desc).second) << "op_name: " << op_name; @@ -566,6 +589,11 @@ Maybe OpGraph::ForEachOpNode(const std::function(const OpNode& return Maybe::Ok(); } +std::function OpGraph::CreatePredicatorIsReachable() + const { + return MakePredicatorIsReachable(); +} + // Print the graph with SBP in order void OpGraph::PrintSBPGraphDebugInfo() const { // test debug @@ -622,4 +650,17 @@ void OpGraph::PrintSBPGraphDebugInfo() const { } } +OpGraphSingletonGuard::OpGraphSingletonGuard(const Job& job) { + // new Singleton and set log configs. + Singleton::New(job); + const JobDesc& job_desc = GlobalJobDesc(); + if (Singleton::Get()->enable_debug_mode()) { + TeePersistentLogStream::Create(StrCat("optimized_job", job_desc.job_id()))->Write(job); + Singleton::Get()->ToDotWithFilePath( + "optimized_dlnet_" + std::to_string(job_desc.job_id()) + "_op_graph.dot"); + } +} + +OpGraphSingletonGuard::~OpGraphSingletonGuard() { Singleton::Delete(); } + } // namespace oneflow diff --git a/oneflow/core/graph/op_graph.h b/oneflow/core/graph/op_graph.h index f36acf04ada..e2281e4749e 100644 --- a/oneflow/core/graph/op_graph.h +++ b/oneflow/core/graph/op_graph.h @@ -34,8 +34,7 @@ class OpGraph; class OpNode final : public Node { public: OF_DISALLOW_COPY_AND_MOVE(OpNode); - explicit OpNode(const std::shared_ptr& parallel_desc, - const OperatorConf& op_conf); + explicit OpNode(Symbol parallel_desc, const OperatorConf& op_conf); ~OpNode() = default; // Getters @@ -43,6 +42,7 @@ class OpNode final : public Node { const Operator& op() const { return *op_; } std::shared_ptr shared_op() const { return op_; } const ParallelDesc& parallel_desc() const { return *parallel_desc_; } + Symbol parallel_desc_sym() const { return parallel_desc_; } const SbpSignature& sbp_signature() const { return *CHECK_JUST(op().sbp_signature()); } const NdSbpSignature& nd_sbp_signature() const { return *CHECK_JUST(op().nd_sbp_signature()); } const SbpParallel& SbpParallel4Lbi(const LogicalBlobId& lbi) const; @@ -67,7 +67,7 @@ class OpNode final : public Node { void InitLbi2SourceNode(); void InitLbi2NdSbp(); - std::shared_ptr parallel_desc_; + Symbol parallel_desc_; std::shared_ptr op_; HashSet ibns_; HashMap lbi2source_node_; @@ -88,6 +88,8 @@ class OpEdge final : public Edge { const std::vector& lbis() const { return *lbis_; } const HashMap& lbi2obn() const { return *lbi2obn_; } const HashMap>& lbi2ibns() const { return *lbi2ibns_; } + + bool NeedBoxing() const; std::string VisualStr() const override; private: @@ -130,6 +132,7 @@ class OpGraph final : public Graph { Maybe Init(const Job& job); + std::function CreatePredicatorIsReachable() const; // Print the graph with SBP in order void PrintSBPGraphDebugInfo() const; @@ -155,6 +158,13 @@ class OpGraph final : public Graph { HashMap> producer_op_name2ctrl_consumer_op_names_; }; +class OpGraphSingletonGuard { + public: + OF_DISALLOW_COPY_AND_MOVE(OpGraphSingletonGuard); + explicit OpGraphSingletonGuard(const Job& job); + ~OpGraphSingletonGuard(); +}; + } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_OP_GRAPH_H_ diff --git a/oneflow/core/graph/task_graph.cpp b/oneflow/core/graph/task_graph.cpp index 19ec3be4098..10280a8dfe5 100644 --- a/oneflow/core/graph/task_graph.cpp +++ b/oneflow/core/graph/task_graph.cpp @@ -14,11 +14,15 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/task_graph.h" +#include "oneflow/core/common/just.h" +#include "oneflow/core/common/maybe.h" #include "oneflow/core/common/util.h" +#include "oneflow/core/common/container_util.h" #include "oneflow/core/common/env_var/debug_mode.h" #include "oneflow/core/graph/inplace_lbi_graph.h" #include "oneflow/core/job/job_conf.pb.h" #include "oneflow/core/job/job_desc.h" +#include "oneflow/core/job/task.pb.h" #include "oneflow/core/register/blob_desc.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/operator/variable_op.h" @@ -36,7 +40,10 @@ limitations under the License. #include "oneflow/core/graph/straighten_nodes.h" #include "oneflow/core/register/runtime_register_desc.h" #include "oneflow/core/common/env_var/env_var.h" +#include "oneflow/core/graph/boxing_task_graph.pb.h" +#include "oneflow/core/graph/task_graph_rebuild_ctx.h" #include "oneflow/core/framework/user_op_registry_manager.h" +#include "oneflow/core/graph/task_type_visitor.h" namespace oneflow { @@ -74,8 +81,9 @@ bool IsTickOpConf(const OperatorConf& conf) { return IsClassRegistered(conf.op_type_case()); } -std::string GetOpConfCalculationPassName(const OperatorConf& op_conf) { +const std::string& GetOpConfCalculationPassName(const OperatorConf& op_conf) { CHECK(op_conf.has_scope_symbol_id()); + if (op_conf.has_calculation_pass_name()) { return op_conf.calculation_pass_name(); } int64_t scope_symbol_id = op_conf.scope_symbol_id(); CHECK(Singleton>::Get()->Has(scope_symbol_id)) << " Error! op : \n " << op_conf.DebugString() @@ -284,37 +292,64 @@ Maybe MakeGetterTaskNode4MachineId7ThrdId( return Maybe::Ok(); } +namespace { + +StreamId GetStreamId(const OpNode* op_node, int64_t parallel_id, TaskType task_type) { + const ParallelDesc& parallel_desc = op_node->parallel_desc(); + int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id)); + int64_t dev_phy_id = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id)); + + DeviceId::device_index_t device_index = parallel_desc.device_type() == DeviceType::kCPU + ? 0 + : static_cast(dev_phy_id); + DeviceId device_id{static_cast(machine_id), parallel_desc.device_type(), + device_index}; + StreamId::stream_index_t stream_index = 0; + if (op_node->op().op_conf().has_stream_name_hint()) { + const std::string& stream_name_hint = op_node->op().op_conf().stream_name_hint(); + VLOG(3) << "set op: " << op_node->op().op_name() << " to stream: " << stream_name_hint; + stream_index = Singleton::Get()->GetNamedTaskStreamIndex( + device_id, stream_name_hint); + } else { + stream_index = + Singleton::Get()->GetTaskStreamIndex(task_type, device_id); + } + return StreamId{device_id, stream_index}; +} + +TaskType TaskType4OpNode(const OpNode* op_node) { + std::unique_ptr comp_task_node(NewCompTaskNode4OpNode(op_node)); + return comp_task_node->GetTaskType(); +} + +} // namespace + +CompTaskNode* GenCompTaskNode( + const OpNode* op_node, int64_t parallel_id, + const std::function& + GetOrCreateStreamId) { + const ParallelDesc& parallel_desc = op_node->parallel_desc(); + int64_t parallel_num = parallel_desc.parallel_num(); + CompTaskNode* comp_task_node = NewCompTaskNode4OpNode(op_node); + int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id)); + comp_task_node->set_machine_id(machine_id); + comp_task_node->mut_parallel_ctx()->set_parallel_id(parallel_id); + comp_task_node->mut_parallel_ctx()->set_parallel_num(parallel_num); + StreamId stream_id = GetOrCreateStreamId(op_node, parallel_id, comp_task_node->GetTaskType()); + comp_task_node->set_thrd_id(EncodeStreamIdToInt64(stream_id)); + comp_task_node->set_op_node(op_node); + return comp_task_node; +} + void GenSortedCompTaskNodes(const OpNode* op_node, std::vector* sorted_comp_tasks) { int64_t parallel_idx = 0; const ParallelDesc& parallel_desc = op_node->parallel_desc(); - int64_t parallel_num = parallel_desc.parallel_num(); for (int64_t machine_id : parallel_desc.sorted_machine_ids()) { for (int64_t dev_phy_id : parallel_desc.sorted_dev_phy_ids(machine_id)) { - CompTaskNode* comp_task_node = NewCompTaskNode4OpNode(op_node); - comp_task_node->set_machine_id(machine_id); - comp_task_node->mut_parallel_ctx()->set_parallel_id(parallel_idx++); - comp_task_node->mut_parallel_ctx()->set_parallel_num(parallel_num); - - DeviceId::device_index_t device_index = - parallel_desc.device_type() == DeviceType::kCPU - ? 0 - : static_cast(dev_phy_id); - DeviceId device_id{static_cast(machine_id), parallel_desc.device_type(), - device_index}; - StreamId::stream_index_t stream_index = 0; - if (op_node->op().op_conf().has_stream_name_hint()) { - const std::string& stream_name_hint = op_node->op().op_conf().stream_name_hint(); - VLOG(3) << "set op: " << op_node->op().op_name() << " to stream: " << stream_name_hint; - stream_index = Singleton::Get()->GetNamedTaskStreamIndex( - device_id, stream_name_hint); - } else { - stream_index = Singleton::Get()->GetTaskStreamIndex( - comp_task_node->GetTaskType(), device_id); - } - comp_task_node->set_thrd_id(EncodeStreamIdToInt64(StreamId{device_id, stream_index})); - comp_task_node->set_op_node(op_node); - sorted_comp_tasks->emplace_back(comp_task_node); + sorted_comp_tasks->emplace_back(GenCompTaskNode(op_node, parallel_idx++, &GetStreamId)); + (void)dev_phy_id; } + (void)machine_id; } } @@ -410,11 +445,12 @@ BldSubTskGphMthd GetMthdForBldSubTskGph(const OpEdge* op_edge) { void ForEachOpGraphNecessaryCtrlEdge( const OpGraph* op_graph, const std::function& Handler) { - auto IsOpGraphDataReachable = op_graph->MakePredicatorIsReachable(); + auto IsOpGraphDataReachable = op_graph->CreatePredicatorIsReachable(); op_graph->ForEachNode([&](OpNode* dst) { for (const auto& ctrl_in_op_name : dst->op().op_conf().ctrl_in_op_name()) { const OpNode* src = op_graph->OpNode4OpName(ctrl_in_op_name); CHECK(!IsOpGraphDataReachable(dst, src)); + // src has ctrl to dst, but src has no data path to dst. if (!IsOpGraphDataReachable(src, dst)) { CHECK_EQ(dst->parallel_desc().parallel_num(), src->parallel_desc().parallel_num()); const Shape* src_time_shape = CHECK_JUST(src->op().GetOpTimeShape()).get(); @@ -483,6 +519,9 @@ HashMap* GlobalDeviceType2CreateSubTskGphB } // namespace +TaskGraph::TaskGraph() = default; +TaskGraph::~TaskGraph() = default; + Maybe RegisterCreateSubTskGphBuilderFn(DeviceType device_type, const CreateSubTskGphBuilderFn& fn) { auto* global_device_type_create_sub_tsk_gph_builder_fn = @@ -491,47 +530,6 @@ Maybe RegisterCreateSubTskGphBuilderFn(DeviceType device_type, return Maybe::Ok(); } -TaskGraph::TaskGraph() { - OpGraph* op_graph = Singleton::Get(); - sub_tsk_gph_builder_ctx_.reset(new SubTskGphBuilderCtx(this)); - boxing_logger_ = CreateBoxingLogger(); - const auto* global_device_type_create_sub_tsk_gph_builder_fn = - GlobalDeviceType2CreateSubTskGphBuilderFn(); - for (const auto& pair : *global_device_type_create_sub_tsk_gph_builder_fn) { - device_type2sub_tsk_gph_builder_.emplace(pair.first, pair.second()); - } - hierarchical_sub_tsk_gph_builder_.reset(new DispatchHierarchicalSubTskGphBuilder()); - HashMap> op_node2sorted_comp_tasks; - - op_graph->ForEachNode([&](const OpNode* op_node) { - std::vector* sorted_comp_tasks = &(op_node2sorted_comp_tasks[op_node]); - GenSortedCompTaskNodes(op_node, sorted_comp_tasks); - for (CompTaskNode* comp_task : *sorted_comp_tasks) { AddAllocatedNode(comp_task); } - }); - - op_graph->ForEachEdge([&](const OpEdge* op_edge) { - BldSubTskGphMthd method = GetMthdForBldSubTskGph(op_edge); - (this->*method)(op_edge, op_node2sorted_comp_tasks.at(op_edge->src_node()), - op_node2sorted_comp_tasks.at(op_edge->dst_node())); - }); - - ForEachOpGraphNecessaryCtrlEdge(op_graph, [&](const OpNode* src, const OpNode* dst) { - const auto& src_task_nodes = op_node2sorted_comp_tasks.at(src); - const auto& dst_task_nodes = op_node2sorted_comp_tasks.at(dst); - if (src->op().op_conf().has_src_subset_tick_conf()) { - UNIMPLEMENTED(); - } else if (dst->op().op_conf().has_dst_subset_tick_conf()) { - UNIMPLEMENTED(); - } else { - ConnectCtrlEdges(src_task_nodes, dst_task_nodes); - } - }); - - if (Singleton::Get()->enable_debug_mode()) { ToDotWithAutoFilePath(); } -} - -TaskGraph::~TaskGraph() = default; - TaskEdge* TaskGraph::NewTaskEdgeWithLbi(const LogicalBlobId& lbi) { TaskEdge* edge = NewEdge(); edge->AddLbi(lbi); @@ -604,15 +602,19 @@ TaskNode* TaskGraph::GetProxyNode(TaskNode* src_node, const LogicalBlobId& lbi, return GetProxyNode(src_node, lbi, mem_zone_id); } +void TaskGraph::ConnectCtrlEdge(CompTaskNode* src_task_node, CompTaskNode* dst_task_node) { + std::string regst_desc_name; + src_task_node->BuildCtrlRegstDesc(dst_task_node, ®st_desc_name); + TaskEdge* edge = NewEdge(); + Connect(src_task_node, edge, dst_task_node); + src_task_node->BindEdgeWithProducedRegst(edge, regst_desc_name); +} + void TaskGraph::ConnectCtrlEdges(const std::vector& src_task_nodes, const std::vector& dst_task_nodes) { CHECK_EQ(src_task_nodes.size(), dst_task_nodes.size()); FOR_RANGE(int32_t, i, 0, src_task_nodes.size()) { - std::string regst_desc_name; - src_task_nodes.at(i)->BuildCtrlRegstDesc(dst_task_nodes.at(i), ®st_desc_name); - TaskEdge* edge = NewEdge(); - Connect(src_task_nodes.at(i), edge, dst_task_nodes.at(i)); - src_task_nodes.at(i)->BindEdgeWithProducedRegst(edge, regst_desc_name); + ConnectCtrlEdge(src_task_nodes.at(i), dst_task_nodes.at(i)); } } @@ -625,12 +627,14 @@ void TaskGraph::RemoveEmptyRegsts() { void TaskGraph::MergeChainAndAddOrderingCtrlEdgeInSameChain() { if (EnableLogicalChain()) { + // Ctrl edges in chain has already been added in logical chain pass, so + // there is no need to call BuildCtrlRegstDescInSameChain here. MergeChainByLogicalChainId(); } else { // TODO(chengcheng): erase old chain version in the future. MergeChainByPhysicalTaskGraph(); + BuildCtrlRegstDescInSameChain(); } - BuildCtrlRegstDescInSameChain(); } void TaskGraph::InitOrderedTaskNodes() { @@ -687,6 +691,8 @@ void TaskGraph::BuildCtrlRegstDescInSameChain() { return (node->chain_id() << 31) | (node->machine_id()); }; HashMap physical_chain_id2node; + // Note that ordered_task_nodes_'s topology order in seperation plan compile is not gerenteed, + // So add ctrl edge with ordered_task_nodes_ in seperation plan compile may case dead lock. for (auto* node : ordered_task_nodes_) { if (IsConnectToTickOp(node)) { continue; } // NOTE(chengcheng): skip invalid chain id @@ -806,12 +812,34 @@ void TaskGraph::EnableInplaceMemSharing( const std::function& IsOpNameDataOrCtrlReachable) { ForEachGpuDeviceNodes([&](const HashSet& dev_nodes) { - InplaceObasInfo safe_inplace_obas_info; - GetSafeInplaceOpBlobArgList(&safe_inplace_obas_info, dev_nodes, IsOpNameDataOrCtrlReachable); - SetTaskRegstInplaceInfo(safe_inplace_obas_info, dev_nodes); + EnableInplaceMemSharing(dev_nodes, IsOpNameDataOrCtrlReachable); }); } +void TaskGraph::EnableInplaceMemSharing( + const HashSet& dev_nodes, + const std::function& + IsOpNameDataOrCtrlReachable) { + InplaceObasInfo safe_inplace_obas_info; + GetSafeInplaceOpBlobArgList(&safe_inplace_obas_info, dev_nodes, IsOpNameDataOrCtrlReachable); + SetTaskRegstInplaceInfo(safe_inplace_obas_info, dev_nodes); +} + +void TaskGraph::DecideExecutionOrder() { + // For one machine with no transfer available, the straighten algorithm for overlaps consume a lot + // of memory + StraightenAlgorithmTag straighten_algorithm_tag = + GlobalJobDesc().job_conf().straighten_algorithm_tag_in_task_graph(); + if (straighten_algorithm_tag == StraightenAlgorithmTag::kDisableStraighten + || (straighten_algorithm_tag == StraightenAlgorithmTag::kOverlap4Transfer + && GlobalProcessCtx::WorldSize() == 1)) { + InitOrderedTaskNodes(); + } else { + StraightenNodes(this, &ordered_task_nodes_, + Singleton::Get()->nccl_use_compute_stream()); + } +} + #define DEFINE_BLD_SUB_TASK_GRAPH_METHOD(method_name) \ void TaskGraph::method_name BLD_SUB_TSK_GPH_MTHD_ARGS() @@ -852,16 +880,14 @@ DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBoxing) { if (device_type != DeviceType::kCPU && device_type2sub_tsk_gph_builder_.find(device_type) != device_type2sub_tsk_gph_builder_.end()) { - auto maybe_status = TRY( // NOLINT + status = CHECK_JUST( // NOLINT device_type2sub_tsk_gph_builder_ // NOLINT .at(device_type) // NOLINT ->Build(sub_tsk_gph_builder_ctx_.get(), in_nodes, &out_nodes, // NOLINT &sorted_ctrl_tasks, src_parallel_desc, dst_parallel_desc, lbi, // NOLINT blob_desc, src_nd_sbp, dst_nd_sbp, // NOLINT *(CHECK_JUST(src_op_node->op().GetOpTimeShape()).get()))); // NOLINT - if (maybe_status.IsOk()) { status = CHECK_JUST(maybe_status); } - } - if (!status) { + } else { status = CHECK_JUST(hierarchical_sub_tsk_gph_builder_->Build( sub_tsk_gph_builder_ctx_.get(), in_nodes, &out_nodes, &sorted_ctrl_tasks, src_parallel_desc, dst_parallel_desc, lbi, blob_desc, src_nd_sbp, dst_nd_sbp, @@ -1022,19 +1048,442 @@ void TaskGraph::BuildTaskPath(TaskNode* src_node, TaskNode* dst_node, const Logi ConnectWithLbi(proxy_node, dst_node, lbi); } -void TaskGraph::DecideExecutionOrder() { - // For one machine with no transfer available, the straighten algorithm for overlaps consume a lot - // of memory - StraightenAlgorithmTag straighten_algorithm_tag = - GlobalJobDesc().job_conf().straighten_algorithm_tag_in_task_graph(); - if (straighten_algorithm_tag == StraightenAlgorithmTag::kDisableStraighten - || (straighten_algorithm_tag == StraightenAlgorithmTag::kOverlap4Transfer - && GlobalProcessCtx::WorldSize() == 1)) { - InitOrderedTaskNodes(); +Maybe GlobalTaskGraph::Init() { + OpGraph* op_graph = Singleton::Get(); + sub_tsk_gph_builder_ctx_.reset(new SubTskGphBuilderCtx(this)); + boxing_logger_ = CreateBoxingLogger(); + hierarchical_sub_tsk_gph_builder_.reset(new DispatchHierarchicalSubTskGphBuilder()); + HashMap> op_node2sorted_comp_tasks; + + op_graph->ForEachNode([&](const OpNode* op_node) { + std::vector* sorted_comp_tasks = &(op_node2sorted_comp_tasks[op_node]); + GenSortedCompTaskNodes(op_node, sorted_comp_tasks); + for (CompTaskNode* comp_task : *sorted_comp_tasks) { AddAllocatedNode(comp_task); } + }); + + op_graph->ForEachEdge([&](const OpEdge* op_edge) { + BldSubTskGphMthd method = GetMthdForBldSubTskGph(op_edge); + (this->*method)(op_edge, op_node2sorted_comp_tasks.at(op_edge->src_node()), + op_node2sorted_comp_tasks.at(op_edge->dst_node())); + }); + + ForEachOpGraphNecessaryCtrlEdge(op_graph, [&](const OpNode* src, const OpNode* dst) { + const auto& src_task_nodes = op_node2sorted_comp_tasks.at(src); + const auto& dst_task_nodes = op_node2sorted_comp_tasks.at(dst); + if (src->op().op_conf().has_src_subset_tick_conf()) { + UNIMPLEMENTED(); + } else if (dst->op().op_conf().has_dst_subset_tick_conf()) { + UNIMPLEMENTED(); + } else { + ConnectCtrlEdges(src_task_nodes, dst_task_nodes); + } + }); + + if (Singleton::Get()->enable_debug_mode()) { ToDotWithAutoFilePath(); } + return Maybe::Ok(); +} + +Maybe BoxingTaskGraph::Init( + const std::function&)>& ParallelRunLoop) { + OpGraph* op_graph = Singleton::Get(); + sub_tsk_gph_builder_ctx_.reset(new SubTskGphBuilderCtx(this)); + boxing_logger_ = CreateBoxingLogger(); + hierarchical_sub_tsk_gph_builder_.reset(new DispatchHierarchicalSubTskGphBuilder()); + + const auto& TryCreateSortedCompTaskNodes = [&](const OpNode* op_node) { + if (boxing_related_op_node2sorted_comp_tasks_.count(op_node) > 0) { return; } + std::vector* sorted_comp_tasks = + &(boxing_related_op_node2sorted_comp_tasks_[op_node]); + GenSortedCompTaskNodes(op_node, sorted_comp_tasks); + for (CompTaskNode* comp_task : *sorted_comp_tasks) { AddAllocatedNode(comp_task); } + }; + op_graph->ForEachEdge([&](const OpEdge* op_edge) { + if (!op_edge->NeedBoxing()) { return; } + TryCreateSortedCompTaskNodes(op_edge->src_node()); + TryCreateSortedCompTaskNodes(op_edge->dst_node()); + BldSubTskGphMthd method = GetMthdForBldSubTskGph(op_edge); + (this->*method)(op_edge, boxing_related_op_node2sorted_comp_tasks_.at(op_edge->src_node()), + boxing_related_op_node2sorted_comp_tasks_.at(op_edge->dst_node())); + }); + ForEachNode(std::bind(&TaskNode::ProduceAllRegstsAndBindEdges, std::placeholders::_1)); + CreateOpNode2TaskIds(ParallelRunLoop); + return Maybe::Ok(); +} + +void BoxingTaskGraph::CreateOpNode2TaskIds( + const std::function&)>& ParallelRunLoop) { + const OpGraph* op_graph = Singleton::Get(); + std::vector op_nodes; + op_nodes.reserve(op_graph->node_num()); + op_graph->ForEachNode([&](OpNode* op_node) { + if (boxing_related_op_node2sorted_comp_tasks_.count(op_node) == 0) { + op_nodes.push_back(op_node); + boxing_unrelated_op_node2sorted_task_ids_[op_node].reserve( + op_node->parallel_desc().parallel_num()); + } + }); + ParallelRunLoop(op_nodes.size(), [&](size_t i) { + const OpNode* op_node = op_nodes.at(i); + TaskType task_type = TaskType4OpNode(op_node); + const auto& parallel_desc = op_node->parallel_desc(); + auto* task_ids = &boxing_unrelated_op_node2sorted_task_ids_[op_node]; + for (int parallel_id = 0; parallel_id < parallel_desc.parallel_num(); ++parallel_id) { + const auto& stream_id = GetStreamId(op_node, parallel_id, task_type); + task_ids->push_back(Singleton::Get()->GetTaskIdGenerator()->Generate(stream_id)); + } + }); +} + +namespace { + +bool IsComputTaskNodeDutyRank(int64_t current_rank, const ParallelDesc& parallel_desc, + int64_t task_node_rank) { + if (current_rank == 0) { + // make sure master knows at least one op_node. + return CHECK_JUST(parallel_desc.MachineId4ParallelId(0)) == task_node_rank; + } else if (parallel_desc.HasMachineId(current_rank)) { + // workers only care their own rank. + return current_rank == task_node_rank; } else { - StraightenNodes(this, &ordered_task_nodes_, - Singleton::Get()->nccl_use_compute_stream()); + return false; + } +} + +// A template function to process task node for different task node type. +// RetT, function return type +// HandleTansportTaskNode, if the task node is a transport task node, call this processing function +// HandleComputeTaskNode, if the task node is a compute task node, call this processing +// task_node, the input task node +template +RetT TaskNodeVisitor(TaskNode* task_node, const HandleTansportTaskNodeT& HandleTansportTaskNode, + const HandleComputeTaskNodeT& HandleComputeTaskNode) { + auto* transport_task_node = dynamic_cast(task_node); + if (transport_task_node != nullptr) { + return HandleTansportTaskNode(transport_task_node); + } else { + auto* comp_task_node = dynamic_cast(task_node); + if (comp_task_node != nullptr) { + return HandleComputeTaskNode(comp_task_node); + } else { + UNIMPLEMENTED(); + } + } +} + +} // namespace + +/*static*/ bool BoxingTaskGraph::SelectTaskNodeByRank(TaskNode* task_node, int64_t rank) { + return TaskNodeVisitor( + task_node, [&](TransportTaskNode* task_node) { return task_node->machine_id() == rank; }, + [&](CompTaskNode* task_node) { + const auto& machine_id = task_node->machine_id(); + return IsComputTaskNodeDutyRank(rank, task_node->op_node()->parallel_desc(), machine_id); + }); +} + +void BoxingTaskGraph::ToProto(const std::function& Pick, + BoxingTaskGraphProto* proto) const { + const auto sources = [&]() -> std::list { + HashSet sources; + ForEachNode([&](TaskNode* task_node) { + if (Pick(task_node)) { sources.insert(task_node); } + }); + HashSet sources_out; + for (auto* source : sources) { + // The consumed task_ids must be generated from out_nodes. + source->ForEachNodeOnOutEdge([&](TaskNode* out_node) { + if (!sources.count(out_node)) { sources_out.insert(out_node); } + }); + } + sources.insert(sources_out.begin(), sources_out.end()); + return std::list{sources.begin(), sources.end()}; + }(); + const auto& TransportTaskNodeToProto = [&](TransportTaskNode* task_node) { + task_node->ToTransportTaskProtoIf(proto->mutable_transport_task()->Add()); + }; + const auto& ComputeTaskNodeToProto = [&](CompTaskNode* task_node) { + auto* map = proto->mutable_boxing_related_op_name2compute_tasks(); + const auto& op_name = task_node->op_node()->op().op_name(); + auto* parallel_id2task_proto = (*map)[op_name].mutable_parallel_id2task(); + int64_t parallel_id = task_node->parallel_id(); + task_node->ToProto(&(*parallel_id2task_proto)[parallel_id], /*check=*/false); + }; + HashSet rank_task_nodes; + BfsForEachNode(sources, &TaskNode::ForEachNodeOnInEdge, [&](TaskNode* task_node) { + rank_task_nodes.insert(task_node); + TaskNodeVisitor(task_node, TransportTaskNodeToProto, ComputeTaskNodeToProto); + }); + const auto rank_task_edges = [&] { + HashSet rank_task_edges; + const auto& TryInsertEdge = [&](TaskEdge* edge) { + if (rank_task_nodes.count(edge->src_node()) > 0 + && rank_task_nodes.count(edge->dst_node()) > 0) { + rank_task_edges.insert(edge); + } + }; + for (const auto* task_node : rank_task_nodes) { + for (auto* in_edge : task_node->in_edges()) { TryInsertEdge(in_edge); } + for (auto* out_edge : task_node->out_edges()) { TryInsertEdge(out_edge); } + } + return rank_task_edges; + }(); + for (auto* edge : rank_task_edges) { edge->ToProto(proto->mutable_task_edge()->Add()); } + for (const auto& pair : boxing_unrelated_op_node2sorted_task_ids_) { + const auto& op_name = pair.first->op().op_name(); + auto* vec = &(*proto->mutable_boxing_unrelated_op_name2task_ids())[op_name]; + for (const auto& task_id : pair.second) { vec->add_task_id(EncodeTaskIdToInt64(task_id)); } + } +} + +RankTaskGraph::RankTaskGraph(const std::shared_ptr& boxing_task_graph_proto, + int64_t current_rank) + : boxing_task_graph_proto_(boxing_task_graph_proto), + current_rank_(current_rank), + task_graph_rebuild_ctx_(std::make_unique()) {} + +Maybe RankTaskGraph::TryGetBoxingRelatedComTaskNode(const OpNode* op_node, + int64_t parallel_id) { + const auto& op_name = op_node->op().op_name(); + auto iter = boxing_task_graph_proto_->boxing_related_op_name2compute_tasks().find(op_name); + if (iter == boxing_task_graph_proto_->boxing_related_op_name2compute_tasks().end()) { + return nullptr; + } + if (iter == boxing_task_graph_proto_->boxing_related_op_name2compute_tasks().end()) { + return nullptr; + } + auto task_iter = iter->second.parallel_id2task().find(parallel_id); + if (task_iter == iter->second.parallel_id2task().end()) { return nullptr; } + int64_t task_id = task_iter->second.task_id(); + auto* task_node = JUST(task_graph_rebuild_ctx_->TaskNode4Id(task_id)); + auto* comp_task_node = dynamic_cast(task_node); + CHECK_NOTNULL_OR_RETURN(comp_task_node) << "invalid task_type. task_id: " << task_id; + return comp_task_node; +} + +Maybe RankTaskGraph::CreateOrFindRankCompTaskNodeByParallelId(const OpNode* op_node, + int64_t parallel_id) { + auto* comp_task_node = JUST(TryGetBoxingRelatedComTaskNode(op_node, parallel_id)); + if (comp_task_node != nullptr) { return comp_task_node; } + auto iter = op_node2comp_task_node_.find(op_node); + if (iter != op_node2comp_task_node_.end()) { return iter->second; } + + const TaskId task_id = *JUST([&]() -> Maybe { + const auto& map = boxing_task_graph_proto_->boxing_unrelated_op_name2task_ids(); + const auto& iter = map.find(op_node->op().op_name()); + CHECK_OR_RETURN(iter != map.end()); + CHECK_LT_OR_RETURN(parallel_id, iter->second.task_id_size()); + return DecodeTaskIdFromInt64(iter->second.task_id().Get(parallel_id)); + }()); + const auto& GetStreamIdFromMaster = [&](const OpNode* op_node, int64_t parallel_id, TaskType) { + return task_id.stream_id(); + }; + auto comp_task_node_ptr = GenCompTaskNode(op_node, parallel_id, GetStreamIdFromMaster); + comp_task_node_ptr->update_new_task_id(task_id); + AddAllocatedNode(comp_task_node_ptr); + CHECK_OR_RETURN(op_node2comp_task_node_.emplace(op_node, comp_task_node_ptr).second) + << "Got dupliacted op_node " << op_node->op().op_name(); + return comp_task_node_ptr; +} + +Maybe RankTaskGraph::CreateOrFindRankCompTaskNodeByRank(const OpNode* op_node, + int64_t rank) { + CHECK_OR_RETURN(op_node->parallel_desc().HasMachineId(rank)) + << "rank is not contained in the placment"; + int64_t parallel_id = -1; + CHECK_OR_RETURN(JUST(op_node->parallel_desc().TryGetParallelId(rank, ¶llel_id))) + << "parallel_id not found."; + return CreateOrFindRankCompTaskNodeByParallelId(op_node, parallel_id); +} + +Maybe RankTaskGraph::TryGetRankCompTaskNode(const OpNode* op_node, int64_t rank) { + if (!op_node->parallel_desc().HasMachineId(rank)) { return nullptr; } + int64_t parallel_id = -1; + CHECK_OR_RETURN(JUST(op_node->parallel_desc().TryGetParallelId(rank, ¶llel_id))) + << "parallel_id not found."; + auto* comp_task_node = JUST(TryGetBoxingRelatedComTaskNode(op_node, parallel_id)); + if (comp_task_node != nullptr) { return comp_task_node; } + auto iter = op_node2comp_task_node_.find(op_node); + CHECK_OR_RETURN(iter != op_node2comp_task_node_.end()) + << "op_node " << op_node->op().op_name() << " not found."; + return iter->second; +} + +Maybe RankTaskGraph::AddBoxingReletedCompTaskNodesFromProto() { + OpGraph* op_graph = Singleton::Get(); + for (const auto& pair : boxing_task_graph_proto_->boxing_related_op_name2compute_tasks()) { + const OpNode* op_node = op_graph->OpNode4OpName(pair.first); + for (const auto& pair : pair.second.parallel_id2task()) { + const auto& task_proto = pair.second; + CHECK_OR_RETURN(task_id2task_proto_.emplace(task_proto.task_id(), &task_proto).second) + << "redundant task_id."; + CompTaskNode* comp_task_node = NewCompTaskNode4OpNode(op_node); + comp_task_node->set_op_node(op_node); + AddAllocatedNode(comp_task_node); + // Note here has no consume regst + // Init task node and produce regst + comp_task_node->InitFromProtoExceptConsumedRegsts(task_proto); + JUST(task_graph_rebuild_ctx_->AddTaskNode(comp_task_node)); + } + } + return Maybe::Ok(); +} + +Maybe RankTaskGraph::CreateAndPartiallyInitTransportTaskNodesFromProto() { + for (const auto& transport_task_proto : boxing_task_graph_proto_->transport_task()) { + const auto& task_proto = transport_task_proto.task_proto(); + CHECK_OR_RETURN(task_id2task_proto_.emplace(task_proto.task_id(), &task_proto).second) + << "redundant task_id."; + auto* task_node = JUST(CreateTransportTask::Visit(task_proto.task_type())); + AddAllocatedNode(task_node); + // Init task node and produce regst + task_node->InitFromProtoExceptConsumedRegsts(transport_task_proto.task_proto()); + JUST(task_graph_rebuild_ctx_->AddTaskNode(task_node)); + } + return Maybe::Ok(); +} + +Maybe RankTaskGraph::AddTransportTaskEdgesFromProto() { + for (const auto& task_edge_proto : boxing_task_graph_proto_->task_edge()) { + TaskEdge* edge = NewEdge(); + auto* src_task_node = JUST(task_graph_rebuild_ctx_->TaskNode4Id(task_edge_proto.src_task_id())); + auto* dst_task_node = JUST(task_graph_rebuild_ctx_->TaskNode4Id(task_edge_proto.dst_task_id())); + Connect(src_task_node, edge, dst_task_node); + JUST(edge->InitFromProto(task_edge_proto, *task_graph_rebuild_ctx_)); + JUST(task_graph_rebuild_ctx_->AddTaskEdge(edge, task_edge_proto.task_edge_uid())); + } + return Maybe::Ok(); +} + +Maybe RankTaskGraph::InitTransportTaskNodesFromProto() { + for (const auto& transport_task_proto : boxing_task_graph_proto_->transport_task()) { + int64_t task_id = transport_task_proto.task_proto().task_id(); + auto* task_node = JUST(task_graph_rebuild_ctx_->TaskNode4Id(task_id)); + auto* transport_task_node = dynamic_cast(task_node); + CHECK_NOTNULL_OR_RETURN(transport_task_node) + << "task node is not a TransportTaskNode. task_id" << task_id; + JUST(transport_task_node->InitTransportTaskFromProtoIf(transport_task_proto, + *task_graph_rebuild_ctx_)); + } + return Maybe::Ok(); +} + +bool RankTaskGraph::ContainRank(const OpNode* op_node, int64_t rank) const { + return op_node->parallel_desc().HasMachineId(rank); +} + +Maybe RankTaskGraph::ConnectDataEdges(const OpEdge* op_edge, int64_t rank) { + if (!op_edge->NeedBoxing()) { + auto* src_task_node = JUST(TryGetRankCompTaskNode(op_edge->src_node(), rank)); + auto* dst_task_node = JUST(TryGetRankCompTaskNode(op_edge->dst_node(), rank)); + if (ContainRank(op_edge->src_node(), rank)) { + CHECK_NOTNULL_OR_RETURN(src_task_node) << "src_task_node should not be nullptr. op_name: " + << op_edge->src_node()->op().op_name(); + } + if (ContainRank(op_edge->dst_node(), rank)) { + CHECK_NOTNULL_OR_RETURN(dst_task_node) << "dst_task_node should not be nullptr. op_name: " + << op_edge->dst_node()->op().op_name(); + } + if (src_task_node != nullptr && dst_task_node != nullptr) { + for (const auto& lbi : op_edge->lbis()) { ConnectWithLbi(src_task_node, dst_task_node, lbi); } + } } + return Maybe::Ok(); } +Maybe RankTaskGraph::ConnectCtrlEdges(const OpNode* src, const OpNode* dst, int64_t rank) { + if ((ContainRank(src, rank) && ContainRank(dst, rank))) { + auto* src_task_node = CHECK_JUST(TryGetRankCompTaskNode(src, rank)); + auto* dst_task_node = CHECK_JUST(TryGetRankCompTaskNode(dst, rank)); + if (src->op().op_conf().has_src_subset_tick_conf()) { + UNIMPLEMENTED_THEN_RETURN() << "ctrl edge from src_subset_tick is not supported."; + } else if (dst->op().op_conf().has_dst_subset_tick_conf()) { + UNIMPLEMENTED_THEN_RETURN() << "ctrl edge to dst_subset_tick is not supported."; + } else { + ConnectCtrlEdge(CHECK_NOTNULL(src_task_node), CHECK_NOTNULL(dst_task_node)); + } + } + return Maybe::Ok(); +} + +bool RankTaskGraph::IsDutyRank(const ParallelDesc& parallel_desc, int64_t rank) const { + return IsComputTaskNodeDutyRank(current_rank_, parallel_desc, rank); +} + +template +Maybe RankTaskGraph::DoRankDuty(const ParallelDesc& parallel_desc, + const DoEachRankT& DoWithRank) { + if (current_rank_ == 0) { + // make sure master knows at least one op_node. + JUST(DoWithRank(JUST(parallel_desc.MachineId4ParallelId(0)))); + } else if (parallel_desc.HasMachineId(current_rank_)) { + // workers only care their own rank. + JUST(DoWithRank(current_rank_)); + } else { + // Do nothing. + } + return Maybe::Ok(); +} + +Maybe RankTaskGraph::InitRegstDescsConsumers() { + const auto& RegstDesc4Id = [&](int64_t regst_desc_id) -> Maybe { + return JUST(task_graph_rebuild_ctx_->RegstDesc4Id(regst_desc_id)); + }; + JUST(MaybeForEachNode([&](TaskNode* task_node) -> Maybe { + const auto& task_proto = *JUST(MapAt(task_id2task_proto_, task_node->task_id())); + JUST(task_node->InitConsumedRegstsFromProto(task_proto, RegstDesc4Id)); + return Maybe::Ok(); + })); + return Maybe::Ok(); +} + +Maybe RankTaskGraph::Init(const HashSet& var_op_names) { + JUST(AddBoxingReletedCompTaskNodesFromProto()); + JUST(CreateAndPartiallyInitTransportTaskNodesFromProto()); + JUST(AddTransportTaskEdgesFromProto()); + JUST(InitTransportTaskNodesFromProto()); + JUST(InitRegstDescsConsumers()); + // Note that tasks currently added in above code are from BoxingTaskGraph, so they are all + // boxing related. + OpGraph* op_graph = Singleton::Get(); + JUST(op_graph->MaybeForEachNode([&](OpNode* op_node) -> Maybe { + JUST(DoRankDuty(op_node->parallel_desc(), [&](int64_t rank) -> Maybe { + JUST(CreateOrFindRankCompTaskNodeByRank(op_node, rank)); + return Maybe::Ok(); + })); + if (var_op_names.count(op_node->op().op_name()) > 0 + && !IsDutyRank(op_node->parallel_desc(), current_rank_)) { + // To makes sure all ranks know all var_op_names, at least one task for variable op is + // needed in the plan. + JUST(CreateOrFindRankCompTaskNodeByParallelId(op_node, /*parallel_id=*/0)); + } + return Maybe::Ok(); + })); + + JUST(op_graph->MaybeForEachEdge([&](const OpEdge* op_edge) -> Maybe { + return DoRankDuty(op_edge->src_node()->parallel_desc(), + [&](int64_t rank) { return ConnectDataEdges(op_edge, rank); }); + })); + + ForEachOpGraphNecessaryCtrlEdge(op_graph, [&](const OpNode* src, const OpNode* dst) { + if (!src->parallel_desc_sym()->EqualsIgnoringHierarchy(*dst->parallel_desc_sym())) { + LOG(INFO) << " src " << src->parallel_desc_sym()->data().DebugString() << " dst " + << dst->parallel_desc_sym()->data().DebugString(); + return; + } + CHECK_JUST(DoRankDuty(src->parallel_desc(), + [&](int64_t rank) { return ConnectCtrlEdges(src, dst, rank); })); + }); + + if (Singleton::Get()->enable_debug_mode()) { ToDotWithAutoFilePath(); } + + ForEachNode([&](TaskNode* task_node) { task_node->ProduceAllRegstsAndBindEdges(); }); + ForEachEdge([&](TaskEdge* edge) { + CHECK(edge->HasRegst()) << "Found edge which has not bound a regst, src task " + << edge->src_node()->VisualStr(); + }); + return Maybe::Ok(); +} + +RankTaskGraph::~RankTaskGraph() {} + } // namespace oneflow diff --git a/oneflow/core/graph/task_graph.h b/oneflow/core/graph/task_graph.h index 676ab8ed7ad..3002d04d2c7 100644 --- a/oneflow/core/graph/task_graph.h +++ b/oneflow/core/graph/task_graph.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef ONEFLOW_CORE_GRAPH_TASK_GRAPH_H_ #define ONEFLOW_CORE_GRAPH_TASK_GRAPH_H_ +#include "oneflow/core/graph/task_node.h" #include "oneflow/core/job/id_manager.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/operator/operator.h" @@ -38,12 +39,10 @@ class HierarchicalSubTskGphBuilder; class TaskGraph; using BldSubTskGphMthd = void(TaskGraph::*) BLD_SUB_TSK_GPH_MTHD_ARGS(); -class TaskGraph final : public Graph { +class TaskGraph : public Graph { public: OF_DISALLOW_COPY_AND_MOVE(TaskGraph); - ~TaskGraph() override; - - explicit TaskGraph(); + virtual ~TaskGraph() override; const char* TypeName() const override { return "TaskGraph"; } void RemoveEmptyRegsts(); @@ -53,6 +52,10 @@ class TaskGraph final : public Graph { void EnableInplaceMemSharing(const std::function& IsOpNameDataOrCtrlReachable); + void EnableInplaceMemSharing(const HashSet& dev_nodes, + const std::function& + IsOpNameDataOrCtrlReachable); + TaskNode* GetProxyNode(TaskNode* src_node, const LogicalBlobId& lbi, const MemZoneId& dst_mem_zone_id); @@ -75,13 +78,20 @@ class TaskGraph final : public Graph { DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByDstSubsetConnect); DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphNormalForwardToDecodeH2D); - private: + void ForEachGpuDeviceNodes( + const std::function& dev_nodes)>& Handler) const; + + protected: + explicit TaskGraph(); + void BuildTaskPath(TaskNode* src_node, TaskNode* dst_node, const LogicalBlobId& lbi, bool is_host_mem_input); void ConnectCtrlEdges(const std::vector& src_task_nodes, const std::vector& dst_task_nodes); + void ConnectCtrlEdge(CompTaskNode* src_task_node, CompTaskNode* dst_task_node); + void InitOrderedTaskNodes(); void MergeChainByPhysicalTaskGraph(); void MergeChainByLogicalChainId(); @@ -97,9 +107,6 @@ class TaskGraph final : public Graph { IsOpNameDataOrCtrlReachable) const; void SetTaskRegstInplaceInfo(const InplaceObasInfo& obas_info, const HashSet& dev_nodes) const; - void ForEachGpuDeviceNodes( - const std::function& dev_nodes)>& Handler) const; - std::vector ordered_task_nodes_; HashMap> device_type2sub_tsk_gph_builder_; @@ -130,6 +137,97 @@ class TaskGraph final : public Graph { HashMap proxy2node; }; +class GlobalTaskGraph final : public TaskGraph { + public: + OF_DISALLOW_COPY_AND_MOVE(GlobalTaskGraph); + ~GlobalTaskGraph() = default; + static Maybe New() { + std::shared_ptr graph(new GlobalTaskGraph()); + JUST(graph->Init()); + return graph; + } + + private: + GlobalTaskGraph() = default; + Maybe Init(); +}; + +class BoxingTaskGraphProto; + +class BoxingTaskGraph final : public TaskGraph { + public: + OF_DISALLOW_COPY_AND_MOVE(BoxingTaskGraph); + ~BoxingTaskGraph() = default; + + static Maybe New( + const std::function&)>& ParallelRunLoop) { + std::shared_ptr graph(new BoxingTaskGraph()); + JUST(graph->Init(ParallelRunLoop)); + return graph; + } + + void ToProto(const std::function& Pick, BoxingTaskGraphProto* proto) const; + static bool SelectTaskNodeByRank(TaskNode*, int64_t rank); + + private: + BoxingTaskGraph() = default; + Maybe Init( + const std::function&)>& ParallelRunLoop); + + void CreateOpNode2TaskIds( + const std::function&)>& ParallelRunLoop); + + HashMap> boxing_related_op_node2sorted_comp_tasks_; + HashMap> boxing_unrelated_op_node2sorted_task_ids_; +}; + +class TaskGraphRebuildCtx; + +class RankTaskGraph final : public TaskGraph { + public: + OF_DISALLOW_COPY_AND_MOVE(RankTaskGraph); + ~RankTaskGraph(); + + static Maybe New( + const std::shared_ptr& boxing_task_graph_proto, + const HashSet& var_op_names, int64_t current_rank) { + std::shared_ptr graph(new RankTaskGraph(boxing_task_graph_proto, current_rank)); + JUST(graph->Init(var_op_names)); + return graph; + } + + // Is `rank` my duty. + bool IsDutyRank(const ParallelDesc& parallel_desc, int64_t rank) const; + + private: + RankTaskGraph(const std::shared_ptr& boxing_task_graph_proto, int64_t rank); + + Maybe Init(const HashSet& var_op_names); + bool ContainRank(const OpNode* op_node, int64_t rank) const; + Maybe AddBoxingReletedCompTaskNodesFromProto(); + Maybe CreateAndPartiallyInitTransportTaskNodesFromProto(); + Maybe AddTransportTaskEdgesFromProto(); + Maybe InitTransportTaskNodesFromProto(); + Maybe InitRegstDescsConsumers(); + template + Maybe DoRankDuty(const ParallelDesc& parallel_desc, const DoEachRankT& DoWithRank); + + Maybe TryGetBoxingRelatedComTaskNode(const OpNode* op_node, int64_t parallel_id); + Maybe CreateOrFindRankCompTaskNodeByParallelId(const OpNode* op_node, + int64_t parallel_id); + Maybe CreateOrFindRankCompTaskNodeByRank(const OpNode* op_node, int64_t rank); + Maybe TryGetRankCompTaskNode(const OpNode* op_node, int64_t rank); + + Maybe ConnectDataEdges(const OpEdge* op_edge, int64_t rank); + Maybe ConnectCtrlEdges(const OpNode* src, const OpNode* dst, int64_t rank); + + std::shared_ptr boxing_task_graph_proto_; + HashMap task_id2task_proto_; + const int64_t current_rank_; + std::unique_ptr task_graph_rebuild_ctx_; + HashMap op_node2comp_task_node_; +}; + using CreateSubTskGphBuilderFn = std::function()>; Maybe RegisterCreateSubTskGphBuilderFn(DeviceType device_type, diff --git a/oneflow/core/graph/task_node.cpp b/oneflow/core/graph/task_node.cpp index affa2818481..ae1df0aac73 100644 --- a/oneflow/core/graph/task_node.cpp +++ b/oneflow/core/graph/task_node.cpp @@ -393,6 +393,13 @@ void TaskNode::UpdateTaskId() { task_id_ = EncodeTaskIdToInt64(*new_task_id_); } +void TaskNode::update_new_task_id(const TaskId& task_id) { + CHECK(static_cast(new_task_id_)); + CHECK(new_task_id_->stream_id() == task_id.stream_id()); + *new_task_id_ = task_id; + task_id_ = EncodeTaskIdToInt64(*new_task_id_); +} + void TaskNode::EraseConsumedRegstsByName(const std::string& name) { if (consumed_regsts_.find(name) != consumed_regsts_.end()) { for (auto& regst : consumed_regsts_[name]) { regst->DeleteConsumer(this); } diff --git a/oneflow/core/graph/task_node.h b/oneflow/core/graph/task_node.h index 81937610ae1..4c8dc28378e 100644 --- a/oneflow/core/graph/task_node.h +++ b/oneflow/core/graph/task_node.h @@ -124,6 +124,11 @@ class TaskNode : public Node { TaskEdge* SoleOutDataEdge() const; size_t in_data_edges_size() const; size_t out_data_edges_size() const; + const TaskId& new_task_id() const { + CHECK(has_new_task_id()); + return *new_task_id_; + } + void update_new_task_id(const TaskId& task_id); bool has_new_task_id() const { return static_cast(new_task_id_); } protected: diff --git a/oneflow/core/job/compiler.cpp b/oneflow/core/job/compiler.cpp index 515ed921837..a3ddfeeb9c6 100644 --- a/oneflow/core/job/compiler.cpp +++ b/oneflow/core/job/compiler.cpp @@ -27,26 +27,6 @@ limitations under the License. namespace oneflow { -void CreateOpAttributeRef(Plan* plan, int64_t job_id, TaskProto* task_proto) { - auto* job_id2op_attribute_ref_table = plan->mutable_job_id2op_attribute_ref_table(); - CHECK(task_proto->exec_sequence().exec_node_size() == 1); - auto* exec_node = task_proto->mutable_exec_sequence()->mutable_exec_node(0); - CHECK(exec_node->kernel_conf().has_op_attribute()); - const std::string op_name = exec_node->kernel_conf().op_attribute().op_conf().name(); - auto* op_name2op_attribute = - (*job_id2op_attribute_ref_table)[job_id].mutable_op_name2op_attribute(); - auto find_it = op_name2op_attribute->find(op_name); - if (find_it == op_name2op_attribute->end()) { - op_name2op_attribute->insert( - {op_name, task_proto->exec_sequence().exec_node(0).kernel_conf().op_attribute()}); - } - auto* kernel_conf = - task_proto->mutable_exec_sequence()->mutable_exec_node(0)->mutable_kernel_conf(); - kernel_conf->set_op_attribute_ref(op_name); - // NOTE(levi): memory of op_attribute_ is released here. - kernel_conf->set_allocated_op_attribute(nullptr); -} - void Compiler::Compile(Job* job, Plan* plan) const { const auto& job_name = job->job_conf().job_name(); auto compile_tc = std::make_unique>(true, true); @@ -57,7 +37,7 @@ void Compiler::Compile(Job* job, Plan* plan) const { // Step2: build task_gph. // TODO(levi): we can rewrite this part of code in visitor pattern. - auto task_gph = std::make_unique(); + auto task_gph = CHECK_JUST(GlobalTaskGraph::New()); using std::placeholders::_1; LazyMode::Guard guard(true); task_gph->ForEachNode(std::bind(&TaskNode::ProduceAllRegstsAndBindEdges, _1)); @@ -89,7 +69,7 @@ void Compiler::Compile(Job* job, Plan* plan) const { std::unique_lock guard(mtx); if (task_node->GetTaskType() == kNormalForward || task_node->GetTaskType() == kRepeat || task_node->GetTaskType() == kAcc) { - CreateOpAttributeRef(plan, job_desc.job_id(), &task_proto); + PlanUtil::CreateOpAttributeRef(plan, job_desc.job_id(), &task_proto); } plan->mutable_task()->Add(std::move(task_proto)); } // guard(mtx) diff --git a/oneflow/core/job/parallel_desc.cpp b/oneflow/core/job/parallel_desc.cpp index 50a932a66bc..af1ea80bc4b 100644 --- a/oneflow/core/job/parallel_desc.cpp +++ b/oneflow/core/job/parallel_desc.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #include #include "oneflow/core/job/parallel_desc.h" +#include "oneflow/core/common/container_util.h" #include "oneflow/core/common/decorator.h" #include "oneflow/core/common/util.h" #include "oneflow/core/common/cpp_attribute.h" @@ -195,6 +196,14 @@ bool ParallelDesc::TryGetParallelId(int64_t machine_id, int64_t device_id, return true; } +Maybe ParallelDesc::TryGetParallelId(int64_t rank, int64_t* parallel_id) const { + if (!HasMachineId(rank)) { return false; } + const auto& device_ids = sorted_dev_phy_ids(rank); + CHECK_EQ_OR_RETURN(device_ids.size(), 1) << "only sole device_id supported. parallel_conf: \n" + << parallel_conf().DebugString(); + return TryGetParallelId(rank, JUST(VectorAt(device_ids, 0)), parallel_id); +} + Maybe ParallelDesc::GetParallelContext(ParallelContext* parallel_ctx, int64_t machine_id, int64_t device_id) const { parallel_ctx->set_parallel_num(parallel_num()); diff --git a/oneflow/core/job/parallel_desc.h b/oneflow/core/job/parallel_desc.h index a15deaa6e25..dcca6c48e52 100644 --- a/oneflow/core/job/parallel_desc.h +++ b/oneflow/core/job/parallel_desc.h @@ -103,6 +103,8 @@ class ParallelDesc final { bool ContainingMachineId(int64_t machine_id) const; bool TryGetParallelId(int64_t machine_id, int64_t device_id, int64_t* parallel_id) const; + Maybe TryGetParallelId(int64_t rank, int64_t* parallel_id) const; + Maybe CheckDeviceIdsIsValid() const; private: diff --git a/oneflow/core/job/plan_util.cpp b/oneflow/core/job/plan_util.cpp index 2a14e8620c5..c4ac9d6e311 100644 --- a/oneflow/core/job/plan_util.cpp +++ b/oneflow/core/job/plan_util.cpp @@ -237,7 +237,7 @@ void GenChunkForMultiNNGraphMemoryReuseInMultiClient( } // namespace -void PlanUtil::MergeMemBlockIdByLogicalChainId(Plan* plan, const Job& job) { +void PlanUtil::MergeMemBlockIdByLogicalChainId(Plan* plan, const Job& job, int64_t limited_rank) { if (job.logical_chain_groups_size() == 0) { return; } HashMap> logical_chain_id2machine_id2mem_block_id; @@ -271,8 +271,16 @@ void PlanUtil::MergeMemBlockIdByLogicalChainId(Plan* plan, const Job& job) { for (const auto& logical_chain_group : job.logical_chain_groups()) { CHECK_GE(logical_chain_group.logical_chain_id_list_size(), 2); int64_t merged_logical_chain_id = logical_chain_group.logical_chain_id_list(0); - CHECK(logical_chain_id2machine_id2mem_block_id.find(merged_logical_chain_id) - != logical_chain_id2machine_id2mem_block_id.end()); + if (limited_rank == -1) { + CHECK(logical_chain_id2machine_id2mem_block_id.find(merged_logical_chain_id) + != logical_chain_id2machine_id2mem_block_id.end()); + } else { + if (logical_chain_id2machine_id2mem_block_id.find(merged_logical_chain_id) + == logical_chain_id2machine_id2mem_block_id.end()) { + // Skip when doing rank compile and this logical chain group is not related to this rank. + continue; + } + } const auto& merged_rank2block = logical_chain_id2machine_id2mem_block_id.at(merged_logical_chain_id); for (int64_t i = 1; i < logical_chain_group.logical_chain_id_list_size(); ++i) { @@ -285,7 +293,12 @@ void PlanUtil::MergeMemBlockIdByLogicalChainId(Plan* plan, const Job& job) { for (const auto& pair : this_rank2block) { int64_t this_machine_id = pair.first; int64_t this_mem_block_id = pair.second; - CHECK(merged_rank2block.find(this_machine_id) != merged_rank2block.end()); + if (limited_rank == -1) { + CHECK(merged_rank2block.find(this_machine_id) != merged_rank2block.end()); + } else { + if (merged_rank2block.find(this_machine_id) == merged_rank2block.end()) { continue; } + } + int64_t merged_mem_block_id = merged_rank2block.at(this_machine_id); CHECK(mem_block_id2merged_mem_block_id.emplace(this_mem_block_id, merged_mem_block_id) .second); @@ -313,12 +326,14 @@ void PlanUtil::MergeMemBlockIdByLogicalChainId(Plan* plan, const Job& job) { // merge mem_block_id int64_t merged_mem_block_id = mem_block_id2merged_mem_block_id.at(mem_block_id); regst_desc->set_mem_block_id(merged_mem_block_id); - const auto& data_regst = regst_desc->regst_desc_type().data_regst_desc(); - CHECK_GE(data_regst.lbi2blob_desc_size(), 1); - const auto& lbi2blob_desc_pair = data_regst.lbi2blob_desc(0); - std::string tensor_name = GenLogicalBlobName(lbi2blob_desc_pair.lbi()); - VLOG(3) << " regst: " << tensor_name << " merge mem block id " << mem_block_id << " to " - << merged_mem_block_id; + if (VLOG_IS_ON(3)) { + const auto& data_regst = regst_desc->regst_desc_type().data_regst_desc(); + CHECK_GE(data_regst.lbi2blob_desc_size(), 1); + const auto& lbi2blob_desc_pair = data_regst.lbi2blob_desc(0); + std::string tensor_name = GenLogicalBlobName(lbi2blob_desc_pair.lbi()); + VLOG(3) << " regst: " << tensor_name << " merge mem block id " << mem_block_id << " to " + << merged_mem_block_id; + } } } } @@ -801,10 +816,13 @@ std::function PlanUtil::MakeMutRegstDesc4Id(Plan* plan }; } -void PlanUtil::SetForceInplaceMemBlock(Plan* plan) { +void PlanUtil::SetForceInplaceMemBlock(Plan* plan, int64_t limited_rank) { auto RegstDesc4Id = MakeMutRegstDesc4Id(plan); for (int i = 0; i < plan->task_size(); i++) { TaskProto* task = plan->mutable_task(i); + // When do seperation compilation, some rank's plan (such as rank 0) has other ranks task node + // for compilation. There is no need to set mem block for other ranks task node. + if (limited_rank >= 0 && task->machine_id() != limited_rank) { continue; } for (auto& pair : *task->mutable_produced_regst_desc()) { RegstDescProto* regst_desc = &pair.second; if (regst_desc->has_force_inplace_consumed_regst_desc_id()) { @@ -1381,4 +1399,24 @@ void PlanUtil::PopulateOpAttribute( return GetStreamId(task).device_id().device_index(); } +/*static*/ void PlanUtil::CreateOpAttributeRef(Plan* plan, int64_t job_id, TaskProto* task_proto) { + auto* job_id2op_attribute_ref_table = plan->mutable_job_id2op_attribute_ref_table(); + CHECK(task_proto->exec_sequence().exec_node_size() == 1); + auto* exec_node = task_proto->mutable_exec_sequence()->mutable_exec_node(0); + CHECK(exec_node->kernel_conf().has_op_attribute()); + const std::string op_name = exec_node->kernel_conf().op_attribute().op_conf().name(); + auto* op_name2op_attribute = + (*job_id2op_attribute_ref_table)[job_id].mutable_op_name2op_attribute(); + auto find_it = op_name2op_attribute->find(op_name); + if (find_it == op_name2op_attribute->end()) { + op_name2op_attribute->insert( + {op_name, task_proto->exec_sequence().exec_node(0).kernel_conf().op_attribute()}); + } + auto* kernel_conf = + task_proto->mutable_exec_sequence()->mutable_exec_node(0)->mutable_kernel_conf(); + kernel_conf->set_op_attribute_ref(op_name); + // NOTE(levi): memory of op_attribute_ is released here. + kernel_conf->set_allocated_op_attribute(nullptr); +} + } // namespace oneflow diff --git a/oneflow/core/job/plan_util.h b/oneflow/core/job/plan_util.h index 9de7142a629..c8601ae6682 100644 --- a/oneflow/core/job/plan_util.h +++ b/oneflow/core/job/plan_util.h @@ -28,7 +28,10 @@ namespace oneflow { struct PlanUtil { static RegstDescProto* GetSoleProducedDataRegst(TaskProto* task_proto); static std::function MakeGetterTaskProto4TaskId(const Plan& plan); - static void MergeMemBlockIdByLogicalChainId(Plan* plan, const Job& job); + // limited_rank equals -1 means taking care of all ranks. + // Otherwise, only take care of rank limited_rank. + static void MergeMemBlockIdByLogicalChainId(Plan* plan, const Job& job, + int64_t limited_rank = -1); static void SetUniqueMemBlockId4UnreusedMemRegst(Plan* plan); static void GenMemBlockAndChunk4Plan(Plan* plan); static void GenMemBlockAndChunkWithVariableOpNames4Plan( @@ -36,7 +39,9 @@ struct PlanUtil { static void CleanUselessMemBlockAndCheckValid(Plan* plan); static void ToDotFile(const Plan& plan, const std::string& filepath); static std::function MakeMutRegstDesc4Id(Plan* plan); - static void SetForceInplaceMemBlock(Plan* plan); + // limited_rank equals -1 means taking care of all ranks. + // Otherwise, only take care of rank limited_rank. + static void SetForceInplaceMemBlock(Plan* plan, int64_t limited_rank = -1); static void DumpCtrlRegstInfoToPlan(Plan* plan); static void GenCollectiveBoxingPlan(Job* job, Plan* plan); static void GenRegisterHint(Plan* plan); @@ -50,6 +55,7 @@ struct PlanUtil { const PbMap& job_id2op_attribute_ref_table); static StreamId GetStreamId(const TaskProto& task); static int64_t GetDeviceIndex(const TaskProto& task); + static void CreateOpAttributeRef(Plan* plan, int64_t job_id, TaskProto* task_proto); }; } // namespace oneflow diff --git a/oneflow/core/job/rank_compiler.cpp b/oneflow/core/job/rank_compiler.cpp new file mode 100644 index 00000000000..fa9af000ce6 --- /dev/null +++ b/oneflow/core/job/rank_compiler.cpp @@ -0,0 +1,110 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/job/rank_compiler.h" +#include "oneflow/core/device/cuda_util.h" +#include "oneflow/core/job/global_for.h" +#include "oneflow/core/job/intra_job_mem_sharing_util.h" +#include "oneflow/core/job/plan_util.h" +#include "oneflow/core/persistence/tee_persistent_log_stream.h" +#include "oneflow/core/graph/op_graph.h" +#include "oneflow/core/job_rewriter/job_completer.h" +#include "oneflow/core/thread/thread_pool.h" +#include "oneflow/core/common/blocking_counter.h" +#include "oneflow/core/rpc/include/global_process_ctx.h" + +namespace oneflow { + +Maybe RankCompiler::Compile(const HashSet& var_op_names, Job* job, + Plan* plan) const { +#ifdef WITH_CUDA + // Use the right device when some plan compilation needs cuda to avoid creating unnecessary cuda + // context on cuda:0. + CudaCurrentDeviceGuard guard(GetCudaDeviceIndex()); +#endif // WITH_CUDA + auto task_gph = JUST(RankTaskGraph::New(boxing_task_graph_proto_, var_op_names, rank_)); + using std::placeholders::_1; + const auto& IsNotMyDuty = [&](const CompTaskNode* comp_task_node) { + if (comp_task_node == nullptr) { return false; } + const auto& parallel_desc = comp_task_node->op_node()->parallel_desc(); + return !task_gph->IsDutyRank(parallel_desc, comp_task_node->machine_id()); + }; + task_gph->ForEachNode([&](TaskNode* task_node) { + auto* comp_task_node = dynamic_cast(task_node); + if (IsNotMyDuty(comp_task_node)) { + auto* fake_consumed_regsts_provider = + dynamic_cast(comp_task_node); + CHECK_NOTNULL(fake_consumed_regsts_provider)->ConsumeFakeRegstsIf(); + } else { + task_node->ConsumeAllRegsts(); + } + }); + task_gph->ForEachNode([&](TaskNode* task_node) { + auto* comp_task_node = dynamic_cast(task_node); + if (IsNotMyDuty(comp_task_node)) { + // Do nothing. because all consumed registers are fake. + } else { + task_node->PinConsumedRegst(); + } + }); + task_gph->TopoForEachNode(&TaskNode::Build); + task_gph->RemoveEmptyRegsts(); + task_gph->TopoForEachNode(&TaskNode::InferTimeShapeIfMeaningful); + task_gph->DecideExecutionOrder(); + task_gph->MergeChainAndAddOrderingCtrlEdgeInSameChain(); + auto IsReachable = Singleton::Get()->MakePredicatorIsOpNameDataOrCtrlReachable(); + const JobDesc& job_desc = GlobalJobDesc(); + if (job_desc.enable_inplace()) { + task_gph->ForEachGpuDeviceNodes([&](const HashSet& dev_nodes) { + if (dev_nodes.empty()) { return; } + if ((*dev_nodes.begin())->machine_id() != rank_) { return; } // other ranks are ignored. + task_gph->EnableInplaceMemSharing(dev_nodes, IsReachable); + }); + } + task_gph->ForEachEdge([&](TaskEdge* task_edge) { task_edge->CheckRegstLbiValid(); }); + + // put infomation from task_gph into plan. + task_gph->ForEachNode([&](TaskNode* task_node) { + if (task_node->IsMeaningLess()) { return; } + auto* comp_task_node = dynamic_cast(task_node); + if (comp_task_node != nullptr) { + const auto& parallel_desc = comp_task_node->op_node()->parallel_desc(); + if (!task_gph->IsDutyRank(parallel_desc, task_node->machine_id())) { + auto* fake_consumed_regsts_provider = + dynamic_cast(comp_task_node); + CHECK_NOTNULL(fake_consumed_regsts_provider)->EraseFakeRegstsIf(); + } + } + TaskProto task_proto; + task_node->ToProto(&task_proto); + if (task_node->GetTaskType() == kNormalForward || task_node->GetTaskType() == kRepeat + || task_node->GetTaskType() == kAcc) { + PlanUtil::CreateOpAttributeRef(plan, job_desc.job_id(), &task_proto); + } + plan->mutable_task()->Add(std::move(task_proto)); + }); + + // post-process for plan and delete Singleton. + auto* job_id2job_conf = plan->mutable_job_confs()->mutable_job_id2job_conf(); + (*job_id2job_conf)[GlobalJobDesc().job_id()] = GlobalJobDesc().job_conf(); + // NOTE(chengcheng): infer mem blob id & set inplace & add ctrl + IntraJobMemSharingUtil::InferMemBlockId4MemReusedRegst(plan); + PlanUtil::MergeMemBlockIdByLogicalChainId(plan, *job, rank_); + PlanUtil::SetUniqueMemBlockId4UnreusedMemRegst(plan); + PlanUtil::SetForceInplaceMemBlock(plan, rank_); + return Maybe::Ok(); +} + +} // namespace oneflow diff --git a/oneflow/core/job/rank_compiler.h b/oneflow/core/job/rank_compiler.h new file mode 100644 index 00000000000..f00fe77fbeb --- /dev/null +++ b/oneflow/core/job/rank_compiler.h @@ -0,0 +1,43 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_JOB_RANK_COMPILER_H_ +#define ONEFLOW_CORE_JOB_RANK_COMPILER_H_ + +#include "oneflow/core/common/protobuf.h" +#include "oneflow/core/graph/task_graph.h" +#include "oneflow/core/graph/boxing_task_graph.pb.h" +#include "oneflow/core/job/plan.pb.h" +#include "oneflow/core/operator/operator.h" + +namespace oneflow { + +class RankCompiler final { + public: + OF_DISALLOW_COPY_AND_MOVE(RankCompiler); + RankCompiler(const std::shared_ptr& boxing_task_graph_proto, int64_t rank) + : boxing_task_graph_proto_(boxing_task_graph_proto), rank_(rank) {} + ~RankCompiler() = default; + + Maybe Compile(const HashSet& var_op_names, Job* job, Plan* plan) const; + + private: + std::shared_ptr boxing_task_graph_proto_; + int64_t rank_; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_JOB_RANK_COMPILER_H_ diff --git a/oneflow/core/operator/op_conf.proto b/oneflow/core/operator/op_conf.proto index fdc58cadc9c..e587fa3ef5b 100644 --- a/oneflow/core/operator/op_conf.proto +++ b/oneflow/core/operator/op_conf.proto @@ -399,6 +399,7 @@ message OperatorConf { optional string loc = 11 [default = ""]; optional int64 logical_chain_id = 12 [default = -1]; optional int64 order_in_logical_chain = 13 [default = -1]; + optional string calculation_pass_name = 14 [default = "forward_pass"]; oneof op_type { // system op CopyCommNetOpConf copy_comm_net_conf = 106;