Skip to content

Commit

Permalink
Merge pull request #264 from gkumbhat/add_global_training_data_limit
Browse files Browse the repository at this point in the history
Update training data validation to consider global and module level defaults
  • Loading branch information
gkumbhat authored Nov 7, 2023
2 parents 4a5b2f8 + b387a77 commit 4be54cf
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 13 deletions.
1 change: 1 addition & 0 deletions caikit_nlp/config/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ master_addr: localhost
master_port: 29550

training_data_limit:
__default__: -1
# Configuration for PeftPromptTuning module
6655831b-960a-4dc5-8df4-867026e2cd41:
add_model_name_here: 10000
Expand Down
17 changes: 5 additions & 12 deletions caikit_nlp/modules/text_generation/peft_prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
import transformers

# First Party
from caikit import get_config
from caikit.core.data_model import DataStream
from caikit.core.exceptions import error_handler
from caikit.core.modules import ModuleBase, ModuleConfig, ModuleSaver, module
Expand Down Expand Up @@ -73,6 +72,7 @@
generate_text_func,
generate_text_func_stream,
)
from ...toolkit.trainer_utils import validate_training_data
from ...toolkit.verbalizer_utils import render_verbalizer
from .peft_config import TuningType, get_peft_config, resolve_base_model

Expand Down Expand Up @@ -368,19 +368,12 @@ def train(
)

# Check if data is within limit allowed for this module and model
max_num_examples = (
get_config()
.training_data_limit.get(cls.MODULE_ID, {})
.get(base_model_name, -1)
validate_training_data(
train_stream,
base_model_name,
cls.MODULE_ID,
)

if max_num_examples > 0:
error.value_check(
"<NLP77627434E>",
len(train_stream) <= max_num_examples,
"Number of examples larger than maximum number of examples allowed for this model",
)

# Coerce the passed model into a resource; if we have one, this is a noop
# TODO: When splitting up this mono-module, use the configured resource
# type of the concrete class to bootstrap
Expand Down
27 changes: 27 additions & 0 deletions caikit_nlp/toolkit/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,36 @@
import torch

# First Party
from caikit import get_config
from caikit.core.data_model import DataStream
from caikit.core.exceptions import error_handler
import alog

log = alog.use_channel("TRNR_UTILS")
error = error_handler.get(log)


def validate_training_data(train_stream: DataStream, model_name: str, module_id: str):

global_default = get_config().training_data_limit.__default__
module_default = (
get_config()
.training_data_limit.get(module_id, {})
.get("__default__", global_default)
)

max_num_examples = (
get_config()
.training_data_limit.get(module_id, {})
.get(model_name, module_default)
)

if max_num_examples > 0:
error.value_check(
"<NLP77627434E>",
len(train_stream) <= max_num_examples,
"Number of examples larger than maximum number of examples allowed for this model",
)


def log_step(state, logs):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ classifiers=[
"License :: OSI Approved :: Apache Software License"
]
dependencies = [
"caikit[runtime-grpc,runtime-http]>=0.23.2,<0.25.0",
"caikit[runtime-grpc,runtime-http]>=0.24.0,<0.25.0",
"caikit-tgis-backend>=0.1.17,<0.2.0",
# TODO: loosen dependencies
"accelerate>=0.22.0",
Expand Down
132 changes: 132 additions & 0 deletions tests/modules/text_generation/test_peft_prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,3 +505,135 @@ def test_train_with_no_limit_for_module(causal_lm_train_kwargs, set_cpu_device):

model = module.train(**causal_lm_train_kwargs)
assert model


def test_train_module_level_data_validation_raises(
causal_lm_train_kwargs, set_cpu_device
):
"""Check if train raises with module level default configuration
if training data is within limits and model config is not provided
"""
patch_kwargs = {
"num_epochs": 1,
"verbalizer": "Tweet text : {{input}} Label : ",
"train_stream": caikit.core.data_model.DataStream.from_iterable(
[
ClassificationTrainRecord(
text="@foo what a cute dog!", labels=["no complaint"]
),
ClassificationTrainRecord(
text="@bar this is the worst idea ever.", labels=["complaint"]
),
]
),
"torch_dtype": torch.bfloat16,
"device": "cpu",
}
causal_lm_train_kwargs.update(patch_kwargs)

module = caikit_nlp.modules.text_generation.PeftPromptTuning
with temp_config(
training_data_limit={module.MODULE_ID: {"__default__": 1, "foo": 2}}
):
with pytest.raises(ValueError):
module.train(**causal_lm_train_kwargs)


def test_train_module_level_data_validation_success(
causal_lm_train_kwargs, set_cpu_device
):
"""Check if we are able to train successfully with module level default configuration
if training data is within limits and model config present
"""
patch_kwargs = {
"num_epochs": 1,
"verbalizer": "Tweet text : {{input}} Label : ",
"train_stream": caikit.core.data_model.DataStream.from_iterable(
[
ClassificationTrainRecord(
text="@foo what a cute dog!", labels=["no complaint"]
),
ClassificationTrainRecord(
text="@bar this is the worst idea ever.", labels=["complaint"]
),
]
),
"torch_dtype": torch.bfloat16,
"device": "cpu",
}
causal_lm_train_kwargs.update(patch_kwargs)

model_name = causal_lm_train_kwargs["base_model"]._model_name
module = caikit_nlp.modules.text_generation.PeftPromptTuning
with temp_config(
training_data_limit={module.MODULE_ID: {"__default__": 1, model_name: 2}}
):

model = module.train(**causal_lm_train_kwargs)
assert model


def test_train_global_default_data_validation_raises(
causal_lm_train_kwargs, set_cpu_device
):
"""Check if train raises with global default configuration
if training data is within limits and model config is not provided
"""
patch_kwargs = {
"num_epochs": 1,
"verbalizer": "Tweet text : {{input}} Label : ",
"train_stream": caikit.core.data_model.DataStream.from_iterable(
[
ClassificationTrainRecord(
text="@foo what a cute dog!", labels=["no complaint"]
),
ClassificationTrainRecord(
text="@bar this is the worst idea ever.", labels=["complaint"]
),
]
),
"torch_dtype": torch.bfloat16,
"device": "cpu",
}
causal_lm_train_kwargs.update(patch_kwargs)

module = caikit_nlp.modules.text_generation.PeftPromptTuning
with temp_config(
training_data_limit={"__default__": 1, module.MODULE_ID: {"foo": 2}}
):
with pytest.raises(ValueError):
module.train(**causal_lm_train_kwargs)


def test_train_global_default_data_validation_success(
causal_lm_train_kwargs, set_cpu_device
):
"""Check if we are able to train successfully with global default configuration
if training data is within limits and model config is present
"""
patch_kwargs = {
"num_epochs": 1,
"verbalizer": "Tweet text : {{input}} Label : ",
"train_stream": caikit.core.data_model.DataStream.from_iterable(
[
ClassificationTrainRecord(
text="@foo what a cute dog!", labels=["no complaint"]
),
ClassificationTrainRecord(
text="@bar this is the worst idea ever.", labels=["complaint"]
),
]
),
"torch_dtype": torch.bfloat16,
"device": "cpu",
}
causal_lm_train_kwargs.update(patch_kwargs)

model_name = causal_lm_train_kwargs["base_model"]._model_name
module = caikit_nlp.modules.text_generation.PeftPromptTuning
with temp_config(
training_data_limit={"__default__": 1, module.MODULE_ID: {model_name: 2}}
):

model = module.train(**causal_lm_train_kwargs)
assert model

0 comments on commit 4be54cf

Please sign in to comment.