-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
xrsrke
committed
Apr 5, 2023
1 parent
345238c
commit 8923a8e
Showing
10 changed files
with
202 additions
and
82 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,53 @@ | ||
import torch | ||
import pytest | ||
from langchain import PromptTemplate | ||
|
||
from toolformer.model import ToolFormer | ||
from toolformer.api import BaseAPI | ||
from toolformer.prompt import calculator_prompt | ||
|
||
@pytest.mark.skip(reason="haven't implemented yet") | ||
|
||
class CalculatorAPI(BaseAPI): | ||
def __call__(self, text): | ||
return str(4269) | ||
|
||
|
||
calculator_api = CalculatorAPI( | ||
name="Calculator", | ||
prompt_template=calculator_prompt | ||
) | ||
|
||
|
||
# @pytest.mark.skip(reason="haven't implemented yet") | ||
def test_inference(model, tokenizer, default_config): | ||
text = "What is the sum of 42 and 69?" | ||
target_output = 111 | ||
text = "From this, we have 10 - 5 minutes = 5 minutes." | ||
|
||
encoded_text = tokenizer(text, return_tensors="pt") | ||
toolformer = ToolFormer(model, apis=[], config=default_config) | ||
# After fine-tune a model with augmented data, | ||
# the model should be able to call the API without few-shot learning | ||
prompt_template = PromptTemplate( | ||
input_variables=["input"], | ||
template=calculator_prompt | ||
) | ||
input = prompt_template.format(input=text) | ||
target_output = str(4269) # from the calculator API | ||
|
||
encoded_text = tokenizer(input, return_tensors="pt") | ||
toolformer = ToolFormer( | ||
model, | ||
apis=[calculator_api], | ||
config=default_config | ||
) | ||
|
||
output_ids = toolformer( | ||
input_ids=encoded_text["input_ids"], | ||
attention_mask=encoded_text["attention_mask"] | ||
attention_mask=encoded_text["attention_mask"], | ||
max_new_tokens=30, | ||
) | ||
|
||
assert isinstance(output_ids, torch.Tensor) | ||
assert output_ids.ndim == 2 | ||
assert output_ids[0].shape[-1] > len(encoded_text["input_ids"][0]) | ||
assert target_output in tokenizer.decode(output_ids[0], skip_special_tokens=True) | ||
assert target_output in tokenizer.decode( | ||
output_ids[0], | ||
skip_special_tokens=True | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,26 @@ | ||
from toolformer.utils import extract_api_request_content | ||
import pytest | ||
|
||
def test_extract_api_request_content(): | ||
from toolformer.utils import extract_api_content, extract_api_name | ||
|
||
def test_extract_api_content(): | ||
text = "From this, we have 10 - 5 minutes = [Calculator(10 - 5)] 5 minutes." | ||
# text = "From this, we have 10 - 5 minutes = [Calculator((2+3) - 1)] 5 minutes." # TODO: add test case for this | ||
target = "10 - 5" | ||
|
||
output = extract_api_request_content(text, api_name = "Calculator") | ||
output = extract_api_content(text, api_name="Calculator") | ||
|
||
assert isinstance(output, str) | ||
assert output == target | ||
|
||
@pytest.mark.parametrize( | ||
"text, is_end_token, target", | ||
[ | ||
("From this, we have 10 - 5 minutes = [Calculator(10 - 5)] 5 minutes.", True, "Calculator"), | ||
("[Calculator(10 - 5)", False, "Calculator"), | ||
], | ||
) | ||
def test_extract_api_name(text, is_end_token, target): | ||
output = extract_api_name(text, is_end_token=is_end_token) | ||
|
||
assert isinstance(output, str) | ||
assert output == target | ||
assert output == target |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.