Skip to content

Commit

Permalink
release v0.2.3
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Feb 24, 2024
1 parent 0e64c7b commit ab95159
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 15 deletions.
6 changes: 2 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
numpy
sse-starlette
infinity-emb[torch]==0.0.17
openai>=1.5.0
transformers>=4.37.2
vllm>=0.3.0
sse-starlette
vllm==0.3.2
2 changes: 1 addition & 1 deletion src/imitater/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.2.2"
__version__ = "0.2.3"
17 changes: 7 additions & 10 deletions src/imitater/model/chat_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, fields
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Optional, Tuple, Union

from transformers import AutoTokenizer, GenerationConfig
from typing_extensions import Self
Expand All @@ -12,8 +12,6 @@
if TYPE_CHECKING:
from argparse import ArgumentParser, Namespace

from vllm import RequestOutput


@dataclass
class ChatConfig:
Expand Down Expand Up @@ -120,9 +118,7 @@ def _load_generation_config(self) -> None:
{"additional_special_tokens": extra_special_tokens}, replace_additional_special_tokens=False
)

async def _generate(
self, messages: List[Dict[str, str]], request_id: str, **gen_kwargs
) -> AsyncIterator["RequestOutput"]:
async def _generate(self, messages: List[Dict[str, str]], request_id: str, **gen_kwargs):
input_ids = self._tokenizer.apply_chat_template(
conversation=messages, tokenize=True, add_generation_prompt=True
)
Expand Down Expand Up @@ -157,7 +153,7 @@ async def chat(self, messages: List[Dict[str, str]], request_id: str, **gen_kwar
generated_text, prompt_tokens, completion_tokens = "", 0, 0
generator = await self._generate(messages, request_id, **gen_kwargs)
async for result in generator:
if result.finished:
if not result.finished:
generated_text = result.outputs[0].text
prompt_tokens = len(result.prompt_token_ids)
completion_tokens = len(result.outputs[0].token_ids)
Expand All @@ -184,9 +180,10 @@ async def stream_chat(
generated_text = ""
generator = await self._generate(messages, request_id, **gen_kwargs)
async for result in generator:
delta_text = result.outputs[0].text[len(generated_text) :]
generated_text = result.outputs[0].text
yield delta_text
if not result.finished:
delta_text = result.outputs[0].text[len(generated_text) :]
generated_text = result.outputs[0].text
yield delta_text

async def function_call(
self, messages: List[Dict[str, str]], tools: List[Dict[str, Any]], request_id: str, **gen_kwargs
Expand Down

0 comments on commit ab95159

Please sign in to comment.