diff --git a/README.md b/README.md index bfaaf1f..9a8ec40 100644 --- a/README.md +++ b/README.md @@ -121,7 +121,7 @@ parallelize_pipe( ) parallelize_vae(pipe.vae, mesh=mesh._flatten()) -# pipe.enable_model_cpu_offload() +# pipe.enable_model_cpu_offload(gpu_id=dist.get_rank()) # torch._inductor.config.reorder_for_compute_comm_overlap = True # pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs") @@ -178,12 +178,12 @@ pipe = HunyuanVideoPipeline.from_pretrained( pipe.vae.enable_tiling( # Make it runnable on GPUs with 48GB memory - tile_sample_min_height=128, - tile_sample_stride_height=96, - tile_sample_min_width=128, - tile_sample_stride_width=96, - tile_sample_min_num_frames=32, - tile_sample_stride_num_frames=24, + # tile_sample_min_height=128, + # tile_sample_stride_height=96, + # tile_sample_min_width=128, + # tile_sample_stride_width=96, + # tile_sample_min_num_frames=32, + # tile_sample_stride_num_frames=24, ) from para_attn.context_parallel import init_context_parallel_mesh @@ -199,16 +199,16 @@ parallelize_pipe( ) parallelize_vae(pipe.vae, mesh=mesh._flatten()) -# pipe.enable_model_cpu_offload() +# pipe.enable_model_cpu_offload(gpu_id=dist.get_rank()) # torch._inductor.config.reorder_for_compute_comm_overlap = True # pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs") output = pipe( prompt="A cat walks on the grass, realistic", - height=320, - width=512, - num_frames=61, + height=720, + width=1280, + num_frames=129, num_inference_steps=30, output_type="pil" if dist.get_rank() == 0 else "pt", ).frames[0] @@ -254,7 +254,7 @@ parallelize_pipe( ) # Enable memory savings -# pipe.enable_model_cpu_offload() +# pipe.enable_model_cpu_offload(gpu_id=dist.get_rank()) pipe.enable_vae_tiling() # torch._inductor.config.reorder_for_compute_comm_overlap = True diff --git a/focus_attn_examples/run_hunyuan_video.py b/focus_attn_examples/run_hunyuan_video.py index af4aacd..b441615 100644 --- a/focus_attn_examples/run_hunyuan_video.py +++ b/focus_attn_examples/run_hunyuan_video.py @@ -18,12 +18,12 @@ pipe.vae.enable_tiling( # Make it runnable on GPUs with 48GB memory - tile_sample_min_height=128, - tile_sample_stride_height=96, - tile_sample_min_width=128, - tile_sample_stride_width=96, - tile_sample_min_num_frames=32, - tile_sample_stride_num_frames=24, + # tile_sample_min_height=128, + # tile_sample_stride_height=96, + # tile_sample_min_width=128, + # tile_sample_stride_width=96, + # tile_sample_min_num_frames=32, + # tile_sample_stride_num_frames=24, ) from para_attn.focus_attn.diffusers_adapters import apply_focus_attn_on_pipe @@ -36,9 +36,9 @@ output = pipe( prompt="A cat walks on the grass, realistic", - height=320, - width=512, - num_frames=61, + height=720, + width=1280, + num_frames=129, num_inference_steps=30, ).frames[0] diff --git a/parallel_examples/run_cogvideox.py b/parallel_examples/run_cogvideox.py index a8a92c5..42b3b2a 100644 --- a/parallel_examples/run_cogvideox.py +++ b/parallel_examples/run_cogvideox.py @@ -23,8 +23,8 @@ ), ) -# pipe.enable_model_cpu_offload() -# pipe.enable_sequential_cpu_offload() +# pipe.enable_model_cpu_offload(gpu_id=dist.get_rank()) +# pipe.enable_sequential_cpu_offload(gpu_id=dist.get_rank()) pipe.vae.enable_slicing() pipe.vae.enable_tiling() diff --git a/parallel_examples/run_flux.py b/parallel_examples/run_flux.py index 6fc423a..bcb28cb 100644 --- a/parallel_examples/run_flux.py +++ b/parallel_examples/run_flux.py @@ -23,7 +23,7 @@ ) parallelize_vae(pipe.vae, mesh=mesh._flatten()) -# pipe.enable_model_cpu_offload() +# pipe.enable_model_cpu_offload(gpu_id=dist.get_rank()) # torch._inductor.config.reorder_for_compute_comm_overlap = True # pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs") diff --git a/parallel_examples/run_hunyuan_video.py b/parallel_examples/run_hunyuan_video.py index 41afbe2..f23ae0a 100644 --- a/parallel_examples/run_hunyuan_video.py +++ b/parallel_examples/run_hunyuan_video.py @@ -24,12 +24,12 @@ pipe.vae.enable_tiling( # Make it runnable on GPUs with 48GB memory - tile_sample_min_height=128, - tile_sample_stride_height=96, - tile_sample_min_width=128, - tile_sample_stride_width=96, - tile_sample_min_num_frames=32, - tile_sample_stride_num_frames=24, + # tile_sample_min_height=128, + # tile_sample_stride_height=96, + # tile_sample_min_width=128, + # tile_sample_stride_width=96, + # tile_sample_min_num_frames=32, + # tile_sample_stride_num_frames=24, ) from para_attn.context_parallel import init_context_parallel_mesh @@ -45,16 +45,16 @@ ) parallelize_vae(pipe.vae, mesh=mesh._flatten()) -# pipe.enable_model_cpu_offload() +# pipe.enable_model_cpu_offload(gpu_id=dist.get_rank()) # torch._inductor.config.reorder_for_compute_comm_overlap = True # pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs") output = pipe( prompt="A cat walks on the grass, realistic", - height=320, - width=512, - num_frames=61, + height=720, + width=1280, + num_frames=129, num_inference_steps=30, output_type="pil" if dist.get_rank() == 0 else "pt", ).frames[0] diff --git a/parallel_examples/run_mochi.py b/parallel_examples/run_mochi.py index 3cb8795..147726a 100644 --- a/parallel_examples/run_mochi.py +++ b/parallel_examples/run_mochi.py @@ -23,7 +23,7 @@ ) # Enable memory savings -# pipe.enable_model_cpu_offload() +# pipe.enable_model_cpu_offload(gpu_id=dist.get_rank()) pipe.enable_vae_tiling() # torch._inductor.config.reorder_for_compute_comm_overlap = True diff --git a/tests/context_parallel/test_diffusers_adapters.py b/tests/context_parallel/test_diffusers_adapters.py index 125c6d5..8c82d15 100644 --- a/tests/context_parallel/test_diffusers_adapters.py +++ b/tests/context_parallel/test_diffusers_adapters.py @@ -210,12 +210,12 @@ def new_pipe(self, dtype, device): pipe.vae.enable_tiling( # Make it runnable on GPUs with 48GB memory - tile_sample_min_height=128, - tile_sample_stride_height=96, - tile_sample_min_width=128, - tile_sample_stride_width=96, - tile_sample_min_num_frames=32, - tile_sample_stride_num_frames=24, + # tile_sample_min_height=128, + # tile_sample_stride_height=96, + # tile_sample_min_width=128, + # tile_sample_stride_width=96, + # tile_sample_min_num_frames=32, + # tile_sample_stride_num_frames=24, ) # Fix OOM because of awful inductor lowering of attn_bias of _scaled_dot_product_efficient_attention @@ -228,9 +228,9 @@ def new_pipe(self, dtype, device): def call_pipe(self, pipe, *args, **kwargs): return pipe( prompt="A cat walks on the grass, realistic", - height=320, - width=512, - num_frames=61, + height=720, + width=1280, + num_frames=129, num_inference_steps=30, output_type="pil" if self.rank == 0 else "pt", )