-
Notifications
You must be signed in to change notification settings - Fork 103
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
sd2 oneflow compile #244
sd2 oneflow compile #244
Conversation
…fusers into sdxl_nn_graph
examples/unet_torch_interplay.py
Outdated
|
||
os.environ["ONEFLOW_MLIR_CSE"] = "1" | ||
os.environ["ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION"] = "1" | ||
os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" | ||
os.environ["ONEFLOW_MLIR_FUSE_FORWARD_OPS"] = "1" | ||
os.environ["ONEFLOW_MLIR_FUSE_OPS_WITH_BACKWARD_IMPL"] = "1" | ||
os.environ["ONEFLOW_MLIR_GROUP_MATMUL"] = "1" | ||
# TODO(): open ONEFLOW_MLIR_GROUP_MATMUL will raise in SDXL, need be fixed. | ||
os.environ["ONEFLOW_MLIR_GROUP_MATMUL"] = "0" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
open ONEFLOW_MLIR_GROUP_MATMUL will raise in SDXL, need be fixed.
|
||
# Compile unet with oneflow | ||
if cmd_args.compile: | ||
pipe.unet = oneflow_compile(pipe.unet) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
和 torch.compile 体验一致的 oneflow_compile 函数。实现了一键替换 pipe 中的 unet 实例的效果(显存共享的)。且该 unet 支持 save/load/dynamic input。
一键替换 pipe 中的 unet 要解决的两个问题:
- pipe 中已经创建了 torch 的 eager unet 实例,如何创建 oneflow unet 实例,且显存共享;
- 如何创建 unet graph 实例,让其和 pipe 中的 eager unet 外部 attr 完全一样,且用 graph 执行;
这个作为 base(验证性能和正确性) 和保底方案。
后面基于这个去做 torch.compile 路线的 save/load/dynamic input 支持。torch.compile 本身对 save/load/dynamic input 这些特性的支持还太完善:
- unet dynamic 会报错;
- 还没有 save load 编译结果的机制;
Depends on
Oneflow-Inc/oneflow#10323
unet graph(sd2)
run
run and save
load and run
sd2 with unet graph
run torch eager
run oneflow graph
run oneflow graph save
run oneflow graph load