Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove "generation_length" in favor of "generation_kwargs" #3014

Merged
merged 26 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
469d31f
kill generation_length
maxisawesome Feb 15, 2024
ef4b6a1
Merge branch 'mosaicml_dev' into kill_generation_length
maxisawesome Feb 15, 2024
a441c1e
fix tests
maxisawesome Feb 15, 2024
474ac42
Merge branch 'mosaicml_dev' into kill_generation_length
maxisawesome Feb 16, 2024
1d0a7fe
fix test
maxisawesome Feb 16, 2024
ada4368
Merge branch 'mosaicml_dev' into kill_generation_length
maxisawesome Feb 16, 2024
ea3a63e
add deprecation warning
maxisawesome Feb 16, 2024
4e98bfa
fix test
maxisawesome Feb 20, 2024
1d3b7cd
Merge branch 'dev' into kill_generation_length
maxisawesome Feb 20, 2024
74ce928
Merge branch 'dev' into kill_generation_length
maxisawesome Feb 20, 2024
de47894
add gen_len back into static_keys
maxisawesome Feb 21, 2024
4599ae7
simplify setting variable in forward and add test
maxisawesome Feb 21, 2024
fe8a1c6
Merge branch 'dev' into kill_generation_length
dakinggg Feb 21, 2024
a17d3fb
simply test
maxisawesome Feb 21, 2024
06c352a
trailing comma
maxisawesome Feb 21, 2024
faefc3d
trailing comma
maxisawesome Feb 21, 2024
3de388a
Merge branch 'mosaicml_dev' into kill_generation_length
maxisawesome Feb 22, 2024
2ba9dbd
linting
maxisawesome Feb 22, 2024
c9b06a7
Merge branch 'dev' into kill_generation_length
maxisawesome Feb 22, 2024
c6ade3c
Merge branch 'dev' into kill_generation_length
maxisawesome Feb 22, 2024
d66f7c2
Merge branch 'dev' into kill_generation_length
maxisawesome Feb 22, 2024
87809ef
Merge branch 'dev' into kill_generation_length
maxisawesome Feb 23, 2024
06dd518
Merge branch 'dev' into kill_generation_length
maxisawesome Feb 23, 2024
27c1a81
Merge branch 'dev' into kill_generation_length
maxisawesome Feb 23, 2024
53b1388
Merge branch 'dev' into kill_generation_length
maxisawesome Feb 26, 2024
5649b20
Merge branch 'dev' into kill_generation_length
dakinggg Feb 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 3 additions & 12 deletions composer/datasets/in_context_learning_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,14 +689,7 @@ def __init__(
self.cot_delimiter = cot_delimiter
self.has_cot = False
self.max_answer_length = 0
static_keys = [
'mode',
'cot_delimiter',
'generation_length',
'generation_kwargs',
'do_normalization',
'stopping_criteria',
]
static_keys = ['mode', 'cot_delimiter', 'generation_kwargs', 'do_normalization', 'stopping_criteria']
maxisawesome marked this conversation as resolved.
Show resolved Hide resolved
tensor_keys = ['input_ids', 'attention_mask']
list_keys = ['labels']
super().__init__(
Expand All @@ -715,10 +708,10 @@ def __init__(
'mode': 'generate',
'labels': [],
'cot_delimiter': self.cot_delimiter,
'generation_length': self.max_answer_length,
'stopping_criteria': early_stopping_criteria,
'do_normalization': do_normalization,
'generation_kwargs': {
'max_new_tokens': self.max_answer_length,
'pad_token_id': self.pad_tok_id,
'use_cache': True,
'eos_token_id': self.tokenizer.eos_token_id,
Expand Down Expand Up @@ -1260,7 +1253,6 @@ class InContextLearningCodeEvalDataset(InContextLearningDataset):
- test_outputs: List of test outputs
- languages: List of languages
- pass_at_k: Passed value for pass_at_k
- generation_length: Derrived maximum generation length
- generation_kwargs: Dictionary of kwargs neeeded for generation. Includes the following, which will be individually overwritten
by keys in generaiton_kwargs if set (see https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig
for more details):
Expand Down Expand Up @@ -1305,7 +1297,6 @@ def __init__(
static_keys = [
'mode',
'pass_at_k',
'generation_length',
'generation_kwargs',
'generations_per_sample',
'dataset_size',
Expand Down Expand Up @@ -1349,14 +1340,14 @@ def __init__(
'test_outputs': [],
'languages': [],
'pass_at_k': pass_at_k,
'generation_length': min(self.max_answer_length, self.max_seq_len - self.max_prompt_length),
'generation_kwargs': {
'pad_token_id': self.pad_tok_id,
'num_beams': 1, # single beam
'do_sample': True,
'temperature': 0.2, # good default for code
'use_cache': True,
'eos_token_id': self.tokenizer.eos_token_id,
'max_new_tokens': min(self.max_answer_length, self.max_seq_len - self.max_prompt_length),
},
'sample_id': [],
'pass_at_k': list(pass_at_k),
Expand Down
1 change: 0 additions & 1 deletion composer/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,6 @@ def eval_forward(self, batch, outputs: Optional[Any] = None):
self.labels = batch.pop('labels')
generation = self.generate(batch['input_ids'],
attention_mask=batch['attention_mask'],
maxisawesome marked this conversation as resolved.
Show resolved Hide resolved
max_new_tokens=batch['generation_length'],
synced_gpus=dist.get_world_size() > 1,
**batch.get('generation_kwargs', {}))

Expand Down
23 changes: 12 additions & 11 deletions tests/datasets/test_in_context_learning_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def test_update_generation_kwargs_no_kwargs_qa_dataset(tmp_path):
continuation_delimiter=': ',
destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'),
generation_kwargs=None)
assert len(dl.base_batch['generation_kwargs']) == 3
assert len(dl.base_batch['generation_kwargs']) == 4


def test_update_generation_kwargs_with_kwargs_qa_dataset(tmp_path):
Expand All @@ -321,7 +321,7 @@ def test_update_generation_kwargs_with_kwargs_qa_dataset(tmp_path):
generation_kwargs={'temperature': 0.9})
assert 'generation_kwargs' in dl.base_batch
assert dl.base_batch['generation_kwargs']['temperature'] == 0.9
assert len(dl.base_batch['generation_kwargs']) == 4
assert len(dl.base_batch['generation_kwargs']) == 5


@pytest.mark.filterwarnings(
Expand Down Expand Up @@ -1255,8 +1255,8 @@ def test_qa_split_batch(tiny_opt_tokenizer, dataset_uri, tmp_path):
assert len(split2['labels']) == 1
assert all(isinstance(v, list) for v in split1['labels'] + split2['labels'])

assert isinstance(split1['generation_length'], int)
assert isinstance(split2['generation_length'], int)
assert isinstance(split1['generation_kwargs']['max_new_tokens'], int)
assert isinstance(split2['generation_kwargs']['max_new_tokens'], int)

assert isinstance(split1['generation_kwargs'], dict)
assert isinstance(split2['generation_kwargs'], dict)
Expand Down Expand Up @@ -1326,7 +1326,7 @@ def test_qa_task_dataloader(dataset_uri, tiny_gpt2_tokenizer, tmp_path, num_fews
assert batch['mode'] == 'generate'
# the maximum generation length from the small test data

assert batch['generation_length'] == maximum_answer_length
assert batch['generation_kwargs']['max_new_tokens'] == maximum_answer_length
assert all(item[0] == tokenizer.eos_token_id for item in batch['input_ids'])

decoded_batch = tokenizer.batch_decode(batch['input_ids'])
Expand Down Expand Up @@ -1376,7 +1376,7 @@ def test_qa_task_with_cot_dataloader(dataset_uri, tiny_gpt2_tokenizer, tmp_path,
assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen - maximum_answer_length)
assert batch['mode'] == 'generate'
# the maximum generation length from the small test data
assert batch['generation_length'] == maximum_answer_length
assert batch['generation_kwargs']['max_new_tokens'] == maximum_answer_length
assert all(item[0] == tokenizer.eos_token_id for item in batch['input_ids'])
decoded_batch = tokenizer.batch_decode(batch['input_ids'])
assert all(item.count('Q: ') == num_fewshot + 1 for item in decoded_batch)
Expand Down Expand Up @@ -1491,10 +1491,11 @@ def test_code_eval_split_batch(dataset_uri, tmp_path):
assert len(batch[field]) == size
assert all(isinstance(val, type_) for val in batch[field])

static_keys = {'pass_at_k': (int, list), 'generation_length': int, 'generation_kwargs': dict}
static_keys = {'pass_at_k': (int, list), 'generation_kwargs': dict}
for batch in batches:
for field, type_ in static_keys.items():
assert isinstance(batch[field], type_)
assert batch['generation_kwargs']['max_new_tokens'] == 122


@pytest.mark.parametrize('dataset_uri', ['human_eval_small.jsonl'])
Expand Down Expand Up @@ -1544,7 +1545,7 @@ def test_code_eval_sentpiece_dataloader(dataset_uri, tmp_path, num_fewshot, prom
assert tuple(batch['attention_mask'].shape) == (bs, max_prompt_length)
assert batch['mode'] == 'generate'
# the maximum generation length from the small test data
assert batch['generation_length'] == 129
assert batch['generation_kwargs']['max_new_tokens'] == 129
has_left_padding.extend([item[0] == tokenizer.eos_token_id for item in batch['input_ids']])
assert not all(has_left_padding) # longest should be pushed left

Expand Down Expand Up @@ -1613,7 +1614,7 @@ def test_code_eval_test_cases(dataset_uri, tmp_path, tiny_llama_tokenizer):
assert tuple(batch['attention_mask'].shape) == (batch_size, max_prompt_length)
assert batch['mode'] == 'generate'
# the maximum generation length from the small test data
assert batch['generation_length'] == 129
assert batch['generation_kwargs']['max_new_tokens'] == 129
assert any(item[0] != tokenizer.eos_token_id for item in batch['input_ids']) # longest should be pushed left

mod = types.ModuleType('test_module')
Expand Down Expand Up @@ -1703,7 +1704,7 @@ def test_code_eval_task_dataloader(dataset_uri, tmp_path, num_fewshot, prompt_st
assert tuple(batch['attention_mask'].shape) == (bs, max_prompt_length)
assert batch['mode'] == 'generate'
# the maximum generation length from the small test data
assert batch['generation_length'] == 122
assert batch['generation_kwargs']['max_new_tokens'] == 122
has_left_padding.extend([item[0] == tokenizer.eos_token_id for item in batch['input_ids']])
assert not all(has_left_padding) # longest should be pushed left

Expand Down Expand Up @@ -2459,7 +2460,7 @@ def test_hf_dataloading_custom_parsing(dataset_uri, tiny_gpt2_tokenizer, tmp_pat
assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen - maximum_answer_length)
assert batch['mode'] == 'generate'
# the maximum generation length from the small test data
assert batch['generation_length'] == maximum_answer_length
assert batch['generation_kwargs']['max_new_tokens'] == maximum_answer_length
assert all(item[0] == tokenizer.eos_token_id for item in batch['input_ids'])

decoded_batch = tokenizer.batch_decode(batch['input_ids'])
Expand Down
5 changes: 3 additions & 2 deletions tests/models/test_hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1195,11 +1195,12 @@ def test_eval_forward_generate(device, world_size, hf_model, hf_tokenizer, use_f
for k, v in input_dict.items():
input_dict[k] = device.tensor_to_device(v)
input_dict['mode'] = 'generate'
input_dict['generation_kwargs'] = {}

input_dict['generation_length'] = 5
input_dict['generation_kwargs']['max_new_tokens'] = 5
input_dict['labels'] = [['answer1'], ['answer2']]
generation1 = model.eval_forward(input_dict, None)
input_dict['generation_length'] = 3
input_dict['generation_kwargs']['max_new_tokens'] = 3
input_dict['labels'] = [['answer1'], ['answer2']]
generation2 = model.eval_forward(input_dict, None)

Expand Down
Loading