diff --git a/.semversioner/next-release/patch-20240826152927762829.json b/.semversioner/next-release/patch-20240826152927762829.json new file mode 100644 index 0000000000..84c032135b --- /dev/null +++ b/.semversioner/next-release/patch-20240826152927762829.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Update query type hints." +} diff --git a/graphrag/index/verbs/graph/merge/merge_graphs.py b/graphrag/index/verbs/graph/merge/merge_graphs.py index c93037baaf..a551e4cee7 100644 --- a/graphrag/index/verbs/graph/merge/merge_graphs.py +++ b/graphrag/index/verbs/graph/merge/merge_graphs.py @@ -130,9 +130,7 @@ def merge_edges( target_graph.add_edge(source, target, **(edge_data or {})) else: merge_attributes( - target_graph.edges[ - (source, target) # noqa: RUF031 Parenthesis needed, false positive - ], + target_graph.edges[(source, target)], # noqa edge_data, edge_ops, ) diff --git a/graphrag/query/api.py b/graphrag/query/api.py index c374260218..57f5a12305 100644 --- a/graphrag/query/api.py +++ b/graphrag/query/api.py @@ -140,7 +140,7 @@ async def global_search_streaming( get_context_data = True async for stream_chunk in search_result: if get_context_data: - context_data = _reformat_context_data(stream_chunk) + context_data = _reformat_context_data(stream_chunk) # type: ignore yield context_data get_context_data = False else: @@ -301,7 +301,7 @@ async def local_search_streaming( get_context_data = True async for stream_chunk in search_result: if get_context_data: - context_data = _reformat_context_data(stream_chunk) + context_data = _reformat_context_data(stream_chunk) # type: ignore yield context_data get_context_data = False else: diff --git a/graphrag/query/factories.py b/graphrag/query/factories.py index 8b6d58fb7e..f976efdae9 100644 --- a/graphrag/query/factories.py +++ b/graphrag/query/factories.py @@ -21,6 +21,7 @@ from graphrag.query.llm.oai.chat_openai import ChatOpenAI from graphrag.query.llm.oai.embedding import OpenAIEmbedding from graphrag.query.llm.oai.typing import OpenaiApiType +from graphrag.query.structured_search.base import BaseSearch from graphrag.query.structured_search.global_search.community_context import ( GlobalCommunityContext, ) @@ -108,7 +109,7 @@ def get_local_search_engine( covariates: dict[str, list[Covariate]], response_type: str, description_embedding_store: BaseVectorStore, -) -> LocalSearch: +) -> BaseSearch: """Create a local search engine based on data + configuration.""" llm = get_llm(config) text_embedder = get_text_embedder(config) @@ -159,7 +160,7 @@ def get_global_search_engine( reports: list[CommunityReport], entities: list[Entity], response_type: str, -): +) -> BaseSearch: """Create a global search engine based on data + configuration.""" token_encoder = tiktoken.get_encoding(config.encoding_model) gs_config = config.global_search