diff --git a/.github/workflows/pre-release-CI.yml b/.github/workflows/pre-release-CI.yml new file mode 100644 index 00000000..50df3692 --- /dev/null +++ b/.github/workflows/pre-release-CI.yml @@ -0,0 +1,60 @@ +name: Pre Release CI + +on: + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs:git + build-and-test: + name: Build & Test on ${{ matrix.os }}-py${{ matrix.python-version }} + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest, windows-latest, macos-latest] + python-version: [3.9, '3.10', 3.11] + defaults: + run: + shell: bash + + steps: + - uses: actions/checkout@v3 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: Install Poetry + uses: snok/install-poetry@v1 + with: + version: 1.3.2 + + - name: Build wheel + run: | + poetry build + + - name: Install the wheel + run: | + pip install dist/pinecone_resin*.whl + + - name: Create dev requirements file + run: | + poetry export -f requirements.txt --without-hashes --only dev -o only-dev.txt + + - name: Install dev requirements + run: | + pip install -r only-dev.txt + + - name: Run tests + run: pytest --html=report.html --self-contained-html tests/unit + + - name: Upload pytest reports + if: always() + uses: actions/upload-artifact@v3 + with: + name: pytest-report-${{ matrix.os }}-py${{ matrix.python-version }} + path: .pytest_cache + diff --git a/pyproject.toml b/pyproject.toml index 529c3e87..17e812c5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,13 +17,12 @@ tiktoken = "^0.3.3" pinecone-datasets = "^0.6.1" pydantic = "^1.10.7" pinecone-text = { version = "^0.6.0", extras = ["openai"] } -flake8-pyproject = "^1.2.3" pandas-stubs = "^2.0.3.230814" -langchain = "^0.0.188" fastapi = "^0.92.0" uvicorn = "^0.20.0" tenacity = "^8.2.1" sse-starlette = "^1.6.5" +types-tqdm = "^4.61.0" [tool.poetry.group.dev.dependencies] diff --git a/src/resin/__init__.py b/src/resin/__init__.py index e69de29b..4a604304 100644 --- a/src/resin/__init__.py +++ b/src/resin/__init__.py @@ -0,0 +1,4 @@ +import importlib.metadata + +# Taken from https://stackoverflow.com/a/67097076 +__version__ = importlib.metadata.version("pinecone-resin") diff --git a/src/resin/context_engine/context_builder/base.py b/src/resin/context_engine/context_builder/base.py index 81c1d281..84f5283a 100644 --- a/src/resin/context_engine/context_builder/base.py +++ b/src/resin/context_engine/context_builder/base.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import List -from resin.knoweldge_base.models import QueryResult +from resin.knowledge_base.models import QueryResult from resin.models.data_models import Context from resin.utils.config import ConfigurableMixin diff --git a/src/resin/context_engine/context_builder/stuffing.py b/src/resin/context_engine/context_builder/stuffing.py index 4bed0cef..c7ebb70e 100644 --- a/src/resin/context_engine/context_builder/stuffing.py +++ b/src/resin/context_engine/context_builder/stuffing.py @@ -3,7 +3,7 @@ from resin.context_engine.context_builder.base import ContextBuilder from resin.context_engine.models import ContextQueryResult, ContextSnippet -from resin.knoweldge_base.models import QueryResult, DocumentWithScore +from resin.knowledge_base.models import QueryResult, DocumentWithScore from resin.tokenizer import Tokenizer from resin.models.data_models import Context diff --git a/src/resin/context_engine/context_engine.py b/src/resin/context_engine/context_engine.py index 1507e926..ae99c46c 100644 --- a/src/resin/context_engine/context_engine.py +++ b/src/resin/context_engine/context_engine.py @@ -4,8 +4,8 @@ from resin.context_engine.context_builder import StuffingContextBuilder from resin.context_engine.context_builder.base import ContextBuilder -from resin.knoweldge_base import KnowledgeBase -from resin.knoweldge_base.base import BaseKnowledgeBase +from resin.knowledge_base import KnowledgeBase +from resin.knowledge_base.base import BaseKnowledgeBase from resin.models.data_models import Context, Query from resin.utils.config import ConfigurableMixin diff --git a/src/resin/knoweldge_base/__init__.py b/src/resin/knowledge_base/__init__.py similarity index 100% rename from src/resin/knoweldge_base/__init__.py rename to src/resin/knowledge_base/__init__.py diff --git a/src/resin/knoweldge_base/base.py b/src/resin/knowledge_base/base.py similarity index 96% rename from src/resin/knoweldge_base/base.py rename to src/resin/knowledge_base/base.py index 32ca4152..f639b8a5 100644 --- a/src/resin/knoweldge_base/base.py +++ b/src/resin/knowledge_base/base.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import List, Optional -from resin.knoweldge_base.models import QueryResult +from resin.knowledge_base.models import QueryResult from resin.models.data_models import Query, Document from resin.utils.config import ConfigurableMixin diff --git a/src/resin/knoweldge_base/chunker/__init__.py b/src/resin/knowledge_base/chunker/__init__.py similarity index 100% rename from src/resin/knoweldge_base/chunker/__init__.py rename to src/resin/knowledge_base/chunker/__init__.py diff --git a/src/resin/knoweldge_base/chunker/base.py b/src/resin/knowledge_base/chunker/base.py similarity index 95% rename from src/resin/knoweldge_base/chunker/base.py rename to src/resin/knowledge_base/chunker/base.py index a0ccee5a..a375073d 100644 --- a/src/resin/knoweldge_base/chunker/base.py +++ b/src/resin/knowledge_base/chunker/base.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import List -from resin.knoweldge_base.models import KBDocChunk +from resin.knowledge_base.models import KBDocChunk from resin.models.data_models import Document from resin.utils.config import ConfigurableMixin diff --git a/src/resin/knoweldge_base/chunker/langchain_text_splitter.py b/src/resin/knowledge_base/chunker/langchain_text_splitter.py similarity index 100% rename from src/resin/knoweldge_base/chunker/langchain_text_splitter.py rename to src/resin/knowledge_base/chunker/langchain_text_splitter.py diff --git a/src/resin/knoweldge_base/chunker/markdown.py b/src/resin/knowledge_base/chunker/markdown.py similarity index 94% rename from src/resin/knoweldge_base/chunker/markdown.py rename to src/resin/knowledge_base/chunker/markdown.py index cedf2bc2..ad8976c3 100644 --- a/src/resin/knoweldge_base/chunker/markdown.py +++ b/src/resin/knowledge_base/chunker/markdown.py @@ -2,7 +2,7 @@ from .langchain_text_splitter import Language, RecursiveCharacterTextSplitter from .recursive_character import RecursiveCharacterChunker -from resin.knoweldge_base.models import KBDocChunk +from resin.knowledge_base.models import KBDocChunk from resin.models.data_models import Document diff --git a/src/resin/knoweldge_base/chunker/recursive_character.py b/src/resin/knowledge_base/chunker/recursive_character.py similarity index 92% rename from src/resin/knoweldge_base/chunker/recursive_character.py rename to src/resin/knowledge_base/chunker/recursive_character.py index ab942355..73651fd0 100644 --- a/src/resin/knoweldge_base/chunker/recursive_character.py +++ b/src/resin/knowledge_base/chunker/recursive_character.py @@ -3,8 +3,8 @@ from .langchain_text_splitter import RecursiveCharacterTextSplitter -from resin.knoweldge_base.chunker.base import Chunker -from resin.knoweldge_base.models import KBDocChunk +from resin.knowledge_base.chunker.base import Chunker +from resin.knowledge_base.models import KBDocChunk from resin.tokenizer import Tokenizer from resin.models.data_models import Document diff --git a/src/resin/knoweldge_base/chunker/token_chunker.py b/src/resin/knowledge_base/chunker/token_chunker.py similarity index 100% rename from src/resin/knoweldge_base/chunker/token_chunker.py rename to src/resin/knowledge_base/chunker/token_chunker.py diff --git a/src/resin/knoweldge_base/knowledge_base.py b/src/resin/knowledge_base/knowledge_base.py similarity index 93% rename from src/resin/knoweldge_base/knowledge_base.py rename to src/resin/knowledge_base/knowledge_base.py index c0871341..a6b68e59 100644 --- a/src/resin/knoweldge_base/knowledge_base.py +++ b/src/resin/knowledge_base/knowledge_base.py @@ -6,6 +6,7 @@ import pandas as pd from pinecone import list_indexes, delete_index, create_index, init \ as pinecone_init, whoami as pinecone_whoami +from pinecone import ApiException as PineconeApiException try: from pinecone import GRPCIndex as Index @@ -15,13 +16,13 @@ from pinecone_datasets import Dataset from pinecone_datasets import DenseModelMetadata, DatasetMetadata -from resin.knoweldge_base.base import BaseKnowledgeBase -from resin.knoweldge_base.chunker import Chunker, MarkdownChunker -from resin.knoweldge_base.record_encoder import (RecordEncoder, +from resin.knowledge_base.base import BaseKnowledgeBase +from resin.knowledge_base.chunker import Chunker, MarkdownChunker +from resin.knowledge_base.record_encoder import (RecordEncoder, OpenAIRecordEncoder) -from resin.knoweldge_base.models import (KBQueryResult, KBQuery, QueryResult, +from resin.knowledge_base.models import (KBQueryResult, KBQuery, QueryResult, KBDocChunkWithScore, DocumentWithScore) -from resin.knoweldge_base.reranker import Reranker, TransparentReranker +from resin.knowledge_base.reranker import Reranker, TransparentReranker from resin.models.data_models import Query, Document @@ -52,7 +53,7 @@ class KnowledgeBase(BaseKnowledgeBase): This is a one-time setup process - the index will exist on Pinecone's managed service until it is deleted. Example: - >>> from resin.knoweldge_base.knowledge_base import KnowledgeBase + >>> from resin.knowledge_base.knowledge_base import KnowledgeBase >>> from tokenizer import Tokenizer >>> Tokenizer.initialize() >>> kb = KnowledgeBase(index_name="my_index") @@ -89,7 +90,7 @@ def __init__(self, Example: create a new index: - >>> from resin.knoweldge_base.knowledge_base import KnowledgeBase + >>> from resin.knowledge_base.knowledge_base import KnowledgeBase >>> from tokenizer import Tokenizer >>> Tokenizer.initialize() >>> kb = KnowledgeBase(index_name="my_index") @@ -168,7 +169,7 @@ def _connect_pinecone(): def _connect_index(self, connect_pinecone: bool = True - ) -> Index: + ) -> None: if connect_pinecone: self._connect_pinecone() @@ -180,13 +181,14 @@ def _connect_index(self, ) try: - index = Index(index_name=self.index_name) + self._index = Index(index_name=self.index_name) + self.verify_index_connection() except Exception as e: + self._index = None raise RuntimeError( f"Unexpected error while connecting to index {self.index_name}. " f"Please check your credentials and try again." ) from e - return index @property def _connection_error_msg(self) -> str: @@ -210,8 +212,7 @@ def connect(self) -> None: RuntimeError: If the knowledge base failed to connect to the underlying Pinecone index. """ # noqa: E501 if self._index is None: - self._index = self._connect_index() - self.verify_index_connection() + self._connect_index() def verify_index_connection(self) -> None: """ @@ -282,8 +283,15 @@ def create_resin_index(self, "Please remove it from indexed_fields") if dimension is None: - if self._encoder.dimension is not None: - dimension = self._encoder.dimension + try: + encoder_dimension = self._encoder.dimension + except Exception as e: + raise RuntimeError( + f"Failed to infer vectors' dimension from encoder due to error: " + f"{e}. Please fix the error or provide the dimension manually" + ) from e + if encoder_dimension is not None: + dimension = encoder_dimension else: raise ValueError("Could not infer dimension from encoder. " "Please provide the vectors' dimension") @@ -307,10 +315,10 @@ def create_resin_index(self, }, timeout=TIMEOUT_INDEX_CREATE, **index_params) - except Exception as e: + except (Exception, PineconeApiException) as e: raise RuntimeError( - f"Unexpected error while creating index {self.index_name}." - f"Please try again." + f"Failed to create index {self.index_name} due to error: " + f"{e.body if isinstance(e, PineconeApiException) else e}" ) from e # wait for index to be provisioned @@ -320,7 +328,7 @@ def _wait_for_index_provision(self): start_time = time.time() while True: try: - self._index = self._connect_index(connect_pinecone=False) + self._connect_index(connect_pinecone=False) break except RuntimeError: pass @@ -387,7 +395,7 @@ def query(self, A list of QueryResult objects. Examples: - >>> from resin.knoweldge_base.knowledge_base import KnowledgeBase + >>> from resin.knowledge_base.knowledge_base import KnowledgeBase >>> from tokenizer import Tokenizer >>> Tokenizer.initialize() >>> kb = KnowledgeBase(index_name="my_index") @@ -432,13 +440,16 @@ def _query_index(self, metadata_filter.update(global_metadata_filter) top_k = query.top_k if query.top_k else self._default_top_k + query_params = deepcopy(query.query_params) + _check_return_type = query.query_params.pop('_check_return_type', False) result = self._index.query(vector=query.values, sparse_vector=query.sparse_values, top_k=top_k, namespace=query.namespace, metadata_filter=metadata_filter, include_metadata=True, - **query.query_params) + _check_return_type=_check_return_type, + **query_params) documents: List[KBDocChunkWithScore] = [] for match in result['matches']: metadata = match['metadata'] @@ -480,7 +491,7 @@ def upsert(self, None Example: - >>> from resin.knoweldge_base.knowledge_base import KnowledgeBase + >>> from resin.knowledge_base.knowledge_base import KnowledgeBase >>> from tokenizer import Tokenizer >>> Tokenizer.initialize() >>> kb = KnowledgeBase(index_name="my_index") @@ -558,7 +569,7 @@ def delete(self, None Example: - >>> from resin.knoweldge_base.knowledge_base import KnowledgeBase + >>> from resin.knowledge_base.knowledge_base import KnowledgeBase >>> from tokenizer import Tokenizer >>> Tokenizer.initialize() >>> kb = KnowledgeBase(index_name="my_index") diff --git a/src/resin/knoweldge_base/models.py b/src/resin/knowledge_base/models.py similarity index 100% rename from src/resin/knoweldge_base/models.py rename to src/resin/knowledge_base/models.py diff --git a/src/resin/knoweldge_base/record_encoder/__init__.py b/src/resin/knowledge_base/record_encoder/__init__.py similarity index 100% rename from src/resin/knoweldge_base/record_encoder/__init__.py rename to src/resin/knowledge_base/record_encoder/__init__.py diff --git a/src/resin/knoweldge_base/record_encoder/base.py b/src/resin/knowledge_base/record_encoder/base.py similarity index 97% rename from src/resin/knoweldge_base/record_encoder/base.py rename to src/resin/knowledge_base/record_encoder/base.py index 98db5e8f..47284781 100644 --- a/src/resin/knoweldge_base/record_encoder/base.py +++ b/src/resin/knowledge_base/record_encoder/base.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import List, Optional -from resin.knoweldge_base.models import KBEncodedDocChunk, KBQuery, KBDocChunk +from resin.knowledge_base.models import KBEncodedDocChunk, KBQuery, KBDocChunk from resin.models.data_models import Query from resin.utils.config import ConfigurableMixin diff --git a/src/resin/knoweldge_base/record_encoder/dense.py b/src/resin/knowledge_base/record_encoder/dense.py similarity index 95% rename from src/resin/knoweldge_base/record_encoder/dense.py rename to src/resin/knowledge_base/record_encoder/dense.py index 77a10aeb..8db12001 100644 --- a/src/resin/knoweldge_base/record_encoder/dense.py +++ b/src/resin/knowledge_base/record_encoder/dense.py @@ -3,7 +3,7 @@ from pinecone_text.dense.base_dense_ecoder import BaseDenseEncoder from .base import RecordEncoder -from resin.knoweldge_base.models import KBQuery, KBEncodedDocChunk, KBDocChunk +from resin.knowledge_base.models import KBQuery, KBEncodedDocChunk, KBDocChunk from resin.models.data_models import Query diff --git a/src/resin/knoweldge_base/record_encoder/openai.py b/src/resin/knowledge_base/record_encoder/openai.py similarity index 91% rename from src/resin/knoweldge_base/record_encoder/openai.py rename to src/resin/knowledge_base/record_encoder/openai.py index d42df14f..b56cb743 100644 --- a/src/resin/knoweldge_base/record_encoder/openai.py +++ b/src/resin/knowledge_base/record_encoder/openai.py @@ -6,8 +6,8 @@ retry_if_exception_type, ) from pinecone_text.dense.openai_encoder import OpenAIEncoder -from resin.knoweldge_base.models import KBDocChunk, KBEncodedDocChunk, KBQuery -from resin.knoweldge_base.record_encoder.dense import DenseRecordEncoder +from resin.knowledge_base.models import KBDocChunk, KBEncodedDocChunk, KBQuery +from resin.knowledge_base.record_encoder.dense import DenseRecordEncoder from resin.models.data_models import Query from resin.utils.openai_exceptions import OPEN_AI_TRANSIENT_EXCEPTIONS diff --git a/src/resin/knoweldge_base/reranker/__init__.py b/src/resin/knowledge_base/reranker/__init__.py similarity index 100% rename from src/resin/knoweldge_base/reranker/__init__.py rename to src/resin/knowledge_base/reranker/__init__.py diff --git a/src/resin/knoweldge_base/reranker/reranker.py b/src/resin/knowledge_base/reranker/reranker.py similarity index 91% rename from src/resin/knoweldge_base/reranker/reranker.py rename to src/resin/knowledge_base/reranker/reranker.py index 3ffc9554..04fe9f1f 100644 --- a/src/resin/knoweldge_base/reranker/reranker.py +++ b/src/resin/knowledge_base/reranker/reranker.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import List -from resin.knoweldge_base.models import KBQueryResult +from resin.knowledge_base.models import KBQueryResult from resin.utils.config import ConfigurableMixin diff --git a/src/resin/llm/openai.py b/src/resin/llm/openai.py index d9ea5f96..af7ca672 100644 --- a/src/resin/llm/openai.py +++ b/src/resin/llm/openai.py @@ -27,8 +27,8 @@ def __init__(self, self.available_models = [k["id"] for k in openai.Model.list().data] if model_name not in self.available_models: raise ValueError( - f"Model {model_name} not found. " + - " Available models: {self.available_models}" + f"Model {model_name} not found. " + f" Available models: {self.available_models}" ) @retry( diff --git a/src/resin_cli/app.py b/src/resin_cli/app.py index ab1c743f..bb84a229 100644 --- a/src/resin_cli/app.py +++ b/src/resin_cli/app.py @@ -1,13 +1,17 @@ import os import logging +import signal import sys import uuid + +import openai +from multiprocessing import current_process from dotenv import load_dotenv from resin.llm import BaseLLM from resin.llm.models import UserMessage from resin.tokenizer import OpenAITokenizer, Tokenizer -from resin.knoweldge_base import KnowledgeBase +from resin.knowledge_base import KnowledgeBase from resin.context_engine import ContextEngine from resin.chat_engine import ChatEngine from starlette.concurrency import run_in_threadpool @@ -23,9 +27,10 @@ ChatRequest, ContextQueryRequest, \ ContextUpsertRequest, HealthStatus, ContextDeleteRequest -load_dotenv() # load env vars before import of openai -from resin.llm.openai import OpenAILLM # noqa: E402 +from resin.llm.openai import OpenAILLM +load_dotenv() # load env vars before import of openai +openai.api_key = os.getenv("OPENAI_API_KEY") app = FastAPI() @@ -157,6 +162,17 @@ async def health_check(): return HealthStatus(pinecone_status="OK", llm_status="OK") +@app.get( + "/shutdown" +) +async def shutdown(): + logger.info("Shutting down") + proc = current_process() + pid = proc._parent_pid if "SpawnProcess" in proc.name else proc.pid + os.kill(pid, signal.SIGINT) + return {"message": "Shutting down"} + + @app.on_event("startup") async def startup(): _init_logging() @@ -197,9 +213,9 @@ def _init_engines(): kb.connect() -def start(host="0.0.0.0", port=8000, reload=False): +def start(host="0.0.0.0", port=8000, reload=False, workers=1): uvicorn.run("resin_cli.app:app", - host=host, port=port, reload=reload) + host=host, port=port, reload=reload, workers=workers) if __name__ == "__main__": diff --git a/src/resin_cli/cli.py b/src/resin_cli/cli.py index 04796c6b..5a8b208b 100644 --- a/src/resin_cli/cli.py +++ b/src/resin_cli/cli.py @@ -1,23 +1,30 @@ import os +from typing import List, Optional import click import time -import sys import requests from dotenv import load_dotenv +from tenacity import retry, stop_after_attempt, wait_fixed +from tqdm import tqdm import pandas as pd import openai +from openai.error import APIError as OpenAI_APIError +from urllib.parse import urljoin -from resin.knoweldge_base import KnowledgeBase +from resin.knowledge_base import KnowledgeBase from resin.models.data_models import Document -from resin.tokenizer import OpenAITokenizer, Tokenizer +from resin.tokenizer import Tokenizer from resin_cli.data_loader import ( load_from_path, + CLIError, IDsNotUniqueError, DocumentsValidationError) +from resin import __version__ + from .app import start as start_service from .cli_spinner import Spinner from .api_models import ChatDebugInfo @@ -25,48 +32,79 @@ dotenv_path = os.path.join(os.path.dirname(__file__), ".env") load_dotenv(dotenv_path) - +if os.getenv("OPENAI_API_KEY"): + openai.api_key = os.getenv("OPENAI_API_KEY") spinner = Spinner() +CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help']) -def is_healthy(url: str): +def check_service_health(url: str): try: - health_url = os.path.join(url, "health") - res = requests.get(health_url) + res = requests.get(urljoin(url, "/health")) res.raise_for_status() return res.ok - except Exception: - return False + except requests.exceptions.ConnectionError: + msg = f""" + Resin service is not running on {url}. + please run `resin start` + """ + raise CLIError(msg) + + except requests.exceptions.HTTPError as e: + if e.response is not None: + error = e.response.json().get("detail", None) or e.response.text + else: + error = str(e) + msg = ( + f"Resin service on {url} is not healthy, failed with error: {error}" + ) + raise CLIError(msg) + + +@retry(wait=wait_fixed(5), stop=stop_after_attempt(6)) +def wait_for_service(chat_service_url: str): + check_service_health(chat_service_url) def validate_connection(): try: KnowledgeBase._connect_pinecone() - except Exception: + except RuntimeError as e: msg = ( - "Failed to connect to Pinecone index, please make sure" - + " you have set the right env vars" + f"{str(e)}\n" + "Credentials should be set by the PINECONE_API_KEY and PINECONE_ENVIRONMENT" + " environment variables. " + "Please visit https://www.pinecone.io/docs/quick-start/ for more details." ) - click.echo(click.style(msg, fg="red"), err=True) - sys.exit(1) + raise CLIError(msg) try: openai.Model.list() except Exception: msg = ( - "Failed to connect to OpenAI, please make sure" - + " you have set the right env vars" + "Failed to connect to OpenAI, please make sure that the OPENAI_API_KEY " + "environment variable is set correctly.\n" + "Please visit https://platform.openai.com/account/api-keys for more details" ) - click.echo(click.style(msg, fg="red"), err=True) - sys.exit(1) + raise CLIError(msg) click.echo("Resin: ", nl=False) click.echo(click.style("Ready\n", bold=True, fg="green")) -@click.group(invoke_without_command=True) +def _initialize_tokenizer(): + try: + Tokenizer.initialize() + except Exception as e: + msg = f"Failed to initialize tokenizer. Reason:\n{e}" + raise CLIError(msg) + + +@click.group(invoke_without_command=True, context_settings=CONTEXT_SETTINGS) +@click.version_option(__version__, "-v", "--version", prog_name="Resin") @click.pass_context def cli(ctx): """ + \b CLI for Pinecone Resin. Actively developed by Pinecone. To use the CLI, you need to have a Pinecone account. Visit https://www.pinecone.io/ to sign up for free. @@ -74,74 +112,97 @@ def cli(ctx): if ctx.invoked_subcommand is None: validate_connection() click.echo(ctx.get_help()) - # click.echo(command.get_help(ctx)) -@cli.command(help="Check if Resin service is running") -@click.option("--host", default="0.0.0.0", help="Host") -@click.option("--port", default=8000, help="Port") -@click.option("--ssl/--no-ssl", default=False, help="SSL") -def health(host, port, ssl): - ssl_str = "s" if ssl else "" - service_url = f"http{ssl_str}://{host}:{port}" - if not is_healthy(service_url): - msg = ( - f"Resin service is not running! on {service_url}" - + " please run `resin start`" - ) - click.echo(click.style(msg, fg="red"), err=True) - sys.exit(1) - else: - click.echo(click.style("Resin service is healthy!", fg="green")) - return +@cli.command(help="Check if resin service is running and healthy.") +@click.option("--url", default="http://0.0.0.0:8000", + help="Resin's service url. Defaults to http://0.0.0.0:8000") +def health(url): + check_service_health(url) + click.echo(click.style("Resin service is healthy!", fg="green")) + return -@cli.command() +@cli.command( + help=( + """Create a new Pinecone index that that will be used by Resin. + \b + A Resin service can not be started without a Pinecone index which is configured to work with Resin. + This command will create a new Pinecone index and configure it in the right schema. + + If the embedding vectors' dimension is not explicitly configured in + the config file - the embedding model will be tapped with a single token to + infer the dimensionality of the embedding space. + """ # noqa: E501 + ) +) @click.argument("index-name", nargs=1, envvar="INDEX_NAME", type=str, required=True) -@click.option("--tokenizer-model", default="gpt-3.5-turbo", help="Tokenizer model") -def new(index_name, tokenizer_model): - Tokenizer.initialize(OpenAITokenizer, model_name=tokenizer_model) +def new(index_name): + _initialize_tokenizer() kb = KnowledgeBase(index_name=index_name) click.echo("Resin is going to create a new index: ", nl=False) click.echo(click.style(f"{kb.index_name}", fg="green")) click.confirm(click.style("Do you want to continue?", fg="red"), abort=True) with spinner: - kb.create_resin_index() + try: + kb.create_resin_index() + # TODO: kb should throw a specific exception for each case + except Exception as e: + msg = f"Failed to create a new index. Reason:\n{e}" + raise CLIError(msg) click.echo(click.style("Success!", fg="green")) os.environ["INDEX_NAME"] = index_name -@cli.command() +@cli.command( + help=( + """ + \b + Upload local data files containing documents to the Resin service. + + Load all the documents from data file or a directory containing multiple data files. + The allowed formats are .jsonl and .parquet. + """ # noqa: E501 + ) +) @click.argument("data-path", type=click.Path(exists=True)) @click.option( "--index-name", default=os.environ.get("INDEX_NAME"), - help="Index name", + help="The name of the index to upload the data to. " + "Inferred from INDEX_NAME env var if not provided." ) -@click.option("--tokenizer-model", default="gpt-3.5-turbo", help="Tokenizer model") -def upsert(index_name, data_path, tokenizer_model): +@click.option("--batch-size", default=10, + help="Number of documents to upload in each batch. Defaults to 10.") +@click.option("--allow-failures/--dont-allow-failures", default=False, + help="On default, the upsert process will stop if any document fails to " + "be uploaded. " + "When set to True, the upsert process will continue on failure, as " + "long as less than 10% of the documents have failed to be uploaded.") +def upsert(index_name: str, data_path: str, batch_size: int, allow_failures: bool): if index_name is None: - msg = ("Index name is not provided, please provide it with" + - ' --index-name or set it with env var + ' - '`export INDEX_NAME="MY_INDEX_NAME`') - click.echo(click.style(msg, fg="red"), err=True) - sys.exit(1) - Tokenizer.initialize(OpenAITokenizer, model_name=tokenizer_model) - if data_path is None: - msg = ("Data path is not provided," + - " please provide it with --data-path or set it with env var") - click.echo(click.style(msg, fg="red"), err=True) - sys.exit(1) + msg = ( + "No index name provided. Please set --index-name or INDEX_NAME environment " + "variable." + ) + raise CLIError(msg) + + _initialize_tokenizer() kb = KnowledgeBase(index_name=index_name) try: kb.connect() except RuntimeError as e: - click.echo(click.style(str(e), fg="red"), err=True) - sys.exit(1) + # TODO: kb should throw a specific exception for each case + msg = str(e) + if "credentials" in msg: + msg += ("\nCredentials should be set by the PINECONE_API_KEY and " + "PINECONE_ENVIRONMENT environment variables. Please visit " + "https://www.pinecone.io/docs/quick-start/ for more details.") + raise CLIError(msg) click.echo("Resin is going to upsert data from ", nl=False) - click.echo(click.style(f'{data_path}', fg='yellow'), nl=False) + click.echo(click.style(f"{data_path}", fg="yellow"), nl=False) click.echo(" to index: ") click.echo(click.style(f'{kb.index_name} \n', fg='green')) with spinner: @@ -149,36 +210,58 @@ def upsert(index_name, data_path, tokenizer_model): data = load_from_path(data_path) except IDsNotUniqueError: msg = ( - "Error: the id field on the data is not unique" - + " this will cause records to override each other on upsert" - + " please make sure the id field is unique" + "The data contains duplicate IDs, please make sure that each document" + " has a unique ID, otherwise documents with the same ID will overwrite" + " each other" ) - click.echo(click.style(msg, fg="red"), err=True) - sys.exit(1) + raise CLIError(msg) except DocumentsValidationError: msg = ( - "Error: one or more rows have not passed validation" - + " data should agree with the Document Schema" - + f" on {Document.__annotations__}" - + " please make sure the data is valid" + f"One or more rows have failed data validation. The rows in the" + f"data file should be in the schema: {Document.__annotations__}." ) - click.echo(click.style(msg, fg="red"), err=True) - sys.exit(1) + raise CLIError(msg) except Exception: msg = ( - "Error: an unexpected error has occured in loading data from files" - + " it may be due to issue with the data format" - + " please make sure the data is valid, and can load with pandas" + f"A unexpected error while loading the data from files in {data_path}. " + "Please make sure the data is in valid `jsonl` or `parquet` format." ) - click.echo(click.style(msg, fg="red"), err=True) - sys.exit(1) + raise CLIError(msg) pd.options.display.max_colwidth = 20 - click.echo(pd.DataFrame([doc.dict(exclude_none=True) for doc in data[:5]])) click.echo(click.style(f"\nTotal records: {len(data)}")) click.confirm(click.style("\nDoes this data look right?", fg="red"), abort=True) - kb.upsert(data) + + pbar = tqdm(total=len(data), desc="Upserting documents") + failed_docs: List[str] = [] + first_error: Optional[str] = None + for i in range(0, len(data), batch_size): + batch = data[i:i + batch_size] + try: + kb.upsert(data) + except Exception as e: + if allow_failures and len(failed_docs) < len(data) // 10: + failed_docs.extend([_.id for _ in batch]) + if first_error is None: + first_error = str(e) + else: + msg = ( + f"Failed to upsert data to index {kb.index_name}. " + f"Underlying error: {e}\n" + f"You can allow partial failures by setting --allow-failures. " + ) + raise CLIError(msg) + + pbar.update(len(batch)) + + if failed_docs: + msg = ( + f"Failed to upsert the following documents to index {kb.index_name}: " + f"{failed_docs}. The first encountered error was: {first_error}" + ) + raise CLIError(msg) + click.echo(click.style("Success!", fg="green")) @@ -195,12 +278,17 @@ def _chat( output = "" history += [{"role": "user", "content": message}] start = time.time() - openai_response = openai.ChatCompletion.create( - model=model, messages=history, stream=stream, api_base=api_base - ) + try: + openai_response = openai.ChatCompletion.create( + model=model, messages=history, stream=stream, api_base=api_base + ) + except (Exception, OpenAI_APIError) as e: + err = e.http_body if isinstance(e, OpenAI_APIError) else str(e) + msg = f"Oops... something went wrong. The error I got is: {err}" + raise CLIError(msg) end = time.time() duration_in_sec = end - start - click.echo(click.style(f"\n {speaker}:\n", fg=speaker_color)) + click.echo(click.style(f"\n> AI {speaker}:\n", fg=speaker_color)) if stream: for chunk in openai_response: openai_response_id = chunk.id @@ -235,28 +323,51 @@ def _chat( return debug_info -@cli.command() -@click.option("--stream/--no-stream", default=True, help="Stream") -@click.option("--debug/--no-debug", default=False, help="Print debug info") -@click.option( - "--rag/--no-rag", - default=True, - help="Direct chat with the model", -) -@click.option("--chat-service-url", default="http://0.0.0.0:8000") -@click.option( - "--index-name", - default=os.environ.get("INDEX_NAME"), - help="Index name suffix", +@cli.command( + help=( + """ + Debugging tool for chatting with the Resin RAG service. + + Run an interactive chat with the Resin RAG service, for debugging and demo + purposes. A prompt is provided for the user to enter a message, and the + RAG-infused ChatBot will respond. You can continue the conversation by entering + more messages. Hit Ctrl+C to exit. + + To compare RAG-infused ChatBot with the original LLM, run with the `--baseline` + flag, which would display both models' responses side by side. + """ + + ) ) -def chat(index_name, chat_service_url, rag, debug, stream): - if not is_healthy(chat_service_url): - msg = ( - f"Resin service is not running! on {chat_service_url}" - + " please run `resin start`" - ) - click.echo(click.style(msg, fg="red"), err=True) - sys.exit(1) +@click.option("--stream/--no-stream", default=True, + help="Stream the response from the RAG chatbot word by word") +@click.option("--debug/--no-debug", default=False, + help="Print additional debugging information") +@click.option("--baseline/--no-baseline", default=False, + help="Compare RAG-infused Chatbot with baseline LLM",) +@click.option("--chat-service-url", default="http://0.0.0.0:8000", + help="URL of the Resin service to use. Defaults to http://0.0.0.0:8000") +def chat(chat_service_url, baseline, debug, stream): + check_service_health(chat_service_url) + note_msg = ( + "🚨 Note 🚨\n" + "Chat is a debugging tool, it is not meant to be used for production!" + ) + for c in note_msg: + click.echo(click.style(c, fg="red"), nl=False) + time.sleep(0.01) + click.echo() + note_white_message = ( + "This method should be used by developers to test the RAG data and model" + "during development. " + "When you are ready to deploy, run the Resin service as a REST API " + "backend for your chatbot UI. \n\n" + "Let's Chat!" + ) + for c in note_white_message: + click.echo(click.style(c, fg="white"), nl=False) + time.sleep(0.01) + click.echo() history_with_pinecone = [] history_without_pinecone = [] @@ -276,7 +387,7 @@ def chat(index_name, chat_service_url, rag, debug, stream): print_debug_info=debug, ) - if not rag: + if baseline: _ = _chat( speaker="Without Context (No RAG)", speaker_color="yellow", @@ -298,71 +409,51 @@ def chat(index_name, chat_service_url, rag, debug, stream): click.echo(click.style("˙", fg="bright_black", bold=True)) -@cli.command() -@click.option("--host", default="0.0.0.0", help="Host") -@click.option("--port", default=8000, help="Port") -@click.option("--ssl/--no-ssl", default=False, help="SSL") -@click.option("--reload/--no-reload", default=False, help="Reload") -def start(host, port, ssl, reload): - click.echo(f"Starting Resin service on {host}:{port}") - start_service(host, port, reload) - +@cli.command( + help=( + """ + \b + Start the Resin service. + This command will launch a uvicorn server that will serve the Resin API. -@cli.command() -@click.option("--host", default="0.0.0.0", help="Host") -@click.option("--port", default=8000, help="Port") -@click.option("--ssl/--no-ssl", default=False, help="SSL") -def stop(host, port, ssl): - ssl_str = "s" if ssl else "" - service_url = f"http{ssl_str}://{host}:{port}" - - if not is_healthy(service_url): - msg = ( - f"Resin service is not running! on {service_url}" - + " please run `resin start`" - ) - click.echo(click.style(msg, fg="red"), err=True) - sys.exit(1) - - import subprocess + If you like to try out the chatbot, run `resin chat` in a separate terminal + window. + """ + ) +) +@click.option("--host", default="0.0.0.0", + help="Hostname or ip address to bind the server to. Defaults to 0.0.0.0") +@click.option("--port", default=8000, + help="TCP port to bind the server to. Defaults to 8000") +@click.option("--reload/--no-reload", default=False, + help="Set the server to reload on code changes. Defaults to False") +@click.option("--workers", default=1, help="Number of worker processes. Defaults to 1") +def start(host, port, reload, workers): + click.echo(f"Starting Resin service on {host}:{port}") + start_service(host, port=port, reload=reload, workers=workers) - p1 = subprocess.Popen(["lsof", "-t", "-i", f"tcp:{port}"], stdout=subprocess.PIPE) - running_server_id = p1.stdout.read().decode("utf-8").strip() - if running_server_id == "": - click.echo( - click.style( - "Did not find active process for Resin service" + f" on {host}:{port}", - fg="red", - ) - ) - sys.exit(1) - msg = ( - "Warning, this will invoke in process kill" - + " to the PID of the service, this method is not recommended!" - + " We recommend ctrl+c on the terminal where you started the service" - + " as this will allow the service to gracefully shutdown" - ) - click.echo(click.style(msg, fg="yellow")) - - click.confirm( - click.style( - f"Stopping Resin service on {host}:{port} with pid " f"{running_server_id}", - fg="red", - ), - abort=True, +@cli.command( + help=( + """ + \b + Stop the Resin service. + This command will send a shutdown request to the Resin service. + """ ) - p2 = subprocess.Popen( - ["kill", "-9", running_server_id], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - kill_result = p2.stderr.read().decode("utf-8").strip() - if kill_result == "": - click.echo(click.style("Success!", fg="green")) - else: - click.echo(click.style(kill_result, fg="red")) - click.echo(click.style("Failed!", fg="red")) +) +@click.option("url", "--url", default="http://0.0.0.0:8000", + help="URL of the Resin service to use. Defaults to http://0.0.0.0:8000") +def stop(url): + try: + res = requests.get(urljoin(url, "/shutdown")) + res.raise_for_status() + return res.ok + except requests.exceptions.ConnectionError: + msg = f""" + Could not find Resin service on {url}. + """ + raise CLIError(msg) if __name__ == "__main__": diff --git a/src/resin_cli/data_loader/__init__.py b/src/resin_cli/data_loader/__init__.py index 8a298030..85464408 100644 --- a/src/resin_cli/data_loader/__init__.py +++ b/src/resin_cli/data_loader/__init__.py @@ -1,5 +1,6 @@ from .data_loader import ( load_from_path, + CLIError, IDsNotUniqueError, DocumentsValidationError ) diff --git a/src/resin_cli/data_loader/data_loader.py b/src/resin_cli/data_loader/data_loader.py index 99f90d6a..b84434b6 100644 --- a/src/resin_cli/data_loader/data_loader.py +++ b/src/resin_cli/data_loader/data_loader.py @@ -3,9 +3,12 @@ import glob from collections.abc import Iterable from typing import List +from textwrap import dedent +import click import numpy as np import pandas as pd +from click import ClickException from pydantic import ValidationError @@ -13,13 +16,20 @@ class IDsNotUniqueError(ValueError): - def __init__(self, message): - super().__init__(message) + pass class DocumentsValidationError(ValueError): - def __init__(self, message): - super().__init__(message) + pass + + +def format_multiline(msg): + return dedent(msg).strip() + + +class CLIError(ClickException): + def format_message(self) -> str: + return click.style(format_multiline(self.message), fg='red') def _process_metadata(value): diff --git a/tests/e2e/test_app.py b/tests/e2e/test_app.py index 5cb50599..19a16b06 100644 --- a/tests/e2e/test_app.py +++ b/tests/e2e/test_app.py @@ -8,7 +8,7 @@ from fastapi.testclient import TestClient from tenacity import retry, stop_after_attempt, wait_fixed -from resin.knoweldge_base import KnowledgeBase +from resin.knowledge_base import KnowledgeBase from resin_cli.app import app from resin_cli.api_models import HealthStatus, ContextUpsertRequest, ContextQueryRequest diff --git a/tests/system/knowledge_base/test_knowledge_base.py b/tests/system/knowledge_base/test_knowledge_base.py index 3ede466a..88d57ec5 100644 --- a/tests/system/knowledge_base/test_knowledge_base.py +++ b/tests/system/knowledge_base/test_knowledge_base.py @@ -12,12 +12,12 @@ ) from dotenv import load_dotenv from datetime import datetime -from resin.knoweldge_base import KnowledgeBase -from resin.knoweldge_base.chunker import Chunker -from resin.knoweldge_base.knowledge_base import INDEX_NAME_PREFIX -from resin.knoweldge_base.models import DocumentWithScore -from resin.knoweldge_base.record_encoder import RecordEncoder -from resin.knoweldge_base.reranker import Reranker +from resin.knowledge_base import KnowledgeBase +from resin.knowledge_base.chunker import Chunker +from resin.knowledge_base.knowledge_base import INDEX_NAME_PREFIX +from resin.knowledge_base.models import DocumentWithScore +from resin.knowledge_base.record_encoder import RecordEncoder +from resin.knowledge_base.reranker import Reranker from resin.models.data_models import Document, Query from tests.unit.stubs.stub_record_encoder import StubRecordEncoder from tests.unit.stubs.stub_dense_encoder import StubDenseEncoder @@ -59,7 +59,7 @@ def chunker(): @pytest.fixture(scope="module") def encoder(): return StubRecordEncoder( - StubDenseEncoder(dimension=3)) + StubDenseEncoder()) @pytest.fixture(scope="module", autouse=True) @@ -117,6 +117,28 @@ def assert_ids_not_in_index(knowledge_base, ids): assert len(fetch_result) == 0, f"Found unexpected ids: {len(fetch_result.keys())}" +@retry_decorator() +def execute_and_assert_queries(knowledge_base, chunks_to_query): + queries = [Query(text=chunk.text, top_k=2) for chunk in chunks_to_query] + + query_results = knowledge_base.query(queries) + + assert len(query_results) == len(queries) + + for i, q_res in enumerate(query_results): + assert queries[i].text == q_res.query + assert len(q_res.documents) == 2 + q_res.documents[0].score = round(q_res.documents[0].score, 1) + assert q_res.documents[0] == DocumentWithScore( + id=chunks_to_query[i].id, + text=chunks_to_query[i].text, + metadata=chunks_to_query[i].metadata, + source=chunks_to_query[i].source, + score=1.0), \ + f"query {i} - expected: {chunks_to_query[i]}, " \ + f"actual: {q_res.documents}" + + @pytest.fixture(scope="module", autouse=True) def teardown_knowledge_base(index_full_name, knowledge_base): yield @@ -158,6 +180,27 @@ def encoded_chunks_large(documents_large, chunker, encoder): return encoder.encode_documents(chunks) +@pytest.fixture +def documents_with_datetime_metadata(): + return [Document(id="doc_1_metadata", + text="document with datetime metadata", + source="source_1", + metadata={"datetime": "2021-01-01T00:00:00Z", + "datetime_other_format": "January 1, 2021 00:00:00", + "datetime_other_format_2": "2210.03945"}), + Document(id="2021-01-01T00:00:00Z", + text="id is datetime", + source="source_1")] + + +@pytest.fixture +def datetime_metadata_encoded_chunks(documents_with_datetime_metadata, + chunker, + encoder): + chunks = chunker.chunk_documents(documents_with_datetime_metadata) + return encoder.encode_documents(chunks) + + @pytest.fixture def encoded_chunks(documents, chunker, encoder): chunks = chunker.chunk_documents(documents) @@ -203,28 +246,7 @@ def test_upsert_forbidden_metadata(knowledge_base, documents, key): def test_query(knowledge_base, encoded_chunks): - queries = [Query(text=encoded_chunks[0].text), - Query(text=encoded_chunks[1].text, top_k=2)] - query_results = knowledge_base.query(queries) - - assert len(query_results) == 2 - - expected_top_k = [5, 2] - expected_first_results = [DocumentWithScore(id=chunk.id, - text=chunk.text, - metadata=chunk.metadata, - source=chunk.source, - score=1.0) - for chunk in encoded_chunks[:2]] - for i, q_res in enumerate(query_results): - assert queries[i].text == q_res.query - assert len(q_res.documents) == expected_top_k[i] - q_res.documents[0].score = round(q_res.documents[0].score, 2) - assert q_res.documents[0] == expected_first_results[i] - q_res.documents[0].score = round(q_res.documents[0].score, 2) - assert q_res.documents[0] == expected_first_results[i], \ - f"query {i} - expected: {expected_first_results[i]}, " \ - f"actual: {q_res.documents[0]}" + execute_and_assert_queries(knowledge_base, encoded_chunks) def test_delete_documents(knowledge_base, encoded_chunks): @@ -302,6 +324,20 @@ def test_delete_large_df_happy_path(knowledge_base, for chunk in chunks_for_validation]) +def test_upsert_documents_with_datetime_metadata(knowledge_base, + documents_with_datetime_metadata, + datetime_metadata_encoded_chunks): + knowledge_base.upsert(documents_with_datetime_metadata) + + assert_ids_in_index(knowledge_base, [chunk.id + for chunk in datetime_metadata_encoded_chunks]) + + +def test_query_edge_case_documents(knowledge_base, + datetime_metadata_encoded_chunks): + execute_and_assert_queries(knowledge_base, datetime_metadata_encoded_chunks) + + def test_create_existing_index_no_connect(index_full_name, index_name): kb = KnowledgeBase( index_name=index_name, diff --git a/tests/unit/chunker/test_markdown_chunker.py b/tests/unit/chunker/test_markdown_chunker.py index 9231b7b1..ada9b2a3 100644 --- a/tests/unit/chunker/test_markdown_chunker.py +++ b/tests/unit/chunker/test_markdown_chunker.py @@ -1,7 +1,7 @@ import pytest -from resin.knoweldge_base.chunker import MarkdownChunker -from resin.knoweldge_base.models import KBDocChunk +from resin.knowledge_base.chunker import MarkdownChunker +from resin.knowledge_base.models import KBDocChunk from resin.models.data_models import Document from tests.unit.chunker.base_test_chunker import BaseTestChunker diff --git a/tests/unit/chunker/test_recursive_character_chunker.py b/tests/unit/chunker/test_recursive_character_chunker.py index 282e4875..3a3cc8fe 100644 --- a/tests/unit/chunker/test_recursive_character_chunker.py +++ b/tests/unit/chunker/test_recursive_character_chunker.py @@ -1,7 +1,7 @@ import pytest -from resin.knoweldge_base.chunker.recursive_character \ +from resin.knowledge_base.chunker.recursive_character \ import RecursiveCharacterChunker -from resin.knoweldge_base.models import KBDocChunk +from resin.knowledge_base.models import KBDocChunk from tests.unit.chunker.base_test_chunker import BaseTestChunker diff --git a/tests/unit/chunker/test_stub_chunker.py b/tests/unit/chunker/test_stub_chunker.py index bb84e418..78eb91a3 100644 --- a/tests/unit/chunker/test_stub_chunker.py +++ b/tests/unit/chunker/test_stub_chunker.py @@ -1,6 +1,6 @@ import pytest -from resin.knoweldge_base.models import KBDocChunk +from resin.knowledge_base.models import KBDocChunk from .base_test_chunker import BaseTestChunker from ..stubs.stub_chunker import StubChunker diff --git a/tests/unit/chunker/test_token_chunker.py b/tests/unit/chunker/test_token_chunker.py index b0de1040..463dc99f 100644 --- a/tests/unit/chunker/test_token_chunker.py +++ b/tests/unit/chunker/test_token_chunker.py @@ -1,9 +1,9 @@ import pytest -from resin.knoweldge_base.models import KBDocChunk +from resin.knowledge_base.models import KBDocChunk from resin.models.data_models import Document from .base_test_chunker import BaseTestChunker -from resin.knoweldge_base.chunker.token_chunker import TokenChunker +from resin.knowledge_base.chunker.token_chunker import TokenChunker class TestTokenChunker(BaseTestChunker): diff --git a/tests/unit/cli/test_data_loader.py b/tests/unit/cli/test_data_loader.py index 3e4eb0fc..ff11c651 100644 --- a/tests/unit/cli/test_data_loader.py +++ b/tests/unit/cli/test_data_loader.py @@ -27,6 +27,25 @@ ] ) + +good_df_all_good_metadata_permutations = ( + pd.DataFrame( + [ + {"id": 1, "text": "foo", "metadata": {"string": "string"}}, + {"id": 2, "text": "bar", "metadata": {"int": 1}}, + {"id": 3, "text": "baz", "metadata": {"float": 1.0}}, + {"id": 4, "text": "foo", "metadata": {"list": ["list", "another"]}}, + ] + ), + [ + Document(id=1, text="foo", metadata={"string": "string"}), + Document(id=2, text="bar", metadata={"int": 1}), + Document(id=3, text="baz", metadata={"float": 1.0}), + Document(id=4, text="foo", metadata={"list": ["list", "another"]}), + ] +) + + good_df_maximal = ( pd.DataFrame( [ @@ -102,6 +121,20 @@ DocumentsValidationError, ) +bad_df_metadata_not_allowed_all_permutations = ( + pd.DataFrame( + [ + {"id": 1, "text": "foo", "metadata": {"list_of_int": [1, 2, 3]}}, + {"id": 2, "text": "bar", "metadata": {"list_of_float": [1.0, 2.0, 3.0]}}, + {"id": 3, "text": "baz", "metadata": {"dict": {"key": "value"}}}, + {"id": 4, "text": "foo", "metadata": {"list_of_dict": [{"key": "value"}]}}, + {"id": 5, "text": "bar", "metadata": {"list_of_list": [["value"]]}}, + {"id": 6, "text": "baz", "metadata": {1: "foo"}}, + ] + ), + DocumentsValidationError +) + bad_df_has_excess_field = ( pd.DataFrame( @@ -132,6 +165,17 @@ DocumentsValidationError, ) +bad_df_missppelled_optional_field = ( + pd.DataFrame( + [ + {"id": 1, "text": "foo", "sorce": "foo_source"}, + {"id": 2, "text": "bar", "metdata": {"key": "value"}}, + {"id": 3, "text": "baz", "sorce": "baz_source"}, + ] + ), + DocumentsValidationError +) + bad_df_missing_mandatory_field = ( pd.DataFrame( [ @@ -166,6 +210,8 @@ ("bad_df_has_excess_field", bad_df_has_excess_field), ("bad_df_missing_mandatory_field", bad_df_missing_mandatory_field), ("bad_df_duplicate_ids", bad_df_duplicate_ids), + ("bad_df_missppelled_optional_field", bad_df_missppelled_optional_field), + ("good_df_all_good_metadata_permutations", good_df_all_good_metadata_permutations), ] diff --git a/tests/unit/context_builder/test_stuffing_context_builder.py b/tests/unit/context_builder/test_stuffing_context_builder.py index 3940df8b..f8ec3220 100644 --- a/tests/unit/context_builder/test_stuffing_context_builder.py +++ b/tests/unit/context_builder/test_stuffing_context_builder.py @@ -2,7 +2,7 @@ ContextSnippet, ContextQueryResult from resin.models.data_models import Context from ..stubs.stub_tokenizer import StubTokenizer -from resin.knoweldge_base.models import \ +from resin.knowledge_base.models import \ QueryResult, DocumentWithScore from resin.context_engine.context_builder import StuffingContextBuilder diff --git a/tests/unit/context_engine/test_context_engine.py b/tests/unit/context_engine/test_context_engine.py index 10255b93..41e87079 100644 --- a/tests/unit/context_engine/test_context_engine.py +++ b/tests/unit/context_engine/test_context_engine.py @@ -3,8 +3,8 @@ from resin.context_engine import ContextEngine from resin.context_engine.context_builder.base import ContextBuilder -from resin.knoweldge_base.base import BaseKnowledgeBase -from resin.knoweldge_base.models import QueryResult, DocumentWithScore +from resin.knowledge_base.base import BaseKnowledgeBase +from resin.knowledge_base.models import QueryResult, DocumentWithScore from resin.models.data_models import Query, Context, ContextContent diff --git a/tests/unit/record_encoder/base_test_record_encoder.py b/tests/unit/record_encoder/base_test_record_encoder.py index 7b0ab6fb..0c1c2008 100644 --- a/tests/unit/record_encoder/base_test_record_encoder.py +++ b/tests/unit/record_encoder/base_test_record_encoder.py @@ -2,7 +2,7 @@ import math from abc import ABC, abstractmethod -from resin.knoweldge_base.models import KBDocChunk, KBEncodedDocChunk, KBQuery +from resin.knowledge_base.models import KBDocChunk, KBEncodedDocChunk, KBQuery from resin.models.data_models import Query diff --git a/tests/unit/record_encoder/test_dense_record_encoder.py b/tests/unit/record_encoder/test_dense_record_encoder.py index a5eb2eb4..c2b7cd8a 100644 --- a/tests/unit/record_encoder/test_dense_record_encoder.py +++ b/tests/unit/record_encoder/test_dense_record_encoder.py @@ -1,6 +1,6 @@ import pytest -from resin.knoweldge_base.record_encoder import DenseRecordEncoder +from resin.knowledge_base.record_encoder import DenseRecordEncoder from .base_test_record_encoder import BaseTestRecordEncoder from ..stubs.stub_dense_encoder import StubDenseEncoder diff --git a/tests/unit/record_encoder/test_openai_record_encoder.py b/tests/unit/record_encoder/test_openai_record_encoder.py index 8f213b81..fd5333ec 100644 --- a/tests/unit/record_encoder/test_openai_record_encoder.py +++ b/tests/unit/record_encoder/test_openai_record_encoder.py @@ -1,7 +1,7 @@ import pytest from pinecone_text.dense.openai_encoder import OpenAIEncoder -from resin.knoweldge_base.record_encoder.openai import OpenAIRecordEncoder +from resin.knowledge_base.record_encoder.openai import OpenAIRecordEncoder from .base_test_record_encoder import BaseTestRecordEncoder from unittest.mock import Mock diff --git a/tests/unit/stubs/stub_chunker.py b/tests/unit/stubs/stub_chunker.py index 6f0dd331..aceb214f 100644 --- a/tests/unit/stubs/stub_chunker.py +++ b/tests/unit/stubs/stub_chunker.py @@ -1,6 +1,6 @@ from typing import List -from resin.knoweldge_base.chunker.base import Chunker -from resin.knoweldge_base.models import KBDocChunk +from resin.knowledge_base.chunker.base import Chunker +from resin.knowledge_base.models import KBDocChunk from resin.models.data_models import Document diff --git a/tests/unit/stubs/stub_dense_encoder.py b/tests/unit/stubs/stub_dense_encoder.py index d0e02ff2..9d55bb7f 100644 --- a/tests/unit/stubs/stub_dense_encoder.py +++ b/tests/unit/stubs/stub_dense_encoder.py @@ -1,15 +1,47 @@ -import hashlib +import mmh3 import numpy as np +from collections import defaultdict from typing import Union, List from pinecone_text.dense.base_dense_ecoder import BaseDenseEncoder class StubDenseEncoder(BaseDenseEncoder): - - def __init__(self, dimension: int = 3): + """ + Bag-of-words encoder that uses a random projection matrix to + project sparse vectors to dense vectors. + uses Johnson–Lindenstrauss lemma to project BOW sparse vectors to dense vectors. + """ + + def __init__(self, + dimension: int = 8, + vocab_size: int = 2 ** 12): + self.input_dim = vocab_size self.dimension = dimension + def _text_to_word_counts(self, text: str) -> defaultdict: + words = text.split() + word_counts = defaultdict(int) + for word in words: + hashed_word = mmh3.hash(word) % self.input_dim + word_counts[hashed_word] += 1 + return word_counts + + def _encode_text(self, text: str) -> List[float]: + word_counts = self._text_to_word_counts(text) + + # This will hold the result of word_counts * random_matrix + projected_embedding = np.zeros(self.dimension, dtype=np.float32) + + for hashed_word, count in word_counts.items(): + rng = np.random.default_rng(hashed_word) + # Seed the RNG with the hashed word index for consistency + random_vector = rng.standard_normal(self.dimension) + projected_embedding += count * random_vector + + projected_embedding = projected_embedding.astype(np.float32) + return list(projected_embedding / np.linalg.norm(projected_embedding)) + def encode_documents(self, texts: Union[str, List[str]] ) -> Union[List[float], List[List[float]]]: @@ -20,23 +52,10 @@ def encode_queries(self, ) -> Union[List[float], List[List[float]]]: return self._encode(texts) - def consistent_embedding(self, text: str) -> List[float]: - # consistent embedding function that project each text to a unique angle - embedding = [] - for i in range(self.dimension): - sha256_hash = hashlib.sha256(f"{text} {i}".encode()).hexdigest() - int_value = int(sha256_hash, 16) - embedding.append(int_value / float(1 << 256)) - - l2_norm = np.linalg.norm(embedding) - normalized_embedding = [float(value / l2_norm) for value in embedding] - - return normalized_embedding - def _encode(self, texts: Union[str, List[str]] ) -> Union[List[float], List[List[float]]]: if isinstance(texts, str): - return self.consistent_embedding(texts) + return self._encode_text(texts) else: - return [self.consistent_embedding(text) for text in texts] + return [self._encode_text(text) for text in texts] diff --git a/tests/unit/stubs/stub_record_encoder.py b/tests/unit/stubs/stub_record_encoder.py index 3bd49454..9af5d387 100644 --- a/tests/unit/stubs/stub_record_encoder.py +++ b/tests/unit/stubs/stub_record_encoder.py @@ -1,7 +1,7 @@ from typing import List -from resin.knoweldge_base.record_encoder import RecordEncoder -from resin.knoweldge_base.models import KBQuery, KBDocChunk, KBEncodedDocChunk +from resin.knowledge_base.record_encoder import RecordEncoder +from resin.knowledge_base.models import KBQuery, KBDocChunk, KBEncodedDocChunk from resin.models.data_models import Query from .stub_dense_encoder import StubDenseEncoder