From 5b3e42a32ef5b9ac67fa840272860568589cd1b2 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Mon, 23 Oct 2023 17:18:25 +0800 Subject: [PATCH] Fix ut Signed-off-by: Jael Gu --- .github/workflows/unit_test.yml | 1 - config.py | 2 +- src_langchain/llm/ernie.py | 21 ++++++++- test_requirements.txt | 15 ------- tests/requirements.txt | 3 +- .../src_langchain/llm/test_ernie.py | 2 +- .../src_towhee/pipelines/test_pipelines.py | 44 ++++++++++++------- 7 files changed, 52 insertions(+), 36 deletions(-) delete mode 100644 test_requirements.txt diff --git a/.github/workflows/unit_test.yml b/.github/workflows/unit_test.yml index 54051fd..4987732 100644 --- a/.github/workflows/unit_test.yml +++ b/.github/workflows/unit_test.yml @@ -31,7 +31,6 @@ jobs: pip install coverage pip install pytest pip install -r requirements.txt - pip install -r test_requirements.txt - name: Install test dependency shell: bash working-directory: tests diff --git a/config.py b/config.py index 42e2bfd..2cdb9f9 100644 --- a/config.py +++ b/config.py @@ -5,7 +5,7 @@ ################## LLM ################## LLM_OPTION = os.getenv('LLM_OPTION', 'openai') # select your LLM service -LANGUAGE = os.getenv('LANGUAGE', 'en') # options: en, zh +LANGUAGE = os.getenv('DOC_LANGUAGE', 'en') # options: en, zh CHAT_CONFIG = { 'openai': { 'openai_model': 'gpt-3.5-turbo', diff --git a/src_langchain/llm/ernie.py b/src_langchain/llm/ernie.py index 09f163b..e64630f 100644 --- a/src_langchain/llm/ernie.py +++ b/src_langchain/llm/ernie.py @@ -1,5 +1,5 @@ from config import CHAT_CONFIG # pylint: disable=C0413 -from typing import Any, List, Dict +from typing import Any, List, Dict, Optional import os import sys @@ -20,7 +20,24 @@ class ChatLLM(BaseChatModel): eb_access_token: str = CHAT_CONFIG['eb_access_token'] or os.getenv('EB_ACCESS_TOKEN') llm_kwargs: dict = llm_kwargs - def _generate(self, messages: List[BaseMessage]) -> ChatResult: + def _generate(self, messages: List[BaseMessage], stop: Optional[List[str]] = None) -> ChatResult: + import erniebot # pylint: disable=C0415 + erniebot.api_type = self.eb_api_type + erniebot.access_token = self.eb_access_token + + message_dicts = self._create_message_dicts(messages) + response = erniebot.ChatCompletion.create( + model=self.model_name, + messages=message_dicts, + **self.llm_kwargs + ) + return self._create_chat_result(response) + + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None + ) -> ChatResult: import erniebot # pylint: disable=C0415 erniebot.api_type = self.eb_api_type erniebot.access_token = self.eb_access_token diff --git a/test_requirements.txt b/test_requirements.txt deleted file mode 100644 index 1f50c5a..0000000 --- a/test_requirements.txt +++ /dev/null @@ -1,15 +0,0 @@ -langchain==0.0.230 -unstructured -pexpect -pdf2image -SQLAlchemy>=2.0.15 -psycopg2-binary -openai -gradio>=3.30.0 -fastapi -uvicorn -towhee>=1.1.0 -pymilvus -elasticsearch>=8.0.0 -prometheus-client -erniebot diff --git a/tests/requirements.txt b/tests/requirements.txt index 85c17a9..52ef6e6 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -6,4 +6,5 @@ milvus transformers dashscope zhipuai -sentence-transformers \ No newline at end of file +sentence-transformers +erniebot \ No newline at end of file diff --git a/tests/unit_tests/src_langchain/llm/test_ernie.py b/tests/unit_tests/src_langchain/llm/test_ernie.py index c67209d..9c53822 100644 --- a/tests/unit_tests/src_langchain/llm/test_ernie.py +++ b/tests/unit_tests/src_langchain/llm/test_ernie.py @@ -38,7 +38,7 @@ def test_generate(self): AIMessage(content='hello, can I help you?'), HumanMessage(content='Please give me a mock answer.'), ] - res = chat_llm._generate(messages) + res = chat_llm.generate(messages) self.assertEqual(res.generations[0].text, 'OK, this is a mock answer.') diff --git a/tests/unit_tests/src_towhee/pipelines/test_pipelines.py b/tests/unit_tests/src_towhee/pipelines/test_pipelines.py index 1f84e24..004012f 100644 --- a/tests/unit_tests/src_towhee/pipelines/test_pipelines.py +++ b/tests/unit_tests/src_towhee/pipelines/test_pipelines.py @@ -55,11 +55,13 @@ def create_pipelines(llm_src): class TestPipelines(unittest.TestCase): project = 'akcio_ut' - data_src = 'https://towhee.io' + data_src = 'akcio_ut.txt' question = 'test question' @classmethod def setUpClass(cls): + with open(cls.data_src, 'w+', encoding='utf-8') as tmp_f: + tmp_f.write('This is test content.') milvus_server.cleanup() milvus_server.start() @@ -84,7 +86,7 @@ def test_openai(self): token_count = 0 for x in res: token_count += x[0]['token_count'] - assert token_count == 261 + assert token_count == 5 num = pipelines.count_entities(self.project)['vector store'] assert len(res) <= num @@ -121,7 +123,7 @@ def test_chatglm(self): token_count = 0 for x in res: token_count += x[0]['token_count'] - assert token_count == 261 + assert token_count == 5 num = pipelines.count_entities(self.project)['vector store'] assert len(res) <= num @@ -137,13 +139,24 @@ def test_chatglm(self): def test_ernie(self): - class MockRequest: - def json(self): - return {'result': MOCK_ANSWER} - - with patch('requests.request') as mock_llm: - - mock_llm.return_value = MockRequest() + from erniebot.response import EBResponse + + with patch('erniebot.ChatCompletion.create') as mock_post: + mock_res = EBResponse(code=200, + body={'id': 'as-0000000000', 'object': 'chat.completion', 'created': 11111111, + 'result': MOCK_ANSWER, + 'usage': {'prompt_tokens': 1, 'completion_tokens': 13, 'total_tokens': 14}, + 'need_clear_history': False, 'is_truncated': False}, + headers={'Connection': 'keep-alive', + 'Content-Security-Policy': 'frame-ancestors https://*.baidu.com/', + 'Content-Type': 'application/json', 'Date': 'Mon, 23 Oct 2023 03:30:53 GMT', + 'Server': 'nginx', 'Statement': 'AI-generated', + 'Vary': 'Origin, Access-Control-Request-Method, Access-Control-Request-Headers', + 'X-Frame-Options': 'allow-from https://*.baidu.com/', + 'X-Request-Id': '0' * 32, + 'Transfer-Encoding': 'chunked'} + ) + mock_post.return_value = mock_res pipelines = create_pipelines('ernie') @@ -160,7 +173,7 @@ def json(self): token_count = 0 for x in res: token_count += x[0]['token_count'] - assert token_count == 261 + assert token_count == 5 num = pipelines.count_entities(self.project)['vector store'] assert len(res) <= num @@ -205,7 +218,7 @@ def output(self): token_count = 0 for x in res: token_count += x[0]['token_count'] - assert token_count == 261 + assert token_count == 5 num = pipelines.count_entities(self.project)['vector store'] assert len(res) <= num @@ -244,7 +257,7 @@ def json(self): token_count = 0 for x in res: token_count += x[0]['token_count'] - assert token_count == 261 + assert token_count == 5 num = pipelines.count_entities(self.project)['vector store'] assert len(res) <= num @@ -285,7 +298,7 @@ def iter_lines(self): token_count = 0 for x in res: token_count += x[0]['token_count'] - assert token_count == 261 + assert token_count == 5 num = pipelines.count_entities(self.project)['vector store'] assert len(res) <= num @@ -323,7 +336,7 @@ def __call__(self, *args, **kwargs): token_count = 0 for x in res: token_count += x[0]['token_count'] - assert token_count == 261 + assert token_count == 5 num = pipelines.count_entities(self.project)['vector store'] assert len(res) <= num @@ -339,6 +352,7 @@ def __call__(self, *args, **kwargs): @classmethod def tearDownClass(cls): + os.remove(cls.data_src) milvus_server.stop() milvus_server.cleanup()