diff --git a/python/oneflow/nn/modules/linear.py b/python/oneflow/nn/modules/linear.py index c8647c2141e..179e7e4c964 100644 --- a/python/oneflow/nn/modules/linear.py +++ b/python/oneflow/nn/modules/linear.py @@ -127,6 +127,7 @@ def reset_parameters(self) -> None: flow.nn.init.uniform_(self.bias, -bound, bound) def forward(self, x): + print("run oneflow linear module") if self.use_fused_matmul_bias: return flow._C.fused_matmul_bias(x, self.weight, self.bias) else: diff --git a/python/oneflow/test/graph/test_fx_symbolic_trace_module.py b/python/oneflow/test/graph/test_fx_symbolic_trace_module.py index 33e127e6d26..34396daea09 100644 --- a/python/oneflow/test/graph/test_fx_symbolic_trace_module.py +++ b/python/oneflow/test/graph/test_fx_symbolic_trace_module.py @@ -69,6 +69,22 @@ def test_alexnet(test_case): np.allclose(gm(input).numpy(), m(input).numpy(), equal_nan=True) ) + class AlexNetEvalGraph(flow.nn.Graph): + def __init__(self): + super().__init__() + self.alexnet = gm + + def build(self, inp): + return self.alexnet(inp) + + gm_g = AlexNetEvalGraph() + gm_g.debug(1) + for i in range(5): + input = flow.randn(1, 3, 224, 224) + test_case.assertTrue( + np.allclose(gm_g(input).numpy(), m(input).numpy(), equal_nan=True) + ) + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/graph/test_graph_interplay.py b/python/oneflow/test/graph/test_graph_interplay.py new file mode 100644 index 00000000000..5e6b8d05892 --- /dev/null +++ b/python/oneflow/test/graph/test_graph_interplay.py @@ -0,0 +1,132 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import os +import unittest +import numpy as np + +# Must import torch before oneflow, otherwise torch.jit.trace will raise error: +# terminate called after throwing an instance of 'pybind11::stop_iteration' +import torch +import oneflow.unittest + + +def _test_relu(test_case, device, from_script=False): + from typing import List + import torch + from oneflow.utils.backend.torch_compile import register_ofrt + + input_arr = np.array( + [ + [-0.94630778, -0.83378579, -0.87060891], + [2.0289922, -0.28708987, -2.18369248], + [0.35217619, -0.67095644, -1.58943879], + [0.08086036, -1.81075924, 1.20752494], + [0.8901075, -0.49976737, -1.07153746], + [-0.44872912, -1.07275683, 0.06256855], + [-0.22556897, 0.74798368, 0.90416439], + [0.48339456, -2.32742195, -0.59321527], + ], + dtype=np.float32, + ) + x = torch.tensor(input_arr, device=device) + eager_out = torch.relu(x) + + os.environ["ofrt_from_script"] = str(from_script) + os.environ["ofrt_enable_graph"] = "1" + + @torch.compile(backend="ofrt") + def fn(x): + y = torch.relu(x) + return y + + compile_out = fn(x) + test_case.assertTrue( + np.allclose( + compile_out.cpu().detach().numpy(), + eager_out.cpu().detach().numpy(), + 1e-05, + 1e-05, + ) + ) + compile_out = fn(x) + test_case.assertTrue( + np.allclose( + compile_out.cpu().detach().numpy(), + eager_out.cpu().detach().numpy(), + 1e-05, + 1e-05, + ) + ) + + +def _test_linear(test_case, device): + from typing import List + import torch + from oneflow.utils.backend.torch_compile import register_ofrt + + os.environ["ofrt_from_script"] = "0" + os.environ["ofrt_enable_graph"] = "1" + + linear = torch.nn.Linear(3, 8, False) + linear = linear.to(device) + input_arr = np.array( + [ + [-0.94630778, -0.83378579, -0.87060891], + [2.0289922, -0.28708987, -2.18369248], + [0.35217619, -0.67095644, -1.58943879], + [0.08086036, -1.81075924, 1.20752494], + [0.8901075, -0.49976737, -1.07153746], + [-0.44872912, -1.07275683, 0.06256855], + [-0.22556897, 0.74798368, 0.90416439], + [0.48339456, -2.32742195, -0.59321527], + ], + dtype=np.float32, + ) + x = torch.tensor(input_arr, device=device) + torch.nn.init.constant_(linear.weight, 2.3) + eager_out = linear(x) + + @torch.compile(backend="ofrt") + def fn(x): + y = linear(x) + return y + + compile_out = fn(x) + test_case.assertTrue( + np.allclose( + compile_out.cpu().detach().numpy(), + eager_out.cpu().detach().numpy(), + 1e-05, + 1e-05, + ) + ) + + +@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") +@oneflow.unittest.skip_unless_1n1d() +class TestAsTorchBackend(oneflow.unittest.TestCase): + def _test_relu_with_fx(test_case): + _test_relu(test_case, "cuda", False) + + def _test_relu_with_script(test_case): + _test_relu(test_case, "cuda", True) + + def test_linear_with_fx(test_case): + _test_linear(test_case, "cuda") + + +if __name__ == "__main__": + unittest.main() diff --git a/python/oneflow/test/graph/test_torch_jit.py b/python/oneflow/test/graph/test_torch_jit.py new file mode 100644 index 00000000000..cd02279c2e1 --- /dev/null +++ b/python/oneflow/test/graph/test_torch_jit.py @@ -0,0 +1,42 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import torch +import oneflow +import numpy as np + +input_arr = np.array( + [ + [-0.94630778, -0.83378579, -0.87060891], + [2.0289922, -0.28708987, -2.18369248], + [0.35217619, -0.67095644, -1.58943879], + [0.08086036, -1.81075924, 1.20752494], + [0.8901075, -0.49976737, -1.07153746], + [-0.44872912, -1.07275683, 0.06256855], + [-0.22556897, 0.74798368, 0.90416439], + [0.48339456, -2.32742195, -0.59321527], + ], + dtype=np.float32, +) +x = torch.tensor(input_arr, device="cuda") + + +def fn(x): + y = torch.relu(x) + return y + + +jit_mod = torch.jit.trace(fn, x) +print(jit_mod) diff --git a/python/oneflow/utils/backend/__init__.py b/python/oneflow/utils/backend/__init__.py new file mode 100644 index 00000000000..fefd352fa2a --- /dev/null +++ b/python/oneflow/utils/backend/__init__.py @@ -0,0 +1,22 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +from oneflow.utils.backend.torch_compile import register_ofrt + +register_ofrt() + +__all__ = [ + "register_ofrt", +] diff --git a/python/oneflow/utils/backend/from_torch_fx.py b/python/oneflow/utils/backend/from_torch_fx.py new file mode 100644 index 00000000000..1990df9bf59 --- /dev/null +++ b/python/oneflow/utils/backend/from_torch_fx.py @@ -0,0 +1,140 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import os +import torch +import oneflow as flow +from torch import fx +from typing import Dict, Any, Dict, Tuple + + +def fx_tranform(gm): + + of_gm = to_of_transform(gm) + + enable_graph = os.getenv("ofrt_enable_graph", "False").lower() in ( + "true", + "1", + "t", + ) + + if not enable_graph: + oneflow_fn = of_gm.forward + else: + + class OfGraph(flow.nn.Graph): + def __init__(self): + super().__init__() + self.m = of_gm + + def build(self, *args, **kwargs): + return self.m(*args, **kwargs) + + of_g = OfGraph() + of_g.debug(0) + oneflow_fn = lambda *args, **kwargs: of_g(*args, **kwargs) + + return oneflow_fn + + +def _parent_name(target: str) -> Tuple[str, str]: + """ + Splits a qualname into parent path and last atom. + For example, `foo.bar.baz` -> (`foo.bar`, `baz`) + """ + *parent, name = target.rsplit(".", 1) + return parent[0] if parent else "", name + + +def _replace_node_module( + node: torch.fx.Node, modules: Dict[str, Any], new_module: flow.nn.Module +): + assert isinstance(node.target, str) + parent_name, name = _parent_name(node.target) + setattr(modules[parent_name], name, new_module) + + +def _get_module(origin_mod): + linear = flow.nn.Linear(3, 8, False) + linear = linear.to("cuda") + flow.nn.init.constant_(linear.weight, 2.3) + return linear + + +def _to_of_transform( + gm: torch.fx.GraphModule, tracer_class: type = fx.Tracer +) -> torch.fx.GraphModule: + modules = dict(gm.named_modules()) + for node in gm.graph.nodes: + # Checks if we're calling a function (i.e: + # torch.add) + if node.op == "call_function": + # The target attribute is the function + # that call_function calls. + if node.target == torch.relu: + node.target = flow.relu + elif node.op == "call_module": + print(node.format_node()) + if type(modules[node.target] is torch.nn.Linear): + linear = modules[node.target] + print(linear) + _replace_node_module(node, modules, _get_module(linear)) + + gm.graph.lint() + gm.recompile() + for node in gm.graph.nodes: + print(node.format_node()) + return gm + + +def to_of_transform( + gm: torch.fx.GraphModule, tracer_class: type = fx.Tracer +) -> torch.fx.GraphModule: + name2node = {} + name2obj = {} + of_g = flow.fx.Graph() + modules = dict(gm.named_modules()) + for node in gm.graph.nodes: + print(node.format_node()) + if node.op == "call_function": + if node.target == torch.relu: + node.target = flow.relu + elif node.op == "call_module": + if type(modules[node.target] is torch.nn.Linear): + linear = modules[node.target] + name2obj[node.target] = _get_module(linear) + of_node = of_g.create_node( + "call_module", node.target, args=(name2node[node.args[0].name],) + ) + name2node[node.name] = of_node + elif node.op == "call_method": + ... + elif node.op == "get_attr": + ... + elif node.op == "placeholder": + of_node = of_g.create_node("placeholder", node.target) + name2node[node.name] = of_node + elif node.op == "output": + of_g.output((name2node[node.args[0][0].name],)) + else: + raise ValueError(f"not valid node type{node.foramt_node()}") + print("\n new of graph", of_g.print_tabular()) + for of_node in of_g.nodes: + print(of_node.format_node()) + + of_gm = flow.fx.GraphModule(name2obj, of_g) + of_gm.graph.lint() + of_gm.recompile() + return of_gm diff --git a/python/oneflow/utils/backend/from_torch_script.py b/python/oneflow/utils/backend/from_torch_script.py new file mode 100644 index 00000000000..889c3767aad --- /dev/null +++ b/python/oneflow/utils/backend/from_torch_script.py @@ -0,0 +1,72 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import os +import torch + + +def _get_output_names(node): + return [output.debugName() for output in node.outputs()] + + +def _get_input_names(node_or_graph): + return [inp.debugName() for inp in node_or_graph.inputs()] + + +def script_tranform(gm, example_inputs): + import oneflow as flow + import pdb + + pdb.set_trace() + print("transform from torch script") + + jit_mod = torch.jit.trace(gm, tuple(example_inputs)) + print("jit mod graph ", jit_mod.graph) + torch_graph = jit_mod.graph.copy() + + nodes = torch_graph.nodes() + for node in nodes: + print("===") + print("node: ", node) + operator = node.kind() + input_names = _get_input_names(node) + output_names = _get_output_names(node) + print("in: ", input_names) + print("out: ", output_names) + if operator == "prim::relu": + print("prim::relu") + elif operator == "prim::TupleConstruct": + print("prim::TupleConstruct") + else: + print(operator) + + enable_graph = os.getenv("ofrt_enable_graph", "False").lower() in ( + "true", + "1", + "t", + ) + return gm + + if not enable_graph: + oneflow_fn = of_gm.forward + else: + + @flow.nn.Graph.trace + def oneflow_fn(inputs): + outs = of_gm.forward(inputs) + return outs + + oneflow_fn.debug(1) + return oneflow_fn diff --git a/python/oneflow/utils/backend/torch_compile.py b/python/oneflow/utils/backend/torch_compile.py new file mode 100644 index 00000000000..94131259cbf --- /dev/null +++ b/python/oneflow/utils/backend/torch_compile.py @@ -0,0 +1,56 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import os +import oneflow +from .from_torch_script import script_tranform +from .from_torch_fx import fx_tranform + + +def register_ofrt(): + from typing import List, Optional, Dict, Any + import torch + from torch import fx + from torch._dynamo.backends.registry import register_backend + from torch._dynamo.backends.common import fake_tensor_unsupported + + @register_backend + @fake_tensor_unsupported + def ofrt(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): + # TODO(): fxGraphModule to nn.Graph + print("my_compiler() called with FX graph:") + gm.graph.print_tabular() + print("gm ", gm) + + from_script = os.getenv("ofrt_from_script", "True").lower() in ( + "true", + "1", + "t", + ) + if from_script: + oneflow_fn = script_tranform(gm, example_inputs) + else: + oneflow_fn = fx_tranform(gm) + + import oneflow as flow + + def from_to_torch(inputs): + flow_inputs = flow.utils.tensor.from_torch(inputs) + flow_outs = oneflow_fn(flow_inputs) + # TODO(): general output process + outs = flow.utils.tensor.to_torch(flow_outs[0]) + return (outs,) + + return from_to_torch