From 0268df20f8f65c3c28160ee49fb9edcee6729007 Mon Sep 17 00:00:00 2001 From: hunteritself <104769634+hunteritself@users.noreply.github.com> Date: Tue, 21 Nov 2023 13:24:39 -0500 Subject: [PATCH] Add Weaviate integration (#1360) Add Weaviate integration, features include: 1. Initiate the required environment and connect to Weaviate vector database; 2. Create a class; 3. Delete a class; 4. Add data; 5. Make similarity-based queries. --------- Co-authored-by: Andy Xu --- docs/_toc.yml | 1 + .../source/reference/databases/hackernews.rst | 2 +- .../reference/vector_databases/weaviate.rst | 31 +++++ evadb/catalog/catalog_type.py | 1 + evadb/evadb_config.py | 2 + evadb/executor/executor_utils.py | 11 ++ evadb/interfaces/relational/db.py | 3 +- evadb/parser/evadb.lark | 3 +- .../parser/lark_visitor/_create_statements.py | 2 + evadb/third_party/vector_stores/utils.py | 8 ++ evadb/third_party/vector_stores/weaviate.py | 115 ++++++++++++++++++ evadb/utils/generic_utils.py | 18 +++ script/formatting/spelling.txt | 5 + setup.py | 4 + .../integration_tests/long/test_similarity.py | 39 ++++++ test/markers.py | 6 + 16 files changed, 248 insertions(+), 3 deletions(-) create mode 100644 docs/source/reference/vector_databases/weaviate.rst create mode 100644 evadb/third_party/vector_stores/weaviate.py diff --git a/docs/_toc.yml b/docs/_toc.yml index eb57363f4e..38309dbcfc 100644 --- a/docs/_toc.yml +++ b/docs/_toc.yml @@ -90,6 +90,7 @@ parts: - file: source/reference/vector_databases/pgvector - file: source/reference/vector_databases/pinecone - file: source/reference/vector_databases/milvus + - file: source/reference/vector_databases/weaviate - file: source/reference/ai/index title: AI Engines diff --git a/docs/source/reference/databases/hackernews.rst b/docs/source/reference/databases/hackernews.rst index d96112e815..cc5bc97dba 100644 --- a/docs/source/reference/databases/hackernews.rst +++ b/docs/source/reference/databases/hackernews.rst @@ -18,7 +18,7 @@ Required: Optional: -* ``tags`` is the tag used for filtering the query results. Check `available tags `_ to see a list of available filter tags. +* ``tags`` is the tag used for filtering the query results. Check `available tags `_ to see a list of available filter tags. Create Connection ----------------- diff --git a/docs/source/reference/vector_databases/weaviate.rst b/docs/source/reference/vector_databases/weaviate.rst new file mode 100644 index 0000000000..c964d7e878 --- /dev/null +++ b/docs/source/reference/vector_databases/weaviate.rst @@ -0,0 +1,31 @@ +Weaviate +========== + +Weaviate is an open-source vector database designed for scalability and rich querying capabilities. It allows for semantic search, automated vectorization, and supports large language model (LLM) integration. +The connection to Weaviate is based on the `weaviate-client `_ library. + +Dependency +---------- + +* weaviate-client + +Parameters +---------- + +To use Weaviate, you need an API key and a URL of your Weaviate instance. Here are the `instructions for setting up a Weaviate instance `_. After setting up your instance, you will find the API key and URL on the Details tab in Weaviate Cloud Services (WCS) dashboard. These details are essential for establishing a connection to the Weaviate server. + +* `WEAVIATE_API_KEY` is the API key for your Weaviate instance. +* `WEAVIATE_API_URL` is the URL of your Weaviate instance. + +The above values can either be set via the ``SET`` statement, or in the os environment fields "WEAVIATE_API_KEY", "WEAVIATE_API_URL" + +Create Collection +----------------- + +Weaviate uses collections (similar to 'classes') to store data. To create a collection in Weaviate, use the following SQL command in EvaDB: + +.. code-block:: sql + + CREATE INDEX collection_name ON table_name (data) USING WEAVIATE; + +This command creates a collection in Weaviate with the specified name, linked to the table in EvaDB. You can also specify vectorizer settings and other configurations for the collection as needed. \ No newline at end of file diff --git a/evadb/catalog/catalog_type.py b/evadb/catalog/catalog_type.py index d6c052126a..5da568779f 100644 --- a/evadb/catalog/catalog_type.py +++ b/evadb/catalog/catalog_type.py @@ -117,6 +117,7 @@ class VectorStoreType(EvaDBEnum): PINECONE # noqa: F821 PGVECTOR # noqa: F821 CHROMADB # noqa: F821 + WEAVIATE # noqa: F821 MILVUS # noqa: F821 diff --git a/evadb/evadb_config.py b/evadb/evadb_config.py index 9c209c0128..6117514b25 100644 --- a/evadb/evadb_config.py +++ b/evadb/evadb_config.py @@ -41,4 +41,6 @@ "MILVUS_PASSWORD": "", "MILVUS_DB_NAME": "", "MILVUS_TOKEN": "", + "WEAVIATE_API_KEY": "", + "WEAVIATE_API_URL": "", } diff --git a/evadb/executor/executor_utils.py b/evadb/executor/executor_utils.py index 40945c1e5a..26f9d14f8e 100644 --- a/evadb/executor/executor_utils.py +++ b/evadb/executor/executor_utils.py @@ -185,6 +185,17 @@ def handle_vector_store_params( ), "PINECONE_ENV": catalog().get_configuration_catalog_value("PINECONE_ENV"), } + elif vector_store_type == VectorStoreType.WEAVIATE: + # Weaviate Configuration + # Weaviate API key and URL Can be obtained from cluster details on Weaviate Cloud Services (WCS) dashboard + return { + "WEAVIATE_API_KEY": catalog().get_configuration_catalog_value( + "WEAVIATE_API_KEY" + ), + "WEAVIATE_API_URL": catalog().get_configuration_catalog_value( + "WEAVIATE_API_URL" + ), + } elif vector_store_type == VectorStoreType.MILVUS: return { "MILVUS_URI": catalog().get_configuration_catalog_value("MILVUS_URI"), diff --git a/evadb/interfaces/relational/db.py b/evadb/interfaces/relational/db.py index 428d0878f5..714593a8a8 100644 --- a/evadb/interfaces/relational/db.py +++ b/evadb/interfaces/relational/db.py @@ -268,7 +268,8 @@ def create_vector_index( index_name (str): Name of the index. table_name (str): Name of the table. expr (str): Expression used to build the vector index. - using (str): Method used for indexing, can be `FAISS` or `QDRANT` or `PINECONE` or `CHROMADB` or `MILVUS`. + + using (str): Method used for indexing, can be `FAISS` or `QDRANT` or `PINECONE` or `CHROMADB` or `WEAVIATE` or `MILVUS`. Returns: EvaDBCursor: The EvaDBCursor object. diff --git a/evadb/parser/evadb.lark b/evadb/parser/evadb.lark index 86798df6c0..4b96bf647b 100644 --- a/evadb/parser/evadb.lark +++ b/evadb/parser/evadb.lark @@ -71,7 +71,7 @@ function_metadata_key: uid function_metadata_value: constant -vector_store_type: USING (FAISS | QDRANT | PINECONE | PGVECTOR | CHROMADB | MILVUS) +vector_store_type: USING (FAISS | QDRANT | PINECONE | PGVECTOR | CHROMADB | WEAVIATE | MILVUS) index_elem: ("(" uid_list ")" | "(" function_call ")") @@ -448,6 +448,7 @@ QDRANT: "QDRANT"i PINECONE: "PINECONE"i PGVECTOR: "PGVECTOR"i CHROMADB: "CHROMADB"i +WEAVIATE: "WEAVIATE"i MILVUS: "MILVUS"i // Computer vision tasks diff --git a/evadb/parser/lark_visitor/_create_statements.py b/evadb/parser/lark_visitor/_create_statements.py index 72066b294c..175f0087e9 100644 --- a/evadb/parser/lark_visitor/_create_statements.py +++ b/evadb/parser/lark_visitor/_create_statements.py @@ -300,6 +300,8 @@ def vector_store_type(self, tree): vector_store_type = VectorStoreType.PGVECTOR elif str.upper(token) == "CHROMADB": vector_store_type = VectorStoreType.CHROMADB + elif str.upper(token) == "WEAVIATE": + vector_store_type = VectorStoreType.WEAVIATE elif str.upper(token) == "MILVUS": vector_store_type = VectorStoreType.MILVUS return vector_store_type diff --git a/evadb/third_party/vector_stores/utils.py b/evadb/third_party/vector_stores/utils.py index 2a01d57e68..9c12fc6fbd 100644 --- a/evadb/third_party/vector_stores/utils.py +++ b/evadb/third_party/vector_stores/utils.py @@ -18,6 +18,7 @@ from evadb.third_party.vector_stores.milvus import MilvusVectorStore from evadb.third_party.vector_stores.pinecone import PineconeVectorStore from evadb.third_party.vector_stores.qdrant import QdrantVectorStore +from evadb.third_party.vector_stores.weaviate import WeaviateVectorStore from evadb.utils.generic_utils import validate_kwargs @@ -51,6 +52,12 @@ def init_vector_store( validate_kwargs(kwargs, required_params, required_params) return ChromaDBVectorStore(index_name, **kwargs) + elif vector_store_type == VectorStoreType.WEAVIATE: + from evadb.third_party.vector_stores.weaviate import required_params + + validate_kwargs(kwargs, required_params, required_params) + return WeaviateVectorStore(index_name, **kwargs) + elif vector_store_type == VectorStoreType.MILVUS: from evadb.third_party.vector_stores.milvus import ( allowed_params, @@ -59,5 +66,6 @@ def init_vector_store( validate_kwargs(kwargs, allowed_params, required_params) return MilvusVectorStore(index_name, **kwargs) + else: raise Exception(f"Vector store {vector_store_type} not supported") diff --git a/evadb/third_party/vector_stores/weaviate.py b/evadb/third_party/vector_stores/weaviate.py new file mode 100644 index 0000000000..073d530312 --- /dev/null +++ b/evadb/third_party/vector_stores/weaviate.py @@ -0,0 +1,115 @@ +# coding=utf-8 +# Copyright 2018-2023 EvaDB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import List + +from evadb.third_party.vector_stores.types import ( + FeaturePayload, + VectorIndexQuery, + VectorIndexQueryResult, + VectorStore, +) +from evadb.utils.generic_utils import try_to_import_weaviate_client + +required_params = [] +_weaviate_init_done = False + + +class WeaviateVectorStore(VectorStore): + def __init__(self, collection_name: str, **kwargs) -> None: + try_to_import_weaviate_client() + global _weaviate_init_done + + self._collection_name = collection_name + + # Get the API key. + self._api_key = kwargs.get("WEAVIATE_API_KEY") + + if not self._api_key: + self._api_key = os.environ.get("WEAVIATE_API_KEY") + + assert ( + self._api_key + ), "Please set your `WEAVIATE_API_KEY` using set command or environment variable (WEAVIATE_API_KEY). It can be found at the Details tab in WCS Dashboard." + + # Get the API Url. + self._api_url = kwargs.get("WEAVIATE_API_URL") + + if not self._api_url: + self._api_url = os.environ.get("WEAVIATE_API_URL") + + assert ( + self._api_url + ), "Please set your `WEAVIATE_API_URL` using set command or environment variable (WEAVIATE_API_URL). It can be found at the Details tab in WCS Dashboard." + + if not _weaviate_init_done: + # Initialize weaviate client + import weaviate + + client = weaviate.Client( + url=self._api_url, + auth_client_secret=weaviate.AuthApiKey(api_key=self._api_key), + ) + client.schema.get() + + _weaviate_init_done = True + + self._client = client + + def create( + self, + vectorizer: str = "text2vec-openai", + properties: list = None, + module_config: dict = None, + ): + properties = properties or [] + module_config = module_config or {} + + collection_obj = { + "class": self._collection_name, + "properties": properties, + "vectorizer": vectorizer, + "moduleConfig": module_config, + } + + if self._client.schema.exists(self._collection_name): + self._client.schema.delete_class(self._collection_name) + + self._client.schema.create_class(collection_obj) + + def add(self, payload: List[FeaturePayload]) -> None: + with self._client.batch as batch: + for item in payload: + data_object = {"id": item.id, "vector": item.embedding} + batch.add_data_object(data_object, self._collection_name) + + def delete(self) -> None: + self._client.schema.delete_class(self._collection_name) + + def query(self, query: VectorIndexQuery) -> VectorIndexQueryResult: + response = ( + self._client.query.get(self._collection_name, ["*"]) + .with_near_vector({"vector": query.embedding}) + .with_limit(query.top_k) + .do() + ) + + data = response.get("data", {}) + results = data.get("Get", {}).get(self._collection_name, []) + + similarities = [item["_additional"]["distance"] for item in results] + ids = [item["id"] for item in results] + + return VectorIndexQueryResult(similarities, ids) diff --git a/evadb/utils/generic_utils.py b/evadb/utils/generic_utils.py index 8f362e8cb5..426719f87c 100644 --- a/evadb/utils/generic_utils.py +++ b/evadb/utils/generic_utils.py @@ -573,6 +573,16 @@ def try_to_import_chromadb_client(): ) +def try_to_import_weaviate_client(): + try: + import weaviate # noqa: F401 + except ImportError: + raise ValueError( + """Could not import weaviate python package. + Please install it with 'pip install weaviate-client`.""" + ) + + def try_to_import_milvus_client(): try: import pymilvus # noqa: F401 @@ -607,6 +617,14 @@ def is_chromadb_available() -> bool: return False +def is_weaviate_available() -> bool: + try: + try_to_import_weaviate_client() + return True + except ValueError: # noqa: E722 + return False + + def is_milvus_available() -> bool: try: try_to_import_milvus_client() diff --git a/script/formatting/spelling.txt b/script/formatting/spelling.txt index 239b4411da..1dd5566caf 100644 --- a/script/formatting/spelling.txt +++ b/script/formatting/spelling.txt @@ -975,10 +975,13 @@ VideoFormat VideoStorageEngineTest VideoWriter VisionEncoderDecoderModel +WEAVIATE WH WIP WMV WeakValueDictionary +Weaviate +WeaviateVectorStore XGBoost XdistTests Xeon @@ -1731,6 +1734,7 @@ testRayErrorHandling testSimilarityFeatureTable testSimilarityImageDataset testSimilarityTable +testWeaviateIndexImageDataset testcase testcases testdeleteone @@ -1814,6 +1818,7 @@ wal warmup wb weakref +weaviate westbrae wget whitespaces diff --git a/setup.py b/setup.py index 8e809e2a32..e3d211ece8 100644 --- a/setup.py +++ b/setup.py @@ -112,8 +112,11 @@ def read(path, encoding="utf-8"): chromadb_libs = ["chromadb"] +weaviate_libs = ["weaviate-client"] + milvus_libs = ["pymilvus>=2.3.0"] + postgres_libs = [ "psycopg2", ] @@ -173,6 +176,7 @@ def read(path, encoding="utf-8"): "pinecone": pinecone_libs, "chromadb": chromadb_libs, "milvus": milvus_libs, + "weaviate": weaviate_libs, "postgres": postgres_libs, "ludwig": ludwig_libs, "sklearn": sklearn_libs, diff --git a/test/integration_tests/long/test_similarity.py b/test/integration_tests/long/test_similarity.py index 81d6054fe8..2a8d52cf8d 100644 --- a/test/integration_tests/long/test_similarity.py +++ b/test/integration_tests/long/test_similarity.py @@ -20,6 +20,7 @@ milvus_skip_marker, pinecone_skip_marker, qdrant_skip_marker, + weaviate_skip_marker, ) from test.util import ( create_sample_image, @@ -142,6 +143,14 @@ def setUp(self): # use default Milvus database for testing os.environ["MILVUS_DB_NAME"] = "default" + self.original_weaviate_key = os.environ.get("WEAVIATE_API_KEY") + self.original_weaviate_env = os.environ.get("WEAVIATE_API_URL") + + os.environ["WEAVIATE_API_KEY"] = "NM4adxLmhtJDF1dPXDiNhEGTN7hhGDpymmO0" + os.environ[ + "WEAVIATE_API_URL" + ] = "https://cs6422-test2-zn83syib.weaviate.network" + def tearDown(self): shutdown_ray() @@ -580,3 +589,33 @@ def test_end_to_end_index_scan_should_work_correctly_on_image_dataset_milvus( # Cleanup drop_query = "DROP INDEX testMilvusIndexImageDataset" execute_query_fetch_all(self.evadb, drop_query) + + @pytest.mark.skip(reason="Requires running Weaviate instance") + @weaviate_skip_marker + def test_end_to_end_index_scan_should_work_correctly_on_image_dataset_weaviate( + self, + ): + for _ in range(2): + create_index_query = """CREATE INDEX testWeaviateIndexImageDataset + ON testSimilarityImageDataset (DummyFeatureExtractor(data)) + USING WEAVIATE;""" + execute_query_fetch_all(self.evadb, create_index_query) + + select_query = """SELECT _row_id FROM testSimilarityImageDataset + ORDER BY Similarity(DummyFeatureExtractor(Open("{}")), DummyFeatureExtractor(data)) + LIMIT 1;""".format( + self.img_path + ) + explain_batch = execute_query_fetch_all( + self.evadb, f"EXPLAIN {select_query}" + ) + self.assertTrue("VectorIndexScan" in explain_batch.frames[0][0]) + + res_batch = execute_query_fetch_all(self.evadb, select_query) + self.assertEqual( + res_batch.frames["testsimilarityimagedataset._row_id"][0], 5 + ) + + # Cleanup + drop_query = "DROP INDEX testWeaviateIndexImageDataset" + execute_query_fetch_all(self.evadb, drop_query) diff --git a/test/markers.py b/test/markers.py index 8273f5f0fb..deefadb294 100644 --- a/test/markers.py +++ b/test/markers.py @@ -28,6 +28,7 @@ is_pinecone_available, is_qdrant_available, is_replicate_available, + is_weaviate_available, ) asyncio_skip_marker = pytest.mark.skipif( @@ -54,6 +55,11 @@ reason="Skipping since pymilvus is not installed", ) +weaviate_skip_marker = pytest.mark.skipif( + is_weaviate_available() is False, + reason="Skipping since weaviate is not installed", +) + windows_skip_marker = pytest.mark.skipif( sys.platform == "win32", reason="Test case not supported on Windows" )