From a90d2104976683a0d2953cb3e80c3ebfb9849781 Mon Sep 17 00:00:00 2001 From: Josh Bradley Date: Mon, 26 Aug 2024 17:31:46 -0400 Subject: [PATCH] Improve search type hint (#1031) * update get_local_search_engine and get_global_search_engine return annotation * add semversioner file * reorder imports * fix pyright errors * revert change and ignore previous pyright error --------- Co-authored-by: wanhua.gu Co-authored-by: longyunfeigu <2514553187@qq.com> Co-authored-by: Alonso Guevara --- .semversioner/next-release/patch-20240826152927762829.json | 4 ++++ graphrag/index/verbs/graph/merge/merge_graphs.py | 4 +--- graphrag/query/api.py | 4 ++-- graphrag/query/factories.py | 5 +++-- 4 files changed, 10 insertions(+), 7 deletions(-) create mode 100644 .semversioner/next-release/patch-20240826152927762829.json diff --git a/.semversioner/next-release/patch-20240826152927762829.json b/.semversioner/next-release/patch-20240826152927762829.json new file mode 100644 index 0000000000..84c032135b --- /dev/null +++ b/.semversioner/next-release/patch-20240826152927762829.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Update query type hints." +} diff --git a/graphrag/index/verbs/graph/merge/merge_graphs.py b/graphrag/index/verbs/graph/merge/merge_graphs.py index c93037baaf..a551e4cee7 100644 --- a/graphrag/index/verbs/graph/merge/merge_graphs.py +++ b/graphrag/index/verbs/graph/merge/merge_graphs.py @@ -130,9 +130,7 @@ def merge_edges( target_graph.add_edge(source, target, **(edge_data or {})) else: merge_attributes( - target_graph.edges[ - (source, target) # noqa: RUF031 Parenthesis needed, false positive - ], + target_graph.edges[(source, target)], # noqa edge_data, edge_ops, ) diff --git a/graphrag/query/api.py b/graphrag/query/api.py index c374260218..57f5a12305 100644 --- a/graphrag/query/api.py +++ b/graphrag/query/api.py @@ -140,7 +140,7 @@ async def global_search_streaming( get_context_data = True async for stream_chunk in search_result: if get_context_data: - context_data = _reformat_context_data(stream_chunk) + context_data = _reformat_context_data(stream_chunk) # type: ignore yield context_data get_context_data = False else: @@ -301,7 +301,7 @@ async def local_search_streaming( get_context_data = True async for stream_chunk in search_result: if get_context_data: - context_data = _reformat_context_data(stream_chunk) + context_data = _reformat_context_data(stream_chunk) # type: ignore yield context_data get_context_data = False else: diff --git a/graphrag/query/factories.py b/graphrag/query/factories.py index 8b6d58fb7e..f976efdae9 100644 --- a/graphrag/query/factories.py +++ b/graphrag/query/factories.py @@ -21,6 +21,7 @@ from graphrag.query.llm.oai.chat_openai import ChatOpenAI from graphrag.query.llm.oai.embedding import OpenAIEmbedding from graphrag.query.llm.oai.typing import OpenaiApiType +from graphrag.query.structured_search.base import BaseSearch from graphrag.query.structured_search.global_search.community_context import ( GlobalCommunityContext, ) @@ -108,7 +109,7 @@ def get_local_search_engine( covariates: dict[str, list[Covariate]], response_type: str, description_embedding_store: BaseVectorStore, -) -> LocalSearch: +) -> BaseSearch: """Create a local search engine based on data + configuration.""" llm = get_llm(config) text_embedder = get_text_embedder(config) @@ -159,7 +160,7 @@ def get_global_search_engine( reports: list[CommunityReport], entities: list[Entity], response_type: str, -): +) -> BaseSearch: """Create a global search engine based on data + configuration.""" token_encoder = tiktoken.get_encoding(config.encoding_model) gs_config = config.global_search