Skip to content

Commit

Permalink
test linear passed
Browse files Browse the repository at this point in the history
  • Loading branch information
strint committed Aug 16, 2023
1 parent 0194b08 commit 3a543a1
Showing 1 changed file with 6 additions and 102 deletions.
108 changes: 6 additions & 102 deletions python/oneflow/test/graph/test_graph_interplay.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,55 +77,8 @@ def _test_linear(test_case, device):
import torch
from oneflow.utils.backend.torch_compile import register_ofrt

linear = torch.nn.Linear(3, 8, False)
linear = linear.to(device)
input_arr = np.array(
[
[-0.94630778, -0.83378579, -0.87060891],
[2.0289922, -0.28708987, -2.18369248],
[0.35217619, -0.67095644, -1.58943879],
[0.08086036, -1.81075924, 1.20752494],
[0.8901075, -0.49976737, -1.07153746],
[-0.44872912, -1.07275683, 0.06256855],
[-0.22556897, 0.74798368, 0.90416439],
[0.48339456, -2.32742195, -0.59321527],
],
dtype=np.float32,
)
x = torch.tensor(input_arr, device=device)
torch.nn.init.constant_(linear.weight, 2.3)
eager_out = linear(x)

@torch.compile(backend="ofrt")
def fn(x):
y = linear(x)
return y

compile_out = fn(x)
test_case.assertTrue(
np.allclose(
compile_out.cpu().detach().numpy(),
eager_out.cpu().detach().numpy(),
1e-05,
1e-05,
)
)
compile_out = fn(x)
test_case.assertTrue(
np.allclose(
compile_out.cpu().detach().numpy(),
eager_out.cpu().detach().numpy(),
1e-05,
1e-05,
)
)


def __test_linear(test_case, device):
from typing import List
import torch
from torch._dynamo.backends.registry import register_backend
from torch._dynamo.backends.common import fake_tensor_unsupported
os.environ["ofrt_from_script"] = "0"
os.environ["ofrt_enable_graph"] = "1"

linear = torch.nn.Linear(3, 8, False)
linear = linear.to(device)
Expand All @@ -146,47 +99,7 @@ def __test_linear(test_case, device):
torch.nn.init.constant_(linear.weight, 2.3)
eager_out = linear(x)

def get_of():
# TODO(): transform torch fx code to oneflow code
import oneflow as flow

linear = flow.nn.Linear(3, 8, False)
linear = linear.to(device)
flow.nn.init.constant_(linear.weight, 2.3)

class LinearGraph(flow.nn.Graph):
def __init__(self):
super().__init__()
self.my_linear = linear

def build(self, x):
return self.my_linear(x)

linear_g = LinearGraph()
linear_g.debug(1)
return linear_g

g = None

def torch_interplay(x):
import oneflow as flow

x = flow.utils.tensor.from_torch(x)
nonlocal g
if g is None:
g = get_of()
# TODO(): This is a special pack trick, try to make it general.
return (flow.utils.tensor.to_torch(g(x)),)

@register_backend
@fake_tensor_unsupported
def oneflowrt(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
print("my_compiler() called with FX graph:")
gm.graph.print_tabular()
gm.forward = torch_interplay
return gm.forward # return a python callable

@torch.compile(backend="oneflowrt")
@torch.compile(backend="ofrt")
def fn(x):
y = linear(x)
return y
Expand All @@ -200,27 +113,18 @@ def fn(x):
1e-05,
)
)
compile_out = fn(x)
test_case.assertTrue(
np.allclose(
compile_out.cpu().detach().numpy(),
eager_out.cpu().detach().numpy(),
1e-05,
1e-05,
)
)


@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
@oneflow.unittest.skip_unless_1n1d()
class TestAsTorchBackend(oneflow.unittest.TestCase):
def test_relu_with_fx(test_case):
def _test_relu_with_fx(test_case):
_test_relu(test_case, "cuda", False)

def test_relu_with_script(test_case):
def _test_relu_with_script(test_case):
_test_relu(test_case, "cuda", True)

def _test_linear(test_case):
def test_linear_with_fx(test_case):
_test_linear(test_case, "cuda")


Expand Down

0 comments on commit 3a543a1

Please sign in to comment.