Skip to content

Commit

Permalink
hnsw support
Browse files Browse the repository at this point in the history
  • Loading branch information
olirice committed Sep 12, 2023
1 parent 6c0bb85 commit a066e2a
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 57 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:
strategy:
matrix:
python-version: ['3.8', '3.9', '3.10', '3.11']
postgres-version: ['15.1.0.118']
postgres-version: ['15.1.0.87', '15.1.0.118']

services:

Expand Down
6 changes: 6 additions & 0 deletions src/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
import vecs


def test_extracts_vector_version(client: vecs.Client) -> None:
# pgvector version is sucessfully extracted
assert client.vector_version != ""
assert client.vector_version.count(".") >= 2


def test_create_collection(client: vecs.Client) -> None:
with pytest.warns(DeprecationWarning):
client.create_collection(name="docs", dimension=384)
Expand Down
16 changes: 16 additions & 0 deletions src/vecs/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, List, Optional

from deprecated import deprecated
Expand Down Expand Up @@ -65,6 +66,21 @@ def __init__(self, connection_string):
with sess.begin():
sess.execute(text("create schema if not exists vecs;"))
sess.execute(text("create extension if not exists vector;"))
self.vector_version: str = sess.execute(
text(
"select installed_version from pg_available_extensions where name = 'vector' limit 1;"
)
).scalar_one()

if self._supports_hnsw():
warnings.warn(
UserWarning(
f"vecs will drop support for pgvector < 0.5.0 in version 1.0. Consider updating to latest postgres"
)
)

def _supports_hnsw(self):
return not self.vector_version.startswith("0.4")

def get_or_create_collection(
self,
Expand Down
154 changes: 98 additions & 56 deletions src/vecs/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,14 @@ class IndexMethod(str, Enum):
expand in the future.
Attributes:
auto (str): Automatically choose the best available index method.
ivfflat (str): The ivfflat index method.
hnsw (str): The hnsw index method.
"""

auto = "auto"
ivfflat = "ivfflat"
hnsw = "hnsw"


class IndexMeasure(str, Enum):
Expand Down Expand Up @@ -594,7 +598,7 @@ def is_indexed_for_measure(self, measure: IndexMeasure):
def create_index(
self,
measure: IndexMeasure = IndexMeasure.cosine_distance,
method: IndexMethod = IndexMethod.ivfflat,
method: IndexMethod = IndexMethod.auto,
replace=True,
) -> None:
"""
Expand All @@ -621,16 +625,28 @@ def create_index(
Args:
measure (IndexMeasure, optional): The measure to index for. Defaults to 'cosine_distance'.
method (IndexMethod, optional): The indexing method to use. Defaults to 'ivfflat'.
method (IndexMethod, optional): The indexing method to use. Defaults to 'auto'.
replace (bool, optional): Whether to replace the existing index. Defaults to True.
Raises:
ArgError: If an invalid index method is used, or if *replace* is False and an index already exists.
"""
if not method == IndexMethod.ivfflat:
# at time of writing, no other methods are supported by pgvector
if not method in (IndexMethod.ivfflat, IndexMethod.hnsw, IndexMethod.auto):
raise ArgError("invalid index method")

if method == IndexMethod.auto:
if self.client._supports_hnsw():
method = IndexMethod.hnsw
else:
method = IndexMethod.ivfflat

if method == IndexMethod.ivfflat:
warnings.warn(
UserWarning(
f"vecs will drop support for ivfflat indexes in version 1.0. upgrade to pgvector >= 0.5.0 and use IndexMethod.hnsw"
)
)

if replace:
self._index = None
else:
Expand All @@ -641,70 +657,96 @@ def create_index(
if ops is None:
raise ArgError("Unknown index measure")

# Clone the table
clone_table = build_table(f"_{self.name}", self.client.meta, self.dimension)
if method == IndexMethod.ivfflat:
# Clone the table
clone_table = build_table(f"_{self.name}", self.client.meta, self.dimension)

# hacky
try:
clone_table.drop(self.client.engine)
except Exception:
pass
# hacky
try:
clone_table.drop(self.client.engine)
except Exception:
pass

with self.client.Session() as sess:
n_records: int = sess.execute(func.count(self.table.c.id)).scalar() # type: ignore
with self.client.Session() as sess:
n_records: int = sess.execute(func.count(self.table.c.id)).scalar() # type: ignore

with self.client.Session() as sess:
with sess.begin():
n_index_seed = min(5000, n_records)
clone_table.create(sess.connection())
stmt_seed_table = clone_table.insert().from_select(
self.table.c,
select(self.table).order_by(func.random()).limit(n_index_seed),
)
sess.execute(stmt_seed_table)
with self.client.Session() as sess:
with sess.begin():
n_index_seed = min(5000, n_records)
clone_table.create(sess.connection())
stmt_seed_table = clone_table.insert().from_select(
self.table.c,
select(self.table).order_by(func.random()).limit(n_index_seed),
)
sess.execute(stmt_seed_table)

n_lists = (
int(max(n_records / 1000, 30))
if n_records < 1_000_000
else int(math.sqrt(n_records))
)
n_lists = (
int(max(n_records / 1000, 30))
if n_records < 1_000_000
else int(math.sqrt(n_records))
)

unique_string = str(uuid.uuid4()).replace("-", "_")[0:7]
unique_string = str(uuid.uuid4()).replace("-", "_")[0:7]

sess.execute(
text(
f"""
create index ix_{ops}_{n_lists}_{unique_string}
on vecs."{clone_table.name}"
using ivfflat (vec {ops}) with (lists={n_lists})
"""
sess.execute(
text(
f"""
create index ix_{ops}_{n_lists}_{unique_string}
on vecs."{clone_table.name}"
using ivfflat (vec {ops}) with (lists={n_lists})
"""
)
)
)

sess.execute(
text(
f"""
create index ix_meta_{unique_string}
on vecs."{clone_table.name}"
using gin ( metadata jsonb_path_ops )
"""
sess.execute(
text(
f"""
create index ix_meta_{unique_string}
on vecs."{clone_table.name}"
using gin ( metadata jsonb_path_ops )
"""
)
)
)

# Fully populate the table
stmt = postgresql.insert(clone_table).from_select(
self.table.c, select(self.table)
)
stmt = stmt.on_conflict_do_nothing()
sess.execute(stmt)
# Fully populate the table
stmt = postgresql.insert(clone_table).from_select(
self.table.c, select(self.table)
)
stmt = stmt.on_conflict_do_nothing()
sess.execute(stmt)

# Replace the table
sess.execute(text(f"drop table vecs.{self.table.name};"))
sess.execute(
text(
f"alter table vecs._{self.table.name} rename to {self.table.name};"
# Replace the table
sess.execute(text(f"drop table vecs.{self.table.name};"))
sess.execute(
text(
f"alter table vecs._{self.table.name} rename to {self.table.name};"
)
)
)

if method == IndexMethod.hnsw:
unique_string = str(uuid.uuid4()).replace("-", "_")[0:7]
with self.client.Session() as sess:
with sess.begin():
sess.execute(
text(
f"""
create index ix_{ops}_{unique_string}
on vecs."{self.table.name}"
using hnsw (vec {ops});
"""
)
)

sess.execute(
text(
f"""
create index ix_meta_{unique_string}
on vecs."{self.table.name}"
using gin ( metadata jsonb_path_ops )
"""
)
)

return None


Expand Down

0 comments on commit a066e2a

Please sign in to comment.