You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The class method encodes the response key new line into token IDs using the tokenizer, and searches for the start position of the response key in each example's label tensor. Once the start position of the response key is found, the label tensor is modified to mask out all tokens before the end of the response key. This is done by setting the label IDs for those tokens to -100, which is a special value that tells the PyTorch loss function to ignore them.
classDataCollatorForCompletionOnlyLM(transformers.DataCollatorForLanguageModeling):
deftorch_call(self, examples):
batch=super().torch_call(examples)
# The prompt ends with the response key plus a newline. We encode this and then try to find it in the# sequence of tokens. This should just be a single token.response_token_ids=self.tokenizer.encode(RESPONSE_KEY_NL)
labels=batch["labels"].clone()
foriinrange(len(examples)): # one batch, batch["labels"][i] gets every prompt labelresponse_token_ids_start_idx=Noneforidxinnp.where(batch["labels"][i] ==response_token_ids[0])[0]:
response_token_ids_start_idx=idxbreakifresponse_token_ids_start_idxisnotNone:
response_token_ids_end_idx=response_token_ids_start_idx+1# Make pytorch loss function ignore all tokens up through the end of the response keylabels[i, :response_token_ids_end_idx] =-100batch["labels"] =labels
I think this is intend to mark every prompt from the beginning to the response key as a special mark, so that the model can focus more on training the response content. But I found this function cannot achieve this purpose in llm-on-ray now.
For example:
When parameter group is false
Prompt content:
Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
Which is a species of fish? Tope or Rope
### Response:
Tope
### End
Expected preprocessing results
-100 -100 -100 -100 -100 -100 -100…………
Tope
### End
Actual preprocessing results
-100 -100 -100 -100 -100 -100 -100…………
-100 Instruction:
Which is a species of fish? Tope or Rope
### Response:
Tope
### End
This is because only response_token_ids[0] (###) is compared in np.where(batch["labels"][i] == response_token_ids[0]), causing "### Instruction" to be discovered first instead of "### Response".
When parameter group is true
When group is true, multiple prompts will be combined into a new prompt. Here is an example of splicing two prompts. Prompt content:
Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
Which is a species of fish? Tope or Rope
### Response:
Tope
### End
Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
Alice's parents have three daughters: Amy, Jessy, and what’s the name of the third daughter?
### Response:
The name of the third daughter is Alice
### End
Expected preprocessing results
-100 -100 -100 -100 -100 -100 -100…………
Tope
### End
-100 -100 -100 -100 -100 -100 -100…………
The name of the third daughter is Alice
### End
Actual preprocessing results
-100 -100 -100 -100 -100 -100 -100…………
-100 Instruction:
Which is a species of fish? Tope or Rope
### Response:
Tope
### End
Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
Alice's parents have three daughters: Amy, Jessy, and what’s the name of the third daughter?
### Response:
The name of the third daughter is Alice
### End
This is because it breaks when the first one is found, and the second prompt spliced in will not be processed.
The text was updated successfully, but these errors were encountered:
DataCollatorForCompletionOnlyLM
seems to be working not as expected, I'm not sure if this will affect the performance of finetuning.We can know from this doc Fine Tuning Your Own ChatGPT-like Model that the purpose of function
DataCollatorForCompletionOnlyLM
in the dataset preprocessing is:I think this is intend to mark every prompt from the beginning to the response key as a special mark, so that the model can focus more on training the response content. But I found this function cannot achieve this purpose in llm-on-ray now.
For example:
When parameter group is false
Prompt content:
Expected preprocessing results
Actual preprocessing results
This is because only response_token_ids[0] (###) is compared in
np.where(batch["labels"][i] == response_token_ids[0])
, causing "### Instruction" to be discovered first instead of "### Response".When parameter group is true
When group is true, multiple prompts will be combined into a new prompt. Here is an example of splicing two prompts.
Prompt content:
Expected preprocessing results
Actual preprocessing results
This is because it breaks when the first one is found, and the second prompt spliced in will not be processed.
The text was updated successfully, but these errors were encountered: