Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat oneflow as backend of torch #10205

Open
wants to merge 33 commits into
base: master
Choose a base branch
from
Open
Changes from 4 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
bd31ce2
graph interplay
strint Apr 26, 2023
3a52d51
rm unused
strint Apr 26, 2023
e2761aa
rm unused
strint Apr 26, 2023
0eb42fd
add relu test
strint Apr 27, 2023
0191d98
auto format by CI
oneflow-ci-bot Apr 27, 2023
1c0e540
mvp passed
strint Apr 27, 2023
8564420
mvp passed
strint Apr 27, 2023
c7e8aaa
auto format by CI
oneflow-ci-bot Apr 27, 2023
0498b8e
format mvp
strint Apr 27, 2023
231ead4
merge upstream
strint Apr 27, 2023
459a618
Update torch_fx_to_oneflow.py
strint Apr 27, 2023
00bea12
Update torch_compile.py
strint Apr 27, 2023
f88f9f2
Update __init__.py
strint Apr 27, 2023
5473588
auto format by CI
oneflow-ci-bot Apr 27, 2023
7feeec7
rename
strint Apr 27, 2023
6117e5c
auto format by CI
oneflow-ci-bot Apr 27, 2023
940fc5b
rename and ctrl graph
strint Apr 27, 2023
6f9e193
Merge branch 'feat_interplay' of https://github.com/Oneflow-Inc/onefl…
strint Apr 27, 2023
0725f11
format
strint Apr 27, 2023
4e3fc4a
auto format by CI
oneflow-ci-bot Apr 27, 2023
0c697c5
refine enable grpah
strint Apr 27, 2023
4c45918
Merge branch 'feat_interplay' of https://github.com/Oneflow-Inc/onefl…
strint Apr 27, 2023
021fde9
auto format by CI
oneflow-ci-bot Apr 27, 2023
37f0bae
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
strint Aug 10, 2023
5ba20c2
split fx transform
strint Aug 11, 2023
364067a
auto format by CI
oneflow-ci-bot Aug 11, 2023
581fab9
get torch script
strint Aug 11, 2023
1d495a1
Merge branch 'feat_interplay' of https://github.com/Oneflow-Inc/onefl…
strint Aug 11, 2023
d20f1d0
auto format by CI
oneflow-ci-bot Aug 11, 2023
0194b08
add gm in graph test
strint Aug 16, 2023
3a543a1
test linear passed
strint Aug 16, 2023
0d72cd7
test linear passed with fx graph and oneflow graph
strint Aug 16, 2023
c45f52d
auto format by CI
oneflow-ci-bot Aug 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 167 additions & 0 deletions python/oneflow/test/graph/test_graph_interplay.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import unittest
import numpy as np

import oneflow.unittest

def _test_relu(test_case, device):
from typing import List
import torch
from torch._dynamo.backends.registry import register_backend
from torch._dynamo.backends.common import fake_tensor_unsupported

input_arr = np.array(
[
[-0.94630778, -0.83378579, -0.87060891],
[2.0289922, -0.28708987, -2.18369248],
[0.35217619, -0.67095644, -1.58943879],
[0.08086036, -1.81075924, 1.20752494],
[0.8901075, -0.49976737, -1.07153746],
[-0.44872912, -1.07275683, 0.06256855],
[-0.22556897, 0.74798368, 0.90416439],
[0.48339456, -2.32742195, -0.59321527],
],
dtype=np.float32,
)
x = torch.tensor(input_arr, device=device)
eager_out = torch.relu(x)

@register_backend
@fake_tensor_unsupported
def oneflowc(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
# TODO(): fxGraphModule to nn.Graph
print("my_compiler() called with FX graph:")
gm.graph.print_tabular()

import oneflow as flow

@flow.nn.Graph.trace
def oneflow_fn(inputs):
print("==>inputs ", inputs)
with flow.mock_torch.enable(lazy=True):
import torch
outs = gm.forward(inputs)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里还是会走到 torch

File "<eval_with_key>.1", line 5, in forward
    relu = torch.relu(x);  x = None
TypeError: relu(): argument 'input' (position 1) must be Tensor, not Tensor

Copy link
Contributor Author

@strint strint Apr 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果绕不过去,可能得借助 fx 来改写而不是 mock

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个地方是不是因为gm并不在mock_torch的作用域类所以失效了

#outs = torch.relu(inputs)
return outs

oneflow_fn.debug(1)

def from_to_torch(inputs):
flow_inputs = flow.utils.tensor.from_torch(inputs)
flow_outs = oneflow_fn(flow_inputs)
outs = flow.utils.tensor.to_torch(flow_outs)
return (outs, )

return from_to_torch

@torch.compile(backend='oneflowc')
def fn(x):
y = torch.relu(x)
return y

compile_out = fn(x)
test_case.assertTrue(np.allclose(compile_out.cpu().detach().numpy(), eager_out.cpu().detach().numpy(), 1e-05, 1e-05))
compile_out = fn(x)
test_case.assertTrue(np.allclose(compile_out.cpu().detach().numpy(), eager_out.cpu().detach().numpy(), 1e-05, 1e-05))

def _test_linear(test_case, device):
from typing import List
import torch
from torch._dynamo.backends.registry import register_backend
from torch._dynamo.backends.common import fake_tensor_unsupported

linear = torch.nn.Linear(3, 8, False)
linear = linear.to(device)
input_arr = np.array(
[
[-0.94630778, -0.83378579, -0.87060891],
[2.0289922, -0.28708987, -2.18369248],
[0.35217619, -0.67095644, -1.58943879],
[0.08086036, -1.81075924, 1.20752494],
[0.8901075, -0.49976737, -1.07153746],
[-0.44872912, -1.07275683, 0.06256855],
[-0.22556897, 0.74798368, 0.90416439],
[0.48339456, -2.32742195, -0.59321527],
],
dtype=np.float32,
)
x = torch.tensor(input_arr, device=device)
torch.nn.init.constant_(linear.weight, 2.3)
eager_out = linear(x)
Copy link
Contributor Author

@strint strint Aug 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后来 bbuff 尝试用 fx 的思路做 lnear 这种带状态的 module 时,遇到了参数 tensor 共享也需要考虑和 oneflow 共享的麻烦。


def get_of():
# TODO(): transform torch fx code to oneflow code
import oneflow as flow
linear = flow.nn.Linear(3, 8, False)
linear = linear.to(device)
flow.nn.init.constant_(linear.weight, 2.3)

class LinearGraph(flow.nn.Graph):
def __init__(self):
super().__init__()
self.my_linear = linear

def build(self, x):
return self.my_linear(x)

linear_g = LinearGraph()
linear_g.debug(1)
return linear_g

g = None

def torch_interplay(x):
import oneflow as flow
x = flow.utils.tensor.from_torch(x)
nonlocal g
if g is None:
g = get_of()
# TODO(): This is a special pack trick, try to make it general.
return (flow.utils.tensor.to_torch(g(x)),)


@register_backend
@fake_tensor_unsupported
def oneflowc(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
print("my_compiler() called with FX graph:")
gm.graph.print_tabular()
gm.forward = torch_interplay
return gm.forward # return a python callable

@torch.compile(backend='oneflowc')
def fn(x):
y = linear(x)
return y

compile_out = fn(x)
test_case.assertTrue(np.allclose(compile_out.cpu().detach().numpy(), eager_out.cpu().detach().numpy(), 1e-05, 1e-05))
compile_out = fn(x)
test_case.assertTrue(np.allclose(compile_out.cpu().detach().numpy(), eager_out.cpu().detach().numpy(), 1e-05, 1e-05))


@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
@oneflow.unittest.skip_unless_1n1d()
class TestAsTorchBackend(oneflow.unittest.TestCase):
def test_relu(test_case):
_test_relu(test_case, "cuda")

def test_linear(test_case):
_test_linear(test_case, "cuda")

if __name__ == "__main__":
unittest.main()