Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat oneflow as backend of torch #10205

Open
wants to merge 33 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
bd31ce2
graph interplay
strint Apr 26, 2023
3a52d51
rm unused
strint Apr 26, 2023
e2761aa
rm unused
strint Apr 26, 2023
0eb42fd
add relu test
strint Apr 27, 2023
0191d98
auto format by CI
oneflow-ci-bot Apr 27, 2023
1c0e540
mvp passed
strint Apr 27, 2023
8564420
mvp passed
strint Apr 27, 2023
c7e8aaa
auto format by CI
oneflow-ci-bot Apr 27, 2023
0498b8e
format mvp
strint Apr 27, 2023
231ead4
merge upstream
strint Apr 27, 2023
459a618
Update torch_fx_to_oneflow.py
strint Apr 27, 2023
00bea12
Update torch_compile.py
strint Apr 27, 2023
f88f9f2
Update __init__.py
strint Apr 27, 2023
5473588
auto format by CI
oneflow-ci-bot Apr 27, 2023
7feeec7
rename
strint Apr 27, 2023
6117e5c
auto format by CI
oneflow-ci-bot Apr 27, 2023
940fc5b
rename and ctrl graph
strint Apr 27, 2023
6f9e193
Merge branch 'feat_interplay' of https://github.com/Oneflow-Inc/onefl…
strint Apr 27, 2023
0725f11
format
strint Apr 27, 2023
4e3fc4a
auto format by CI
oneflow-ci-bot Apr 27, 2023
0c697c5
refine enable grpah
strint Apr 27, 2023
4c45918
Merge branch 'feat_interplay' of https://github.com/Oneflow-Inc/onefl…
strint Apr 27, 2023
021fde9
auto format by CI
oneflow-ci-bot Apr 27, 2023
37f0bae
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
strint Aug 10, 2023
5ba20c2
split fx transform
strint Aug 11, 2023
364067a
auto format by CI
oneflow-ci-bot Aug 11, 2023
581fab9
get torch script
strint Aug 11, 2023
1d495a1
Merge branch 'feat_interplay' of https://github.com/Oneflow-Inc/onefl…
strint Aug 11, 2023
d20f1d0
auto format by CI
oneflow-ci-bot Aug 11, 2023
0194b08
add gm in graph test
strint Aug 16, 2023
3a543a1
test linear passed
strint Aug 16, 2023
0d72cd7
test linear passed with fx graph and oneflow graph
strint Aug 16, 2023
c45f52d
auto format by CI
oneflow-ci-bot Aug 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/oneflow/nn/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def reset_parameters(self) -> None:
flow.nn.init.uniform_(self.bias, -bound, bound)

def forward(self, x):
print("run oneflow linear module")
if self.use_fused_matmul_bias:
return flow._C.fused_matmul_bias(x, self.weight, self.bias)
else:
Expand Down
16 changes: 16 additions & 0 deletions python/oneflow/test/graph/test_fx_symbolic_trace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,22 @@ def test_alexnet(test_case):
np.allclose(gm(input).numpy(), m(input).numpy(), equal_nan=True)
)

class AlexNetEvalGraph(flow.nn.Graph):
def __init__(self):
super().__init__()
self.alexnet = gm

def build(self, inp):
return self.alexnet(inp)

gm_g = AlexNetEvalGraph()
gm_g.debug(1)
for i in range(5):
input = flow.randn(1, 3, 224, 224)
test_case.assertTrue(
np.allclose(gm_g(input).numpy(), m(input).numpy(), equal_nan=True)
)


if __name__ == "__main__":
unittest.main()
132 changes: 132 additions & 0 deletions python/oneflow/test/graph/test_graph_interplay.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import unittest
import numpy as np

# Must import torch before oneflow, otherwise torch.jit.trace will raise error:
# terminate called after throwing an instance of 'pybind11::stop_iteration'
import torch
import oneflow.unittest


def _test_relu(test_case, device, from_script=False):
from typing import List
import torch
from oneflow.utils.backend.torch_compile import register_ofrt

input_arr = np.array(
[
[-0.94630778, -0.83378579, -0.87060891],
[2.0289922, -0.28708987, -2.18369248],
[0.35217619, -0.67095644, -1.58943879],
[0.08086036, -1.81075924, 1.20752494],
[0.8901075, -0.49976737, -1.07153746],
[-0.44872912, -1.07275683, 0.06256855],
[-0.22556897, 0.74798368, 0.90416439],
[0.48339456, -2.32742195, -0.59321527],
],
dtype=np.float32,
)
x = torch.tensor(input_arr, device=device)
eager_out = torch.relu(x)

os.environ["ofrt_from_script"] = str(from_script)
os.environ["ofrt_enable_graph"] = "1"

@torch.compile(backend="ofrt")
def fn(x):
y = torch.relu(x)
return y

compile_out = fn(x)
test_case.assertTrue(
np.allclose(
compile_out.cpu().detach().numpy(),
eager_out.cpu().detach().numpy(),
1e-05,
1e-05,
)
)
compile_out = fn(x)
test_case.assertTrue(
np.allclose(
compile_out.cpu().detach().numpy(),
eager_out.cpu().detach().numpy(),
1e-05,
1e-05,
)
)


def _test_linear(test_case, device):
from typing import List
import torch
from oneflow.utils.backend.torch_compile import register_ofrt

os.environ["ofrt_from_script"] = "0"
os.environ["ofrt_enable_graph"] = "1"

linear = torch.nn.Linear(3, 8, False)
linear = linear.to(device)
input_arr = np.array(
[
[-0.94630778, -0.83378579, -0.87060891],
[2.0289922, -0.28708987, -2.18369248],
[0.35217619, -0.67095644, -1.58943879],
[0.08086036, -1.81075924, 1.20752494],
[0.8901075, -0.49976737, -1.07153746],
[-0.44872912, -1.07275683, 0.06256855],
[-0.22556897, 0.74798368, 0.90416439],
[0.48339456, -2.32742195, -0.59321527],
],
dtype=np.float32,
)
x = torch.tensor(input_arr, device=device)
torch.nn.init.constant_(linear.weight, 2.3)
eager_out = linear(x)
Copy link
Contributor Author

@strint strint Aug 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后来 bbuff 尝试用 fx 的思路做 lnear 这种带状态的 module 时,遇到了参数 tensor 共享也需要考虑和 oneflow 共享的麻烦。


@torch.compile(backend="ofrt")
def fn(x):
y = linear(x)
return y

compile_out = fn(x)
test_case.assertTrue(
np.allclose(
compile_out.cpu().detach().numpy(),
eager_out.cpu().detach().numpy(),
1e-05,
1e-05,
)
)


@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
@oneflow.unittest.skip_unless_1n1d()
class TestAsTorchBackend(oneflow.unittest.TestCase):
def _test_relu_with_fx(test_case):
_test_relu(test_case, "cuda", False)

def _test_relu_with_script(test_case):
_test_relu(test_case, "cuda", True)

def test_linear_with_fx(test_case):
_test_linear(test_case, "cuda")


if __name__ == "__main__":
unittest.main()
42 changes: 42 additions & 0 deletions python/oneflow/test/graph/test_torch_jit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import torch
import oneflow
import numpy as np

input_arr = np.array(
[
[-0.94630778, -0.83378579, -0.87060891],
[2.0289922, -0.28708987, -2.18369248],
[0.35217619, -0.67095644, -1.58943879],
[0.08086036, -1.81075924, 1.20752494],
[0.8901075, -0.49976737, -1.07153746],
[-0.44872912, -1.07275683, 0.06256855],
[-0.22556897, 0.74798368, 0.90416439],
[0.48339456, -2.32742195, -0.59321527],
],
dtype=np.float32,
)
x = torch.tensor(input_arr, device="cuda")


def fn(x):
y = torch.relu(x)
return y


jit_mod = torch.jit.trace(fn, x)
print(jit_mod)
22 changes: 22 additions & 0 deletions python/oneflow/utils/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from oneflow.utils.backend.torch_compile import register_ofrt

register_ofrt()

__all__ = [
"register_ofrt",
]
140 changes: 140 additions & 0 deletions python/oneflow/utils/backend/from_torch_fx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import torch
import oneflow as flow
from torch import fx
from typing import Dict, Any, Dict, Tuple


def fx_tranform(gm):

of_gm = to_of_transform(gm)

enable_graph = os.getenv("ofrt_enable_graph", "False").lower() in (
"true",
"1",
"t",
)

if not enable_graph:
oneflow_fn = of_gm.forward
else:

class OfGraph(flow.nn.Graph):
def __init__(self):
super().__init__()
self.m = of_gm

def build(self, *args, **kwargs):
return self.m(*args, **kwargs)

of_g = OfGraph()
of_g.debug(0)
oneflow_fn = lambda *args, **kwargs: of_g(*args, **kwargs)

return oneflow_fn


def _parent_name(target: str) -> Tuple[str, str]:
"""
Splits a qualname into parent path and last atom.
For example, `foo.bar.baz` -> (`foo.bar`, `baz`)
"""
*parent, name = target.rsplit(".", 1)
return parent[0] if parent else "", name


def _replace_node_module(
node: torch.fx.Node, modules: Dict[str, Any], new_module: flow.nn.Module
):
assert isinstance(node.target, str)
parent_name, name = _parent_name(node.target)
setattr(modules[parent_name], name, new_module)


def _get_module(origin_mod):
linear = flow.nn.Linear(3, 8, False)
linear = linear.to("cuda")
flow.nn.init.constant_(linear.weight, 2.3)
return linear


def _to_of_transform(
gm: torch.fx.GraphModule, tracer_class: type = fx.Tracer
) -> torch.fx.GraphModule:
modules = dict(gm.named_modules())
for node in gm.graph.nodes:
# Checks if we're calling a function (i.e:
# torch.add)
if node.op == "call_function":
# The target attribute is the function
# that call_function calls.
if node.target == torch.relu:
node.target = flow.relu
elif node.op == "call_module":
print(node.format_node())
if type(modules[node.target] is torch.nn.Linear):
linear = modules[node.target]
print(linear)
_replace_node_module(node, modules, _get_module(linear))

gm.graph.lint()
gm.recompile()
for node in gm.graph.nodes:
print(node.format_node())
return gm


def to_of_transform(
gm: torch.fx.GraphModule, tracer_class: type = fx.Tracer
) -> torch.fx.GraphModule:
name2node = {}
name2obj = {}
of_g = flow.fx.Graph()
modules = dict(gm.named_modules())
for node in gm.graph.nodes:
print(node.format_node())
if node.op == "call_function":
if node.target == torch.relu:
node.target = flow.relu
elif node.op == "call_module":
if type(modules[node.target] is torch.nn.Linear):
linear = modules[node.target]
name2obj[node.target] = _get_module(linear)
of_node = of_g.create_node(
"call_module", node.target, args=(name2node[node.args[0].name],)
)
name2node[node.name] = of_node
elif node.op == "call_method":
...
elif node.op == "get_attr":
...
elif node.op == "placeholder":
of_node = of_g.create_node("placeholder", node.target)
name2node[node.name] = of_node
elif node.op == "output":
of_g.output((name2node[node.args[0][0].name],))
else:
raise ValueError(f"not valid node type{node.foramt_node()}")
print("\n new of graph", of_g.print_tabular())
for of_node in of_g.nodes:
print(of_node.format_node())

of_gm = flow.fx.GraphModule(name2obj, of_g)
of_gm.graph.lint()
of_gm.recompile()
return of_gm
Loading