Skip to content

Commit

Permalink
Update reranker limits (#203)
Browse files Browse the repository at this point in the history
* update reranker limits

* update versions

* format

* update names

* fix: voyage linter

---------

Co-authored-by: paulpaliychuk <pavlo.paliychuk.ca@gmail.com>
  • Loading branch information
prasmussen15 and paul-paliychuk authored Oct 28, 2024
1 parent ceb60a3 commit 7bb0c78
Show file tree
Hide file tree
Showing 6 changed files with 1,062 additions and 854 deletions.
6 changes: 4 additions & 2 deletions graphiti_core/embedder/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def __init__(self, config: OpenAIEmbedderConfig | None = None):
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)

async def create(
self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
self, input_data: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
) -> list[float]:
result = await self.client.embeddings.create(input=input, model=self.config.embedding_model)
result = await self.client.embeddings.create(
input=input_data, model=self.config.embedding_model
)
return result.data[0].embedding[: self.config.embedding_dim]
15 changes: 13 additions & 2 deletions graphiti_core/embedder/voyage.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,18 @@ def __init__(self, config: VoyageAIEmbedderConfig | None = None):
self.client = voyageai.AsyncClient(api_key=config.api_key)

async def create(
self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
self, input_data: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
) -> list[float]:
result = await self.client.embed(input, model=self.config.embedding_model)
if isinstance(input_data, str):
input_list = [input_data]
elif isinstance(input_data, List):
input_list = [str(i) for i in input_data if i]
else:
input_list = [str(i) for i in input_data if i is not None]

input_list = [i for i in input_list if i]
if len(input_list) == 0:
return []

result = await self.client.embed(input_list, model=self.config.embedding_model)
return result.embeddings[0][: self.config.embedding_dim]
20 changes: 15 additions & 5 deletions graphiti_core/search/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
edge_similarity_search,
episode_mentions_reranker,
maximal_marginal_relevance,
node_bfs_search,
node_distance_reranker,
node_fulltext_search,
node_similarity_search,
Expand Down Expand Up @@ -138,7 +139,7 @@ async def edge_search(
edge_similarity_search(
driver, query_vector, None, None, group_ids, 2 * limit, config.sim_min_score
),
edge_bfs_search(driver, bfs_origin_node_uuids, config.bfs_max_depth),
edge_bfs_search(driver, bfs_origin_node_uuids, config.bfs_max_depth, 2 * limit),
]
)
)
Expand All @@ -160,7 +161,12 @@ async def edge_search(
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
)
elif config.reranker == EdgeReranker.cross_encoder:
fact_to_uuid_map = {edge.fact: edge.uuid for result in search_results for edge in result}
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]

rrf_result_uuids = rrf(search_result_uuids)
rrf_edges = [edge_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]

fact_to_uuid_map = {edge.fact: edge.uuid for edge in rrf_edges}
reranked_facts = await cross_encoder.rank(query, list(fact_to_uuid_map.keys()))
reranked_uuids = [fact_to_uuid_map[fact] for fact, _ in reranked_facts]
elif config.reranker == EdgeReranker.node_distance:
Expand Down Expand Up @@ -212,6 +218,7 @@ async def node_search(
node_similarity_search(
driver, query_vector, group_ids, 2 * limit, config.sim_min_score
),
node_bfs_search(driver, bfs_origin_node_uuids, config.bfs_max_depth, 2 * limit),
]
)
)
Expand All @@ -232,9 +239,12 @@ async def node_search(
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
)
elif config.reranker == NodeReranker.cross_encoder:
summary_to_uuid_map = {
node.summary: node.uuid for result in search_results for node in result
}
# use rrf as a preliminary reranker
rrf_result_uuids = rrf(search_result_uuids)
rrf_results = [node_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]

summary_to_uuid_map = {node.summary: node.uuid for node in rrf_results}

reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys()))
reranked_uuids = [summary_to_uuid_map[fact] for fact, _ in reranked_summaries]
elif config.reranker == NodeReranker.episode_mentions:
Expand Down
7 changes: 6 additions & 1 deletion graphiti_core/search/search_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ async def edge_bfs_search(
driver: AsyncDriver,
bfs_origin_node_uuids: list[str] | None,
bfs_max_depth: int,
limit: int,
) -> list[EntityEdge]:
# vector similarity search over embedded facts
if bfs_origin_node_uuids is None:
Expand All @@ -256,12 +257,14 @@ async def edge_bfs_search(
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
LIMIT $limit
""")

records, _, _ = await driver.execute_query(
query,
bfs_origin_node_uuids=bfs_origin_node_uuids,
depth=bfs_max_depth,
limit=limit,
database_=DEFAULT_DATABASE,
routing_='r',
)
Expand Down Expand Up @@ -348,6 +351,7 @@ async def node_bfs_search(
driver: AsyncDriver,
bfs_origin_node_uuids: list[str] | None,
bfs_max_depth: int,
limit: int,
) -> list[EntityNode]:
# vector similarity search over entity names
if bfs_origin_node_uuids is None:
Expand All @@ -368,6 +372,7 @@ async def node_bfs_search(
""",
bfs_origin_node_uuids=bfs_origin_node_uuids,
depth=bfs_max_depth,
limit=limit,
database_=DEFAULT_DATABASE,
routing_='r',
)
Expand Down Expand Up @@ -690,4 +695,4 @@ def maximal_marginal_relevance(

candidates_with_mmr.sort(reverse=True, key=lambda c: c[1])

return [candidate[0] for candidate in candidates_with_mmr]
return list(set([candidate[0] for candidate in candidates_with_mmr]))
Loading

0 comments on commit 7bb0c78

Please sign in to comment.