diff --git a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py index 5642c824..91f747c9 100644 --- a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py +++ b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py @@ -48,8 +48,9 @@ ClassificationTrainRecord, GeneratedTextResult, GeneratedTextStreamResult, + TokenizationResults, ) -from caikit.interfaces.nlp.tasks import TextGenerationTask +from caikit.interfaces.nlp.tasks import TextGenerationTask, TokenizationTask import alog # Local @@ -87,7 +88,7 @@ id="6655831b-960a-4dc5-8df4-867026e2cd41", name="Peft generation", version="0.1.0", - task=TextGenerationTask, + tasks=[TextGenerationTask, TokenizationTask], ) class PeftPromptTuning(ModuleBase): @@ -274,6 +275,22 @@ def run_stream_out( stop_sequences=stop_sequences, ) + @TokenizationTask.taskmethod() + def run_tokenizer( + self, + text: str, + ) -> TokenizationResults: + """Run tokenization task against the model + + Args: + text: str + Text to tokenize + Returns: + TokenizationResults + The token count + """ + raise NotImplementedError("Tokenization not implemented for local") + @classmethod def train( cls, diff --git a/caikit_nlp/modules/text_generation/text_generation_local.py b/caikit_nlp/modules/text_generation/text_generation_local.py index e0558d4a..290551ba 100644 --- a/caikit_nlp/modules/text_generation/text_generation_local.py +++ b/caikit_nlp/modules/text_generation/text_generation_local.py @@ -31,8 +31,8 @@ from caikit.core.data_model import DataStream from caikit.core.exceptions import error_handler from caikit.core.modules import ModuleBase, ModuleConfig, ModuleSaver, module -from caikit.interfaces.nlp.data_model import GeneratedTextResult -from caikit.interfaces.nlp.tasks import TextGenerationTask +from caikit.interfaces.nlp.data_model import GeneratedTextResult, TokenizationResults +from caikit.interfaces.nlp.tasks import TextGenerationTask, TokenizationTask import alog # Local @@ -60,7 +60,7 @@ id="f9181353-4ccf-4572-bd1e-f12bcda26792", name="Text Generation", version="0.1.0", - task=TextGenerationTask, + tasks=[TextGenerationTask, TokenizationTask], ) class TextGeneration(ModuleBase): """Module to provide text generation capabilities""" @@ -521,6 +521,7 @@ def save(self, model_path): json.dump(loss_log, f) f.write("\n") + @TextGenerationTask.taskmethod() def run( self, text: str, @@ -575,6 +576,22 @@ def run( **kwargs, ) + @TokenizationTask.taskmethod() + def run_tokenizer( + self, + text: str, + ) -> TokenizationResults: + """Run tokenization task against the model + + Args: + text: str + Text to tokenize + Returns: + TokenizationResults + The token count + """ + raise NotImplementedError("Tokenization not implemented for local") + ################################## Private Functions ###################################### @staticmethod diff --git a/caikit_nlp/toolkit/torch_run.py b/caikit_nlp/toolkit/torch_run.py index 3a8879c8..43184f2d 100644 --- a/caikit_nlp/toolkit/torch_run.py +++ b/caikit_nlp/toolkit/torch_run.py @@ -24,7 +24,8 @@ # Third Party from torch import cuda -from torch.distributed.launcher.api import LaunchConfig, Std +from torch.distributed.elastic.multiprocessing.api import Std +from torch.distributed.launcher.api import LaunchConfig import torch.distributed as dist # First Party diff --git a/pyproject.toml b/pyproject.toml index e1ce63ec..bcb82db0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ dependencies = [ "scipy>=1.8.1", "sentence-transformers>=2.3.1,<2.4.0", "tokenizers>=0.13.3", - "torch>=2.0.1", + "torch>=2.0.1,<2.3.0", "tqdm>=4.65.0", "transformers>=4.32.0", "peft==0.6.0", diff --git a/tests/modules/text_generation/test_peft_prompt_tuning.py b/tests/modules/text_generation/test_peft_prompt_tuning.py index 5cf82439..74360d36 100644 --- a/tests/modules/text_generation/test_peft_prompt_tuning.py +++ b/tests/modules/text_generation/test_peft_prompt_tuning.py @@ -429,6 +429,14 @@ def test_run_exponential_decay_len_penatly_object(causal_lm_dummy_model): assert isinstance(pred, GeneratedTextResult) +def test_run_tokenizer_not_implemented(causal_lm_dummy_model): + with pytest.raises(NotImplementedError): + causal_lm_dummy_model.run_tokenizer("This text doesn't matter") + + +######################## Test train ############################################### + + def test_train_with_data_validation_raises(causal_lm_train_kwargs, set_cpu_device): """Check if we are able to throw error for when number of examples are more than configured limit""" patch_kwargs = { diff --git a/tests/modules/text_generation/test_text_generation_local.py b/tests/modules/text_generation/test_text_generation_local.py index 5afce777..f76400f1 100644 --- a/tests/modules/text_generation/test_text_generation_local.py +++ b/tests/modules/text_generation/test_text_generation_local.py @@ -209,3 +209,9 @@ def test_zero_epoch_case(disable_wip): } model = TextGeneration.train(**train_kwargs) assert isinstance(model.model, HFAutoSeq2SeqLM) + + +def test_run_tokenizer_not_implemented(): + with pytest.raises(NotImplementedError): + model = TextGeneration.bootstrap(SEQ2SEQ_LM_MODEL) + model.run_tokenizer("This text doesn't matter")