Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sd2 oneflow compile #244

Merged
merged 17 commits into from
Aug 29, 2023
2 changes: 1 addition & 1 deletion examples/text_to_image_sdxl_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,5 @@
pipe.to("cuda")

for i in range(3):
image = pipe(prompt=args.prompt).images[0]
image = pipe(prompt=args.prompt, height=96, width=128, num_inference_steps=50).images[0]
image.save(f"{i}-{args.saved_image}")
57 changes: 57 additions & 0 deletions examples/text_to_image_sdxl_fp16_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import os
import argparse
# cv2 must be imported before diffusers and oneflow to avlid error: AttributeError: module 'cv2.gapi' has no attribute 'wip'
# Maybe bacause oneflow use a lower version of cv2
import cv2
import oneflow as flow
import torch
# oneflow_compile should be imported before importing any diffusers
from onediff.infer_compiler import oneflow_compile
from diffusers import StableDiffusionXLPipeline

parser = argparse.ArgumentParser()
parser.add_argument(
"--model", type=str, default="/share_nfs/hf_models/stable-diffusion-xl-base-1.0"
)
parser.add_argument("--variant", type=str, default="fp16")
parser.add_argument(
"--prompt",
type=str,
default="street style, detailed, raw photo, woman, face, shot on CineStill 800T",
)
parser.add_argument("--saved_image", type=str, required=False, default="xl-base-out.png")
parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--compile", action=argparse.BooleanOptionalAction)
parser.add_argument("--save", action=argparse.BooleanOptionalAction)
parser.add_argument("--load", action=argparse.BooleanOptionalAction)
parser.add_argument("--file", type=str, required=False, default="unet_graph")
cmd_args = parser.parse_args()

# Normal SDXL
torch.manual_seed(cmd_args.seed)
pipe = StableDiffusionXLPipeline.from_pretrained(
cmd_args.model, torch_dtype=torch.float16, variant=cmd_args.variant, use_safetensors=True
)
pipe.to("cuda")

# Compile unet with oneflow
if cmd_args.compile:
pipe.unet = oneflow_compile(pipe.unet)
strint marked this conversation as resolved.
Show resolved Hide resolved
print("unet is compiled to oneflow.")
if cmd_args.load:
# Load compiled unet with oneflow
print("loading graphs...")
pipe.unet._graph_load(cmd_args.file)

# Normal SDXL call
sizes = [1024, 896, 768]
for h in sizes:
for w in sizes:
for i in range(1):
image = pipe(prompt=cmd_args.prompt, height=h, width=w, num_inference_steps=2).images[0]
image.save(f"h{h}-w{w}-i{i}-{cmd_args.saved_image}")

# Save compiled unet with oneflow
if cmd_args.compile and cmd_args.save:
print("saving graphs...")
pipe.unet._graph_save(cmd_args.file)
78 changes: 64 additions & 14 deletions examples/unet_torch_interplay.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import os
import random

os.environ["ONEFLOW_MLIR_CSE"] = "1"
os.environ["ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION"] = "1"
os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1"
os.environ["ONEFLOW_MLIR_FUSE_FORWARD_OPS"] = "1"
os.environ["ONEFLOW_MLIR_FUSE_OPS_WITH_BACKWARD_IMPL"] = "1"
os.environ["ONEFLOW_MLIR_GROUP_MATMUL"] = "1"
# TODO(): open ONEFLOW_MLIR_GROUP_MATMUL will raise in SDXL, need be fixed.
os.environ["ONEFLOW_MLIR_GROUP_MATMUL"] = "0"
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

open ONEFLOW_MLIR_GROUP_MATMUL will raise in SDXL, need be fixed.

os.environ["ONEFLOW_MLIR_PREFER_NHWC"] = "1"

os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_CONV_BIAS"] = "1"
Expand Down Expand Up @@ -45,13 +47,13 @@ def __exit__(self, exc_type, exc_val, exc_tb):
flow.mock_torch.disable()


def get_unet(token, _model_id):
def get_unet(token, _model_id, revision):
from diffusers import UNet2DConditionModel

unet = UNet2DConditionModel.from_pretrained(
_model_id,
use_auth_token=token,
revision="fp16",
revision=revision,
torch_dtype=flow.float16,
subfolder="unet",
)
Expand All @@ -67,14 +69,15 @@ def __init__(self, unet):
self.unet = unet
self.config.enable_cudnn_conv_heuristic_search_algo(False)
self.config.allow_fuse_add_to_output(True)
self.debug(0)

def build(self, latent_model_input, t, text_embeddings):
def build(self, latent_model_input, t, text_embeddings, added_cond_kwargs=None):
text_embeddings = flow._C.amp_white_identity(text_embeddings)
return self.unet(
latent_model_input, t, encoder_hidden_states=text_embeddings
latent_model_input, t, encoder_hidden_states=text_embeddings, added_cond_kwargs=added_cond_kwargs,
).sample

def warmup_with_arg(self, arg_meta_of_sizes):
def warmup_with_arg(self, arg_meta_of_sizes, added):
for arg_metas in arg_meta_of_sizes:
print(f"warmup {arg_metas=}")
arg_tensors = [
Expand All @@ -87,7 +90,7 @@ def warmup_with_arg(self, arg_meta_of_sizes):
dtype=arg_metas.gettype("cross_attention_dim"),
).to("cuda"),
]
self(*arg_tensors) # build and warmup
self(*arg_tensors, added) # build and warmup

def warmup_with_load(self, file_path):
state_dict = flow.load(file_path)
Expand All @@ -97,6 +100,27 @@ def save_graph(self, file_path):
state_dict = self.runtime_state_dict()
flow.save(state_dict, file_path)

def get_deployable_unet(token, model_id, revision, ):
with MockCtx():
unet = get_unet(token, model_id, revision)
unet_graph = UNetGraphWithCache(unet)
cross_attention_dim = unet.config["cross_attention_dim"]
warmup_meta_of_sizes = get_arg_meta_of_sizes(
batch_sizes=BATCH_SIZES,
resolution_scales=RESOLUTION_SCALES,
num_channels=num_channels,
cross_attention_dim=cross_attention_dim,
)
for (i, m) in enumerate(warmup_meta_of_sizes):
print(f"warmup case #{i + 1}:", m)
if load == True:
print("loading graphs...")
unet_graph.warmup_with_load(file)
else:
print("warmup with arguments...")
unet_graph.warmup_with_arg(warmup_meta_of_sizes)



def img_dim(i, start, stride):
return start + stride * i
Expand Down Expand Up @@ -131,6 +155,22 @@ def get_arg_meta_of_sizes(
for j in resolution_scales
]

global_rng = random.Random()
def floats_tensor(shape, scale=1.0, rng=None, name=None):
import torch
"""Creates a random float32 tensor"""
if rng is None:
rng = global_rng

total_dims = 1
for dim in shape:
total_dims *= dim

values = []
for _ in range(total_dims):
values.append(rng.random() * scale)

return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous()

@click.command()
@click.option("--token")
Expand All @@ -140,18 +180,29 @@ def get_arg_meta_of_sizes(
@click.option("--load", is_flag=True)
@click.option("--file", type=str, default="./unet_graphs")
@click.option("--model_id", type=str, default="runwayml/stable-diffusion-v1-5")
def benchmark(token, repeat, sync_interval, save, load, file, model_id):
@click.option("--revision", type=str, default="fp16")
def benchmark(token, repeat, sync_interval, save, load, file, model_id, revision):
RESOLUTION_SCALES = [2, 1, 0]
BATCH_SIZES = [2]
# TODO: reproduce bug caused by changing batch
# BATCH_SIZES = [4, 2]

num_channels = 4
import torch
if model_id == "stabilityai/stable-diffusion-xl-base-1.0":
# sdxl needed
add_text_embeds = flow.utils.tensor.from_torch(floats_tensor((2, 1280)).to("cuda").to(torch.float16))
add_time_ids = flow.utils.tensor.from_torch(floats_tensor((2, 6)).to("cuda").to(torch.float16))
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
else:
added_cond_kwargs = None

# create a mocked unet graph
with MockCtx():
unet = get_unet(token, model_id)
unet = get_unet(token, model_id, revision)
unet_graph = UNetGraphWithCache(unet)
cross_attention_dim = unet.config["cross_attention_dim"]

warmup_meta_of_sizes = get_arg_meta_of_sizes(
batch_sizes=BATCH_SIZES,
resolution_scales=RESOLUTION_SCALES,
Expand All @@ -160,17 +211,15 @@ def benchmark(token, repeat, sync_interval, save, load, file, model_id):
)
for (i, m) in enumerate(warmup_meta_of_sizes):
print(f"warmup case #{i + 1}:", m)

if load == True:
print("loading graphs...")
unet_graph.warmup_with_load(file)
else:
print("warmup with arguments...")
unet_graph.warmup_with_arg(warmup_meta_of_sizes)
unet_graph.warmup_with_arg(warmup_meta_of_sizes, added_cond_kwargs)

# generate inputs with torch
from diffusers.utils import floats_tensor
import torch

time_step = torch.tensor([10]).to("cuda")
encoder_hidden_states_of_sizes = {
batch_size: floats_tensor((batch_size, 77, cross_attention_dim))
Expand All @@ -187,6 +236,7 @@ def benchmark(token, repeat, sync_interval, save, load, file, model_id):
k: flow.utils.tensor.from_torch(v)
for k, v in encoder_hidden_states_of_sizes.items()
}

# convert to oneflow tensors
time_step = flow.utils.tensor.from_torch(time_step)

Expand All @@ -199,7 +249,7 @@ def benchmark(token, repeat, sync_interval, save, load, file, model_id):

noise = random.choice(noise_of_sizes)
encoder_hidden_states = encoder_hidden_states_of_sizes[noise.shape[0]]
out = unet_graph(noise, time_step, encoder_hidden_states)
out = unet_graph(noise, time_step, encoder_hidden_states, added_cond_kwargs)
# convert to torch tensors
out = flow.utils.tensor.to_torch(out)
if r == repeat - 1 or r % sync_interval == 0:
Expand Down
1 change: 1 addition & 0 deletions src/onediff/infer_compiler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import torch
import oneflow as flow
from .with_onef_graph import oneflow_compile
from .with_fx_interpreter import OneFlowInterpreter
from .with_fx_graph import fx_node_tranform

Expand Down
96 changes: 92 additions & 4 deletions src/onediff/infer_compiler/obj_1f_from_torch.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,32 @@
import importlib
import os
from collections import OrderedDict
import torch
import oneflow as flow

__of_mds = {}
__convert_list = [
"diffusers.models.unet_2d_condition.UNet2DConditionModel",
"diffusers.models.embeddings.TimestepEmbedding",
"diffusers.models.embeddings.Timesteps",
"diffusers.models.resnet.ResnetBlock2D",
"diffusers.models.resnet.Downsample2D",
"diffusers.models.resnet.Upsample2D",
"diffusers.models.transformer_2d.Transformer2DModel",
"diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D",
"diffusers.models.unet_2d_blocks.DownBlock2D",
"diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D",
"diffusers.models.unet_2d_blocks.UpBlock2D",
"diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D",
"diffusers.models.unet_2d_blocks.UNetMidBlock2DCrossAttn",
]

with flow.mock_torch.enable(lazy=False):
for md_name in __convert_list:
strint marked this conversation as resolved.
Show resolved Hide resolved
p, m = md_name.rsplit('.', 1)
md = importlib.import_module(p)
__of_mds[md_name] = getattr(md, m)

import diffusers
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
from .attention_1f import BasicTransformerBlock, FeedForward, GEGLU
Expand Down Expand Up @@ -32,6 +57,10 @@ def replace_class(cls):
return LoRACompatibleLinear
elif cls == diffusers.models.lora.LoRACompatibleConv:
return LoRACompatibleConv

full_cls_name = str(cls.__module__) + '.' + str(cls.__name__)
if full_cls_name in __of_mds:
return __of_mds[full_cls_name]

if _is_diffusers_quant_available:
if cls == diffusers_quant.FakeQuantModule:
Expand Down Expand Up @@ -152,10 +181,6 @@ def __getattribute__(self, attribute):
self._1f_proxy_children[attribute] = a
else:
a = self._1f_proxy_children[attribute]
assert (
type(a).__module__.startswith("torch") == False
and type(a).__module__.startswith("diffusers") == False
), "can't be a torch module at this point! But found " + str(type(a))
strint marked this conversation as resolved.
Show resolved Hide resolved
return a

def __call__(self, *args: Any, **kwargs: Any) -> Any:
Expand All @@ -166,3 +191,66 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
raise RuntimeError(
"can't find oneflow module for: " + str(type(self._1f_proxy_submod))
)

def _get_module_list(origin_mod, torch2flow):
assert isinstance(origin_mod, torch.nn.ModuleList)
if origin_mod in torch2flow:
return torch2flow[origin_mod]
of_md_list = flow.nn.ModuleList()
for m in origin_mod:
of_md_list.append(_get_module(m, torch2flow))
torch2flow[origin_mod] = of_md_list
return of_md_list


def _get_module(origin_mod, torch2flow):
if origin_mod in torch2flow:
return torch2flow[origin_mod]

if isinstance(origin_mod, torch.nn.ModuleList):
return _get_module_list(origin_mod, torch2flow)

proxy_md = ProxySubmodule(origin_mod)
new_md_cls = replace_class(type(origin_mod))

def init(self):
self._parameters = OrderedDict()
self._buffers = OrderedDict()
self._modules = OrderedDict()
for (n, p) in list(proxy_md.named_parameters("", False)):
self._parameters[n] = flow.utils.tensor.from_torch(p.data)
for (n, b) in list(proxy_md.named_buffers("", False)):
self._buffers[n] = flow.utils.tensor.from_torch(b.data)
for (n, m) in proxy_md._modules.items():
self._modules[n] = _get_module(m, torch2flow)

for k, v in proxy_md.__dict__.items():
if k not in self.__dict__:
try:
attr = getattr(proxy_md, k)
except:
continue
self.__dict__[k] = attr

def proxy_getattr(self, attr):
if attr in ["_parameters", "_buffers", "_modules"]:
raise ValueError(f"missing attr {attr} in base class")
else:
return getattr(proxy_md, attr)

of_md_cls = type(
str(new_md_cls), (new_md_cls,), {"__init__": init, "__getattr__": proxy_getattr},
)

new_md = of_md_cls()

torch2flow[origin_mod] = new_md
return new_md

def _get_attr(gm, node, torch2flow):
attr = getattr(gm, node.target)
if attr in torch2flow:
return torch2flow[attr]
of_attr = replace_obj(attr)
torch2flow[attr] = of_attr
return of_attr
Loading