Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fx trans for nn graph #237

Merged
merged 9 commits into from
Aug 18, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions examples/torch_interpretor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# HF_HUB_OFFLINE=1 python3 examples/torch_interpretor.py
import os
import torch
from diffusers import StableDiffusionPipeline
from onediff.infer_compiler import torchbackend
Expand All @@ -10,13 +11,14 @@
torch_dtype=torch.float16,
)

os.environ["with_interp"] = "0"
os.environ["with_graph"] = "1"
pipe.unet = torch.compile(pipe.unet, fullgraph=True, mode="reduce-overhead", backend=torchbackend)
pipe = pipe.to("cuda")

prompt = "a photo of an astronaut riding a horse on mars"
with torch.autocast("cuda"):
images = pipe(prompt).images
images = pipe(prompt).images
images = pipe(prompt).images
for i, image in enumerate(images):
image.save(f"{prompt}-of-{i}.png")
for i in range(3):
images = pipe(prompt).images
for j, image in enumerate(images):
image.save(f'{prompt}-of-{i}-{j}.png')
156 changes: 151 additions & 5 deletions src/onediff/infer_compiler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import os
from collections import OrderedDict
import torch
import torch.fx as fx
import diffusers
import oneflow
import oneflow as flow
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.node import map_aggregate
from torch.func import functionalize
import importlib
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
Expand Down Expand Up @@ -153,8 +156,6 @@ class OneFlowInterpreter(torch.fx.Interpreter):
from torch.fx.node import Argument, Node, Target, map_arg, map_aggregate

def call_function(self, target: Target, args: Tuple, kwargs: Dict) -> Any:
if target == torch.sigmoid:
return torch.neg(*args, **kwargs)
args, kwargs = map_args(args, kwargs)
target = replace_func(target)
return super().call_function(target, args, kwargs)
Expand All @@ -172,13 +173,158 @@ def call_module(


def torchbackend(gm, example_inputs):
with_interp = os.getenv("with_interp", "True").lower() in (
strint marked this conversation as resolved.
Show resolved Hide resolved
"true",
"1",
"t",
)
if not with_interp:
transformed_fn = fx_node_tranform(gm)

def wrapped_forward(*args, **kwargs):
print("==> id of gm", id(gm))
args = [flow.utils.tensor.from_torch(a) for a in args]
output = OneFlowInterpreter(gm, garbage_collect_values=False).run(
*args, **kwargs
)
if with_interp:
output = OneFlowInterpreter(gm, garbage_collect_values=False).run(
*args, **kwargs
)
else:
output = transformed_fn(*args, **kwargs)
if isinstance(output, tuple):
return tuple(flow.utils.tensor.to_torch(i) for i in output)
return flow.utils.tensor.to_torch(output)

return wrapped_forward

def fx_node_tranform(gm):
print("==> gm node transform")
of_gm = to_of_transform(gm)

enable_graph = os.getenv("with_graph", "False").lower() in (
"true",
"1",
"t",
)

if not enable_graph:
oneflow_fn = of_gm.forward
else:
class OfGraph(flow.nn.Graph):
def __init__(self):
super().__init__()
self.m = of_gm
strint marked this conversation as resolved.
Show resolved Hide resolved

def build(self, *args, **kwargs):
return self.m(*args, **kwargs)

of_g = OfGraph()
of_g.debug(0)
oneflow_fn = lambda *args, **kwargs: of_g(*args, **kwargs)

return oneflow_fn


def to_of_transform(
gm: torch.fx.GraphModule, tracer_class: type = fx.Tracer
) -> torch.fx.GraphModule:
name2node = {}
name2obj = {}
torch2flow = {}
of_g = flow.fx.Graph()
modules = dict(gm.named_modules())
for node in gm.graph.nodes:
if node.op == "placeholder":
of_node = of_g.create_node('placeholder', node.target)
name2node[node.name] = of_node
elif node.op == "output":
of_node = of_g.output(node_replace_args(node.args, name2node)[0])
name2node[node.name] = of_node
elif node.op == "call_function":
of_node = of_g.create_node('call_function', replace_func(node.target), args=node_replace_args(node.args, name2node), kwargs=node_replace_args(node.kwargs, name2node))
name2node[node.name] = of_node
elif node.op == "call_method":
of_node = of_g.create_node('call_method', node.target, args=node_replace_args(node.args, name2node), kwargs=node_replace_args(node.kwargs, name2node))
name2node[node.name] = of_node
elif node.op == "call_module":
torch_md = modules[node.target]
name2obj[node.target] = _get_module(torch_md, torch2flow)
of_node = of_g.create_node('call_module', node.target, args=node_replace_args(node.args, name2node), kwargs=node_replace_args(node.kwargs, name2node))
name2node[node.name] = of_node
elif node.op == "get_attr":
of_node = of_g.create_node('get_attr', node.target)
name2node[node.name] = of_node
name2obj[node.target] = _get_attr(gm, node, torch2flow)
else:
raise ValueError(f"not valid node type{node.foramt_node()}")

of_gm = flow.fx.GraphModule(name2obj, of_g)
of_gm.graph.lint()
of_gm.recompile()
return of_gm

def replace_node(node, name2node):
if isinstance(node, torch.fx.Node):
return name2node[node.name]
else:
return replace_obj(node)


def node_replace_args(args, name2node):
return map_aggregate(args, lambda node: replace_node(node, name2node))


def _get_module_list(origin_mod, torch2flow):
assert isinstance(origin_mod, torch.nn.ModuleList)
if origin_mod in torch2flow:
return torch2flow[origin_mod]
of_md_list = flow.nn.ModuleList()
for m in origin_mod:
of_md_list.append(_get_module(m, torch2flow))
torch2flow[origin_mod] = of_md_list
return of_md_list


def _get_module(origin_mod, torch2flow):
if origin_mod in torch2flow:
return torch2flow[origin_mod]

if isinstance(origin_mod, torch.nn.ModuleList):
return _get_module_list(origin_mod, torch2flow)

proxy_md = ProxySubmodule(origin_mod)
new_md_cls = replace_class(type(origin_mod))

def init(self):
self._parameters = OrderedDict()
self._buffers = OrderedDict()
self._modules = OrderedDict()
for (n, p) in list(proxy_md.named_parameters("", False)):
self._parameters[n] = flow.utils.tensor.from_torch(p.data)
for (n, b) in list(proxy_md.named_buffers("", False)):
self._buffers[n] = flow.utils.tensor.from_torch(b.data)
for (n, m) in proxy_md._modules.items():
# TODO
self._modules[n] = _get_module(m, torch2flow)

def proxy_getattr(self, attr):
if attr in ["_parameters", "_buffers", "_modules"]:
raise ValueError(f"missing attr {attr} in base class")
else:
return getattr(proxy_md, attr)
strint marked this conversation as resolved.
Show resolved Hide resolved

of_md_cls = type(
str(new_md_cls), (new_md_cls,), {"__init__": init, "__getattr__": proxy_getattr},
)

new_md = of_md_cls()

torch2flow[origin_mod] = new_md
return new_md

def _get_attr(gm, node, torch2flow):
attr = getattr(gm, node.target)
if attr in torch2flow:
return torch2flow[attr]
of_attr = replace_obj(attr)
torch2flow[attr] = of_attr
return of_attr
56 changes: 56 additions & 0 deletions tests/test_trans.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from typing import Tuple, Dict, Any
import torch
from torch.fx.node import Argument
import oneflow as flow

def test_torch_trans():
class NegSigmSwapXformer(torch.fx.Transformer):
def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
if target == torch.sigmoid:
return torch.neg(*args, **kwargs)
return super().call_function(n)

def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
if target == 'neg':
call_self, *args_tail = args
return call_self.sigmoid(*args_tail, **kwargs)
return super().call_method(n)

def fn(x):
return torch.sigmoid(x).neg()

gm = torch.fx.symbolic_trace(fn)
print(gm)

transformed : torch.nn.Module = NegSigmSwapXformer(gm).transform()
print(transformed)
strint marked this conversation as resolved.
Show resolved Hide resolved

input = torch.randn(3, 4)
torch.testing.assert_close(transformed(input), torch.neg(input).sigmoid())

test_torch_trans()

def test_oneflow_trans():
class NegSigmSwapXformer(torch.fx.Transformer):
def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
if target == torch.sigmoid:
return flow.neg(*args, **kwargs)
return super().call_function(n)

def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
if target == 'neg':
call_self, *args_tail = args
return call_self.sigmoid(*args_tail, **kwargs)
return super().call_method(n)

def fn(x):
return torch.sigmoid(x).neg()

gm = torch.fx.symbolic_trace(fn)
print(gm)

transformed : flow.nn.Module = NegSigmSwapXformer(gm).transform()
print(transformed)

input = flow.randn(3, 4)
flow.testing.assert_close(transformed(input), flow.neg(input).sigmoid())