-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Replies: 1 comment · 5 replies
-
It's very unlikely that the task type is the issue here, probably there is something else going on. Would it be possible for you to share the code? If not, could you share the output of calling model.print_trainable_parameters()
model.get_model_status()
model.get_layer_status() |
Beta Was this translation helpful? Give feedback.
All reactions
-
Hi, sorry for the delay! I share now the requested outputs:
Do you have an hint? Thank you! |
Beta Was this translation helpful? Give feedback.
All reactions
-
From these outputs, I don't see anything wrong, although they're a little bit garbled. Would it be possible for you to share the training code?
Are they 100% the same or just moving very slowly? |
Beta Was this translation helpful? Give feedback.
All reactions
-
Hi, sorry I can not share the whole code but a snippet of what I do on the git_vatex_base_model from hugging face. The losses are not 100% the same but basically not changing a lot: for comparison by fine-tuning of the base model the losses decreases from ~11 to 0.1 in 5-10 epochs according to learning rate. The behaviour below is pretty the same if I change the lr in range 10^-2 - 10^-5 (!)
def lora(self, lora_config):
# Apply low-rank decomposition using PEFT - documentation about module selections and parameters:
# https://stackoverflow.com/questions/76768226/target-modules-for-applying-peft-lora-on-different-models
# https://medium.com/@manyi.yim/more-about-loraconfig-from-peft-581cf54643db
# modules supported: `torch.nn.Linear`, `torch.nn.Embedding`, `torch.nn.Conv2d`, `transformers.pytorch_utils.Conv1D`
try:
# print info about model parameters
self._info()
# Predefined layer types and regular expression
predefined_layer_types = ['Linear', 'Conv2d'] # Add more types as needed
regex_patterns= 'self_attn|mlp'
# List to store matching layer names
matching_layer_names = []
# Iterate through the named modules
for name, module in self.model.git.named_modules():
if any(isinstance(module, getattr(torch.nn, layer_type)) for layer_type in predefined_layer_types) and re.search(regex_patterns, name):
matching_layer_names.append(name)
logging.info(f"{matching_layer_names}")
def get_num_parameters(layer_names):
total_parameters = 0
for name, param in self.model.git.named_parameters():
if any(layer_name in name for layer_name in layer_names):
total_parameters += param.numel()
return total_parameters
total_parameters = get_num_parameters(matching_layer_names)
logging.info(f"total parameters : {total_parameters}")
self.selected_layers=matching_layer_names
except Exception as err:
logging.error(f"selecting layers for LoRa failed:{err}")
#sys.exit(1)
try:
peft_config = LoraConfig(task_type = "CAUSAL_LM",
target_modules = self.selected_layers,
inference_mode = False,
r = lora_config["rank"],
lora_alpha = lora_config["alpha"],
lora_dropout = lora_config["dropout"],
bias = "none")
#modules_to_save =[""] #? ensures that these modules are serialized alongside the LoRA trainable parameters when using utilities like save_pretrained()
self.peft_model = get_peft_model(self.model, peft_config)
logging.info(f"parameters after LoRa: {self.peft_model.print_trainable_parameters()}")
return self.peft_model
except Exception as err:
logging.error(f"LoRa decomposition failed: {err}")
raise LoRaException("LoRa decomposition failed") ` and the config is Thank you in advance for your Ideas! |
Beta Was this translation helpful? Give feedback.
All reactions
-
Thanks for the additional context. I did a quick check if the model is trainable at all and it appears that yes, it works. Here is a reproducer that closely follows your code for model initialization and then trains a dummy objective: import re
import requests
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor, pipeline
from peft import LoraConfig, get_peft_model
processor = AutoProcessor.from_pretrained("microsoft/git-base-vatex")
model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-vatex")
predefined_layer_types = ['Linear', 'Conv2d'] # Add more types as needed
regex_patterns= 'self_attn|mlp'
# List to store matching layer names
matching_layer_names = []
# Iterate through the named modules
for name, module in model.git.named_modules():
if any(isinstance(module, getattr(torch.nn, layer_type)) for layer_type in predefined_layer_types) and re.search(regex_patterns, name):
matching_layer_names.append(name)
lora_config = {"rank":8,"alpha":16,"dropout":0.05}
peft_config = LoraConfig(
task_type = "CAUSAL_LM",
target_modules = matching_layer_names,
inference_mode = False,
r = lora_config["rank"],
lora_alpha = lora_config["alpha"],
lora_dropout = lora_config["dropout"],
bias = "none"
)
peft_model = get_peft_model(model, peft_config)
# dummy task to train on
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
text = "this is an image of two cats"
inputs = processor(text, images=image, return_tensors="pt")
if torch.cuda.is_available:
peft_model.to(0)
inputs = inputs.to(0)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-5)
for i in range(30):
optimizer.zero_grad()
outputs = model(**inputs)
loss = (outputs.logits ** 2).sum()
loss.backward()
optimizer.step()
print(i, loss.item()) For me, the loss starts out at 11,141,211 and then monotonically decreases to 2,626,418 at epoch 30. Could you try if you can reproduce this? If yes, this means that the model is not fundamentally broken, but it's more likely related to the training process itself. If it's possible for you to show the training code, I can try to help there. |
Beta Was this translation helpful? Give feedback.
All reactions
-
Hi, thanks a lot for the example. I used it with an image of mine (somehow couldnt access the one you entered): the losses are indeed reduced rapidly ... I will think about and come to you with elements of the training code. Thanks for the moment! |
Beta Was this translation helpful? Give feedback.
All reactions
-
👍 1
-
Hello, I have done with success a fine-tuning of the Generative Image-to-text Transformer for Vision and Language (git-base-vatex from hugging.face imported as AutoModelForCausalLM.from_pretrained). I use it for video-captioning. When I perform LoRa on it using the peft framework (peft config and get_peft_model) and try to train it in native pytorch manner, the model doesn´t learn anything (the losses remain pretty the same) indipendently form the learning rate (in range 10^-2 - 10^-5) and if I select myself the layers or leave it per default. I wonder if it caused by the task type choice: I selected CAUSAL_LM but it does the same if I do not set it. Is my kind of model perhaps not supported? Has anyone suggestions? Thank you very much!
Beta Was this translation helpful? Give feedback.
All reactions