Skip to content

Commit

Permalink
add validation for num_moe
Browse files Browse the repository at this point in the history
  • Loading branch information
SangbumChoi committed Nov 18, 2024
1 parent 5207b57 commit 96b4da9
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 34 deletions.
55 changes: 25 additions & 30 deletions src/transformers/models/vitpose/convert_vitpose_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,44 +191,39 @@ def write_model(model_path, model_name, push_to_hub):
new_key = new_keys[key]
value = original_state_dict[key]

if re.search("qkv", new_key):
if "weight" in new_key:
state_dict[new_key.replace("self.qkv", "attention.query")] = value[:dim, :]
state_dict[new_key.replace("self.qkv", "attention.key")] = value[dim : dim * 2, :]
state_dict[new_key.replace("self.qkv", "attention.value")] = value[-dim:, :]
else:
if re.search("qkv", new_key):
state_dict[new_key.replace("self.qkv", "attention.query")] = value[:dim]
state_dict[new_key.replace("self.qkv", "attention.key")] = value[dim : dim * 2]
state_dict[new_key.replace("self.qkv", "attention.value")] = value[-dim:]

elif re.search("head", new_key) and not config.use_simple_decoder:
# Pattern for deconvolution layers
print(new_key)
deconv_pattern = r"deconv_layers\.(0|3)\.weight"
new_key = re.sub(deconv_pattern, lambda m: f"deconv{int(m.group(1))//3 + 1}.weight", new_key)
# Pattern for batch normalization layers
bn_patterns = [
(r"deconv_layers\.(\d+)\.weight", r"batchnorm\1.weight"),
(r"deconv_layers\.(\d+)\.bias", r"batchnorm\1.bias"),
(r"deconv_layers\.(\d+)\.running_mean", r"batchnorm\1.running_mean"),
(r"deconv_layers\.(\d+)\.running_var", r"batchnorm\1.running_var"),
(r"deconv_layers\.(\d+)\.num_batches_tracked", r"batchnorm\1.num_batches_tracked"),
]

for pattern, replacement in bn_patterns:
if re.search(pattern, new_key):
# Convert the layer number to the correct batch norm index
layer_num = int(re.search(pattern, key).group(1))
bn_num = layer_num // 3 + 1
new_key = re.sub(pattern, replacement.replace(r"\1", str(bn_num)), new_key)
state_dict[new_key] = value
else:
state_dict[new_key] = value
elif re.search("head", new_key) and not config.use_simple_decoder:
# Pattern for deconvolution layers
print(new_key)
deconv_pattern = r"deconv_layers\.(0|3)\.weight"
new_key = re.sub(deconv_pattern, lambda m: f"deconv{int(m.group(1))//3 + 1}.weight", new_key)
# Pattern for batch normalization layers
bn_patterns = [
(r"deconv_layers\.(\d+)\.weight", r"batchnorm\1.weight"),
(r"deconv_layers\.(\d+)\.bias", r"batchnorm\1.bias"),
(r"deconv_layers\.(\d+)\.running_mean", r"batchnorm\1.running_mean"),
(r"deconv_layers\.(\d+)\.running_var", r"batchnorm\1.running_var"),
(r"deconv_layers\.(\d+)\.num_batches_tracked", r"batchnorm\1.num_batches_tracked"),
]

for pattern, replacement in bn_patterns:
if re.search(pattern, new_key):
# Convert the layer number to the correct batch norm index
layer_num = int(re.search(pattern, key).group(1))
bn_num = layer_num // 3 + 1
new_key = re.sub(pattern, replacement.replace(r"\1", str(bn_num)), new_key)
state_dict[new_key] = value
else:
state_dict[new_key] = value

print("Loading the checkpoint in a Vitpose model.")
model = VitPoseForPoseEstimation(config)
model.eval()
model.load_state_dict(state_dict, strict=False)
model.load_state_dict(state_dict)
print("Checkpoint loaded successfully.")

# create image processor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def __init__(self, config: VitPoseBackboneConfig) -> None:
self.activation = ACT2FN[config.hidden_act]
self.fc2 = nn.Linear(hidden_features, out_features, bias=True)

def forward(self, hidden_state: torch.Tensor, indices=None) -> torch.Tensor:
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
hidden_state = self.fc1(hidden_state)
hidden_state = self.activation(hidden_state)
hidden_state = self.fc2(hidden_state)
Expand All @@ -281,8 +281,9 @@ def forward(self, hidden_state: torch.Tensor, indices=None) -> torch.Tensor:
class VitPoseBackboneLayer(nn.Module):
def __init__(self, config: VitPoseBackboneConfig) -> None:
super().__init__()
self.num_experts = config.num_experts
self.attention = VitPoseBackboneAttention(config)
self.mlp = VitPoseBackboneMLP(config) if config.num_experts == 1 else VitPoseBackboneMoeMLP(config)
self.mlp = VitPoseBackboneMLP(config) if self.num_experts == 1 else VitPoseBackboneMoeMLP(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

Expand All @@ -293,6 +294,13 @@ def forward(
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
# Validate dataset_index when using multiple experts
if self.num_experts > 1 and dataset_index is None:
raise ValueError(
"dataset_index must be provided when using multiple experts "
f"(num_experts={self.num_experts}). Please provide dataset_index "
"to the forward pass."
)
self_attention_outputs = self.attention(
self.layernorm_before(hidden_states), # in VitPoseBackbone, layernorm is applied before self-attention
head_mask,
Expand All @@ -305,7 +313,10 @@ def forward(
hidden_states = attention_output + hidden_states

layer_output = self.layernorm_after(hidden_states)
layer_output = self.mlp(layer_output, indices=dataset_index)
if self.num_experts == 1:
layer_output = self.mlp(layer_output)
else:
layer_output = self.mlp(layer_output, indices=dataset_index)

# second residual connection
layer_output = layer_output + hidden_states
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@


if is_vision_available():
from PIL import Image
pass


class VitPoseBackboneModelTester:
Expand Down

0 comments on commit 96b4da9

Please sign in to comment.