You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm seeing random and sudden loss spikes during training, if there is a simpler way of debugging this, I'm open to a new approach. However, I attempted to reproduce the training loop in pytorch such that I could log out abnormal gradients during the training process to detect any erroneous examples in my training data.
However, I'm always getting AttributeError: 'NoneType' object has no attribute 'device' in the forward pass (Full stacktrace below).
I built the model exactly how its done in train.py and my training loop looks like
# Define a threshold for outlier detection
gradient_threshold = 10.0
# Create a DataLoader for iterating through the dataset
train_dataloader = torch.utils.data.DataLoader(data_module['train_dataset'], batch_size=1, shuffle=True)
for batch_idx, batch in enumerate(train_dataloader):
input_ids = batch["input_ids"] # torch.Size([1, 200])
labels = batch["labels"] # torch.Size([1, 200])
image_tensor = batch["image"].half() # torch.Size([1, 3, 336, 336])
# Zero the gradient
optimizer.zero_grad()
# Always errors out here
output = model.forward(input_ids=input_ids, images=image_tensor)
....
The model.forward always fails with the below stacktrace. I've tried the forward pass with and without labels, similar results. After prepare_inputs_labels_for_multimodal call, the inputs look like the following:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[36], line 46
28 # (_input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels_embeds) = model.prepare_inputs_labels_for_multimodal(input_ids=input_ids, position_ids=None, attention_mask=None, past_key_values=None, labels=labels, images=image_tensor)
(...)
44
45 # 4
---> 46 output = model.forward(input_ids=input_ids, images=image_tensor, labels=labels)
47 loss = compute_loss(output.logits, labels)
48 print("LOSS: ", loss.item())
File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/peft/peft_model.py:1129](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/peft/peft_model.py#line=1128), in PeftModelForCausalLM.forward(self, input_ids, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, **kwargs)
1127 with self._enable_peft_forward_hooks(**kwargs):
1128 kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
-> 1129 return self.base_model(
1130 input_ids=input_ids,
1131 attention_mask=attention_mask,
1132 inputs_embeds=inputs_embeds,
1133 labels=labels,
1134 output_attentions=output_attentions,
1135 output_hidden_states=output_hidden_states,
1136 return_dict=return_dict,
1137 **kwargs,
1138 )
1140 batch_size = _get_batch_size(input_ids, inputs_embeds)
1141 if attention_mask is not None:
1142 # concat prompt attention mask
File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/module.py:1518](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1517), in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/module.py:1527](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1526), in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/peft/tuners/tuners_utils.py:161](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/peft/tuners/tuners_utils.py#line=160), in BaseTuner.forward(self, *args, **kwargs)
160 def forward(self, *args: Any, **kwargs: Any):
--> 161 return self.model.forward(*args, **kwargs)
File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/accelerate/hooks.py:165](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/accelerate/hooks.py#line=164), in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
163 output = old_forward(*args, **kwargs)
164 else:
--> 165 output = old_forward(*args, **kwargs)
166 return module._hf_hook.post_forward(module, output)
File [~/LLaVA-pp/LLaVA/llava/model/language_model/llava_llama.py:103](http://34.146.99.81:8888/lab/tree/LLaVA/llava/train/LLaVA/llava/model/language_model/llava_llama.py#line=102), in LlavaLlamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, images, image_sizes, return_dict, cache_position)
101 print("inputs_embeds: ", inputs_embeds.shape)
102 print("labels: ", labels)
--> 103 return super().forward(
104 input_ids=input_ids,
105 attention_mask=attention_mask,
106 position_ids=position_ids,
107 past_key_values=past_key_values,
108 inputs_embeds=inputs_embeds,
109 labels=labels,
110 use_cache=use_cache,
111 output_attentions=output_attentions,
112 output_hidden_states=output_hidden_states,
113 return_dict=return_dict
114 )
File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:1183](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py#line=1182), in LlamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)
1180 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1182 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1183 outputs = self.model(
1184 input_ids=input_ids,
1185 attention_mask=attention_mask,
1186 position_ids=position_ids,
1187 past_key_values=past_key_values,
1188 inputs_embeds=inputs_embeds,
1189 use_cache=use_cache,
1190 output_attentions=output_attentions,
1191 output_hidden_states=output_hidden_states,
1192 return_dict=return_dict,
1193 )
1195 hidden_states = outputs[0]
1196 if self.config.pretraining_tp > 1:
File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/module.py:1518](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1517), in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/module.py:1527](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1526), in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/accelerate/hooks.py:165](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/accelerate/hooks.py#line=164), in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
163 output = old_forward(*args, **kwargs)
164 else:
--> 165 output = old_forward(*args, **kwargs)
166 return module._hf_hook.post_forward(module, output)
File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:1070](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py#line=1069), in LlamaModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)
1060 layer_outputs = self._gradient_checkpointing_func(
1061 decoder_layer.__call__,
1062 hidden_states,
(...)
1067 use_cache,
1068 )
1069 else:
-> 1070 layer_outputs = decoder_layer(
1071 hidden_states,
1072 attention_mask=attention_mask,
1073 position_ids=position_ids,
1074 past_key_value=past_key_values,
1075 output_attentions=output_attentions,
1076 use_cache=use_cache,
1077 )
1079 hidden_states = layer_outputs[0]
1081 if use_cache:
File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/module.py:1518](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1517), in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/module.py:1527](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1526), in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/accelerate/hooks.py:165](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/accelerate/hooks.py#line=164), in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
163 output = old_forward(*args, **kwargs)
164 else:
--> 165 output = old_forward(*args, **kwargs)
166 return module._hf_hook.post_forward(module, output)
File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:798](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py#line=797), in LlamaDecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, **kwargs)
795 hidden_states = self.input_layernorm(hidden_states)
797 # Self Attention
--> 798 hidden_states, self_attn_weights, present_key_value = self.self_attn(
799 hidden_states=hidden_states,
800 attention_mask=attention_mask,
801 position_ids=position_ids,
802 past_key_value=past_key_value,
803 output_attentions=output_attentions,
804 use_cache=use_cache,
805 **kwargs,
806 )
807 hidden_states = residual + hidden_states
809 # Fully Connected
File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/module.py:1518](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1517), in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/module.py:1527](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1526), in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/accelerate/hooks.py:165](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/accelerate/hooks.py#line=164), in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
163 output = old_forward(*args, **kwargs)
164 else:
--> 165 output = old_forward(*args, **kwargs)
166 return module._hf_hook.post_forward(module, output)
File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:494](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py#line=493), in LlamaFlashAttention2.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, **kwargs)
490 output_attentions = False
492 bsz, q_len, _ = hidden_states.size()
--> 494 query_states = self.q_proj(hidden_states)
495 key_states = self.k_proj(hidden_states)
496 value_states = self.v_proj(hidden_states)
File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/module.py:1518](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1517), in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/module.py:1527](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1526), in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/peft/tuners/lora/bnb.py:217](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/peft/tuners/lora/bnb.py#line=216), in Linear8bitLt.forward(self, x, *args, **kwargs)
215 result = self.base_layer(x, *args, **kwargs)
216 else:
--> 217 result = self.base_layer(x, *args, **kwargs)
218 for active_adapter in self.active_adapters:
219 if active_adapter not in self.lora_A.keys():
File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/module.py:1518](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1517), in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/module.py:1527](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1526), in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/accelerate/hooks.py:165](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/accelerate/hooks.py#line=164), in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
163 output = old_forward(*args, **kwargs)
164 else:
--> 165 output = old_forward(*args, **kwargs)
166 return module._hf_hook.post_forward(module, output)
File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/bitsandbytes/nn/modules.py:797](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/bitsandbytes/nn/modules.py#line=796), in Linear8bitLt.forward(self, x)
794 if self.bias is not None and self.bias.dtype != x.dtype:
795 self.bias.data = self.bias.data.to(x.dtype)
--> 797 out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
799 if not self.state.has_fp16_weights:
800 if self.state.CB is not None and self.state.CxB is not None:
801 # we converted 8-bit row major to turing[/ampere](http://34.146.99.81:8888/ampere) format in the first inference pass
802 # we no longer need the row-major weight
File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:556](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py#line=555), in matmul(A, B, out, state, threshold, bias)
554 if threshold > 0.0:
555 state.threshold = threshold
--> 556 return MatMul8bitLt.apply(A, B, out, bias, state)
File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/autograd/function.py:539](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/autograd/function.py#line=538), in Function.apply(cls, *args, **kwargs)
536 if not torch._C._are_functorch_transforms_active():
537 # See NOTE: [functorch vjp and autograd interaction]
538 args = _functorch.utils.unwrap_dead_wrappers(args)
--> 539 return super().apply(*args, **kwargs) # type: ignore[misc]
541 if cls.setup_context == _SingleLevelFunction.setup_context:
542 raise RuntimeError(
543 "In order to use an autograd.Function with functorch transforms "
544 "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
545 "staticmethod. For more details, please see "
546 "https://pytorch.org/docs/master/notes/extending.func.html"
547 )
File /opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:335, in MatMul8bitLt.forward(ctx, A, B, out, bias, state)
331 else:
332 if state.CxB is None and using_igemmlt:
333 # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
334 # we also need to convert it to the turing[/ampere](http://34.146.99.81:8888/ampere) format
--> 335 state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
336 else:
337 if not state.has_fp16_weights and state.CxB is None and using_igemmlt:
File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/bitsandbytes/functional.py:2597](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/bitsandbytes/functional.py#line=2596), in transform(A, to_order, from_order, out, transpose, state, ld)
2596 def transform(A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None):
-> 2597 prev_device = pre_call(A.device)
2598 if state is None:
2599 state = (A.shape, from_order)
AttributeError: 'NoneType' object has no attribute 'device'
The text was updated successfully, but these errors were encountered:
Thank you for your interest in our work. Did you try to upgrade the transformers to the latest version? Please note that LLaMA-3 based trainings are only supported with "transformers==4.41+" which you can install as follows,
Environment
Issue
I'm seeing random and sudden loss spikes during training, if there is a simpler way of debugging this, I'm open to a new approach. However, I attempted to reproduce the training loop in pytorch such that I could log out abnormal gradients during the training process to detect any erroneous examples in my training data.
However, I'm always getting
AttributeError: 'NoneType' object has no attribute 'device'
in the forward pass (Full stacktrace below).I built the model exactly how its done in train.py and my training loop looks like
The
model.forward
always fails with the below stacktrace. I've tried the forward pass with and without labels, similar results. Afterprepare_inputs_labels_for_multimodal
call, the inputs look like the following:Below is the full stacktrace and the model layers. What am I missing?
Model
Full StackTrace
The text was updated successfully, but these errors were encountered: