Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate ICL classes to foundry #936

Merged
merged 100 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from 84 commits
Commits
Show all changes
100 commits
Select commit Hold shift + click to select a range
cd18e74
start
bmosaicml Jan 23, 2024
1fffbad
still need to migrate fixtures
bmosaicml Feb 2, 2024
5a6e81c
Merge branch 'main' into migrate_subclasses_to_foundry
bmosaicml Feb 2, 2024
4aac81e
wip onboarding tests
bmosaicml Feb 2, 2024
946a4af
still workin'
bmosaicml Feb 2, 2024
289ca55
still wip
bmosaicml Feb 2, 2024
3696f8d
maybe done; test out on mcli now
bmosaicml Feb 2, 2024
a20877d
mcli
bmosaicml Feb 2, 2024
53da3ea
remove calibration error
bmosaicml Feb 2, 2024
16b8e32
merge
bmosaicml Feb 7, 2024
a90766e
migration
bmosaicml Feb 7, 2024
72ce793
migration
bmosaicml Feb 7, 2024
667bdec
Merge branch 'migrate_subclasses_to_foundry' of github.com:mosaicml/l…
bmosaicml Feb 9, 2024
ceff0c4
full migration
bmosaicml Feb 9, 2024
5bb06cc
precommit
bmosaicml Feb 12, 2024
fe83828
fix
bmosaicml Feb 12, 2024
b54a12b
fix pytests
bmosaicml Feb 12, 2024
71e8391
refactor QA
bmosaicml Feb 16, 2024
414153e
update
bmosaicml Feb 22, 2024
a3f5a31
restore
bmosaicml Feb 23, 2024
820069a
Merge branch 'main' into migrate_subclasses_to_foundry
bmosaicml Feb 23, 2024
4a1cd79
add
bmosaicml Feb 23, 2024
d265979
Merge branch 'migrate_subclasses_to_foundry' of github.com:mosaicml/l…
bmosaicml Feb 23, 2024
ddfd7b5
Merge branch 'main' into migrate_subclasses_to_foundry
bmosaicml Feb 26, 2024
71f77e3
fix
bmosaicml Feb 26, 2024
cb3725b
wip
bmosaicml Feb 27, 2024
5135152
update readme
bmosaicml Feb 27, 2024
18bae17
Merge branch 'main' into migrate_subclasses_to_foundry
bmosaicml Feb 27, 2024
c6162dd
final pyright
bmosaicml Feb 27, 2024
25d431e
Merge branch 'migrate_subclasses_to_foundry' of github.com:mosaicml/l…
bmosaicml Feb 27, 2024
f1b334d
done
bmosaicml Feb 27, 2024
c4ed644
pass prelimiter into ALL the ICL task datasets
eitanturok Feb 27, 2024
2516c24
Merge branch 'main' into migrate_subclasses_to_foundry
bmosaicml Mar 4, 2024
f213a40
allow QA task name stil lfor backward compatibility
bmosaicml Mar 4, 2024
35fd2f1
Merge branch 'migrate_subclasses_to_foundry' of github.com:mosaicml/l…
bmosaicml Mar 4, 2024
d570e5d
fix
bmosaicml Mar 5, 2024
a5cd308
fix test
bmosaicml Mar 5, 2024
0fb37cd
Merge branch 'main' into migrate_subclasses_to_foundry
bmosaicml Mar 5, 2024
901fc69
add generation length
bmosaicml Mar 5, 2024
a313499
Merge branch 'migrate_subclasses_to_foundry' of github.com:mosaicml/l…
bmosaicml Mar 5, 2024
df19c0d
remove max_new_tokens
bmosaicml Mar 5, 2024
54bb4c7
fix cpu trsts
bmosaicml Mar 6, 2024
9ebeaa0
Merge branch 'main' into migrate_subclasses_to_foundry
bmosaicml Mar 6, 2024
ca9816c
Merge branch 'main' into migrate_subclasses_to_foundry
maxisawesome Mar 6, 2024
b9d6cd1
try and fix lm eval test
bmosaicml Mar 7, 2024
691ab20
Merge branch 'migrate_subclasses_to_foundry' of github.com:mosaicml/l…
bmosaicml Mar 7, 2024
c207cd9
temp disable lm task eval test
bmosaicml Mar 7, 2024
c85813b
fix test?
bmosaicml Mar 8, 2024
08ef908
fix tet
bmosaicml Mar 11, 2024
aca0e63
finish
bmosaicml Mar 11, 2024
30fcedd
fix
bmosaicml Mar 12, 2024
59daa26
Merge branch 'main' into migrate_subclasses_to_foundry
maxisawesome Mar 13, 2024
4217a78
Update scripts/eval/README.md
bmosaicml Mar 13, 2024
6f597a9
fix comments
bmosaicml Mar 13, 2024
8c6e622
Merge branch 'migrate_subclasses_to_foundry' of github.com:mosaicml/l…
bmosaicml Mar 13, 2024
f387a73
fix bug with seq len
bmosaicml Mar 14, 2024
cbfa3da
Merge branch 'main' into migrate_subclasses_to_foundry
bmosaicml Mar 14, 2024
2f405d9
restore mcli
bmosaicml Mar 14, 2024
76e600a
Merge branch 'migrate_subclasses_to_foundry' of github.com:mosaicml/l…
bmosaicml Mar 14, 2024
898928e
Merge branch 'main' into migrate_subclasses_to_foundry
bmosaicml Mar 15, 2024
07fb59e
Merge branch 'main' into migrate_subclasses_to_foundry
bmosaicml Mar 18, 2024
587971f
Merge branch 'main' into migrate_subclasses_to_foundry
bmosaicml Mar 18, 2024
18efa86
Merge branch 'main' into migrate_subclasses_to_foundry
bmosaicml Mar 29, 2024
7faeb78
merge
bmosaicml Apr 1, 2024
343e115
fix builder
bmosaicml Apr 1, 2024
bf6231e
add deprecation warning
bmosaicml Apr 1, 2024
58859a3
Merge branch 'main' into migrate_subclasses_to_foundry
bmosaicml Apr 1, 2024
501d4cc
add deprecation warning
bmosaicml Apr 1, 2024
22f6759
Merge branch 'migrate_subclasses_to_foundry' of github.com:mosaicml/l…
bmosaicml Apr 1, 2024
414467a
merge
bmosaicml Apr 1, 2024
65fbbed
merge
bmosaicml Apr 1, 2024
5696f09
add logging necessities to nlp.py
maxisawesome Apr 1, 2024
91a2b18
add attention_mask test update
maxisawesome Apr 1, 2024
79877ee
fix generation_length in tests
maxisawesome Apr 1, 2024
9c50795
Merge branch 'main' into migrate_subclasses_to_foundry
maxisawesome Apr 1, 2024
eac919a
fix bug
maxisawesome Apr 1, 2024
965c20c
Merge branch 'migrate_subclasses_to_foundry' of github.com:mosaicml/l…
maxisawesome Apr 1, 2024
57e902a
Merge branch 'main' into migrate_subclasses_to_foundry
bmosaicml Apr 2, 2024
7c0996f
Merge branch 'main' into migrate_subclasses_to_foundry
maxisawesome Apr 2, 2024
599695c
Merge branch 'main' into migrate_subclasses_to_foundry
maxisawesome Apr 2, 2024
e10086f
restore yamls
bmosaicml Apr 3, 2024
a60ef1d
Merge branch 'main' into migrate_subclasses_to_foundry
bmosaicml Apr 4, 2024
d78d783
Merge branch 'main' into migrate_subclasses_to_foundry
maxisawesome Apr 4, 2024
1ddf194
Merge branch 'main' into migrate_subclasses_to_foundry
maxisawesome Apr 9, 2024
d5aebc8
fix typos
bmosaicml Apr 10, 2024
d7272b1
Merge branch 'migrate_subclasses_to_foundry' of github.com:mosaicml/l…
bmosaicml Apr 10, 2024
a5082b0
add deprecation warning for code
maxisawesome Apr 10, 2024
3c8ac56
Merge branch 'migrate_subclasses_to_foundry' of github.com:mosaicml/l…
maxisawesome Apr 10, 2024
642ad40
pyright wip
maxisawesome Apr 10, 2024
f30db14
Merge branch 'main' into migrate_subclasses_to_foundry
maxisawesome Apr 11, 2024
de321b2
fix pyright
bmosaicml Apr 11, 2024
019c58a
fix pyright error again
bmosaicml Apr 11, 2024
779f490
fix pyright
bmosaicml Apr 11, 2024
03f7e91
fix pyright
bmosaicml Apr 11, 2024
e81823d
Merge branch 'main' into migrate_subclasses_to_foundry
bmosaicml Apr 12, 2024
709fc80
Merge branch 'main' into migrate_subclasses_to_foundry
bmosaicml Apr 12, 2024
eb494d8
Merge branch 'main' into migrate_subclasses_to_foundry
bmosaicml Apr 12, 2024
3cd226d
Merge branch 'main' into migrate_subclasses_to_foundry
bmosaicml Apr 12, 2024
02308df
update version
maxisawesome Apr 12, 2024
1a36d3f
Merge branch 'migrate_subclasses_to_foundry' of github.com:mosaicml/l…
maxisawesome Apr 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
my-copy-c4*/
my-copy-arxiv*/
*.jsonl*
!tests/eval/local_data/*.jsonl

# WandB
wandb/
Expand Down
34 changes: 34 additions & 0 deletions llmfoundry/eval/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Natively supported datasets."""

from llmfoundry.eval.datasets.in_context_learning_evaluation import (
InContextLearningCodeEvalDataset, InContextLearningDataset,
InContextLearningGenerationTaskWithAnswersDataset,
InContextLearningLMTaskDataset, InContextLearningMultipleChoiceTaskDataset,
InContextLearningSchemaTaskDataset, get_icl_task_dataloader)
bmosaicml marked this conversation as resolved.
Show resolved Hide resolved
from llmfoundry.eval.datasets.utils import (get_continuation_span,
get_fewshot_sample_idxs,
make_padded_input, strip_data,
tokenizer_needs_prefix_space,
trim_context)

__all__ = [
'InContextLearningDataset',
'InContextLearningGenerationTaskWithAnswersDataset',
'InContextLearningLMTaskDataset',
'InContextLearningCodeEvalDataset',
'InContextLearningMultipleChoiceTaskDataset',
'InContextLearningSchemaTaskDataset',
'get_icl_task_dataloader',
'strip_data',
'tokenizer_needs_prefix_space',
'trim_context',
'get_continuation_span',
'get_fewshot_sample_idxs',
'make_padded_input',
]
1,782 changes: 1,782 additions & 0 deletions llmfoundry/eval/datasets/in_context_learning_evaluation.py

Large diffs are not rendered by default.

280 changes: 280 additions & 0 deletions llmfoundry/eval/datasets/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Utility and helper functions for datasets."""
from __future__ import annotations

import logging
import random
from typing import Any, Dict, List, Optional, Set

import torch
import transformers

__all__ = [
'MultiTokenEOSCriteria',
]

log = logging.getLogger(__name__)


def strip_data(example: Dict) -> Dict:
"""Remove white space from the begging and end of string values in a.

dictionary.

Args:
example: Dictionary to be stripped

Returns:
dict: The same dictionary with .strip() applied to any value in the dict that is a string
"""
return {
k: v.strip() if isinstance(v, str) else v for k, v in example.items()
}


def tokenizer_needs_prefix_space(
tokenizer: transformers.PreTrainedTokenizerBase) -> bool:
"""Test for whether a prefix space is needed before the continuation.

Sentencepiece tokenization should not have a prefix space, but gpt2 style
BPE should.

Args:
tokenizer: Tokenizer to test

Returns:
bool: Whether or not the tokenizer needs a prefix space
"""
test_tokens = tokenizer(' a', add_special_tokens=False)['input_ids']
assert isinstance(test_tokens, list)
return len(test_tokens) == 1


def trim_context(context_enc: List, continuation_enc: List,
max_seq_len: int) -> List:
"""Trims a list of tokens down to `max_seq_len` if the length of the list.

plus the continuation is more than `max_seq_len`. It will always trim tokens
from the left, i.e. tokens at the beginning of the context will be removed.

Args:
context_enc (list): List of tokens in the context
continuation_enc (lsit): List of tokens in the continuation
max_seq_len (int): Maximum length the model can ingest

Returns:
list: The encoded context trimmed from the left
"""
if len(continuation_enc) + len(context_enc) > max_seq_len:
context_max_subseq_len = max_seq_len - len(continuation_enc)

if context_max_subseq_len < 0:
# can't support continuations which are longer than the max seq len
raise Exception(
f'Dataset included continuation longer than the max seq len')

# clip from the end
context_enc = context_enc[-(context_max_subseq_len):]
return context_enc


def get_continuation_span(context_enc: List,
continuation_enc: List) -> torch.Tensor:
"""Gets the list of indices of the continuation tokens for language.

modeling.

or generation tasks.

Args:
context_enc (list): List of context tokens
continuation_enc (list): List of continuation tokens

Returns:
torch.tensor: A tensor containing indices corresponding to continuation tokens
"""
return torch.tensor(
range(len(context_enc),
len(context_enc) + len(continuation_enc)))


def make_padded_input(context_enc: List,
continuation_enc: List,
max_seq_len: int,
pad_tok_id: int,
padding_side: str = 'right') -> torch.Tensor:
"""Takes an encoded context and continuation and clips the beginning of the.

context if they're too long. Adds the padding token to the specified side.

Args:
context_enc (List): The encoded input to the model
continuation_enc (List): The encoded desired output for the example
max_seq_list (int): Maximum length sequences can be
pad_tok_id (int): The token id we pad with
padding_side (str): Which side to pad the context on. Can be 'right' or 'left

Returns:
input (torch.tensor): The padded and encoded context
continuation_span (torch.tensor): The _inclusive_ range of indices corresponding to the continuation
"""
inp = torch.tensor(
(context_enc + continuation_enc),
dtype=torch.long,
)
(inp_len,) = inp.shape

# Sometimes tokenizers that have neither a pad_tok_id or eos_tok_id will pass None in as the padding
# token and cause errors
if not isinstance(pad_tok_id, int):
raise ValueError(
f'`pad_tok_id` must be an integer. Found {type(pad_tok_id)} instead'
)
# pad length from seq to padding_length
if padding_side == 'right':
inp = torch.cat(
[
inp, # [seq]
torch.LongTensor((max_seq_len - inp_len) * [pad_tok_id]),
],
dim=0,
)
elif padding_side == 'left':
inp = torch.cat(
[
torch.LongTensor((max_seq_len - inp_len) * [pad_tok_id]),
inp, # [seq]
],
dim=0,
)
else:
raise ValueError(
f"Unknown padding_side {padding_side}. padding_side must be either 'left' or 'right'"
)

return inp


def convert_tokens_to_tensors(batch: Dict,
tokenize_labels: bool) -> Dict[str, Any]:
"""HF Datasets converts tensors into lists when we store them, and we don't.

want to use `type='torch'` because some content in the dataset, like
generation args or single ints, should not be converted.

Here, we convert those lists of tokens back into tensors in order to feed them into the model.

Args:
batch (dict): A dictionary of batched inputs
tokenize_labels (bool): Whether or not the labels are tokenized (and need to be stacked)

Returns:
dict: The batch with torch tensors in the corresponding keys instead of lists of lists
"""
batch['input_ids'] = torch.stack(list(map(torch.tensor,
batch['input_ids'])))
if tokenize_labels:
batch['labels'] = torch.stack(list(map(torch.tensor, batch['labels'])))
batch['continuation_indices'] = list(
map(torch.tensor, batch['continuation_indices']))
return batch


def get_fewshot_sample_idxs(dataset_size: int, num_fewshot: int,
example_idx: int, rng: random.Random) -> Set[int]:
"""Samples indices without replacement. If num_fewshot exceeds the number.

of unique examples in the dataset, then we will have fewer than num_fewshot examples in context.

Args:
dataset_size (int): Length of the dataset
num_fewshot (int): Number of examples to prepend
example_idx (int): Current example's index (excluded from fewshot choices)
rng (random.Random): RNG for repeatable sample selection

Returns:
list: Indices of the examples chosen for fewshot selection
"""
num_fewshot = min(dataset_size - 1, num_fewshot)
fewshot_idxs = set(rng.sample(range(0, dataset_size), num_fewshot))

if example_idx in fewshot_idxs:
fewshot_idxs.remove(example_idx)
if len(fewshot_idxs) >= dataset_size - 1:
return fewshot_idxs

replacement_sample = rng.choice(range(0, dataset_size))
while replacement_sample in fewshot_idxs or replacement_sample == example_idx:
replacement_sample = rng.choice(range(0, dataset_size))
fewshot_idxs.add(replacement_sample)
return fewshot_idxs


class MultiTokenEOSCriteria(transformers.StoppingCriteria):
"""Criteria to stop on the specified multi-token sequence.

Slightly modified from: https://github.com/EleutherAI/lm-evaluation-harness/blob/78545d42f2ca95c6fe0ed220d456eeb94f4485e9/lm_eval/utils.py#L614-L649
"""

def __init__(
self,
stop_sequence: str,
tokenizer: transformers.PreTrainedTokenizerBase,
batch_size: int,
) -> None:
self.done_tracker = [False] * batch_size
self.stop_sequence = stop_sequence
self.stop_sequence_ids = tokenizer.encode(stop_sequence,
add_special_tokens=False)

# sentence piece tokenizers add a superflous underline token before string-initial \n
# that throws off our calculation of the stop sequence length
# so we remove any token ids that produce empty strings
self.stop_sequence_ids = [
id for id in self.stop_sequence_ids if tokenizer.decode(id) != ''
]

# we look back for 1 more token than it takes to encode our stop sequence
# because tokenizers suck, and a model might generate `['\n', '\n']` but our `sequence` is `['\n\n']`
# and we don't want to mistakenly not stop a generation because our
# (string) stop sequence was output in a different tokenization

self.stop_sequence_id_len = len(self.stop_sequence_ids) + 1
self.tokenizer = tokenizer

def __call__(self,
input_ids: torch.LongTensor,
scores: Optional[torch.FloatTensor] = None,
**kwargs: Dict[str, Any]) -> bool:
# For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence
lookback_ids_batch = input_ids[:, :][:, -self.stop_sequence_id_len:]
lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch)
for i, done in enumerate(self.done_tracker):
if i >= len(lookback_tokens_batch):
# The last batch of a dataset may be smaller than `batch_size`
# Automatically set those indices in the done_tracker to True
# since those indices don't show up in the current batch
self.done_tracker[i] = True
break
elif not done:
self.done_tracker[
i] = self.stop_sequence in lookback_tokens_batch[i]
return False not in self.done_tracker


def stop_sequences_criteria(
tokenizer: transformers.PreTrainedTokenizerBase,
stop_sequences: List[str],
batch_size: int,
) -> transformers.StoppingCriteriaList:
return transformers.StoppingCriteriaList([
*[
MultiTokenEOSCriteria(sequence, tokenizer, batch_size)
for sequence in stop_sequences
],
])
24 changes: 24 additions & 0 deletions llmfoundry/eval/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""A collection of common torchmetrics."""

from llmfoundry.eval.metrics.nlp import (
InContextLearningCodeEvalAccuracy,
InContextLearningGenerationExactMatchAccuracy, InContextLearningLMAccuracy,
InContextLearningLMExpectedCalibrationError,
InContextLearningMCExpectedCalibrationError, InContextLearningMetric,
InContextLearningMultipleChoiceAccuracy)
bmosaicml marked this conversation as resolved.
Show resolved Hide resolved

__all__ = [
'InContextLearningLMAccuracy',
'InContextLearningMultipleChoiceAccuracy',
'InContextLearningGenerationExactMatchAccuracy',
'InContextLearningMCExpectedCalibrationError',
'InContextLearningLMExpectedCalibrationError',
'InContextLearningMetric',
'InContextLearningCodeEvalAccuracy',
]
Loading
Loading