From e8bb2e2788afb233f1ddb8c3f4b05b2472234bf2 Mon Sep 17 00:00:00 2001 From: artitw Date: Mon, 7 Oct 2024 02:57:16 +0000 Subject: [PATCH] RAG Assistant --- README.md | 14 +++++------ demos/Text2Text_LLM.ipynb | 28 ++++++++++++++++++++++ setup.py | 4 ++-- text2text/__init__.py | 1 + text2text/assistant.py | 49 +++++++++++++++++++++++++------------- text2text/handler.py | 1 + text2text/rag_assistant.py | 48 +++++++++++++++++++++++++++++++++++++ 7 files changed, 120 insertions(+), 25 deletions(-) create mode 100644 text2text/rag_assistant.py diff --git a/README.md b/README.md index 041788c..00bdd16 100755 --- a/README.md +++ b/README.md @@ -52,7 +52,7 @@ Module Importing | `import text2text as t2t` | Libraries imported [Assistant](https://github.com/artitw/text2text#assistant) | `t2t.Assistant().transform("Describe Text2Text in a few words: ")` | `['Text2Text is an AI-powered text generation tool that creates coherent and continuous text based on prompts.']` [Language Model Setting](https://github.com/artitw/text2text#byot-bring-your-own-translator) | `t2t.Transformer.PRETRAINED_TRANSLATOR = "facebook/m2m100_418M"` | Change from default [Tokenization](https://github.com/artitw/text2text#tokenization) | `t2t.Tokenizer().transform(["Hello, World!"])` | `[['▁Hello', ',', '▁World', '!']]` -[Embedding](https://github.com/artitw/text2text#embedding--vectorization) | `t2t.Vectorizer().transform(["Hello, World!"])` | `array([[0.18745188, 0.05658336, ..., 0.6332584 , 0.43805206]], dtype=float32)` +[Embedding](https://github.com/artitw/text2text#embedding--vectorization) | `t2t.Vectorizer().transform(["Hello, World!"])` | `[[0.18745188, 0.05658336, ..., 0.6332584 , 0.43805206]]` [TF-IDF](https://github.com/artitw/text2text#tf-idf) | `t2t.Tfidfer().transform(["Hello, World!"])` | `[{'!': 0.5, ',': 0.5, '▁Hello': 0.5, '▁World': 0.5}]` [BM25](https://github.com/artitw/text2text#bm25) | `t2t.Bm25er().transform(["Hello, World!"])` | `[{'!': 0.3068528194400547, ',': 0.3068528194400547, '▁Hello': 0.3068528194400547, '▁World': 0.3068528194400547}]` [Indexer](https://github.com/artitw/text2text#index) | `index = t2t.Indexer().transform(["Hello, World!"])` | Index object for information retrieval @@ -249,12 +249,12 @@ t2t.Vectorizer().transform([ ]) # Embeddings -array([[-0.00352954, 0.0260059 , 0.00407429, ..., -0.04830331, - -0.02540749, -0.00924972], - [ 0.00043362, 0.00249816, 0.01755436, ..., 0.04451273, - 0.05118701, 0.01895813], - [-0.03563676, -0.04856304, 0.00518898, ..., -0.00311068, - 0.00071953, -0.00216325]]) +[[-0.00352954, 0.0260059 , 0.00407429, ..., -0.04830331, + -0.02540749, -0.00924972], + [ 0.00043362, 0.00249816, 0.01755436, ..., 0.04451273, + 0.05118701, 0.01895813], + [-0.03563676, -0.04856304, 0.00518898, ..., -0.00311068, + 0.00071953, -0.00216325]] ``` ### TF-IDF diff --git a/demos/Text2Text_LLM.ipynb b/demos/Text2Text_LLM.ipynb index c894809..1ccc63d 100644 --- a/demos/Text2Text_LLM.ipynb +++ b/demos/Text2Text_LLM.ipynb @@ -194,6 +194,34 @@ }, "execution_count": null, "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# RAG Assistant\n", + "texts = [\n", + " \"**Get outdoors**: Go for a walk, hike, or bike ride to enjoy nature and fresh air.\",\n", + " \"**Learn something new**: Watch a TED talk, take an online course, or read a book on a topic that interests you.\",\n", + " \"**Practice self-care**: Treat yourself to a relaxing bath, meditation session, or yoga practice.\",\n", + " \"**Connect with others**: Call a friend or family member, meet up with someone for coffee, or join a social event.\",\n", + " \"**Get creative**: Paint, draw, write, or try any other creative activity that brings you joy.\",\n", + "]\n", + "\n", + "asst = t2t.RagAssistant(texts=texts)\n", + "\n", + "chat_history = [\n", + " {\"role\": \"user\", \"content\": \"What should I do today?\"}\n", + "]\n", + "\n", + "result = asst.chat_completion(chat_history, stream=True) #{'role': 'assistant', 'content': '1. Make a list of things to be grateful for.\\n2. Go outside and take a walk in nature.\\n3. Practice mindfulness meditation.\\n4. Connect with a loved one or friend.\\n5. Do something kind for someone else.\\n6. Engage in a creative activity like drawing or writing.\\n7. Read an uplifting book or listen to motivational podcasts.'}\n", + "for chunk in result:\n", + " print(chunk['message']['content'], end='', flush=True)" + ], + "metadata": { + "id": "nnxCHbgU5PVG" + }, + "execution_count": null, + "outputs": [] } ] } \ No newline at end of file diff --git a/setup.py b/setup.py index e428d64..443dfd0 100755 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name="text2text", - version="1.5.9", + version="1.6.0", author="artitw", author_email="artitw@gmail.com", description="Text2Text: Crosslingual NLP/G toolkit", @@ -18,7 +18,7 @@ "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", ], - keywords='multilingual crosslingual gpt chatgpt bert natural language processing nlp nlg text generation gpt question answer answering information retrieval tfidf tf-idf bm25 search index summary summarizer summarization tokenizer tokenization translation backtranslation data augmentation science machine learning colab embedding levenshtein sub-word edit distance conversational dialog chatbot mixtral', + keywords='multilingual crosslingual gpt chatgpt bert natural language processing nlp nlg text generation gpt question answer answering information retrieval tfidf tf-idf bm25 search index summary summarizer summarization tokenizer tokenization translation backtranslation data augmentation science machine learning colab embedding levenshtein sub-word edit distance conversational dialog chatbot llama', install_requires=[ 'faiss-cpu', 'flask', diff --git a/text2text/__init__.py b/text2text/__init__.py index c33b04e..3abf6b7 100755 --- a/text2text/__init__.py +++ b/text2text/__init__.py @@ -9,6 +9,7 @@ from .tfidfer import Tfidfer from .bm25er import Bm25er from .indexer import Indexer +from .rag_assistant import RagAssistant from .variator import Variator from .identifier import Identifier from .server import Server diff --git a/text2text/assistant.py b/text2text/assistant.py index e0ad589..d1ca066 100644 --- a/text2text/assistant.py +++ b/text2text/assistant.py @@ -1,41 +1,58 @@ import os import ollama import psutil +import time import subprocess +import warnings from llama_index.llms.ollama import Ollama from llama_index.core.llms import ChatMessage +def ollama_version(): + result = subprocess.check_output(["ollama", "-v"], stderr=subprocess.STDOUT).decode("utf-8") + if result.startswith("ollama version "): + return result.replace("ollama version ", "") + return "" + class Assistant(object): def __init__(self, **kwargs): self.host = kwargs.get("host", "http://localhost") self.port = kwargs.get("port", 11434) self.model_url = f"{self.host}:{self.port}" self.model_name = kwargs.get("model_name", "llama3.2") + self.ollama_serve_proc = None self.load_model() self.client = ollama.Client(host=self.model_url) self.structured_client = Ollama(model=self.model_name, request_timeout=120.0) def __del__(self): ollama.delete(self.model_name) + self.ollama_serve_proc.kill() def load_model(self): - return_code = os.system("sudo apt install -q -y lshw") - if return_code != 0: - raise Exception("Cannot install lshw.") + if not ollama_version(): + if self.ollama_serve_proc: + self.ollama_serve_proc.kill() + self.ollama_serve_proc = None + + return_code = os.system("sudo apt install -q -y lshw") + if return_code != 0: + raise Exception("Cannot install lshw.") + + return_code = os.system("curl -fsSL https://ollama.com/install.sh | sh") + if return_code != 0: + raise Exception("Cannot install ollama.") - return_code = os.system("curl -fsSL https://ollama.com/install.sh | sh") - if return_code != 0: - raise Exception("Cannot install ollama.") + return_code = os.system("sudo systemctl enable ollama") + if return_code != 0: + raise Exception("Cannot enable ollama.") - return_code = os.system("sudo systemctl enable ollama") - if return_code != 0: - raise Exception("Cannot enable ollama.") + self.ollama_serve_proc = subprocess.Popen(["ollama", "serve"]) + time.sleep(1) - sub = subprocess.Popen(["ollama", "serve"]) - return_code = os.system("ollama -v") - if return_code != 0: - raise Exception("Cannot serve ollama.") + result = subprocess.check_output(["ollama", "-v"], stderr=subprocess.STDOUT).decode("utf-8") + if not result.startswith("ollama version"): + raise Exception(result) result = ollama.pull(self.model_name) if result["status"] != "success": @@ -44,14 +61,14 @@ def load_model(self): def chat_completion(self, messages=[{"role": "user", "content": "hello"}], stream=False, schema=None, **kwargs): try: result = ollama.ps() - if not result: + if not result or not result.get("models"): result = ollama.pull(self.model_name) if result["status"] == "success": return self.chat_completion(messages=messages, stream=stream, **kwargs) raise Exception(f"Did not pull {self.model_name}. Try restarting.") except Exception as e: - print(str(e)) - print("Retrying...") + warnings.warn(str(e)) + warnings.warn("Retrying...") self.load_model() return self.chat_completion(messages=messages, stream=stream, **kwargs) diff --git a/text2text/handler.py b/text2text/handler.py index aa568cd..4283ff6 100644 --- a/text2text/handler.py +++ b/text2text/handler.py @@ -7,6 +7,7 @@ class Handler(object): EXPOSED_TRANSFORMERS = { "assist": t2t.Assistant, + "rag_assist": t2t.RagAssistant, "bm25": t2t.Bm25er, "count": t2t.Counter, "identify":t2t.Identifier, diff --git a/text2text/rag_assistant.py b/text2text/rag_assistant.py new file mode 100644 index 0000000..1cc0ccf --- /dev/null +++ b/text2text/rag_assistant.py @@ -0,0 +1,48 @@ +import text2text as t2t + +import urllib.parse +import urllib.request + +import warnings + +def is_valid_url(url): + try: + result = urllib.parse.urlparse(url) + return all([result.scheme, result.netloc]) + except Exception: + return False + +class RagAssistant(t2t.Assistant): + def __init__(self, **kwargs): + super().__init__(**kwargs) + schema = kwargs.get("schema", None) + texts = kwargs.get("texts", []) + urls = kwargs.get("urls", []) + input_lines = [] + for u in urls: + if is_valid_url(u): + try: + with urllib.request.urlopen(u) as f: + texts.append(f.read()) + except Exception as e: + warnings.warn(f"Skipping URL with errors: {u}") + else: + warnings.warn(f"Skipping invalid URL: {u}") + + if schema: + for t in texts: + res = t2t.Assistant.chat_completion(self, [{"role": "user", "content": t}], schema=schema) + res = "\n".join(f'{k}: {v}' for k,v in vars(res).items()) + input_lines.append(res) + else: + input_lines = texts + + self.index = t2t.Indexer().transform(input_lines, encoders=[t2t.Vectorizer()]) + + def chat_completion(self, messages=[{"role": "user", "content": "hello"}], stream=False, schema=None, **kwargs): + k = kwargs.get("k", 3) + query = messages[-1]["content"] + docs = self.index.retrieve([query], k=k)[0] + grounding_information = "\n\n".join(docs) + "\n\n" + messages[-1] = {"role": "user", "content": grounding_information+query} + return t2t.Assistant.chat_completion(self, messages=messages, stream=stream, schema=schema, **kwargs)