Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama Vision PEFT #1937

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ checkpointer:
output_dir: /tmp/Llama-3.2-11B-Vision-Instruct/
model_type: LLAMA3_VISION
resume_from_checkpoint: False
save_adapter_weights_only: False # PeFT formatting not available yet. This will save it in torchtune format only.
save_adapter_weights_only: False

# Dataset
dataset:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ checkpointer:
output_dir: /tmp/Llama-3.2-11B-Vision-Instruct/
model_type: LLAMA3_VISION
resume_from_checkpoint: False
save_adapter_weights_only: False # PeFT formatting not available yet. This will save it in torchtune format only.
save_adapter_weights_only: False

# Dataset
dataset:
Expand Down
15 changes: 15 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# TODO: (philip) remove after tests

from transformers import AutoModelForCausalLM # , AutoTokenizer

model_id = "meta-llama/Llama-3.2-11B-Vision"
peft_model_id = "/tmp/Llama-3.2-11B-Vision-Instruct/"

model = AutoModelForCausalLM.from_pretrained(model_id)
model.load_adapter(peft_model_id)
102 changes: 102 additions & 0 deletions torchtune/models/llama3_2_vision/_convert_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,3 +427,105 @@ def _permute(t, n_heads):

converted_state_dict[new_key] = value
return converted_state_dict


# Mapping from torchtune LoRA module names to PEFT LoRA module names
_TO_PEFT_KEYS = {
"lora_a": "lora_A",
"lora_b": "lora_B",
"magnitude": "lora_magnitude_vector",
}


def _get_peft_dict(tune_to_hf_dict: Dict[str, str]) -> Dict[str, torch.Tensor]:
"""
Rather than recreate a separate mapping for LoRA adapter weights, we just
re-use the _FROM_HF mapping for base model weights. We iterate over it twice:
once to add mappings for LoRA A matrices and once to add mappings for LoRA B matrices.
"""
new_mapping_dict = {}
for tune_peft, hf_peft in _TO_PEFT_KEYS.items():
for tune_key, hf_key in tune_to_hf_dict.items():
if hf_key is None or tune_key is None:
continue

if tune_peft == "magnitude":
# e.g. attn.q_proj.magnitude -> attn.q_proj.lora_magnitude_vector
tune_adapter = tune_key.replace(".weight", f".{tune_peft}")
hf_adapter = hf_key.replace(".weight", f".{hf_peft}")
else:
# e.g. attn.q_proj.lora_a.weight -> attn.q_proj.lora_A.weight
tune_adapter = tune_key.replace(".weight", f".{tune_peft}.weight")
hf_adapter = hf_key.replace(".weight", f".{hf_peft}.weight")

new_mapping_dict[tune_adapter] = hf_adapter
return new_mapping_dict


def llama3_vision_tune_to_peft_adapter_weights(
state_dict: Dict[str, torch.Tensor],
num_heads: int = 32,
num_kv_heads: int = 32,
dim: int = 4096,
head_dim: int = None,
cross_attention_layers: Optional[List[int]] = None,
) -> Dict[str, torch.Tensor]:
"""
Convertor from Tune state dict to HF state dict. This handles:
- Updateing the cross attention layer numbers
- skip loading the rope embeddings
- reshaping q, k projections
"""
converted_state_dict = {}
inverted_mapping_dict = {v: k for k, v in _FROM_HF.items()}
# missing keys in _FROM_HF due to naming collisions
missing_keys = {
"decoder.layers.{}.fusion_layer.ca_norm.scale": "language_model.model.layers.{}.input_layernorm.weight",
"decoder.layers.{}.fusion_layer.mlp_norm.scale": "language_model.model.layers.{}.post_attention_layernorm.weight",
"decoder.layers.{}.fusion_layer.mlp.w1.weight": "language_model.model.layers.{}.mlp.gate_proj.weight",
"decoder.layers.{}.fusion_layer.mlp.w3.weight": "language_model.model.layers.{}.mlp.up_proj.weight",
"decoder.layers.{}.fusion_layer.mlp.w2.weight": "language_model.model.layers.{}.mlp.down_proj.weight",
"decoder.tok_embeddings.fusion_embedding.weight": None,
}
inverted_mapping_dict.update(missing_keys)
inverted_mapping_dict = _get_peft_dict(inverted_mapping_dict)

if head_dim is None:
head_dim = dim // num_heads
if cross_attention_layers is None:
cross_attention_layers = []
# convert hf layer numbers to tune numbers
cross_attention_layers = [
l - i for i, l in enumerate(sorted(cross_attention_layers))
]

def _permute_lora_matrix(t, n_heads):
rank = t.shape[-1]
return (
t.view(n_heads, head_dim // 2, 2, rank)
.transpose(1, 2)
.reshape((head_dim * n_heads), rank)
)

for key, value in state_dict.items():
# if key == "decoder.layers.3.layer.attn.q_proj.lora_a.weight":
# import pdb; pdb.set_trace()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

new_key = get_mapped_key(key, inverted_mapping_dict)
if "decoder" in key:
if "layers" in key: # Update layer numbers
layer = int(key.split(".")[2])
num_shifts = sum(layer > l for l in cross_attention_layers)
new_layer = layer + num_shifts
key_lst = new_key.split(".")
if layer in cross_attention_layers and "fusion_layer" not in key:
new_layer += 1 # hf treats the fusion_layer as an additional layer
key_lst[3] = str(new_layer)
new_key = ".".join(key_lst)
if "q_proj" in key and "lora_B" in new_key and "cross_attn" not in new_key:
value = _permute_lora_matrix(value, num_heads)
elif (
"k_proj" in key and "lora_B" in new_key and "cross_attn" not in new_key
):
value = _permute_lora_matrix(value, num_kv_heads)
converted_state_dict["base_model.model." + new_key] = value
return converted_state_dict
76 changes: 43 additions & 33 deletions torchtune/training/checkpointing/_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,7 @@ def save_checkpoint(
)

if training.ADAPTER_KEY in state_dict:
# import pdb; pdb.set_trace()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

# Save torchtune format adapter weights even if we save PEFT format
# This way we can resume no matter what (and memory footprint of adapter weights is small)
output_path = Path.joinpath(
Expand All @@ -651,20 +652,38 @@ def save_checkpoint(
logger.warning(
"Saving Phi-3 Mini adapter weights to PEFT format is not supported, saving to torchtune format instead"
)
elif self._model_type == ModelType.LLAMA3_VISION:
elif self._model_type == ModelType.QWEN2:
logger.warning(
"Saving Llama3.2 Vision adapter weights to PEFT format is not supported, saving to torchtune format instead"
"Saving QWEN2 adapter weights to PEFT format is not supported, saving to torchtune format instead"
)
else:
state_dict[
training.ADAPTER_KEY
] = convert_weights.tune_to_peft_adapter_weights(
state_dict[training.ADAPTER_KEY],
num_heads=self._config["num_attention_heads"],
num_kv_heads=self._config["num_key_value_heads"],
dim=self._config["hidden_size"],
head_dim=self._config.get("head_dim", None),
)
if self._model_type == ModelType.LLAMA3_VISION:
from torchtune.models.llama3_2_vision._convert_weights import (
llama3_vision_tune_to_peft_adapter_weights,
)

state_dict[
training.ADAPTER_KEY
] = llama3_vision_tune_to_peft_adapter_weights(
state_dict[training.ADAPTER_KEY],
num_heads=text_config["num_attention_heads"],
num_kv_heads=text_config["num_key_value_heads"],
dim=text_config["hidden_size"],
head_dim=text_config.get("head_dim", None),
cross_attention_layers=text_config.get(
"cross_attention_layers", None
),
)
else:
state_dict[
training.ADAPTER_KEY
] = convert_weights.tune_to_peft_adapter_weights(
state_dict[training.ADAPTER_KEY],
num_heads=self._config["num_attention_heads"],
num_kv_heads=self._config["num_key_value_heads"],
dim=self._config["hidden_size"],
head_dim=self._config.get("head_dim", None),
)
peft_output_path = Path.joinpath(
self._output_dir, "adapter_model"
).with_suffix(".bin")
Expand All @@ -680,28 +699,19 @@ def save_checkpoint(
)

if training.ADAPTER_CONFIG in state_dict:
if self._model_type == ModelType.PHI3_MINI:
logger.warning(
"PEFT integration for Phi-3 Mini is not supported, skipping adapter config save"
)
elif self._model_type == ModelType.LLAMA3_VISION:
logger.warning(
"PEFT integration for Llama3.2 Vision is not supported, skipping adapter config save"
)
else:
state_dict[
training.ADAPTER_CONFIG
] = convert_weights.tune_to_peft_adapter_config(
state_dict[training.ADAPTER_CONFIG]
)
output_path = Path.joinpath(self._output_dir, "adapter_config.json")
with open(output_path, "w") as f:
json.dump(state_dict[training.ADAPTER_CONFIG], f)
logger.info(
"Adapter checkpoint of size "
f"{os.path.getsize(output_path) / 1000**3:.2f} GB "
f"saved to {output_path}"
)
state_dict[
training.ADAPTER_CONFIG
] = convert_weights.tune_to_peft_adapter_config(
state_dict[training.ADAPTER_CONFIG]
)
output_path = Path.joinpath(self._output_dir, "adapter_config.json")
with open(output_path, "w") as f:
json.dump(state_dict[training.ADAPTER_CONFIG], f)
logger.info(
"Adapter checkpoint of size "
f"{os.path.getsize(output_path) / 1000**3:.2f} GB "
f"saved to {output_path}"
)

# If the recipe state needs to be output, first remove the model state dict
# and if it exists, remove the adapter state dict as well
Expand Down
Loading