From b026b125d8c00d9a50e33858e7ffc684ee1c5ccf Mon Sep 17 00:00:00 2001 From: "Alexie (Boyong) Madolid" Date: Tue, 29 Aug 2023 14:47:51 +0800 Subject: [PATCH] [FEATURE-REQUEST]: Elastic Retrieval Action --- jaseci_ai_kit/install.sh | 2 +- .../jac_misc/elastic_retrieval/__init__.py | 1 + .../elastic_retrieval/elastic_retrieval.py | 241 ++++++++++++++++++ .../elastic_retrieval/requirements.txt | 7 + .../jac_misc/elastic_retrieval/utils.py | 147 +++++++++++ jaseci_ai_kit/jac_misc/setup.py | 1 + 6 files changed, 398 insertions(+), 1 deletion(-) create mode 100644 jaseci_ai_kit/jac_misc/jac_misc/elastic_retrieval/__init__.py create mode 100644 jaseci_ai_kit/jac_misc/jac_misc/elastic_retrieval/elastic_retrieval.py create mode 100644 jaseci_ai_kit/jac_misc/jac_misc/elastic_retrieval/requirements.txt create mode 100644 jaseci_ai_kit/jac_misc/jac_misc/elastic_retrieval/utils.py diff --git a/jaseci_ai_kit/install.sh b/jaseci_ai_kit/install.sh index 36526e08bb..157ea5e092 100644 --- a/jaseci_ai_kit/install.sh +++ b/jaseci_ai_kit/install.sh @@ -2,7 +2,7 @@ JAC_NLP_MODULES=("bart_sum" "cl_summer" "ent_ext" "fast_enc" "sbert_sim" "t5_sum" "text_seg" "tfm_ner" "use_enc" "use_qa" "zs_classifier" "bi_enc" "topic_ext" "gpt2" "gpt3" "dolly" "llm") JAC_SPEECH_MODULES=("stt" "vc_tts") -JAC_MISC_MODULES=("pdf_ext" "translator" "cluster" "ph" "openai" "huggingface" "langchain") +JAC_MISC_MODULES=("pdf_ext" "translator" "cluster" "ph" "openai" "elastic_retrieval" "huggingface" "langchain") JAC_VISION_MODULES=("detr" "rftm" "yolos" "dpt") install_modules() { diff --git a/jaseci_ai_kit/jac_misc/jac_misc/elastic_retrieval/__init__.py b/jaseci_ai_kit/jac_misc/jac_misc/elastic_retrieval/__init__.py new file mode 100644 index 0000000000..0d3bc3d407 --- /dev/null +++ b/jaseci_ai_kit/jac_misc/jac_misc/elastic_retrieval/__init__.py @@ -0,0 +1 @@ +from .elastic_retrieval import * # noqa diff --git a/jaseci_ai_kit/jac_misc/jac_misc/elastic_retrieval/elastic_retrieval.py b/jaseci_ai_kit/jac_misc/jac_misc/elastic_retrieval/elastic_retrieval.py new file mode 100644 index 0000000000..5476b36bc1 --- /dev/null +++ b/jaseci_ai_kit/jac_misc/jac_misc/elastic_retrieval/elastic_retrieval.py @@ -0,0 +1,241 @@ +import openai +from os import environ, unlink +from datetime import datetime +from requests import get +from uuid import uuid4 + +from .utils import extract_text_from_file, get_embeddings, generate_chunks +from jaseci.jsorc.live_actions import jaseci_action +from elasticsearch import Elasticsearch, NotFoundError + +CLIENT = None +CONFIG = { + "elastic": { + "url": environ.get("ELASTICSEARCH_URL", "http://localhost:9200"), + "key": environ.get("ELASTICSEARCH_API_KEY"), + "index_template": { + "name": environ.get("ELASTICSEARCH_INDEX_TEMPLATE") or "openai-embeddings", + "index_patterns": ( + environ.get("ELASTICSEARCH_INDEX_PATTERNS") or "oai-emb-*" + ).split(","), + "priority": 500, + "version": 1, + "template": { + "settings": { + "number_of_shards": int(environ.get("ELASTICSEARCH_SHARDS", "1")), + "number_of_replicas": int( + environ.get("ELASTICSEARCH_REPLICAS", "1") + ), + "refresh_interval": "1s", + }, + "mappings": { + "_source": {"enabled": True}, + "properties": { + "id": {"type": "keyword"}, + "embedding": { + "type": "dense_vector", + "dims": int( + environ.get("ELASTICSEARCH_VECTOR_SIZE", "1536") + ), + "index": True, + "similarity": environ.get( + "ELASTICSEARCH_SIMILARITY", "cosine" + ), + }, + "version": {"type": "keyword"}, + }, + }, + }, + }, + }, + "openai": { + "key": openai.api_key, + "type": openai.api_type, + "base": openai.api_base, + "version": openai.api_version, + "embedding": { + "deployment_id": environ.get("OPENAI_EMBEDDING_DEPLOYMENT_ID"), + "model": environ.get("OPENAI_EMBEDDING_MODEL", "text-embedding-ada-002"), + }, + }, + "chunk_config": { + "chunk_size": environ.get("CHUNK_SIZE", 200), + "min_chunk_size_chars": environ.get("MIN_CHUNK_SIZE_CHARS", 350), + "min_chunk_length_to_embed": environ.get("MIN_CHUNK_LENGTH_TO_EMBED", 5), + "max_num_chunks": environ.get("MAX_NUM_CHUNKS", 10000), + }, + "batch_size": int(environ.get("ELASTICSEARCH_UPSERT_BATCH_SIZE", "100")), +} + + +@jaseci_action(allow_remote=True) +def setup(config: dict = CONFIG, rebuild: bool = False, reindex_template: bool = False): + """ """ + global CONFIG, CLIENT + CONFIG = config + + if rebuild: + CLIENT = None + + if reindex_template: + reapply_index_template() + + openai_config = CONFIG["openai"] + openai.api_key = openai_config.get("key") or openai.api_key + openai.api_type = openai_config.get("type") or openai.api_type + openai.api_base = openai_config.get("base") or openai.api_base + openai.api_version = openai_config.get("version") or openai.api_version + + +@jaseci_action(allow_remote=True) +def upsert(index: str, data: dict, reset: bool = False, refresh=None, meta: dict = {}): + """ """ + bs = CONFIG["batch_size"] + + doc_u = data.get("url", []) + doc_t = data.get("text", []) + + # only works if not remote + doc_f = data.get("file", []) + + if reset: + reset_index(index) + else: + delete(index, [doc["id"] for doc in doc_t] + [doc["id"] for doc in doc_u]) + + doc_a = [] + for doc in doc_u: + file_name: str = "/tmp/" + (doc.pop("name", None) or str(uuid4())) + with get(doc.pop("url"), stream=True) as res, open(file_name, "wb") as buffer: + res.raise_for_status() + for chunk in res.iter_content(chunk_size=8192): + buffer.write(chunk) + + doc["text"] = extract_text_from_file(file_name) + + unlink(file_name) + doc_a += generate_chunks(doc, CONFIG["chunk_config"]) + + for doc in doc_t: + doc_a += generate_chunks(doc, CONFIG["chunk_config"]) + + hook = meta.get("h") + if hasattr(hook, "get_file_handler"): + for doc in doc_f: + fh = hook.get_file_handler(doc["file"]) + doc["text"] = extract_text_from_file(fh.absolute_path) + doc_a += generate_chunks(doc, CONFIG["chunk_config"]) + + ops_index = {"index": {"_index": index}} + ops_t = [] + for docs in [doc_a[x : x + bs] for x in range(0, len(doc_a), bs)]: + for i, emb in enumerate( + get_embeddings([doc["text"] for doc in docs], CONFIG["openai"]) + ): + docs[i]["embedding"] = emb + docs[i]["created_time"] = int( + datetime.fromisoformat(docs[i]["created_time"]).timestamp() + ) + ops_t.append(ops_index) + ops_t.append(docs[i]) + + elastic().bulk(operations=ops_t, index=index, refresh=refresh) + + return True + + +@jaseci_action(allow_remote=True) +def delete(index: str, ids: [], all: bool = False): + """ """ + if all: + return reset_index(index) + elif ids: + return ( + elastic() + .delete_by_query( + index=index, + query={"terms": {"id": ids}}, + ignore_unavailable=True, + ) + .body + ) + + +@jaseci_action(allow_remote=True) +def query(index: str, data: list): + """ """ + bs = CONFIG["batch_size"] + + search_index = {"index": index} + searches = [] + for queries in [data[x : x + bs] for x in range(0, len(data), bs)]: + for i, emb in enumerate( + get_embeddings([query["query"] for query in queries], CONFIG["openai"]) + ): + top = queries[i].get("top") or 3 + query = { + "knn": { + "field": "embedding", + "query_vector": emb, + "k": top, + "num_candidates": queries[i].get("num_candidates") or (top * 10), + "filter": queries[i].get("filter") or [], + } + } + + min_score = queries[i].get("min_score") + if min_score: + query["min_score"] = min_score + + searches.append(search_index) + searches.append(query) + + return [ + { + "query": query, + "results": [ + { + "id": hit["_source"]["id"], + "text": hit["_source"]["text"], + "score": hit["_score"], + } + for hit in result["hits"]["hits"] + ], + } + for query, result in zip( + queries, + elastic().msearch(searches=searches, ignore_unavailable=True)["responses"], + ) + ] + + +@jaseci_action(allow_remote=True) +def reset_index(index: str): + return elastic().indices.delete(index=index, ignore_unavailable=True).body + + +@jaseci_action(allow_remote=True) +def reapply_index_template(): + """ """ + index_template = CONFIG["elastic"]["index_template"] + try: + return elastic().indices.get_index_template(name=index_template["name"]).body + except NotFoundError: + return elastic().indices.put_index_template(**index_template).body + + +def elastic() -> Elasticsearch: + global CONFIG, CLIENT + if not CLIENT: + config = CONFIG.get("elastic") + try: + client = Elasticsearch( + hosts=[config["url"]], + api_key=config["key"], + request_timeout=config.get("request_timeout"), + ) + client.info() + CLIENT = client + except Exception as e: + raise e + return CLIENT diff --git a/jaseci_ai_kit/jac_misc/jac_misc/elastic_retrieval/requirements.txt b/jaseci_ai_kit/jac_misc/jac_misc/elastic_retrieval/requirements.txt new file mode 100644 index 0000000000..a9d2cc83da --- /dev/null +++ b/jaseci_ai_kit/jac_misc/jac_misc/elastic_retrieval/requirements.txt @@ -0,0 +1,7 @@ +openai>=0.27.9 +elasticsearch==8.9.0 +tenacity>=8.2.1 +PyPDF==3.15.4 +docx2txt>=0.8 +python-pptx>=0.6.21 +tiktoken>=0.2.0 \ No newline at end of file diff --git a/jaseci_ai_kit/jac_misc/jac_misc/elastic_retrieval/utils.py b/jaseci_ai_kit/jac_misc/jac_misc/elastic_retrieval/utils.py new file mode 100644 index 0000000000..b1dff37390 --- /dev/null +++ b/jaseci_ai_kit/jac_misc/jac_misc/elastic_retrieval/utils.py @@ -0,0 +1,147 @@ +import openai +import tiktoken +import docx2txt +import csv +import pptx +from PyPDF2 import PdfReader +from magic import from_buffer +from tenacity import retry, wait_random_exponential, stop_after_attempt + +tokenizer = tiktoken.get_encoding("cl100k_base") + + +@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(3)) +def get_embeddings(texts: list, config: dict): + embedding = config["embedding"] + deployment = embedding.get("deployment_id") + + kwargs = {"input": texts} + + if deployment: + kwargs["deployment_id"] = deployment + else: + kwargs["model"] = embedding.get("model") + + return [result["embedding"] for result in openai.Embedding.create(**kwargs)["data"]] + + +def extract_text_from_file(file) -> str: + with open(file, "rb") as buff: + mimetype = from_buffer(buff.read(), mime=True) + buff.seek(0) + + if mimetype == "application/pdf": + # Extract text from pdf using PyPDF2 + reader = PdfReader(buff) + extracted_text = " ".join([page.extract_text() for page in reader.pages]) + elif mimetype == "text/plain" or mimetype == "text/markdown": + # Read text from plain text buff + extracted_text = buff.read().decode("utf-8") + elif ( + mimetype + == "application/vnd.openxmlformats-officedocument.wordprocessingml.document" + ): + # Extract text from docx using docx2txt + extracted_text = docx2txt.process(buff) + elif mimetype == "text/csv": + # Extract text from csv using csv module + extracted_text = "" + decoded_buffer = (line.decode("utf-8") for line in buff) + reader = csv.reader(decoded_buffer) + for row in reader: + extracted_text += " ".join(row) + "\n" + elif ( + mimetype + == "application/vnd.openxmlformats-officedocument.presentationml.presentation" + ): + # Extract text from pptx using python-pptx + extracted_text = "" + presentation = pptx.Presentation(buff) + for slide in presentation.slides: + for shape in slide.shapes: + if shape.has_text_frame: + for paragraph in shape.text_frame.paragraphs: + for run in paragraph.runs: + extracted_text += run.text + " " + extracted_text += "\n" + else: + # Unsupported file type + raise ValueError("Unsupported file type: {}".format(mimetype)) + + return extracted_text + + +def generate_chunks(doc: dict, config: dict) -> list: + # Check if the document text is empty or whitespace + text: str = doc.get("text") + if not text or text.isspace(): + return [] + + tokens = tokenizer.encode(text, disallowed_special=()) + + # Initialize an empty list of chunks + chunks = [] + + # Use the provided chunk token size or the default one + chunk_size = config.get("chunk_size") or 200 + max_num_chunks = config.get("max_num_chunks") or 10000 + min_chunk_size_chars = config.get("min_chunk_size_chars") or 350 + min_chunk_length_to_embed = config.get("min_chunk_length_to_embed") or 5 + + # Initialize a counter for the number of chunks + num_chunks = 0 + + # Loop until all tokens are consumed + while tokens and num_chunks < max_num_chunks: + # Take the first chunk_size tokens as a chunk + chunk = tokens[:chunk_size] + + # Decode the chunk into text + chunk_text = tokenizer.decode(chunk) + + # Skip the chunk if it is empty or whitespace + if not chunk_text or chunk_text.isspace(): + # Remove the tokens corresponding to the chunk text from the remaining tokens + tokens = tokens[len(chunk) :] + # Continue to the next iteration of the loop + continue + + # Find the last period or punctuation mark in the chunk + last_punctuation = max( + chunk_text.rfind("."), + chunk_text.rfind("?"), + chunk_text.rfind("!"), + chunk_text.rfind("\n"), + ) + + # If there is a punctuation mark, and the last punctuation index is before MIN_CHUNK_SIZE_CHARS + if last_punctuation != -1 and last_punctuation > min_chunk_size_chars: + # Truncate the chunk text at the punctuation mark + chunk_text = chunk_text[: last_punctuation + 1] + + # Remove any newline characters and strip any leading or trailing whitespace + chunk_text_to_append = chunk_text.replace("\n", " ").strip() + + if len(chunk_text_to_append) > min_chunk_length_to_embed: + # Append the cloned doc with chunk text + chunks.append(clone_doc(doc, chunk_text_to_append)) + + # Remove the tokens corresponding to the chunk text from the remaining tokens + tokens = tokens[len(tokenizer.encode(chunk_text, disallowed_special=())) :] + + # Increment the number of chunks + num_chunks += 1 + + # Handle the remaining tokens + if tokens: + remaining_text = tokenizer.decode(tokens).replace("\n", " ").strip() + if len(remaining_text) > min_chunk_length_to_embed: + chunks.append(clone_doc(doc, remaining_text)) + + return chunks + + +def clone_doc(doc: dict, text: str) -> dict: + _doc = doc.copy() + _doc["text"] = text + return _doc diff --git a/jaseci_ai_kit/jac_misc/setup.py b/jaseci_ai_kit/jac_misc/setup.py index a5e6f5fdd8..73f0e7043e 100644 --- a/jaseci_ai_kit/jac_misc/setup.py +++ b/jaseci_ai_kit/jac_misc/setup.py @@ -7,6 +7,7 @@ "cluster", "ph", "openai", + "elastic_retrieval", "huggingface", "langchain", ]