Skip to content

Commit

Permalink
Feat graph load to new device (#10335)
Browse files Browse the repository at this point in the history
Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
  • Loading branch information
strint and oneflow-ci-bot authored Oct 6, 2023
1 parent 84ef72f commit dea3f43
Show file tree
Hide file tree
Showing 6 changed files with 271 additions and 35 deletions.
4 changes: 2 additions & 2 deletions oneflow/core/graph/stream_id.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ namespace oneflow {

// StreamId encoding (bits)
// | reserved | node_index | device_type | device_index | stream_index |
// | -- 21 -- | ----- 19 ----- | ---- 5 ---- | ----- 7 ----- | |
// | -- 18 -- | ----- 19 ----- | ---- 5 ---- | ----- 7 ----- | |
// | | DeviceId | |
// | | ------------------- 31 --------------------- | ---- 12 ---- |
// | | ------------------- 31 --------------------- | ---- 15 ---- |
// | StreamId |
// | -------------------------------- 64 ---------------------------------- |

Expand Down
4 changes: 2 additions & 2 deletions oneflow/core/graph/task_id.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ namespace oneflow {

// TaskId encoding (maybe extended to 128 bits in future)
// | rank | device_type | device_index | |
// | ----------- 19 ----------- | ---- 5 ---- | ----- 7 ----- | |
// | ----------- 16 ----------- | ---- 5 ---- | ----- 7 ----- | |
// | DeviceId | stream_index | |
// | ------------------------- 31 --------------------------- | ---- 12 ---- | |
// | ------------------------- 31 --------------------------- | ---- 15 ---- | |
// | StreamId | task_index |
// | -------------------------------- 43 ----------------------------------- | --- 21 --- |
// | TaskId |
Expand Down
24 changes: 23 additions & 1 deletion python/oneflow/nn/graph/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from oneflow.framework.args_tree import ArgsTree
from oneflow.framework.tensor import Tensor
import oneflow as flow
import oneflow


class LRUCache(object):
Expand Down Expand Up @@ -134,6 +134,28 @@ def runtime_state_dict(
destination[state_dict["graph_name"]] = state_dict
return destination

@staticmethod
def runtime_state_dict_to(
state_dict: Union[
Dict[str, Union[Dict[str, Tensor], str]],
Dict[str, Dict[str, Union[Dict[str, Tensor], str]]],
],
device: str,
) -> Union[
Dict[str, Union[Dict[str, Tensor], str]],
Dict[str, Dict[str, Union[Dict[str, Tensor], str]]],
]:
destination = OrderedDict()
destination._metadata = OrderedDict()
for (key, sub_state_dict) in state_dict.items():
dest_sub_state_dict = oneflow.nn.Graph.runtime_state_dict_to(
sub_state_dict, device
)
dest_sub_state_dict["cache_order"] = sub_state_dict["cache_order"]
dest_sub_state_dict["cache_key"] = sub_state_dict["cache_key"]
destination[key] = dest_sub_state_dict
return destination

def _init_and_get_a_graph_in_cache(self, cache_key):
self._base_graph._print(
0,
Expand Down
78 changes: 65 additions & 13 deletions python/oneflow/nn/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@
GraphIR,
seq_to_func_return,
sys_exc_error_msg,
_rsd_sub_destination_to,
_job_to,
_plan_to,
)
from oneflow.framework.args_tree import ArgsTree
from oneflow.nn.modules.module import Module
Expand Down Expand Up @@ -1069,34 +1072,35 @@ def _fill_sub_destination(dest_dict, name_list, tensor_tuple):
assert len(tensor_tuple) == len(name_list)
for name_idx in range(len(name_list)):
tensor_item = tensor_tuple[name_idx]
dest_dict[name_list[name_idx]] = (tensor_item, tensor_item.device.type)
device_str = ":".join(
(tensor_item.device.type, str(tensor_item.device.index))
)
dest_dict[name_list[name_idx]] = (tensor_item, device_str)

# This is original outputs is needed to build output buffer.
tuple_idx = -1

def gen_index_in_tuple(eager_out):
def gen_index_in_tuple(item):
nonlocal tuple_idx
tuple_idx += 1
return "_OFTPI" + str(tuple_idx) # oneflow tuple index
if isinstance(item, Tensor):
tuple_idx += 1
return "_OFTPI" + str(tuple_idx) # oneflow tuple index
else:
return item

inputs_sub_destination = OrderedDict()
_fill_sub_destination(
inputs_sub_destination, self._input_op_names, self._inputs_tensor_tuple
)

_eager_inputs_args, _eager_inputs_kwargs = self.__map_io(
"input",
gen_index_in_tuple,
*self.inputs_original[0],
**self.inputs_original[1],
_eager_inputs_args, _eager_inputs_kwargs = self.__map_io_lite(
gen_index_in_tuple, *self.inputs_original[0], **self.inputs_original[1],
)
destination["inputs"] = inputs_sub_destination
destination["inputs_original"] = (_eager_inputs_args, _eager_inputs_kwargs)

tuple_idx = -1
_eager_outputs, _ = self.__map_io(
"output", gen_index_in_tuple, *self._eager_outputs
)
_eager_outputs, _ = self.__map_io_lite(gen_index_in_tuple, *self._eager_outputs)
destination["outputs_original"] = _eager_outputs
assert len(self._outputs_tensor_tuple) == tuple_idx + 1
outputs_sub_destination = OrderedDict()
Expand Down Expand Up @@ -1146,7 +1150,7 @@ def load_runtime_state_dict(
Dict[str, Dict[str, Union[Dict[str, Tensor], str]]],
],
*,
warmup_with_run: bool = False,
warmup_with_run: bool = True,
) -> None:
if self._run_with_cache == True:
return self._dynamic_input_graph_cache.load_runtime_state_dict(
Expand Down Expand Up @@ -1293,6 +1297,7 @@ def get_tensor_in_tuple(tensor_tuple, map_item):
self.__run(
*_eager_inputs_args, **_eager_inputs_kwargs
) # pre-run to warm up
oneflow._oneflow_internal.eager.Sync()
build_graph_end = time.perf_counter()
self.__print(
0,
Expand All @@ -1304,6 +1309,53 @@ def get_tensor_in_tuple(tensor_tuple, map_item):
+ "\n",
)

@staticmethod
def runtime_state_dict_to(
state_dict: Union[
Dict[str, Union[Dict[str, Tensor], str]],
Dict[str, Dict[str, Union[Dict[str, Tensor], str]]],
],
device: str,
) -> Union[
Dict[str, Union[Dict[str, Tensor], str]],
Dict[str, Dict[str, Union[Dict[str, Tensor], str]]],
]:
if "job_id" not in state_dict:
from oneflow.nn.graph.cache import GraphCache

return GraphCache.runtime_state_dict_to(state_dict, device)

dest_device = oneflow.device(device)
assert dest_device.type == "cuda", "device must be cuda."

destination = OrderedDict()
destination._metadata = OrderedDict()
destination["oneflow_version"] = state_dict["oneflow_version"]
destination["graph_name"] = state_dict["graph_name"]
destination["job_id"] = state_dict["job_id"]
destination["inputs"] = _rsd_sub_destination_to(state_dict["inputs"], device)
destination["inputs_original"] = state_dict["inputs_original"]
destination["outputs"] = _rsd_sub_destination_to(state_dict["outputs"], device)
destination["outputs_original"] = state_dict["outputs_original"]
destination["oneflow_with_eager_tensor"] = state_dict[
"oneflow_with_eager_tensor"
]
if "states" in state_dict:
destination["states"] = _rsd_sub_destination_to(
state_dict["states"], device
)
destination["exe_plan"] = _plan_to(state_dict["exe_plan"], dest_device)
if "forward_graph" in state_dict:
forward_graph = deepcopy(state_dict["forward_graph"])
_job_to(forward_graph, dest_device)
destination["forward_graph"] = forward_graph
if "compile_graph" in state_dict:
compile_graph = deepcopy(state_dict["compile_graph"])
_job_to(compile_graph, dest_device)
destination["compile_graph"] = compile_graph
destination["id_state"] = state_dict["id_state"]
return destination

def build_graph(self, *args, **kwargs):
# Build graph
try:
Expand Down
119 changes: 119 additions & 0 deletions python/oneflow/nn/graph/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@
import sys
from string import Template
from typing import Callable, Dict, Union, List, Tuple, Optional
from collections import OrderedDict

from google.protobuf import text_format
from google.protobuf.message import Message

import oneflow
import oneflow.core.job.job_pb2 as job_pb
import oneflow.core.job.plan_pb2 as plan_pb
import oneflow.core.common.device_type_pb2 as device_type
import oneflow.core.operator.op_conf_pb2 as op_conf_util
from oneflow.framework.tensor import Tensor

Expand Down Expand Up @@ -308,3 +311,119 @@ def seq_to_func_return(seq, need_unpack=False):
if need_unpack:
return seq[0]
return seq


def _rsd_sub_destination_to(origin_dict, dest_device_str):
dest_dict = OrderedDict()
for k, v in origin_dict.items():
tensor_item, device_str = v
dest_dict[k] = (
tensor_item.to(device=oneflow.device(dest_device_str), copy=True),
dest_device_str,
)
return dest_dict


def _parallel_conf_to(parallel_conf, dest_device):
if parallel_conf.device_tag == "cuda":
assert len(parallel_conf.device_name) == 1
parallel_conf.device_name[0] = "@0:" + str(dest_device.index)


def _mem_case_to(mem_case, dest_device):
if mem_case.device_type == device_type.DeviceType.kCUDA:
mem_case.device_id = dest_device.index
if (
mem_case.HasField("pinned_device_type")
and mem_case.pinned_device_type == device_type.DeviceType.kCUDA
):
mem_case.pinned_device_id = dest_device.index


def _job_to(job, dest_device):
for pg in job.placement.placement_group:
_parallel_conf_to(pg.parallel_conf, dest_device)
for bpg in job.placement.blob_placement_group:
_parallel_conf_to(bpg.parallel_conf, dest_device)


def _modify_bits(original_num, k, j, new_num):
if k > j:
return original_num
mask = ((1 << (j - k + 1)) - 1) << k
cleared_num = original_num & ~mask
modified_num = cleared_num | ((new_num & ((1 << (j - k + 1)) - 1)) << k)
return modified_num


def _get_bits(original_num, k, j):
mask = ((1 << (j - k + 1)) - 1) << k
cleared_num = (original_num & mask) >> k

return cleared_num


def _task_id_to(task_id, dest_device):
if _get_bits(task_id, 43, 48) == 2:
new_id = _modify_bits(task_id, 36, 43, dest_device.index)

return new_id
else:
return task_id


def _thrd_id_to(thrd_id, dest_device):
if _get_bits(thrd_id, 22, 27) == 2:
new_id = _modify_bits(thrd_id, 15, 22, dest_device.index)
return new_id
else:
return thrd_id


def _plan_to(plan_str, dest_device):
plan = plan_pb.Plan()
plan.ParseFromString(plan_str)
for task in plan.task:
task.task_id = _task_id_to(task.task_id, dest_device)
task.thrd_id = _thrd_id_to(task.thrd_id, dest_device)
for node in task.exec_sequence.exec_node:
_parallel_conf_to(
node.kernel_conf.op_attribute.parallel_conf_signature.op_parallel_conf,
dest_device,
)
for name, regst in task.produced_regst_desc.items():
regst.producer_task_id = _task_id_to(regst.producer_task_id, dest_device)
for c_task_id_idx in range(len(regst.consumer_task_id)):
regst.consumer_task_id[c_task_id_idx] = _task_id_to(
regst.consumer_task_id[c_task_id_idx], dest_device
)
_mem_case_to(regst.mem_case, dest_device)
for mem_block in plan.block_chunk_list.mem_block:
_mem_case_to(mem_block.mem_case, dest_device)
mem_block.thrd_id_hint = _thrd_id_to(mem_block.thrd_id_hint, dest_device)
for chunk in plan.block_chunk_list.chunk:
_mem_case_to(chunk.mem_case, dest_device)

new_ctrl_regst_desc_id2producer_task_id = {}
for (
regst_desc_id,
producer_task_id,
) in plan.ctrl_regst_desc_info.ctrl_regst_desc_id2producer_task_id.items():
new_ctrl_regst_desc_id2producer_task_id[regst_desc_id] = _task_id_to(
producer_task_id, dest_device
)
for (
regst_desc_id,
producer_task_id,
) in new_ctrl_regst_desc_id2producer_task_id.items():
plan.ctrl_regst_desc_info.ctrl_regst_desc_id2producer_task_id[
regst_desc_id
] = producer_task_id

for job_id, op_attr_tab in plan.job_id2op_attribute_ref_table.items():
for _, op_attr in op_attr_tab.op_name2op_attribute.items():
_parallel_conf_to(
op_attr.parallel_conf_signature.op_parallel_conf, dest_device
)

return plan.SerializeToString()
Loading

0 comments on commit dea3f43

Please sign in to comment.