From bd31ce2b40f9102ae829887ac6ec69c6aaa45eab Mon Sep 17 00:00:00 2001 From: strint Date: Wed, 26 Apr 2023 23:45:57 +0800 Subject: [PATCH 01/27] graph interplay --- .../test/graph/test_graph_interplay.py | 144 ++++++++++++++++++ 1 file changed, 144 insertions(+) create mode 100644 python/oneflow/test/graph/test_graph_interplay.py diff --git a/python/oneflow/test/graph/test_graph_interplay.py b/python/oneflow/test/graph/test_graph_interplay.py new file mode 100644 index 00000000000..e233804aa0b --- /dev/null +++ b/python/oneflow/test/graph/test_graph_interplay.py @@ -0,0 +1,144 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import os +import unittest +import numpy as np + +import oneflow.unittest + + +def __test_linear(test_case, device): + import oneflow as flow + linear = flow.nn.Linear(3, 8, False) + linear = linear.to(device) + input_arr = np.array( + [ + [-0.94630778, -0.83378579, -0.87060891], + [2.0289922, -0.28708987, -2.18369248], + [0.35217619, -0.67095644, -1.58943879], + [0.08086036, -1.81075924, 1.20752494], + [0.8901075, -0.49976737, -1.07153746], + [-0.44872912, -1.07275683, 0.06256855], + [-0.22556897, 0.74798368, 0.90416439], + [0.48339456, -2.32742195, -0.59321527], + ], + dtype=np.float32, + ) + np_weight = np.ones((3, 8)).astype(np.float32) + np_weight.fill(2.3) + x = flow.tensor(input_arr, device=device) + flow.nn.init.constant_(linear.weight, 2.3) + of_eager_out = linear(x) + np_out = np.matmul(input_arr, np_weight) + test_case.assertTrue(np.allclose(of_eager_out.numpy(), np_out, 1e-05, 1e-05)) + + class LinearGraph(flow.nn.Graph): + def __init__(self): + super().__init__() + self.my_linear = linear + + def build(self, x): + return self.my_linear(x) + + linear_g = LinearGraph() + linear_g.debug(0) + of_lazy_out = linear_g(x) + test_case.assertTrue(np.array_equal(of_lazy_out.numpy(), of_eager_out.numpy())) + +def _test_linear(test_case, device): + from typing import List + import torch + + linear = torch.nn.Linear(3, 8, False) + linear = linear.to(device) + input_arr = np.array( + [ + [-0.94630778, -0.83378579, -0.87060891], + [2.0289922, -0.28708987, -2.18369248], + [0.35217619, -0.67095644, -1.58943879], + [0.08086036, -1.81075924, 1.20752494], + [0.8901075, -0.49976737, -1.07153746], + [-0.44872912, -1.07275683, 0.06256855], + [-0.22556897, 0.74798368, 0.90416439], + [0.48339456, -2.32742195, -0.59321527], + ], + dtype=np.float32, + ) + np_weight = np.ones((3, 8)).astype(np.float32) + np_weight.fill(2.3) + x = torch.tensor(input_arr, device=device) + torch.nn.init.constant_(linear.weight, 2.3) + eager_out = linear(x) + + np_out = np.matmul(input_arr, np_weight) + test_case.assertTrue(np.allclose(eager_out.cpu().detach().numpy(), np_out, 1e-05, 1e-05)) + + def get_of(): + # TODO(): transform torch fx code to oneflow code + import oneflow as flow + linear = flow.nn.Linear(3, 8, False) + linear = linear.to(device) + flow.nn.init.constant_(linear.weight, 2.3) + + class LinearGraph(flow.nn.Graph): + def __init__(self): + super().__init__() + self.my_linear = linear + + def build(self, x): + return self.my_linear(x) + + linear_g = LinearGraph() + linear_g.debug(1) + return linear_g + + g = None + + def torch_interplay(x): + import oneflow as flow + x = flow.utils.tensor.from_torch(x) + nonlocal g + if g is None: + g = get_of() + # TODO(): This is a special pack trick, try to make it general. + return (flow.utils.tensor.to_torch(g(x)),) + + + def of_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): + print("my_compiler() called with FX graph:") + gm.graph.print_tabular() + gm.forward = torch_interplay + return gm.forward # return a python callable + + @torch.compile(backend=of_compiler) + def fn(x): + y = linear(x) + return y + + compile_out = fn(x) + test_case.assertTrue(np.allclose(compile_out.cpu().detach().numpy(), np_out, 1e-05, 1e-05)) + compile_out = fn(x) + test_case.assertTrue(np.allclose(compile_out.cpu().detach().numpy(), np_out, 1e-05, 1e-05)) + + +@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") +@oneflow.unittest.skip_unless_1n1d() +class TestLinear(oneflow.unittest.TestCase): + def test_linear_(test_case): + _test_linear(test_case, "cuda") + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From 3a52d51c77d5bd6eba8e58d1c404ca2b89a17f0d Mon Sep 17 00:00:00 2001 From: strint Date: Wed, 26 Apr 2023 23:51:55 +0800 Subject: [PATCH 02/27] rm unused --- .../test/graph/test_graph_interplay.py | 39 ------------------- 1 file changed, 39 deletions(-) diff --git a/python/oneflow/test/graph/test_graph_interplay.py b/python/oneflow/test/graph/test_graph_interplay.py index e233804aa0b..13c191f425c 100644 --- a/python/oneflow/test/graph/test_graph_interplay.py +++ b/python/oneflow/test/graph/test_graph_interplay.py @@ -19,45 +19,6 @@ import oneflow.unittest - -def __test_linear(test_case, device): - import oneflow as flow - linear = flow.nn.Linear(3, 8, False) - linear = linear.to(device) - input_arr = np.array( - [ - [-0.94630778, -0.83378579, -0.87060891], - [2.0289922, -0.28708987, -2.18369248], - [0.35217619, -0.67095644, -1.58943879], - [0.08086036, -1.81075924, 1.20752494], - [0.8901075, -0.49976737, -1.07153746], - [-0.44872912, -1.07275683, 0.06256855], - [-0.22556897, 0.74798368, 0.90416439], - [0.48339456, -2.32742195, -0.59321527], - ], - dtype=np.float32, - ) - np_weight = np.ones((3, 8)).astype(np.float32) - np_weight.fill(2.3) - x = flow.tensor(input_arr, device=device) - flow.nn.init.constant_(linear.weight, 2.3) - of_eager_out = linear(x) - np_out = np.matmul(input_arr, np_weight) - test_case.assertTrue(np.allclose(of_eager_out.numpy(), np_out, 1e-05, 1e-05)) - - class LinearGraph(flow.nn.Graph): - def __init__(self): - super().__init__() - self.my_linear = linear - - def build(self, x): - return self.my_linear(x) - - linear_g = LinearGraph() - linear_g.debug(0) - of_lazy_out = linear_g(x) - test_case.assertTrue(np.array_equal(of_lazy_out.numpy(), of_eager_out.numpy())) - def _test_linear(test_case, device): from typing import List import torch From e2761aa9c667b76c2a01121122bb7ef36ca7bfc4 Mon Sep 17 00:00:00 2001 From: strint Date: Wed, 26 Apr 2023 23:52:39 +0800 Subject: [PATCH 03/27] rm unused --- python/oneflow/test/graph/test_graph_interplay.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/oneflow/test/graph/test_graph_interplay.py b/python/oneflow/test/graph/test_graph_interplay.py index 13c191f425c..8429333db24 100644 --- a/python/oneflow/test/graph/test_graph_interplay.py +++ b/python/oneflow/test/graph/test_graph_interplay.py @@ -102,4 +102,4 @@ def test_linear_(test_case): _test_linear(test_case, "cuda") if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() From 0eb42fd09eb20e68d415d4f044eefe1a37e97f1f Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 27 Apr 2023 14:45:35 +0800 Subject: [PATCH 04/27] add relu test --- .../test/graph/test_graph_interplay.py | 84 ++++++++++++++++--- 1 file changed, 73 insertions(+), 11 deletions(-) diff --git a/python/oneflow/test/graph/test_graph_interplay.py b/python/oneflow/test/graph/test_graph_interplay.py index 8429333db24..3a61c2510aa 100644 --- a/python/oneflow/test/graph/test_graph_interplay.py +++ b/python/oneflow/test/graph/test_graph_interplay.py @@ -19,9 +19,71 @@ import oneflow.unittest +def _test_relu(test_case, device): + from typing import List + import torch + from torch._dynamo.backends.registry import register_backend + from torch._dynamo.backends.common import fake_tensor_unsupported + + input_arr = np.array( + [ + [-0.94630778, -0.83378579, -0.87060891], + [2.0289922, -0.28708987, -2.18369248], + [0.35217619, -0.67095644, -1.58943879], + [0.08086036, -1.81075924, 1.20752494], + [0.8901075, -0.49976737, -1.07153746], + [-0.44872912, -1.07275683, 0.06256855], + [-0.22556897, 0.74798368, 0.90416439], + [0.48339456, -2.32742195, -0.59321527], + ], + dtype=np.float32, + ) + x = torch.tensor(input_arr, device=device) + eager_out = torch.relu(x) + + @register_backend + @fake_tensor_unsupported + def oneflowc(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): + # TODO(): fxGraphModule to nn.Graph + print("my_compiler() called with FX graph:") + gm.graph.print_tabular() + + import oneflow as flow + + @flow.nn.Graph.trace + def oneflow_fn(inputs): + print("==>inputs ", inputs) + with flow.mock_torch.enable(lazy=True): + import torch + outs = gm.forward(inputs) + #outs = torch.relu(inputs) + return outs + + oneflow_fn.debug(1) + + def from_to_torch(inputs): + flow_inputs = flow.utils.tensor.from_torch(inputs) + flow_outs = oneflow_fn(flow_inputs) + outs = flow.utils.tensor.to_torch(flow_outs) + return (outs, ) + + return from_to_torch + + @torch.compile(backend='oneflowc') + def fn(x): + y = torch.relu(x) + return y + + compile_out = fn(x) + test_case.assertTrue(np.allclose(compile_out.cpu().detach().numpy(), eager_out.cpu().detach().numpy(), 1e-05, 1e-05)) + compile_out = fn(x) + test_case.assertTrue(np.allclose(compile_out.cpu().detach().numpy(), eager_out.cpu().detach().numpy(), 1e-05, 1e-05)) + def _test_linear(test_case, device): from typing import List import torch + from torch._dynamo.backends.registry import register_backend + from torch._dynamo.backends.common import fake_tensor_unsupported linear = torch.nn.Linear(3, 8, False) linear = linear.to(device) @@ -38,15 +100,10 @@ def _test_linear(test_case, device): ], dtype=np.float32, ) - np_weight = np.ones((3, 8)).astype(np.float32) - np_weight.fill(2.3) x = torch.tensor(input_arr, device=device) torch.nn.init.constant_(linear.weight, 2.3) eager_out = linear(x) - np_out = np.matmul(input_arr, np_weight) - test_case.assertTrue(np.allclose(eager_out.cpu().detach().numpy(), np_out, 1e-05, 1e-05)) - def get_of(): # TODO(): transform torch fx code to oneflow code import oneflow as flow @@ -78,27 +135,32 @@ def torch_interplay(x): return (flow.utils.tensor.to_torch(g(x)),) - def of_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): + @register_backend + @fake_tensor_unsupported + def oneflowc(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): print("my_compiler() called with FX graph:") gm.graph.print_tabular() gm.forward = torch_interplay return gm.forward # return a python callable - @torch.compile(backend=of_compiler) + @torch.compile(backend='oneflowc') def fn(x): y = linear(x) return y compile_out = fn(x) - test_case.assertTrue(np.allclose(compile_out.cpu().detach().numpy(), np_out, 1e-05, 1e-05)) + test_case.assertTrue(np.allclose(compile_out.cpu().detach().numpy(), eager_out.cpu().detach().numpy(), 1e-05, 1e-05)) compile_out = fn(x) - test_case.assertTrue(np.allclose(compile_out.cpu().detach().numpy(), np_out, 1e-05, 1e-05)) + test_case.assertTrue(np.allclose(compile_out.cpu().detach().numpy(), eager_out.cpu().detach().numpy(), 1e-05, 1e-05)) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @oneflow.unittest.skip_unless_1n1d() -class TestLinear(oneflow.unittest.TestCase): - def test_linear_(test_case): +class TestAsTorchBackend(oneflow.unittest.TestCase): + def test_relu(test_case): + _test_relu(test_case, "cuda") + + def test_linear(test_case): _test_linear(test_case, "cuda") if __name__ == "__main__": From 0191d98ffb1d437315a343db1a1e8b2a977ac8af Mon Sep 17 00:00:00 2001 From: oneflow-ci-bot Date: Thu, 27 Apr 2023 06:49:59 +0000 Subject: [PATCH 05/27] auto format by CI --- .../test/graph/test_graph_interplay.py | 59 +++++++++++++++---- 1 file changed, 46 insertions(+), 13 deletions(-) diff --git a/python/oneflow/test/graph/test_graph_interplay.py b/python/oneflow/test/graph/test_graph_interplay.py index 3a61c2510aa..caae8281d27 100644 --- a/python/oneflow/test/graph/test_graph_interplay.py +++ b/python/oneflow/test/graph/test_graph_interplay.py @@ -19,6 +19,7 @@ import oneflow.unittest + def _test_relu(test_case, device): from typing import List import torch @@ -55,29 +56,45 @@ def oneflow_fn(inputs): print("==>inputs ", inputs) with flow.mock_torch.enable(lazy=True): import torch + outs = gm.forward(inputs) - #outs = torch.relu(inputs) + # outs = torch.relu(inputs) return outs oneflow_fn.debug(1) - + def from_to_torch(inputs): flow_inputs = flow.utils.tensor.from_torch(inputs) flow_outs = oneflow_fn(flow_inputs) outs = flow.utils.tensor.to_torch(flow_outs) - return (outs, ) + return (outs,) return from_to_torch - @torch.compile(backend='oneflowc') + @torch.compile(backend="oneflowc") def fn(x): y = torch.relu(x) - return y + return y compile_out = fn(x) - test_case.assertTrue(np.allclose(compile_out.cpu().detach().numpy(), eager_out.cpu().detach().numpy(), 1e-05, 1e-05)) + test_case.assertTrue( + np.allclose( + compile_out.cpu().detach().numpy(), + eager_out.cpu().detach().numpy(), + 1e-05, + 1e-05, + ) + ) compile_out = fn(x) - test_case.assertTrue(np.allclose(compile_out.cpu().detach().numpy(), eager_out.cpu().detach().numpy(), 1e-05, 1e-05)) + test_case.assertTrue( + np.allclose( + compile_out.cpu().detach().numpy(), + eager_out.cpu().detach().numpy(), + 1e-05, + 1e-05, + ) + ) + def _test_linear(test_case, device): from typing import List @@ -107,6 +124,7 @@ def _test_linear(test_case, device): def get_of(): # TODO(): transform torch fx code to oneflow code import oneflow as flow + linear = flow.nn.Linear(3, 8, False) linear = linear.to(device) flow.nn.init.constant_(linear.weight, 2.3) @@ -122,11 +140,12 @@ def build(self, x): linear_g = LinearGraph() linear_g.debug(1) return linear_g - + g = None def torch_interplay(x): import oneflow as flow + x = flow.utils.tensor.from_torch(x) nonlocal g if g is None: @@ -134,7 +153,6 @@ def torch_interplay(x): # TODO(): This is a special pack trick, try to make it general. return (flow.utils.tensor.to_torch(g(x)),) - @register_backend @fake_tensor_unsupported def oneflowc(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): @@ -143,15 +161,29 @@ def oneflowc(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): gm.forward = torch_interplay return gm.forward # return a python callable - @torch.compile(backend='oneflowc') + @torch.compile(backend="oneflowc") def fn(x): y = linear(x) - return y + return y compile_out = fn(x) - test_case.assertTrue(np.allclose(compile_out.cpu().detach().numpy(), eager_out.cpu().detach().numpy(), 1e-05, 1e-05)) + test_case.assertTrue( + np.allclose( + compile_out.cpu().detach().numpy(), + eager_out.cpu().detach().numpy(), + 1e-05, + 1e-05, + ) + ) compile_out = fn(x) - test_case.assertTrue(np.allclose(compile_out.cpu().detach().numpy(), eager_out.cpu().detach().numpy(), 1e-05, 1e-05)) + test_case.assertTrue( + np.allclose( + compile_out.cpu().detach().numpy(), + eager_out.cpu().detach().numpy(), + 1e-05, + 1e-05, + ) + ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @@ -163,5 +195,6 @@ def test_relu(test_case): def test_linear(test_case): _test_linear(test_case, "cuda") + if __name__ == "__main__": unittest.main() From 1c0e5403af884d1a968886299b0cb2191c7926d9 Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 27 Apr 2023 15:52:30 +0800 Subject: [PATCH 06/27] mvp passed --- .../test/graph/test_graph_interplay.py | 31 +++++++++++++++---- 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/python/oneflow/test/graph/test_graph_interplay.py b/python/oneflow/test/graph/test_graph_interplay.py index 3a61c2510aa..f491d0970d1 100644 --- a/python/oneflow/test/graph/test_graph_interplay.py +++ b/python/oneflow/test/graph/test_graph_interplay.py @@ -24,6 +24,7 @@ def _test_relu(test_case, device): import torch from torch._dynamo.backends.registry import register_backend from torch._dynamo.backends.common import fake_tensor_unsupported + from torch import fx input_arr = np.array( [ @@ -41,22 +42,39 @@ def _test_relu(test_case, device): x = torch.tensor(input_arr, device=device) eager_out = torch.relu(x) + def to_of_transform(gm: torch.fx.GraphModule, + tracer_class : type = fx.Tracer) -> torch.fx.GraphModule: + for node in gm.graph.nodes: + # Checks if we're calling a function (i.e: + # torch.add) + if node.op == 'call_function': + # The target attribute is the function + # that call_function calls. + if node.target == torch.relu: + node.target = oneflow.relu + + gm.graph.lint() + gm.recompile() + return gm + @register_backend @fake_tensor_unsupported def oneflowc(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): # TODO(): fxGraphModule to nn.Graph print("my_compiler() called with FX graph:") gm.graph.print_tabular() + print("gm ", gm) import oneflow as flow + of_gm = to_of_transform(gm) + @flow.nn.Graph.trace def oneflow_fn(inputs): - print("==>inputs ", inputs) - with flow.mock_torch.enable(lazy=True): - import torch - outs = gm.forward(inputs) - #outs = torch.relu(inputs) + # with flow.mock_torch.enable(lazy=True): + # import torch + # outs = torch.relu(inputs) + outs = of_gm.forward(inputs) return outs oneflow_fn.debug(1) @@ -64,7 +82,8 @@ def oneflow_fn(inputs): def from_to_torch(inputs): flow_inputs = flow.utils.tensor.from_torch(inputs) flow_outs = oneflow_fn(flow_inputs) - outs = flow.utils.tensor.to_torch(flow_outs) + # TODO(): general output process + outs = flow.utils.tensor.to_torch(flow_outs[0]) return (outs, ) return from_to_torch From c7e8aaad04b1bef852b3dba1da8fc50b061fb7cb Mon Sep 17 00:00:00 2001 From: oneflow-ci-bot Date: Thu, 27 Apr 2023 07:57:35 +0000 Subject: [PATCH 07/27] auto format by CI --- python/oneflow/test/graph/test_graph_interplay.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/oneflow/test/graph/test_graph_interplay.py b/python/oneflow/test/graph/test_graph_interplay.py index d15d58c82b5..79ed4e4acda 100644 --- a/python/oneflow/test/graph/test_graph_interplay.py +++ b/python/oneflow/test/graph/test_graph_interplay.py @@ -43,12 +43,13 @@ def _test_relu(test_case, device): x = torch.tensor(input_arr, device=device) eager_out = torch.relu(x) - def to_of_transform(gm: torch.fx.GraphModule, - tracer_class : type = fx.Tracer) -> torch.fx.GraphModule: + def to_of_transform( + gm: torch.fx.GraphModule, tracer_class: type = fx.Tracer + ) -> torch.fx.GraphModule: for node in gm.graph.nodes: # Checks if we're calling a function (i.e: # torch.add) - if node.op == 'call_function': + if node.op == "call_function": # The target attribute is the function # that call_function calls. if node.target == torch.relu: @@ -85,7 +86,7 @@ def from_to_torch(inputs): flow_outs = oneflow_fn(flow_inputs) # TODO(): general output process outs = flow.utils.tensor.to_torch(flow_outs[0]) - return (outs, ) + return (outs,) return from_to_torch From 0498b8e21b8176ab82d721467b14f998bf8fec6b Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 27 Apr 2023 16:16:29 +0800 Subject: [PATCH 08/27] format mvp --- .../test/graph/test_graph_interplay.py | 50 +----------------- python/oneflow/utils/backend/__init__.py | 22 ++++++++ python/oneflow/utils/backend/torch_compile.py | 52 +++++++++++++++++++ .../utils/backend/torch_fx_to_oneflow.py | 33 ++++++++++++ 4 files changed, 108 insertions(+), 49 deletions(-) create mode 100644 python/oneflow/utils/backend/__init__.py create mode 100644 python/oneflow/utils/backend/torch_compile.py create mode 100644 python/oneflow/utils/backend/torch_fx_to_oneflow.py diff --git a/python/oneflow/test/graph/test_graph_interplay.py b/python/oneflow/test/graph/test_graph_interplay.py index d15d58c82b5..3f358042839 100644 --- a/python/oneflow/test/graph/test_graph_interplay.py +++ b/python/oneflow/test/graph/test_graph_interplay.py @@ -23,9 +23,7 @@ def _test_relu(test_case, device): from typing import List import torch - from torch._dynamo.backends.registry import register_backend - from torch._dynamo.backends.common import fake_tensor_unsupported - from torch import fx + from oneflow.utils.backend.torch_compile import register_oneflowc input_arr = np.array( [ @@ -43,52 +41,6 @@ def _test_relu(test_case, device): x = torch.tensor(input_arr, device=device) eager_out = torch.relu(x) - def to_of_transform(gm: torch.fx.GraphModule, - tracer_class : type = fx.Tracer) -> torch.fx.GraphModule: - for node in gm.graph.nodes: - # Checks if we're calling a function (i.e: - # torch.add) - if node.op == 'call_function': - # The target attribute is the function - # that call_function calls. - if node.target == torch.relu: - node.target = oneflow.relu - - gm.graph.lint() - gm.recompile() - return gm - - @register_backend - @fake_tensor_unsupported - def oneflowc(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): - # TODO(): fxGraphModule to nn.Graph - print("my_compiler() called with FX graph:") - gm.graph.print_tabular() - print("gm ", gm) - - import oneflow as flow - - of_gm = to_of_transform(gm) - - @flow.nn.Graph.trace - def oneflow_fn(inputs): - # with flow.mock_torch.enable(lazy=True): - # import torch - # outs = torch.relu(inputs) - outs = of_gm.forward(inputs) - return outs - - oneflow_fn.debug(1) - - def from_to_torch(inputs): - flow_inputs = flow.utils.tensor.from_torch(inputs) - flow_outs = oneflow_fn(flow_inputs) - # TODO(): general output process - outs = flow.utils.tensor.to_torch(flow_outs[0]) - return (outs, ) - - return from_to_torch - @torch.compile(backend="oneflowc") def fn(x): y = torch.relu(x) diff --git a/python/oneflow/utils/backend/__init__.py b/python/oneflow/utils/backend/__init__.py new file mode 100644 index 00000000000..c956130bfd1 --- /dev/null +++ b/python/oneflow/utils/backend/__init__.py @@ -0,0 +1,22 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +from oneflow.utils.backend.torch_compile import register_oneflowc + +register_oneflowc() + +__all__ = [ + "register_oneflowc", +] \ No newline at end of file diff --git a/python/oneflow/utils/backend/torch_compile.py b/python/oneflow/utils/backend/torch_compile.py new file mode 100644 index 00000000000..7d00e60c90c --- /dev/null +++ b/python/oneflow/utils/backend/torch_compile.py @@ -0,0 +1,52 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import oneflow +from .torch_fx_to_oneflow import to_of_transform + +def register_oneflowc(): + from typing import List + import torch + from torch import fx + from torch._dynamo.backends.registry import register_backend + from torch._dynamo.backends.common import fake_tensor_unsupported + + @register_backend + @fake_tensor_unsupported + def oneflowc(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): + # TODO(): fxGraphModule to nn.Graph + print("my_compiler() called with FX graph:") + gm.graph.print_tabular() + print("gm ", gm) + + import oneflow as flow + + of_gm = to_of_transform(gm) + + @flow.nn.Graph.trace + def oneflow_fn(inputs): + outs = of_gm.forward(inputs) + return outs + + oneflow_fn.debug(1) + + def from_to_torch(inputs): + flow_inputs = flow.utils.tensor.from_torch(inputs) + flow_outs = oneflow_fn(flow_inputs) + # TODO(): general output process + outs = flow.utils.tensor.to_torch(flow_outs[0]) + return (outs, ) + + return from_to_torch \ No newline at end of file diff --git a/python/oneflow/utils/backend/torch_fx_to_oneflow.py b/python/oneflow/utils/backend/torch_fx_to_oneflow.py new file mode 100644 index 00000000000..92a3ed90ea1 --- /dev/null +++ b/python/oneflow/utils/backend/torch_fx_to_oneflow.py @@ -0,0 +1,33 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import oneflow +import torch +from torch import fx + +def to_of_transform(gm: torch.fx.GraphModule, + tracer_class : type = fx.Tracer) -> torch.fx.GraphModule: + for node in gm.graph.nodes: + # Checks if we're calling a function (i.e: + # torch.add) + if node.op == 'call_function': + # The target attribute is the function + # that call_function calls. + if node.target == torch.relu: + node.target = oneflow.relu + + gm.graph.lint() + gm.recompile() + return gm \ No newline at end of file From 459a618b5b9bb9fc44889b508195a1e6dcaed4b1 Mon Sep 17 00:00:00 2001 From: Xiaoyu Xu Date: Thu, 27 Apr 2023 16:19:35 +0800 Subject: [PATCH 09/27] Update torch_fx_to_oneflow.py --- python/oneflow/utils/backend/torch_fx_to_oneflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/oneflow/utils/backend/torch_fx_to_oneflow.py b/python/oneflow/utils/backend/torch_fx_to_oneflow.py index 92a3ed90ea1..9d9fc3dae32 100644 --- a/python/oneflow/utils/backend/torch_fx_to_oneflow.py +++ b/python/oneflow/utils/backend/torch_fx_to_oneflow.py @@ -30,4 +30,4 @@ def to_of_transform(gm: torch.fx.GraphModule, gm.graph.lint() gm.recompile() - return gm \ No newline at end of file + return gm From 00bea12b390c67fa39e56c4edade62bcf1e08d19 Mon Sep 17 00:00:00 2001 From: Xiaoyu Xu Date: Thu, 27 Apr 2023 16:20:05 +0800 Subject: [PATCH 10/27] Update torch_compile.py --- python/oneflow/utils/backend/torch_compile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/oneflow/utils/backend/torch_compile.py b/python/oneflow/utils/backend/torch_compile.py index 7d00e60c90c..d788f73779f 100644 --- a/python/oneflow/utils/backend/torch_compile.py +++ b/python/oneflow/utils/backend/torch_compile.py @@ -49,4 +49,4 @@ def from_to_torch(inputs): outs = flow.utils.tensor.to_torch(flow_outs[0]) return (outs, ) - return from_to_torch \ No newline at end of file + return from_to_torch From f88f9f2cbc6cc3376a6fe4d1732629f1e24d1ce3 Mon Sep 17 00:00:00 2001 From: Xiaoyu Xu Date: Thu, 27 Apr 2023 16:20:20 +0800 Subject: [PATCH 11/27] Update __init__.py --- python/oneflow/utils/backend/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/oneflow/utils/backend/__init__.py b/python/oneflow/utils/backend/__init__.py index c956130bfd1..5a71d879aea 100644 --- a/python/oneflow/utils/backend/__init__.py +++ b/python/oneflow/utils/backend/__init__.py @@ -19,4 +19,4 @@ __all__ = [ "register_oneflowc", -] \ No newline at end of file +] From 547358851e818ef3e75c93cdf0399bd1d97bfe31 Mon Sep 17 00:00:00 2001 From: oneflow-ci-bot Date: Thu, 27 Apr 2023 08:22:14 +0000 Subject: [PATCH 12/27] auto format by CI --- python/oneflow/utils/backend/torch_compile.py | 3 ++- python/oneflow/utils/backend/torch_fx_to_oneflow.py | 8 +++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/python/oneflow/utils/backend/torch_compile.py b/python/oneflow/utils/backend/torch_compile.py index d788f73779f..abd899e66c7 100644 --- a/python/oneflow/utils/backend/torch_compile.py +++ b/python/oneflow/utils/backend/torch_compile.py @@ -16,6 +16,7 @@ import oneflow from .torch_fx_to_oneflow import to_of_transform + def register_oneflowc(): from typing import List import torch @@ -47,6 +48,6 @@ def from_to_torch(inputs): flow_outs = oneflow_fn(flow_inputs) # TODO(): general output process outs = flow.utils.tensor.to_torch(flow_outs[0]) - return (outs, ) + return (outs,) return from_to_torch diff --git a/python/oneflow/utils/backend/torch_fx_to_oneflow.py b/python/oneflow/utils/backend/torch_fx_to_oneflow.py index 9d9fc3dae32..c8faa85e189 100644 --- a/python/oneflow/utils/backend/torch_fx_to_oneflow.py +++ b/python/oneflow/utils/backend/torch_fx_to_oneflow.py @@ -17,12 +17,14 @@ import torch from torch import fx -def to_of_transform(gm: torch.fx.GraphModule, - tracer_class : type = fx.Tracer) -> torch.fx.GraphModule: + +def to_of_transform( + gm: torch.fx.GraphModule, tracer_class: type = fx.Tracer +) -> torch.fx.GraphModule: for node in gm.graph.nodes: # Checks if we're calling a function (i.e: # torch.add) - if node.op == 'call_function': + if node.op == "call_function": # The target attribute is the function # that call_function calls. if node.target == torch.relu: From 7feeec7dfe296e1e09e8a5267937cfb4b23d8d72 Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 27 Apr 2023 18:14:38 +0800 Subject: [PATCH 13/27] rename --- .../test/graph/test_graph_interplay.py | 57 +++++++++++++++++-- python/oneflow/utils/backend/__init__.py | 6 +- python/oneflow/utils/backend/torch_compile.py | 4 +- 3 files changed, 57 insertions(+), 10 deletions(-) diff --git a/python/oneflow/test/graph/test_graph_interplay.py b/python/oneflow/test/graph/test_graph_interplay.py index 3f358042839..1b5da69c9fb 100644 --- a/python/oneflow/test/graph/test_graph_interplay.py +++ b/python/oneflow/test/graph/test_graph_interplay.py @@ -23,7 +23,7 @@ def _test_relu(test_case, device): from typing import List import torch - from oneflow.utils.backend.torch_compile import register_oneflowc + from oneflow.utils.backend.torch_compile import register_oneflowrt input_arr = np.array( [ @@ -41,7 +41,7 @@ def _test_relu(test_case, device): x = torch.tensor(input_arr, device=device) eager_out = torch.relu(x) - @torch.compile(backend="oneflowc") + @torch.compile(backend="oneflowrt") def fn(x): y = torch.relu(x) return y @@ -65,8 +65,55 @@ def fn(x): ) ) - def _test_linear(test_case, device): + from typing import List + import torch + from oneflow.utils.backend.torch_compile import register_oneflowrt + + linear = torch.nn.Linear(3, 8, False) + linear = linear.to(device) + input_arr = np.array( + [ + [-0.94630778, -0.83378579, -0.87060891], + [2.0289922, -0.28708987, -2.18369248], + [0.35217619, -0.67095644, -1.58943879], + [0.08086036, -1.81075924, 1.20752494], + [0.8901075, -0.49976737, -1.07153746], + [-0.44872912, -1.07275683, 0.06256855], + [-0.22556897, 0.74798368, 0.90416439], + [0.48339456, -2.32742195, -0.59321527], + ], + dtype=np.float32, + ) + x = torch.tensor(input_arr, device=device) + torch.nn.init.constant_(linear.weight, 2.3) + eager_out = linear(x) + + @torch.compile(backend="oneflowrt") + def fn(x): + y = linear(x) + return y + + compile_out = fn(x) + test_case.assertTrue( + np.allclose( + compile_out.cpu().detach().numpy(), + eager_out.cpu().detach().numpy(), + 1e-05, + 1e-05, + ) + ) + compile_out = fn(x) + test_case.assertTrue( + np.allclose( + compile_out.cpu().detach().numpy(), + eager_out.cpu().detach().numpy(), + 1e-05, + 1e-05, + ) + ) + +def __test_linear(test_case, device): from typing import List import torch from torch._dynamo.backends.registry import register_backend @@ -125,13 +172,13 @@ def torch_interplay(x): @register_backend @fake_tensor_unsupported - def oneflowc(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): + def oneflowrt(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): print("my_compiler() called with FX graph:") gm.graph.print_tabular() gm.forward = torch_interplay return gm.forward # return a python callable - @torch.compile(backend="oneflowc") + @torch.compile(backend="oneflowrt") def fn(x): y = linear(x) return y diff --git a/python/oneflow/utils/backend/__init__.py b/python/oneflow/utils/backend/__init__.py index 5a71d879aea..614f4a0284f 100644 --- a/python/oneflow/utils/backend/__init__.py +++ b/python/oneflow/utils/backend/__init__.py @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. """ -from oneflow.utils.backend.torch_compile import register_oneflowc +from oneflow.utils.backend.torch_compile import register_oneflowrt -register_oneflowc() +register_oneflowrt() __all__ = [ - "register_oneflowc", + "register_oneflowrt", ] diff --git a/python/oneflow/utils/backend/torch_compile.py b/python/oneflow/utils/backend/torch_compile.py index abd899e66c7..e9642c70449 100644 --- a/python/oneflow/utils/backend/torch_compile.py +++ b/python/oneflow/utils/backend/torch_compile.py @@ -17,7 +17,7 @@ from .torch_fx_to_oneflow import to_of_transform -def register_oneflowc(): +def register_oneflowrt(): from typing import List import torch from torch import fx @@ -26,7 +26,7 @@ def register_oneflowc(): @register_backend @fake_tensor_unsupported - def oneflowc(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): + def oneflowrt(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): # TODO(): fxGraphModule to nn.Graph print("my_compiler() called with FX graph:") gm.graph.print_tabular() From 6117e5c0aec18bd98d9fd74c62c782f300804668 Mon Sep 17 00:00:00 2001 From: oneflow-ci-bot Date: Thu, 27 Apr 2023 10:17:47 +0000 Subject: [PATCH 14/27] auto format by CI --- python/oneflow/test/graph/test_graph_interplay.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/oneflow/test/graph/test_graph_interplay.py b/python/oneflow/test/graph/test_graph_interplay.py index 1b5da69c9fb..19a46e04933 100644 --- a/python/oneflow/test/graph/test_graph_interplay.py +++ b/python/oneflow/test/graph/test_graph_interplay.py @@ -65,6 +65,7 @@ def fn(x): ) ) + def _test_linear(test_case, device): from typing import List import torch @@ -113,6 +114,7 @@ def fn(x): ) ) + def __test_linear(test_case, device): from typing import List import torch From 940fc5b499eef869b33c9a56cc7f90eabeded8f5 Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 27 Apr 2023 18:56:41 +0800 Subject: [PATCH 15/27] rename and ctrl graph --- .../test/graph/test_graph_interplay.py | 9 ++++---- python/oneflow/utils/backend/__init__.py | 6 +++--- python/oneflow/utils/backend/torch_compile.py | 21 ++++++++++++------- 3 files changed, 21 insertions(+), 15 deletions(-) diff --git a/python/oneflow/test/graph/test_graph_interplay.py b/python/oneflow/test/graph/test_graph_interplay.py index 1b5da69c9fb..224601d888b 100644 --- a/python/oneflow/test/graph/test_graph_interplay.py +++ b/python/oneflow/test/graph/test_graph_interplay.py @@ -23,7 +23,7 @@ def _test_relu(test_case, device): from typing import List import torch - from oneflow.utils.backend.torch_compile import register_oneflowrt + from oneflow.utils.backend.torch_compile import register_ofrt input_arr = np.array( [ @@ -41,7 +41,8 @@ def _test_relu(test_case, device): x = torch.tensor(input_arr, device=device) eager_out = torch.relu(x) - @torch.compile(backend="oneflowrt") + os.environ["ofrt_enable_graph"] = "1" + @torch.compile(backend="ofrt") def fn(x): y = torch.relu(x) return y @@ -68,7 +69,7 @@ def fn(x): def _test_linear(test_case, device): from typing import List import torch - from oneflow.utils.backend.torch_compile import register_oneflowrt + from oneflow.utils.backend.torch_compile import register_ofrt linear = torch.nn.Linear(3, 8, False) linear = linear.to(device) @@ -89,7 +90,7 @@ def _test_linear(test_case, device): torch.nn.init.constant_(linear.weight, 2.3) eager_out = linear(x) - @torch.compile(backend="oneflowrt") + @torch.compile(backend="ofrt") def fn(x): y = linear(x) return y diff --git a/python/oneflow/utils/backend/__init__.py b/python/oneflow/utils/backend/__init__.py index 614f4a0284f..fefd352fa2a 100644 --- a/python/oneflow/utils/backend/__init__.py +++ b/python/oneflow/utils/backend/__init__.py @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. """ -from oneflow.utils.backend.torch_compile import register_oneflowrt +from oneflow.utils.backend.torch_compile import register_ofrt -register_oneflowrt() +register_ofrt() __all__ = [ - "register_oneflowrt", + "register_ofrt", ] diff --git a/python/oneflow/utils/backend/torch_compile.py b/python/oneflow/utils/backend/torch_compile.py index e9642c70449..993bb2f8386 100644 --- a/python/oneflow/utils/backend/torch_compile.py +++ b/python/oneflow/utils/backend/torch_compile.py @@ -13,12 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. """ +import os import oneflow from .torch_fx_to_oneflow import to_of_transform -def register_oneflowrt(): - from typing import List +def register_ofrt(): + from typing import List, Optional, Dict, Any import torch from torch import fx from torch._dynamo.backends.registry import register_backend @@ -26,7 +27,7 @@ def register_oneflowrt(): @register_backend @fake_tensor_unsupported - def oneflowrt(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): + def ofrt(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): # TODO(): fxGraphModule to nn.Graph print("my_compiler() called with FX graph:") gm.graph.print_tabular() @@ -36,12 +37,16 @@ def oneflowrt(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): of_gm = to_of_transform(gm) - @flow.nn.Graph.trace - def oneflow_fn(inputs): - outs = of_gm.forward(inputs) - return outs + enable_graph = bool(os.environ.get("ofrt_enable_graph", False)) - oneflow_fn.debug(1) + if not enable_graph: + oneflow_fn = of_gm.forward + else: + @flow.nn.Graph.trace + def oneflow_fn(inputs): + outs = of_gm.forward(inputs) + return outs + oneflow_fn.debug(1) def from_to_torch(inputs): flow_inputs = flow.utils.tensor.from_torch(inputs) From 0725f116e6ff5369dc57288ecadd0c9887a104ca Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 27 Apr 2023 19:02:50 +0800 Subject: [PATCH 16/27] format --- oneflow/extension/stack/stacktrace.h | 2 +- python/oneflow/test/graph/test_graph_interplay.py | 1 + python/oneflow/utils/backend/torch_compile.py | 2 ++ 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/oneflow/extension/stack/stacktrace.h b/oneflow/extension/stack/stacktrace.h index 5e53450f458..e2f5ee8988d 100644 --- a/oneflow/extension/stack/stacktrace.h +++ b/oneflow/extension/stack/stacktrace.h @@ -1123,7 +1123,7 @@ class StackTraceImpl : public StackTraceImplHolder { if (context()) { ucontext_t* uctx = reinterpret_cast(context()); -#ifdef REG_RIP // x86_64 +#ifdef REG_RIP // x86_64 if (uctx->uc_mcontext.gregs[REG_RIP] == reinterpret_cast(error_addr())) { uctx->uc_mcontext.gregs[REG_RIP] = *reinterpret_cast(uctx->uc_mcontext.gregs[REG_RSP]); diff --git a/python/oneflow/test/graph/test_graph_interplay.py b/python/oneflow/test/graph/test_graph_interplay.py index b8de07cc0b0..ed3fdf3d942 100644 --- a/python/oneflow/test/graph/test_graph_interplay.py +++ b/python/oneflow/test/graph/test_graph_interplay.py @@ -42,6 +42,7 @@ def _test_relu(test_case, device): eager_out = torch.relu(x) os.environ["ofrt_enable_graph"] = "1" + @torch.compile(backend="ofrt") def fn(x): y = torch.relu(x) diff --git a/python/oneflow/utils/backend/torch_compile.py b/python/oneflow/utils/backend/torch_compile.py index 993bb2f8386..bd77e4589e5 100644 --- a/python/oneflow/utils/backend/torch_compile.py +++ b/python/oneflow/utils/backend/torch_compile.py @@ -42,10 +42,12 @@ def ofrt(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): if not enable_graph: oneflow_fn = of_gm.forward else: + @flow.nn.Graph.trace def oneflow_fn(inputs): outs = of_gm.forward(inputs) return outs + oneflow_fn.debug(1) def from_to_torch(inputs): From 4e3fc4ae2bb9362f41a24d770110debdad8da45a Mon Sep 17 00:00:00 2001 From: oneflow-ci-bot Date: Thu, 27 Apr 2023 11:05:19 +0000 Subject: [PATCH 17/27] auto format by CI --- oneflow/extension/stack/stacktrace.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oneflow/extension/stack/stacktrace.h b/oneflow/extension/stack/stacktrace.h index e2f5ee8988d..5e53450f458 100644 --- a/oneflow/extension/stack/stacktrace.h +++ b/oneflow/extension/stack/stacktrace.h @@ -1123,7 +1123,7 @@ class StackTraceImpl : public StackTraceImplHolder { if (context()) { ucontext_t* uctx = reinterpret_cast(context()); -#ifdef REG_RIP // x86_64 +#ifdef REG_RIP // x86_64 if (uctx->uc_mcontext.gregs[REG_RIP] == reinterpret_cast(error_addr())) { uctx->uc_mcontext.gregs[REG_RIP] = *reinterpret_cast(uctx->uc_mcontext.gregs[REG_RSP]); From 0c697c51786b1b806442f3398c223e98b9718de1 Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 27 Apr 2023 19:09:36 +0800 Subject: [PATCH 18/27] refine enable grpah --- python/oneflow/utils/backend/torch_compile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/oneflow/utils/backend/torch_compile.py b/python/oneflow/utils/backend/torch_compile.py index bd77e4589e5..1e4e35d65c9 100644 --- a/python/oneflow/utils/backend/torch_compile.py +++ b/python/oneflow/utils/backend/torch_compile.py @@ -37,7 +37,7 @@ def ofrt(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): of_gm = to_of_transform(gm) - enable_graph = bool(os.environ.get("ofrt_enable_graph", False)) + enable_graph = os.getenv("ofrt_enable_graph", 'False').lower() in ('true', '1', 't') if not enable_graph: oneflow_fn = of_gm.forward From 021fde915a0d21fa1ecbc93a2e1d5ca8fefc63c5 Mon Sep 17 00:00:00 2001 From: oneflow-ci-bot Date: Thu, 27 Apr 2023 11:12:47 +0000 Subject: [PATCH 19/27] auto format by CI --- python/oneflow/utils/backend/torch_compile.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/oneflow/utils/backend/torch_compile.py b/python/oneflow/utils/backend/torch_compile.py index 1e4e35d65c9..b16d1c2e584 100644 --- a/python/oneflow/utils/backend/torch_compile.py +++ b/python/oneflow/utils/backend/torch_compile.py @@ -37,7 +37,11 @@ def ofrt(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): of_gm = to_of_transform(gm) - enable_graph = os.getenv("ofrt_enable_graph", 'False').lower() in ('true', '1', 't') + enable_graph = os.getenv("ofrt_enable_graph", "False").lower() in ( + "true", + "1", + "t", + ) if not enable_graph: oneflow_fn = of_gm.forward From 5ba20c236947862d68659678c0f1dea51b597b84 Mon Sep 17 00:00:00 2001 From: strint Date: Fri, 11 Aug 2023 12:53:01 +0000 Subject: [PATCH 20/27] split fx transform --- python/oneflow/utils/backend/torch_compile.py | 23 ++----------------- .../utils/backend/torch_fx_to_oneflow.py | 23 +++++++++++++++++++ 2 files changed, 25 insertions(+), 21 deletions(-) diff --git a/python/oneflow/utils/backend/torch_compile.py b/python/oneflow/utils/backend/torch_compile.py index b16d1c2e584..5b5bfbfdd4d 100644 --- a/python/oneflow/utils/backend/torch_compile.py +++ b/python/oneflow/utils/backend/torch_compile.py @@ -13,9 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. """ -import os import oneflow -from .torch_fx_to_oneflow import to_of_transform +from .torch_fx_to_oneflow import fx_tranform def register_ofrt(): @@ -32,27 +31,9 @@ def ofrt(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): print("my_compiler() called with FX graph:") gm.graph.print_tabular() print("gm ", gm) - import oneflow as flow - of_gm = to_of_transform(gm) - - enable_graph = os.getenv("ofrt_enable_graph", "False").lower() in ( - "true", - "1", - "t", - ) - - if not enable_graph: - oneflow_fn = of_gm.forward - else: - - @flow.nn.Graph.trace - def oneflow_fn(inputs): - outs = of_gm.forward(inputs) - return outs - - oneflow_fn.debug(1) + oneflow_fn = fx_tranform(gm) def from_to_torch(inputs): flow_inputs = flow.utils.tensor.from_torch(inputs) diff --git a/python/oneflow/utils/backend/torch_fx_to_oneflow.py b/python/oneflow/utils/backend/torch_fx_to_oneflow.py index c8faa85e189..ecc80fb489c 100644 --- a/python/oneflow/utils/backend/torch_fx_to_oneflow.py +++ b/python/oneflow/utils/backend/torch_fx_to_oneflow.py @@ -13,11 +13,34 @@ See the License for the specific language governing permissions and limitations under the License. """ +import os import oneflow import torch from torch import fx +def fx_tranform(gm): + import oneflow as flow + + of_gm = to_of_transform(gm) + + enable_graph = os.getenv("ofrt_enable_graph", "False").lower() in ( + "true", + "1", + "t", + ) + + if not enable_graph: + oneflow_fn = of_gm.forward + else: + @flow.nn.Graph.trace + def oneflow_fn(inputs): + outs = of_gm.forward(inputs) + return outs + + oneflow_fn.debug(1) + return oneflow_fn + def to_of_transform( gm: torch.fx.GraphModule, tracer_class: type = fx.Tracer ) -> torch.fx.GraphModule: From 364067ab64594b1c99815e166d5ae21d722bf9ab Mon Sep 17 00:00:00 2001 From: oneflow-ci-bot Date: Fri, 11 Aug 2023 12:54:59 +0000 Subject: [PATCH 21/27] auto format by CI --- python/oneflow/utils/backend/torch_fx_to_oneflow.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/oneflow/utils/backend/torch_fx_to_oneflow.py b/python/oneflow/utils/backend/torch_fx_to_oneflow.py index ecc80fb489c..dc0a20d3f4c 100644 --- a/python/oneflow/utils/backend/torch_fx_to_oneflow.py +++ b/python/oneflow/utils/backend/torch_fx_to_oneflow.py @@ -33,6 +33,7 @@ def fx_tranform(gm): if not enable_graph: oneflow_fn = of_gm.forward else: + @flow.nn.Graph.trace def oneflow_fn(inputs): outs = of_gm.forward(inputs) @@ -41,6 +42,7 @@ def oneflow_fn(inputs): oneflow_fn.debug(1) return oneflow_fn + def to_of_transform( gm: torch.fx.GraphModule, tracer_class: type = fx.Tracer ) -> torch.fx.GraphModule: From 581fab9895526e701fa0c6b8eec3eef27c85d042 Mon Sep 17 00:00:00 2001 From: strint Date: Fri, 11 Aug 2023 14:53:13 +0000 Subject: [PATCH 22/27] get torch script --- .../test/graph/test_graph_interplay.py | 15 ++-- python/oneflow/test/graph/test_torch_jit.py | 25 +++++++ ...orch_fx_to_oneflow.py => from_torch_fx.py} | 0 .../utils/backend/from_torch_script.py | 68 +++++++++++++++++++ python/oneflow/utils/backend/torch_compile.py | 16 ++++- 5 files changed, 117 insertions(+), 7 deletions(-) create mode 100644 python/oneflow/test/graph/test_torch_jit.py rename python/oneflow/utils/backend/{torch_fx_to_oneflow.py => from_torch_fx.py} (100%) create mode 100644 python/oneflow/utils/backend/from_torch_script.py diff --git a/python/oneflow/test/graph/test_graph_interplay.py b/python/oneflow/test/graph/test_graph_interplay.py index ed3fdf3d942..f6b71b180ab 100644 --- a/python/oneflow/test/graph/test_graph_interplay.py +++ b/python/oneflow/test/graph/test_graph_interplay.py @@ -17,10 +17,13 @@ import unittest import numpy as np +# Must import torch before oneflow, otherwise torch.jit.trace will raise error: +# terminate called after throwing an instance of 'pybind11::stop_iteration' +import torch import oneflow.unittest -def _test_relu(test_case, device): +def _test_relu(test_case, device, from_script=False): from typing import List import torch from oneflow.utils.backend.torch_compile import register_ofrt @@ -41,6 +44,7 @@ def _test_relu(test_case, device): x = torch.tensor(input_arr, device=device) eager_out = torch.relu(x) + os.environ["ofrt_from_script"] = str(from_script) os.environ["ofrt_enable_graph"] = "1" @torch.compile(backend="ofrt") @@ -210,10 +214,13 @@ def fn(x): @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @oneflow.unittest.skip_unless_1n1d() class TestAsTorchBackend(oneflow.unittest.TestCase): - def test_relu(test_case): - _test_relu(test_case, "cuda") + def test_relu_with_fx(test_case): + _test_relu(test_case, "cuda", False) - def test_linear(test_case): + def test_relu_with_script(test_case): + _test_relu(test_case, "cuda", True) + + def _test_linear(test_case): _test_linear(test_case, "cuda") diff --git a/python/oneflow/test/graph/test_torch_jit.py b/python/oneflow/test/graph/test_torch_jit.py new file mode 100644 index 00000000000..99949f793bd --- /dev/null +++ b/python/oneflow/test/graph/test_torch_jit.py @@ -0,0 +1,25 @@ +import torch +import oneflow +import numpy as np + +input_arr = np.array( + [ + [-0.94630778, -0.83378579, -0.87060891], + [2.0289922, -0.28708987, -2.18369248], + [0.35217619, -0.67095644, -1.58943879], + [0.08086036, -1.81075924, 1.20752494], + [0.8901075, -0.49976737, -1.07153746], + [-0.44872912, -1.07275683, 0.06256855], + [-0.22556897, 0.74798368, 0.90416439], + [0.48339456, -2.32742195, -0.59321527], + ], + dtype=np.float32, +) +x = torch.tensor(input_arr, device="cuda") + +def fn(x): + y = torch.relu(x) + return y + +jit_mod = torch.jit.trace(fn, x) +print(jit_mod) diff --git a/python/oneflow/utils/backend/torch_fx_to_oneflow.py b/python/oneflow/utils/backend/from_torch_fx.py similarity index 100% rename from python/oneflow/utils/backend/torch_fx_to_oneflow.py rename to python/oneflow/utils/backend/from_torch_fx.py diff --git a/python/oneflow/utils/backend/from_torch_script.py b/python/oneflow/utils/backend/from_torch_script.py new file mode 100644 index 00000000000..7e410aeec22 --- /dev/null +++ b/python/oneflow/utils/backend/from_torch_script.py @@ -0,0 +1,68 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import os +import torch + +def _get_output_names(node): + return [output.debugName() for output in node.outputs()] + +def _get_input_names(node_or_graph): + return [inp.debugName() for inp in node_or_graph.inputs()] + +def script_tranform(gm, example_inputs): + import oneflow as flow + import pdb; pdb.set_trace() + print("transform from torch script") + + jit_mod = torch.jit.trace(gm, tuple(example_inputs)) + print("jit mod graph ", jit_mod.graph) + torch_graph = jit_mod.graph.copy() + + nodes = torch_graph.nodes() + for node in nodes: + print("===") + print("node: ", node) + operator = node.kind() + input_names = _get_input_names(node) + output_names = _get_output_names(node) + print("in: ", input_names) + print("out: ", output_names) + if operator == "prim::relu": + print("prim::relu") + elif operator == "prim::TupleConstruct": + print("prim::TupleConstruct") + else: + print(operator) + + + + enable_graph = os.getenv("ofrt_enable_graph", "False").lower() in ( + "true", + "1", + "t", + ) + return gm + + if not enable_graph: + oneflow_fn = of_gm.forward + else: + @flow.nn.Graph.trace + def oneflow_fn(inputs): + outs = of_gm.forward(inputs) + return outs + + oneflow_fn.debug(1) + return oneflow_fn diff --git a/python/oneflow/utils/backend/torch_compile.py b/python/oneflow/utils/backend/torch_compile.py index 5b5bfbfdd4d..a973326c572 100644 --- a/python/oneflow/utils/backend/torch_compile.py +++ b/python/oneflow/utils/backend/torch_compile.py @@ -13,8 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. """ +import os import oneflow -from .torch_fx_to_oneflow import fx_tranform +from .from_torch_script import script_tranform +from .from_torch_fx import fx_tranform def register_ofrt(): @@ -31,10 +33,18 @@ def ofrt(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): print("my_compiler() called with FX graph:") gm.graph.print_tabular() print("gm ", gm) - import oneflow as flow - oneflow_fn = fx_tranform(gm) + from_script = os.getenv("ofrt_from_script", "True").lower() in ( + "true", + "1", + "t", + ) + if from_script: + oneflow_fn = script_tranform(gm, example_inputs) + else: + oneflow_fn = fx_tranform(gm) + import oneflow as flow def from_to_torch(inputs): flow_inputs = flow.utils.tensor.from_torch(inputs) flow_outs = oneflow_fn(flow_inputs) From d20f1d0c90683c385827aaeaddea83a7ac3ea472 Mon Sep 17 00:00:00 2001 From: oneflow-ci-bot Date: Fri, 11 Aug 2023 14:54:43 +0000 Subject: [PATCH 23/27] auto format by CI --- python/oneflow/test/graph/test_torch_jit.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/python/oneflow/test/graph/test_torch_jit.py b/python/oneflow/test/graph/test_torch_jit.py index 99949f793bd..99f64de10db 100644 --- a/python/oneflow/test/graph/test_torch_jit.py +++ b/python/oneflow/test/graph/test_torch_jit.py @@ -1,3 +1,18 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import torch import oneflow import numpy as np From 0194b085d6dedbe9b58c0de1d43f7f195955de1b Mon Sep 17 00:00:00 2001 From: strint Date: Wed, 16 Aug 2023 11:27:54 +0000 Subject: [PATCH 24/27] add gm in graph test --- .../test/graph/test_fx_symbolic_trace_module.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/python/oneflow/test/graph/test_fx_symbolic_trace_module.py b/python/oneflow/test/graph/test_fx_symbolic_trace_module.py index 33e127e6d26..9d808d240d5 100644 --- a/python/oneflow/test/graph/test_fx_symbolic_trace_module.py +++ b/python/oneflow/test/graph/test_fx_symbolic_trace_module.py @@ -68,6 +68,21 @@ def test_alexnet(test_case): test_case.assertTrue( np.allclose(gm(input).numpy(), m(input).numpy(), equal_nan=True) ) + class AlexNetEvalGraph(flow.nn.Graph): + def __init__(self): + super().__init__() + self.alexnet = gm + + def build(self, inp): + return self.alexnet(inp) + gm_g = AlexNetEvalGraph() + gm_g.debug(1) + for i in range(5): + input = flow.randn(1, 3, 224, 224) + test_case.assertTrue( + np.allclose(gm_g(input).numpy(), m(input).numpy(), equal_nan=True) + ) + if __name__ == "__main__": From 3a543a1029d81d4e4be2a8c916dc443a3c0c3e8f Mon Sep 17 00:00:00 2001 From: strint Date: Wed, 16 Aug 2023 14:59:35 +0000 Subject: [PATCH 25/27] test linear passed --- .../test/graph/test_graph_interplay.py | 108 +----------------- 1 file changed, 6 insertions(+), 102 deletions(-) diff --git a/python/oneflow/test/graph/test_graph_interplay.py b/python/oneflow/test/graph/test_graph_interplay.py index f6b71b180ab..5e6b8d05892 100644 --- a/python/oneflow/test/graph/test_graph_interplay.py +++ b/python/oneflow/test/graph/test_graph_interplay.py @@ -77,55 +77,8 @@ def _test_linear(test_case, device): import torch from oneflow.utils.backend.torch_compile import register_ofrt - linear = torch.nn.Linear(3, 8, False) - linear = linear.to(device) - input_arr = np.array( - [ - [-0.94630778, -0.83378579, -0.87060891], - [2.0289922, -0.28708987, -2.18369248], - [0.35217619, -0.67095644, -1.58943879], - [0.08086036, -1.81075924, 1.20752494], - [0.8901075, -0.49976737, -1.07153746], - [-0.44872912, -1.07275683, 0.06256855], - [-0.22556897, 0.74798368, 0.90416439], - [0.48339456, -2.32742195, -0.59321527], - ], - dtype=np.float32, - ) - x = torch.tensor(input_arr, device=device) - torch.nn.init.constant_(linear.weight, 2.3) - eager_out = linear(x) - - @torch.compile(backend="ofrt") - def fn(x): - y = linear(x) - return y - - compile_out = fn(x) - test_case.assertTrue( - np.allclose( - compile_out.cpu().detach().numpy(), - eager_out.cpu().detach().numpy(), - 1e-05, - 1e-05, - ) - ) - compile_out = fn(x) - test_case.assertTrue( - np.allclose( - compile_out.cpu().detach().numpy(), - eager_out.cpu().detach().numpy(), - 1e-05, - 1e-05, - ) - ) - - -def __test_linear(test_case, device): - from typing import List - import torch - from torch._dynamo.backends.registry import register_backend - from torch._dynamo.backends.common import fake_tensor_unsupported + os.environ["ofrt_from_script"] = "0" + os.environ["ofrt_enable_graph"] = "1" linear = torch.nn.Linear(3, 8, False) linear = linear.to(device) @@ -146,47 +99,7 @@ def __test_linear(test_case, device): torch.nn.init.constant_(linear.weight, 2.3) eager_out = linear(x) - def get_of(): - # TODO(): transform torch fx code to oneflow code - import oneflow as flow - - linear = flow.nn.Linear(3, 8, False) - linear = linear.to(device) - flow.nn.init.constant_(linear.weight, 2.3) - - class LinearGraph(flow.nn.Graph): - def __init__(self): - super().__init__() - self.my_linear = linear - - def build(self, x): - return self.my_linear(x) - - linear_g = LinearGraph() - linear_g.debug(1) - return linear_g - - g = None - - def torch_interplay(x): - import oneflow as flow - - x = flow.utils.tensor.from_torch(x) - nonlocal g - if g is None: - g = get_of() - # TODO(): This is a special pack trick, try to make it general. - return (flow.utils.tensor.to_torch(g(x)),) - - @register_backend - @fake_tensor_unsupported - def oneflowrt(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): - print("my_compiler() called with FX graph:") - gm.graph.print_tabular() - gm.forward = torch_interplay - return gm.forward # return a python callable - - @torch.compile(backend="oneflowrt") + @torch.compile(backend="ofrt") def fn(x): y = linear(x) return y @@ -200,27 +113,18 @@ def fn(x): 1e-05, ) ) - compile_out = fn(x) - test_case.assertTrue( - np.allclose( - compile_out.cpu().detach().numpy(), - eager_out.cpu().detach().numpy(), - 1e-05, - 1e-05, - ) - ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @oneflow.unittest.skip_unless_1n1d() class TestAsTorchBackend(oneflow.unittest.TestCase): - def test_relu_with_fx(test_case): + def _test_relu_with_fx(test_case): _test_relu(test_case, "cuda", False) - def test_relu_with_script(test_case): + def _test_relu_with_script(test_case): _test_relu(test_case, "cuda", True) - def _test_linear(test_case): + def test_linear_with_fx(test_case): _test_linear(test_case, "cuda") From 0d72cd7a003319e713c1aacdba3b7898b9a6fea1 Mon Sep 17 00:00:00 2001 From: strint Date: Wed, 16 Aug 2023 15:03:46 +0000 Subject: [PATCH 26/27] test linear passed with fx graph and oneflow graph --- python/oneflow/nn/modules/linear.py | 1 + python/oneflow/utils/backend/from_torch_fx.py | 95 +++++++++++++++++-- 2 files changed, 86 insertions(+), 10 deletions(-) diff --git a/python/oneflow/nn/modules/linear.py b/python/oneflow/nn/modules/linear.py index c8647c2141e..179e7e4c964 100644 --- a/python/oneflow/nn/modules/linear.py +++ b/python/oneflow/nn/modules/linear.py @@ -127,6 +127,7 @@ def reset_parameters(self) -> None: flow.nn.init.uniform_(self.bias, -bound, bound) def forward(self, x): + print("run oneflow linear module") if self.use_fused_matmul_bias: return flow._C.fused_matmul_bias(x, self.weight, self.bias) else: diff --git a/python/oneflow/utils/backend/from_torch_fx.py b/python/oneflow/utils/backend/from_torch_fx.py index dc0a20d3f4c..351f3154e3b 100644 --- a/python/oneflow/utils/backend/from_torch_fx.py +++ b/python/oneflow/utils/backend/from_torch_fx.py @@ -14,13 +14,13 @@ limitations under the License. """ import os -import oneflow import torch +import oneflow as flow from torch import fx +from typing import Dict, Any, Dict, Tuple def fx_tranform(gm): - import oneflow as flow of_gm = to_of_transform(gm) @@ -33,19 +33,47 @@ def fx_tranform(gm): if not enable_graph: oneflow_fn = of_gm.forward else: + class OfGraph(flow.nn.Graph): + def __init__(self): + super().__init__() + self.m = of_gm + + def build(self, *args, **kwargs): + return self.m(*args, **kwargs) + + of_g = OfGraph() + of_g.debug(0) + oneflow_fn = lambda *args, **kwargs: of_g(*args, **kwargs) - @flow.nn.Graph.trace - def oneflow_fn(inputs): - outs = of_gm.forward(inputs) - return outs - - oneflow_fn.debug(1) return oneflow_fn +def _parent_name(target: str) -> Tuple[str, str]: + """ + Splits a qualname into parent path and last atom. + For example, `foo.bar.baz` -> (`foo.bar`, `baz`) + """ + *parent, name = target.rsplit(".", 1) + return parent[0] if parent else "", name -def to_of_transform( + +def _replace_node_module( + node: torch.fx.Node, modules: Dict[str, Any], new_module: flow.nn.Module +): + assert isinstance(node.target, str) + parent_name, name = _parent_name(node.target) + setattr(modules[parent_name], name, new_module) + + +def _get_module(origin_mod): + linear = flow.nn.Linear(3, 8, False) + linear = linear.to("cuda") + flow.nn.init.constant_(linear.weight, 2.3) + return linear + +def _to_of_transform( gm: torch.fx.GraphModule, tracer_class: type = fx.Tracer ) -> torch.fx.GraphModule: + modules = dict(gm.named_modules()) for node in gm.graph.nodes: # Checks if we're calling a function (i.e: # torch.add) @@ -53,8 +81,55 @@ def to_of_transform( # The target attribute is the function # that call_function calls. if node.target == torch.relu: - node.target = oneflow.relu + node.target = flow.relu + elif node.op == "call_module": + print(node.format_node()) + if type(modules[node.target] is torch.nn.Linear): + linear = modules[node.target] + print(linear) + _replace_node_module(node, modules, _get_module(linear)) gm.graph.lint() gm.recompile() + for node in gm.graph.nodes: + print(node.format_node()) return gm + + +def to_of_transform( + gm: torch.fx.GraphModule, tracer_class: type = fx.Tracer +) -> torch.fx.GraphModule: + name2node = {} + name2obj = {} + of_g = flow.fx.Graph() + modules = dict(gm.named_modules()) + for node in gm.graph.nodes: + print(node.format_node()) + if node.op == "call_function": + if node.target == torch.relu: + node.target = flow.relu + elif node.op == "call_module": + if type(modules[node.target] is torch.nn.Linear): + linear = modules[node.target] + name2obj[node.target] = _get_module(linear) + of_node = of_g.create_node('call_module', node.target, args=(name2node[node.args[0].name],)) + name2node[node.name] = of_node + elif node.op == "call_method": + ... + elif node.op == "get_attr": + ... + elif node.op == "placeholder": + of_node = of_g.create_node('placeholder', node.target) + name2node[node.name] = of_node + elif node.op == "output": + of_g.output((name2node[node.args[0][0].name],)) + else: + raise ValueError(f"not valid node type{node.foramt_node()}") + print("\n new of graph", of_g.print_tabular()) + for of_node in of_g.nodes: + print(of_node.format_node()) + + of_gm = flow.fx.GraphModule(name2obj, of_g) + of_gm.graph.lint() + of_gm.recompile() + return of_gm \ No newline at end of file From c45f52d2e81ddc1ed577acfa8ea8d1dc9d6133fb Mon Sep 17 00:00:00 2001 From: oneflow-ci-bot Date: Wed, 16 Aug 2023 15:06:25 +0000 Subject: [PATCH 27/27] auto format by CI --- .../test/graph/test_fx_symbolic_trace_module.py | 3 ++- python/oneflow/test/graph/test_torch_jit.py | 2 ++ python/oneflow/utils/backend/from_torch_fx.py | 17 +++++++++++------ .../oneflow/utils/backend/from_torch_script.py | 10 +++++++--- python/oneflow/utils/backend/torch_compile.py | 1 + 5 files changed, 23 insertions(+), 10 deletions(-) diff --git a/python/oneflow/test/graph/test_fx_symbolic_trace_module.py b/python/oneflow/test/graph/test_fx_symbolic_trace_module.py index 9d808d240d5..34396daea09 100644 --- a/python/oneflow/test/graph/test_fx_symbolic_trace_module.py +++ b/python/oneflow/test/graph/test_fx_symbolic_trace_module.py @@ -68,6 +68,7 @@ def test_alexnet(test_case): test_case.assertTrue( np.allclose(gm(input).numpy(), m(input).numpy(), equal_nan=True) ) + class AlexNetEvalGraph(flow.nn.Graph): def __init__(self): super().__init__() @@ -75,6 +76,7 @@ def __init__(self): def build(self, inp): return self.alexnet(inp) + gm_g = AlexNetEvalGraph() gm_g.debug(1) for i in range(5): @@ -84,6 +86,5 @@ def build(self, inp): ) - if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/graph/test_torch_jit.py b/python/oneflow/test/graph/test_torch_jit.py index 99f64de10db..cd02279c2e1 100644 --- a/python/oneflow/test/graph/test_torch_jit.py +++ b/python/oneflow/test/graph/test_torch_jit.py @@ -32,9 +32,11 @@ ) x = torch.tensor(input_arr, device="cuda") + def fn(x): y = torch.relu(x) return y + jit_mod = torch.jit.trace(fn, x) print(jit_mod) diff --git a/python/oneflow/utils/backend/from_torch_fx.py b/python/oneflow/utils/backend/from_torch_fx.py index 351f3154e3b..1990df9bf59 100644 --- a/python/oneflow/utils/backend/from_torch_fx.py +++ b/python/oneflow/utils/backend/from_torch_fx.py @@ -33,20 +33,22 @@ def fx_tranform(gm): if not enable_graph: oneflow_fn = of_gm.forward else: + class OfGraph(flow.nn.Graph): def __init__(self): super().__init__() self.m = of_gm - + def build(self, *args, **kwargs): return self.m(*args, **kwargs) - + of_g = OfGraph() of_g.debug(0) oneflow_fn = lambda *args, **kwargs: of_g(*args, **kwargs) return oneflow_fn + def _parent_name(target: str) -> Tuple[str, str]: """ Splits a qualname into parent path and last atom. @@ -70,6 +72,7 @@ def _get_module(origin_mod): flow.nn.init.constant_(linear.weight, 2.3) return linear + def _to_of_transform( gm: torch.fx.GraphModule, tracer_class: type = fx.Tracer ) -> torch.fx.GraphModule: @@ -112,14 +115,16 @@ def to_of_transform( if type(modules[node.target] is torch.nn.Linear): linear = modules[node.target] name2obj[node.target] = _get_module(linear) - of_node = of_g.create_node('call_module', node.target, args=(name2node[node.args[0].name],)) + of_node = of_g.create_node( + "call_module", node.target, args=(name2node[node.args[0].name],) + ) name2node[node.name] = of_node elif node.op == "call_method": ... elif node.op == "get_attr": ... elif node.op == "placeholder": - of_node = of_g.create_node('placeholder', node.target) + of_node = of_g.create_node("placeholder", node.target) name2node[node.name] = of_node elif node.op == "output": of_g.output((name2node[node.args[0][0].name],)) @@ -128,8 +133,8 @@ def to_of_transform( print("\n new of graph", of_g.print_tabular()) for of_node in of_g.nodes: print(of_node.format_node()) - + of_gm = flow.fx.GraphModule(name2obj, of_g) of_gm.graph.lint() of_gm.recompile() - return of_gm \ No newline at end of file + return of_gm diff --git a/python/oneflow/utils/backend/from_torch_script.py b/python/oneflow/utils/backend/from_torch_script.py index 7e410aeec22..889c3767aad 100644 --- a/python/oneflow/utils/backend/from_torch_script.py +++ b/python/oneflow/utils/backend/from_torch_script.py @@ -16,15 +16,20 @@ import os import torch + def _get_output_names(node): return [output.debugName() for output in node.outputs()] + def _get_input_names(node_or_graph): return [inp.debugName() for inp in node_or_graph.inputs()] + def script_tranform(gm, example_inputs): import oneflow as flow - import pdb; pdb.set_trace() + import pdb + + pdb.set_trace() print("transform from torch script") jit_mod = torch.jit.trace(gm, tuple(example_inputs)) @@ -47,8 +52,6 @@ def script_tranform(gm, example_inputs): else: print(operator) - - enable_graph = os.getenv("ofrt_enable_graph", "False").lower() in ( "true", "1", @@ -59,6 +62,7 @@ def script_tranform(gm, example_inputs): if not enable_graph: oneflow_fn = of_gm.forward else: + @flow.nn.Graph.trace def oneflow_fn(inputs): outs = of_gm.forward(inputs) diff --git a/python/oneflow/utils/backend/torch_compile.py b/python/oneflow/utils/backend/torch_compile.py index a973326c572..94131259cbf 100644 --- a/python/oneflow/utils/backend/torch_compile.py +++ b/python/oneflow/utils/backend/torch_compile.py @@ -45,6 +45,7 @@ def ofrt(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): oneflow_fn = fx_tranform(gm) import oneflow as flow + def from_to_torch(inputs): flow_inputs = flow.utils.tensor.from_torch(inputs) flow_outs = oneflow_fn(flow_inputs)