Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sd2 oneflow compile #244

Merged
merged 17 commits into from
Aug 29, 2023
Merged

sd2 oneflow compile #244

merged 17 commits into from
Aug 29, 2023

Conversation

strint
Copy link
Collaborator

@strint strint commented Aug 23, 2023

Depends on
Oneflow-Inc/oneflow#10323

unet graph(sd2)

run

python examples/unet_torch_interplay.py --model_id=stabilityai/stable-diffusion-xl-base-1.0 --revision=main

run and save

python examples/unet_torch_interplay.py --model_id=stabilityai/stable-diffusion-xl-base-1.0 --revision=main --save

load and run

python examples/unet_torch_interplay.py --model_id=stabilityai/stable-diffusion-xl-base-1.0 --revision=main --load

sd2 with unet graph

run torch eager

python examples/text_to_image_sdxl_fp16_with_oneflow_compile.py

run oneflow graph

python examples/text_to_image_sdxl_fp16_with_oneflow_compile.py --compile

run oneflow graph save

python examples/text_to_image_sdxl_fp16_with_oneflow_compile.py --compile --save

run oneflow graph load

python examples/text_to_image_sdxl_fp16_with_oneflow_compile.py --compile --load


os.environ["ONEFLOW_MLIR_CSE"] = "1"
os.environ["ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION"] = "1"
os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1"
os.environ["ONEFLOW_MLIR_FUSE_FORWARD_OPS"] = "1"
os.environ["ONEFLOW_MLIR_FUSE_OPS_WITH_BACKWARD_IMPL"] = "1"
os.environ["ONEFLOW_MLIR_GROUP_MATMUL"] = "1"
# TODO(): open ONEFLOW_MLIR_GROUP_MATMUL will raise in SDXL, need be fixed.
os.environ["ONEFLOW_MLIR_GROUP_MATMUL"] = "0"
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

open ONEFLOW_MLIR_GROUP_MATMUL will raise in SDXL, need be fixed.

@strint strint changed the title [Draft]sd2 nn graph sd2 nn graph Aug 24, 2023
@strint strint changed the title sd2 nn graph sd2 oneflow compile Aug 25, 2023

# Compile unet with oneflow
if cmd_args.compile:
pipe.unet = oneflow_compile(pipe.unet)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

和 torch.compile 体验一致的 oneflow_compile 函数。实现了一键替换 pipe 中的 unet 实例的效果(显存共享的)。且该 unet 支持 save/load/dynamic input。

一键替换 pipe 中的 unet 要解决的两个问题:

  • pipe 中已经创建了 torch 的 eager unet 实例,如何创建 oneflow unet 实例,且显存共享;
  • 如何创建 unet graph 实例,让其和 pipe 中的 eager unet 外部 attr 完全一样,且用 graph 执行;

这个作为 base(验证性能和正确性) 和保底方案。

后面基于这个去做 torch.compile 路线的 save/load/dynamic input 支持。torch.compile 本身对 save/load/dynamic input 这些特性的支持还太完善:

  • unet dynamic 会报错;
  • 还没有 save load 编译结果的机制;

@hjchen2 hjchen2 merged commit 3680e73 into refactor-backend Aug 29, 2023
@hjchen2 hjchen2 deleted the sdxl_nn_graph branch August 29, 2023 06:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants