Skip to content

Commit

Permalink
CogVideoX 1.5 (#9877)
Browse files Browse the repository at this point in the history
* CogVideoX1_1PatchEmbed test

* 1360 * 768

* refactor

* make style

* update docs

* add modeling tests for cogvideox 1.5

* update

* make fix-copies

* add ofs embed(for convert)

* add ofs embed(for convert)

* more resolution for cogvideox1.5-5b-i2v

* use even number of latent frames only

* update pipeline implementations

* make style

* set patch_size_t as None by default

* #skip frames 0

* refactor

* make style

* update docs

* fix ofs_embed

* update docs

* invert_scale_latents

* update

* fix

* Update docs/source/en/api/pipelines/cogvideox.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/api/pipelines/cogvideox.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/api/pipelines/cogvideox.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/api/pipelines/cogvideox.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update src/diffusers/models/transformers/cogvideox_transformer_3d.py

* update conversion script

* remove copied from

* fix test

* Update docs/source/en/api/pipelines/cogvideox.md

* Update docs/source/en/api/pipelines/cogvideox.md

* Update docs/source/en/api/pipelines/cogvideox.md

* Update docs/source/en/api/pipelines/cogvideox.md

---------

Co-authored-by: Aryan <aryan@huggingface.co>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
  • Loading branch information
3 people authored Nov 18, 2024
1 parent c3c94fe commit 3b28306
Show file tree
Hide file tree
Showing 10 changed files with 405 additions and 92 deletions.
33 changes: 23 additions & 10 deletions docs/source/en/api/pipelines/cogvideox.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,29 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m

This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM).

There are two models available that can be used with the text-to-video and video-to-video CogVideoX pipelines:
- [`THUDM/CogVideoX-2b`](https://huggingface.co/THUDM/CogVideoX-2b): The recommended dtype for running this model is `fp16`.
- [`THUDM/CogVideoX-5b`](https://huggingface.co/THUDM/CogVideoX-5b): The recommended dtype for running this model is `bf16`.

There is one model available that can be used with the image-to-video CogVideoX pipeline:
- [`THUDM/CogVideoX-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-5b-I2V): The recommended dtype for running this model is `bf16`.

There are two models that support pose controllable generation (by the [Alibaba-PAI](https://huggingface.co/alibaba-pai) team):
- [`alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose): The recommended dtype for running this model is `bf16`.
- [`alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose): The recommended dtype for running this model is `bf16`.
There are three official CogVideoX checkpoints for text-to-video and video-to-video.
| checkpoints | recommended inference dtype |
|---|---|
| [`THUDM/CogVideoX-2b`](https://huggingface.co/THUDM/CogVideoX-2b) | torch.float16 |
| [`THUDM/CogVideoX-5b`](https://huggingface.co/THUDM/CogVideoX-5b) | torch.bfloat16 |
| [`THUDM/CogVideoX1.5-5b`](https://huggingface.co/THUDM/CogVideoX1.5-5b) | torch.bfloat16 |

There are two official CogVideoX checkpoints available for image-to-video.
| checkpoints | recommended inference dtype |
|---|---|
| [`THUDM/CogVideoX-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-5b-I2V) | torch.bfloat16 |
| [`THUDM/CogVideoX-1.5-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-1.5-5b-I2V) | torch.bfloat16 |

For the CogVideoX 1.5 series:
- Text-to-video (T2V) works best at a resolution of 1360x768 because it was trained with that specific resolution.
- Image-to-video (I2V) works for multiple resolutions. The width can vary from 768 to 1360, but the height must be 768. The height/width must be divisible by 16.
- Both T2V and I2V models support generation with 81 and 161 frames and work best at this value. Exporting videos at 16 FPS is recommended.

There are two official CogVideoX checkpoints that support pose controllable generation (by the [Alibaba-PAI](https://huggingface.co/alibaba-pai) team).
| checkpoints | recommended inference dtype |
|---|---|
| [`alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose) | torch.bfloat16 |
| [`alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose) | torch.bfloat16 |

## Inference

Expand Down
76 changes: 66 additions & 10 deletions scripts/convert_cogvideox_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]):
"post_attn1_layernorm": "norm2.norm",
"time_embed.0": "time_embedding.linear_1",
"time_embed.2": "time_embedding.linear_2",
"ofs_embed.0": "ofs_embedding.linear_1",
"ofs_embed.2": "ofs_embedding.linear_2",
"mixins.patch_embed": "patch_embed",
"mixins.final_layer.norm_final": "norm_out.norm",
"mixins.final_layer.linear": "proj_out",
Expand Down Expand Up @@ -140,6 +142,7 @@ def convert_transformer(
use_rotary_positional_embeddings: bool,
i2v: bool,
dtype: torch.dtype,
init_kwargs: Dict[str, Any],
):
PREFIX_KEY = "model.diffusion_model."

Expand All @@ -149,7 +152,9 @@ def convert_transformer(
num_layers=num_layers,
num_attention_heads=num_attention_heads,
use_rotary_positional_embeddings=use_rotary_positional_embeddings,
use_learned_positional_embeddings=i2v,
ofs_embed_dim=512 if (i2v and init_kwargs["patch_size_t"] is not None) else None, # CogVideoX1.5-5B-I2V
use_learned_positional_embeddings=i2v and init_kwargs["patch_size_t"] is None, # CogVideoX-5B-I2V
**init_kwargs,
).to(dtype=dtype)

for key in list(original_state_dict.keys()):
Expand All @@ -163,13 +168,18 @@ def convert_transformer(
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)

transformer.load_state_dict(original_state_dict, strict=True)
return transformer


def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype):
def convert_vae(ckpt_path: str, scaling_factor: float, version: str, dtype: torch.dtype):
init_kwargs = {"scaling_factor": scaling_factor}
if version == "1.5":
init_kwargs.update({"invert_scale_latents": True})

original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
vae = AutoencoderKLCogVideoX(scaling_factor=scaling_factor).to(dtype=dtype)
vae = AutoencoderKLCogVideoX(**init_kwargs).to(dtype=dtype)

for key in list(original_state_dict.keys()):
new_key = key[:]
Expand All @@ -187,6 +197,34 @@ def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype):
return vae


def get_transformer_init_kwargs(version: str):
if version == "1.0":
vae_scale_factor_spatial = 8
init_kwargs = {
"patch_size": 2,
"patch_size_t": None,
"patch_bias": True,
"sample_height": 480 // vae_scale_factor_spatial,
"sample_width": 720 // vae_scale_factor_spatial,
"sample_frames": 49,
}

elif version == "1.5":
vae_scale_factor_spatial = 8
init_kwargs = {
"patch_size": 2,
"patch_size_t": 2,
"patch_bias": False,
"sample_height": 300,
"sample_width": 300,
"sample_frames": 81,
}
else:
raise ValueError("Unsupported version of CogVideoX.")

return init_kwargs


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
Expand All @@ -202,6 +240,12 @@ def get_args():
parser.add_argument(
"--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
)
parser.add_argument(
"--typecast_text_encoder",
action="store_true",
default=False,
help="Whether or not to apply fp16/bf16 precision to text_encoder",
)
# For CogVideoX-2B, num_layers is 30. For 5B, it is 42
parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks")
# For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48
Expand All @@ -214,7 +258,18 @@ def get_args():
parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE")
# For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0
parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE")
parser.add_argument("--i2v", action="store_true", default=False, help="Whether to save the model weights in fp16")
parser.add_argument(
"--i2v",
action="store_true",
default=False,
help="Whether the model to be converted is the Image-to-Video version of CogVideoX.",
)
parser.add_argument(
"--version",
choices=["1.0", "1.5"],
default="1.0",
help="Which version of CogVideoX to use for initializing default modeling parameters.",
)
return parser.parse_args()


Expand All @@ -230,21 +285,27 @@ def get_args():
dtype = torch.float16 if args.fp16 else torch.bfloat16 if args.bf16 else torch.float32

if args.transformer_ckpt_path is not None:
init_kwargs = get_transformer_init_kwargs(args.version)
transformer = convert_transformer(
args.transformer_ckpt_path,
args.num_layers,
args.num_attention_heads,
args.use_rotary_positional_embeddings,
args.i2v,
dtype,
init_kwargs,
)
if args.vae_ckpt_path is not None:
vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype)
# Keep VAE in float32 for better quality
vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, args.version, torch.float32)

text_encoder_id = "google/t5-v1_1-xxl"
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)

if args.typecast_text_encoder:
text_encoder = text_encoder.to(dtype=dtype)

# Apparently, the conversion does not work anymore without this :shrug:
for param in text_encoder.parameters():
param.data = param.data.contiguous()
Expand Down Expand Up @@ -276,11 +337,6 @@ def get_args():
scheduler=scheduler,
)

if args.fp16:
pipe = pipe.to(dtype=torch.float16)
if args.bf16:
pipe = pipe.to(dtype=torch.bfloat16)

# We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird
# for users to specify variant when the default is not fp32 and they want to run with the correct default (which
# is either fp16/bf16 here).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1057,6 +1057,7 @@ def __init__(
force_upcast: float = True,
use_quant_conv: bool = False,
use_post_quant_conv: bool = False,
invert_scale_latents: bool = False,
):
super().__init__()

Expand Down
76 changes: 61 additions & 15 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ class CogVideoXPatchEmbed(nn.Module):
def __init__(
self,
patch_size: int = 2,
patch_size_t: Optional[int] = None,
in_channels: int = 16,
embed_dim: int = 1920,
text_embed_dim: int = 4096,
Expand All @@ -355,6 +356,7 @@ def __init__(
super().__init__()

self.patch_size = patch_size
self.patch_size_t = patch_size_t
self.embed_dim = embed_dim
self.sample_height = sample_height
self.sample_width = sample_width
Expand All @@ -366,9 +368,15 @@ def __init__(
self.use_positional_embeddings = use_positional_embeddings
self.use_learned_positional_embeddings = use_learned_positional_embeddings

self.proj = nn.Conv2d(
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
)
if patch_size_t is None:
# CogVideoX 1.0 checkpoints
self.proj = nn.Conv2d(
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
)
else:
# CogVideoX 1.5 checkpoints
self.proj = nn.Linear(in_channels * patch_size * patch_size * patch_size_t, embed_dim)

self.text_proj = nn.Linear(text_embed_dim, embed_dim)

if use_positional_embeddings or use_learned_positional_embeddings:
Expand Down Expand Up @@ -407,12 +415,24 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
"""
text_embeds = self.text_proj(text_embeds)

batch, num_frames, channels, height, width = image_embeds.shape
image_embeds = image_embeds.reshape(-1, channels, height, width)
image_embeds = self.proj(image_embeds)
image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
batch_size, num_frames, channels, height, width = image_embeds.shape

if self.patch_size_t is None:
image_embeds = image_embeds.reshape(-1, channels, height, width)
image_embeds = self.proj(image_embeds)
image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:])
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
else:
p = self.patch_size
p_t = self.patch_size_t

image_embeds = image_embeds.permute(0, 1, 3, 4, 2)
image_embeds = image_embeds.reshape(
batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels
)
image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3)
image_embeds = self.proj(image_embeds)

embeds = torch.cat(
[text_embeds, image_embeds], dim=1
Expand Down Expand Up @@ -497,7 +517,14 @@ def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tens


def get_3d_rotary_pos_embed(
embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
embed_dim,
crops_coords,
grid_size,
temporal_size,
theta: int = 10000,
use_real: bool = True,
grid_type: str = "linspace",
max_size: Optional[Tuple[int, int]] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
RoPE for video tokens with 3D structure.
Expand All @@ -513,17 +540,30 @@ def get_3d_rotary_pos_embed(
The size of the temporal dimension.
theta (`float`):
Scaling factor for frequency computation.
grid_type (`str`):
Whether to use "linspace" or "slice" to compute grids.
Returns:
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
"""
if use_real is not True:
raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed")
start, stop = crops_coords
grid_size_h, grid_size_w = grid_size
grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)

if grid_type == "linspace":
start, stop = crops_coords
grid_size_h, grid_size_w = grid_size
grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
grid_t = np.arange(temporal_size, dtype=np.float32)
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
elif grid_type == "slice":
max_h, max_w = max_size
grid_size_h, grid_size_w = grid_size
grid_h = np.arange(max_h, dtype=np.float32)
grid_w = np.arange(max_w, dtype=np.float32)
grid_t = np.arange(temporal_size, dtype=np.float32)
else:
raise ValueError("Invalid value passed for `grid_type`.")

# Compute dimensions for each axis
dim_t = embed_dim // 4
Expand Down Expand Up @@ -559,6 +599,12 @@ def combine_time_height_width(freqs_t, freqs_h, freqs_w):
t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w

if grid_type == "slice":
t_cos, t_sin = t_cos[:temporal_size], t_sin[:temporal_size]
h_cos, h_sin = h_cos[:grid_size_h], h_sin[:grid_size_h]
w_cos, w_sin = w_cos[:grid_size_w], w_sin[:grid_size_w]

cos = combine_time_height_width(t_cos, h_cos, w_cos)
sin = combine_time_height_width(t_sin, h_sin, w_sin)
return cos, sin
Expand Down
Loading

0 comments on commit 3b28306

Please sign in to comment.