Skip to content

Commit

Permalink
Fix dynamo mock error (#10318)
Browse files Browse the repository at this point in the history
Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
  • Loading branch information
strint and oneflow-ci-bot authored Aug 29, 2023
1 parent 57c4a49 commit e4118c7
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 29 deletions.
33 changes: 33 additions & 0 deletions python/oneflow/_dynamo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import warnings

# Reference: https://github.com/pytorch/pytorch/blob/v2.0.1/torch/_dynamo/__init__.py
__all__ = [
"allow_in_graph",
]


def allow_in_graph(fn):
"""
"""
if isinstance(fn, (list, tuple)):
return [allow_in_graph(x) for x in fn]
assert callable(fn), "allow_in_graph expects a callable"
warnings.warn(
"The oneflow._dynamo.allow_in_graph interface is just to align the torch._dynamo.allow_in_graph interface and has no practical significance."
)
return fn
34 changes: 26 additions & 8 deletions python/oneflow/framework/args_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,15 @@ class NamedArg(object):
named_input = NamedArg([NamedArg(1), NamedArg({key: NamedArg("value")})])
"""

def __init__(self, prefix="", name=None, global_index=0) -> None:
def __init__(
self, prefix="", name=None, global_index=0, tensor_type=Tensor
) -> None:
self._name = name if name is not None else str(global_index)
self._prefix = prefix
self._global_index = global_index
self._is_value_set = False
self._value = None
self._tensor_type = tensor_type

def prefix(self):
return self._prefix
Expand Down Expand Up @@ -86,21 +89,28 @@ def __repr__(self):
repr_str += "LIST"
elif _is_raw_type(self._value, dict) or _is_raw_type(self._value, OrderedDict):
repr_str += "DICT"
elif isinstance(self._value, Tensor):
elif isinstance(self._value, self._tensor_type):
repr_str += "TENSOR"
elif self._value is None:
repr_str += "NONE"
else:
repr_str += "OPAQUE"
if isinstance(self._value, Tensor):
repr_str += ", value: " + self._value._meta_repr()

if isinstance(self._value, self._tensor_type):
repr_str += (
", value: tensor("
+ str(self._value.shape)
+ ", "
+ str(self._value.dtype)
+ ")"
)
elif (
_is_raw_type(self._value, dict)
or _is_raw_type(self._value, OrderedDict)
or _is_raw_type(self._value, list)
or _is_raw_type(self._value, tuple)
):
pass
repr_str += ", value: " + repr(self._value)
else:
repr_str += ", value: " + repr(self._value)
repr_str += ")"
Expand All @@ -114,6 +124,7 @@ def __init__(
gen_name: bool = False,
root_prefix: str = "",
root_name: str = None,
tensor_type=Tensor,
) -> None:

self._io_args = io_args
Expand All @@ -122,6 +133,7 @@ def __init__(
self._root_name = root_name
self._named_io_args = None
self._next_global_index = 0
self._tensor_type = tensor_type

if self._gen_name:
self._named_io_args = self._construct_named_io_args(
Expand Down Expand Up @@ -178,7 +190,7 @@ def iter_named_nodes(self):
yield (named_node.prefix() + "_" + named_node.name(), named_node)

def _construct_named_io_args(self, value, prefix: str, name: str) -> NamedArg:
arg = NamedArg(prefix, name, self._next_global_index)
arg = NamedArg(prefix, name, self._next_global_index, self._tensor_type)
self._next_global_index += 1

if _is_raw_type(value, list) or _is_raw_type(value, tuple):
Expand Down Expand Up @@ -219,7 +231,7 @@ def map_tuple_leaf(self, map_function: Callable):
stack = []

# Cases handled: tuple(tensor, ...), such as input args.
if len(self._io_args) > 0 and isinstance(self._io_args[0], Tensor):
if len(self._io_args) > 0 and isinstance(self._io_args[0], self._tensor_type):
for i in self._io_args:
mapped_value = map_function(i)
stack.append(mapped_value)
Expand All @@ -233,7 +245,7 @@ def map_tuple_leaf(self, map_function: Callable):
elif (
len(self._io_args) > 0
and isinstance(self._io_args[0], (tuple, list))
and all(isinstance(arg, Tensor) for arg in self._io_args[0])
and all(isinstance(arg, self._tensor_type) for arg in self._io_args[0])
):
for i in self._io_args[0]:
mapped_value = map_function(i)
Expand Down Expand Up @@ -283,3 +295,9 @@ def _execute_mapping(self, value, map_function):
mapped_value = map_function(value)

return mapped_value

def __repr__(self):
if self._named_io_args:
return self._named_io_args.__repr__()
else:
return str(self.__class__)
31 changes: 10 additions & 21 deletions python/oneflow/nn/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1748,14 +1748,13 @@ def __build_io(self, io_type, build_func, *args, **kwargs):
args_repr = []
tensor2op_name = {}

def build_tensor_or_none(tensor, name, repr_str):
assert tensor is None or (isinstance(tensor, Tensor))
def build_tensor_or_any(tensor, name, repr_str):
if isinstance(tensor, Tensor):
build_arg = build_func(name, tensor)
op_names.append(name)
tensor2op_name[build_arg] = name
else:
build_arg = None
build_arg = tensor

args_repr.append(repr_str)
self.__print(0, 1, repr_str)
Expand All @@ -1771,18 +1770,13 @@ def leaf_arg_fn(arg):
arg_repr = self.__io_item_check_and_gen_repr(
arg.value(), Tensor, io_type, name
)
build_arg = build_tensor_or_none(arg.value(), name, arg_repr)
build_arg = build_tensor_or_any(arg.value(), name, arg_repr)
return build_arg
elif arg.value() is None:
arg_repr = self.__io_item_check_and_gen_repr(
arg.value(), None, io_type, name
)
build_arg = build_tensor_or_none(arg.value(), name, arg_repr)
else: # Opaque
# Error
arg_repr = self.__io_item_check_and_gen_repr(
arg.value(), None, io_type, name
)
build_arg = build_tensor_or_any(arg.value(), name, arg_repr)

out = args_tree.map_leaf(leaf_arg_fn)
build_args = out[0]
Expand All @@ -1792,7 +1786,7 @@ def leaf_arg_fn(arg):

def __io_item_check_and_gen_repr(self, item, expect_type, io_type, name):
assert io_type in ("input", "output")
if expect_type is None and item is None:
if expect_type is None:
repr_str = (
"[WARNING]("
+ io_type.upper()
Expand All @@ -1802,6 +1796,7 @@ def __io_item_check_and_gen_repr(self, item, expect_type, io_type, name):
+ str(type(item))
+ ")"
)
self.__print(1, 0, repr_str)
return repr_str
elif expect_type is not None and isinstance(item, expect_type):
if isinstance(item, Tensor):
Expand Down Expand Up @@ -1831,27 +1826,21 @@ def __io_item_check_and_gen_repr(self, item, expect_type, io_type, name):
def __map_io(self, io_type, func, *args, **kwargs):
assert io_type in ("input", "output")

def mapping_tensor_or_none(tensor):
assert tensor is None or (isinstance(tensor, Tensor))
def mapping_tensor_or_any(tensor):
if isinstance(tensor, Tensor):
mapped_arg = func(tensor)
else:
mapped_arg = None
mapped_arg = tensor
return mapped_arg

def leaf_arg_fn(arg):
arg_value = arg.value()
if isinstance(arg_value, Tensor) or arg_value is None:
return mapping_tensor_or_none(arg_value)
else:
self.__io_item_check(
arg_value, None, io_type, arg.prefix() + "_" + arg.name(),
)
return mapping_tensor_or_any(arg_value)

# NOTE(lixiang): Reduce the overhead of traversal and parsing of io args.
if self._is_simple_tuple_output or self._is_simple_tuple_input:
args_tree = ArgsTree(args, False)
out = args_tree.map_tuple_leaf(mapping_tensor_or_none)
out = args_tree.map_tuple_leaf(mapping_tensor_or_any)
return out, kwargs

args_tree = ArgsTree(
Expand Down

0 comments on commit e4118c7

Please sign in to comment.