From dccf886b547a85fe99559efac9e2894d8e0780c1 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Mon, 23 Oct 2023 17:18:25 +0800 Subject: [PATCH] Fix ut Signed-off-by: Jael Gu --- .github/workflows/unit_test.yml | 1 - config.py | 2 +- src_langchain/llm/ernie.py | 26 +++++++++-- test_requirements.txt | 15 ------- tests/requirements.txt | 3 +- .../src_langchain/llm/test_ernie.py | 2 +- .../src_towhee/pipelines/test_pipelines.py | 44 ++++++++++++------- 7 files changed, 56 insertions(+), 37 deletions(-) delete mode 100644 test_requirements.txt diff --git a/.github/workflows/unit_test.yml b/.github/workflows/unit_test.yml index 54051fd..4987732 100644 --- a/.github/workflows/unit_test.yml +++ b/.github/workflows/unit_test.yml @@ -31,7 +31,6 @@ jobs: pip install coverage pip install pytest pip install -r requirements.txt - pip install -r test_requirements.txt - name: Install test dependency shell: bash working-directory: tests diff --git a/config.py b/config.py index 42e2bfd..2cdb9f9 100644 --- a/config.py +++ b/config.py @@ -5,7 +5,7 @@ ################## LLM ################## LLM_OPTION = os.getenv('LLM_OPTION', 'openai') # select your LLM service -LANGUAGE = os.getenv('LANGUAGE', 'en') # options: en, zh +LANGUAGE = os.getenv('DOC_LANGUAGE', 'en') # options: en, zh CHAT_CONFIG = { 'openai': { 'openai_model': 'gpt-3.5-turbo', diff --git a/src_langchain/llm/ernie.py b/src_langchain/llm/ernie.py index 09f163b..21eb43a 100644 --- a/src_langchain/llm/ernie.py +++ b/src_langchain/llm/ernie.py @@ -1,5 +1,4 @@ -from config import CHAT_CONFIG # pylint: disable=C0413 -from typing import Any, List, Dict +from typing import Any, List, Dict, Optional import os import sys @@ -9,6 +8,9 @@ sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) +from config import CHAT_CONFIG # pylint: disable=C0413 + + CHAT_CONFIG = CHAT_CONFIG['ernie'] llm_kwargs = CHAT_CONFIG.get('llm_kwargs', {}) @@ -20,7 +22,25 @@ class ChatLLM(BaseChatModel): eb_access_token: str = CHAT_CONFIG['eb_access_token'] or os.getenv('EB_ACCESS_TOKEN') llm_kwargs: dict = llm_kwargs - def _generate(self, messages: List[BaseMessage]) -> ChatResult: + def _generate(self, messages: List[BaseMessage], stop: Optional[List[str]] = None) -> ChatResult: # pylint: disable=W0613 + import erniebot # pylint: disable=C0415 + erniebot.api_type = self.eb_api_type + erniebot.access_token = self.eb_access_token + + message_dicts = self._create_message_dicts(messages) + + response = erniebot.ChatCompletion.create( + model=self.model_name, + messages=message_dicts, + **self.llm_kwargs + ) + return self._create_chat_result(response) + + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None # pylint: disable=W0613 + ) -> ChatResult: import erniebot # pylint: disable=C0415 erniebot.api_type = self.eb_api_type erniebot.access_token = self.eb_access_token diff --git a/test_requirements.txt b/test_requirements.txt deleted file mode 100644 index 1f50c5a..0000000 --- a/test_requirements.txt +++ /dev/null @@ -1,15 +0,0 @@ -langchain==0.0.230 -unstructured -pexpect -pdf2image -SQLAlchemy>=2.0.15 -psycopg2-binary -openai -gradio>=3.30.0 -fastapi -uvicorn -towhee>=1.1.0 -pymilvus -elasticsearch>=8.0.0 -prometheus-client -erniebot diff --git a/tests/requirements.txt b/tests/requirements.txt index 85c17a9..52ef6e6 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -6,4 +6,5 @@ milvus transformers dashscope zhipuai -sentence-transformers \ No newline at end of file +sentence-transformers +erniebot \ No newline at end of file diff --git a/tests/unit_tests/src_langchain/llm/test_ernie.py b/tests/unit_tests/src_langchain/llm/test_ernie.py index c67209d..d3d8ab3 100644 --- a/tests/unit_tests/src_langchain/llm/test_ernie.py +++ b/tests/unit_tests/src_langchain/llm/test_ernie.py @@ -4,7 +4,7 @@ from unittest.mock import patch from langchain.schema import HumanMessage, AIMessage -sys.path.append(os.path.join(os.path.dirname(__file__), '../../../../..')) +sys.path.append(os.path.join(os.path.dirname(__file__), '../../../..')) class TestERNIE(unittest.TestCase): diff --git a/tests/unit_tests/src_towhee/pipelines/test_pipelines.py b/tests/unit_tests/src_towhee/pipelines/test_pipelines.py index 1f84e24..004012f 100644 --- a/tests/unit_tests/src_towhee/pipelines/test_pipelines.py +++ b/tests/unit_tests/src_towhee/pipelines/test_pipelines.py @@ -55,11 +55,13 @@ def create_pipelines(llm_src): class TestPipelines(unittest.TestCase): project = 'akcio_ut' - data_src = 'https://towhee.io' + data_src = 'akcio_ut.txt' question = 'test question' @classmethod def setUpClass(cls): + with open(cls.data_src, 'w+', encoding='utf-8') as tmp_f: + tmp_f.write('This is test content.') milvus_server.cleanup() milvus_server.start() @@ -84,7 +86,7 @@ def test_openai(self): token_count = 0 for x in res: token_count += x[0]['token_count'] - assert token_count == 261 + assert token_count == 5 num = pipelines.count_entities(self.project)['vector store'] assert len(res) <= num @@ -121,7 +123,7 @@ def test_chatglm(self): token_count = 0 for x in res: token_count += x[0]['token_count'] - assert token_count == 261 + assert token_count == 5 num = pipelines.count_entities(self.project)['vector store'] assert len(res) <= num @@ -137,13 +139,24 @@ def test_chatglm(self): def test_ernie(self): - class MockRequest: - def json(self): - return {'result': MOCK_ANSWER} - - with patch('requests.request') as mock_llm: - - mock_llm.return_value = MockRequest() + from erniebot.response import EBResponse + + with patch('erniebot.ChatCompletion.create') as mock_post: + mock_res = EBResponse(code=200, + body={'id': 'as-0000000000', 'object': 'chat.completion', 'created': 11111111, + 'result': MOCK_ANSWER, + 'usage': {'prompt_tokens': 1, 'completion_tokens': 13, 'total_tokens': 14}, + 'need_clear_history': False, 'is_truncated': False}, + headers={'Connection': 'keep-alive', + 'Content-Security-Policy': 'frame-ancestors https://*.baidu.com/', + 'Content-Type': 'application/json', 'Date': 'Mon, 23 Oct 2023 03:30:53 GMT', + 'Server': 'nginx', 'Statement': 'AI-generated', + 'Vary': 'Origin, Access-Control-Request-Method, Access-Control-Request-Headers', + 'X-Frame-Options': 'allow-from https://*.baidu.com/', + 'X-Request-Id': '0' * 32, + 'Transfer-Encoding': 'chunked'} + ) + mock_post.return_value = mock_res pipelines = create_pipelines('ernie') @@ -160,7 +173,7 @@ def json(self): token_count = 0 for x in res: token_count += x[0]['token_count'] - assert token_count == 261 + assert token_count == 5 num = pipelines.count_entities(self.project)['vector store'] assert len(res) <= num @@ -205,7 +218,7 @@ def output(self): token_count = 0 for x in res: token_count += x[0]['token_count'] - assert token_count == 261 + assert token_count == 5 num = pipelines.count_entities(self.project)['vector store'] assert len(res) <= num @@ -244,7 +257,7 @@ def json(self): token_count = 0 for x in res: token_count += x[0]['token_count'] - assert token_count == 261 + assert token_count == 5 num = pipelines.count_entities(self.project)['vector store'] assert len(res) <= num @@ -285,7 +298,7 @@ def iter_lines(self): token_count = 0 for x in res: token_count += x[0]['token_count'] - assert token_count == 261 + assert token_count == 5 num = pipelines.count_entities(self.project)['vector store'] assert len(res) <= num @@ -323,7 +336,7 @@ def __call__(self, *args, **kwargs): token_count = 0 for x in res: token_count += x[0]['token_count'] - assert token_count == 261 + assert token_count == 5 num = pipelines.count_entities(self.project)['vector store'] assert len(res) <= num @@ -339,6 +352,7 @@ def __call__(self, *args, **kwargs): @classmethod def tearDownClass(cls): + os.remove(cls.data_src) milvus_server.stop() milvus_server.cleanup()