Skip to content
This repository has been archived by the owner on Nov 13, 2024. It is now read-only.

Commit

Permalink
Merge branch 'dev' into remove-upsert-df-from-kb
Browse files Browse the repository at this point in the history
  • Loading branch information
acatav authored Oct 15, 2023
2 parents 652c477 + 1ec5e90 commit 18a613d
Show file tree
Hide file tree
Showing 11 changed files with 141 additions and 150 deletions.
1 change: 0 additions & 1 deletion config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,5 @@ chat_engine:
query_builder:
type: FunctionCallingQueryGenerator
params:
top_k: 5
prompt: null # Will use the default prompt
function_description: null # Will use the default function description
6 changes: 1 addition & 5 deletions resin/chat_engine/query_generator/function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,9 @@ class FunctionCallingQueryGenerator(QueryGenerator):
def __init__(self,
*,
llm: BaseLLM,
top_k: int = 10,
prompt: Optional[str] = None,
function_description: Optional[str] = None):
super().__init__(llm=llm)
self._top_k = top_k
self._system_prompt = prompt or DEFAULT_SYSTEM_PROMPT
self._function_description = \
function_description or DEFAULT_FUNCTION_DESCRIPTION
Expand All @@ -36,9 +34,7 @@ def generate(self,
arguments = self._llm.enforced_function_call(messages,
function=self._function)

return [Query(text=q,
top_k=self._top_k,
metadata_filter=None)
return [Query(text=q)
for q in arguments["queries"]]

async def agenerate(self,
Expand Down
2 changes: 1 addition & 1 deletion resin/knoweldge_base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def delete(self,
pass

@abstractmethod
def verify_connection_health(self) -> None:
def verify_index_connection(self) -> None:
pass

@abstractmethod
Expand Down
140 changes: 55 additions & 85 deletions resin/knoweldge_base/knowledge_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,6 @@
from resin.models.data_models import Query, Document


INDEX_DELETED_MESSAGE = (
"index was deleted. "
"Please create it first using `create_with_new_index()`"
)

INDEX_NAME_PREFIX = "resin--"
TIMEOUT_INDEX_CREATE = 300
TIMEOUT_INDEX_PROVISION = 30
Expand All @@ -42,7 +37,6 @@


class KnowledgeBase(BaseKnowledgeBase):

DEFAULT_RECORD_ENCODER = OpenAIRecordEncoder
DEFAULT_CHUNKER = MarkdownChunker
DEFAULT_RERANKER = TransparentReranker
Expand All @@ -54,6 +48,7 @@ def __init__(self,
chunker: Optional[Chunker] = None,
reranker: Optional[Reranker] = None,
default_top_k: int = 5,
index_params: Optional[dict] = None,
):
if default_top_k < 1:
raise ValueError("default_top_k must be greater than 0")
Expand All @@ -64,7 +59,8 @@ def __init__(self,
self._chunker = chunker if chunker is not None else self.DEFAULT_CHUNKER()
self._reranker = reranker if reranker is not None else self.DEFAULT_RERANKER()

self._index: Optional[Index] = self._connect_index(self._index_name)
self._index: Optional[Index] = None
self._index_params = index_params

@staticmethod
def _connect_pinecone():
Expand All @@ -75,67 +71,56 @@ def _connect_pinecone():
raise RuntimeError("Failed to connect to Pinecone. "
"Please check your credentials and try again") from e

@classmethod
def _connect_index(cls,
full_index_name: str,
def _connect_index(self,
connect_pinecone: bool = True
) -> Index:
if connect_pinecone:
cls._connect_pinecone()
self._connect_pinecone()

if full_index_name not in list_indexes():
if self.index_name not in list_indexes():
raise RuntimeError(
f"Index {full_index_name} does not exist. "
"Please create it first using `create_with_new_index()`"
f"The index {self.index_name} does not exist or was deleted. "
"Please create it by calling knowledge_base.create_resin_index() or "
"running the `resin new` command"
)

try:
index = Index(index_name=full_index_name)
index.describe_index_stats()
index = Index(index_name=self.index_name)
except Exception as e:
raise RuntimeError(
f"Unexpected error while connecting to index {full_index_name}. "
f"Unexpected error while connecting to index {self.index_name}. "
f"Please check your credentials and try again."
) from e
return index

def verify_connection_health(self) -> None:
@property
def _connection_error_msg(self) -> str:
return (
f"KnowledgeBase is not connected to index {self.index_name}, "
f"Please call knowledge_base.connect(). "
)

def connect(self) -> None:
if self._index is None:
raise RuntimeError(INDEX_DELETED_MESSAGE)
self._index = self._connect_index()
self.verify_index_connection()

def verify_index_connection(self) -> None:
if self._index is None:
raise RuntimeError(self._connection_error_msg)

try:
self._index.describe_index_stats()
except Exception as e:
try:
pinecone_whoami()
except Exception:
raise RuntimeError(
"Failed to connect to Pinecone. "
"Please check your credentials and try again"
) from e

if self._index_name not in list_indexes():
raise RuntimeError(
f"index {self._index_name} does not exist anymore"
"and was probably deleted. "
"Please create it first using `create_with_new_index()`"
) from e
raise RuntimeError("Index unexpectedly did not respond. "
"Please try again in few moments") from e

@classmethod
def create_with_new_index(cls,
index_name: str,
*,
record_encoder: Optional[RecordEncoder] = None,
chunker: Optional[Chunker] = None,
reranker: Optional[Reranker] = None,
default_top_k: int = 10,
indexed_fields: Optional[List[str]] = None,
dimension: Optional[int] = None,
create_index_params: Optional[dict] = None
) -> 'KnowledgeBase':
raise RuntimeError(
"The index did not respond. Please check your credentials and try again"
) from e

def create_resin_index(self,
indexed_fields: Optional[List[str]] = None,
dimension: Optional[int] = None,
index_params: Optional[dict] = None
):
# validate inputs
if indexed_fields is None:
indexed_fields = ['document_id']
Expand All @@ -147,68 +132,53 @@ def create_with_new_index(cls,
"Please remove it from indexed_fields")

if dimension is None:
record_encoder = record_encoder if record_encoder is not None else cls.DEFAULT_RECORD_ENCODER() # noqa: E501
if record_encoder.dimension is not None:
dimension = record_encoder.dimension
if self._encoder.dimension is not None:
dimension = self._encoder.dimension
else:
raise ValueError("Could not infer dimension from encoder. "
"Please provide the vectors' dimension")

# connect to pinecone and create index
cls._connect_pinecone()
self._connect_pinecone()

full_index_name = cls._get_full_index_name(index_name)

if full_index_name in list_indexes():
if self.index_name in list_indexes():
raise RuntimeError(
f"Index {full_index_name} already exists. "
f"Index {self.index_name} already exists. "
"If you wish to delete it, use `delete_index()`. "
"If you wish to connect to it,"
"directly initialize a `KnowledgeBase` instance"
)

# create index
create_index_params = create_index_params or {}
index_params = index_params or self._index_params or {}
try:
create_index(name=full_index_name,
create_index(name=self.index_name,
dimension=dimension,
metadata_config={
'indexed': indexed_fields
},
timeout=TIMEOUT_INDEX_CREATE,
**create_index_params)
**index_params)
except Exception as e:
raise RuntimeError(
f"Unexpected error while creating index {full_index_name}."
f"Unexpected error while creating index {self.index_name}."
f"Please try again."
) from e

# wait for index to be provisioned
cls._wait_for_index_provision(full_index_name=full_index_name)

# initialize KnowledgeBase
return cls(index_name=index_name,
record_encoder=record_encoder,
chunker=chunker,
reranker=reranker,
default_top_k=default_top_k)

@classmethod
def _wait_for_index_provision(cls,
full_index_name: str):
self._wait_for_index_provision()

def _wait_for_index_provision(self):
start_time = time.time()
while True:
try:
cls._connect_index(full_index_name,
connect_pinecone=False)
self._index = self._connect_index(connect_pinecone=False)
break
except RuntimeError:
pass

time_passed = time.time() - start_time
if time_passed > TIMEOUT_INDEX_PROVISION:
raise RuntimeError(
f"Index {full_index_name} failed to provision "
f"Index {self.index_name} failed to provision "
f"for {time_passed} seconds."
f"Please try creating KnowledgeBase again in a few minutes."
)
Expand All @@ -227,19 +197,19 @@ def index_name(self) -> str:

def delete_index(self):
if self._index is None:
raise RuntimeError(INDEX_DELETED_MESSAGE)
raise RuntimeError(self._connection_error_msg)
delete_index(self._index_name)
self._index = None

def query(self,
queries: List[Query],
global_metadata_filter: Optional[dict] = None
) -> List[QueryResult]:
queries: List[KBQuery] = self._encoder.encode_queries(queries)

results: List[KBQueryResult] = [self._query_index(q, global_metadata_filter)
for q in queries]
if self._index is None:
raise RuntimeError(self._connection_error_msg)

queries = self._encoder.encode_queries(queries)
results = [self._query_index(q, global_metadata_filter) for q in queries]
results = self._reranker.rerank(results)

return [
Expand All @@ -260,7 +230,7 @@ def _query_index(self,
query: KBQuery,
global_metadata_filter: Optional[dict]) -> KBQueryResult:
if self._index is None:
raise RuntimeError(INDEX_DELETED_MESSAGE)
raise RuntimeError(self._connection_error_msg)

metadata_filter = deepcopy(query.metadata_filter)
if global_metadata_filter is not None:
Expand Down Expand Up @@ -296,7 +266,7 @@ def upsert(self,
namespace: str = "",
batch_size: int = 100):
if self._index is None:
raise RuntimeError(INDEX_DELETED_MESSAGE)
raise RuntimeError(self._connection_error_msg)

for doc in documents:
metadata_keys = set(doc.metadata.keys())
Expand Down Expand Up @@ -345,7 +315,7 @@ def delete(self,
document_ids: List[str],
namespace: str = "") -> None:
if self._index is None:
raise RuntimeError(INDEX_DELETED_MESSAGE)
raise RuntimeError(self._connection_error_msg)

if self._is_starter_env():
for i in range(0, len(document_ids), DELETE_STARTER_BATCH_SIZE):
Expand Down
6 changes: 3 additions & 3 deletions resin/models/data_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
class Query(BaseModel):
text: str
namespace: str = ""
metadata_filter: Optional[dict]
top_k: Optional[int]
query_params: Optional[dict] = Field(default_factory=dict)
metadata_filter: Optional[dict] = None
top_k: Optional[int] = None
query_params: dict = Field(default_factory=dict)


class Document(BaseModel):
Expand Down
5 changes: 3 additions & 2 deletions resin_cli/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ async def delete(
)
async def health_check():
try:
await run_in_threadpool(kb.verify_connection_health)
await run_in_threadpool(kb.verify_index_connection)
except Exception as e:
err_msg = f"Failed connecting to Pinecone Index {kb._index_name}"
logger.exception(err_msg)
Expand Down Expand Up @@ -192,9 +192,10 @@ def _init_engines():
kb = KnowledgeBase(index_name=INDEX_NAME)
context_engine = ContextEngine(knowledge_base=kb)
llm = OpenAILLM()

chat_engine = ChatEngine(context_engine=context_engine, llm=llm)

kb.connect()


def start(host="0.0.0.0", port=8000, reload=False):
uvicorn.run("resin_cli.app:app",
Expand Down
19 changes: 12 additions & 7 deletions resin_cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

from resin.knoweldge_base import KnowledgeBase
from resin.models.data_models import Document
from resin.knoweldge_base.knowledge_base import INDEX_NAME_PREFIX
from resin.tokenizer import OpenAITokenizer, Tokenizer
from resin_cli.data_loader import (
load_from_path,
Expand Down Expand Up @@ -101,14 +100,13 @@ def health(host, port, ssl):
@click.argument("index-name", nargs=1, envvar="INDEX_NAME", type=str, required=True)
@click.option("--tokenizer-model", default="gpt-3.5-turbo", help="Tokenizer model")
def new(index_name, tokenizer_model):
kb = KnowledgeBase(index_name=index_name)
click.echo("Resin is going to create a new index: ", nl=False)
click.echo(click.style(f"{INDEX_NAME_PREFIX}{index_name}", fg="green"))
click.echo(click.style(f"{kb.index_name}", fg="green"))
click.confirm(click.style("Do you want to continue?", fg="red"), abort=True)
Tokenizer.initialize(OpenAITokenizer, tokenizer_model)
with spinner:
_ = KnowledgeBase.create_with_new_index(
index_name=index_name
)
kb.create_resin_index()
click.echo(click.style("Success!", fg="green"))
os.environ["INDEX_NAME"] = index_name

Expand All @@ -134,12 +132,19 @@ def upsert(index_name, data_path, tokenizer_model):
" please provide it with --data-path or set it with env var")
click.echo(click.style(msg, fg="red"), err=True)
sys.exit(1)

kb = KnowledgeBase(index_name=index_name)
try:
kb.connect()
except RuntimeError as e:
click.echo(click.style(str(e), fg="red"), err=True)
sys.exit(1)

click.echo("Resin is going to upsert data from ", nl=False)
click.echo(click.style(f'{data_path}', fg='yellow'), nl=False)
click.echo(" to index: ")
click.echo(click.style(f'{INDEX_NAME_PREFIX}{index_name} \n', fg='green'))
click.echo(click.style(f'{kb.index_name} \n', fg='green'))
with spinner:
kb = KnowledgeBase(index_name=index_name)
try:
data = load_from_path(data_path)
except IDsNotUniqueError:
Expand Down
Loading

0 comments on commit 18a613d

Please sign in to comment.