Skip to content

Commit

Permalink
Merge branch 'main' into shashank/flexattention
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML authored Dec 17, 2024
2 parents 718d89d + 3269c73 commit 135abd7
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 2 deletions.
28 changes: 28 additions & 0 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1161,3 +1161,31 @@ def shareGPT_format_preprocessor(inp: dict) -> ChatFormattedDict:
except Exception as e:
raise UnableToProcessPromptResponseError(inp) from e
return {'messages': messages}


@dataset_constructor.register('math-ai/StackMathQA')
def QA_format_preprocessor(inp: dict) -> ChatFormattedDict:
"""Convert from QA format to our chat format."""
try:
Q = inp['Q']
A = inp['A']
messages: list[dict[str, str]] = [{
'role': 'user',
'content': Q,
}, {
'role': 'assistant',
'content': A,
}]
except Exception as e:
raise UnableToProcessPromptResponseError(inp) from e
return {'messages': messages}


@dataset_constructor.register('AI-MO/NuminaMath-CoT')
def messages_format_preprocessor(inp: dict) -> ChatFormattedDict:
"""Convert from QA format to our chat format."""
try:
messages = inp['messages']
except Exception as e:
raise UnableToProcessPromptResponseError(inp) from e
return {'messages': messages}
1 change: 1 addition & 0 deletions llmfoundry/models/hf/hf_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ def build_inner_model(
model = PeftModelForCausalLM.from_pretrained(
model,
pretrained_lora_id_or_path,
is_trainable=True,
)

if prepare_for_fsdp:
Expand Down
4 changes: 3 additions & 1 deletion llmfoundry/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
import warnings
from dataclasses import dataclass, fields
from pathlib import Path
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -703,6 +704,8 @@ def _process_data_source(
true_split (str): The split of the dataset to be added (i.e. train or eval)
data_paths (List[Tuple[str, str, str]]): A list of tuples formatted as (data type, path, split)
"""
if source_dataset_path:
source_dataset_path = str(Path(source_dataset_path))
# Check for Delta table
if source_dataset_path and len(source_dataset_path.split('.')) == 3:
data_paths.append(('delta_table', source_dataset_path, true_split))
Expand Down Expand Up @@ -788,7 +791,6 @@ def log_dataset_uri(cfg: dict[str, Any]) -> None:

# Map data source types to their respective MLFlow DataSource.
for dataset_type, path, split in data_paths:

if dataset_type in dataset_source_mapping:
source_class = dataset_source_mapping[dataset_type]
if dataset_type == 'delta_table':
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
'transformers>=4.43.2,<4.47',
'mosaicml-streaming>=0.10.0,<0.11',
'torch>=2.5.1,<2.5.2',
'datasets>=2.20.0,<2.21',
'datasets>=2.20.0,<3.2',
'fsspec==2023.6.0', # newer version results in a bug in datasets that duplicates data
'sentencepiece==0.2.0',
'einops==0.8.0',
Expand Down
47 changes: 47 additions & 0 deletions tests/data/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
import pytest

from llmfoundry.data.finetuning.tasks import (
QA_format_preprocessor,
_get_num_processes,
dataset_constructor,
messages_format_preprocessor,
)
from llmfoundry.utils.exceptions import DatasetTooSmallError

Expand Down Expand Up @@ -60,3 +62,48 @@ def get_local_world_size(self):
new=MockDataset,
):
dataset_constructor.build_from_streaming()


def test_QA_format_preprocessor():
inp = {
'Q': 'What is the capital of France?',
'A': 'Paris',
'meta': {
'a': 'b',
},
}

expected_messages = [{
'role': 'user',
'content': 'What is the capital of France?',
}, {
'role': 'assistant',
'content': 'Paris',
}]
output = QA_format_preprocessor(inp)
assert len(output) == 1
assert 'messages' in output
for i, message in enumerate(output['messages']):
expected_message = expected_messages[i]
for k, v in message.items():
assert k in expected_message
assert v == expected_message[k]


def test_messages_format_preprocessor():
messages = [{
'role': 'user',
'content': 'What is the capital of France?',
}, {
'role': 'assistant',
'content': 'Paris',
}]
inp = {
'messages': messages,
'other_key': 'other_value',
}

output = messages_format_preprocessor(inp)
assert len(output) == 1
assert 'messages' in output
assert output['messages'] == messages
23 changes: 23 additions & 0 deletions tests/models/hf/test_hf_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from peft import PeftModel

from llmfoundry.models.hf.hf_base import BaseHuggingFaceModel


Expand All @@ -23,3 +25,24 @@ def test_build_inner_model_fsdp():
)

assert model.fsdp_wrap_fn(model.model.layers[0])


def test_pretrained_peft_trainable():
model = BaseHuggingFaceModel.build_inner_model(
pretrained_model_name_or_path='facebook/opt-350m',
pretrained_lora_id_or_path='ybelkada/opt-350m-lora',
trust_remote_code=False,
init_device='cpu',
use_flash_attention_2=False,
use_auth_token=False,
config_overrides={},
load_in_8bit=False,
pretrained=True,
prepare_for_fsdp=True,
)

assert isinstance(model, PeftModel)

n_trainable, n_all = model.get_nb_trainable_parameters()
assert n_all > 0
assert n_trainable > 0

0 comments on commit 135abd7

Please sign in to comment.