From d81c1367dab3e73c892ee6b37279b01daf8d8cac Mon Sep 17 00:00:00 2001 From: strint Date: Mon, 21 Aug 2023 07:09:38 +0000 Subject: [PATCH 01/16] add env var for graph opt --- examples/torch_interpretor.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/examples/torch_interpretor.py b/examples/torch_interpretor.py index 2566941a2..94c8cfa35 100644 --- a/examples/torch_interpretor.py +++ b/examples/torch_interpretor.py @@ -12,7 +12,25 @@ ) os.environ["with_interp"] = "0" + +# optimize with oneflow graph os.environ["with_graph"] = "1" +os.environ["ONEFLOW_MLIR_CSE"] = "1" +os.environ["ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION"] = "1" +os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" +os.environ["ONEFLOW_MLIR_FUSE_FORWARD_OPS"] = "1" +os.environ["ONEFLOW_MLIR_FUSE_OPS_WITH_BACKWARD_IMPL"] = "1" +os.environ["ONEFLOW_MLIR_GROUP_MATMUL"] = "1" +os.environ["ONEFLOW_MLIR_PREFER_NHWC"] = "1" +os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_CONV_BIAS"] = "1" +os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_LINEAR"] = "1" +os.environ["ONEFLOW_KERNEL_CONV_CUTLASS_IMPL_ENABLE_TUNING_WARMUP"] = "1" +os.environ["ONEFLOW_KERNEL_CONV_ENABLE_CUTLASS_IMPL"] = "1" +os.environ["ONEFLOW_CONV_ALLOW_HALF_PRECISION_ACCUMULATION"] = "1" +os.environ["ONEFLOW_MATMUL_ALLOW_HALF_PRECISION_ACCUMULATION"] = "1" +os.environ["ONEFLOW_LINEAR_EMBEDDING_SKIP_INIT"] = "1" + + pipe.unet = torch.compile(pipe.unet, fullgraph=True, mode="reduce-overhead", backend=torchbackend) pipe = pipe.to("cuda") From 5fcf5b644e36208adc7e606cf5635986c1a37da8 Mon Sep 17 00:00:00 2001 From: strint Date: Mon, 21 Aug 2023 08:28:51 +0000 Subject: [PATCH 02/16] test pass of sdxl graph --- examples/text-to-image-sdxl-fp16.py | 33 ------------- examples/text_to_image_sdxl_fp16.py | 55 +++++++++++++++++++++ src/onediff/infer_compiler/with_fx_graph.py | 3 ++ 3 files changed, 58 insertions(+), 33 deletions(-) delete mode 100644 examples/text-to-image-sdxl-fp16.py create mode 100644 examples/text_to_image_sdxl_fp16.py diff --git a/examples/text-to-image-sdxl-fp16.py b/examples/text-to-image-sdxl-fp16.py deleted file mode 100644 index a84654df4..000000000 --- a/examples/text-to-image-sdxl-fp16.py +++ /dev/null @@ -1,33 +0,0 @@ -import argparse -from diffusers import StableDiffusionXLPipeline -import torch - -from onediff.infer_compiler import torchbackend - - -parser = argparse.ArgumentParser() -parser.add_argument( - "--model", type=str, default="/share_nfs/hf_models/stable-diffusion-xl-base-1.0" -) -parser.add_argument("--variant", type=str, default="fp16") -parser.add_argument( - "--prompt", - type=str, - default="street style, detailed, raw photo, woman, face, shot on CineStill 800T", -) -parser.add_argument("--saved_image", type=str, required=False, default="xl-base-out.png") -parser.add_argument("--seed", type=int, default=1) -args = parser.parse_args() - -torch.manual_seed(args.seed) - -pipe = StableDiffusionXLPipeline.from_pretrained( - args.model, torch_dtype=torch.float16, variant=args.variant, use_safetensors=True -) - -pipe.unet = torch.compile(pipe.unet, fullgraph=True, mode="reduce-overhead", backend=torchbackend) - -pipe.to("cuda") - -image = pipe(prompt=args.prompt).images[0] -image.save(args.saved_image) diff --git a/examples/text_to_image_sdxl_fp16.py b/examples/text_to_image_sdxl_fp16.py new file mode 100644 index 000000000..1ea605c75 --- /dev/null +++ b/examples/text_to_image_sdxl_fp16.py @@ -0,0 +1,55 @@ +import os +import argparse +from diffusers import StableDiffusionXLPipeline +import torch + +from onediff.infer_compiler import torchbackend + + +parser = argparse.ArgumentParser() +parser.add_argument( + "--model", type=str, default="/share_nfs/hf_models/stable-diffusion-xl-base-1.0" +) +parser.add_argument("--variant", type=str, default="fp16") +parser.add_argument( + "--prompt", + type=str, + default="street style, detailed, raw photo, woman, face, shot on CineStill 800T", +) +parser.add_argument("--saved_image", type=str, required=False, default="xl-base-out.png") +parser.add_argument("--seed", type=int, default=1) +parser.add_argument("--graph", action=argparse.BooleanOptionalAction) +args = parser.parse_args() + +torch.manual_seed(args.seed) +print(args.graph) + +pipe = StableDiffusionXLPipeline.from_pretrained( + args.model, torch_dtype=torch.float16, variant=args.variant, use_safetensors=True +) + +if args.graph: + os.environ["with_graph"] = "1" + os.environ["ONEFLOW_MLIR_CSE"] = "1" + os.environ["ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION"] = "1" + os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" + os.environ["ONEFLOW_MLIR_FUSE_FORWARD_OPS"] = "1" + os.environ["ONEFLOW_MLIR_FUSE_OPS_WITH_BACKWARD_IMPL"] = "1" + # Open this will raise error + # os.environ["ONEFLOW_MLIR_GROUP_MATMUL"] = "1" + os.environ["ONEFLOW_MLIR_PREFER_NHWC"] = "1" + os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_CONV_BIAS"] = "1" + os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_LINEAR"] = "1" + os.environ["ONEFLOW_KERNEL_CONV_CUTLASS_IMPL_ENABLE_TUNING_WARMUP"] = "1" + os.environ["ONEFLOW_KERNEL_CONV_ENABLE_CUTLASS_IMPL"] = "1" + os.environ["ONEFLOW_CONV_ALLOW_HALF_PRECISION_ACCUMULATION"] = "1" + os.environ["ONEFLOW_MATMUL_ALLOW_HALF_PRECISION_ACCUMULATION"] = "1" + os.environ["ONEFLOW_LINEAR_EMBEDDING_SKIP_INIT"] = "1" + +pipe.unet = torch.compile(pipe.unet, fullgraph=True, mode="reduce-overhead", backend=torchbackend) + +pipe.to("cuda") + +for i in range(3): + image = pipe(prompt=args.prompt).images[0] + image.save(f"{i}-{args.saved_image}") diff --git a/src/onediff/infer_compiler/with_fx_graph.py b/src/onediff/infer_compiler/with_fx_graph.py index ad7de58d7..d73eb9be3 100644 --- a/src/onediff/infer_compiler/with_fx_graph.py +++ b/src/onediff/infer_compiler/with_fx_graph.py @@ -23,11 +23,14 @@ class OfGraph(flow.nn.Graph): def __init__(self): super().__init__() self.fx_md = of_gm + self.config.enable_cudnn_conv_heuristic_search_algo(False) + self.config.allow_fuse_add_to_output(True) def build(self, *args, **kwargs): return self.fx_md(*args, **kwargs) of_g = OfGraph() + of_g.debug(0) oneflow_fn = lambda *args, **kwargs: of_g(*args, **kwargs) return oneflow_fn From 2909c924129735d98581241e3fba15aed1868a6f Mon Sep 17 00:00:00 2001 From: strint Date: Mon, 21 Aug 2023 08:51:05 +0000 Subject: [PATCH 03/16] draft graph run --- examples/text_to_image_sdxl_fp16.py | 45 ++++++++++++++++------------- 1 file changed, 25 insertions(+), 20 deletions(-) diff --git a/examples/text_to_image_sdxl_fp16.py b/examples/text_to_image_sdxl_fp16.py index 1ea605c75..b0a700bf6 100644 --- a/examples/text_to_image_sdxl_fp16.py +++ b/examples/text_to_image_sdxl_fp16.py @@ -18,35 +18,40 @@ ) parser.add_argument("--saved_image", type=str, required=False, default="xl-base-out.png") parser.add_argument("--seed", type=int, default=1) +parser.add_argument("--compile", action=argparse.BooleanOptionalAction) parser.add_argument("--graph", action=argparse.BooleanOptionalAction) args = parser.parse_args() +if args.compile: + print("unet is compiled to oneflow.") + if args.graph: + print("unet is compiled to oneflow graph.") + torch.manual_seed(args.seed) -print(args.graph) pipe = StableDiffusionXLPipeline.from_pretrained( args.model, torch_dtype=torch.float16, variant=args.variant, use_safetensors=True ) -if args.graph: - os.environ["with_graph"] = "1" - os.environ["ONEFLOW_MLIR_CSE"] = "1" - os.environ["ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION"] = "1" - os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" - os.environ["ONEFLOW_MLIR_FUSE_FORWARD_OPS"] = "1" - os.environ["ONEFLOW_MLIR_FUSE_OPS_WITH_BACKWARD_IMPL"] = "1" - # Open this will raise error - # os.environ["ONEFLOW_MLIR_GROUP_MATMUL"] = "1" - os.environ["ONEFLOW_MLIR_PREFER_NHWC"] = "1" - os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_CONV_BIAS"] = "1" - os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_LINEAR"] = "1" - os.environ["ONEFLOW_KERNEL_CONV_CUTLASS_IMPL_ENABLE_TUNING_WARMUP"] = "1" - os.environ["ONEFLOW_KERNEL_CONV_ENABLE_CUTLASS_IMPL"] = "1" - os.environ["ONEFLOW_CONV_ALLOW_HALF_PRECISION_ACCUMULATION"] = "1" - os.environ["ONEFLOW_MATMUL_ALLOW_HALF_PRECISION_ACCUMULATION"] = "1" - os.environ["ONEFLOW_LINEAR_EMBEDDING_SKIP_INIT"] = "1" - -pipe.unet = torch.compile(pipe.unet, fullgraph=True, mode="reduce-overhead", backend=torchbackend) +if args.compile: + if args.graph: + os.environ["with_graph"] = "1" + os.environ["ONEFLOW_MLIR_CSE"] = "1" + os.environ["ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION"] = "1" + os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" + os.environ["ONEFLOW_MLIR_FUSE_FORWARD_OPS"] = "1" + os.environ["ONEFLOW_MLIR_FUSE_OPS_WITH_BACKWARD_IMPL"] = "1" + # Open this will raise error + # os.environ["ONEFLOW_MLIR_GROUP_MATMUL"] = "1" + os.environ["ONEFLOW_MLIR_PREFER_NHWC"] = "1" + os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_CONV_BIAS"] = "1" + os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_LINEAR"] = "1" + os.environ["ONEFLOW_KERNEL_CONV_CUTLASS_IMPL_ENABLE_TUNING_WARMUP"] = "1" + os.environ["ONEFLOW_KERNEL_CONV_ENABLE_CUTLASS_IMPL"] = "1" + os.environ["ONEFLOW_CONV_ALLOW_HALF_PRECISION_ACCUMULATION"] = "1" + os.environ["ONEFLOW_MATMUL_ALLOW_HALF_PRECISION_ACCUMULATION"] = "1" + os.environ["ONEFLOW_LINEAR_EMBEDDING_SKIP_INIT"] = "1" + pipe.unet = torch.compile(pipe.unet, fullgraph=True, mode="reduce-overhead", backend=torchbackend) pipe.to("cuda") From 9b0a373684b91a997692deb213b345c27eed7ea9 Mon Sep 17 00:00:00 2001 From: strint Date: Wed, 23 Aug 2023 07:02:06 +0000 Subject: [PATCH 04/16] unet sd1.5 test passed --- examples/unet_torch_interplay.py | 49 ++++++++++++++++++++++++++++---- 1 file changed, 44 insertions(+), 5 deletions(-) diff --git a/examples/unet_torch_interplay.py b/examples/unet_torch_interplay.py index 9a5af8477..c753eddfe 100644 --- a/examples/unet_torch_interplay.py +++ b/examples/unet_torch_interplay.py @@ -1,4 +1,5 @@ import os +import random os.environ["ONEFLOW_MLIR_CSE"] = "1" os.environ["ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION"] = "1" @@ -45,13 +46,13 @@ def __exit__(self, exc_type, exc_val, exc_tb): flow.mock_torch.disable() -def get_unet(token, _model_id): +def get_unet(token, _model_id, revision): from diffusers import UNet2DConditionModel unet = UNet2DConditionModel.from_pretrained( _model_id, use_auth_token=token, - revision="fp16", + revision=revision, torch_dtype=flow.float16, subfolder="unet", ) @@ -97,6 +98,27 @@ def save_graph(self, file_path): state_dict = self.runtime_state_dict() flow.save(state_dict, file_path) +def get_deployable_unet(token, model_id, revision, ): + with MockCtx(): + unet = get_unet(token, model_id, revision) + unet_graph = UNetGraphWithCache(unet) + cross_attention_dim = unet.config["cross_attention_dim"] + warmup_meta_of_sizes = get_arg_meta_of_sizes( + batch_sizes=BATCH_SIZES, + resolution_scales=RESOLUTION_SCALES, + num_channels=num_channels, + cross_attention_dim=cross_attention_dim, + ) + for (i, m) in enumerate(warmup_meta_of_sizes): + print(f"warmup case #{i + 1}:", m) + if load == True: + print("loading graphs...") + unet_graph.warmup_with_load(file) + else: + print("warmup with arguments...") + unet_graph.warmup_with_arg(warmup_meta_of_sizes) + + def img_dim(i, start, stride): return start + stride * i @@ -131,6 +153,22 @@ def get_arg_meta_of_sizes( for j in resolution_scales ] +global_rng = random.Random() +def floats_tensor(shape, scale=1.0, rng=None, name=None): + import torch + """Creates a random float32 tensor""" + if rng is None: + rng = global_rng + + total_dims = 1 + for dim in shape: + total_dims *= dim + + values = [] + for _ in range(total_dims): + values.append(rng.random() * scale) + + return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous() @click.command() @click.option("--token") @@ -140,7 +178,8 @@ def get_arg_meta_of_sizes( @click.option("--load", is_flag=True) @click.option("--file", type=str, default="./unet_graphs") @click.option("--model_id", type=str, default="runwayml/stable-diffusion-v1-5") -def benchmark(token, repeat, sync_interval, save, load, file, model_id): +@click.option("--revision", type=str, default="fp16") +def benchmark(token, repeat, sync_interval, save, load, file, model_id, revision): RESOLUTION_SCALES = [2, 1, 0] BATCH_SIZES = [2] # TODO: reproduce bug caused by changing batch @@ -149,7 +188,7 @@ def benchmark(token, repeat, sync_interval, save, load, file, model_id): num_channels = 4 # create a mocked unet graph with MockCtx(): - unet = get_unet(token, model_id) + unet = get_unet(token, model_id, revision) unet_graph = UNetGraphWithCache(unet) cross_attention_dim = unet.config["cross_attention_dim"] warmup_meta_of_sizes = get_arg_meta_of_sizes( @@ -168,7 +207,7 @@ def benchmark(token, repeat, sync_interval, save, load, file, model_id): unet_graph.warmup_with_arg(warmup_meta_of_sizes) # generate inputs with torch - from diffusers.utils import floats_tensor + #from diffusers.utils import floats_tensor import torch time_step = torch.tensor([10]).to("cuda") From 0eaa2c9342fabbb1ab44cffa76fdb4ad53f574b1 Mon Sep 17 00:00:00 2001 From: strint Date: Wed, 23 Aug 2023 07:52:32 +0000 Subject: [PATCH 05/16] part2 --- examples/text_to_image_sdxl_fp16_graph.py | 60 +++++++++++++++++++++++ examples/unet_torch_interplay.py | 6 ++- 2 files changed, 64 insertions(+), 2 deletions(-) create mode 100644 examples/text_to_image_sdxl_fp16_graph.py diff --git a/examples/text_to_image_sdxl_fp16_graph.py b/examples/text_to_image_sdxl_fp16_graph.py new file mode 100644 index 000000000..195d09d0c --- /dev/null +++ b/examples/text_to_image_sdxl_fp16_graph.py @@ -0,0 +1,60 @@ +import os +import argparse +from diffusers import StableDiffusionXLPipeline +import torch +import oneflow as flow +from onediff.infer_compiler import torchbackend +from .unet_torch_interplay import get_unet, UNetGraphWithCache + +parser = argparse.ArgumentParser() +parser.add_argument( + "--model", type=str, default="/share_nfs/hf_models/stable-diffusion-xl-base-1.0" +) +parser.add_argument("--variant", type=str, default="fp16") +parser.add_argument( + "--prompt", + type=str, + default="street style, detailed, raw photo, woman, face, shot on CineStill 800T", +) +parser.add_argument("--saved_image", type=str, required=False, default="xl-base-out.png") +parser.add_argument("--seed", type=int, default=1) +parser.add_argument("--compile", action=argparse.BooleanOptionalAction) +parser.add_argument("--graph", action=argparse.BooleanOptionalAction) +args = parser.parse_args() + +if args.compile: + print("unet is compiled to oneflow.") + if args.graph: + print("unet is compiled to oneflow graph.") + +torch.manual_seed(args.seed) + +pipe = StableDiffusionXLPipeline.from_pretrained( + args.model, torch_dtype=torch.float16, variant=args.variant, use_safetensors=True +) + +if args.compile: + if args.graph: + os.environ["with_graph"] = "1" + os.environ["ONEFLOW_MLIR_CSE"] = "1" + os.environ["ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION"] = "1" + os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" + os.environ["ONEFLOW_MLIR_FUSE_FORWARD_OPS"] = "1" + os.environ["ONEFLOW_MLIR_FUSE_OPS_WITH_BACKWARD_IMPL"] = "1" + # Open this will raise error + # os.environ["ONEFLOW_MLIR_GROUP_MATMUL"] = "1" + os.environ["ONEFLOW_MLIR_PREFER_NHWC"] = "1" + os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_CONV_BIAS"] = "1" + os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_LINEAR"] = "1" + os.environ["ONEFLOW_KERNEL_CONV_CUTLASS_IMPL_ENABLE_TUNING_WARMUP"] = "1" + os.environ["ONEFLOW_KERNEL_CONV_ENABLE_CUTLASS_IMPL"] = "1" + os.environ["ONEFLOW_CONV_ALLOW_HALF_PRECISION_ACCUMULATION"] = "1" + os.environ["ONEFLOW_MATMUL_ALLOW_HALF_PRECISION_ACCUMULATION"] = "1" + os.environ["ONEFLOW_LINEAR_EMBEDDING_SKIP_INIT"] = "1" + pipe.unet = torch.compile(pipe.unet, fullgraph=True, mode="reduce-overhead", backend=torchbackend) + +pipe.to("cuda") + +for i in range(3): + image = pipe(prompt=args.prompt).images[0] + image.save(f"{i}-{args.saved_image}") diff --git a/examples/unet_torch_interplay.py b/examples/unet_torch_interplay.py index c753eddfe..e09dde614 100644 --- a/examples/unet_torch_interplay.py +++ b/examples/unet_torch_interplay.py @@ -69,10 +69,10 @@ def __init__(self, unet): self.config.enable_cudnn_conv_heuristic_search_algo(False) self.config.allow_fuse_add_to_output(True) - def build(self, latent_model_input, t, text_embeddings): + def build(self, latent_model_input, t, text_embeddings, added_cond_kwargs=None): text_embeddings = flow._C.amp_white_identity(text_embeddings) return self.unet( - latent_model_input, t, encoder_hidden_states=text_embeddings + latent_model_input, t, encoder_hidden_states=text_embeddings, added_cond_kwargs=added_cond_kwargs, ).sample def warmup_with_arg(self, arg_meta_of_sizes): @@ -191,6 +191,7 @@ def benchmark(token, repeat, sync_interval, save, load, file, model_id, revision unet = get_unet(token, model_id, revision) unet_graph = UNetGraphWithCache(unet) cross_attention_dim = unet.config["cross_attention_dim"] + warmup_meta_of_sizes = get_arg_meta_of_sizes( batch_sizes=BATCH_SIZES, resolution_scales=RESOLUTION_SCALES, @@ -199,6 +200,7 @@ def benchmark(token, repeat, sync_interval, save, load, file, model_id, revision ) for (i, m) in enumerate(warmup_meta_of_sizes): print(f"warmup case #{i + 1}:", m) + if load == True: print("loading graphs...") unet_graph.warmup_with_load(file) From a1eaf26738c8f1fc9cb0bab11d06a171fe569c72 Mon Sep 17 00:00:00 2001 From: strint Date: Wed, 23 Aug 2023 09:42:06 +0000 Subject: [PATCH 06/16] sdxl unet graph test passed --- examples/text_to_image_sdxl_fp16.py | 2 +- examples/unet_torch_interplay.py | 25 +++++++++++++++++-------- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/examples/text_to_image_sdxl_fp16.py b/examples/text_to_image_sdxl_fp16.py index b0a700bf6..c3b99a6f1 100644 --- a/examples/text_to_image_sdxl_fp16.py +++ b/examples/text_to_image_sdxl_fp16.py @@ -56,5 +56,5 @@ pipe.to("cuda") for i in range(3): - image = pipe(prompt=args.prompt).images[0] + image = pipe(prompt=args.prompt, height=96, width=128, num_inference_steps=1).images[0] image.save(f"{i}-{args.saved_image}") diff --git a/examples/unet_torch_interplay.py b/examples/unet_torch_interplay.py index e09dde614..9c2677d50 100644 --- a/examples/unet_torch_interplay.py +++ b/examples/unet_torch_interplay.py @@ -6,7 +6,8 @@ os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" os.environ["ONEFLOW_MLIR_FUSE_FORWARD_OPS"] = "1" os.environ["ONEFLOW_MLIR_FUSE_OPS_WITH_BACKWARD_IMPL"] = "1" -os.environ["ONEFLOW_MLIR_GROUP_MATMUL"] = "1" +# TODO(): open ONEFLOW_MLIR_GROUP_MATMUL will raise in SDXL, need be fixed. +os.environ["ONEFLOW_MLIR_GROUP_MATMUL"] = "0" os.environ["ONEFLOW_MLIR_PREFER_NHWC"] = "1" os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_CONV_BIAS"] = "1" @@ -68,6 +69,7 @@ def __init__(self, unet): self.unet = unet self.config.enable_cudnn_conv_heuristic_search_algo(False) self.config.allow_fuse_add_to_output(True) + self.debug(0) def build(self, latent_model_input, t, text_embeddings, added_cond_kwargs=None): text_embeddings = flow._C.amp_white_identity(text_embeddings) @@ -75,7 +77,7 @@ def build(self, latent_model_input, t, text_embeddings, added_cond_kwargs=None): latent_model_input, t, encoder_hidden_states=text_embeddings, added_cond_kwargs=added_cond_kwargs, ).sample - def warmup_with_arg(self, arg_meta_of_sizes): + def warmup_with_arg(self, arg_meta_of_sizes, added): for arg_metas in arg_meta_of_sizes: print(f"warmup {arg_metas=}") arg_tensors = [ @@ -88,7 +90,7 @@ def warmup_with_arg(self, arg_meta_of_sizes): dtype=arg_metas.gettype("cross_attention_dim"), ).to("cuda"), ] - self(*arg_tensors) # build and warmup + self(*arg_tensors, added) # build and warmup def warmup_with_load(self, file_path): state_dict = flow.load(file_path) @@ -186,6 +188,15 @@ def benchmark(token, repeat, sync_interval, save, load, file, model_id, revision # BATCH_SIZES = [4, 2] num_channels = 4 + import torch + if model_id == "stabilityai/stable-diffusion-xl-base-1.0": + # sdxl needed + add_text_embeds = flow.utils.tensor.from_torch(floats_tensor((2, 1280)).to("cuda").to(torch.float16)) + add_time_ids = flow.utils.tensor.from_torch(floats_tensor((2, 6)).to("cuda").to(torch.float16)) + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + else: + added_cond_kwargs = None + # create a mocked unet graph with MockCtx(): unet = get_unet(token, model_id, revision) @@ -206,12 +217,9 @@ def benchmark(token, repeat, sync_interval, save, load, file, model_id, revision unet_graph.warmup_with_load(file) else: print("warmup with arguments...") - unet_graph.warmup_with_arg(warmup_meta_of_sizes) + unet_graph.warmup_with_arg(warmup_meta_of_sizes, added_cond_kwargs) # generate inputs with torch - #from diffusers.utils import floats_tensor - import torch - time_step = torch.tensor([10]).to("cuda") encoder_hidden_states_of_sizes = { batch_size: floats_tensor((batch_size, 77, cross_attention_dim)) @@ -228,6 +236,7 @@ def benchmark(token, repeat, sync_interval, save, load, file, model_id, revision k: flow.utils.tensor.from_torch(v) for k, v in encoder_hidden_states_of_sizes.items() } + # convert to oneflow tensors time_step = flow.utils.tensor.from_torch(time_step) @@ -240,7 +249,7 @@ def benchmark(token, repeat, sync_interval, save, load, file, model_id, revision noise = random.choice(noise_of_sizes) encoder_hidden_states = encoder_hidden_states_of_sizes[noise.shape[0]] - out = unet_graph(noise, time_step, encoder_hidden_states) + out = unet_graph(noise, time_step, encoder_hidden_states, added_cond_kwargs) # convert to torch tensors out = flow.utils.tensor.to_torch(out) if r == repeat - 1 or r % sync_interval == 0: From da4f47403bc5c0077fcc11fe30021ef0dd97550c Mon Sep 17 00:00:00 2001 From: strint Date: Wed, 23 Aug 2023 16:41:25 +0000 Subject: [PATCH 07/16] oneflow unet replace test passed --- examples/text_to_image_sdxl_fp16_graph.py | 50 +++++++++++++-- .../infer_compiler/obj_1f_from_torch.py | 37 +++++++++-- src/onediff/infer_compiler/with_fx_graph.py | 64 +++++++++++++++++++ 3 files changed, 142 insertions(+), 9 deletions(-) diff --git a/examples/text_to_image_sdxl_fp16_graph.py b/examples/text_to_image_sdxl_fp16_graph.py index 195d09d0c..fa4e34a87 100644 --- a/examples/text_to_image_sdxl_fp16_graph.py +++ b/examples/text_to_image_sdxl_fp16_graph.py @@ -1,10 +1,16 @@ import os import argparse +# cv2 must be imported before diffusers and oneflow to avlid error: AttributeError: module 'cv2.gapi' has no attribute 'wip' +# Maybe bacause oneflow use a lower version of cv2 +import cv2 +import oneflow as flow +# obj_1f_from_torch should be import before import any diffusers +from onediff.infer_compiler import obj_1f_from_torch + from diffusers import StableDiffusionXLPipeline import torch -import oneflow as flow -from onediff.infer_compiler import torchbackend -from .unet_torch_interplay import get_unet, UNetGraphWithCache +#from onediff.infer_compiler import torchbackend +from onediff.infer_compiler.with_fx_graph import _get_of_module parser = argparse.ArgumentParser() parser.add_argument( @@ -32,6 +38,7 @@ pipe = StableDiffusionXLPipeline.from_pretrained( args.model, torch_dtype=torch.float16, variant=args.variant, use_safetensors=True ) +pipe.to("cuda") if args.compile: if args.graph: @@ -51,9 +58,42 @@ os.environ["ONEFLOW_CONV_ALLOW_HALF_PRECISION_ACCUMULATION"] = "1" os.environ["ONEFLOW_MATMUL_ALLOW_HALF_PRECISION_ACCUMULATION"] = "1" os.environ["ONEFLOW_LINEAR_EMBEDDING_SKIP_INIT"] = "1" - pipe.unet = torch.compile(pipe.unet, fullgraph=True, mode="reduce-overhead", backend=torchbackend) + torch2flow = {} -pipe.to("cuda") + def get_deployable(of_md): + from oneflow.framework.args_tree import ArgsTree + def input_fn(value): + if isinstance(value, torch.Tensor): + return flow.utils.tensor.from_torch(value) + else: + return value + + def output_fn(value): + if isinstance(value, flow.Tensor): + return flow.utils.tensor.to_torch(value) + else: + return value + + class DeplayableModule(of_md.__class__): + def __call__(self, *args, **kwargs): + args_tree = ArgsTree((args, kwargs), False, tensor_type=torch.Tensor) + out = args_tree.map_leaf(input_fn) + mapped_args = out[0] + mapped_kwargs = out[1] + + output = super().__call__(*mapped_args, **mapped_kwargs) + + out_tree = ArgsTree((output, None), False) + out = out_tree.map_leaf(output_fn) + return out[0] + + of_md.__class__ = DeplayableModule + return of_md + + unet = _get_of_module(pipe.unet, torch2flow) + d_unet = get_deployable(unet) + print(type(unet)) + pipe.unet = d_unet for i in range(3): image = pipe(prompt=args.prompt).images[0] diff --git a/src/onediff/infer_compiler/obj_1f_from_torch.py b/src/onediff/infer_compiler/obj_1f_from_torch.py index 012225d45..1bca1274c 100644 --- a/src/onediff/infer_compiler/obj_1f_from_torch.py +++ b/src/onediff/infer_compiler/obj_1f_from_torch.py @@ -2,6 +2,31 @@ import os import torch import oneflow as flow + +__of_mds = {} +__convert_list = [ + "diffusers.models.unet_2d_condition.UNet2DConditionModel", + "diffusers.models.embeddings.TimestepEmbedding", + "diffusers.models.embeddings.Timesteps", + "diffusers.models.resnet.ResnetBlock2D", + "diffusers.models.resnet.Downsample2D", + "diffusers.models.resnet.Upsample2D", + "diffusers.models.transformer_2d.Transformer2DModel", + "diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D", + "diffusers.models.unet_2d_blocks.DownBlock2D", + "diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D", + "diffusers.models.unet_2d_blocks.UpBlock2D", + "diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D", + "diffusers.models.unet_2d_blocks.UNetMidBlock2DCrossAttn", +] + +with flow.mock_torch.enable(lazy=False): + for md_name in __convert_list: + p, m = md_name.rsplit('.', 1) + md = importlib.import_module(p) + __of_mds[md_name] = getattr(md, m) + print(f"import {md_name}") + import diffusers from typing import Any, Dict, Iterator, List, Optional, Tuple, Union from .attention_1f import BasicTransformerBlock, FeedForward, GEGLU @@ -32,6 +57,10 @@ def replace_class(cls): return LoRACompatibleLinear elif cls == diffusers.models.lora.LoRACompatibleConv: return LoRACompatibleConv + + full_cls_name = str(cls.__module__) + '.' + str(cls.__name__) + if full_cls_name in __of_mds: + return __of_mds[full_cls_name] if _is_diffusers_quant_available: if cls == diffusers_quant.FakeQuantModule: @@ -152,10 +181,10 @@ def __getattribute__(self, attribute): self._1f_proxy_children[attribute] = a else: a = self._1f_proxy_children[attribute] - assert ( - type(a).__module__.startswith("torch") == False - and type(a).__module__.startswith("diffusers") == False - ), "can't be a torch module at this point! But found " + str(type(a)) + # assert ( + # type(a).__module__.startswith("torch") == False + # and type(a).__module__.startswith("diffusers") == False + # ), "can't be a torch module at this point! But found " + str(type(a)) return a def __call__(self, *args: Any, **kwargs: Any) -> Any: diff --git a/src/onediff/infer_compiler/with_fx_graph.py b/src/onediff/infer_compiler/with_fx_graph.py index d73eb9be3..8fab7f7fc 100644 --- a/src/onediff/infer_compiler/with_fx_graph.py +++ b/src/onediff/infer_compiler/with_fx_graph.py @@ -147,3 +147,67 @@ def _get_attr(gm, node, torch2flow): of_attr = replace_obj(attr) torch2flow[attr] = of_attr return of_attr + +def _get_of_module_list(origin_mod, torch2flow): + assert isinstance(origin_mod, torch.nn.ModuleList) + if origin_mod in torch2flow: + return torch2flow[origin_mod] + of_md_list = flow.nn.ModuleList() + for m in origin_mod: + of_md_list.append(_get_of_module(m, torch2flow)) + torch2flow[origin_mod] = of_md_list + return of_md_list + +def _get_of_module(origin_mod, torch2flow): + if origin_mod in torch2flow: + return torch2flow[origin_mod] + + if isinstance(origin_mod, torch.nn.ModuleList): + return _get_of_module_list(origin_mod, torch2flow) + + proxy_md = ProxySubmodule(origin_mod) + new_md_cls = replace_class(type(origin_mod)) + if new_md_cls: + print("succeed") + import inspect + print(inspect.getmro(new_md_cls)) + else: + print("failed") + import inspect + print(inspect.getmro(type(origin_mod))) + import pdb; pdb.set_trace() + new_md_cls = flow.nn.Module + + def init(self): + self._parameters = OrderedDict() + self._buffers = OrderedDict() + self._modules = OrderedDict() + for (n, p) in list(proxy_md.named_parameters("", False)): + self._parameters[n] = flow.utils.tensor.from_torch(p.data) + for (n, b) in list(proxy_md.named_buffers("", False)): + self._buffers[n] = flow.utils.tensor.from_torch(b.data) + for (n, m) in proxy_md._modules.items(): + self._modules[n] = _get_of_module(m, torch2flow) + + for k, v in proxy_md.__dict__.items(): + if k not in self.__dict__: + try: + attr = getattr(proxy_md, k) + except: + continue + self.__dict__[k] = attr + + def proxy_getattr(self, attr): + if attr in ["_parameters", "_buffers", "_modules"]: + raise ValueError(f"missing attr {attr} in base class") + else: + return getattr(proxy_md, attr) + + of_md_cls = type( + str(new_md_cls), (new_md_cls,), {"__init__": init, "__getattr__": proxy_getattr}, + ) + + new_md = of_md_cls() + + torch2flow[origin_mod] = new_md + return new_md From a3adab92b0188ecdb4cb8438e5590d50ce943a55 Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 24 Aug 2023 10:06:12 +0000 Subject: [PATCH 08/16] add unet graph to sdxl test passed --- examples/text_to_image_sdxl_fp16.py | 2 +- examples/text_to_image_sdxl_fp16_graph.py | 48 ++++++---------- .../infer_compiler/obj_1f_from_torch.py | 1 - src/onediff/infer_compiler/with_fx_graph.py | 56 +++++++++++++++---- 4 files changed, 65 insertions(+), 42 deletions(-) diff --git a/examples/text_to_image_sdxl_fp16.py b/examples/text_to_image_sdxl_fp16.py index c3b99a6f1..fba1f6623 100644 --- a/examples/text_to_image_sdxl_fp16.py +++ b/examples/text_to_image_sdxl_fp16.py @@ -56,5 +56,5 @@ pipe.to("cuda") for i in range(3): - image = pipe(prompt=args.prompt, height=96, width=128, num_inference_steps=1).images[0] + image = pipe(prompt=args.prompt, height=96, width=128, num_inference_steps=50).images[0] image.save(f"{i}-{args.saved_image}") diff --git a/examples/text_to_image_sdxl_fp16_graph.py b/examples/text_to_image_sdxl_fp16_graph.py index fa4e34a87..2ffa4e9d2 100644 --- a/examples/text_to_image_sdxl_fp16_graph.py +++ b/examples/text_to_image_sdxl_fp16_graph.py @@ -10,7 +10,7 @@ from diffusers import StableDiffusionXLPipeline import torch #from onediff.infer_compiler import torchbackend -from onediff.infer_compiler.with_fx_graph import _get_of_module +from onediff.infer_compiler.with_fx_graph import _get_of_module, UNetGraph parser = argparse.ArgumentParser() parser.add_argument( @@ -26,40 +26,21 @@ parser.add_argument("--seed", type=int, default=1) parser.add_argument("--compile", action=argparse.BooleanOptionalAction) parser.add_argument("--graph", action=argparse.BooleanOptionalAction) -args = parser.parse_args() +cmd_args = parser.parse_args() -if args.compile: +if cmd_args.compile: print("unet is compiled to oneflow.") - if args.graph: + if cmd_args.graph: print("unet is compiled to oneflow graph.") -torch.manual_seed(args.seed) +torch.manual_seed(cmd_args.seed) pipe = StableDiffusionXLPipeline.from_pretrained( - args.model, torch_dtype=torch.float16, variant=args.variant, use_safetensors=True + cmd_args.model, torch_dtype=torch.float16, variant=cmd_args.variant, use_safetensors=True ) pipe.to("cuda") -if args.compile: - if args.graph: - os.environ["with_graph"] = "1" - os.environ["ONEFLOW_MLIR_CSE"] = "1" - os.environ["ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION"] = "1" - os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" - os.environ["ONEFLOW_MLIR_FUSE_FORWARD_OPS"] = "1" - os.environ["ONEFLOW_MLIR_FUSE_OPS_WITH_BACKWARD_IMPL"] = "1" - # Open this will raise error - # os.environ["ONEFLOW_MLIR_GROUP_MATMUL"] = "1" - os.environ["ONEFLOW_MLIR_PREFER_NHWC"] = "1" - os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_CONV_BIAS"] = "1" - os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_LINEAR"] = "1" - os.environ["ONEFLOW_KERNEL_CONV_CUTLASS_IMPL_ENABLE_TUNING_WARMUP"] = "1" - os.environ["ONEFLOW_KERNEL_CONV_ENABLE_CUTLASS_IMPL"] = "1" - os.environ["ONEFLOW_CONV_ALLOW_HALF_PRECISION_ACCUMULATION"] = "1" - os.environ["ONEFLOW_MATMUL_ALLOW_HALF_PRECISION_ACCUMULATION"] = "1" - os.environ["ONEFLOW_LINEAR_EMBEDDING_SKIP_INIT"] = "1" - torch2flow = {} - +if cmd_args.compile: def get_deployable(of_md): from oneflow.framework.args_tree import ArgsTree def input_fn(value): @@ -74,6 +55,9 @@ def output_fn(value): else: return value + if cmd_args.graph: + unet_graph = UNetGraph(of_md) + class DeplayableModule(of_md.__class__): def __call__(self, *args, **kwargs): args_tree = ArgsTree((args, kwargs), False, tensor_type=torch.Tensor) @@ -81,7 +65,11 @@ def __call__(self, *args, **kwargs): mapped_args = out[0] mapped_kwargs = out[1] - output = super().__call__(*mapped_args, **mapped_kwargs) + if cmd_args.graph: + output = unet_graph(*mapped_args, **mapped_kwargs) + else: + output = super().__call__(*mapped_args, **mapped_kwargs) + out_tree = ArgsTree((output, None), False) out = out_tree.map_leaf(output_fn) @@ -90,11 +78,11 @@ def __call__(self, *args, **kwargs): of_md.__class__ = DeplayableModule return of_md + torch2flow = {} unet = _get_of_module(pipe.unet, torch2flow) d_unet = get_deployable(unet) - print(type(unet)) pipe.unet = d_unet for i in range(3): - image = pipe(prompt=args.prompt).images[0] - image.save(f"{i}-{args.saved_image}") + image = pipe(prompt=cmd_args.prompt, height=96, width=128, num_inference_steps=50).images[0] + image.save(f"{i}-{cmd_args.saved_image}") diff --git a/src/onediff/infer_compiler/obj_1f_from_torch.py b/src/onediff/infer_compiler/obj_1f_from_torch.py index 1bca1274c..05a8b7c83 100644 --- a/src/onediff/infer_compiler/obj_1f_from_torch.py +++ b/src/onediff/infer_compiler/obj_1f_from_torch.py @@ -25,7 +25,6 @@ p, m = md_name.rsplit('.', 1) md = importlib.import_module(p) __of_mds[md_name] = getattr(md, m) - print(f"import {md_name}") import diffusers from typing import Any, Dict, Iterator, List, Optional, Tuple, Union diff --git a/src/onediff/infer_compiler/with_fx_graph.py b/src/onediff/infer_compiler/with_fx_graph.py index 8fab7f7fc..4f512b2c9 100644 --- a/src/onediff/infer_compiler/with_fx_graph.py +++ b/src/onediff/infer_compiler/with_fx_graph.py @@ -167,16 +167,6 @@ def _get_of_module(origin_mod, torch2flow): proxy_md = ProxySubmodule(origin_mod) new_md_cls = replace_class(type(origin_mod)) - if new_md_cls: - print("succeed") - import inspect - print(inspect.getmro(new_md_cls)) - else: - print("failed") - import inspect - print(inspect.getmro(type(origin_mod))) - import pdb; pdb.set_trace() - new_md_cls = flow.nn.Module def init(self): self._parameters = OrderedDict() @@ -211,3 +201,49 @@ def proxy_getattr(self, attr): torch2flow[origin_mod] = new_md return new_md + + +class UNetGraph(flow.nn.Graph): + @flow.nn.Graph.with_dynamic_input_shape(size=9) + def __init__(self, unet): + super().__init__(enable_get_runtime_state_dict=True) + self.unet = unet + self.config.enable_cudnn_conv_heuristic_search_algo(False) + self.config.allow_fuse_add_to_output(True) + self.debug(0) + + os.environ["ONEFLOW_MLIR_CSE"] = "1" + os.environ["ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION"] = "1" + os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" + os.environ["ONEFLOW_MLIR_FUSE_FORWARD_OPS"] = "1" + os.environ["ONEFLOW_MLIR_FUSE_OPS_WITH_BACKWARD_IMPL"] = "1" + # Open this will raise error + # os.environ["ONEFLOW_MLIR_GROUP_MATMUL"] = "1" + os.environ["ONEFLOW_MLIR_PREFER_NHWC"] = "1" + os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_CONV_BIAS"] = "1" + os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_LINEAR"] = "1" + os.environ["ONEFLOW_KERNEL_CONV_CUTLASS_IMPL_ENABLE_TUNING_WARMUP"] = "1" + os.environ["ONEFLOW_KERNEL_CONV_ENABLE_CUTLASS_IMPL"] = "1" + os.environ["ONEFLOW_CONV_ALLOW_HALF_PRECISION_ACCUMULATION"] = "1" + os.environ["ONEFLOW_MATMUL_ALLOW_HALF_PRECISION_ACCUMULATION"] = "1" + os.environ["ONEFLOW_LINEAR_EMBEDDING_SKIP_INIT"] = "1" + + def build(self, latent_model_input, t, encoder_hidden_states, cross_attention_kwargs, added_cond_kwargs=None, return_dict=False): + encoder_hidden_states = flow._C.amp_white_identity(encoder_hidden_states) + pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=return_dict, + ) + return pred + + def warmup_with_load(self, file_path): + state_dict = flow.load(file_path) + self.load_runtime_state_dict(state_dict) + + def save_graph(self, file_path): + state_dict = self.runtime_state_dict() + flow.save(state_dict, file_path) From 97d69cf51f1aa47fa062819a02f2cdace5acffb4 Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 24 Aug 2023 10:30:01 +0000 Subject: [PATCH 09/16] sdxl unet graph multi input shape --- examples/text_to_image_sdxl_fp16_graph.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/examples/text_to_image_sdxl_fp16_graph.py b/examples/text_to_image_sdxl_fp16_graph.py index 2ffa4e9d2..efc6323c9 100644 --- a/examples/text_to_image_sdxl_fp16_graph.py +++ b/examples/text_to_image_sdxl_fp16_graph.py @@ -83,6 +83,9 @@ def __call__(self, *args, **kwargs): d_unet = get_deployable(unet) pipe.unet = d_unet -for i in range(3): - image = pipe(prompt=cmd_args.prompt, height=96, width=128, num_inference_steps=50).images[0] - image.save(f"{i}-{cmd_args.saved_image}") +sizes = [1024, 896, 768] +for h in sizes: + for w in sizes: + for i in range(2): + image = pipe(prompt=cmd_args.prompt, height=h, width=w, num_inference_steps=50).images[0] + image.save(f"h{h}-w{w}-i{i}-{cmd_args.saved_image}") From 570d383a1b83fba53249e492766bbbf03590bf2f Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 24 Aug 2023 11:38:20 +0000 Subject: [PATCH 10/16] sdxl unet graph save load multi input test passed --- examples/text_to_image_sdxl_fp16_graph.py | 101 +++++++++++++--------- 1 file changed, 61 insertions(+), 40 deletions(-) diff --git a/examples/text_to_image_sdxl_fp16_graph.py b/examples/text_to_image_sdxl_fp16_graph.py index efc6323c9..cfbd191ea 100644 --- a/examples/text_to_image_sdxl_fp16_graph.py +++ b/examples/text_to_image_sdxl_fp16_graph.py @@ -26,66 +26,87 @@ parser.add_argument("--seed", type=int, default=1) parser.add_argument("--compile", action=argparse.BooleanOptionalAction) parser.add_argument("--graph", action=argparse.BooleanOptionalAction) +parser.add_argument("--file", type=str, required=False, default="deployable_unet") +parser.add_argument("--save", action=argparse.BooleanOptionalAction) +parser.add_argument("--load", action=argparse.BooleanOptionalAction) cmd_args = parser.parse_args() +# For compile with oneflow if cmd_args.compile: print("unet is compiled to oneflow.") if cmd_args.graph: print("unet is compiled to oneflow graph.") + +def get_deployable(torch_md): + torch2flow = {} + of_md = _get_of_module(torch_md, torch2flow) + from oneflow.framework.args_tree import ArgsTree + def input_fn(value): + if isinstance(value, torch.Tensor): + return flow.utils.tensor.from_torch(value) + else: + return value -torch.manual_seed(cmd_args.seed) + def output_fn(value): + if isinstance(value, flow.Tensor): + return flow.utils.tensor.to_torch(value) + else: + return value -pipe = StableDiffusionXLPipeline.from_pretrained( - cmd_args.model, torch_dtype=torch.float16, variant=cmd_args.variant, use_safetensors=True -) -pipe.to("cuda") + if cmd_args.graph: + dpl_graph = UNetGraph(of_md) -if cmd_args.compile: - def get_deployable(of_md): - from oneflow.framework.args_tree import ArgsTree - def input_fn(value): - if isinstance(value, torch.Tensor): - return flow.utils.tensor.from_torch(value) - else: - return value + class DeplayableModule(of_md.__class__): + def __call__(self, *args, **kwargs): + args_tree = ArgsTree((args, kwargs), False, tensor_type=torch.Tensor) + out = args_tree.map_leaf(input_fn) + mapped_args = out[0] + mapped_kwargs = out[1] - def output_fn(value): - if isinstance(value, flow.Tensor): - return flow.utils.tensor.to_torch(value) + if cmd_args.graph: + output = self._dpl_graph(*mapped_args, **mapped_kwargs) else: - return value - - if cmd_args.graph: - unet_graph = UNetGraph(of_md) + output = super().__call__(*mapped_args, **mapped_kwargs) - class DeplayableModule(of_md.__class__): - def __call__(self, *args, **kwargs): - args_tree = ArgsTree((args, kwargs), False, tensor_type=torch.Tensor) - out = args_tree.map_leaf(input_fn) - mapped_args = out[0] - mapped_kwargs = out[1] - if cmd_args.graph: - output = unet_graph(*mapped_args, **mapped_kwargs) - else: - output = super().__call__(*mapped_args, **mapped_kwargs) + out_tree = ArgsTree((output, None), False) + out = out_tree.map_leaf(output_fn) + return out[0] + + def _dpl_load(self, file_path): + self._dpl_graph.warmup_with_load(file_path) + + def _dpl_save(self, file_path): + self._dpl_graph.save_graph(file_path) + of_md.__class__ = DeplayableModule + if cmd_args.graph: + of_md._dpl_graph = dpl_graph + if cmd_args.load: + print("loading deployable graphs...") + of_md._dpl_load(cmd_args.file) + return of_md - out_tree = ArgsTree((output, None), False) - out = out_tree.map_leaf(output_fn) - return out[0] - - of_md.__class__ = DeplayableModule - return of_md +# Normal SDXL +torch.manual_seed(cmd_args.seed) +pipe = StableDiffusionXLPipeline.from_pretrained( + cmd_args.model, torch_dtype=torch.float16, variant=cmd_args.variant, use_safetensors=True +) +pipe.to("cuda") - torch2flow = {} - unet = _get_of_module(pipe.unet, torch2flow) - d_unet = get_deployable(unet) - pipe.unet = d_unet +# Compile unet with oneflow +if cmd_args.compile: + pipe.unet = get_deployable(pipe.unet) +# Normal SDXL call sizes = [1024, 896, 768] for h in sizes: for w in sizes: for i in range(2): image = pipe(prompt=cmd_args.prompt, height=h, width=w, num_inference_steps=50).images[0] image.save(f"h{h}-w{w}-i{i}-{cmd_args.saved_image}") + +# Save compiled unet with oneflow +if cmd_args.save: + print("saving deployable graphs...") + pipe.unet._dpl_save(cmd_args.file) \ No newline at end of file From 39263e59d54c3913da908d0f713c2d980eb15d6f Mon Sep 17 00:00:00 2001 From: strint Date: Fri, 25 Aug 2023 03:22:11 +0000 Subject: [PATCH 11/16] refine oneflow_compile api --- examples/text_to_image_sdxl_fp16_graph.py | 85 ++------- src/onediff/infer_compiler/__init__.py | 1 + .../infer_compiler/obj_1f_from_torch.py | 68 +++++++- src/onediff/infer_compiler/with_fx_graph.py | 162 ------------------ src/onediff/infer_compiler/with_onef_graph.py | 97 +++++++++++ 5 files changed, 177 insertions(+), 236 deletions(-) create mode 100644 src/onediff/infer_compiler/with_onef_graph.py diff --git a/examples/text_to_image_sdxl_fp16_graph.py b/examples/text_to_image_sdxl_fp16_graph.py index cfbd191ea..5c2dc49c6 100644 --- a/examples/text_to_image_sdxl_fp16_graph.py +++ b/examples/text_to_image_sdxl_fp16_graph.py @@ -4,13 +4,10 @@ # Maybe bacause oneflow use a lower version of cv2 import cv2 import oneflow as flow -# obj_1f_from_torch should be import before import any diffusers -from onediff.infer_compiler import obj_1f_from_torch - -from diffusers import StableDiffusionXLPipeline import torch -#from onediff.infer_compiler import torchbackend -from onediff.infer_compiler.with_fx_graph import _get_of_module, UNetGraph +# oneflow_compile should be imported before importing any diffusers +from onediff.infer_compiler import oneflow_compile +from diffusers import StableDiffusionXLPipeline parser = argparse.ArgumentParser() parser.add_argument( @@ -25,68 +22,11 @@ parser.add_argument("--saved_image", type=str, required=False, default="xl-base-out.png") parser.add_argument("--seed", type=int, default=1) parser.add_argument("--compile", action=argparse.BooleanOptionalAction) -parser.add_argument("--graph", action=argparse.BooleanOptionalAction) -parser.add_argument("--file", type=str, required=False, default="deployable_unet") parser.add_argument("--save", action=argparse.BooleanOptionalAction) parser.add_argument("--load", action=argparse.BooleanOptionalAction) +parser.add_argument("--file", type=str, required=False, default="unet_graph") cmd_args = parser.parse_args() -# For compile with oneflow -if cmd_args.compile: - print("unet is compiled to oneflow.") - if cmd_args.graph: - print("unet is compiled to oneflow graph.") - -def get_deployable(torch_md): - torch2flow = {} - of_md = _get_of_module(torch_md, torch2flow) - from oneflow.framework.args_tree import ArgsTree - def input_fn(value): - if isinstance(value, torch.Tensor): - return flow.utils.tensor.from_torch(value) - else: - return value - - def output_fn(value): - if isinstance(value, flow.Tensor): - return flow.utils.tensor.to_torch(value) - else: - return value - - if cmd_args.graph: - dpl_graph = UNetGraph(of_md) - - class DeplayableModule(of_md.__class__): - def __call__(self, *args, **kwargs): - args_tree = ArgsTree((args, kwargs), False, tensor_type=torch.Tensor) - out = args_tree.map_leaf(input_fn) - mapped_args = out[0] - mapped_kwargs = out[1] - - if cmd_args.graph: - output = self._dpl_graph(*mapped_args, **mapped_kwargs) - else: - output = super().__call__(*mapped_args, **mapped_kwargs) - - - out_tree = ArgsTree((output, None), False) - out = out_tree.map_leaf(output_fn) - return out[0] - - def _dpl_load(self, file_path): - self._dpl_graph.warmup_with_load(file_path) - - def _dpl_save(self, file_path): - self._dpl_graph.save_graph(file_path) - - of_md.__class__ = DeplayableModule - if cmd_args.graph: - of_md._dpl_graph = dpl_graph - if cmd_args.load: - print("loading deployable graphs...") - of_md._dpl_load(cmd_args.file) - return of_md - # Normal SDXL torch.manual_seed(cmd_args.seed) pipe = StableDiffusionXLPipeline.from_pretrained( @@ -96,17 +36,22 @@ def _dpl_save(self, file_path): # Compile unet with oneflow if cmd_args.compile: - pipe.unet = get_deployable(pipe.unet) + pipe.unet = oneflow_compile(pipe.unet) + print("unet is compiled to oneflow.") + if cmd_args.load: + # Load compiled unet with oneflow + print("loading graphs...") + pipe.unet._graph_load(cmd_args.file) # Normal SDXL call sizes = [1024, 896, 768] for h in sizes: for w in sizes: - for i in range(2): - image = pipe(prompt=cmd_args.prompt, height=h, width=w, num_inference_steps=50).images[0] + for i in range(1): + image = pipe(prompt=cmd_args.prompt, height=h, width=w, num_inference_steps=2).images[0] image.save(f"h{h}-w{w}-i{i}-{cmd_args.saved_image}") # Save compiled unet with oneflow -if cmd_args.save: - print("saving deployable graphs...") - pipe.unet._dpl_save(cmd_args.file) \ No newline at end of file +if cmd_args.compile and cmd_args.save: + print("saving graphs...") + pipe.unet._graph_save(cmd_args.file) \ No newline at end of file diff --git a/src/onediff/infer_compiler/__init__.py b/src/onediff/infer_compiler/__init__.py index bce15e29d..83d052f6e 100644 --- a/src/onediff/infer_compiler/__init__.py +++ b/src/onediff/infer_compiler/__init__.py @@ -1,6 +1,7 @@ import os import torch import oneflow as flow +from .with_onef_graph import oneflow_compile from .with_fx_interpreter import OneFlowInterpreter from .with_fx_graph import fx_node_tranform diff --git a/src/onediff/infer_compiler/obj_1f_from_torch.py b/src/onediff/infer_compiler/obj_1f_from_torch.py index 05a8b7c83..6aa703ae1 100644 --- a/src/onediff/infer_compiler/obj_1f_from_torch.py +++ b/src/onediff/infer_compiler/obj_1f_from_torch.py @@ -1,5 +1,6 @@ import importlib import os +from collections import OrderedDict import torch import oneflow as flow @@ -180,10 +181,6 @@ def __getattribute__(self, attribute): self._1f_proxy_children[attribute] = a else: a = self._1f_proxy_children[attribute] - # assert ( - # type(a).__module__.startswith("torch") == False - # and type(a).__module__.startswith("diffusers") == False - # ), "can't be a torch module at this point! But found " + str(type(a)) return a def __call__(self, *args: Any, **kwargs: Any) -> Any: @@ -194,3 +191,66 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: raise RuntimeError( "can't find oneflow module for: " + str(type(self._1f_proxy_submod)) ) + +def _get_module_list(origin_mod, torch2flow): + assert isinstance(origin_mod, torch.nn.ModuleList) + if origin_mod in torch2flow: + return torch2flow[origin_mod] + of_md_list = flow.nn.ModuleList() + for m in origin_mod: + of_md_list.append(_get_module(m, torch2flow)) + torch2flow[origin_mod] = of_md_list + return of_md_list + + +def _get_module(origin_mod, torch2flow): + if origin_mod in torch2flow: + return torch2flow[origin_mod] + + if isinstance(origin_mod, torch.nn.ModuleList): + return _get_module_list(origin_mod, torch2flow) + + proxy_md = ProxySubmodule(origin_mod) + new_md_cls = replace_class(type(origin_mod)) + + def init(self): + self._parameters = OrderedDict() + self._buffers = OrderedDict() + self._modules = OrderedDict() + for (n, p) in list(proxy_md.named_parameters("", False)): + self._parameters[n] = flow.utils.tensor.from_torch(p.data) + for (n, b) in list(proxy_md.named_buffers("", False)): + self._buffers[n] = flow.utils.tensor.from_torch(b.data) + for (n, m) in proxy_md._modules.items(): + self._modules[n] = _get_module(m, torch2flow) + + for k, v in proxy_md.__dict__.items(): + if k not in self.__dict__: + try: + attr = getattr(proxy_md, k) + except: + continue + self.__dict__[k] = attr + + def proxy_getattr(self, attr): + if attr in ["_parameters", "_buffers", "_modules"]: + raise ValueError(f"missing attr {attr} in base class") + else: + return getattr(proxy_md, attr) + + of_md_cls = type( + str(new_md_cls), (new_md_cls,), {"__init__": init, "__getattr__": proxy_getattr}, + ) + + new_md = of_md_cls() + + torch2flow[origin_mod] = new_md + return new_md + +def _get_attr(gm, node, torch2flow): + attr = getattr(gm, node.target) + if attr in torch2flow: + return torch2flow[attr] + of_attr = replace_obj(attr) + torch2flow[attr] = of_attr + return of_attr diff --git a/src/onediff/infer_compiler/with_fx_graph.py b/src/onediff/infer_compiler/with_fx_graph.py index 4f512b2c9..fde958d6c 100644 --- a/src/onediff/infer_compiler/with_fx_graph.py +++ b/src/onediff/infer_compiler/with_fx_graph.py @@ -3,7 +3,6 @@ import torch.fx as fx import oneflow as flow from torch.fx.node import map_aggregate -from collections import OrderedDict from typing import Any, Dict, Iterator, List, Optional, Tuple, Union from .obj_1f_from_torch import replace_obj, replace_func, replace_class, ProxySubmodule @@ -85,165 +84,4 @@ def node_replace_args(args, name2node): return map_aggregate(args, lambda node: replace_node(node, name2node)) -def _get_module_list(origin_mod, torch2flow): - assert isinstance(origin_mod, torch.nn.ModuleList) - if origin_mod in torch2flow: - return torch2flow[origin_mod] - of_md_list = flow.nn.ModuleList() - for m in origin_mod: - of_md_list.append(_get_module(m, torch2flow)) - torch2flow[origin_mod] = of_md_list - return of_md_list - -def _get_module(origin_mod, torch2flow): - if origin_mod in torch2flow: - return torch2flow[origin_mod] - - if isinstance(origin_mod, torch.nn.ModuleList): - return _get_module_list(origin_mod, torch2flow) - - proxy_md = ProxySubmodule(origin_mod) - new_md_cls = replace_class(type(origin_mod)) - - def init(self): - self._parameters = OrderedDict() - self._buffers = OrderedDict() - self._modules = OrderedDict() - for (n, p) in list(proxy_md.named_parameters("", False)): - self._parameters[n] = flow.utils.tensor.from_torch(p.data) - for (n, b) in list(proxy_md.named_buffers("", False)): - self._buffers[n] = flow.utils.tensor.from_torch(b.data) - for (n, m) in proxy_md._modules.items(): - self._modules[n] = _get_module(m, torch2flow) - - for k, v in proxy_md.__dict__.items(): - if k not in self.__dict__: - try: - attr = getattr(proxy_md, k) - except: - continue - self.__dict__[k] = attr - - def proxy_getattr(self, attr): - if attr in ["_parameters", "_buffers", "_modules"]: - raise ValueError(f"missing attr {attr} in base class") - else: - return getattr(proxy_md, attr) - - of_md_cls = type( - str(new_md_cls), (new_md_cls,), {"__init__": init, "__getattr__": proxy_getattr}, - ) - - new_md = of_md_cls() - - torch2flow[origin_mod] = new_md - return new_md - -def _get_attr(gm, node, torch2flow): - attr = getattr(gm, node.target) - if attr in torch2flow: - return torch2flow[attr] - of_attr = replace_obj(attr) - torch2flow[attr] = of_attr - return of_attr - -def _get_of_module_list(origin_mod, torch2flow): - assert isinstance(origin_mod, torch.nn.ModuleList) - if origin_mod in torch2flow: - return torch2flow[origin_mod] - of_md_list = flow.nn.ModuleList() - for m in origin_mod: - of_md_list.append(_get_of_module(m, torch2flow)) - torch2flow[origin_mod] = of_md_list - return of_md_list - -def _get_of_module(origin_mod, torch2flow): - if origin_mod in torch2flow: - return torch2flow[origin_mod] - - if isinstance(origin_mod, torch.nn.ModuleList): - return _get_of_module_list(origin_mod, torch2flow) - - proxy_md = ProxySubmodule(origin_mod) - new_md_cls = replace_class(type(origin_mod)) - - def init(self): - self._parameters = OrderedDict() - self._buffers = OrderedDict() - self._modules = OrderedDict() - for (n, p) in list(proxy_md.named_parameters("", False)): - self._parameters[n] = flow.utils.tensor.from_torch(p.data) - for (n, b) in list(proxy_md.named_buffers("", False)): - self._buffers[n] = flow.utils.tensor.from_torch(b.data) - for (n, m) in proxy_md._modules.items(): - self._modules[n] = _get_of_module(m, torch2flow) - - for k, v in proxy_md.__dict__.items(): - if k not in self.__dict__: - try: - attr = getattr(proxy_md, k) - except: - continue - self.__dict__[k] = attr - - def proxy_getattr(self, attr): - if attr in ["_parameters", "_buffers", "_modules"]: - raise ValueError(f"missing attr {attr} in base class") - else: - return getattr(proxy_md, attr) - - of_md_cls = type( - str(new_md_cls), (new_md_cls,), {"__init__": init, "__getattr__": proxy_getattr}, - ) - - new_md = of_md_cls() - - torch2flow[origin_mod] = new_md - return new_md - - -class UNetGraph(flow.nn.Graph): - @flow.nn.Graph.with_dynamic_input_shape(size=9) - def __init__(self, unet): - super().__init__(enable_get_runtime_state_dict=True) - self.unet = unet - self.config.enable_cudnn_conv_heuristic_search_algo(False) - self.config.allow_fuse_add_to_output(True) - self.debug(0) - - os.environ["ONEFLOW_MLIR_CSE"] = "1" - os.environ["ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION"] = "1" - os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" - os.environ["ONEFLOW_MLIR_FUSE_FORWARD_OPS"] = "1" - os.environ["ONEFLOW_MLIR_FUSE_OPS_WITH_BACKWARD_IMPL"] = "1" - # Open this will raise error - # os.environ["ONEFLOW_MLIR_GROUP_MATMUL"] = "1" - os.environ["ONEFLOW_MLIR_PREFER_NHWC"] = "1" - os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_CONV_BIAS"] = "1" - os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_LINEAR"] = "1" - os.environ["ONEFLOW_KERNEL_CONV_CUTLASS_IMPL_ENABLE_TUNING_WARMUP"] = "1" - os.environ["ONEFLOW_KERNEL_CONV_ENABLE_CUTLASS_IMPL"] = "1" - os.environ["ONEFLOW_CONV_ALLOW_HALF_PRECISION_ACCUMULATION"] = "1" - os.environ["ONEFLOW_MATMUL_ALLOW_HALF_PRECISION_ACCUMULATION"] = "1" - os.environ["ONEFLOW_LINEAR_EMBEDDING_SKIP_INIT"] = "1" - - def build(self, latent_model_input, t, encoder_hidden_states, cross_attention_kwargs, added_cond_kwargs=None, return_dict=False): - encoder_hidden_states = flow._C.amp_white_identity(encoder_hidden_states) - pred = self.unet( - latent_model_input, - t, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - added_cond_kwargs=added_cond_kwargs, - return_dict=return_dict, - ) - return pred - - def warmup_with_load(self, file_path): - state_dict = flow.load(file_path) - self.load_runtime_state_dict(state_dict) - - def save_graph(self, file_path): - state_dict = self.runtime_state_dict() - flow.save(state_dict, file_path) diff --git a/src/onediff/infer_compiler/with_onef_graph.py b/src/onediff/infer_compiler/with_onef_graph.py new file mode 100644 index 000000000..ded3fc301 --- /dev/null +++ b/src/onediff/infer_compiler/with_onef_graph.py @@ -0,0 +1,97 @@ +from . import obj_1f_from_torch +from .obj_1f_from_torch import _get_module +import os +import oneflow as flow +import torch + +class UNetGraph(flow.nn.Graph): + @flow.nn.Graph.with_dynamic_input_shape(size=9) + def __init__(self, unet): + super().__init__(enable_get_runtime_state_dict=True) + self.unet = unet + self.config.enable_cudnn_conv_heuristic_search_algo(False) + self.config.allow_fuse_add_to_output(True) + self.debug(0) + + os.environ["ONEFLOW_MLIR_CSE"] = "1" + os.environ["ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION"] = "1" + os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" + os.environ["ONEFLOW_MLIR_FUSE_FORWARD_OPS"] = "1" + os.environ["ONEFLOW_MLIR_FUSE_OPS_WITH_BACKWARD_IMPL"] = "1" + # Open this will raise error + # os.environ["ONEFLOW_MLIR_GROUP_MATMUL"] = "1" + os.environ["ONEFLOW_MLIR_PREFER_NHWC"] = "1" + os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_CONV_BIAS"] = "1" + os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_LINEAR"] = "1" + os.environ["ONEFLOW_KERNEL_CONV_CUTLASS_IMPL_ENABLE_TUNING_WARMUP"] = "1" + os.environ["ONEFLOW_KERNEL_CONV_ENABLE_CUTLASS_IMPL"] = "1" + os.environ["ONEFLOW_CONV_ALLOW_HALF_PRECISION_ACCUMULATION"] = "1" + os.environ["ONEFLOW_MATMUL_ALLOW_HALF_PRECISION_ACCUMULATION"] = "1" + os.environ["ONEFLOW_LINEAR_EMBEDDING_SKIP_INIT"] = "1" + + def build(self, latent_model_input, t, encoder_hidden_states, cross_attention_kwargs, added_cond_kwargs=None, return_dict=False): + encoder_hidden_states = flow._C.amp_white_identity(encoder_hidden_states) + pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=return_dict, + ) + return pred + + def warmup_with_load(self, file_path): + state_dict = flow.load(file_path) + self.load_runtime_state_dict(state_dict) + + def save_graph(self, file_path): + state_dict = self.runtime_state_dict() + flow.save(state_dict, file_path) + +def oneflow_compile(torch_unet, use_graph=True): + torch2flow = {} + of_md = _get_module(torch_unet, torch2flow) + from oneflow.framework.args_tree import ArgsTree + def input_fn(value): + if isinstance(value, torch.Tensor): + return flow.utils.tensor.from_torch(value) + else: + return value + + def output_fn(value): + if isinstance(value, flow.Tensor): + return flow.utils.tensor.to_torch(value) + else: + return value + + if use_graph: + dpl_graph = UNetGraph(of_md) + + class DeplayableModule(of_md.__class__): + def __call__(self, *args, **kwargs): + args_tree = ArgsTree((args, kwargs), False, tensor_type=torch.Tensor) + out = args_tree.map_leaf(input_fn) + mapped_args = out[0] + mapped_kwargs = out[1] + + if use_graph: + output = self._dpl_graph(*mapped_args, **mapped_kwargs) + else: + output = super().__call__(*mapped_args, **mapped_kwargs) + + + out_tree = ArgsTree((output, None), False) + out = out_tree.map_leaf(output_fn) + return out[0] + + def _graph_load(self, file_path): + self._dpl_graph.warmup_with_load(file_path) + + def _graph_save(self, file_path): + self._dpl_graph.save_graph(file_path) + + of_md.__class__ = DeplayableModule + if use_graph: + of_md._dpl_graph = dpl_graph + return of_md \ No newline at end of file From b62c21f101bf8f607856a1424710c2dfa86d5e56 Mon Sep 17 00:00:00 2001 From: strint Date: Fri, 25 Aug 2023 04:04:15 +0000 Subject: [PATCH 12/16] refine code --- ...o_image_sdxl_fp16_with_oneflow_compile.py} | 2 +- examples/unet_torch_interplay.py | 83 ++++++------------- src/onediff/infer_compiler/__init__.py | 2 +- ..._onef_graph.py => with_oneflow_compile.py} | 3 +- 4 files changed, 28 insertions(+), 62 deletions(-) rename examples/{text_to_image_sdxl_fp16_graph.py => text_to_image_sdxl_fp16_with_oneflow_compile.py} (99%) rename src/onediff/infer_compiler/{with_onef_graph.py => with_oneflow_compile.py} (98%) diff --git a/examples/text_to_image_sdxl_fp16_graph.py b/examples/text_to_image_sdxl_fp16_with_oneflow_compile.py similarity index 99% rename from examples/text_to_image_sdxl_fp16_graph.py rename to examples/text_to_image_sdxl_fp16_with_oneflow_compile.py index 5c2dc49c6..2b9f4ba29 100644 --- a/examples/text_to_image_sdxl_fp16_graph.py +++ b/examples/text_to_image_sdxl_fp16_with_oneflow_compile.py @@ -24,7 +24,7 @@ parser.add_argument("--compile", action=argparse.BooleanOptionalAction) parser.add_argument("--save", action=argparse.BooleanOptionalAction) parser.add_argument("--load", action=argparse.BooleanOptionalAction) -parser.add_argument("--file", type=str, required=False, default="unet_graph") +parser.add_argument("--file", type=str, required=False, default="unet_compiled") cmd_args = parser.parse_args() # Normal SDXL diff --git a/examples/unet_torch_interplay.py b/examples/unet_torch_interplay.py index 9c2677d50..98affcb69 100644 --- a/examples/unet_torch_interplay.py +++ b/examples/unet_torch_interplay.py @@ -22,6 +22,9 @@ os.environ["ONEFLOW_LINEAR_EMBEDDING_SKIP_INIT"] = "1" import click +# cv2 must be imported before diffusers and oneflow to avlid error: AttributeError: module 'cv2.gapi' has no attribute 'wip' +# Maybe bacause oneflow use a lower version of cv2 +import cv2 import oneflow as flow from tqdm import tqdm from dataclasses import dataclass, fields @@ -100,27 +103,6 @@ def save_graph(self, file_path): state_dict = self.runtime_state_dict() flow.save(state_dict, file_path) -def get_deployable_unet(token, model_id, revision, ): - with MockCtx(): - unet = get_unet(token, model_id, revision) - unet_graph = UNetGraphWithCache(unet) - cross_attention_dim = unet.config["cross_attention_dim"] - warmup_meta_of_sizes = get_arg_meta_of_sizes( - batch_sizes=BATCH_SIZES, - resolution_scales=RESOLUTION_SCALES, - num_channels=num_channels, - cross_attention_dim=cross_attention_dim, - ) - for (i, m) in enumerate(warmup_meta_of_sizes): - print(f"warmup case #{i + 1}:", m) - if load == True: - print("loading graphs...") - unet_graph.warmup_with_load(file) - else: - print("warmup with arguments...") - unet_graph.warmup_with_arg(warmup_meta_of_sizes) - - def img_dim(i, start, stride): return start + stride * i @@ -155,23 +137,6 @@ def get_arg_meta_of_sizes( for j in resolution_scales ] -global_rng = random.Random() -def floats_tensor(shape, scale=1.0, rng=None, name=None): - import torch - """Creates a random float32 tensor""" - if rng is None: - rng = global_rng - - total_dims = 1 - for dim in shape: - total_dims *= dim - - values = [] - for _ in range(total_dims): - values.append(rng.random() * scale) - - return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous() - @click.command() @click.option("--token") @click.option("--repeat", default=100) @@ -187,7 +152,15 @@ def benchmark(token, repeat, sync_interval, save, load, file, model_id, revision # TODO: reproduce bug caused by changing batch # BATCH_SIZES = [4, 2] + # create a mocked unet graph + # unet mock should be placed before importing any diffusers + with MockCtx(): + unet = get_unet(token, model_id, revision) + unet_graph = UNetGraphWithCache(unet) + num_channels = 4 + cross_attention_dim = unet.config["cross_attention_dim"] + from diffusers.utils import floats_tensor import torch if model_id == "stabilityai/stable-diffusion-xl-base-1.0": # sdxl needed @@ -197,27 +170,21 @@ def benchmark(token, repeat, sync_interval, save, load, file, model_id, revision else: added_cond_kwargs = None - # create a mocked unet graph - with MockCtx(): - unet = get_unet(token, model_id, revision) - unet_graph = UNetGraphWithCache(unet) - cross_attention_dim = unet.config["cross_attention_dim"] + warmup_meta_of_sizes = get_arg_meta_of_sizes( + batch_sizes=BATCH_SIZES, + resolution_scales=RESOLUTION_SCALES, + num_channels=num_channels, + cross_attention_dim=cross_attention_dim, + ) + for (i, m) in enumerate(warmup_meta_of_sizes): + print(f"warmup case #{i + 1}:", m) - warmup_meta_of_sizes = get_arg_meta_of_sizes( - batch_sizes=BATCH_SIZES, - resolution_scales=RESOLUTION_SCALES, - num_channels=num_channels, - cross_attention_dim=cross_attention_dim, - ) - for (i, m) in enumerate(warmup_meta_of_sizes): - print(f"warmup case #{i + 1}:", m) - - if load == True: - print("loading graphs...") - unet_graph.warmup_with_load(file) - else: - print("warmup with arguments...") - unet_graph.warmup_with_arg(warmup_meta_of_sizes, added_cond_kwargs) + if load == True: + print("loading graphs...") + unet_graph.warmup_with_load(file) + else: + print("warmup with arguments...") + unet_graph.warmup_with_arg(warmup_meta_of_sizes, added_cond_kwargs) # generate inputs with torch time_step = torch.tensor([10]).to("cuda") diff --git a/src/onediff/infer_compiler/__init__.py b/src/onediff/infer_compiler/__init__.py index 83d052f6e..698611b6a 100644 --- a/src/onediff/infer_compiler/__init__.py +++ b/src/onediff/infer_compiler/__init__.py @@ -1,7 +1,7 @@ import os import torch import oneflow as flow -from .with_onef_graph import oneflow_compile +from .with_oneflow_compile import oneflow_compile from .with_fx_interpreter import OneFlowInterpreter from .with_fx_graph import fx_node_tranform diff --git a/src/onediff/infer_compiler/with_onef_graph.py b/src/onediff/infer_compiler/with_oneflow_compile.py similarity index 98% rename from src/onediff/infer_compiler/with_onef_graph.py rename to src/onediff/infer_compiler/with_oneflow_compile.py index ded3fc301..4a3e4e63b 100644 --- a/src/onediff/infer_compiler/with_onef_graph.py +++ b/src/onediff/infer_compiler/with_oneflow_compile.py @@ -11,7 +11,6 @@ def __init__(self, unet): self.unet = unet self.config.enable_cudnn_conv_heuristic_search_algo(False) self.config.allow_fuse_add_to_output(True) - self.debug(0) os.environ["ONEFLOW_MLIR_CSE"] = "1" os.environ["ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION"] = "1" @@ -94,4 +93,4 @@ def _graph_save(self, file_path): of_md.__class__ = DeplayableModule if use_graph: of_md._dpl_graph = dpl_graph - return of_md \ No newline at end of file + return of_md From 83a19f346bb9fee143ee9055c420ef63a5353e7f Mon Sep 17 00:00:00 2001 From: strint Date: Fri, 25 Aug 2023 04:08:48 +0000 Subject: [PATCH 13/16] format --- examples/text_to_image_sdxl_fp16_with_oneflow_compile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text_to_image_sdxl_fp16_with_oneflow_compile.py b/examples/text_to_image_sdxl_fp16_with_oneflow_compile.py index 2b9f4ba29..e0d83ef78 100644 --- a/examples/text_to_image_sdxl_fp16_with_oneflow_compile.py +++ b/examples/text_to_image_sdxl_fp16_with_oneflow_compile.py @@ -54,4 +54,4 @@ # Save compiled unet with oneflow if cmd_args.compile and cmd_args.save: print("saving graphs...") - pipe.unet._graph_save(cmd_args.file) \ No newline at end of file + pipe.unet._graph_save(cmd_args.file) From f6665125a5221d1d4c34528d58cf112c3767f583 Mon Sep 17 00:00:00 2001 From: strint Date: Fri, 25 Aug 2023 11:31:57 +0000 Subject: [PATCH 14/16] deal with special case --- .../infer_compiler/obj_1f_from_torch.py | 50 +++++++++++-------- 1 file changed, 30 insertions(+), 20 deletions(-) diff --git a/src/onediff/infer_compiler/obj_1f_from_torch.py b/src/onediff/infer_compiler/obj_1f_from_torch.py index 6aa703ae1..0daab1d5d 100644 --- a/src/onediff/infer_compiler/obj_1f_from_torch.py +++ b/src/onediff/infer_compiler/obj_1f_from_torch.py @@ -3,25 +3,27 @@ from collections import OrderedDict import torch import oneflow as flow +import logging +logger = logging.getLogger(__name__) __of_mds = {} -__convert_list = [ - "diffusers.models.unet_2d_condition.UNet2DConditionModel", - "diffusers.models.embeddings.TimestepEmbedding", - "diffusers.models.embeddings.Timesteps", - "diffusers.models.resnet.ResnetBlock2D", - "diffusers.models.resnet.Downsample2D", - "diffusers.models.resnet.Upsample2D", - "diffusers.models.transformer_2d.Transformer2DModel", - "diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D", - "diffusers.models.unet_2d_blocks.DownBlock2D", - "diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D", - "diffusers.models.unet_2d_blocks.UpBlock2D", - "diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D", - "diffusers.models.unet_2d_blocks.UNetMidBlock2DCrossAttn", -] - with flow.mock_torch.enable(lazy=False): + __convert_list = [ + "diffusers.models.unet_2d_condition.UNet2DConditionModel", + "diffusers.models.embeddings.TimestepEmbedding", + "diffusers.models.embeddings.Timesteps", + "diffusers.models.resnet.ResnetBlock2D", + "diffusers.models.resnet.Downsample2D", + "diffusers.models.resnet.Upsample2D", + "diffusers.models.transformer_2d.Transformer2DModel", + "diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D", + "diffusers.models.unet_2d_blocks.DownBlock2D", + "diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D", + "diffusers.models.unet_2d_blocks.UpBlock2D", + "diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D", + "diffusers.models.unet_2d_blocks.UNetMidBlock2DCrossAttn", + ] + for md_name in __convert_list: p, m = md_name.rsplit('.', 1) md = importlib.import_module(p) @@ -181,6 +183,17 @@ def __getattribute__(self, attribute): self._1f_proxy_children[attribute] = a else: a = self._1f_proxy_children[attribute] + + full_name = '.'.join((type(a).__module__, type(a).__name__)) + if full_name == "diffusers.configuration_utils.FrozenDict": + return a + if full_name == "diffusers.models.attention_processor.AttnProcessor2_0": + return a + + assert ( + type(a).__module__.startswith("torch") == False + and type(a).__module__.startswith("diffusers") == False + ), "can't be a torch module at this point! But found " + str(type(a)) return a def __call__(self, *args: Any, **kwargs: Any) -> Any: @@ -226,10 +239,7 @@ def init(self): for k, v in proxy_md.__dict__.items(): if k not in self.__dict__: - try: - attr = getattr(proxy_md, k) - except: - continue + attr = getattr(proxy_md, k) self.__dict__[k] = attr def proxy_getattr(self, attr): From 7e8c0f83da814c554798713e0d2ef9af3eead5af Mon Sep 17 00:00:00 2001 From: strint Date: Fri, 25 Aug 2023 12:00:40 +0000 Subject: [PATCH 15/16] open GROUP_MATMUL --- examples/text_to_image_sdxl_fp16.py | 3 +-- examples/text_to_image_sdxl_fp16_with_oneflow_compile.py | 5 +++-- examples/unet_torch_interplay.py | 3 +-- src/onediff/infer_compiler/with_oneflow_compile.py | 3 +-- 4 files changed, 6 insertions(+), 8 deletions(-) diff --git a/examples/text_to_image_sdxl_fp16.py b/examples/text_to_image_sdxl_fp16.py index fba1f6623..ed8d4076e 100644 --- a/examples/text_to_image_sdxl_fp16.py +++ b/examples/text_to_image_sdxl_fp16.py @@ -41,8 +41,7 @@ os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" os.environ["ONEFLOW_MLIR_FUSE_FORWARD_OPS"] = "1" os.environ["ONEFLOW_MLIR_FUSE_OPS_WITH_BACKWARD_IMPL"] = "1" - # Open this will raise error - # os.environ["ONEFLOW_MLIR_GROUP_MATMUL"] = "1" + os.environ["ONEFLOW_MLIR_GROUP_MATMUL"] = "1" os.environ["ONEFLOW_MLIR_PREFER_NHWC"] = "1" os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_CONV_BIAS"] = "1" os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_LINEAR"] = "1" diff --git a/examples/text_to_image_sdxl_fp16_with_oneflow_compile.py b/examples/text_to_image_sdxl_fp16_with_oneflow_compile.py index e0d83ef78..d7c6e0eaa 100644 --- a/examples/text_to_image_sdxl_fp16_with_oneflow_compile.py +++ b/examples/text_to_image_sdxl_fp16_with_oneflow_compile.py @@ -45,10 +45,11 @@ # Normal SDXL call sizes = [1024, 896, 768] +#sizes = [1024] for h in sizes: for w in sizes: - for i in range(1): - image = pipe(prompt=cmd_args.prompt, height=h, width=w, num_inference_steps=2).images[0] + for i in range(3): + image = pipe(prompt=cmd_args.prompt, height=h, width=w, num_inference_steps=30).images[0] image.save(f"h{h}-w{w}-i{i}-{cmd_args.saved_image}") # Save compiled unet with oneflow diff --git a/examples/unet_torch_interplay.py b/examples/unet_torch_interplay.py index 98affcb69..2dcf95172 100644 --- a/examples/unet_torch_interplay.py +++ b/examples/unet_torch_interplay.py @@ -6,8 +6,7 @@ os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" os.environ["ONEFLOW_MLIR_FUSE_FORWARD_OPS"] = "1" os.environ["ONEFLOW_MLIR_FUSE_OPS_WITH_BACKWARD_IMPL"] = "1" -# TODO(): open ONEFLOW_MLIR_GROUP_MATMUL will raise in SDXL, need be fixed. -os.environ["ONEFLOW_MLIR_GROUP_MATMUL"] = "0" +os.environ["ONEFLOW_MLIR_GROUP_MATMUL"] = "1" os.environ["ONEFLOW_MLIR_PREFER_NHWC"] = "1" os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_CONV_BIAS"] = "1" diff --git a/src/onediff/infer_compiler/with_oneflow_compile.py b/src/onediff/infer_compiler/with_oneflow_compile.py index 4a3e4e63b..e28a895b5 100644 --- a/src/onediff/infer_compiler/with_oneflow_compile.py +++ b/src/onediff/infer_compiler/with_oneflow_compile.py @@ -17,8 +17,7 @@ def __init__(self, unet): os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" os.environ["ONEFLOW_MLIR_FUSE_FORWARD_OPS"] = "1" os.environ["ONEFLOW_MLIR_FUSE_OPS_WITH_BACKWARD_IMPL"] = "1" - # Open this will raise error - # os.environ["ONEFLOW_MLIR_GROUP_MATMUL"] = "1" + os.environ["ONEFLOW_MLIR_GROUP_MATMUL"] = "1" os.environ["ONEFLOW_MLIR_PREFER_NHWC"] = "1" os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_CONV_BIAS"] = "1" os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_LINEAR"] = "1" From 4c7d6da1ad8d06c885572eac8dfdc0fc68753576 Mon Sep 17 00:00:00 2001 From: strint Date: Mon, 28 Aug 2023 11:46:11 +0000 Subject: [PATCH 16/16] add (base and refiner) full pipeline --- ...6_base_and_refiner_with_oneflow_compile.py | 92 +++++++++++++++++++ 1 file changed, 92 insertions(+) create mode 100644 examples/text_to_image_sdxl_fp16_base_and_refiner_with_oneflow_compile.py diff --git a/examples/text_to_image_sdxl_fp16_base_and_refiner_with_oneflow_compile.py b/examples/text_to_image_sdxl_fp16_base_and_refiner_with_oneflow_compile.py new file mode 100644 index 000000000..d7e6c70fb --- /dev/null +++ b/examples/text_to_image_sdxl_fp16_base_and_refiner_with_oneflow_compile.py @@ -0,0 +1,92 @@ +import os +import argparse +# cv2 must be imported before diffusers and oneflow to avlid error: AttributeError: module 'cv2.gapi' has no attribute 'wip' +# Maybe bacause oneflow use a lower version of cv2 +import cv2 +import oneflow as flow +import torch +# oneflow_compile should be imported before importing any diffusers +from onediff.infer_compiler import oneflow_compile +from diffusers import DiffusionPipeline + +parser = argparse.ArgumentParser() +parser.add_argument( + "--base", type=str, default="/share_nfs/hf_models/stable-diffusion-xl-base-1.0" +) +parser.add_argument( + "--refiner", type=str, default="stabilityai/stable-diffusion-xl-refiner-1.0" +) +parser.add_argument("--variant", type=str, default="fp16") +parser.add_argument( + "--prompt", + type=str, + default="street style, detailed, raw photo, woman, face, shot on CineStill 800T", +) +parser.add_argument("--n_steps", type=int, default=30) +parser.add_argument("--saved_image", type=str, required=False, default="xl-refiner-out.png") +parser.add_argument("--seed", type=int, default=1) +parser.add_argument("--compile", action=argparse.BooleanOptionalAction) +parser.add_argument("--save", action=argparse.BooleanOptionalAction) +parser.add_argument("--load", action=argparse.BooleanOptionalAction) +parser.add_argument("--file", type=str, required=False, default="unet_compiled") +cmd_args = parser.parse_args() + +# Normal SDXL +seed = torch.manual_seed(cmd_args.seed) +# SDXL base: StableDiffusionXLPipeline +base = DiffusionPipeline.from_pretrained( + cmd_args.base, + torch_dtype=torch.float16, + variant=cmd_args.variant, + use_safetensors=True, +) +base.to("cuda") +# SDXL refiner: StableDiffusionXLImg2ImgPipeline +refiner = DiffusionPipeline.from_pretrained( + cmd_args.refiner, + text_encoder_2=base.text_encoder_2, + vae=base.vae, + torch_dtype=torch.float16, + use_safetensors=True, + variant=cmd_args.variant, +) +refiner.to("cuda") + +# Compile unet with oneflow +if cmd_args.compile: + base.unet = oneflow_compile(base.unet) + refiner.unet = oneflow_compile(refiner.unet) + print("unet is compiled to oneflow.") + if cmd_args.load: + # Load compiled unet with oneflow + print("loading graphs...") + base.unet._graph_load("base_" + cmd_args.file) + refiner.unet._graph_load("refiner_" + cmd_args.file) + +# Normal SDXL call +sizes = [1024, 896, 768] +#sizes = [1024] +for h in sizes: + for w in sizes: + for i in range(3): + image = base( + prompt=cmd_args.prompt, + height=h, + width=w, + generator=seed, + num_inference_steps=cmd_args.n_steps, + output_type="latent", + ).images + image = refiner( + prompt=cmd_args.prompt, + generator=seed, + num_inference_steps=cmd_args.n_steps, + image=image, + ).images[0] + image.save(f"h{h}-w{w}-i{i}-{cmd_args.saved_image}") + +# Save compiled unet with oneflow +if cmd_args.compile and cmd_args.save: + print("saving graphs...") + base.unet._graph_save("base_" + cmd_args.file) + refiner.unet._graph_save("refiner_" + cmd_args.file) \ No newline at end of file