diff --git a/config/config.yaml b/config/config.yaml index fdc728dc..d8c3c758 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -45,6 +45,5 @@ chat_engine: query_builder: type: FunctionCallingQueryGenerator params: - top_k: 5 prompt: null # Will use the default prompt function_description: null # Will use the default function description \ No newline at end of file diff --git a/resin/chat_engine/query_generator/function_calling.py b/resin/chat_engine/query_generator/function_calling.py index 32dddb0d..1aedb16f 100644 --- a/resin/chat_engine/query_generator/function_calling.py +++ b/resin/chat_engine/query_generator/function_calling.py @@ -18,11 +18,9 @@ class FunctionCallingQueryGenerator(QueryGenerator): def __init__(self, *, llm: BaseLLM, - top_k: int = 10, prompt: Optional[str] = None, function_description: Optional[str] = None): super().__init__(llm=llm) - self._top_k = top_k self._system_prompt = prompt or DEFAULT_SYSTEM_PROMPT self._function_description = \ function_description or DEFAULT_FUNCTION_DESCRIPTION @@ -36,9 +34,7 @@ def generate(self, arguments = self._llm.enforced_function_call(messages, function=self._function) - return [Query(text=q, - top_k=self._top_k, - metadata_filter=None) + return [Query(text=q) for q in arguments["queries"]] async def agenerate(self, diff --git a/resin/knoweldge_base/base.py b/resin/knoweldge_base/base.py index 18a7021f..ceaff038 100644 --- a/resin/knoweldge_base/base.py +++ b/resin/knoweldge_base/base.py @@ -29,7 +29,7 @@ def delete(self, pass @abstractmethod - def verify_connection_health(self) -> None: + def verify_index_connection(self) -> None: pass @abstractmethod diff --git a/resin/knoweldge_base/knowledge_base.py b/resin/knoweldge_base/knowledge_base.py index d0675d77..7846cc33 100644 --- a/resin/knoweldge_base/knowledge_base.py +++ b/resin/knoweldge_base/knowledge_base.py @@ -25,11 +25,6 @@ from resin.models.data_models import Query, Document -INDEX_DELETED_MESSAGE = ( - "index was deleted. " - "Please create it first using `create_with_new_index()`" -) - INDEX_NAME_PREFIX = "resin--" TIMEOUT_INDEX_CREATE = 300 TIMEOUT_INDEX_PROVISION = 30 @@ -42,7 +37,6 @@ class KnowledgeBase(BaseKnowledgeBase): - DEFAULT_RECORD_ENCODER = OpenAIRecordEncoder DEFAULT_CHUNKER = MarkdownChunker DEFAULT_RERANKER = TransparentReranker @@ -54,6 +48,7 @@ def __init__(self, chunker: Optional[Chunker] = None, reranker: Optional[Reranker] = None, default_top_k: int = 5, + index_params: Optional[dict] = None, ): if default_top_k < 1: raise ValueError("default_top_k must be greater than 0") @@ -64,7 +59,8 @@ def __init__(self, self._chunker = chunker if chunker is not None else self.DEFAULT_CHUNKER() self._reranker = reranker if reranker is not None else self.DEFAULT_RERANKER() - self._index: Optional[Index] = self._connect_index(self._index_name) + self._index: Optional[Index] = None + self._index_params = index_params @staticmethod def _connect_pinecone(): @@ -75,67 +71,56 @@ def _connect_pinecone(): raise RuntimeError("Failed to connect to Pinecone. " "Please check your credentials and try again") from e - @classmethod - def _connect_index(cls, - full_index_name: str, + def _connect_index(self, connect_pinecone: bool = True ) -> Index: if connect_pinecone: - cls._connect_pinecone() + self._connect_pinecone() - if full_index_name not in list_indexes(): + if self.index_name not in list_indexes(): raise RuntimeError( - f"Index {full_index_name} does not exist. " - "Please create it first using `create_with_new_index()`" + f"The index {self.index_name} does not exist or was deleted. " + "Please create it by calling knowledge_base.create_resin_index() or " + "running the `resin new` command" ) try: - index = Index(index_name=full_index_name) - index.describe_index_stats() + index = Index(index_name=self.index_name) except Exception as e: raise RuntimeError( - f"Unexpected error while connecting to index {full_index_name}. " + f"Unexpected error while connecting to index {self.index_name}. " f"Please check your credentials and try again." ) from e return index - def verify_connection_health(self) -> None: + @property + def _connection_error_msg(self) -> str: + return ( + f"KnowledgeBase is not connected to index {self.index_name}, " + f"Please call knowledge_base.connect(). " + ) + + def connect(self) -> None: if self._index is None: - raise RuntimeError(INDEX_DELETED_MESSAGE) + self._index = self._connect_index() + self.verify_index_connection() + + def verify_index_connection(self) -> None: + if self._index is None: + raise RuntimeError(self._connection_error_msg) try: self._index.describe_index_stats() except Exception as e: - try: - pinecone_whoami() - except Exception: - raise RuntimeError( - "Failed to connect to Pinecone. " - "Please check your credentials and try again" - ) from e - - if self._index_name not in list_indexes(): - raise RuntimeError( - f"index {self._index_name} does not exist anymore" - "and was probably deleted. " - "Please create it first using `create_with_new_index()`" - ) from e - raise RuntimeError("Index unexpectedly did not respond. " - "Please try again in few moments") from e - - @classmethod - def create_with_new_index(cls, - index_name: str, - *, - record_encoder: Optional[RecordEncoder] = None, - chunker: Optional[Chunker] = None, - reranker: Optional[Reranker] = None, - default_top_k: int = 10, - indexed_fields: Optional[List[str]] = None, - dimension: Optional[int] = None, - create_index_params: Optional[dict] = None - ) -> 'KnowledgeBase': + raise RuntimeError( + "The index did not respond. Please check your credentials and try again" + ) from e + def create_resin_index(self, + indexed_fields: Optional[List[str]] = None, + dimension: Optional[int] = None, + index_params: Optional[dict] = None + ): # validate inputs if indexed_fields is None: indexed_fields = ['document_id'] @@ -147,60 +132,45 @@ def create_with_new_index(cls, "Please remove it from indexed_fields") if dimension is None: - record_encoder = record_encoder if record_encoder is not None else cls.DEFAULT_RECORD_ENCODER() # noqa: E501 - if record_encoder.dimension is not None: - dimension = record_encoder.dimension + if self._encoder.dimension is not None: + dimension = self._encoder.dimension else: raise ValueError("Could not infer dimension from encoder. " "Please provide the vectors' dimension") # connect to pinecone and create index - cls._connect_pinecone() + self._connect_pinecone() - full_index_name = cls._get_full_index_name(index_name) - - if full_index_name in list_indexes(): + if self.index_name in list_indexes(): raise RuntimeError( - f"Index {full_index_name} already exists. " + f"Index {self.index_name} already exists. " "If you wish to delete it, use `delete_index()`. " - "If you wish to connect to it," - "directly initialize a `KnowledgeBase` instance" ) # create index - create_index_params = create_index_params or {} + index_params = index_params or self._index_params or {} try: - create_index(name=full_index_name, + create_index(name=self.index_name, dimension=dimension, metadata_config={ 'indexed': indexed_fields }, timeout=TIMEOUT_INDEX_CREATE, - **create_index_params) + **index_params) except Exception as e: raise RuntimeError( - f"Unexpected error while creating index {full_index_name}." + f"Unexpected error while creating index {self.index_name}." f"Please try again." ) from e # wait for index to be provisioned - cls._wait_for_index_provision(full_index_name=full_index_name) - - # initialize KnowledgeBase - return cls(index_name=index_name, - record_encoder=record_encoder, - chunker=chunker, - reranker=reranker, - default_top_k=default_top_k) - - @classmethod - def _wait_for_index_provision(cls, - full_index_name: str): + self._wait_for_index_provision() + + def _wait_for_index_provision(self): start_time = time.time() while True: try: - cls._connect_index(full_index_name, - connect_pinecone=False) + self._index = self._connect_index(connect_pinecone=False) break except RuntimeError: pass @@ -208,7 +178,7 @@ def _wait_for_index_provision(cls, time_passed = time.time() - start_time if time_passed > TIMEOUT_INDEX_PROVISION: raise RuntimeError( - f"Index {full_index_name} failed to provision " + f"Index {self.index_name} failed to provision " f"for {time_passed} seconds." f"Please try creating KnowledgeBase again in a few minutes." ) @@ -227,7 +197,7 @@ def index_name(self) -> str: def delete_index(self): if self._index is None: - raise RuntimeError(INDEX_DELETED_MESSAGE) + raise RuntimeError(self._connection_error_msg) delete_index(self._index_name) self._index = None @@ -235,11 +205,11 @@ def query(self, queries: List[Query], global_metadata_filter: Optional[dict] = None ) -> List[QueryResult]: - queries: List[KBQuery] = self._encoder.encode_queries(queries) - - results: List[KBQueryResult] = [self._query_index(q, global_metadata_filter) - for q in queries] + if self._index is None: + raise RuntimeError(self._connection_error_msg) + queries = self._encoder.encode_queries(queries) + results = [self._query_index(q, global_metadata_filter) for q in queries] results = self._reranker.rerank(results) return [ @@ -260,7 +230,7 @@ def _query_index(self, query: KBQuery, global_metadata_filter: Optional[dict]) -> KBQueryResult: if self._index is None: - raise RuntimeError(INDEX_DELETED_MESSAGE) + raise RuntimeError(self._connection_error_msg) metadata_filter = deepcopy(query.metadata_filter) if global_metadata_filter is not None: @@ -296,7 +266,7 @@ def upsert(self, namespace: str = "", batch_size: int = 100): if self._index is None: - raise RuntimeError(INDEX_DELETED_MESSAGE) + raise RuntimeError(self._connection_error_msg) for doc in documents: metadata_keys = set(doc.metadata.keys()) @@ -345,7 +315,7 @@ def delete(self, document_ids: List[str], namespace: str = "") -> None: if self._index is None: - raise RuntimeError(INDEX_DELETED_MESSAGE) + raise RuntimeError(self._connection_error_msg) if self._is_starter_env(): for i in range(0, len(document_ids), DELETE_STARTER_BATCH_SIZE): diff --git a/resin/models/data_models.py b/resin/models/data_models.py index f94e7cf1..a1a66c4b 100644 --- a/resin/models/data_models.py +++ b/resin/models/data_models.py @@ -13,9 +13,9 @@ class Query(BaseModel): text: str namespace: str = "" - metadata_filter: Optional[dict] - top_k: Optional[int] - query_params: Optional[dict] = Field(default_factory=dict) + metadata_filter: Optional[dict] = None + top_k: Optional[int] = None + query_params: dict = Field(default_factory=dict) class Document(BaseModel): diff --git a/resin_cli/app.py b/resin_cli/app.py index 010ff45a..ab1c743f 100644 --- a/resin_cli/app.py +++ b/resin_cli/app.py @@ -136,7 +136,7 @@ async def delete( ) async def health_check(): try: - await run_in_threadpool(kb.verify_connection_health) + await run_in_threadpool(kb.verify_index_connection) except Exception as e: err_msg = f"Failed connecting to Pinecone Index {kb._index_name}" logger.exception(err_msg) @@ -192,9 +192,10 @@ def _init_engines(): kb = KnowledgeBase(index_name=INDEX_NAME) context_engine = ContextEngine(knowledge_base=kb) llm = OpenAILLM() - chat_engine = ChatEngine(context_engine=context_engine, llm=llm) + kb.connect() + def start(host="0.0.0.0", port=8000, reload=False): uvicorn.run("resin_cli.app:app", diff --git a/resin_cli/cli.py b/resin_cli/cli.py index 67bf68ab..a14f31d4 100644 --- a/resin_cli/cli.py +++ b/resin_cli/cli.py @@ -12,7 +12,6 @@ from resin.knoweldge_base import KnowledgeBase from resin.models.data_models import Document -from resin.knoweldge_base.knowledge_base import INDEX_NAME_PREFIX from resin.tokenizer import OpenAITokenizer, Tokenizer from resin_cli.data_loader import ( load_from_path, @@ -101,14 +100,13 @@ def health(host, port, ssl): @click.argument("index-name", nargs=1, envvar="INDEX_NAME", type=str, required=True) @click.option("--tokenizer-model", default="gpt-3.5-turbo", help="Tokenizer model") def new(index_name, tokenizer_model): + kb = KnowledgeBase(index_name=index_name) click.echo("Resin is going to create a new index: ", nl=False) - click.echo(click.style(f"{INDEX_NAME_PREFIX}{index_name}", fg="green")) + click.echo(click.style(f"{kb.index_name}", fg="green")) click.confirm(click.style("Do you want to continue?", fg="red"), abort=True) Tokenizer.initialize(OpenAITokenizer, tokenizer_model) with spinner: - _ = KnowledgeBase.create_with_new_index( - index_name=index_name - ) + kb.create_resin_index() click.echo(click.style("Success!", fg="green")) os.environ["INDEX_NAME"] = index_name @@ -134,12 +132,19 @@ def upsert(index_name, data_path, tokenizer_model): " please provide it with --data-path or set it with env var") click.echo(click.style(msg, fg="red"), err=True) sys.exit(1) + + kb = KnowledgeBase(index_name=index_name) + try: + kb.connect() + except RuntimeError as e: + click.echo(click.style(str(e), fg="red"), err=True) + sys.exit(1) + click.echo("Resin is going to upsert data from ", nl=False) click.echo(click.style(f'{data_path}', fg='yellow'), nl=False) click.echo(" to index: ") - click.echo(click.style(f'{INDEX_NAME_PREFIX}{index_name} \n', fg='green')) + click.echo(click.style(f'{kb.index_name} \n', fg='green')) with spinner: - kb = KnowledgeBase(index_name=index_name) try: data = load_from_path(data_path) except IDsNotUniqueError: diff --git a/tests/e2e/test_app.py b/tests/e2e/test_app.py index 81450850..5cb50599 100644 --- a/tests/e2e/test_app.py +++ b/tests/e2e/test_app.py @@ -49,9 +49,10 @@ def index_name(testrun_uid): @pytest.fixture(scope="module", autouse=True) def knowledge_base(index_name): pinecone.init() - KnowledgeBase.create_with_new_index(index_name=index_name,) + kb = KnowledgeBase(index_name=index_name) + kb.create_resin_index() - return KnowledgeBase(index_name=index_name) + return kb @pytest.fixture(scope="module") diff --git a/tests/system/knowledge_base/test_knowledge_base.py b/tests/system/knowledge_base/test_knowledge_base.py index e5398995..8ad06aae 100644 --- a/tests/system/knowledge_base/test_knowledge_base.py +++ b/tests/system/knowledge_base/test_knowledge_base.py @@ -65,13 +65,12 @@ def knowledge_base(index_full_name, index_name, chunker, encoder): if index_full_name in pinecone.list_indexes(): pinecone.delete_index(index_full_name) - KnowledgeBase.create_with_new_index(index_name=index_name, - record_encoder=encoder, - chunker=chunker) + kb = KnowledgeBase(index_name=index_name, + record_encoder=encoder, + chunker=chunker) + kb.create_resin_index() - return KnowledgeBase(index_name=index_name, - record_encoder=encoder, - chunker=chunker) + return kb def total_vectors_in_index(knowledge_base): @@ -169,8 +168,8 @@ def test_create_index(index_full_name, knowledge_base): assert knowledge_base._index.describe_index_stats() -def test_is_verify_connection_health_happy_path(knowledge_base): - knowledge_base.verify_connection_health() +def test_is_verify_index_connection_happy_path(knowledge_base): + knowledge_base.verify_index_connection() def test_init_with_context_engine_prefix(index_full_name, chunker, encoder): @@ -255,6 +254,7 @@ def test_update_documents(encoder, kb = KnowledgeBase(index_name=index_name, record_encoder=encoder, chunker=chunker) + kb.connect() docs = documents[:2] doc_ids = [doc.id for doc in docs] chunk_ids = [chunk.id for chunk in encoded_chunks @@ -299,25 +299,45 @@ def test_delete_large_df_happy_path(knowledge_base, for chunk in chunks_for_validation]) -def test_create_existing_index(index_full_name, index_name): +def test_create_existing_index_no_connect(index_full_name, index_name): + kb = KnowledgeBase( + index_name=index_name, + record_encoder=StubRecordEncoder(StubDenseEncoder(dimension=3)), + chunker=StubChunker(num_chunks_per_doc=2)) with pytest.raises(RuntimeError) as e: - KnowledgeBase.create_with_new_index(index_name=index_name, - record_encoder=StubRecordEncoder( - StubDenseEncoder(dimension=3)), - chunker=StubChunker(num_chunks_per_doc=2)) + kb.create_resin_index() assert f"Index {index_full_name} already exists" in str(e.value) -def test_init_kb_non_existing_index(index_name, chunker, encoder): +def test_kb_non_existing_index(index_name, chunker, encoder): + kb = KnowledgeBase(index_name="non-existing-index", + record_encoder=encoder, + chunker=chunker) + assert kb._index is None with pytest.raises(RuntimeError) as e: - KnowledgeBase(index_name="non-existing-index", - record_encoder=encoder, - chunker=chunker) - expected_msg = f"Index {INDEX_NAME_PREFIX}non-existing-index does not exist" + kb.connect() + expected_msg = f"index {INDEX_NAME_PREFIX}non-existing-index does not exist" assert expected_msg in str(e.value) +@pytest.mark.parametrize("operation", ["upsert", "delete", "query", + "verify_index_connection", "delete_index"]) +def test_error_not_connected(operation, index_name): + kb = KnowledgeBase( + index_name=index_name, + record_encoder=StubRecordEncoder(StubDenseEncoder(dimension=3)), + chunker=StubChunker(num_chunks_per_doc=2)) + + method = getattr(kb, operation) + with pytest.raises(RuntimeError) as e: + if operation == "verify_index_connection" or operation == "delete_index": + method() + else: + method("dummy_input") + assert "KnowledgeBase is not connected to index" in str(e.value) + + def test_delete_index_happy_path(knowledge_base): knowledge_base.delete_index() @@ -325,43 +345,43 @@ def test_delete_index_happy_path(knowledge_base): assert knowledge_base._index is None with pytest.raises(RuntimeError) as e: knowledge_base.delete(["doc_0"]) - - assert "index was deleted." in str(e.value) + assert "KnowledgeBase is not connected" in str(e.value) def test_delete_index_for_non_existing(knowledge_base): with pytest.raises(RuntimeError) as e: knowledge_base.delete_index() - assert "index was deleted." in str(e.value) + assert "KnowledgeBase is not connected" in str(e.value) -def test_verify_connection_health_raise_for_deleted_index(knowledge_base): +def test_connect_after_delete(knowledge_base): with pytest.raises(RuntimeError) as e: - knowledge_base.verify_connection_health() + knowledge_base.connect() - assert "index was deleted" in str(e.value) + assert "does not exist or was deleted" in str(e.value) def test_create_with_text_in_indexed_field_raise(index_name, chunker, encoder): with pytest.raises(ValueError) as e: - KnowledgeBase.create_with_new_index(index_name=index_name, - record_encoder=encoder, - chunker=chunker, - indexed_fields=["id", "text", "metadata"]) + kb = KnowledgeBase(index_name=index_name, + record_encoder=encoder, + chunker=chunker) + kb.create_resin_index(indexed_fields=["id", "text", "metadata"]) assert "The 'text' field cannot be used for metadata filtering" in str(e.value) -def test_create_with_new_index_encoder_dimension_none(index_name, chunker): +def test_create_with_index_encoder_dimension_none(index_name, chunker): encoder = StubRecordEncoder(StubDenseEncoder(dimension=3)) encoder._dense_encoder.dimension = None with pytest.raises(ValueError) as e: - KnowledgeBase.create_with_new_index(index_name=index_name, - record_encoder=encoder, - chunker=chunker) + kb = KnowledgeBase(index_name=index_name, + record_encoder=encoder, + chunker=chunker) + kb.create_resin_index() assert "Could not infer dimension from encoder" in str(e.value) @@ -379,18 +399,20 @@ def set_bad_credentials(): def test_create_bad_credentials(set_bad_credentials, index_name, chunker, encoder): + kb = KnowledgeBase(index_name=index_name, + record_encoder=encoder, + chunker=chunker) with pytest.raises(RuntimeError) as e: - KnowledgeBase.create_with_new_index(index_name=index_name, - record_encoder=encoder, - chunker=chunker) + kb.create_resin_index() assert "Please check your credentials" in str(e.value) def test_init_bad_credentials(set_bad_credentials, index_name, chunker, encoder): + kb = KnowledgeBase(index_name=index_name, + record_encoder=encoder, + chunker=chunker) with pytest.raises(RuntimeError) as e: - KnowledgeBase(index_name=index_name, - record_encoder=encoder, - chunker=chunker) + kb.connect() assert "Please check your credentials and try again" in str(e.value) diff --git a/tests/system/query_generator/test_query_generator_integration.py b/tests/system/query_generator/test_query_generator_integration.py index aba87e12..51a4d275 100644 --- a/tests/system/query_generator/test_query_generator_integration.py +++ b/tests/system/query_generator/test_query_generator_integration.py @@ -26,7 +26,6 @@ def prompt_builder(): def query_generator(openai_llm, prompt_builder): query_gen = FunctionCallingQueryGenerator( llm=openai_llm, - top_k=5, ) query_gen._prompt_builder = prompt_builder return query_gen diff --git a/tests/unit/query_generators/test_function_calling_query_generator.py b/tests/unit/query_generators/test_function_calling_query_generator.py index 470e7961..2d47a71f 100644 --- a/tests/unit/query_generators/test_function_calling_query_generator.py +++ b/tests/unit/query_generators/test_function_calling_query_generator.py @@ -36,7 +36,6 @@ def mock_model_params(): def query_generator(mock_llm, mock_prompt_builder, mock_model_params): query_gen = FunctionCallingQueryGenerator( llm=mock_llm, - top_k=5, ) query_gen._prompt_builder = mock_prompt_builder return query_gen @@ -86,8 +85,8 @@ def test_generate_with_default_params(query_generator, # Ensure the result is correct assert isinstance(result, List) assert len(result) == 2 - assert result[0] == Query(text="query1", top_k=5) - assert result[1] == Query(text="query2", top_k=5) + assert result[0] == Query(text="query1") + assert result[1] == Query(text="query2") @staticmethod def test_generate_with_non_defaults(query_generator, @@ -100,7 +99,6 @@ def test_generate_with_non_defaults(query_generator, gen_custom = FunctionCallingQueryGenerator( llm=mock_llm, - top_k=5, prompt=custom_system_prompt, function_description=custom_function_description, ) @@ -112,7 +110,7 @@ def test_generate_with_non_defaults(query_generator, result = gen_custom.generate(messages=sample_messages, max_prompt_tokens=100) - expected_result = [Query(text="query1", top_k=5)] + expected_result = [Query(text="query1")] assert result == expected_result mock_prompt_builder.build.assert_called_once_with(