Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat oneflow as backend of torch #10205

Open
wants to merge 33 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
bd31ce2
graph interplay
strint Apr 26, 2023
3a52d51
rm unused
strint Apr 26, 2023
e2761aa
rm unused
strint Apr 26, 2023
0eb42fd
add relu test
strint Apr 27, 2023
0191d98
auto format by CI
oneflow-ci-bot Apr 27, 2023
1c0e540
mvp passed
strint Apr 27, 2023
8564420
mvp passed
strint Apr 27, 2023
c7e8aaa
auto format by CI
oneflow-ci-bot Apr 27, 2023
0498b8e
format mvp
strint Apr 27, 2023
231ead4
merge upstream
strint Apr 27, 2023
459a618
Update torch_fx_to_oneflow.py
strint Apr 27, 2023
00bea12
Update torch_compile.py
strint Apr 27, 2023
f88f9f2
Update __init__.py
strint Apr 27, 2023
5473588
auto format by CI
oneflow-ci-bot Apr 27, 2023
7feeec7
rename
strint Apr 27, 2023
6117e5c
auto format by CI
oneflow-ci-bot Apr 27, 2023
940fc5b
rename and ctrl graph
strint Apr 27, 2023
6f9e193
Merge branch 'feat_interplay' of https://github.com/Oneflow-Inc/onefl…
strint Apr 27, 2023
0725f11
format
strint Apr 27, 2023
4e3fc4a
auto format by CI
oneflow-ci-bot Apr 27, 2023
0c697c5
refine enable grpah
strint Apr 27, 2023
4c45918
Merge branch 'feat_interplay' of https://github.com/Oneflow-Inc/onefl…
strint Apr 27, 2023
021fde9
auto format by CI
oneflow-ci-bot Apr 27, 2023
37f0bae
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
strint Aug 10, 2023
5ba20c2
split fx transform
strint Aug 11, 2023
364067a
auto format by CI
oneflow-ci-bot Aug 11, 2023
581fab9
get torch script
strint Aug 11, 2023
1d495a1
Merge branch 'feat_interplay' of https://github.com/Oneflow-Inc/onefl…
strint Aug 11, 2023
d20f1d0
auto format by CI
oneflow-ci-bot Aug 11, 2023
0194b08
add gm in graph test
strint Aug 16, 2023
3a543a1
test linear passed
strint Aug 16, 2023
0d72cd7
test linear passed with fx graph and oneflow graph
strint Aug 16, 2023
c45f52d
auto format by CI
oneflow-ci-bot Aug 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 170 additions & 0 deletions python/oneflow/test/graph/test_graph_interplay.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import unittest
import numpy as np

import oneflow.unittest


def _test_relu(test_case, device):
from typing import List
import torch
from oneflow.utils.backend.torch_compile import register_oneflowc
Copy link
Contributor Author

@strint strint Apr 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注册 oneflowc


input_arr = np.array(
[
[-0.94630778, -0.83378579, -0.87060891],
[2.0289922, -0.28708987, -2.18369248],
[0.35217619, -0.67095644, -1.58943879],
[0.08086036, -1.81075924, 1.20752494],
[0.8901075, -0.49976737, -1.07153746],
[-0.44872912, -1.07275683, 0.06256855],
[-0.22556897, 0.74798368, 0.90416439],
[0.48339456, -2.32742195, -0.59321527],
],
dtype=np.float32,
)
x = torch.tensor(input_arr, device=device)
eager_out = torch.relu(x)

@torch.compile(backend="oneflowc")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用 oneflowc

def fn(x):
y = torch.relu(x)
return y

compile_out = fn(x)
test_case.assertTrue(
np.allclose(
compile_out.cpu().detach().numpy(),
eager_out.cpu().detach().numpy(),
1e-05,
1e-05,
)
)
compile_out = fn(x)
test_case.assertTrue(
np.allclose(
compile_out.cpu().detach().numpy(),
eager_out.cpu().detach().numpy(),
1e-05,
1e-05,
)
)


def _test_linear(test_case, device):
from typing import List
import torch
from torch._dynamo.backends.registry import register_backend
from torch._dynamo.backends.common import fake_tensor_unsupported

linear = torch.nn.Linear(3, 8, False)
linear = linear.to(device)
input_arr = np.array(
[
[-0.94630778, -0.83378579, -0.87060891],
[2.0289922, -0.28708987, -2.18369248],
[0.35217619, -0.67095644, -1.58943879],
[0.08086036, -1.81075924, 1.20752494],
[0.8901075, -0.49976737, -1.07153746],
[-0.44872912, -1.07275683, 0.06256855],
[-0.22556897, 0.74798368, 0.90416439],
[0.48339456, -2.32742195, -0.59321527],
],
dtype=np.float32,
)
x = torch.tensor(input_arr, device=device)
torch.nn.init.constant_(linear.weight, 2.3)
eager_out = linear(x)
Copy link
Contributor Author

@strint strint Aug 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后来 bbuff 尝试用 fx 的思路做 lnear 这种带状态的 module 时,遇到了参数 tensor 共享也需要考虑和 oneflow 共享的麻烦。


def get_of():
# TODO(): transform torch fx code to oneflow code
import oneflow as flow

linear = flow.nn.Linear(3, 8, False)
linear = linear.to(device)
flow.nn.init.constant_(linear.weight, 2.3)

class LinearGraph(flow.nn.Graph):
def __init__(self):
super().__init__()
self.my_linear = linear

def build(self, x):
return self.my_linear(x)

linear_g = LinearGraph()
linear_g.debug(1)
return linear_g

g = None

def torch_interplay(x):
import oneflow as flow

x = flow.utils.tensor.from_torch(x)
nonlocal g
if g is None:
g = get_of()
# TODO(): This is a special pack trick, try to make it general.
return (flow.utils.tensor.to_torch(g(x)),)

@register_backend
@fake_tensor_unsupported
def oneflowc(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
print("my_compiler() called with FX graph:")
gm.graph.print_tabular()
gm.forward = torch_interplay
return gm.forward # return a python callable

@torch.compile(backend="oneflowc")
def fn(x):
y = linear(x)
return y

compile_out = fn(x)
test_case.assertTrue(
np.allclose(
compile_out.cpu().detach().numpy(),
eager_out.cpu().detach().numpy(),
1e-05,
1e-05,
)
)
compile_out = fn(x)
test_case.assertTrue(
np.allclose(
compile_out.cpu().detach().numpy(),
eager_out.cpu().detach().numpy(),
1e-05,
1e-05,
)
)


@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
@oneflow.unittest.skip_unless_1n1d()
class TestAsTorchBackend(oneflow.unittest.TestCase):
def test_relu(test_case):
_test_relu(test_case, "cuda")

def test_linear(test_case):
_test_linear(test_case, "cuda")


if __name__ == "__main__":
unittest.main()
22 changes: 22 additions & 0 deletions python/oneflow/utils/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from oneflow.utils.backend.torch_compile import register_oneflowc

register_oneflowc()

__all__ = [
"register_oneflowc",
]
52 changes: 52 additions & 0 deletions python/oneflow/utils/backend/torch_compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import oneflow
from .torch_fx_to_oneflow import to_of_transform

def register_oneflowc():
from typing import List
import torch
from torch import fx
from torch._dynamo.backends.registry import register_backend
from torch._dynamo.backends.common import fake_tensor_unsupported

@register_backend
@fake_tensor_unsupported
def oneflowc(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oneflowc的作为后端

# TODO(): fxGraphModule to nn.Graph
print("my_compiler() called with FX graph:")
gm.graph.print_tabular()
print("gm ", gm)

import oneflow as flow

of_gm = to_of_transform(gm)

@flow.nn.Graph.trace
def oneflow_fn(inputs):
outs = of_gm.forward(inputs)
return outs

oneflow_fn.debug(1)

def from_to_torch(inputs):
flow_inputs = flow.utils.tensor.from_torch(inputs)
flow_outs = oneflow_fn(flow_inputs)
# TODO(): general output process
outs = flow.utils.tensor.to_torch(flow_outs[0])
return (outs, )

return from_to_torch
33 changes: 33 additions & 0 deletions python/oneflow/utils/backend/torch_fx_to_oneflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import oneflow
import torch
from torch import fx

def to_of_transform(gm: torch.fx.GraphModule,
strint marked this conversation as resolved.
Show resolved Hide resolved
tracer_class : type = fx.Tracer) -> torch.fx.GraphModule:
for node in gm.graph.nodes:
# Checks if we're calling a function (i.e:
# torch.add)
if node.op == 'call_function':
# The target attribute is the function
# that call_function calls.
if node.target == torch.relu:
node.target = oneflow.relu
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这是最早的思路,通过 torch python api -> torch fx -> oneflow python api

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我觉得这个方法挺好的啊,这里都不用这样写吧。
可以写着这样通用的?

getattr(oneflow, ".".join(node.target.split(".")[1::]))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

哦不对,要把function对象取得名字,应该有办法吧,node这个对象也会有字符串的field吧


gm.graph.lint()
gm.recompile()
return gm