-
Notifications
You must be signed in to change notification settings - Fork 103
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
sd2 oneflow compile #244
Merged
Merged
sd2 oneflow compile #244
Changes from 13 commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
d81c136
add env var for graph opt
strint 5fcf5b6
test pass of sdxl graph
strint 2909c92
draft graph run
strint 9b0a373
unet sd1.5 test passed
strint 0eaa2c9
part2
strint 4d95606
Merge branch 'refactor-backend' of https://github.com/Oneflow-Inc/dif…
strint a1eaf26
sdxl unet graph test passed
strint da4f474
oneflow unet replace test passed
strint a3adab9
add unet graph to sdxl test passed
strint 97d69cf
sdxl unet graph multi input shape
strint 570d383
sdxl unet graph save load multi input test passed
strint 39263e5
refine oneflow_compile api
strint b62c21f
refine code
strint 83a19f3
format
strint f666512
deal with special case
strint 7e8c0f8
open GROUP_MATMUL
strint 4c7d6da
add (base and refiner) full pipeline
strint File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import os | ||
import argparse | ||
# cv2 must be imported before diffusers and oneflow to avlid error: AttributeError: module 'cv2.gapi' has no attribute 'wip' | ||
# Maybe bacause oneflow use a lower version of cv2 | ||
import cv2 | ||
import oneflow as flow | ||
import torch | ||
# oneflow_compile should be imported before importing any diffusers | ||
from onediff.infer_compiler import oneflow_compile | ||
from diffusers import StableDiffusionXLPipeline | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--model", type=str, default="/share_nfs/hf_models/stable-diffusion-xl-base-1.0" | ||
) | ||
parser.add_argument("--variant", type=str, default="fp16") | ||
parser.add_argument( | ||
"--prompt", | ||
type=str, | ||
default="street style, detailed, raw photo, woman, face, shot on CineStill 800T", | ||
) | ||
parser.add_argument("--saved_image", type=str, required=False, default="xl-base-out.png") | ||
parser.add_argument("--seed", type=int, default=1) | ||
parser.add_argument("--compile", action=argparse.BooleanOptionalAction) | ||
parser.add_argument("--save", action=argparse.BooleanOptionalAction) | ||
parser.add_argument("--load", action=argparse.BooleanOptionalAction) | ||
parser.add_argument("--file", type=str, required=False, default="unet_compiled") | ||
cmd_args = parser.parse_args() | ||
|
||
# Normal SDXL | ||
torch.manual_seed(cmd_args.seed) | ||
pipe = StableDiffusionXLPipeline.from_pretrained( | ||
cmd_args.model, torch_dtype=torch.float16, variant=cmd_args.variant, use_safetensors=True | ||
) | ||
pipe.to("cuda") | ||
|
||
# Compile unet with oneflow | ||
if cmd_args.compile: | ||
pipe.unet = oneflow_compile(pipe.unet) | ||
print("unet is compiled to oneflow.") | ||
if cmd_args.load: | ||
# Load compiled unet with oneflow | ||
print("loading graphs...") | ||
pipe.unet._graph_load(cmd_args.file) | ||
|
||
# Normal SDXL call | ||
sizes = [1024, 896, 768] | ||
for h in sizes: | ||
for w in sizes: | ||
for i in range(1): | ||
image = pipe(prompt=cmd_args.prompt, height=h, width=w, num_inference_steps=2).images[0] | ||
image.save(f"h{h}-w{w}-i{i}-{cmd_args.saved_image}") | ||
|
||
# Save compiled unet with oneflow | ||
if cmd_args.compile and cmd_args.save: | ||
print("saving graphs...") | ||
pipe.unet._graph_save(cmd_args.file) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,13 @@ | ||
import os | ||
import random | ||
|
||
os.environ["ONEFLOW_MLIR_CSE"] = "1" | ||
os.environ["ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION"] = "1" | ||
os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" | ||
os.environ["ONEFLOW_MLIR_FUSE_FORWARD_OPS"] = "1" | ||
os.environ["ONEFLOW_MLIR_FUSE_OPS_WITH_BACKWARD_IMPL"] = "1" | ||
os.environ["ONEFLOW_MLIR_GROUP_MATMUL"] = "1" | ||
# TODO(): open ONEFLOW_MLIR_GROUP_MATMUL will raise in SDXL, need be fixed. | ||
os.environ["ONEFLOW_MLIR_GROUP_MATMUL"] = "0" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. open ONEFLOW_MLIR_GROUP_MATMUL will raise in SDXL, need be fixed. |
||
os.environ["ONEFLOW_MLIR_PREFER_NHWC"] = "1" | ||
|
||
os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_CONV_BIAS"] = "1" | ||
|
@@ -20,6 +22,9 @@ | |
os.environ["ONEFLOW_LINEAR_EMBEDDING_SKIP_INIT"] = "1" | ||
|
||
import click | ||
# cv2 must be imported before diffusers and oneflow to avlid error: AttributeError: module 'cv2.gapi' has no attribute 'wip' | ||
# Maybe bacause oneflow use a lower version of cv2 | ||
import cv2 | ||
import oneflow as flow | ||
from tqdm import tqdm | ||
from dataclasses import dataclass, fields | ||
|
@@ -45,13 +50,13 @@ def __exit__(self, exc_type, exc_val, exc_tb): | |
flow.mock_torch.disable() | ||
|
||
|
||
def get_unet(token, _model_id): | ||
def get_unet(token, _model_id, revision): | ||
from diffusers import UNet2DConditionModel | ||
|
||
unet = UNet2DConditionModel.from_pretrained( | ||
_model_id, | ||
use_auth_token=token, | ||
revision="fp16", | ||
revision=revision, | ||
torch_dtype=flow.float16, | ||
subfolder="unet", | ||
) | ||
|
@@ -67,14 +72,15 @@ def __init__(self, unet): | |
self.unet = unet | ||
self.config.enable_cudnn_conv_heuristic_search_algo(False) | ||
self.config.allow_fuse_add_to_output(True) | ||
self.debug(0) | ||
|
||
def build(self, latent_model_input, t, text_embeddings): | ||
def build(self, latent_model_input, t, text_embeddings, added_cond_kwargs=None): | ||
text_embeddings = flow._C.amp_white_identity(text_embeddings) | ||
return self.unet( | ||
latent_model_input, t, encoder_hidden_states=text_embeddings | ||
latent_model_input, t, encoder_hidden_states=text_embeddings, added_cond_kwargs=added_cond_kwargs, | ||
).sample | ||
|
||
def warmup_with_arg(self, arg_meta_of_sizes): | ||
def warmup_with_arg(self, arg_meta_of_sizes, added): | ||
for arg_metas in arg_meta_of_sizes: | ||
print(f"warmup {arg_metas=}") | ||
arg_tensors = [ | ||
|
@@ -87,7 +93,7 @@ def warmup_with_arg(self, arg_meta_of_sizes): | |
dtype=arg_metas.gettype("cross_attention_dim"), | ||
).to("cuda"), | ||
] | ||
self(*arg_tensors) # build and warmup | ||
self(*arg_tensors, added) # build and warmup | ||
|
||
def warmup_with_load(self, file_path): | ||
state_dict = flow.load(file_path) | ||
|
@@ -131,7 +137,6 @@ def get_arg_meta_of_sizes( | |
for j in resolution_scales | ||
] | ||
|
||
|
||
@click.command() | ||
@click.option("--token") | ||
@click.option("--repeat", default=100) | ||
|
@@ -140,37 +145,48 @@ def get_arg_meta_of_sizes( | |
@click.option("--load", is_flag=True) | ||
@click.option("--file", type=str, default="./unet_graphs") | ||
@click.option("--model_id", type=str, default="runwayml/stable-diffusion-v1-5") | ||
def benchmark(token, repeat, sync_interval, save, load, file, model_id): | ||
@click.option("--revision", type=str, default="fp16") | ||
def benchmark(token, repeat, sync_interval, save, load, file, model_id, revision): | ||
RESOLUTION_SCALES = [2, 1, 0] | ||
BATCH_SIZES = [2] | ||
# TODO: reproduce bug caused by changing batch | ||
# BATCH_SIZES = [4, 2] | ||
|
||
num_channels = 4 | ||
# create a mocked unet graph | ||
# unet mock should be placed before importing any diffusers | ||
with MockCtx(): | ||
unet = get_unet(token, model_id) | ||
unet = get_unet(token, model_id, revision) | ||
unet_graph = UNetGraphWithCache(unet) | ||
cross_attention_dim = unet.config["cross_attention_dim"] | ||
warmup_meta_of_sizes = get_arg_meta_of_sizes( | ||
batch_sizes=BATCH_SIZES, | ||
resolution_scales=RESOLUTION_SCALES, | ||
num_channels=num_channels, | ||
cross_attention_dim=cross_attention_dim, | ||
) | ||
for (i, m) in enumerate(warmup_meta_of_sizes): | ||
print(f"warmup case #{i + 1}:", m) | ||
if load == True: | ||
print("loading graphs...") | ||
unet_graph.warmup_with_load(file) | ||
else: | ||
print("warmup with arguments...") | ||
unet_graph.warmup_with_arg(warmup_meta_of_sizes) | ||
|
||
# generate inputs with torch | ||
num_channels = 4 | ||
cross_attention_dim = unet.config["cross_attention_dim"] | ||
from diffusers.utils import floats_tensor | ||
import torch | ||
if model_id == "stabilityai/stable-diffusion-xl-base-1.0": | ||
# sdxl needed | ||
add_text_embeds = flow.utils.tensor.from_torch(floats_tensor((2, 1280)).to("cuda").to(torch.float16)) | ||
add_time_ids = flow.utils.tensor.from_torch(floats_tensor((2, 6)).to("cuda").to(torch.float16)) | ||
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} | ||
else: | ||
added_cond_kwargs = None | ||
|
||
warmup_meta_of_sizes = get_arg_meta_of_sizes( | ||
batch_sizes=BATCH_SIZES, | ||
resolution_scales=RESOLUTION_SCALES, | ||
num_channels=num_channels, | ||
cross_attention_dim=cross_attention_dim, | ||
) | ||
for (i, m) in enumerate(warmup_meta_of_sizes): | ||
print(f"warmup case #{i + 1}:", m) | ||
|
||
if load == True: | ||
print("loading graphs...") | ||
unet_graph.warmup_with_load(file) | ||
else: | ||
print("warmup with arguments...") | ||
unet_graph.warmup_with_arg(warmup_meta_of_sizes, added_cond_kwargs) | ||
|
||
# generate inputs with torch | ||
time_step = torch.tensor([10]).to("cuda") | ||
encoder_hidden_states_of_sizes = { | ||
batch_size: floats_tensor((batch_size, 77, cross_attention_dim)) | ||
|
@@ -187,6 +203,7 @@ def benchmark(token, repeat, sync_interval, save, load, file, model_id): | |
k: flow.utils.tensor.from_torch(v) | ||
for k, v in encoder_hidden_states_of_sizes.items() | ||
} | ||
|
||
# convert to oneflow tensors | ||
time_step = flow.utils.tensor.from_torch(time_step) | ||
|
||
|
@@ -199,7 +216,7 @@ def benchmark(token, repeat, sync_interval, save, load, file, model_id): | |
|
||
noise = random.choice(noise_of_sizes) | ||
encoder_hidden_states = encoder_hidden_states_of_sizes[noise.shape[0]] | ||
out = unet_graph(noise, time_step, encoder_hidden_states) | ||
out = unet_graph(noise, time_step, encoder_hidden_states, added_cond_kwargs) | ||
# convert to torch tensors | ||
out = flow.utils.tensor.to_torch(out) | ||
if r == repeat - 1 or r % sync_interval == 0: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
和 torch.compile 体验一致的 oneflow_compile 函数。实现了一键替换 pipe 中的 unet 实例的效果(显存共享的)。且该 unet 支持 save/load/dynamic input。
一键替换 pipe 中的 unet 要解决的两个问题:
这个作为 base(验证性能和正确性) 和保底方案。
后面基于这个去做 torch.compile 路线的 save/load/dynamic input 支持。torch.compile 本身对 save/load/dynamic input 这些特性的支持还太完善: