Skip to content

Commit

Permalink
Fix ut
Browse files Browse the repository at this point in the history
Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
  • Loading branch information
jaelgu committed Oct 23, 2023
1 parent 0f4893c commit 009463d
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 34 deletions.
1 change: 0 additions & 1 deletion .github/workflows/unit_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ jobs:
pip install coverage
pip install pytest
pip install -r requirements.txt
pip install -r test_requirements.txt
- name: Install test dependency
shell: bash
working-directory: tests
Expand Down
2 changes: 1 addition & 1 deletion config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

################## LLM ##################
LLM_OPTION = os.getenv('LLM_OPTION', 'openai') # select your LLM service
LANGUAGE = os.getenv('LANGUAGE', 'en') # options: en, zh
LANGUAGE = os.getenv('DOC_LANGUAGE', 'en') # options: en, zh
CHAT_CONFIG = {
'openai': {
'openai_model': 'gpt-3.5-turbo',
Expand Down
16 changes: 16 additions & 0 deletions src_langchain/llm/ernie.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,22 @@ def _generate(self, messages: List[BaseMessage]) -> ChatResult:
)
return self._create_chat_result(response)

async def _agenerate(
self,
messages: List[BaseMessage],
) -> ChatResult:
import erniebot # pylint: disable=C0415
erniebot.api_type = self.eb_api_type
erniebot.access_token = self.eb_access_token

message_dicts = self._create_message_dicts(messages)
response = erniebot.ChatCompletion.create(
model=self.model_name,
messages=message_dicts,
**self.llm_kwargs
)
return self._create_chat_result(response)

def _create_message_dicts(
self, messages: List[BaseMessage]
) -> List[Dict[str, Any]]:
Expand Down
15 changes: 0 additions & 15 deletions test_requirements.txt

This file was deleted.

3 changes: 2 additions & 1 deletion tests/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ milvus
transformers
dashscope
zhipuai
sentence-transformers
sentence-transformers
erniebot
2 changes: 1 addition & 1 deletion tests/unit_tests/src_langchain/llm/test_ernie.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_generate(self):
AIMessage(content='hello, can I help you?'),
HumanMessage(content='Please give me a mock answer.'),
]
res = chat_llm._generate(messages)
res = chat_llm.generate(messages)
self.assertEqual(res.generations[0].text, 'OK, this is a mock answer.')


Expand Down
44 changes: 29 additions & 15 deletions tests/unit_tests/src_towhee/pipelines/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,13 @@ def create_pipelines(llm_src):

class TestPipelines(unittest.TestCase):
project = 'akcio_ut'
data_src = 'https://towhee.io'
data_src = 'akcio_ut.txt'
question = 'test question'

@classmethod
def setUpClass(cls):
with open(cls.data_src, 'w+', encoding='utf-8') as tmp_f:
tmp_f.write('This is test content.')
milvus_server.cleanup()
milvus_server.start()

Expand All @@ -84,7 +86,7 @@ def test_openai(self):
token_count = 0
for x in res:
token_count += x[0]['token_count']
assert token_count == 261
assert token_count == 5
num = pipelines.count_entities(self.project)['vector store']
assert len(res) <= num

Expand Down Expand Up @@ -121,7 +123,7 @@ def test_chatglm(self):
token_count = 0
for x in res:
token_count += x[0]['token_count']
assert token_count == 261
assert token_count == 5
num = pipelines.count_entities(self.project)['vector store']
assert len(res) <= num

Expand All @@ -137,13 +139,24 @@ def test_chatglm(self):

def test_ernie(self):

class MockRequest:
def json(self):
return {'result': MOCK_ANSWER}

with patch('requests.request') as mock_llm:

mock_llm.return_value = MockRequest()
from erniebot.response import EBResponse

with patch('erniebot.ChatCompletion.create') as mock_post:
mock_res = EBResponse(code=200,
body={'id': 'as-0000000000', 'object': 'chat.completion', 'created': 11111111,
'result': MOCK_ANSWER,
'usage': {'prompt_tokens': 1, 'completion_tokens': 13, 'total_tokens': 14},
'need_clear_history': False, 'is_truncated': False},
headers={'Connection': 'keep-alive',
'Content-Security-Policy': 'frame-ancestors https://*.baidu.com/',
'Content-Type': 'application/json', 'Date': 'Mon, 23 Oct 2023 03:30:53 GMT',
'Server': 'nginx', 'Statement': 'AI-generated',
'Vary': 'Origin, Access-Control-Request-Method, Access-Control-Request-Headers',
'X-Frame-Options': 'allow-from https://*.baidu.com/',
'X-Request-Id': '0' * 32,
'Transfer-Encoding': 'chunked'}
)
mock_post.return_value = mock_res

pipelines = create_pipelines('ernie')

Expand All @@ -160,7 +173,7 @@ def json(self):
token_count = 0
for x in res:
token_count += x[0]['token_count']
assert token_count == 261
assert token_count == 5
num = pipelines.count_entities(self.project)['vector store']
assert len(res) <= num

Expand Down Expand Up @@ -205,7 +218,7 @@ def output(self):
token_count = 0
for x in res:
token_count += x[0]['token_count']
assert token_count == 261
assert token_count == 5
num = pipelines.count_entities(self.project)['vector store']
assert len(res) <= num

Expand Down Expand Up @@ -244,7 +257,7 @@ def json(self):
token_count = 0
for x in res:
token_count += x[0]['token_count']
assert token_count == 261
assert token_count == 5
num = pipelines.count_entities(self.project)['vector store']
assert len(res) <= num

Expand Down Expand Up @@ -285,7 +298,7 @@ def iter_lines(self):
token_count = 0
for x in res:
token_count += x[0]['token_count']
assert token_count == 261
assert token_count == 5
num = pipelines.count_entities(self.project)['vector store']
assert len(res) <= num

Expand Down Expand Up @@ -323,7 +336,7 @@ def __call__(self, *args, **kwargs):
token_count = 0
for x in res:
token_count += x[0]['token_count']
assert token_count == 261
assert token_count == 5
num = pipelines.count_entities(self.project)['vector store']
assert len(res) <= num

Expand All @@ -339,6 +352,7 @@ def __call__(self, *args, **kwargs):

@classmethod
def tearDownClass(cls):
os.remove(cls.data_src)
milvus_server.stop()
milvus_server.cleanup()

Expand Down

0 comments on commit 009463d

Please sign in to comment.