Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-network the DIT, fix some parameters, and simplify the model networking code #632

Merged
merged 28 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
59f23a0
modified the dit
chang-wenbin Jul 29, 2024
5fee64b
add zkk_facebook
chang-wenbin Jul 29, 2024
f653a66
update zkk_facebook_dit.py
chang-wenbin Jul 29, 2024
3b29d9d
update transformer_2d
chang-wenbin Jul 30, 2024
a88caea
update dit optimize
chang-wenbin Jul 31, 2024
54eeec2
update transformer_2d
chang-wenbin Jul 31, 2024
28a62c0
rename facebook_dit
chang-wenbin Aug 1, 2024
884e29a
merge PR
chang-wenbin Aug 5, 2024
15d08b6
merge from develop
chang-wenbin Aug 5, 2024
7d49c49
Fixed the original dynamic image bug
chang-wenbin Aug 5, 2024
b03aa8e
update triton op import paddlemix
chang-wenbin Aug 5, 2024
cb86d17
update dit
chang-wenbin Aug 7, 2024
dc0c45c
update transformer_2d & simplified_facebook_dit
chang-wenbin Aug 7, 2024
42f61bc
update demo & implified_facebook_dit & transformer_2d
chang-wenbin Aug 7, 2024
000dd80
update Inference_Optimize
chang-wenbin Aug 7, 2024
9bb9cde
update demo & simplified_facebook_dit
chang-wenbin Aug 7, 2024
d3de838
update demo
chang-wenbin Aug 7, 2024
400ab19
update demo simplified_facebook_dit transformer_2d
chang-wenbin Aug 7, 2024
bfe8c41
update demo transformer_2d & simplified_facebook_dit
chang-wenbin Aug 7, 2024
8896057
test
chang-wenbin Aug 7, 2024
e9aa47d
add format
chang-wenbin Aug 7, 2024
c8916f7
add format
chang-wenbin Aug 7, 2024
a87f81b
add Argument to the demo
chang-wenbin Aug 8, 2024
0a09bf2
update Argument to the demo
chang-wenbin Aug 8, 2024
10e8c1f
Merge remote-tracking branch 'upstream/develop' into DIT_PaddleMIX_729
chang-wenbin Aug 8, 2024
10953b5
update transformer_2d
chang-wenbin Aug 9, 2024
922d7d0
update DIT_demo
chang-wenbin Aug 19, 2024
c4f8242
Merge branch 'develop' into DIT_PaddleMIX_729
nemonameless Aug 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,74 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import datetime
import os

import paddle
from paddlenlp.trainer import set_seed

from ppdiffusers import DDIMScheduler, DiTPipeline

dtype = paddle.float32

def parse_args():
parser = argparse.ArgumentParser(
description=" Use PaddleMIX to accelerate the Diffusion Transformer image generation model."
)
parser.add_argument(
"--benchmark",
type=(lambda x: str(x).lower() in ["true", "1", "yes"]),
default=False,
help="if benchmark is set to True, measure inference performance",
)
parser.add_argument(
"--inference_optimize",
type=(lambda x: str(x).lower() in ["true", "1", "yes"]),
default=False,
help="If inference_optimize is set to True, all optimizations except Triton are enabled.",
)
parser.add_argument(
"--inference_optimize_triton",
type=(lambda x: str(x).lower() in ["true", "1", "yes"]),
default=True,
help="If inference_optimize_triton is set to True, Triton operator optimized inference is enabled.",
)
return parser.parse_args()


args = parse_args()

if args.inference_optimize:
os.environ["INFERENCE_OPTIMIZE"] = "True"
if args.inference_optimize_triton:
os.environ["INFERENCE_OPTIMIZE_TRITON"] = "True"

dtype = paddle.float16
pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256", paddle_dtype=dtype)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
set_seed(42)

words = ["golden retriever"] # class_ids [207]
class_ids = pipe.get_label_ids(words)
image = pipe(class_labels=class_ids, num_inference_steps=25).images[0]

if args.benchmark:

# warmup
for i in range(5):
image = pipe(class_labels=class_ids, num_inference_steps=25).images[0]

repeat_times = 5

paddle.device.synchronize()
starttime = datetime.datetime.now()
for i in range(repeat_times):
image = pipe(class_labels=class_ids, num_inference_steps=25).images[0]
paddle.device.synchronize()
endtime = datetime.datetime.now()

duringtime = endtime - starttime
time_ms = duringtime.seconds * 1000 + duringtime.microseconds / 1000.0
print("The ave end to end time : ", time_ms / repeat_times, "ms")

image = pipe(class_labels=class_ids, num_inference_steps=25).images[0]
image.save("class_conditional_image_generation-dit-result.png")
5 changes: 5 additions & 0 deletions ppdiffusers/ppdiffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,6 +1050,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

return model

@classmethod
def custom_modify_weight(cls, state_dict):
pass

@classmethod
def _load_pretrained_model(
cls,
Expand Down Expand Up @@ -1130,6 +1134,7 @@ def _find_mismatched_keys(
error_msgs.append(
f"Error size mismatch, {key_name} receives a shape {loaded_shape}, but the expected shape is {model_shape}."
)
cls.custom_modify_weight(state_dict)
faster_set_state_dict(model_to_load, state_dict)

missing_keys = sorted(list(set(expected_keys) - set(loaded_keys)))
Expand Down
137 changes: 137 additions & 0 deletions ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import os

import paddle
import paddle.nn.functional as F
from paddle import nn


class SimplifiedFacebookDIT(nn.Layer):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

必须一定要简化这个模块吗?

Copy link
Contributor

@zhoutianzi666 zhoutianzi666 Aug 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

必须一定要简化这个模块吗?

手工优化需要

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

手工优化需要对原动态图模型组网 做高性能精简重组,这一模块还将transformer循环中的冗余计算部分提出,减少了部分计算量。
感谢提供修改意见,辛苦!

def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attention_head_dim: int):
super().__init__()
self.num_layers = num_layers
self.dim = dim
self.heads_num = num_attention_heads
self.head_dim = attention_head_dim
self.timestep_embedder_in_channels = 256
self.timestep_embedder_time_embed_dim = 1152
self.timestep_embedder_time_embed_dim_out = self.timestep_embedder_time_embed_dim
self.LabelEmbedding_num_classes = 1001
self.LabelEmbedding_num_hidden_size = 1152

self.fcs0 = nn.LayerList(
[
nn.Linear(self.timestep_embedder_in_channels, self.timestep_embedder_time_embed_dim)
for i in range(num_layers)
]
)

self.fcs1 = nn.LayerList(
[
nn.Linear(self.timestep_embedder_time_embed_dim, self.timestep_embedder_time_embed_dim_out)
for i in range(num_layers)
]
)

self.fcs2 = nn.LayerList(
[
nn.Linear(self.timestep_embedder_time_embed_dim, 6 * self.timestep_embedder_time_embed_dim)
for i in range(num_layers)
]
)

self.embs = nn.LayerList(
[
nn.Embedding(self.LabelEmbedding_num_classes, self.LabelEmbedding_num_hidden_size)
for i in range(num_layers)
]
)

self.q = nn.LayerList([nn.Linear(dim, dim) for i in range(num_layers)])
self.k = nn.LayerList([nn.Linear(dim, dim) for i in range(num_layers)])
self.v = nn.LayerList([nn.Linear(dim, dim) for i in range(num_layers)])
self.out_proj = nn.LayerList([nn.Linear(dim, dim) for i in range(num_layers)])
self.ffn1 = nn.LayerList([nn.Linear(dim, dim * 4) for i in range(num_layers)])
self.ffn2 = nn.LayerList([nn.Linear(dim * 4, dim) for i in range(num_layers)])
self.norm = nn.LayerNorm(1152, epsilon=1e-06, weight_attr=False, bias_attr=False)
self.norm1 = nn.LayerNorm(1152, epsilon=1e-05, weight_attr=False, bias_attr=False)

def forward(self, hidden_states, timesteps, class_labels):

# below code are copied from PaddleMIX/ppdiffusers/ppdiffusers/models/embeddings.py
num_channels = 256
max_period = 10000
downscale_freq_shift = 1
half_dim = num_channels // 2
exponent = -math.log(max_period) * paddle.arange(start=0, end=half_dim, dtype="float32")
exponent = exponent / (half_dim - downscale_freq_shift)
emb = paddle.exp(exponent)
emb = timesteps[:, None].cast("float32") * emb[None, :]
emb = paddle.concat([paddle.cos(emb), paddle.sin(emb)], axis=-1)
common_emb = emb.cast(hidden_states.dtype)

for i in range(self.num_layers):
emb = self.fcs0[i](common_emb)
emb = F.silu(emb)
emb = self.fcs1[i](emb)
emb = emb + self.embs[i](class_labels)
emb = F.silu(emb)
emb = self.fcs2[i](emb)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, axis=1)
import paddlemix

if os.getenv("INFERENCE_OPTIMIZE_TRITON"):
norm_hidden_states = paddlemix.triton_ops.adaptive_layer_norm(
hidden_states, scale_msa, shift_msa, epsilon=1e-06
)
else:
norm_hidden_states = self.norm(
hidden_states,
)
norm_hidden_states = norm_hidden_states * (1 + scale_msa[:, None]) + shift_msa[:, None]

q = self.q[i](norm_hidden_states).reshape([0, 0, self.heads_num, self.head_dim])
k = self.k[i](norm_hidden_states).reshape([0, 0, self.heads_num, self.head_dim])
v = self.v[i](norm_hidden_states).reshape([0, 0, self.heads_num, self.head_dim])

norm_hidden_states = F.scaled_dot_product_attention_(q, k, v, scale=self.head_dim**-0.5)
norm_hidden_states = norm_hidden_states.reshape(
[norm_hidden_states.shape[0], norm_hidden_states.shape[1], self.dim]
)
norm_hidden_states = self.out_proj[i](norm_hidden_states)
if os.getenv("INFERENCE_OPTIMIZE_TRITON"):
hidden_states, norm_hidden_states = paddlemix.triton_ops.fused_adaLN_scale_residual(
hidden_states, norm_hidden_states, gate_msa, scale_mlp, shift_mlp, epsilon=1e-05
)
else:
hidden_states = hidden_states + norm_hidden_states * gate_msa.reshape(
[norm_hidden_states.shape[0], 1, self.dim]
)
norm_hidden_states = self.norm1(
hidden_states,
)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]

norm_hidden_states = self.ffn1[i](norm_hidden_states)
norm_hidden_states = F.gelu(norm_hidden_states, approximate=True)
norm_hidden_states = self.ffn2[i](norm_hidden_states)

hidden_states = hidden_states + norm_hidden_states * gate_mlp.reshape(
[norm_hidden_states.shape[0], 1, self.dim]
)

return hidden_states
115 changes: 81 additions & 34 deletions ppdiffusers/ppdiffusers/models/transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from dataclasses import dataclass
from typing import Any, Dict, Optional

Expand All @@ -33,6 +34,7 @@
from .lora import LoRACompatibleConv, LoRACompatibleLinear
from .modeling_utils import ModelMixin
from .normalization import AdaLayerNormSingle
from .simplified_facebook_dit import SimplifiedFacebookDIT


@dataclass
Expand Down Expand Up @@ -114,6 +116,8 @@ def __init__(
self.inner_dim = inner_dim = num_attention_heads * attention_head_dim
self.data_format = data_format

self.inference_optimize = os.getenv("INFERENCE_OPTIMIZE") == "True"

conv_cls = nn.Conv2D if USE_PEFT_BACKEND else LoRACompatibleConv
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear

Expand Down Expand Up @@ -217,6 +221,17 @@ def __init__(
for d in range(num_layers)
]
)
if self.inference_optimize:
self.simplified_facebookdit = SimplifiedFacebookDIT(
num_layers, inner_dim, num_attention_heads, attention_head_dim
)
self.simplified_facebookdit = paddle.incubate.jit.inference(
self.simplified_facebookdit,
enable_new_ir=True,
cache_static_model=False,
exp_enable_use_cutlass=True,
delete_pass_lists=["add_norm_fuse_pass"],
)

# 4. Define output layers
self.out_channels = in_channels if out_channels is None else out_channels
Expand Down Expand Up @@ -392,40 +407,43 @@ def forward(
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.reshape([batch_size, -1, hidden_states.shape[-1]])

for block in self.transformer_blocks:
if self.gradient_checkpointing and not hidden_states.stop_gradient and not use_old_recompute():

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)

return custom_forward

ckpt_kwargs = {} if recompute_use_reentrant() else {"use_reentrant": False}
hidden_states = recompute(
create_custom_forward(block),
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
timestep,
cross_attention_kwargs,
class_labels,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=class_labels,
)
if self.inference_optimize:
hidden_states = self.simplified_facebookdit(hidden_states, timestep, class_labels)
else:
for block in self.transformer_blocks:
if self.gradient_checkpointing and not hidden_states.stop_gradient and not use_old_recompute():

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)

return custom_forward

ckpt_kwargs = {} if recompute_use_reentrant() else {"use_reentrant": False}
hidden_states = recompute(
create_custom_forward(block),
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
timestep,
cross_attention_kwargs,
class_labels,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=class_labels,
)

# 3. Output
if self.is_input_continuous:
Expand Down Expand Up @@ -489,3 +507,32 @@ def custom_forward(*inputs):
return (output,)

return Transformer2DModelOutput(sample=output)

@classmethod
def custom_modify_weight(cls, state_dict):
if os.getenv("INFERENCE_OPTIMIZE") != "True":
return
for i in range(28):
map_from_my_dit = [
(f"q.{i}.weight", f"{i}.attn1.to_q.weight"),
(f"k.{i}.weight", f"{i}.attn1.to_k.weight"),
(f"v.{i}.weight", f"{i}.attn1.to_v.weight"),
(f"q.{i}.bias", f"{i}.attn1.to_q.bias"),
(f"k.{i}.bias", f"{i}.attn1.to_k.bias"),
(f"v.{i}.bias", f"{i}.attn1.to_v.bias"),
(f"out_proj.{i}.weight", f"{i}.attn1.to_out.0.weight"),
(f"out_proj.{i}.bias", f"{i}.attn1.to_out.0.bias"),
(f"ffn1.{i}.weight", f"{i}.ff.net.0.proj.weight"),
(f"ffn1.{i}.bias", f"{i}.ff.net.0.proj.bias"),
(f"ffn2.{i}.weight", f"{i}.ff.net.2.weight"),
(f"ffn2.{i}.bias", f"{i}.ff.net.2.bias"),
(f"fcs0.{i}.weight", f"{i}.norm1.emb.timestep_embedder.linear_1.weight"),
(f"fcs0.{i}.bias", f"{i}.norm1.emb.timestep_embedder.linear_1.bias"),
(f"fcs1.{i}.weight", f"{i}.norm1.emb.timestep_embedder.linear_2.weight"),
(f"fcs1.{i}.bias", f"{i}.norm1.emb.timestep_embedder.linear_2.bias"),
(f"fcs2.{i}.weight", f"{i}.norm1.linear.weight"),
(f"fcs2.{i}.bias", f"{i}.norm1.linear.bias"),
(f"embs.{i}.weight", f"{i}.norm1.emb.class_embedder.embedding_table.weight"),
]
for to_, from_ in map_from_my_dit:
state_dict["simplified_facebookdit." + to_] = paddle.assign(state_dict["transformer_blocks." + from_])