Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sd2 oneflow compile #244

Merged
merged 17 commits into from
Aug 29, 2023
2 changes: 1 addition & 1 deletion examples/text_to_image_sdxl_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,5 @@
pipe.to("cuda")

for i in range(3):
image = pipe(prompt=args.prompt).images[0]
image = pipe(prompt=args.prompt, height=96, width=128, num_inference_steps=50).images[0]
image.save(f"{i}-{args.saved_image}")
57 changes: 57 additions & 0 deletions examples/text_to_image_sdxl_fp16_with_oneflow_compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import os
import argparse
# cv2 must be imported before diffusers and oneflow to avlid error: AttributeError: module 'cv2.gapi' has no attribute 'wip'
# Maybe bacause oneflow use a lower version of cv2
import cv2
import oneflow as flow
import torch
# oneflow_compile should be imported before importing any diffusers
from onediff.infer_compiler import oneflow_compile
from diffusers import StableDiffusionXLPipeline

parser = argparse.ArgumentParser()
parser.add_argument(
"--model", type=str, default="/share_nfs/hf_models/stable-diffusion-xl-base-1.0"
)
parser.add_argument("--variant", type=str, default="fp16")
parser.add_argument(
"--prompt",
type=str,
default="street style, detailed, raw photo, woman, face, shot on CineStill 800T",
)
parser.add_argument("--saved_image", type=str, required=False, default="xl-base-out.png")
parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--compile", action=argparse.BooleanOptionalAction)
parser.add_argument("--save", action=argparse.BooleanOptionalAction)
parser.add_argument("--load", action=argparse.BooleanOptionalAction)
parser.add_argument("--file", type=str, required=False, default="unet_compiled")
cmd_args = parser.parse_args()

# Normal SDXL
torch.manual_seed(cmd_args.seed)
pipe = StableDiffusionXLPipeline.from_pretrained(
cmd_args.model, torch_dtype=torch.float16, variant=cmd_args.variant, use_safetensors=True
)
pipe.to("cuda")

# Compile unet with oneflow
if cmd_args.compile:
pipe.unet = oneflow_compile(pipe.unet)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

和 torch.compile 体验一致的 oneflow_compile 函数。实现了一键替换 pipe 中的 unet 实例的效果(显存共享的)。且该 unet 支持 save/load/dynamic input。

一键替换 pipe 中的 unet 要解决的两个问题:

  • pipe 中已经创建了 torch 的 eager unet 实例,如何创建 oneflow unet 实例,且显存共享;
  • 如何创建 unet graph 实例,让其和 pipe 中的 eager unet 外部 attr 完全一样,且用 graph 执行;

这个作为 base(验证性能和正确性) 和保底方案。

后面基于这个去做 torch.compile 路线的 save/load/dynamic input 支持。torch.compile 本身对 save/load/dynamic input 这些特性的支持还太完善:

  • unet dynamic 会报错;
  • 还没有 save load 编译结果的机制;

print("unet is compiled to oneflow.")
if cmd_args.load:
# Load compiled unet with oneflow
print("loading graphs...")
pipe.unet._graph_load(cmd_args.file)

# Normal SDXL call
sizes = [1024, 896, 768]
for h in sizes:
for w in sizes:
for i in range(1):
image = pipe(prompt=cmd_args.prompt, height=h, width=w, num_inference_steps=2).images[0]
image.save(f"h{h}-w{w}-i{i}-{cmd_args.saved_image}")

# Save compiled unet with oneflow
if cmd_args.compile and cmd_args.save:
print("saving graphs...")
pipe.unet._graph_save(cmd_args.file)
73 changes: 45 additions & 28 deletions examples/unet_torch_interplay.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import os
import random

os.environ["ONEFLOW_MLIR_CSE"] = "1"
os.environ["ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION"] = "1"
os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1"
os.environ["ONEFLOW_MLIR_FUSE_FORWARD_OPS"] = "1"
os.environ["ONEFLOW_MLIR_FUSE_OPS_WITH_BACKWARD_IMPL"] = "1"
os.environ["ONEFLOW_MLIR_GROUP_MATMUL"] = "1"
# TODO(): open ONEFLOW_MLIR_GROUP_MATMUL will raise in SDXL, need be fixed.
os.environ["ONEFLOW_MLIR_GROUP_MATMUL"] = "0"
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

open ONEFLOW_MLIR_GROUP_MATMUL will raise in SDXL, need be fixed.

os.environ["ONEFLOW_MLIR_PREFER_NHWC"] = "1"

os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_CONV_BIAS"] = "1"
Expand All @@ -20,6 +22,9 @@
os.environ["ONEFLOW_LINEAR_EMBEDDING_SKIP_INIT"] = "1"

import click
# cv2 must be imported before diffusers and oneflow to avlid error: AttributeError: module 'cv2.gapi' has no attribute 'wip'
# Maybe bacause oneflow use a lower version of cv2
import cv2
import oneflow as flow
from tqdm import tqdm
from dataclasses import dataclass, fields
Expand All @@ -45,13 +50,13 @@ def __exit__(self, exc_type, exc_val, exc_tb):
flow.mock_torch.disable()


def get_unet(token, _model_id):
def get_unet(token, _model_id, revision):
from diffusers import UNet2DConditionModel

unet = UNet2DConditionModel.from_pretrained(
_model_id,
use_auth_token=token,
revision="fp16",
revision=revision,
torch_dtype=flow.float16,
subfolder="unet",
)
Expand All @@ -67,14 +72,15 @@ def __init__(self, unet):
self.unet = unet
self.config.enable_cudnn_conv_heuristic_search_algo(False)
self.config.allow_fuse_add_to_output(True)
self.debug(0)

def build(self, latent_model_input, t, text_embeddings):
def build(self, latent_model_input, t, text_embeddings, added_cond_kwargs=None):
text_embeddings = flow._C.amp_white_identity(text_embeddings)
return self.unet(
latent_model_input, t, encoder_hidden_states=text_embeddings
latent_model_input, t, encoder_hidden_states=text_embeddings, added_cond_kwargs=added_cond_kwargs,
).sample

def warmup_with_arg(self, arg_meta_of_sizes):
def warmup_with_arg(self, arg_meta_of_sizes, added):
for arg_metas in arg_meta_of_sizes:
print(f"warmup {arg_metas=}")
arg_tensors = [
Expand All @@ -87,7 +93,7 @@ def warmup_with_arg(self, arg_meta_of_sizes):
dtype=arg_metas.gettype("cross_attention_dim"),
).to("cuda"),
]
self(*arg_tensors) # build and warmup
self(*arg_tensors, added) # build and warmup

def warmup_with_load(self, file_path):
state_dict = flow.load(file_path)
Expand Down Expand Up @@ -131,7 +137,6 @@ def get_arg_meta_of_sizes(
for j in resolution_scales
]


@click.command()
@click.option("--token")
@click.option("--repeat", default=100)
Expand All @@ -140,37 +145,48 @@ def get_arg_meta_of_sizes(
@click.option("--load", is_flag=True)
@click.option("--file", type=str, default="./unet_graphs")
@click.option("--model_id", type=str, default="runwayml/stable-diffusion-v1-5")
def benchmark(token, repeat, sync_interval, save, load, file, model_id):
@click.option("--revision", type=str, default="fp16")
def benchmark(token, repeat, sync_interval, save, load, file, model_id, revision):
RESOLUTION_SCALES = [2, 1, 0]
BATCH_SIZES = [2]
# TODO: reproduce bug caused by changing batch
# BATCH_SIZES = [4, 2]

num_channels = 4
# create a mocked unet graph
# unet mock should be placed before importing any diffusers
with MockCtx():
unet = get_unet(token, model_id)
unet = get_unet(token, model_id, revision)
unet_graph = UNetGraphWithCache(unet)
cross_attention_dim = unet.config["cross_attention_dim"]
warmup_meta_of_sizes = get_arg_meta_of_sizes(
batch_sizes=BATCH_SIZES,
resolution_scales=RESOLUTION_SCALES,
num_channels=num_channels,
cross_attention_dim=cross_attention_dim,
)
for (i, m) in enumerate(warmup_meta_of_sizes):
print(f"warmup case #{i + 1}:", m)
if load == True:
print("loading graphs...")
unet_graph.warmup_with_load(file)
else:
print("warmup with arguments...")
unet_graph.warmup_with_arg(warmup_meta_of_sizes)

# generate inputs with torch
num_channels = 4
cross_attention_dim = unet.config["cross_attention_dim"]
from diffusers.utils import floats_tensor
import torch
if model_id == "stabilityai/stable-diffusion-xl-base-1.0":
# sdxl needed
add_text_embeds = flow.utils.tensor.from_torch(floats_tensor((2, 1280)).to("cuda").to(torch.float16))
add_time_ids = flow.utils.tensor.from_torch(floats_tensor((2, 6)).to("cuda").to(torch.float16))
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
else:
added_cond_kwargs = None

warmup_meta_of_sizes = get_arg_meta_of_sizes(
batch_sizes=BATCH_SIZES,
resolution_scales=RESOLUTION_SCALES,
num_channels=num_channels,
cross_attention_dim=cross_attention_dim,
)
for (i, m) in enumerate(warmup_meta_of_sizes):
print(f"warmup case #{i + 1}:", m)

if load == True:
print("loading graphs...")
unet_graph.warmup_with_load(file)
else:
print("warmup with arguments...")
unet_graph.warmup_with_arg(warmup_meta_of_sizes, added_cond_kwargs)

# generate inputs with torch
time_step = torch.tensor([10]).to("cuda")
encoder_hidden_states_of_sizes = {
batch_size: floats_tensor((batch_size, 77, cross_attention_dim))
Expand All @@ -187,6 +203,7 @@ def benchmark(token, repeat, sync_interval, save, load, file, model_id):
k: flow.utils.tensor.from_torch(v)
for k, v in encoder_hidden_states_of_sizes.items()
}

# convert to oneflow tensors
time_step = flow.utils.tensor.from_torch(time_step)

Expand All @@ -199,7 +216,7 @@ def benchmark(token, repeat, sync_interval, save, load, file, model_id):

noise = random.choice(noise_of_sizes)
encoder_hidden_states = encoder_hidden_states_of_sizes[noise.shape[0]]
out = unet_graph(noise, time_step, encoder_hidden_states)
out = unet_graph(noise, time_step, encoder_hidden_states, added_cond_kwargs)
# convert to torch tensors
out = flow.utils.tensor.to_torch(out)
if r == repeat - 1 or r % sync_interval == 0:
Expand Down
1 change: 1 addition & 0 deletions src/onediff/infer_compiler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import torch
import oneflow as flow
from .with_oneflow_compile import oneflow_compile
from .with_fx_interpreter import OneFlowInterpreter
from .with_fx_graph import fx_node_tranform

Expand Down
96 changes: 92 additions & 4 deletions src/onediff/infer_compiler/obj_1f_from_torch.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,32 @@
import importlib
import os
from collections import OrderedDict
import torch
import oneflow as flow

__of_mds = {}
__convert_list = [
"diffusers.models.unet_2d_condition.UNet2DConditionModel",
"diffusers.models.embeddings.TimestepEmbedding",
"diffusers.models.embeddings.Timesteps",
"diffusers.models.resnet.ResnetBlock2D",
"diffusers.models.resnet.Downsample2D",
"diffusers.models.resnet.Upsample2D",
"diffusers.models.transformer_2d.Transformer2DModel",
"diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D",
"diffusers.models.unet_2d_blocks.DownBlock2D",
"diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D",
"diffusers.models.unet_2d_blocks.UpBlock2D",
"diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D",
"diffusers.models.unet_2d_blocks.UNetMidBlock2DCrossAttn",
]

with flow.mock_torch.enable(lazy=False):
for md_name in __convert_list:
strint marked this conversation as resolved.
Show resolved Hide resolved
p, m = md_name.rsplit('.', 1)
md = importlib.import_module(p)
__of_mds[md_name] = getattr(md, m)

import diffusers
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
from .attention_1f import BasicTransformerBlock, FeedForward, GEGLU
Expand Down Expand Up @@ -32,6 +57,10 @@ def replace_class(cls):
return LoRACompatibleLinear
elif cls == diffusers.models.lora.LoRACompatibleConv:
return LoRACompatibleConv

full_cls_name = str(cls.__module__) + '.' + str(cls.__name__)
if full_cls_name in __of_mds:
return __of_mds[full_cls_name]

if _is_diffusers_quant_available:
if cls == diffusers_quant.FakeQuantModule:
Expand Down Expand Up @@ -152,10 +181,6 @@ def __getattribute__(self, attribute):
self._1f_proxy_children[attribute] = a
else:
a = self._1f_proxy_children[attribute]
assert (
type(a).__module__.startswith("torch") == False
and type(a).__module__.startswith("diffusers") == False
), "can't be a torch module at this point! But found " + str(type(a))
strint marked this conversation as resolved.
Show resolved Hide resolved
return a

def __call__(self, *args: Any, **kwargs: Any) -> Any:
Expand All @@ -166,3 +191,66 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
raise RuntimeError(
"can't find oneflow module for: " + str(type(self._1f_proxy_submod))
)

def _get_module_list(origin_mod, torch2flow):
assert isinstance(origin_mod, torch.nn.ModuleList)
if origin_mod in torch2flow:
return torch2flow[origin_mod]
of_md_list = flow.nn.ModuleList()
for m in origin_mod:
of_md_list.append(_get_module(m, torch2flow))
torch2flow[origin_mod] = of_md_list
return of_md_list


def _get_module(origin_mod, torch2flow):
if origin_mod in torch2flow:
return torch2flow[origin_mod]

if isinstance(origin_mod, torch.nn.ModuleList):
return _get_module_list(origin_mod, torch2flow)

proxy_md = ProxySubmodule(origin_mod)
new_md_cls = replace_class(type(origin_mod))

def init(self):
self._parameters = OrderedDict()
self._buffers = OrderedDict()
self._modules = OrderedDict()
for (n, p) in list(proxy_md.named_parameters("", False)):
self._parameters[n] = flow.utils.tensor.from_torch(p.data)
for (n, b) in list(proxy_md.named_buffers("", False)):
self._buffers[n] = flow.utils.tensor.from_torch(b.data)
for (n, m) in proxy_md._modules.items():
self._modules[n] = _get_module(m, torch2flow)

for k, v in proxy_md.__dict__.items():
if k not in self.__dict__:
try:
attr = getattr(proxy_md, k)
except:
continue
self.__dict__[k] = attr

def proxy_getattr(self, attr):
if attr in ["_parameters", "_buffers", "_modules"]:
raise ValueError(f"missing attr {attr} in base class")
else:
return getattr(proxy_md, attr)

of_md_cls = type(
str(new_md_cls), (new_md_cls,), {"__init__": init, "__getattr__": proxy_getattr},
)

new_md = of_md_cls()

torch2flow[origin_mod] = new_md
return new_md

def _get_attr(gm, node, torch2flow):
attr = getattr(gm, node.target)
if attr in torch2flow:
return torch2flow[attr]
of_attr = replace_obj(attr)
torch2flow[attr] = of_attr
return of_attr
Loading