diff --git a/python/oneflow/_dynamo/__init__.py b/python/oneflow/_dynamo/__init__.py new file mode 100644 index 00000000000..abc1eea891a --- /dev/null +++ b/python/oneflow/_dynamo/__init__.py @@ -0,0 +1,33 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import warnings + +# Reference: https://github.com/pytorch/pytorch/blob/v2.0.1/torch/_dynamo/__init__.py +__all__ = [ + "allow_in_graph", +] + + +def allow_in_graph(fn): + """ + """ + if isinstance(fn, (list, tuple)): + return [allow_in_graph(x) for x in fn] + assert callable(fn), "allow_in_graph expects a callable" + warnings.warn( + "The oneflow._dynamo.allow_in_graph interface is just to align the torch._dynamo.allow_in_graph interface and has no practical significance." + ) + return fn diff --git a/python/oneflow/framework/args_tree.py b/python/oneflow/framework/args_tree.py index afd38c9907b..50f4e6a7fcb 100644 --- a/python/oneflow/framework/args_tree.py +++ b/python/oneflow/framework/args_tree.py @@ -41,12 +41,15 @@ class NamedArg(object): named_input = NamedArg([NamedArg(1), NamedArg({key: NamedArg("value")})]) """ - def __init__(self, prefix="", name=None, global_index=0) -> None: + def __init__( + self, prefix="", name=None, global_index=0, tensor_type=Tensor + ) -> None: self._name = name if name is not None else str(global_index) self._prefix = prefix self._global_index = global_index self._is_value_set = False self._value = None + self._tensor_type = tensor_type def prefix(self): return self._prefix @@ -86,21 +89,28 @@ def __repr__(self): repr_str += "LIST" elif _is_raw_type(self._value, dict) or _is_raw_type(self._value, OrderedDict): repr_str += "DICT" - elif isinstance(self._value, Tensor): + elif isinstance(self._value, self._tensor_type): repr_str += "TENSOR" elif self._value is None: repr_str += "NONE" else: repr_str += "OPAQUE" - if isinstance(self._value, Tensor): - repr_str += ", value: " + self._value._meta_repr() + + if isinstance(self._value, self._tensor_type): + repr_str += ( + ", value: tensor(" + + str(self._value.shape) + + ", " + + str(self._value.dtype) + + ")" + ) elif ( _is_raw_type(self._value, dict) or _is_raw_type(self._value, OrderedDict) or _is_raw_type(self._value, list) or _is_raw_type(self._value, tuple) ): - pass + repr_str += ", value: " + repr(self._value) else: repr_str += ", value: " + repr(self._value) repr_str += ")" @@ -114,6 +124,7 @@ def __init__( gen_name: bool = False, root_prefix: str = "", root_name: str = None, + tensor_type=Tensor, ) -> None: self._io_args = io_args @@ -122,6 +133,7 @@ def __init__( self._root_name = root_name self._named_io_args = None self._next_global_index = 0 + self._tensor_type = tensor_type if self._gen_name: self._named_io_args = self._construct_named_io_args( @@ -178,7 +190,7 @@ def iter_named_nodes(self): yield (named_node.prefix() + "_" + named_node.name(), named_node) def _construct_named_io_args(self, value, prefix: str, name: str) -> NamedArg: - arg = NamedArg(prefix, name, self._next_global_index) + arg = NamedArg(prefix, name, self._next_global_index, self._tensor_type) self._next_global_index += 1 if _is_raw_type(value, list) or _is_raw_type(value, tuple): @@ -219,7 +231,7 @@ def map_tuple_leaf(self, map_function: Callable): stack = [] # Cases handled: tuple(tensor, ...), such as input args. - if len(self._io_args) > 0 and isinstance(self._io_args[0], Tensor): + if len(self._io_args) > 0 and isinstance(self._io_args[0], self._tensor_type): for i in self._io_args: mapped_value = map_function(i) stack.append(mapped_value) @@ -233,7 +245,7 @@ def map_tuple_leaf(self, map_function: Callable): elif ( len(self._io_args) > 0 and isinstance(self._io_args[0], (tuple, list)) - and all(isinstance(arg, Tensor) for arg in self._io_args[0]) + and all(isinstance(arg, self._tensor_type) for arg in self._io_args[0]) ): for i in self._io_args[0]: mapped_value = map_function(i) @@ -283,3 +295,9 @@ def _execute_mapping(self, value, map_function): mapped_value = map_function(value) return mapped_value + + def __repr__(self): + if self._named_io_args: + return self._named_io_args.__repr__() + else: + return str(self.__class__) diff --git a/python/oneflow/nn/graph/graph.py b/python/oneflow/nn/graph/graph.py index 2b6aca3627c..7093c228b72 100644 --- a/python/oneflow/nn/graph/graph.py +++ b/python/oneflow/nn/graph/graph.py @@ -1748,14 +1748,13 @@ def __build_io(self, io_type, build_func, *args, **kwargs): args_repr = [] tensor2op_name = {} - def build_tensor_or_none(tensor, name, repr_str): - assert tensor is None or (isinstance(tensor, Tensor)) + def build_tensor_or_any(tensor, name, repr_str): if isinstance(tensor, Tensor): build_arg = build_func(name, tensor) op_names.append(name) tensor2op_name[build_arg] = name else: - build_arg = None + build_arg = tensor args_repr.append(repr_str) self.__print(0, 1, repr_str) @@ -1771,18 +1770,13 @@ def leaf_arg_fn(arg): arg_repr = self.__io_item_check_and_gen_repr( arg.value(), Tensor, io_type, name ) - build_arg = build_tensor_or_none(arg.value(), name, arg_repr) + build_arg = build_tensor_or_any(arg.value(), name, arg_repr) return build_arg - elif arg.value() is None: - arg_repr = self.__io_item_check_and_gen_repr( - arg.value(), None, io_type, name - ) - build_arg = build_tensor_or_none(arg.value(), name, arg_repr) else: # Opaque - # Error arg_repr = self.__io_item_check_and_gen_repr( arg.value(), None, io_type, name ) + build_arg = build_tensor_or_any(arg.value(), name, arg_repr) out = args_tree.map_leaf(leaf_arg_fn) build_args = out[0] @@ -1792,7 +1786,7 @@ def leaf_arg_fn(arg): def __io_item_check_and_gen_repr(self, item, expect_type, io_type, name): assert io_type in ("input", "output") - if expect_type is None and item is None: + if expect_type is None: repr_str = ( "[WARNING](" + io_type.upper() @@ -1802,6 +1796,7 @@ def __io_item_check_and_gen_repr(self, item, expect_type, io_type, name): + str(type(item)) + ")" ) + self.__print(1, 0, repr_str) return repr_str elif expect_type is not None and isinstance(item, expect_type): if isinstance(item, Tensor): @@ -1831,27 +1826,21 @@ def __io_item_check_and_gen_repr(self, item, expect_type, io_type, name): def __map_io(self, io_type, func, *args, **kwargs): assert io_type in ("input", "output") - def mapping_tensor_or_none(tensor): - assert tensor is None or (isinstance(tensor, Tensor)) + def mapping_tensor_or_any(tensor): if isinstance(tensor, Tensor): mapped_arg = func(tensor) else: - mapped_arg = None + mapped_arg = tensor return mapped_arg def leaf_arg_fn(arg): arg_value = arg.value() - if isinstance(arg_value, Tensor) or arg_value is None: - return mapping_tensor_or_none(arg_value) - else: - self.__io_item_check( - arg_value, None, io_type, arg.prefix() + "_" + arg.name(), - ) + return mapping_tensor_or_any(arg_value) # NOTE(lixiang): Reduce the overhead of traversal and parsing of io args. if self._is_simple_tuple_output or self._is_simple_tuple_input: args_tree = ArgsTree(args, False) - out = args_tree.map_tuple_leaf(mapping_tensor_or_none) + out = args_tree.map_tuple_leaf(mapping_tensor_or_any) return out, kwargs args_tree = ArgsTree(