Skip to content

Commit

Permalink
RAG Assistant
Browse files Browse the repository at this point in the history
  • Loading branch information
artitw committed Oct 7, 2024
1 parent 3f9baf4 commit e8bb2e2
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 25 deletions.
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ Module Importing | `import text2text as t2t` | Libraries imported
[Assistant](https://github.com/artitw/text2text#assistant) | `t2t.Assistant().transform("Describe Text2Text in a few words: ")` | `['Text2Text is an AI-powered text generation tool that creates coherent and continuous text based on prompts.']`
[Language Model Setting](https://github.com/artitw/text2text#byot-bring-your-own-translator) | `t2t.Transformer.PRETRAINED_TRANSLATOR = "facebook/m2m100_418M"` | Change from default
[Tokenization](https://github.com/artitw/text2text#tokenization) | `t2t.Tokenizer().transform(["Hello, World!"])` | `[['▁Hello', ',', '▁World', '!']]`
[Embedding](https://github.com/artitw/text2text#embedding--vectorization) | `t2t.Vectorizer().transform(["Hello, World!"])` | `array([[0.18745188, 0.05658336, ..., 0.6332584 , 0.43805206]], dtype=float32)`
[Embedding](https://github.com/artitw/text2text#embedding--vectorization) | `t2t.Vectorizer().transform(["Hello, World!"])` | `[[0.18745188, 0.05658336, ..., 0.6332584 , 0.43805206]]`
[TF-IDF](https://github.com/artitw/text2text#tf-idf) | `t2t.Tfidfer().transform(["Hello, World!"])` | `[{'!': 0.5, ',': 0.5, '▁Hello': 0.5, '▁World': 0.5}]`
[BM25](https://github.com/artitw/text2text#bm25) | `t2t.Bm25er().transform(["Hello, World!"])` | `[{'!': 0.3068528194400547, ',': 0.3068528194400547, '▁Hello': 0.3068528194400547, '▁World': 0.3068528194400547}]`
[Indexer](https://github.com/artitw/text2text#index) | `index = t2t.Indexer().transform(["Hello, World!"])` | Index object for information retrieval
Expand Down Expand Up @@ -249,12 +249,12 @@ t2t.Vectorizer().transform([
])
# Embeddings
array([[-0.00352954, 0.0260059 , 0.00407429, ..., -0.04830331,
-0.02540749, -0.00924972],
[ 0.00043362, 0.00249816, 0.01755436, ..., 0.04451273,
0.05118701, 0.01895813],
[-0.03563676, -0.04856304, 0.00518898, ..., -0.00311068,
0.00071953, -0.00216325]])
[[-0.00352954, 0.0260059 , 0.00407429, ..., -0.04830331,
-0.02540749, -0.00924972],
[ 0.00043362, 0.00249816, 0.01755436, ..., 0.04451273,
0.05118701, 0.01895813],
[-0.03563676, -0.04856304, 0.00518898, ..., -0.00311068,
0.00071953, -0.00216325]]
```

### TF-IDF
Expand Down
28 changes: 28 additions & 0 deletions demos/Text2Text_LLM.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,34 @@
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# RAG Assistant\n",
"texts = [\n",
" \"**Get outdoors**: Go for a walk, hike, or bike ride to enjoy nature and fresh air.\",\n",
" \"**Learn something new**: Watch a TED talk, take an online course, or read a book on a topic that interests you.\",\n",
" \"**Practice self-care**: Treat yourself to a relaxing bath, meditation session, or yoga practice.\",\n",
" \"**Connect with others**: Call a friend or family member, meet up with someone for coffee, or join a social event.\",\n",
" \"**Get creative**: Paint, draw, write, or try any other creative activity that brings you joy.\",\n",
"]\n",
"\n",
"asst = t2t.RagAssistant(texts=texts)\n",
"\n",
"chat_history = [\n",
" {\"role\": \"user\", \"content\": \"What should I do today?\"}\n",
"]\n",
"\n",
"result = asst.chat_completion(chat_history, stream=True) #{'role': 'assistant', 'content': '1. Make a list of things to be grateful for.\\n2. Go outside and take a walk in nature.\\n3. Practice mindfulness meditation.\\n4. Connect with a loved one or friend.\\n5. Do something kind for someone else.\\n6. Engage in a creative activity like drawing or writing.\\n7. Read an uplifting book or listen to motivational podcasts.'}\n",
"for chunk in result:\n",
" print(chunk['message']['content'], end='', flush=True)"
],
"metadata": {
"id": "nnxCHbgU5PVG"
},
"execution_count": null,
"outputs": []
}
]
}
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="text2text",
version="1.5.9",
version="1.6.0",
author="artitw",
author_email="artitw@gmail.com",
description="Text2Text: Crosslingual NLP/G toolkit",
Expand All @@ -18,7 +18,7 @@
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
],
keywords='multilingual crosslingual gpt chatgpt bert natural language processing nlp nlg text generation gpt question answer answering information retrieval tfidf tf-idf bm25 search index summary summarizer summarization tokenizer tokenization translation backtranslation data augmentation science machine learning colab embedding levenshtein sub-word edit distance conversational dialog chatbot mixtral',
keywords='multilingual crosslingual gpt chatgpt bert natural language processing nlp nlg text generation gpt question answer answering information retrieval tfidf tf-idf bm25 search index summary summarizer summarization tokenizer tokenization translation backtranslation data augmentation science machine learning colab embedding levenshtein sub-word edit distance conversational dialog chatbot llama',
install_requires=[
'faiss-cpu',
'flask',
Expand Down
1 change: 1 addition & 0 deletions text2text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .tfidfer import Tfidfer
from .bm25er import Bm25er
from .indexer import Indexer
from .rag_assistant import RagAssistant
from .variator import Variator
from .identifier import Identifier
from .server import Server
Expand Down
49 changes: 33 additions & 16 deletions text2text/assistant.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,58 @@
import os
import ollama
import psutil
import time
import subprocess
import warnings

from llama_index.llms.ollama import Ollama
from llama_index.core.llms import ChatMessage

def ollama_version():
result = subprocess.check_output(["ollama", "-v"], stderr=subprocess.STDOUT).decode("utf-8")
if result.startswith("ollama version "):
return result.replace("ollama version ", "")
return ""

class Assistant(object):
def __init__(self, **kwargs):
self.host = kwargs.get("host", "http://localhost")
self.port = kwargs.get("port", 11434)
self.model_url = f"{self.host}:{self.port}"
self.model_name = kwargs.get("model_name", "llama3.2")
self.ollama_serve_proc = None
self.load_model()
self.client = ollama.Client(host=self.model_url)
self.structured_client = Ollama(model=self.model_name, request_timeout=120.0)

def __del__(self):
ollama.delete(self.model_name)
self.ollama_serve_proc.kill()

def load_model(self):
return_code = os.system("sudo apt install -q -y lshw")
if return_code != 0:
raise Exception("Cannot install lshw.")
if not ollama_version():
if self.ollama_serve_proc:
self.ollama_serve_proc.kill()
self.ollama_serve_proc = None

return_code = os.system("sudo apt install -q -y lshw")
if return_code != 0:
raise Exception("Cannot install lshw.")

return_code = os.system("curl -fsSL https://ollama.com/install.sh | sh")
if return_code != 0:
raise Exception("Cannot install ollama.")

return_code = os.system("curl -fsSL https://ollama.com/install.sh | sh")
if return_code != 0:
raise Exception("Cannot install ollama.")
return_code = os.system("sudo systemctl enable ollama")
if return_code != 0:
raise Exception("Cannot enable ollama.")

return_code = os.system("sudo systemctl enable ollama")
if return_code != 0:
raise Exception("Cannot enable ollama.")
self.ollama_serve_proc = subprocess.Popen(["ollama", "serve"])
time.sleep(1)

sub = subprocess.Popen(["ollama", "serve"])
return_code = os.system("ollama -v")
if return_code != 0:
raise Exception("Cannot serve ollama.")
result = subprocess.check_output(["ollama", "-v"], stderr=subprocess.STDOUT).decode("utf-8")
if not result.startswith("ollama version"):
raise Exception(result)

result = ollama.pull(self.model_name)
if result["status"] != "success":
Expand All @@ -44,14 +61,14 @@ def load_model(self):
def chat_completion(self, messages=[{"role": "user", "content": "hello"}], stream=False, schema=None, **kwargs):
try:
result = ollama.ps()
if not result:
if not result or not result.get("models"):
result = ollama.pull(self.model_name)
if result["status"] == "success":
return self.chat_completion(messages=messages, stream=stream, **kwargs)
raise Exception(f"Did not pull {self.model_name}. Try restarting.")
except Exception as e:
print(str(e))
print("Retrying...")
warnings.warn(str(e))
warnings.warn("Retrying...")
self.load_model()
return self.chat_completion(messages=messages, stream=stream, **kwargs)

Expand Down
1 change: 1 addition & 0 deletions text2text/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ class Handler(object):

EXPOSED_TRANSFORMERS = {
"assist": t2t.Assistant,
"rag_assist": t2t.RagAssistant,
"bm25": t2t.Bm25er,
"count": t2t.Counter,
"identify":t2t.Identifier,
Expand Down
48 changes: 48 additions & 0 deletions text2text/rag_assistant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import text2text as t2t

import urllib.parse
import urllib.request

import warnings

def is_valid_url(url):
try:
result = urllib.parse.urlparse(url)
return all([result.scheme, result.netloc])
except Exception:
return False

class RagAssistant(t2t.Assistant):
def __init__(self, **kwargs):
super().__init__(**kwargs)
schema = kwargs.get("schema", None)
texts = kwargs.get("texts", [])
urls = kwargs.get("urls", [])
input_lines = []
for u in urls:
if is_valid_url(u):
try:
with urllib.request.urlopen(u) as f:
texts.append(f.read())
except Exception as e:
warnings.warn(f"Skipping URL with errors: {u}")
else:
warnings.warn(f"Skipping invalid URL: {u}")

if schema:
for t in texts:
res = t2t.Assistant.chat_completion(self, [{"role": "user", "content": t}], schema=schema)
res = "\n".join(f'{k}: {v}' for k,v in vars(res).items())
input_lines.append(res)
else:
input_lines = texts

self.index = t2t.Indexer().transform(input_lines, encoders=[t2t.Vectorizer()])

def chat_completion(self, messages=[{"role": "user", "content": "hello"}], stream=False, schema=None, **kwargs):
k = kwargs.get("k", 3)
query = messages[-1]["content"]
docs = self.index.retrieve([query], k=k)[0]
grounding_information = "\n\n".join(docs) + "\n\n"
messages[-1] = {"role": "user", "content": grounding_information+query}
return t2t.Assistant.chat_completion(self, messages=messages, stream=stream, schema=schema, **kwargs)

0 comments on commit e8bb2e2

Please sign in to comment.