From a066e2a9bf991dea1d4ba2b201d4ab28d6716759 Mon Sep 17 00:00:00 2001 From: Oliver Rice Date: Tue, 12 Sep 2023 13:54:31 -0500 Subject: [PATCH] hnsw support --- .github/workflows/tests.yml | 2 +- src/tests/test_client.py | 6 ++ src/vecs/client.py | 16 ++++ src/vecs/collection.py | 154 +++++++++++++++++++++++------------- 4 files changed, 121 insertions(+), 57 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 5cd7172..8fb83fb 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -8,7 +8,7 @@ jobs: strategy: matrix: python-version: ['3.8', '3.9', '3.10', '3.11'] - postgres-version: ['15.1.0.118'] + postgres-version: ['15.1.0.87', '15.1.0.118'] services: diff --git a/src/tests/test_client.py b/src/tests/test_client.py index 02c8094..6e8694f 100644 --- a/src/tests/test_client.py +++ b/src/tests/test_client.py @@ -3,6 +3,12 @@ import vecs +def test_extracts_vector_version(client: vecs.Client) -> None: + # pgvector version is sucessfully extracted + assert client.vector_version != "" + assert client.vector_version.count(".") >= 2 + + def test_create_collection(client: vecs.Client) -> None: with pytest.warns(DeprecationWarning): client.create_collection(name="docs", dimension=384) diff --git a/src/vecs/client.py b/src/vecs/client.py index 1701aa8..281586a 100644 --- a/src/vecs/client.py +++ b/src/vecs/client.py @@ -7,6 +7,7 @@ from __future__ import annotations +import warnings from typing import TYPE_CHECKING, List, Optional from deprecated import deprecated @@ -65,6 +66,21 @@ def __init__(self, connection_string): with sess.begin(): sess.execute(text("create schema if not exists vecs;")) sess.execute(text("create extension if not exists vector;")) + self.vector_version: str = sess.execute( + text( + "select installed_version from pg_available_extensions where name = 'vector' limit 1;" + ) + ).scalar_one() + + if self._supports_hnsw(): + warnings.warn( + UserWarning( + f"vecs will drop support for pgvector < 0.5.0 in version 1.0. Consider updating to latest postgres" + ) + ) + + def _supports_hnsw(self): + return not self.vector_version.startswith("0.4") def get_or_create_collection( self, diff --git a/src/vecs/collection.py b/src/vecs/collection.py index 853780c..26f0883 100644 --- a/src/vecs/collection.py +++ b/src/vecs/collection.py @@ -57,10 +57,14 @@ class IndexMethod(str, Enum): expand in the future. Attributes: + auto (str): Automatically choose the best available index method. ivfflat (str): The ivfflat index method. + hnsw (str): The hnsw index method. """ + auto = "auto" ivfflat = "ivfflat" + hnsw = "hnsw" class IndexMeasure(str, Enum): @@ -594,7 +598,7 @@ def is_indexed_for_measure(self, measure: IndexMeasure): def create_index( self, measure: IndexMeasure = IndexMeasure.cosine_distance, - method: IndexMethod = IndexMethod.ivfflat, + method: IndexMethod = IndexMethod.auto, replace=True, ) -> None: """ @@ -621,16 +625,28 @@ def create_index( Args: measure (IndexMeasure, optional): The measure to index for. Defaults to 'cosine_distance'. - method (IndexMethod, optional): The indexing method to use. Defaults to 'ivfflat'. + method (IndexMethod, optional): The indexing method to use. Defaults to 'auto'. replace (bool, optional): Whether to replace the existing index. Defaults to True. Raises: ArgError: If an invalid index method is used, or if *replace* is False and an index already exists. """ - if not method == IndexMethod.ivfflat: - # at time of writing, no other methods are supported by pgvector + if not method in (IndexMethod.ivfflat, IndexMethod.hnsw, IndexMethod.auto): raise ArgError("invalid index method") + if method == IndexMethod.auto: + if self.client._supports_hnsw(): + method = IndexMethod.hnsw + else: + method = IndexMethod.ivfflat + + if method == IndexMethod.ivfflat: + warnings.warn( + UserWarning( + f"vecs will drop support for ivfflat indexes in version 1.0. upgrade to pgvector >= 0.5.0 and use IndexMethod.hnsw" + ) + ) + if replace: self._index = None else: @@ -641,70 +657,96 @@ def create_index( if ops is None: raise ArgError("Unknown index measure") - # Clone the table - clone_table = build_table(f"_{self.name}", self.client.meta, self.dimension) + if method == IndexMethod.ivfflat: + # Clone the table + clone_table = build_table(f"_{self.name}", self.client.meta, self.dimension) - # hacky - try: - clone_table.drop(self.client.engine) - except Exception: - pass + # hacky + try: + clone_table.drop(self.client.engine) + except Exception: + pass - with self.client.Session() as sess: - n_records: int = sess.execute(func.count(self.table.c.id)).scalar() # type: ignore + with self.client.Session() as sess: + n_records: int = sess.execute(func.count(self.table.c.id)).scalar() # type: ignore - with self.client.Session() as sess: - with sess.begin(): - n_index_seed = min(5000, n_records) - clone_table.create(sess.connection()) - stmt_seed_table = clone_table.insert().from_select( - self.table.c, - select(self.table).order_by(func.random()).limit(n_index_seed), - ) - sess.execute(stmt_seed_table) + with self.client.Session() as sess: + with sess.begin(): + n_index_seed = min(5000, n_records) + clone_table.create(sess.connection()) + stmt_seed_table = clone_table.insert().from_select( + self.table.c, + select(self.table).order_by(func.random()).limit(n_index_seed), + ) + sess.execute(stmt_seed_table) - n_lists = ( - int(max(n_records / 1000, 30)) - if n_records < 1_000_000 - else int(math.sqrt(n_records)) - ) + n_lists = ( + int(max(n_records / 1000, 30)) + if n_records < 1_000_000 + else int(math.sqrt(n_records)) + ) - unique_string = str(uuid.uuid4()).replace("-", "_")[0:7] + unique_string = str(uuid.uuid4()).replace("-", "_")[0:7] - sess.execute( - text( - f""" - create index ix_{ops}_{n_lists}_{unique_string} - on vecs."{clone_table.name}" - using ivfflat (vec {ops}) with (lists={n_lists}) - """ + sess.execute( + text( + f""" + create index ix_{ops}_{n_lists}_{unique_string} + on vecs."{clone_table.name}" + using ivfflat (vec {ops}) with (lists={n_lists}) + """ + ) ) - ) - sess.execute( - text( - f""" - create index ix_meta_{unique_string} - on vecs."{clone_table.name}" - using gin ( metadata jsonb_path_ops ) - """ + sess.execute( + text( + f""" + create index ix_meta_{unique_string} + on vecs."{clone_table.name}" + using gin ( metadata jsonb_path_ops ) + """ + ) ) - ) - # Fully populate the table - stmt = postgresql.insert(clone_table).from_select( - self.table.c, select(self.table) - ) - stmt = stmt.on_conflict_do_nothing() - sess.execute(stmt) + # Fully populate the table + stmt = postgresql.insert(clone_table).from_select( + self.table.c, select(self.table) + ) + stmt = stmt.on_conflict_do_nothing() + sess.execute(stmt) - # Replace the table - sess.execute(text(f"drop table vecs.{self.table.name};")) - sess.execute( - text( - f"alter table vecs._{self.table.name} rename to {self.table.name};" + # Replace the table + sess.execute(text(f"drop table vecs.{self.table.name};")) + sess.execute( + text( + f"alter table vecs._{self.table.name} rename to {self.table.name};" + ) ) - ) + + if method == IndexMethod.hnsw: + unique_string = str(uuid.uuid4()).replace("-", "_")[0:7] + with self.client.Session() as sess: + with sess.begin(): + sess.execute( + text( + f""" + create index ix_{ops}_{unique_string} + on vecs."{self.table.name}" + using hnsw (vec {ops}); + """ + ) + ) + + sess.execute( + text( + f""" + create index ix_meta_{unique_string} + on vecs."{self.table.name}" + using gin ( metadata jsonb_path_ops ) + """ + ) + ) + return None