Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-network the DIT, fix some parameters, and simplify the model networking code #632

Merged
merged 28 commits into from
Aug 28, 2024

Conversation

chang-wenbin
Copy link
Contributor

@chang-wenbin chang-wenbin commented Jul 29, 2024

Latest optimization: Re-network DIT, simplify the original model dynamic graph into a high-performance model network,

  1. For the core that consumes more time: the transformer part uses paddle.incubate.jit.inference to do dynamic and static conversion, and removes redundant parts in the loop;
  2. We also use some triton operators for artificial operator fusion;
  3. We also use horizontal fusion operators to merge horizontal operators for calculation;
  4. We use the cutlass library for optimization and acceleration;

Currently facebook-DIT takes: 219.936 ms

Copy link

paddle-bot bot commented Jul 29, 2024

Thanks for your contribution!

@CLAassistant
Copy link

CLAassistant commented Jul 29, 2024

CLA assistant check
All committers have signed the CLA.

@chang-wenbin chang-wenbin changed the title modified the dit 对DIT重新组网,固定部分参数,简化模型组网代码 Jul 29, 2024
@chang-wenbin chang-wenbin changed the title 对DIT重新组网,固定部分参数,简化模型组网代码 Re-network the DIT, fix some parameters, and simplify the model networking code Jul 29, 2024
if qkv is not None:
state_dict[qkv_key_b] = paddle.concat([qkv, state_dict.pop(key)], axis=-1)

for key in list(state_dict.keys()):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

518行以下改成

        map_from_my_dit = {}
        for i in range(28):
            map_from_my_dit[f'tmp_ZKKFacebookDIT.qkv.{i}.weight'] = f'transformer_blocks.{i}.attn1.to_qkv.weight'
            map_from_my_dit[f'tmp_ZKKFacebookDIT.qkv.{i}.bias'] = f'transformer_blocks.{i}.attn1.to_qkv.bias'
            map_from_my_dit[f'tmp_ZKKFacebookDIT.out_proj.{i}.weight'] = f'transformer_blocks.{i}.attn1.to_out.0.weight'
            map_from_my_dit[f'tmp_ZKKFacebookDIT.out_proj.{i}.bias'] = f'transformer_blocks.{i}.attn1.to_out.0.bias'
            map_from_my_dit[f'tmp_ZKKFacebookDIT.ffn1.{i}.weight'] = f'transformer_blocks.{i}.ff.net.0.proj.weight'
            map_from_my_dit[f'tmp_ZKKFacebookDIT.ffn1.{i}.bias'] = f'transformer_blocks.{i}.ff.net.0.proj.bias'
            map_from_my_dit[f'tmp_ZKKFacebookDIT.ffn2.{i}.weight'] = f'transformer_blocks.{i}.ff.net.2.weight'
            map_from_my_dit[f'tmp_ZKKFacebookDIT.ffn2.{i}.bias'] = f'transformer_blocks.{i}.ff.net.2.bias'

            map_from_my_dit[f'tmp_ZKKFacebookDIT.fcs0.{i}.weight'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_1.weight'
            map_from_my_dit[f'tmp_ZKKFacebookDIT.fcs0.{i}.bias'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_1.bias'

            map_from_my_dit[f'tmp_ZKKFacebookDIT.fcs1.{i}.weight'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_2.weight'
            map_from_my_dit[f'tmp_ZKKFacebookDIT.fcs1.{i}.bias'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_2.bias'

            map_from_my_dit[f'tmp_ZKKFacebookDIT.fcs2.{i}.weight'] = f'transformer_blocks.{i}.norm1.linear.weight'
            map_from_my_dit[f'tmp_ZKKFacebookDIT.fcs2.{i}.bias'] = f'transformer_blocks.{i}.norm1.linear.bias'

            map_from_my_dit[f'tmp_ZKKFacebookDIT.embs.{i}.weight'] = f'transformer_blocks.{i}.norm1.emb.class_embedder.embedding_table.weight'

        for key in map_from_my_dit.keys():
            state_dict[key] = paddle.assign(state_dict[map_from_my_dit[key]])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已更改!
感谢提供修改意见,辛苦!

def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attention_head_dim: int):
super().__init__()
self.num_layers = num_layers
self.dtype = "float16"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.dtype = "float16"改成可配置的。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已更改!
感谢提供修改意见,辛苦!

@@ -1130,6 +1134,8 @@ def _find_mismatched_keys(
error_msgs.append(
f"Error size mismatch, {key_name} receives a shape {loaded_shape}, but the expected shape is {model_shape}."
)
if os.getenv('Inference_Optimize'):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里去掉,改在transformer_2d.py里面判断吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已更改!
感谢提供修改意见,辛苦!

@@ -28,11 +28,15 @@
recompute_use_reentrant,
use_old_recompute,
)
from .simplified_facebook_dit import Simplified_FacebookDIT
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplified_FacebookDIT改成SimplifiedFacebookDIT

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已更改!
感谢提供修改意见,辛苦!

@@ -213,6 +219,8 @@ def __init__(
for d in range(num_layers)
]
)
if self.Inference_Optimize:
self.simplified_facebookDIT = SimplifiedFacebookDIT(num_layers, inner_dim, num_attention_heads, attention_head_dim)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里del self.transformer_blocks

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

修改该项会引发相关报错,因为该方法还需要在其他位置调用,暂时不做更改!
感谢提供修改意见,辛苦!

@@ -114,6 +118,8 @@ def __init__(
self.inner_dim = inner_dim = num_attention_heads * attention_head_dim
self.data_format = data_format

self.Inference_Optimize = bool(os.getenv('Inference_Optimize'))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.Inference_Optimize = os.getenv('Inference_Optimize') == "True"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已更改!
感谢提供修改意见,辛苦!

Copy link

@vivienfanghuagood vivienfanghuagood left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

应该把代码format一下

return
map_from_my_dit = {}
for i in range(28):
map_from_my_dit[f'simplified_facebookDIT.q.{i}.weight'] = f'transformer_blocks.{i}.attn1.to_q.weight'

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

尽量减少代码的拷贝,例如公共的命名前缀应该抽出来,避免后续修改

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

尽量减少代码的拷贝,例如公共的命名前缀应该抽出来,避免后续修改

已更改,折叠了部分命名代码!
感谢提供修改意见,辛苦!

from ppdiffusers import DDIMScheduler, DiTPipeline

dtype = paddle.float32
os.environ["Inference_Optimize"] = "False"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

环境变量全都大写吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已更改!
感谢提供修改意见,辛苦!

pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256", paddle_dtype=dtype)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
set_seed(42)

words = ["golden retriever"] # class_ids [207]
class_ids = pipe.get_label_ids(words)

# warmup
for i in range(5):
image = pipe(class_labels=class_ids, num_inference_steps=25).images[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里只是为了测benchmark,实际用户并不需要warmpup。看下是否增加benchmark开关。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已更改,添加benchmark & inference_optimize 的相关开关!
感谢提供修改意见,辛苦!



import datetime
import time
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import移动到前面

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已更改!
感谢提供修改意见,辛苦!


image = pipe(class_labels=class_ids, num_inference_steps=25).images[0]
for i in range(repeat_times):
image = pipe(class_labels=class_ids, num_inference_steps=25).images[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,benchmark才需要,用户使用不需要

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已更改!
感谢提供修改意见,辛苦!

enable_new_ir=True,
cache_static_model=False,
exp_enable_use_cutlass=True,
delete_pass_lists=["add_norm_fuse_pass"],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

遵守代码规范,一行不会要超过80字符

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已使用pre-commit调整!
感谢提供修改意见,辛苦!

@@ -114,6 +118,8 @@ def __init__(
self.inner_dim = inner_dim = num_attention_heads * attention_head_dim
self.data_format = data_format

self.Inference_Optimize = os.getenv('Inference_Optimize') == "True"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.inference_optimize ,遵守命名规范

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已更改!
感谢提供修改意见,辛苦!

import paddle.nn.functional as F
import math

class SimplifiedFacebookDIT(nn.Layer):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

必须一定要简化这个模块吗?

Copy link
Contributor

@zhoutianzi666 zhoutianzi666 Aug 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

必须一定要简化这个模块吗?

手工优化需要

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

手工优化需要对原动态图模型组网 做高性能精简重组,这一模块还将transformer循环中的冗余计算部分提出,减少了部分计算量。
感谢提供修改意见,辛苦!

@@ -221,7 +240,9 @@ def __init__(
if use_linear_projection:
self.proj_out = linear_cls(inner_dim, in_channels)
else:
self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0, data_format=data_format)
self.proj_out = conv_cls(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

格式修改请忽略

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

格式修改请忽略

采用pre-commit统一修改格式!
感谢提供修改意见,辛苦!

@@ -154,11 +158,15 @@ def __init__(
if self.is_input_continuous:
self.in_channels = in_channels

self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, epsilon=1e-6, data_format=data_format)
self.norm = nn.GroupNorm(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

格式修改请忽略

if use_linear_projection:
self.proj_in = linear_cls(in_channels, inner_dim)
else:
self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0, data_format=data_format)
self.proj_in = conv_cls(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

格式修改请忽略

@nemonameless nemonameless merged commit aeee830 into PaddlePaddle:develop Aug 28, 2024
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants