-
Notifications
You must be signed in to change notification settings - Fork 667
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat oneflow as backend of torch #10205
base: master
Are you sure you want to change the base?
Changes from 13 commits
bd31ce2
3a52d51
e2761aa
0eb42fd
0191d98
1c0e540
8564420
c7e8aaa
0498b8e
231ead4
459a618
00bea12
f88f9f2
5473588
7feeec7
6117e5c
940fc5b
6f9e193
0725f11
4e3fc4a
0c697c5
4c45918
021fde9
37f0bae
5ba20c2
364067a
581fab9
1d495a1
d20f1d0
0194b08
3a543a1
0d72cd7
c45f52d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
""" | ||
Copyright 2020 The OneFlow Authors. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
import os | ||
import unittest | ||
import numpy as np | ||
|
||
import oneflow.unittest | ||
|
||
|
||
def _test_relu(test_case, device): | ||
from typing import List | ||
import torch | ||
from oneflow.utils.backend.torch_compile import register_oneflowc | ||
|
||
input_arr = np.array( | ||
[ | ||
[-0.94630778, -0.83378579, -0.87060891], | ||
[2.0289922, -0.28708987, -2.18369248], | ||
[0.35217619, -0.67095644, -1.58943879], | ||
[0.08086036, -1.81075924, 1.20752494], | ||
[0.8901075, -0.49976737, -1.07153746], | ||
[-0.44872912, -1.07275683, 0.06256855], | ||
[-0.22556897, 0.74798368, 0.90416439], | ||
[0.48339456, -2.32742195, -0.59321527], | ||
], | ||
dtype=np.float32, | ||
) | ||
x = torch.tensor(input_arr, device=device) | ||
eager_out = torch.relu(x) | ||
|
||
@torch.compile(backend="oneflowc") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 使用 oneflowc |
||
def fn(x): | ||
y = torch.relu(x) | ||
return y | ||
|
||
compile_out = fn(x) | ||
test_case.assertTrue( | ||
np.allclose( | ||
compile_out.cpu().detach().numpy(), | ||
eager_out.cpu().detach().numpy(), | ||
1e-05, | ||
1e-05, | ||
) | ||
) | ||
compile_out = fn(x) | ||
test_case.assertTrue( | ||
np.allclose( | ||
compile_out.cpu().detach().numpy(), | ||
eager_out.cpu().detach().numpy(), | ||
1e-05, | ||
1e-05, | ||
) | ||
) | ||
|
||
|
||
def _test_linear(test_case, device): | ||
from typing import List | ||
import torch | ||
from torch._dynamo.backends.registry import register_backend | ||
from torch._dynamo.backends.common import fake_tensor_unsupported | ||
|
||
linear = torch.nn.Linear(3, 8, False) | ||
linear = linear.to(device) | ||
input_arr = np.array( | ||
[ | ||
[-0.94630778, -0.83378579, -0.87060891], | ||
[2.0289922, -0.28708987, -2.18369248], | ||
[0.35217619, -0.67095644, -1.58943879], | ||
[0.08086036, -1.81075924, 1.20752494], | ||
[0.8901075, -0.49976737, -1.07153746], | ||
[-0.44872912, -1.07275683, 0.06256855], | ||
[-0.22556897, 0.74798368, 0.90416439], | ||
[0.48339456, -2.32742195, -0.59321527], | ||
], | ||
dtype=np.float32, | ||
) | ||
x = torch.tensor(input_arr, device=device) | ||
torch.nn.init.constant_(linear.weight, 2.3) | ||
eager_out = linear(x) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 后来 bbuff 尝试用 fx 的思路做 lnear 这种带状态的 module 时,遇到了参数 tensor 共享也需要考虑和 oneflow 共享的麻烦。 |
||
|
||
def get_of(): | ||
# TODO(): transform torch fx code to oneflow code | ||
import oneflow as flow | ||
|
||
linear = flow.nn.Linear(3, 8, False) | ||
linear = linear.to(device) | ||
flow.nn.init.constant_(linear.weight, 2.3) | ||
|
||
class LinearGraph(flow.nn.Graph): | ||
def __init__(self): | ||
super().__init__() | ||
self.my_linear = linear | ||
|
||
def build(self, x): | ||
return self.my_linear(x) | ||
|
||
linear_g = LinearGraph() | ||
linear_g.debug(1) | ||
return linear_g | ||
|
||
g = None | ||
|
||
def torch_interplay(x): | ||
import oneflow as flow | ||
|
||
x = flow.utils.tensor.from_torch(x) | ||
nonlocal g | ||
if g is None: | ||
g = get_of() | ||
# TODO(): This is a special pack trick, try to make it general. | ||
return (flow.utils.tensor.to_torch(g(x)),) | ||
|
||
@register_backend | ||
@fake_tensor_unsupported | ||
def oneflowc(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): | ||
print("my_compiler() called with FX graph:") | ||
gm.graph.print_tabular() | ||
gm.forward = torch_interplay | ||
return gm.forward # return a python callable | ||
|
||
@torch.compile(backend="oneflowc") | ||
def fn(x): | ||
y = linear(x) | ||
return y | ||
|
||
compile_out = fn(x) | ||
test_case.assertTrue( | ||
np.allclose( | ||
compile_out.cpu().detach().numpy(), | ||
eager_out.cpu().detach().numpy(), | ||
1e-05, | ||
1e-05, | ||
) | ||
) | ||
compile_out = fn(x) | ||
test_case.assertTrue( | ||
np.allclose( | ||
compile_out.cpu().detach().numpy(), | ||
eager_out.cpu().detach().numpy(), | ||
1e-05, | ||
1e-05, | ||
) | ||
) | ||
|
||
|
||
@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") | ||
@oneflow.unittest.skip_unless_1n1d() | ||
class TestAsTorchBackend(oneflow.unittest.TestCase): | ||
def test_relu(test_case): | ||
_test_relu(test_case, "cuda") | ||
|
||
def test_linear(test_case): | ||
_test_linear(test_case, "cuda") | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
""" | ||
Copyright 2020 The OneFlow Authors. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
from oneflow.utils.backend.torch_compile import register_oneflowc | ||
|
||
register_oneflowc() | ||
|
||
__all__ = [ | ||
"register_oneflowc", | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
""" | ||
Copyright 2020 The OneFlow Authors. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
import oneflow | ||
from .torch_fx_to_oneflow import to_of_transform | ||
|
||
def register_oneflowc(): | ||
from typing import List | ||
import torch | ||
from torch import fx | ||
from torch._dynamo.backends.registry import register_backend | ||
from torch._dynamo.backends.common import fake_tensor_unsupported | ||
|
||
@register_backend | ||
@fake_tensor_unsupported | ||
def oneflowc(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oneflowc的作为后端 |
||
# TODO(): fxGraphModule to nn.Graph | ||
print("my_compiler() called with FX graph:") | ||
gm.graph.print_tabular() | ||
print("gm ", gm) | ||
|
||
import oneflow as flow | ||
|
||
of_gm = to_of_transform(gm) | ||
|
||
@flow.nn.Graph.trace | ||
def oneflow_fn(inputs): | ||
outs = of_gm.forward(inputs) | ||
return outs | ||
|
||
oneflow_fn.debug(1) | ||
|
||
def from_to_torch(inputs): | ||
flow_inputs = flow.utils.tensor.from_torch(inputs) | ||
flow_outs = oneflow_fn(flow_inputs) | ||
# TODO(): general output process | ||
outs = flow.utils.tensor.to_torch(flow_outs[0]) | ||
return (outs, ) | ||
|
||
return from_to_torch |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
""" | ||
Copyright 2020 The OneFlow Authors. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
import oneflow | ||
import torch | ||
from torch import fx | ||
|
||
def to_of_transform(gm: torch.fx.GraphModule, | ||
strint marked this conversation as resolved.
Show resolved
Hide resolved
|
||
tracer_class : type = fx.Tracer) -> torch.fx.GraphModule: | ||
for node in gm.graph.nodes: | ||
# Checks if we're calling a function (i.e: | ||
# torch.add) | ||
if node.op == 'call_function': | ||
# The target attribute is the function | ||
# that call_function calls. | ||
if node.target == torch.relu: | ||
node.target = oneflow.relu | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这是最早的思路,通过 torch python api -> torch fx -> oneflow python api There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 我觉得这个方法挺好的啊,这里都不用这样写吧。 getattr(oneflow, ".".join(node.target.split(".")[1::])) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 哦不对,要把function对象取得名字,应该有办法吧,node这个对象也会有字符串的field吧 |
||
|
||
gm.graph.lint() | ||
gm.recompile() | ||
return gm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
注册 oneflowc