Skip to content

Commit

Permalink
Use util to construct embeddings collection name
Browse files Browse the repository at this point in the history
  • Loading branch information
natoverse committed Nov 5, 2024
1 parent 3f50541 commit f206eb7
Show file tree
Hide file tree
Showing 8 changed files with 74 additions and 16 deletions.
2 changes: 1 addition & 1 deletion docs/examples_notebooks/drift_search.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@
"# load description embeddings to an in-memory lancedb vectorstore\n",
"# to connect to a remote db, specify url and port values.\n",
"description_embedding_store = LanceDBVectorStore(\n",
" collection_name=\"entity_description_embeddings\",\n",
" collection_name=\"default-entity-description\",\n",
")\n",
"description_embedding_store.connect(db_uri=LANCEDB_URI)\n",
"entity_description_embeddings = store_entity_semantic_embeddings(\n",
Expand Down
2 changes: 1 addition & 1 deletion docs/examples_notebooks/local_search.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@
"# load description embeddings to an in-memory lancedb vectorstore\n",
"# to connect to a remote db, specify url and port values.\n",
"description_embedding_store = LanceDBVectorStore(\n",
" collection_name=\"entity.description\",\n",
" collection_name=\"default-entity-description\",\n",
")\n",
"description_embedding_store.connect(db_uri=LANCEDB_URI)\n",
"entity_description_embeddings = store_entity_semantic_embeddings(\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@
"entities = read_indexer_entities(entity_df, entity_embedding_df, COMMUNITY_LEVEL)\n",
"\n",
"description_embedding_store = LanceDBVectorStore(\n",
" collection_name=\"entity.description\",\n",
" collection_name=\"default-entity-description\",\n",
")\n",
"description_embedding_store.connect(db_uri=LANCEDB_URI)\n",
"entity_description_embeddings = store_entity_semantic_embeddings(\n",
Expand Down
31 changes: 20 additions & 11 deletions graphrag/api/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
from pydantic import validate_call

from graphrag.config import GraphRagConfig
from graphrag.index.config.embeddings import (
community_full_content_embedding,
entity_description_embedding,
)
from graphrag.logging import PrintProgressReporter
from graphrag.query.factories import (
get_drift_search_engine,
Expand All @@ -41,6 +45,7 @@
)
from graphrag.query.structured_search.base import SearchResult # noqa: TCH001
from graphrag.utils.cli import redact
from graphrag.utils.embeddings import create_collection_name
from graphrag.vector_stores import VectorStoreFactory, VectorStoreType
from graphrag.vector_stores.base import BaseVectorStore

Expand Down Expand Up @@ -203,7 +208,7 @@ async def local_search(

description_embedding_store = _get_embedding_store(
config_args=vector_store_args, # type: ignore
container_suffix="entity-description",
embedding_name=entity_description_embedding,
)

_entities = read_indexer_entities(nodes, entities, community_level)
Expand Down Expand Up @@ -276,8 +281,8 @@ async def local_search_streaming(
reporter.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore

description_embedding_store = _get_embedding_store(
conf_args=vector_store_args, # type: ignore
container_suffix="entity-description",
config_args=vector_store_args, # type: ignore
embedding_name=entity_description_embedding,
)

_entities = read_indexer_entities(nodes, entities, community_level)
Expand Down Expand Up @@ -360,12 +365,12 @@ async def drift_search(

description_embedding_store = _get_embedding_store(
config_args=vector_store_args, # type: ignore
container_suffix="entity-description",
embedding_name=entity_description_embedding,
)

full_content_embedding_store = _get_embedding_store(
config_args=vector_store_args, # type: ignore
container_suffix="community-full_content",
embedding_name=community_full_content_embedding,
)

_entities = read_indexer_entities(nodes, entities, community_level)
Expand Down Expand Up @@ -425,7 +430,9 @@ def _patch_vector_store(
}
description_embedding_store = LanceDBVectorStore(
db_uri=config.embeddings.vector_store["db_uri"],
collection_name="default-entity-description",
collection_name=create_collection_name(
"default", entity_description_embedding
),
overwrite=config.embeddings.vector_store["overwrite"],
)
description_embedding_store.connect(
Expand All @@ -444,7 +451,7 @@ def _patch_vector_store(
from graphrag.vector_stores.lancedb import LanceDBVectorStore

community_reports = with_reports
collection_name = (
container_name = (
config.embeddings.vector_store.get("container_name", "default")
if config.embeddings.vector_store
else "default"
Expand All @@ -460,7 +467,9 @@ def _patch_vector_store(

full_content_embedding_store = LanceDBVectorStore(
db_uri=config.embeddings.vector_store["db_uri"],
collection_name=f"{collection_name}-community-full_content",
collection_name=create_collection_name(
container_name, community_full_content_embedding
),
overwrite=config.embeddings.vector_store["overwrite"],
)
full_content_embedding_store.connect(
Expand All @@ -476,12 +485,12 @@ def _patch_vector_store(

def _get_embedding_store(
config_args: dict,
container_suffix: str,
embedding_name: str,
) -> BaseVectorStore:
"""Get the embedding description store."""
vector_store_type = config_args["type"]
collection_name = (
f"{config_args.get('container_name', 'default')}-{container_suffix}"
collection_name = create_collection_name(
config_args.get("container_name", "default"), embedding_name
)
embedding_store = VectorStoreFactory.get_vector_store(
vector_store_type=vector_store_type,
Expand Down
5 changes: 3 additions & 2 deletions graphrag/index/operations/embed_text/embed_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from datashaper import VerbCallbacks

from graphrag.index.cache import PipelineCache
from graphrag.utils.embeddings import create_collection_name
from graphrag.vector_stores import (
BaseVectorStore,
VectorStoreDocument,
Expand Down Expand Up @@ -229,8 +230,8 @@ def _create_vector_store(


def _get_collection_name(vector_store_config: dict, embedding_name: str) -> str:
container_name = vector_store_config.get("container_name")
collection_name = f"{container_name}.{embedding_name}".replace(".", "-")
container_name = vector_store_config.get("container_name", "default")
collection_name = create_collection_name(container_name, embedding_name)

msg = f"using vector store {vector_store_config.get('type')} with container_name {container_name} for embedding {embedding_name}: {collection_name}"
log.info(msg)
Expand Down
25 changes: 25 additions & 0 deletions graphrag/utils/embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""Utilities for working with embeddings stores."""

from graphrag.index.config.embeddings import all_embeddings


def create_collection_name(
container_name: str, embedding_name: str, validate: bool = True
) -> str:
"""
Create a collection name for the embedding store.
Within any given vector store, we can have multiple sets of embeddings organized into projects.
The `container` param is used for this partitioning, and is added as a prefix to the collection name for differentiation.
The embedding name is fixed, with the available list defined in graphrag.index.config.embeddings
Note that we use dot notation in our names, but many vector stores do not support this - so we convert to dashes.
"""
if validate and embedding_name not in all_embeddings:
msg = f"Invalid embedding name: {embedding_name}"
raise KeyError(msg)
return f"{container_name}-{embedding_name}".replace(".", "-")
2 changes: 2 additions & 0 deletions tests/unit/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
21 changes: 21 additions & 0 deletions tests/unit/utils/test_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

import pytest

from graphrag.utils.embeddings import create_collection_name


def test_create_collection_name():
collection = create_collection_name("default", "entity.name")
assert collection == "default-entity-name"


def test_create_collection_name_invalid_embedding_throws():
with pytest.raises(KeyError):
create_collection_name("default", "invalid.name")


def test_create_collection_name_invalid_embedding_does_not_throw():
collection = create_collection_name("default", "invalid.name", validate=False)
assert collection == "default-invalid-name"

0 comments on commit f206eb7

Please sign in to comment.