Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A question about finetune dataset processing #234

Open
KepingYan opened this issue May 22, 2024 · 1 comment
Open

A question about finetune dataset processing #234

KepingYan opened this issue May 22, 2024 · 1 comment

Comments

@KepingYan
Copy link
Contributor

DataCollatorForCompletionOnlyLM seems to be working not as expected, I'm not sure if this will affect the performance of finetuning.

We can know from this doc Fine Tuning Your Own ChatGPT-like Model that the purpose of function DataCollatorForCompletionOnlyLM in the dataset preprocessing is:

The class method encodes the response key new line into token IDs using the tokenizer, and searches for the start position of the response key in each example's label tensor. Once the start position of the response key is found, the label tensor is modified to mask out all tokens before the end of the response key. This is done by setting the label IDs for those tokens to -100, which is a special value that tells the PyTorch loss function to ignore them.

class DataCollatorForCompletionOnlyLM(transformers.DataCollatorForLanguageModeling):
    def torch_call(self, examples):
        batch = super().torch_call(examples)
        # The prompt ends with the response key plus a newline.  We encode this and then try to find it in the
        # sequence of tokens.  This should just be a single token.
        response_token_ids = self.tokenizer.encode(RESPONSE_KEY_NL)
        labels = batch["labels"].clone()
        for i in range(len(examples)):             # one batch, batch["labels"][i] gets every prompt label
            response_token_ids_start_idx = None
            for idx in np.where(batch["labels"][i] == response_token_ids[0])[0]:
                response_token_ids_start_idx = idx
                break
            if response_token_ids_start_idx is not None:
                response_token_ids_end_idx = response_token_ids_start_idx + 1
                # Make pytorch loss function ignore all tokens up through the end of the response key
                labels[i, :response_token_ids_end_idx] = -100
        batch["labels"] = labels

I think this is intend to mark every prompt from the beginning to the response key as a special mark, so that the model can focus more on training the response content. But I found this function cannot achieve this purpose in llm-on-ray now.

For example:

When parameter group is false

Prompt content:

Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
Which is a species of fish? Tope or Rope

### Response:
Tope

### End

Expected preprocessing results

  -100 -100 -100 -100 -100 -100 -100…………
  Tope
  
  ### End

Actual preprocessing results

  -100 -100 -100 -100 -100 -100 -100…………
  -100 Instruction:
  Which is a species of fish? Tope or Rope
  
  ### Response:
  Tope
  
  ### End

This is because only response_token_ids[0] (###) is compared in np.where(batch["labels"][i] == response_token_ids[0]), causing "### Instruction" to be discovered first instead of "### Response".

When parameter group is true

When group is true, multiple prompts will be combined into a new prompt. Here is an example of splicing two prompts.
Prompt content:

Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
Which is a species of fish? Tope or Rope

### Response:
Tope

### End
Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
Alice's parents have three daughters: Amy, Jessy, and what’s the name of the third daughter?

### Response:
The name of the third daughter is Alice

### End

Expected preprocessing results

-100 -100 -100 -100 -100 -100 -100…………
Tope

### End
-100 -100 -100 -100 -100 -100 -100…………
The name of the third daughter is Alice

### End

Actual preprocessing results

-100 -100 -100 -100 -100 -100 -100…………
-100 Instruction:
Which is a species of fish? Tope or Rope

### Response:
Tope

### End
Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
Alice's parents have three daughters: Amy, Jessy, and what’s the name of the third daughter?

### Response:
The name of the third daughter is Alice

### End

This is because it breaks when the first one is found, and the second prompt spliced in will not be processed.

@KepingYan
Copy link
Contributor Author

@harborn @minmingzhu

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant