Skip to content

Commit

Permalink
add infernece
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Apr 5, 2023
1 parent 345238c commit 8923a8e
Show file tree
Hide file tree
Showing 10 changed files with 202 additions and 82 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.
<!-- [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
[![docs](https://img.shields.io/github/deployments/vwxyzjn/cleanrl/Production?label=docs&logo=vercel)](https://xrsrke.github.io/instructGOOSE/) -->

![image.png](index_files/figure-commonmark/e5f7b2fa-1-image.png)
![image.png](index_files/figure-commonmark/08f39f23-1-image.png)

Paper: [Toolformer: Language Models Can Teach Themselves to Use
Tools](https://arxiv.org/abs/2302.04761)
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
26 changes: 24 additions & 2 deletions nbs/01_utils.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@
"source": [
"#| export\n",
"import yaml\n",
"import re"
"import re\n",
"from typing import Optional"
]
},
{
Expand All @@ -70,7 +71,7 @@
"outputs": [],
"source": [
"#| export\n",
"def extract_api_request_content(text: str, api_name: str) -> str:\n",
"def extract_api_content(text: str, api_name: str) -> str:\n",
" \"\"\"Extract the content of an API request from a given text.\"\"\"\n",
" start_tag = f\"{api_name}(\"\n",
" end_tag = \")\"\n",
Expand All @@ -97,6 +98,27 @@
" matches = re.findall(pattern, text)\n",
" return matches"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"def extract_api_name(text: str, is_end_token: bool = True) -> Optional[str]:\n",
" if is_end_token:\n",
" pattern = r'\\[(\\w+)\\(.+\\]\\s?'\n",
" else:\n",
" pattern = r'\\[(\\w+)\\(.+\\s?'\n",
" \n",
" match = re.search(pattern, text)\n",
"\n",
" if match:\n",
" return match.group(1)\n",
" else:\n",
" return None"
]
}
],
"metadata": {
Expand Down
14 changes: 1 addition & 13 deletions nbs/03_api.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -101,19 +101,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'BaseAPI' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[1], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[39m#| export\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[39mclass\u001b[39;00m \u001b[39mCalculatorAPI\u001b[39;00m(BaseAPI):\n\u001b[1;32m 3\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mexecute\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39minput\u001b[39m: \u001b[39mstr\u001b[39m) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m \u001b[39mstr\u001b[39m:\n\u001b[1;32m 4\u001b[0m \u001b[39mtry\u001b[39;00m:\n",
"\u001b[0;31mNameError\u001b[0m: name 'BaseAPI' is not defined"
]
}
],
"outputs": [],
"source": [
"#| export\n",
"class CalculatorAPI(BaseAPI):\n",
Expand Down
87 changes: 50 additions & 37 deletions nbs/05_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/education/DATA/projects/ai/toolformer/env/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"outputs": [],
"source": [
"#| export\n",
"from typing import Optional, List\n",
Expand All @@ -63,23 +54,10 @@
"\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"from torchtyping import TensorType\n",
"from einops import rearrange\n",
"\n",
"from toolformer.api import BaseAPI\n",
"from toolformer.utils import extract_api_request_content"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# detect\n",
"# wait for end of token\n",
"# extract\n",
"# excute\n",
"# add the result to the input\n",
"# continue"
"from toolformer.utils import extract_api_content, extract_api_name"
]
},
{
Expand Down Expand Up @@ -113,20 +91,37 @@
" self.api_start_token_id = tokenizer(f' {start_character}', return_tensors=\"pt\")[\"input_ids\"][0]\n",
" self.api_end_token_id = tokenizer(end_character, return_tensors=\"pt\")[\"input_ids\"][0]\n",
" self.api_output_token_id = tokenizer(f'{output_character}', return_tensors=\"pt\")[\"input_ids\"][0]\n",
" \n",
"\n",
" self.eos_token_ids = tokenizer(\n",
" [\".\", \".\\n\\n\"],\n",
" return_tensors=\"pt\"\n",
" )[\"input_ids\"].squeeze()\n",
"\n",
" # TODO: support batch\n",
" self.api_request_content: torch.Tensor = torch.tensor([])\n",
" \n",
" def _sampling(self, probs: TensorType[\"batch_size\", \"seq_len\"]) -> TensorType[\"batch_size\", \"seq_len\"]:\n",
" return torch.argmax(probs, dim=-1)\n",
" \n",
" def execute_api(self, text_ids: TensorType[\"seq_len\"]) -> TensorType[\"seq_len\"]:\n",
" def execute_api(self, text_ids: TensorType[\"seq_len\"]) -> Optional[TensorType[\"seq_len\"]]:\n",
" \"\"\"Execute an API call.\"\"\"\n",
" # content_ids = extract_api_request_content(text_ids, self.apis)\n",
" pass\n",
" text = self.tokenizer.decode(text_ids, skip_special_tokens=True)\n",
" api_name = extract_api_name(text, is_end_token=False)\n",
"\n",
" if api_name is not None:\n",
" # find does apis contains the api_name\n",
" for api in self.apis:\n",
" if api.name == api_name:\n",
" api_content = extract_api_content(text, api_name=api_name)\n",
" api_output = api(api_content)\n",
" return self.tokenizer(api_output, return_tensors=\"pt\")[\"input_ids\"][0]\n",
" return None\n",
" \n",
" def add_idx_to_api_request_content(self, idx: TensorType[1]):\n",
" self.api_request_content = torch.cat([self.api_request_content, idx.unsqueeze(0)], dim=0)\n",
" self.api_request_content = torch.cat([\n",
" self.api_request_content,\n",
" rearrange(idx, '... -> 1 ...')\n",
" ], dim=-1).long()\n",
" \n",
" def forward(\n",
" self,\n",
Expand All @@ -136,7 +131,6 @@
" **kwargs\n",
" ) -> TensorType[\"batch_size\", \"seq_len\"]:\n",
" # check padding to the left\n",
" \n",
" generated_ids = input_ids\n",
" \n",
" for _ in range(max_new_tokens):\n",
Expand All @@ -148,15 +142,23 @@
" \n",
" logits = output_ids.logits[:, -1, :]\n",
" probs = F.softmax(logits, dim=-1)\n",
" _, top_k_idx = torch.topk(probs, k=5, dim=-1)\n",
" # TODO: k should be a config\n",
" _, top_k_idx = torch.topk(probs, k=1, dim=-1)\n",
" \n",
" if self.is_calling_api is True:\n",
" if self.api_end_token_id in top_k_idx:\n",
" # if the api end token is in the top_k_idx, then we will execute the api\n",
" # and then add api_end_token_id to the generated_ids\n",
" self.add_idx_to_api_request_content(self.api_end_token_id)\n",
" api_output_ids = self.execute_api(self.api_request_content)\n",
" pred_ids = torch.tensor([self.api_end_token_id, api_output_ids])\n",
" # TODO: add support batch\n",
" api_output_ids = self.execute_api(self.api_request_content[0])\n",
" if api_output_ids is not None:\n",
" pred_ids = torch.cat([\n",
" self.api_output_token_id,\n",
" api_output_ids,\n",
" self.api_end_token_id\n",
" ], dim=-1).long()\n",
" else:\n",
" pred_ids = self.api_end_token_id\n",
" self.is_calling_api = False\n",
" else:\n",
" pred_ids = self._sampling(probs)\n",
Expand All @@ -170,8 +172,19 @@
" else:\n",
" pred_ids = self._sampling(probs)\n",
" \n",
" generated_ids = torch.cat([generated_ids, pred_ids.unsqueeze(dim=1)], dim=1)\n",
" attention_mask = torch.cat([attention_mask, torch.ones_like(pred_ids).unsqueeze(dim=1)], dim=1)\n",
" generated_ids = torch.cat([\n",
" generated_ids,\n",
" rearrange(pred_ids, '... -> 1 ...')\n",
" ], dim=1)\n",
" \n",
" attention_mask = torch.cat([\n",
" attention_mask,\n",
" rearrange(torch.ones_like(pred_ids), '... -> 1 ...')\n",
" ], dim=1)\n",
" \n",
" # ignore the case that pred_ids contains api_output\n",
" if len(pred_ids) == 1 and pred_ids in self.eos_token_ids:\n",
" break\n",
" \n",
" return generated_ids"
]
Expand Down
45 changes: 38 additions & 7 deletions tests/test_model.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,53 @@
import torch
import pytest
from langchain import PromptTemplate

from toolformer.model import ToolFormer
from toolformer.api import BaseAPI
from toolformer.prompt import calculator_prompt

@pytest.mark.skip(reason="haven't implemented yet")

class CalculatorAPI(BaseAPI):
def __call__(self, text):
return str(4269)


calculator_api = CalculatorAPI(
name="Calculator",
prompt_template=calculator_prompt
)


# @pytest.mark.skip(reason="haven't implemented yet")
def test_inference(model, tokenizer, default_config):
text = "What is the sum of 42 and 69?"
target_output = 111
text = "From this, we have 10 - 5 minutes = 5 minutes."

encoded_text = tokenizer(text, return_tensors="pt")
toolformer = ToolFormer(model, apis=[], config=default_config)
# After fine-tune a model with augmented data,
# the model should be able to call the API without few-shot learning
prompt_template = PromptTemplate(
input_variables=["input"],
template=calculator_prompt
)
input = prompt_template.format(input=text)
target_output = str(4269) # from the calculator API

encoded_text = tokenizer(input, return_tensors="pt")
toolformer = ToolFormer(
model,
apis=[calculator_api],
config=default_config
)

output_ids = toolformer(
input_ids=encoded_text["input_ids"],
attention_mask=encoded_text["attention_mask"]
attention_mask=encoded_text["attention_mask"],
max_new_tokens=30,
)

assert isinstance(output_ids, torch.Tensor)
assert output_ids.ndim == 2
assert output_ids[0].shape[-1] > len(encoded_text["input_ids"][0])
assert target_output in tokenizer.decode(output_ids[0], skip_special_tokens=True)
assert target_output in tokenizer.decode(
output_ids[0],
skip_special_tokens=True
)
23 changes: 19 additions & 4 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,26 @@
from toolformer.utils import extract_api_request_content
import pytest

def test_extract_api_request_content():
from toolformer.utils import extract_api_content, extract_api_name

def test_extract_api_content():
text = "From this, we have 10 - 5 minutes = [Calculator(10 - 5)] 5 minutes."
# text = "From this, we have 10 - 5 minutes = [Calculator((2+3) - 1)] 5 minutes." # TODO: add test case for this
target = "10 - 5"

output = extract_api_request_content(text, api_name = "Calculator")
output = extract_api_content(text, api_name="Calculator")

assert isinstance(output, str)
assert output == target

@pytest.mark.parametrize(
"text, is_end_token, target",
[
("From this, we have 10 - 5 minutes = [Calculator(10 - 5)] 5 minutes.", True, "Calculator"),
("[Calculator(10 - 5)", False, "Calculator"),
],
)
def test_extract_api_name(text, is_end_token, target):
output = extract_api_name(text, is_end_token=is_end_token)

assert isinstance(output, str)
assert output == target
assert output == target
4 changes: 2 additions & 2 deletions toolformer/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
'toolformer.model.ToolFormer.execute_api': ('model.html#toolformer.execute_api', 'toolformer/model.py'),
'toolformer.model.ToolFormer.forward': ('model.html#toolformer.forward', 'toolformer/model.py')},
'toolformer.prompt': {},
'toolformer.utils': { 'toolformer.utils.extract_api_request_content': ( 'utils.html#extract_api_request_content',
'toolformer/utils.py'),
'toolformer.utils': { 'toolformer.utils.extract_api_content': ('utils.html#extract_api_content', 'toolformer/utils.py'),
'toolformer.utils.extract_api_name': ('utils.html#extract_api_name', 'toolformer/utils.py'),
'toolformer.utils.extract_api_syntax': ('utils.html#extract_api_syntax', 'toolformer/utils.py'),
'toolformer.utils.yaml2dict': ('utils.html#yaml2dict', 'toolformer/utils.py')}}}
Loading

0 comments on commit 8923a8e

Please sign in to comment.