From aaa9c8120f291ae8ea5f68860175f1f694047769 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Wed, 25 Oct 2023 15:39:31 +0800 Subject: [PATCH] Replace postgresql with sql in langchain Signed-off-by: Jael Gu --- README.md | 11 ++- config.py | 8 +-- src_langchain/store/README.md | 2 +- src_langchain/store/__init__.py | 2 +- src_langchain/store/memory_store/sql.py | 84 ++++++++++++++++++++++ src_langchain/store/vector_store/milvus.py | 67 ++++++++++++++++- src_towhee/pipelines/__init__.py | 43 +++++++---- 7 files changed, 186 insertions(+), 31 deletions(-) create mode 100644 src_langchain/store/memory_store/sql.py diff --git a/README.md b/README.md index b5f2bec..fb52e2f 100644 --- a/README.md +++ b/README.md @@ -159,22 +159,21 @@ The option using LangChain employs the use of [Agent](https://python.langchain.c - Vector Store: You need to prepare the service of vector database in advance. For example, you can refer to [Milvus Documents](https://milvus.io/docs) or [Zilliz Cloud](https://zilliz.com/doc/quick_start) to learn about how to start a Milvus service. - Scalar Store (Optional): This is optional, only work when `USE_SCALAR` is true in [configuration](config.py). If this is enabled (i.e. USE_SCALAR=True), the default scalar store will use [Elastic](https://www.elastic.co/). In this case, you need to prepare the Elasticsearch service in advance. - - Memory Store: You need to prepare the database for memory storage as well. By default, LangChain mode supports [Postgresql](https://www.postgresql.org/) and Towhee mode allows interaction with any database supported by [SQLAlchemy 2.0](https://docs.sqlalchemy.org/en/20/dialects/). + - Memory Store: You need to prepare the database for memory storage as well. By default, both LangChain and Towhee mode allow interaction with any database supported by [SQLAlchemy 2.0](https://docs.sqlalchemy.org/en/20/dialects/). The system will use default store configs. To set up your special connections for each database, you can also export environment variables instead of modifying the configuration file. - For the Vector Store, set **MILVUS_URI**: + For the Vector Store, set **ZILLIZ_URI**: ```shell - $ export MILVUS_URI=https://localhost:19530 + $ export ZILLIZ_URI=your_zilliz_cloud_endpoint + $ export ZILLIZ_TOKEN=your_zilliz_cloud_api_key # skip this if using Milvus instance ``` For the Memory Store, set **SQL_URI**: ```shell $ export SQL_URI={database_type}://{user}:{password}@{host}/{database_name} - ``` - > LangChain mode only supports [Postgresql](https://www.postgresql.org/) as database type. - + ```
By default, scalar store (elastic) is disabled. diff --git a/config.py b/config.py index dcc7b90..4f778c5 100644 --- a/config.py +++ b/config.py @@ -72,10 +72,8 @@ # Vector db configs VECTORDB_CONFIG = { 'connection_args': { - 'uri': os.getenv('MILVUS_URI', 'http://localhost:19530'), - 'user': os.getenv('MILVUS_USER', ''), - 'password': os.getenv('MILVUS_PASSWORD', ''), - 'secure': True if os.getenv('MILVUS_SECURE', 'False').lower() == 'true' else False + 'uri': os.getenv('ZILLIZ_URI', 'http://localhost:19530'), + 'token': os.getenv('ZILLIZ_TOKEN') }, 'top_k': 5, 'threshold': 0, @@ -104,7 +102,7 @@ # Memory db configs MEMORYDB_CONFIG = { - 'connect_str': os.getenv('SQL_URI', 'postgresql://postgres:postgres@localhost/chat_history') + 'connect_str': os.getenv('SQL_URI', 'sqlite:///./sqlite.db') } diff --git a/src_langchain/store/README.md b/src_langchain/store/README.md index c00152b..e6289dc 100644 --- a/src_langchain/store/README.md +++ b/src_langchain/store/README.md @@ -23,7 +23,7 @@ The default module also works with [Zilliz Cloud](https://zilliz.com) by setting # Vector db configs VECTORDB_CONFIG = { 'connection_args': { - 'uri': os.getenv('MILVUS_URI', 'your_endpoint'), + 'uri': os.getenv('ZILLIZ_URI', 'your_endpoint'), 'user': os.getenv('MILVUS_USER', 'user_name'), 'password': os.getenv('MILVUS_PASSWORD', 'password_goes_here'), 'secure': True diff --git a/src_langchain/store/__init__.py b/src_langchain/store/__init__.py index df7d31d..6e97af3 100644 --- a/src_langchain/store/__init__.py +++ b/src_langchain/store/__init__.py @@ -3,7 +3,7 @@ from typing import Optional, List from .vector_store.milvus import VectorStore, Embeddings -from .memory_store.pg import MemoryStore +from .memory_store.sql import MemoryStore sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) diff --git a/src_langchain/store/memory_store/sql.py b/src_langchain/store/memory_store/sql.py new file mode 100644 index 0000000..4a22037 --- /dev/null +++ b/src_langchain/store/memory_store/sql.py @@ -0,0 +1,84 @@ +import os +import sys +from typing import List + +from sqlalchemy import create_engine, inspect, MetaData, Table + +from langchain.schema import HumanMessage, AIMessage +from langchain.memory import SQLChatMessageHistory, ConversationBufferMemory + +sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) + +from config import MEMORYDB_CONFIG # pylint: disable=C0413 + + +CONNECT_STR = MEMORYDB_CONFIG.get( + 'connect_str', 'sqlite:///./sqlite.db') + + +class MemoryStore: + '''Memory database APIs: add_history, get_history''' + + def __init__(self, table_name: str, session_id: str): + '''Initialize memory storage: e.g. history_db''' + self.table_name = table_name + self.session_id = session_id + + self.history_db = SQLChatMessageHistory( + table_name=self.table_name, + session_id=self.session_id, + connection_string=CONNECT_STR, + ) + self.memory = ConversationBufferMemory( + memory_key='chat_history', + chat_memory=self.history_db, + return_messages=True + ) + + def add_history(self, messages: List[dict]): + for qa in messages: + if 'question' in qa: + self.history_db.add_user_message(qa['question']) + if 'answer' in qa: + self.history_db.add_ai_message(qa['answer']) + + def get_history(self): + history = self.history_db.messages + messages = [] + for x in history: + if isinstance(x, HumanMessage): + if len(messages) > 0 and messages[-1][0] is None: + a = messages[-1][-1] + del messages[-1] + else: + a = None + messages.append((x.content, a)) + if isinstance(x, AIMessage): + if len(messages) > 0 and messages[-1][-1] is None: + q = messages[-1][0] + del messages[-1] + else: + q = None + messages.append((q, x.content)) + return messages + + @staticmethod + def drop(table_name, connect_str: str = CONNECT_STR, session_id: str = None): + engine = create_engine(connect_str, echo=False) + existence = MemoryStore.check(table_name) + + if existence: + project_table = Table(table_name, MetaData(), + autoload_with=engine, extend_existing=True) + if session_id and len(session_id) > 0: + query = project_table.delete().where(project_table.c.session_id == session_id) + with engine.connect() as conn: + conn.execute(query) + conn.commit() + else: + query = project_table.drop(engine) + + @staticmethod + def check(table_name, connect_str: str = CONNECT_STR): + engine = create_engine(connect_str, echo=False) + return inspect(engine).has_table(table_name) diff --git a/src_langchain/store/vector_store/milvus.py b/src_langchain/store/vector_store/milvus.py index 958dcd7..82d321f 100644 --- a/src_langchain/store/vector_store/milvus.py +++ b/src_langchain/store/vector_store/milvus.py @@ -1,7 +1,8 @@ import os import sys import logging -from typing import Optional, Any, Tuple, List, Dict +from typing import Optional, Any, Tuple, List, Dict, Union +from uuid import uuid4 from langchain.vectorstores import Milvus from langchain.embeddings.base import Embeddings @@ -14,7 +15,7 @@ logger = logging.getLogger('vector_store') -CONNECTION_ARGS = VECTORDB_CONFIG.get('connection_args', {'host': 'localhost', 'port': 19530}) +CONNECTION_ARGS = VECTORDB_CONFIG.get('connection_args', {'uri': 'http://localhost:19530'}) TOP_K = VECTORDB_CONFIG.get('top_k', 3) INDEX_PARAMS = VECTORDB_CONFIG.get('index_params', None) SEARCH_PARAMS = VECTORDB_CONFIG.get('search_params', None) @@ -26,7 +27,12 @@ class VectorStore(Milvus): ''' def __init__(self, table_name: str, embedding_func: Embeddings = None, connection_args: dict = CONNECTION_ARGS): - '''Initialize vector db''' + '''Initialize vector db + + connection_args: + uri: milvus or zilliz uri + token: zilliz token + ''' # assert isinstance( # embedding_func, Embeddings), 'Invalid embedding function. Only accept langchain.embeddings.' self.embedding_func = embedding_func @@ -40,6 +46,61 @@ def __init__(self, table_name: str, embedding_func: Embeddings = None, connectio search_params=SEARCH_PARAMS ) + def _create_connection_alias(self, connection_args: dict) -> str: + """Create the connection to the Milvus server.""" + from pymilvus import MilvusException, connections # pylint: disable = C0415 + + # Grab the connection arguments that are used for checking existing connection + host: str = connection_args.get('host', None) + port: Union[str, int] = connection_args.get('port', None) + uri: str = connection_args.get('uri', None) + user = connection_args.get('user', None) + password = connection_args.get('password', None) + token = connection_args.get('token', None) + + + _connection_args = {} # pylint: disable = C0103 + # Order of use is uri > host/port + if uri is not None: + _connection_args['uri'] = uri + given_address = uri.split('://')[1] + elif host is not None and port is not None: + _connection_args['host'] = host + _connection_args['port'] = port + given_address = f'{host}:{port}' + else: + logger.debug('Missing standard address type for reuse attempt') + given_address = None + + # Order of use is token > user/password + if token is not None: + _connection_args['token'] = token + _connection_args['secure'] = True + elif user is not None and password is not None: + _connection_args['user'] = user + _connection_args['password'] = password + _connection_args['secure'] = True + else: + _connection_args['secure'] = False + + # If a valid address was given, then check if a connection exists + if given_address is not None: + for con in connections.list_connections(): + addr = connections.get_connection_addr(con[0]) + if addr == given_address: + logger.debug('Using previous connection: %s', con[0]) + return con[0] + + # Generate a new connection if one doesn't exist + alias = uuid4().hex + try: + connections.connect(alias=alias, **_connection_args) + logger.debug('Created new connection using: %s', alias) + return alias + except MilvusException as e: + logger.error('Failed to create new connection using: %s', alias) + raise e + def similarity_search_with_score_by_vector( self, embedding: List[float], diff --git a/src_towhee/pipelines/__init__.py b/src_towhee/pipelines/__init__.py index afee7ea..c9d8136 100644 --- a/src_towhee/pipelines/__init__.py +++ b/src_towhee/pipelines/__init__.py @@ -43,25 +43,38 @@ def __init__(self, self.rerank_config = rerank_config self.chunk_size = chunk_size - self.milvus_uri = vectordb_config['connection_args']['uri'] - self.milvus_host = self.milvus_uri.split('https://')[1].split(':')[0] - self.milvus_port = self.milvus_uri.split('https://')[1].split(':')[1] - milvus_user = vectordb_config['connection_args'].get('user') - self.milvus_secure = vectordb_config['connection_args'].get('secure', False) - self.milvus_user = None if milvus_user == '' else milvus_user - milvus_password = vectordb_config['connection_args'].get('password') - self.milvus_password = None if milvus_password == '' else milvus_password self.milvus_topk = vectordb_config.get('top_k', 5) self.milvus_threshold = vectordb_config.get('threshold', 0) self.milvus_index_params = vectordb_config.get('index_params', {}) - connections.connect( - host=self.milvus_host, - port=self.milvus_port, - user=self.milvus_user, - secure=self.milvus_secure, - password=self.milvus_password - ) + self.connection_args = vectordb_config['connection_args'] + for k, v in self.connection_args.items(): + if v is None or len(v) == 0: + del self.connection_args[k] + + if 'uri' in self.connection_args: + [self.milvus_host, self.milvus_port] = self.connection_args['uri'].split('://')[1].split(':') + self.connection_args.pop('host', None) + self.connection_args.pop('port', None) + elif 'host' in self.connection_args and 'port' in self.connection_args: + self.milvus_host = self.connection_args.get('host') + self.milvus_port = self.connection_args.get('port') + else: + raise AttributeError('Invalid connection args for milvus.') + + if 'token' in self.connection_args: + self.milvus_token = self.connection_args.get('token') + self.connection_args.pop('user', None) + self.connection_args.pop('password', None) + else: + self.milvus_user = self.connection_args.get('user') + self.milvus_password = self.connection_args.get('password') + + if self.milvus_token or self.milvus_user: + self.connection_args['secure'] = True + self.milvus_secure = True + + connections.connect(**self.connection_args) if self.use_scalar: from elasticsearch import Elasticsearch # pylint: disable=C0415