Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[For dev]Sdxl dev #10323

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions oneflow/core/functional/impl/nn_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5223,6 +5223,7 @@ class GroupedMatmulFunctor {
Maybe<TensorTuple> operator()(const TensorTuple& xs, const TensorTuple& weights) const {
const int64_t input_size = xs.size();
const int64_t weight_size = weights.size();
CHECK_LT_OR_RETURN(input_size, kMaxInputCount);
CHECK_GE_OR_RETURN(input_size, 1)
<< Error::RuntimeError() << "The number of xs should be greater equal than 1.";
CHECK_EQ_OR_RETURN(weight_size, input_size)
Expand Down
8 changes: 7 additions & 1 deletion oneflow/user/kernels/grouped_matmul_bias.cu
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,13 @@ class GroupedMatmulBiasKernel final : public user_op::OpKernel, public user_op::
}
void* workspace = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0)->mut_dptr();
for (const auto& group : groups) {
ApplyGroup<T>(group.first, group.second, has_biases, workspace, ctx->stream());
for (size_t i = 0; i < group.second.size(); i += kMaxProblemBatch) {
std::vector<Buffer<T>> ptrs(
{group.second.begin() + i,
group.second.begin() + i
+ std::min<size_t>(group.second.size() - i, kMaxProblemBatch)});
ApplyGroup<T>(group.first, ptrs, has_biases, workspace, ctx->stream());
}
}
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
Expand Down
33 changes: 33 additions & 0 deletions python/oneflow/_dynamo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import warnings

# Reference: https://github.com/pytorch/pytorch/blob/v2.0.1/torch/_dynamo/__init__.py
__all__ = [
"allow_in_graph",
]


def allow_in_graph(fn):
"""
"""
if isinstance(fn, (list, tuple)):
return [allow_in_graph(x) for x in fn]
assert callable(fn), "allow_in_graph expects a callable"
warnings.warn(
"The oneflow._dynamo.allow_in_graph interface is just to align the torch._dynamo.allow_in_graph interface and has no practical significance."
)
return fn
34 changes: 26 additions & 8 deletions python/oneflow/framework/args_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,15 @@ class NamedArg(object):
named_input = NamedArg([NamedArg(1), NamedArg({key: NamedArg("value")})])
"""

def __init__(self, prefix="", name=None, global_index=0) -> None:
def __init__(
self, prefix="", name=None, global_index=0, tensor_type=Tensor
) -> None:
self._name = name if name is not None else str(global_index)
self._prefix = prefix
self._global_index = global_index
self._is_value_set = False
self._value = None
self._tensor_type = tensor_type

def prefix(self):
return self._prefix
Expand Down Expand Up @@ -86,21 +89,28 @@ def __repr__(self):
repr_str += "LIST"
elif _is_raw_type(self._value, dict) or _is_raw_type(self._value, OrderedDict):
repr_str += "DICT"
elif isinstance(self._value, Tensor):
elif isinstance(self._value, self._tensor_type):
repr_str += "TENSOR"
elif self._value is None:
repr_str += "NONE"
else:
repr_str += "OPAQUE"
if isinstance(self._value, Tensor):
repr_str += ", value: " + self._value._meta_repr()

if isinstance(self._value, self._tensor_type):
repr_str += (
", value: tensor("
+ str(self._value.shape)
+ ", "
+ str(self._value.dtype)
+ ")"
)
elif (
_is_raw_type(self._value, dict)
or _is_raw_type(self._value, OrderedDict)
or _is_raw_type(self._value, list)
or _is_raw_type(self._value, tuple)
):
pass
repr_str += ", value: " + repr(self._value)
else:
repr_str += ", value: " + repr(self._value)
repr_str += ")"
Expand All @@ -114,6 +124,7 @@ def __init__(
gen_name: bool = False,
root_prefix: str = "",
root_name: str = None,
tensor_type=Tensor,
) -> None:

self._io_args = io_args
Expand All @@ -122,6 +133,7 @@ def __init__(
self._root_name = root_name
self._named_io_args = None
self._next_global_index = 0
self._tensor_type = tensor_type

if self._gen_name:
self._named_io_args = self._construct_named_io_args(
Expand Down Expand Up @@ -178,7 +190,7 @@ def iter_named_nodes(self):
yield (named_node.prefix() + "_" + named_node.name(), named_node)

def _construct_named_io_args(self, value, prefix: str, name: str) -> NamedArg:
arg = NamedArg(prefix, name, self._next_global_index)
arg = NamedArg(prefix, name, self._next_global_index, self._tensor_type)
self._next_global_index += 1

if _is_raw_type(value, list) or _is_raw_type(value, tuple):
Expand Down Expand Up @@ -219,7 +231,7 @@ def map_tuple_leaf(self, map_function: Callable):
stack = []

# Cases handled: tuple(tensor, ...), such as input args.
if len(self._io_args) > 0 and isinstance(self._io_args[0], Tensor):
if len(self._io_args) > 0 and isinstance(self._io_args[0], self._tensor_type):
for i in self._io_args:
mapped_value = map_function(i)
stack.append(mapped_value)
Expand All @@ -233,7 +245,7 @@ def map_tuple_leaf(self, map_function: Callable):
elif (
len(self._io_args) > 0
and isinstance(self._io_args[0], (tuple, list))
and all(isinstance(arg, Tensor) for arg in self._io_args[0])
and all(isinstance(arg, self._tensor_type) for arg in self._io_args[0])
):
for i in self._io_args[0]:
mapped_value = map_function(i)
Expand Down Expand Up @@ -283,3 +295,9 @@ def _execute_mapping(self, value, map_function):
mapped_value = map_function(value)

return mapped_value

def __repr__(self):
if self._named_io_args:
return self._named_io_args.__repr__()
else:
return str(self.__class__)
31 changes: 10 additions & 21 deletions python/oneflow/nn/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1748,14 +1748,13 @@ def __build_io(self, io_type, build_func, *args, **kwargs):
args_repr = []
tensor2op_name = {}

def build_tensor_or_none(tensor, name, repr_str):
assert tensor is None or (isinstance(tensor, Tensor))
def build_tensor_or_any(tensor, name, repr_str):
if isinstance(tensor, Tensor):
build_arg = build_func(name, tensor)
op_names.append(name)
tensor2op_name[build_arg] = name
else:
build_arg = None
build_arg = tensor

args_repr.append(repr_str)
self.__print(0, 1, repr_str)
Expand All @@ -1771,18 +1770,13 @@ def leaf_arg_fn(arg):
arg_repr = self.__io_item_check_and_gen_repr(
arg.value(), Tensor, io_type, name
)
build_arg = build_tensor_or_none(arg.value(), name, arg_repr)
build_arg = build_tensor_or_any(arg.value(), name, arg_repr)
return build_arg
elif arg.value() is None:
arg_repr = self.__io_item_check_and_gen_repr(
arg.value(), None, io_type, name
)
build_arg = build_tensor_or_none(arg.value(), name, arg_repr)
else: # Opaque
# Error
arg_repr = self.__io_item_check_and_gen_repr(
arg.value(), None, io_type, name
)
build_arg = build_tensor_or_any(arg.value(), name, arg_repr)

out = args_tree.map_leaf(leaf_arg_fn)
build_args = out[0]
Expand All @@ -1792,7 +1786,7 @@ def leaf_arg_fn(arg):

def __io_item_check_and_gen_repr(self, item, expect_type, io_type, name):
assert io_type in ("input", "output")
if expect_type is None and item is None:
if expect_type is None:
repr_str = (
"[WARNING]("
+ io_type.upper()
Expand All @@ -1802,6 +1796,7 @@ def __io_item_check_and_gen_repr(self, item, expect_type, io_type, name):
+ str(type(item))
+ ")"
)
self.__print(1, 0, repr_str)
return repr_str
elif expect_type is not None and isinstance(item, expect_type):
if isinstance(item, Tensor):
Expand Down Expand Up @@ -1831,27 +1826,21 @@ def __io_item_check_and_gen_repr(self, item, expect_type, io_type, name):
def __map_io(self, io_type, func, *args, **kwargs):
assert io_type in ("input", "output")

def mapping_tensor_or_none(tensor):
assert tensor is None or (isinstance(tensor, Tensor))
def mapping_tensor_or_any(tensor):
if isinstance(tensor, Tensor):
mapped_arg = func(tensor)
else:
mapped_arg = None
mapped_arg = tensor
return mapped_arg

def leaf_arg_fn(arg):
arg_value = arg.value()
if isinstance(arg_value, Tensor) or arg_value is None:
return mapping_tensor_or_none(arg_value)
else:
self.__io_item_check(
arg_value, None, io_type, arg.prefix() + "_" + arg.name(),
)
return mapping_tensor_or_any(arg_value)

# NOTE(lixiang): Reduce the overhead of traversal and parsing of io args.
if self._is_simple_tuple_output or self._is_simple_tuple_input:
args_tree = ArgsTree(args, False)
out = args_tree.map_tuple_leaf(mapping_tensor_or_none)
out = args_tree.map_tuple_leaf(mapping_tensor_or_any)
return out, kwargs

args_tree = ArgsTree(
Expand Down
Loading