diff --git a/.flake8 b/.flake8 deleted file mode 100644 index 8c8d82b..0000000 --- a/.flake8 +++ /dev/null @@ -1,17 +0,0 @@ -[flake8] -max-line-length = 100 -extend-ignore = - # ignore tabs - W191, - # blank line at end of file - W391, - # white space inside {} - E201, - E202, - # multiple blank lines - ; E303, - # lambda functions assigned to a variable - E731, - # multiple statements on one line (colon) - E701, - diff --git a/.github/workflows/lint-n-static-analysis.yml b/.github/workflows/lint-n-static-analysis.yml new file mode 100644 index 0000000..19a7183 --- /dev/null +++ b/.github/workflows/lint-n-static-analysis.yml @@ -0,0 +1,45 @@ +name: Lint and Static Analysis + +on: + pull_request: + paths: + - main.py + - context_chat_backend/** + - reqs.txt + - reqs.dev + push: + branches: + - master + paths: + - main.py + - context_chat_backend/** + - reqs.txt + - reqs.dev + +jobs: + analysis: + runs-on: ubuntu-latest + + name: Lint and Static Analysis + + steps: + - uses: actions/checkout@v4 + + - name: Setup python 3.11 + uses: actions/setup-python@v5 + with: + python-version: '3.11' + cache: 'pip' + + - name: Install dependencies + run: | + pip install -r reqs.txt + pip install -r reqs.dev + + - name: Lint with Ruff + run: | + ruff --output-format=github context_chat_backend main.py + + - name: Static analysis with pyright + run: | + pyright context_chat_backend main.py diff --git a/context_chat_backend/chain/ingest/doc_loader.py b/context_chat_backend/chain/ingest/doc_loader.py index 3c43021..5b9ab98 100644 --- a/context_chat_backend/chain/ingest/doc_loader.py +++ b/context_chat_backend/chain/ingest/doc_loader.py @@ -1,19 +1,20 @@ -from logging import error as log_error import re import tempfile +from collections.abc import Callable +from logging import error as log_error from typing import BinaryIO from fastapi import UploadFile -from pandas import read_csv, read_excel -from pypandoc import convert_text -from pypdf import PdfReader from langchain.document_loaders import ( - UnstructuredPowerPointLoader, UnstructuredEmailLoader, + UnstructuredPowerPointLoader, ) +from pandas import read_csv, read_excel +from pypandoc import convert_text +from pypdf import PdfReader -def _temp_file_wrapper(file: BinaryIO, loader: callable, sep: str = '\n') -> str: +def _temp_file_wrapper(file: BinaryIO, loader: Callable, sep: str = '\n') -> str: raw_bytes = file.read() tmp = tempfile.NamedTemporaryFile(mode='wb') tmp.write(raw_bytes) @@ -25,7 +26,7 @@ def _temp_file_wrapper(file: BinaryIO, loader: callable, sep: str = '\n') -> str import os os.remove(tmp.name) - return sep.join(map(lambda d: d.page_content, docs)) + return sep.join(d.page_content for d in docs) # -- LOADERS -- # @@ -40,11 +41,11 @@ def _load_csv(file: BinaryIO) -> str: def _load_epub(file: BinaryIO) -> str: - return convert_text(file.read(), 'plain', 'epub').strip() + return convert_text(str(file.read()), 'plain', 'epub').strip() def _load_docx(file: BinaryIO) -> str: - return convert_text(file.read(), 'plain', 'docx').strip() + return convert_text(str(file.read()), 'plain', 'docx').strip() def _load_ppt_x(file: BinaryIO) -> str: @@ -52,11 +53,11 @@ def _load_ppt_x(file: BinaryIO) -> str: def _load_rtf(file: BinaryIO) -> str: - return convert_text(file.read(), 'plain', 'rtf').strip() + return convert_text(str(file.read()), 'plain', 'rtf').strip() def _load_rst(file: BinaryIO) -> str: - return convert_text(file.read(), 'plain', 'rst').strip() + return convert_text(str(file.read()), 'plain', 'rst').strip() def _load_xml(file: BinaryIO) -> str: @@ -70,7 +71,7 @@ def _load_xlsx(file: BinaryIO) -> str: def _load_odt(file: BinaryIO) -> str: - return convert_text(file.read(), 'plain', 'odt').strip() + return convert_text(str(file.read()), 'plain', 'odt').strip() def _load_email(file: BinaryIO, ext: str = 'eml') -> str | None: @@ -95,7 +96,7 @@ def attachment_partitioner( def _load_org(file: BinaryIO) -> str: - return convert_text(file.read(), 'plain', 'org').strip() + return convert_text(str(file.read()), 'plain', 'org').strip() # -- LOADER FUNCTION MAP -- # @@ -124,11 +125,15 @@ def decode_source(source: UploadFile) -> str | None: try: # .pot files are powerpoint templates but also plain text files, # so we skip them to prevent decoding errors - if source.headers.get('title').endswith('.pot'): + if source.headers.get('title', '').endswith('.pot'): + return None + + mimetype = source.headers.get('type') + if mimetype is None: return None - if _loader_map.get(source.headers.get('type')): - return _loader_map[source.headers.get('type')](source.file) + if _loader_map.get(mimetype): + return _loader_map[mimetype](source.file) return source.file.read().decode('utf-8') except Exception as e: diff --git a/context_chat_backend/chain/ingest/doc_splitter.py b/context_chat_backend/chain/ingest/doc_splitter.py index f3a2b17..6d95d04 100644 --- a/context_chat_backend/chain/ingest/doc_splitter.py +++ b/context_chat_backend/chain/ingest/doc_splitter.py @@ -17,7 +17,7 @@ def get_splitter_for(mimetype: str = 'text/plain') -> TextSplitter: mt_map = { 'text/markdown': MarkdownTextSplitter(**kwargs), - 'application/json': RecursiveCharacterTextSplitter(separators=['{', '}', r'\[', r'\]', ',', ''], **kwargs), # noqa: E501 + 'application/json': RecursiveCharacterTextSplitter(separators=['{', '}', r'\[', r'\]', ',', ''], **kwargs), # processed csv, does not contain commas 'text/csv': RecursiveCharacterTextSplitter(separators=['\n', ' ', ''], **kwargs), # remove end tags for less verbosity, and remove all whitespace outside of tags @@ -26,7 +26,7 @@ def get_splitter_for(mimetype: str = 'text/plain') -> TextSplitter: 'application/vnd.ms-excel.sheet.macroEnabled.12': RecursiveCharacterTextSplitter(separators=['\n\n', '\n', ' ', ''], **kwargs), # noqa: E501 } - if mimetype in mt_map.keys(): + if mimetype in mt_map: return mt_map[mimetype] # all other mimetypes diff --git a/context_chat_backend/chain/ingest/injest.py b/context_chat_backend/chain/ingest/injest.py index 7e66853..af08422 100644 --- a/context_chat_backend/chain/ingest/injest.py +++ b/context_chat_backend/chain/ingest/injest.py @@ -1,14 +1,14 @@ -from logging import error as log_error import re +from logging import error as log_error from fastapi.datastructures import UploadFile from langchain.schema import Document +from ...utils import to_int +from ...vectordb import BaseVectorDB from .doc_loader import decode_source from .doc_splitter import get_splitter_for from .mimetype_list import SUPPORTED_MIMETYPES -from ...utils import to_int -from ...vectordb import BaseVectorDB def _allowed_file(file: UploadFile) -> bool: @@ -51,21 +51,22 @@ def _filter_documents( .difference(set(existing_objects)) new_sources.update(set(to_delete.keys())) - filtered_documents = [ + return [ doc for doc in documents if doc.metadata.get('source') in new_sources ] - return filtered_documents - -def _sources_to_documents(sources: list[UploadFile]) -> list[Document]: +def _sources_to_documents(sources: list[UploadFile]) -> dict[str, list[Document]]: + ''' + Converts a list of sources to a dictionary of documents with the user_id as the key. + ''' documents = {} for source in sources: user_id = source.headers.get('userId') if user_id is None: - log_error('userId not found in headers for source: ' + source.filename) + log_error(f'userId not found in headers for source: {source.filename}') continue # transform the source to have text data diff --git a/context_chat_backend/chain/one_shot.py b/context_chat_backend/chain/one_shot.py index 844cff7..105dd79 100644 --- a/context_chat_backend/chain/one_shot.py +++ b/context_chat_backend/chain/one_shot.py @@ -18,22 +18,19 @@ def process_query( ctx_limit: int = 5, template: str = _LLM_TEMPLATE, end_separator: str = '', -) -> tuple[str, list]: +) -> tuple[str, set]: if not use_context: - return llm.predict(query), [] + return llm.predict(query), set() user_client = vectordb.get_user_client(user_id) if user_client is None: - return llm.predict(query), [] + return llm.predict(query), set() context_docs = user_client.similarity_search(query, k=ctx_limit) - context_text = '\n\n'.join(map( - lambda d: f'{d.metadata.get("title")}\n{d.page_content}', - context_docs, - )) + context_text = '\n\n'.join(f'{d.metadata.get("title")}\n{d.page_content}' for d in context_docs) output = llm.predict(template.format(context=context_text, question=query)) \ .strip().rstrip(end_separator).strip() - unique_sources = list(set(map(lambda d: d.metadata.get('source', ''), context_docs))) + unique_sources = {d.metadata.get('source') for d in context_docs} return (output, unique_sources) diff --git a/context_chat_backend/config_parser.py b/context_chat_backend/config_parser.py index a67f043..2eb6ada 100644 --- a/context_chat_backend/config_parser.py +++ b/context_chat_backend/config_parser.py @@ -1,4 +1,5 @@ from pprint import pprint +from typing import TypedDict from ruamel.yaml import YAML @@ -6,6 +7,12 @@ from .vectordb import vector_dbs +class TConfig(TypedDict): + vectordb: tuple[str, dict] + embedding: tuple[str, dict] + llm: tuple[str, dict] + + def _first_in_list( input_dict: dict[str, dict], supported_list: list[str] @@ -21,7 +28,7 @@ def _first_in_list( return None -def get_config(file_path: str = 'config.yaml') -> dict[str, tuple[str, dict]]: +def get_config(file_path: str = 'config.yaml') -> TConfig: ''' Get the config from the given file path (relative to the root directory). ''' @@ -32,27 +39,30 @@ def get_config(file_path: str = 'config.yaml') -> dict[str, tuple[str, dict]]: except Exception as e: raise AssertionError('Error: could not load config from', file_path, 'file') from e - selected_config = { - 'vectordb': _first_in_list(config.get('vectordb', {}), vector_dbs), - 'embedding': _first_in_list(config.get('embedding', {}), models['embedding']), - 'llm': _first_in_list(config.get('llm', {}), models['llm']), - } - - if not selected_config['vectordb']: + vectordb = _first_in_list(config.get('vectordb', {}), vector_dbs) + if not vectordb: raise AssertionError( f'Error: vectordb should be at least one of {vector_dbs} in the config file' ) - if not selected_config['embedding']: + embedding = _first_in_list(config.get('embedding', {}), models['embedding']) + if not embedding: raise AssertionError( f'Error: embedding model should be at least one of {models["embedding"]} in the config file' ) - if not selected_config['llm']: + llm = _first_in_list(config.get('llm', {}), models['llm']) + if not llm: raise AssertionError( f'Error: llm model should be at least one of {models["llm"]} in the config file' ) + selected_config: TConfig = { + 'vectordb': vectordb, + 'embedding': embedding, + 'llm': llm, + } + pprint(f'Selected config: {selected_config}') return selected_config diff --git a/context_chat_backend/controller.py b/context_chat_backend/controller.py index 63df74c..2a078d1 100644 --- a/context_chat_backend/controller.py +++ b/context_chat_backend/controller.py @@ -2,13 +2,13 @@ from typing import Annotated from dotenv import load_dotenv -from fastapi import Body, FastAPI, Request, UploadFile, BackgroundTasks +from fastapi import BackgroundTasks, Body, FastAPI, Request, UploadFile from langchain.llms.base import LLM from .chain import embed_sources, process_query from .download import download_all_models from .ocs_utils import AppAPIAuthMiddleware -from .utils import enabled_guard, JSONResponse, update_progress, value_of +from .utils import JSONResponse, enabled_guard, update_progress, value_of from .vectordb import BaseVectorDB load_dotenv() @@ -34,7 +34,12 @@ def _(request: Request): @app.get('/world') @enabled_guard(app) def _(query: str | None = None): - em = app.extra.get('EMBEDDING_MODEL') + from langchain.schema.embeddings import Embeddings + em: Embeddings | None = app.extra.get('EMBEDDING_MODEL') + + if em is None: + return JSONResponse('Error: Embedding model not initialised', 500) + return em.embed_query(query if query is not None else 'what is an apple?') @@ -42,11 +47,19 @@ def _(query: str | None = None): @app.get('/vectors') @enabled_guard(app) def _(userId: str): - from chromadb import ClientAPI + from chromadb.api import ClientAPI + from .vectordb import COLLECTION_NAME - db: BaseVectorDB = app.extra.get('VECTOR_DB') - client: ClientAPI = db.client + db: BaseVectorDB | None = app.extra.get('VECTOR_DB') + if db is None: + return JSONResponse('Error: VectorDB not initialised', 500) + + client: ClientAPI | None = db.client + + if client is None: + return JSONResponse('Error: VectorDB client not initialised', 500) + db.setup_schema(userId) return JSONResponse( @@ -58,18 +71,19 @@ def _(userId: str): @app.get('/search') @enabled_guard(app) def _(userId: str, sourceNames: str): - sourceNames: list[str] = [source.strip() for source in sourceNames.split(',') if source.strip() != ''] + sourceList = [source.strip() for source in sourceNames.split(',') if source.strip() != ''] - if len(sourceNames) == 0: + if len(sourceList) == 0: return JSONResponse('No sources provided', 400) - db: BaseVectorDB = app.extra.get('VECTOR_DB') + db: BaseVectorDB | None = app.extra.get('VECTOR_DB') if db is None: return JSONResponse('Error: VectorDB not initialised', 500) - source_objs = db.get_objects_from_metadata(userId, 'source', sourceNames) - sources = list(map(lambda s: s.get('id'), source_objs.values())) + source_objs = db.get_objects_from_metadata(userId, 'source', sourceList) + # sources = list(map(lambda s: s.get('id'), source_objs.values())) + sources = [s.get('id') for s in source_objs.values()] return JSONResponse({ 'sources': sources }) @@ -106,7 +120,7 @@ def _(userId: Annotated[str, Body()], sourceNames: Annotated[list[str], Body()]) if len(sourceNames) == 0: return JSONResponse('No sources provided', 400) - db: BaseVectorDB = app.extra.get('VECTOR_DB') + db: BaseVectorDB | None = app.extra.get('VECTOR_DB') if db is None: return JSONResponse('Error: VectorDB not initialised', 500) @@ -125,7 +139,7 @@ def _(userId: Annotated[str, Body()], providerKey: Annotated[str, Body()]): if value_of(providerKey) is None: return JSONResponse('Invalid provider key provided', 400) - db: BaseVectorDB = app.extra.get('VECTOR_DB') + db: BaseVectorDB | None = app.extra.get('VECTOR_DB') if db is None: return JSONResponse('Error: VectorDB not initialised', 500) @@ -145,16 +159,16 @@ def _(sources: list[UploadFile]): return JSONResponse('No sources provided', 400) # TODO: headers validation using pydantic - if not all([ + if not ( value_of(source.headers.get('userId')) and value_of(source.headers.get('type')) and value_of(source.headers.get('modified')) and value_of(source.headers.get('provider')) - for source in sources] + for source in sources ): return JSONResponse('Invaild/missing headers', 400) - db: BaseVectorDB = app.extra.get('VECTOR_DB') + db: BaseVectorDB | None = app.extra.get('VECTOR_DB') if db is None: return JSONResponse('Error: VectorDB not initialised', 500) @@ -168,11 +182,11 @@ def _(sources: list[UploadFile]): @app.get('/query') @enabled_guard(app) def _(userId: str, query: str, useContext: bool = True, ctxLimit: int = 5): - llm: LLM = app.extra.get('LLM_MODEL') + llm: LLM | None = app.extra.get('LLM_MODEL') if llm is None: return JSONResponse('Error: LLM not initialised', 500) - db: BaseVectorDB = app.extra.get('VECTOR_DB') + db: BaseVectorDB | None = app.extra.get('VECTOR_DB') if db is None: return JSONResponse('Error: VectorDB not initialised', 500) diff --git a/context_chat_backend/download.py b/context_chat_backend/download.py index 889d460..dcb3d3e 100644 --- a/context_chat_backend/download.py +++ b/context_chat_backend/download.py @@ -1,19 +1,19 @@ -from hashlib import file_digest -from logging import error as log_error -from pathlib import Path import os import re import shutil import tarfile import zipfile +from hashlib import file_digest +from logging import error as log_error +from pathlib import Path +import requests from dotenv import load_dotenv from fastapi import FastAPI -import requests +from .config_parser import TConfig from .utils import update_progress - load_dotenv() _MODELS_DIR = '' @@ -44,13 +44,13 @@ '.zip', ) -_model_config: dict[str, str | None] = { +_model_config: dict[str, tuple[str, str, str]] = { 'hkunlp/instructor-base': ('hkunlp_instructor-base', '.tar.gz', '19751ec112564f2c568b96a794dd4a16f335ee42b2535a890b577fc5137531eb'), # noqa: E501 'dolphin-2.2.1-mistral-7b.Q5_K_M.gguf': ('dolphin-2.2.1-mistral-7b.Q5_K_M.gguf', '', '591a9b807bfa6dba9a5aed1775563e4364d7b7b3b714fc1f9e427fa0e2bf6ace'), # noqa: E501 } -def _get_model_name_or_path(config: dict, model_type: str) -> str | None: +def _get_model_name_or_path(config: TConfig, model_type: str) -> str | None: if (model_config := config.get(model_type)) is not None: model_config = model_config[1] return ( @@ -60,9 +60,10 @@ def _get_model_name_or_path(config: dict, model_type: str) -> str | None: or model_config.get('model_file') or model_config.get('model') ) + return None -def _set_app_config(app: FastAPI, config: dict[str, tuple[str, dict]]): +def _set_app_config(app: FastAPI, config: TConfig): ''' Sets the app config as an extra attribute to the app object. @@ -76,24 +77,28 @@ def _set_app_config(app: FastAPI, config: dict[str, tuple[str, dict]]): if config.get('embedding'): from .models import init_model - model = init_model('embedding', config.get('embedding')) + model = init_model('embedding', config['embedding']) app.extra['EMBEDDING_MODEL'] = model - if config.get('vectordb'): + if config.get('vectordb') and config.get('embedding'): + from langchain.schema.embeddings import Embeddings + from .vectordb import get_vector_db - client_klass = get_vector_db(config.get('vectordb')[0]) + client_klass = get_vector_db(config['vectordb'][0]) - if app.extra.get('EMBEDDING_MODEL') is not None: - app.extra['VECTOR_DB'] = client_klass(app.extra['EMBEDDING_MODEL'], **config.get('vectordb')[1]) + em: Embeddings | None = app.extra.get('EMBEDDING_MODEL') + if em is not None: + app.extra['VECTOR_DB'] = client_klass(em, **config['vectordb'][1]) # type: ignore else: - app.extra['VECTOR_DB'] = client_klass(**config.get('vectordb')[1]) + app.extra['VECTOR_DB'] = client_klass(**config.get('vectordb')[1]) # type: ignore if config.get('llm'): from .models import init_model - llm_name, llm_config = config.get('llm') + llm_name, llm_config = config['llm'] app.extra['LLM_TEMPLATE'] = llm_config.pop('template', '') + app.extra['LLM_END_SEPARATOR'] = llm_config.pop('end_separator', '') model = init_model('llm', (llm_name, llm_config)) app.extra['LLM_MODEL'] = model @@ -126,7 +131,7 @@ def _download_model(model_name_or_path: str) -> bool: else: model_name = re.sub(r'^.*' + _MODELS_DIR + r'/', '', model_name_or_path) - if model_name in _model_config.keys(): + if model_name in _model_config: model_file = _model_config[model_name][0] + _model_config[model_name][1] url = _BASE_URL + model_file filepath = Path(_MODELS_DIR, model_file).as_posix() @@ -139,7 +144,7 @@ def _download_model(model_name_or_path: str) -> bool: try: f = open(filepath, 'w+b') - r = requests.get(url, stream=True) + r = requests.get(url, stream=True, timeout=(10, 60)) r.raw.decode_content = True # content decompression if r.status_code >= 400: @@ -149,15 +154,19 @@ def _download_model(model_name_or_path: str) -> bool: shutil.copyfileobj(r.raw, f, length=16 * 1024 * 1024) # 16MB chunks # hash check if the config is declared - if model_name in _model_config.keys(): + if model_name in _model_config: f.seek(0) original_digest = _model_config.get(model_name, (None, None, None))[2] - digest = file_digest(f, 'sha256').hexdigest() - if (original_digest != digest): - log_error( - f'Error: Model file ({filepath}) corrupted:\nexpected hash {original_digest}\ngot {digest}' - ) - return False + if original_digest is None: + # warning + log_error(f'Error: Hash not found for model {model_name}, continuing without hash check') + else: + digest = file_digest(f, 'sha256').hexdigest() + if (original_digest != digest): + log_error( + f'Error: Model file ({filepath}) corrupted:\nexpected hash {original_digest}\ngot {digest}' + ) + return False f.close() @@ -173,18 +182,26 @@ def _extract_n_save(model_name: str, filepath: str) -> bool: # extract the model if it is a compressed file if (filepath.endswith(_KNOWN_ARCHIVES)): + tar_archive = None + zip_archive = None + try: if filepath.endswith('.tar.gz'): - tar = tarfile.open(filepath, 'r:gz') + tar_archive = tarfile.open(filepath, 'r:gz') elif filepath.endswith('.tar.bz2'): - tar = tarfile.open(filepath, 'r:bz2') + tar_archive = tarfile.open(filepath, 'r:bz2') elif filepath.endswith('.tar.xz'): - tar = tarfile.open(filepath, 'r:xz') - else: - tar = zipfile.ZipFile(filepath, 'r') + tar_archive = tarfile.open(filepath, 'r:xz') + elif filepath.endswith('.zip'): + zip_archive = zipfile.ZipFile(filepath, 'r') + + if tar_archive: + tar_archive.extractall(_MODELS_DIR, filter='data') + tar_archive.close() + elif zip_archive: + zip_archive.extractall(_MODELS_DIR) + zip_archive.close() - tar.extractall(_MODELS_DIR) - tar.close() os.remove(filepath) except OSError as e: raise OSError('Error: Model extraction failed') from e @@ -208,7 +225,7 @@ def download_all_models(app: FastAPI): ---- app: FastAPI object ''' - config = app.extra['CONFIG'] + config: TConfig = app.extra['CONFIG'] if os.getenv('DISABLE_CUSTOM_DOWNLOAD_URI', '0') == '1': update_progress(100) @@ -218,8 +235,12 @@ def download_all_models(app: FastAPI): progress = 0 for model_type in ('embedding', 'llm'): model_name = _get_model_name_or_path(config, model_type) + if model_name is None: + raise Exception(f'Error: Model name/path not found for {model_type}') + if not _download_model(model_name): raise Exception(f'Error: Model download failed for {model_name}') + update_progress(progress := progress + 50) _set_app_config(app, config) @@ -232,8 +253,13 @@ def model_init(app: FastAPI) -> bool: for model_type in ('embedding', 'llm'): model_name = _get_model_name_or_path(app.extra['CONFIG'], model_type) + if model_name is None: + return False + if not _model_exists(model_name): return False - _set_app_config(app, app.extra['CONFIG']) + config: TConfig = app.extra['CONFIG'] + _set_app_config(app, config) + return True diff --git a/context_chat_backend/models/__init__.py b/context_chat_backend/models/__init__.py index 5a56cf4..696c525 100644 --- a/context_chat_backend/models/__init__.py +++ b/context_chat_backend/models/__init__.py @@ -17,7 +17,7 @@ def init_model(model_type: str, model_info: tuple[str, dict]): the same name as the model in the models dir. ''' model_name, _ = model_info - available_models = models.get(model_type) + available_models = models.get(model_type, []) if model_name not in available_models: raise AssertionError(f'Error: {model_type}_model should be one of {available_models}') diff --git a/context_chat_backend/models/load_model.py b/context_chat_backend/models/load_model.py index 17b39aa..d3ba9ba 100644 --- a/context_chat_backend/models/load_model.py +++ b/context_chat_backend/models/load_model.py @@ -1,5 +1,5 @@ +from collections.abc import Callable from importlib import import_module -from typing import Callable from langchain.llms.base import LLM from langchain.schema.embeddings import Embeddings @@ -9,8 +9,6 @@ def load_model(model_type: str, model_info: tuple[str, dict]) -> Embeddings | LLM | None: model_name, model_config = model_info - model_config.pop('template', '') - model_config.pop('end_separator', '') module = import_module(f'.{model_name}', 'context_chat_backend.models') diff --git a/context_chat_backend/ocs_utils.py b/context_chat_backend/ocs_utils.py index 90ac276..cf18c50 100644 --- a/context_chat_backend/ocs_utils.py +++ b/context_chat_backend/ocs_utils.py @@ -2,11 +2,10 @@ from base64 import b64decode, b64encode from logging import error as log_error from os import getenv -from packaging import version -from typing import Optional, Union import httpx -from starlette.datastructures import Headers, URL +from packaging import version +from starlette.datastructures import URL, Headers from starlette.responses import JSONResponse from starlette.status import HTTP_401_UNAUTHORIZED from starlette.types import ASGIApp, Receive, Scope, Send @@ -19,12 +18,15 @@ def _sign_request(headers: dict, username: str = '') -> None: headers['AUTHORIZATION-APP-API'] = b64encode(f'{username}:{getenv("APP_SECRET")}'.encode('UTF=8')) -def _verify_signature(headers: Headers) -> str: +# We assume that the env variables are set +def _verify_signature(headers: Headers) -> str | None: if headers.get('AA-VERSION') is None: log_error('AppAPI header AA-VERSION not set') return None - if version.parse(headers.get('AA-VERSION')) < version.parse(getenv('AA_VERSION')): + if headers.get('AA-VERSION') is None or \ + getenv('AA_VERSION') is None or \ + version.parse(headers['AA-VERSION']) < version.parse(getenv('AA_VERSION', '')): log_error('AppAPI version should be at least', getenv('AA_VERSION')) return None @@ -38,7 +40,7 @@ def _verify_signature(headers: Headers) -> str: ) return None - auth_aa = b64decode(headers.get('AUTHORIZATION-APP-API')).decode('UTF-8') + auth_aa = b64decode(headers.get('AUTHORIZATION-APP-API', '')).decode('UTF-8') username, app_secret = auth_aa.split(':', maxsplit=1) if app_secret != getenv('APP_SECRET'): @@ -81,14 +83,14 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: def get_nc_url() -> str: - return getenv('NEXTCLOUD_URL').removesuffix('/index.php').removesuffix('/') + return getenv('NEXTCLOUD_URL', '').removesuffix('/index.php').removesuffix('/') def ocs_call( method: str, path: str, - params: Optional[dict] = {}, - json_data: Optional[Union[dict, list]] = None, + params: dict | None = None, + json_data: dict | list | None = None, **kwargs, ): ''' @@ -111,7 +113,8 @@ def ocs_call( The username to use for signing the request. Additional keyword arguments to pass to the httpx.request function. ''' - if not params: params = {} + if params is None: + params = {} params.update({'format': 'json'}) headers = kwargs.pop('headers', {}) diff --git a/context_chat_backend/utils.py b/context_chat_backend/utils.py index b76aa06..956efd2 100644 --- a/context_chat_backend/utils.py +++ b/context_chat_backend/utils.py @@ -1,3 +1,4 @@ +from collections.abc import Callable from functools import wraps from os import getenv from typing import Any, TypeVar @@ -58,7 +59,7 @@ def JSONResponse( def enabled_guard(app: FastAPI): - def decorator(func: callable): + def decorator(func: Callable): ''' Decorator to check if the service is enabled ''' diff --git a/context_chat_backend/vectordb/base.py b/context_chat_backend/vectordb/base.py index 7e16b0e..1437dfc 100644 --- a/context_chat_backend/vectordb/base.py +++ b/context_chat_backend/vectordb/base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import List, Optional +from typing import Any, TypedDict from langchain.schema.embeddings import Embeddings from langchain.vectorstores import VectorStore @@ -7,20 +7,27 @@ from ..utils import value_of +class TSearchObject(TypedDict): + id: str + modified: str + +TSearchDict = dict[str, TSearchObject] + + class BaseVectorDB(ABC): - client = None - embedding = None + client: Any = None + embedding: Any = None @abstractmethod - def __init__(self, embedding: Optional[Embeddings] = None, **kwargs): + def __init__(self, embedding: Embeddings | None = None, **kwargs): self.embedding = embedding @abstractmethod def get_user_client( self, user_id: str, - embedding: Optional[Embeddings] = None # Use this embedding if not None or use global embedding - ) -> Optional[VectorStore]: + embedding: Embeddings | None = None # Use this embedding if not None or use global embedding + ) -> VectorStore | None: ''' Creates and returns the langchain vectordb client object for the given user_id. @@ -28,12 +35,12 @@ def get_user_client( ---- user_id: str User ID for which to create the client object. - embedding: Optional[Embeddings] + embedding: Embeddings | None Embeddings object to use for embedding documents. Returns ------- - Optional[VectorStore] + VectorStore | None Client object for the VectorDB or None if error occurs. ''' @@ -57,8 +64,8 @@ def get_objects_from_metadata( self, user_id: str, metadata_key: str, - values: List[str], - ) -> dict: + values: list[str], + ) -> TSearchDict: ''' Get all objects with the given metadata key and values. (Only gets the following fields: [id, 'metadata_key', modified]) @@ -69,22 +76,12 @@ def get_objects_from_metadata( User ID for whose database to get the sources. metadata_key: str Metadata key to get. - values: List[str] + values: list[str] List of metadata names to get. Returns ------- - List[dict] - if error occurs: {} - - otherwise: - - { - [metadata_key: str]: { - 'id': str, - 'modified': str, - } - } + TSearchDict ''' def delete_by_ids(self, user_id: str, ids: list[str]) -> bool: diff --git a/context_chat_backend/vectordb/chroma.py b/context_chat_backend/vectordb/chroma.py index 3c5e8e8..30f7a6f 100644 --- a/context_chat_backend/vectordb/chroma.py +++ b/context_chat_backend/vectordb/chroma.py @@ -1,21 +1,20 @@ from logging import error as log_error from os import getenv -from typing import List, Optional +from chromadb import Client, Where +from chromadb.config import Settings from dotenv import load_dotenv from langchain.schema.embeddings import Embeddings -from langchain.vectorstores import VectorStore, Chroma -from chromadb.config import Settings -from chromadb import Client +from langchain.vectorstores import Chroma, VectorStore -from .base import BaseVectorDB from . import COLLECTION_NAME +from .base import BaseVectorDB, TSearchDict load_dotenv() class VectorDB(BaseVectorDB): - def __init__(self, embedding: Optional[Embeddings] = None, **kwargs): + def __init__(self, embedding: Embeddings | None = None, **kwargs): try: client = Client(Settings( anonymized_telemetry=False, @@ -26,7 +25,7 @@ def __init__(self, embedding: Optional[Embeddings] = None, **kwargs): }, )) except Exception as e: - raise Exception(f'Error: Chromadb instantiation error: {e}') + raise Exception('Error: Chromadb instantiation error') from e if client.heartbeat() <= 0: raise Exception('Error: Chromadb connection error') @@ -44,8 +43,8 @@ def setup_schema(self, user_id: str) -> None: def get_user_client( self, user_id: str, - embedding: Optional[Embeddings] = None # Use this embedding if not None or use global embedding - ) -> Optional[VectorStore]: + embedding: Embeddings | None = None # Use this embedding if not None or use global embedding + ) -> VectorStore | None: self.setup_schema(user_id) em = None @@ -64,8 +63,8 @@ def get_objects_from_metadata( self, user_id: str, metadata_key: str, - values: List[str], - ) -> dict: + values: list[str], + ) -> TSearchDict: # NOTE: the limit of objects returned is not known, maybe it would be better to set one manually if not self.client: @@ -76,7 +75,7 @@ def get_objects_from_metadata( if len(values) == 0: return {} - data_filter = { metadata_key: { '$in': values } } + data_filter: Where = { metadata_key: { '$in': values } } # type: ignore try: results = self.client.get_collection(COLLECTION_NAME(user_id)).get( @@ -87,13 +86,17 @@ def get_objects_from_metadata( log_error(f'Error: Chromadb query error: {e}') return {} - if len(results.get('ids')) == 0: + if len(results.get('ids', [])) == 0: + return {} + + res_metadatas = results.get('metadatas') + if res_metadatas is None: return {} output = {} try: for i, _id in enumerate(results.get('ids')): - meta = results['metadatas'][i] + meta = res_metadatas[i] output[meta[metadata_key]] = { 'id': _id, 'modified': meta['modified'], diff --git a/context_chat_backend/vectordb/weaviate.py b/context_chat_backend/vectordb/weaviate.py index cf78a0f..718b26d 100644 --- a/context_chat_backend/vectordb/weaviate.py +++ b/context_chat_backend/vectordb/weaviate.py @@ -1,15 +1,14 @@ from logging import error as log_error from os import getenv -from typing import List, Optional from dotenv import load_dotenv from langchain.schema.embeddings import Embeddings from langchain.vectorstores import VectorStore, Weaviate from weaviate import AuthApiKey, Client -from .base import BaseVectorDB -from . import COLLECTION_NAME from ..utils import value_of +from . import COLLECTION_NAME +from .base import BaseVectorDB, TSearchDict load_dotenv() @@ -69,20 +68,20 @@ class VectorDB(BaseVectorDB): - def __init__(self, embedding: Optional[Embeddings] = None, **kwargs): + def __init__(self, embedding: Embeddings | None = None, **kwargs): try: client = Client( **({ - 'auth_client_secret': AuthApiKey(getenv('WEAVIATE_APIKEY')), + 'auth_client_secret': AuthApiKey(getenv('WEAVIATE_APIKEY', '')), } if value_of(getenv('WEAVIATE_APIKEY')) is not None else {}), - **{**{ + **{ 'url': getenv('WEAVIATE_URL'), 'timeout_config': (1, 20), **kwargs, - }}, + }, ) except Exception as e: - raise Exception(f'Error: Weaviate connection error: {e}') + raise Exception('Error: Weaviate connection error') from e if not client.is_ready(): raise Exception('Error: Weaviate connection error') @@ -105,8 +104,8 @@ def setup_schema(self, user_id: str) -> None: def get_user_client( self, user_id: str, - embedding: Optional[Embeddings] = None # Use this embedding if not None or use global embedding - ) -> Optional[VectorStore]: + embedding: Embeddings | None = None # Use this embedding if not None or use global embedding + ) -> VectorStore | None: self.setup_schema(user_id) em = None @@ -130,8 +129,8 @@ def get_objects_from_metadata( self, user_id: str, metadata_key: str, - values: List[str], - ) -> dict: + values: list[str], + ) -> TSearchDict: # NOTE: the limit of objects returned is not known, maybe it would be better to set one manually if not self.client: diff --git a/main.py b/main.py index 5ed64a4..e410f61 100755 --- a/main.py +++ b/main.py @@ -8,7 +8,7 @@ if __name__ == '__main__': uvicorn.run( app='context_chat_backend:app', - host=getenv('APP_HOST', '0.0.0.0'), + host=getenv('APP_HOST', '127.0.0.1'), port=to_int(getenv('APP_PORT'), 9000), http='h11', interface='asgi3', diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..4db04f7 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,42 @@ +[project] +name = "context_chat_backend" +version = "1.1.1" +requires-python = ">=3.11" +authors = [ + { name = "Anupam Kumar", email = "kyteinsky@gmail.com" } +] +description = "The Python backend for Context Chat" +readme = { file = "readme.markdown", content-type = "text/markdown" } +license = { file = "LICENSE" } +classifiers = [ + "Development Status :: 4 - Beta", + "Programming Language :: Python", + "Private :: Do Not Upload", +] + +[tool.ruff] +target-version = "py311" +include = ["context_chat_backend/**/*.py", "main.py"] +line-length = 120 +fix = true + +[tool.ruff.lint] +select = ["A", "B", "C", "E", "F", "G", "I", "S", "PIE", "RET", "RUF", "UP" , "W"] +ignore = [ + "W191", # Indentation contains tabs + "E201", # Whitespace after opening bracket + "E202", # Whitespace before closing bracket + "E731", # Do not assign a lambda expression, use a def + "C901", # Function is too complex + "G004", # Logging statement uses f-string formatting +] +# remove G004 after better logging solution is implemented +fixable = [ + "F401", # Unused import + "RUF100" # Unused noqa comments +] + +[tool.pyright] +include = ["context_chat_backend/**/*.py", "main.py"] +pythonVersion = "3.11" +pythonPlatform = "Linux" diff --git a/reqs.dev b/reqs.dev new file mode 100644 index 0000000..7af7b9f --- /dev/null +++ b/reqs.dev @@ -0,0 +1,2 @@ +pyright +ruff