Skip to content

Commit

Permalink
Merge pull request caikit#262 from dtrifiro/improve-test-times
Browse files Browse the repository at this point in the history
tests: make models fixtures session-scoped
  • Loading branch information
gkumbhat authored Nov 9, 2023
2 parents 4be54cf + 84a7924 commit 2682ef3
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 10 deletions.
10 changes: 5 additions & 5 deletions tests/fixtures/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def models_cache_dir(request):

### Fixtures for grabbing a randomly initialized model to test interfaces against
## Causal LM
@pytest.fixture
@pytest.fixture(scope="session")
def causal_lm_train_kwargs():
"""Get the kwargs for a valid train call to a Causal LM."""
model_kwargs = {
Expand All @@ -124,15 +124,15 @@ def causal_lm_train_kwargs():
return model_kwargs


@pytest.fixture
@pytest.fixture(scope="session")
def causal_lm_dummy_model(causal_lm_train_kwargs):
"""Train a Causal LM dummy model."""
return caikit_nlp.modules.text_generation.PeftPromptTuning.train(
**causal_lm_train_kwargs
)


@pytest.fixture
@pytest.fixture(scope="session")
def saved_causal_lm_dummy_model(causal_lm_dummy_model):
"""Give a path to a saved dummy model that can be loaded"""
with tempfile.TemporaryDirectory() as workdir:
Expand All @@ -142,7 +142,7 @@ def saved_causal_lm_dummy_model(causal_lm_dummy_model):


## Seq2seq
@pytest.fixture
@pytest.fixture(scope="session")
def seq2seq_lm_train_kwargs():
"""Get the kwargs for a valid train call to a Causal LM."""
model_kwargs = {
Expand All @@ -158,7 +158,7 @@ def seq2seq_lm_train_kwargs():
return model_kwargs


@pytest.fixture
@pytest.fixture(scope="session")
def seq2seq_lm_dummy_model(seq2seq_lm_train_kwargs):
"""Train a Seq2Seq LM dummy model."""
return caikit_nlp.modules.text_generation.PeftPromptTuning.train(
Expand Down
14 changes: 9 additions & 5 deletions tests/modules/text_generation/test_peft_prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,17 @@ def test_run_stream_out_model(causal_lm_dummy_model):
assert isinstance(pred, GeneratedTextStreamResult)


def test_verbalizer_rendering(causal_lm_dummy_model):
def test_verbalizer_rendering(causal_lm_dummy_model, monkeypatch):
"""Ensure that our model renders its verbalizer text correctly before calling tokenizer."""
# Mock the tokenizer; we want to make sure its inputs are rendered properly
causal_lm_dummy_model.tokenizer = mock.Mock(
side_effect=RuntimeError("Tokenizer is a mock!"),
# Set eos token property to be attribute of tokenizer
eos_token="</s>",
monkeypatch.setattr(
causal_lm_dummy_model,
"tokenizer",
mock.Mock(
side_effect=RuntimeError("Tokenizer is a mock!"),
# Set eos token property to be attribute of tokenizer
eos_token="</s>",
),
)
input_text = "This text doesn't matter"
causal_lm_dummy_model.verbalizer = " | {{input}} |"
Expand Down

0 comments on commit 2682ef3

Please sign in to comment.