diff --git a/setup.py b/setup.py index 2118c5b..82c0eab 100755 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name="text2text", - version="1.5.6", + version="1.5.7", author="artitw", author_email="artitw@gmail.com", description="Text2Text: Crosslingual NLP/G toolkit", diff --git a/text2text/assistant.py b/text2text/assistant.py index 887bbee..652b2e2 100644 --- a/text2text/assistant.py +++ b/text2text/assistant.py @@ -2,7 +2,6 @@ import ollama import psutil import subprocess -import time from llama_index.llms.ollama import Ollama from llama_index.core.llms import ChatMessage @@ -23,33 +22,37 @@ def __del__(self): def load_model(self): return_code = os.system("sudo apt install -q -y lshw") if return_code != 0: - print("Cannot install lshw.") + raise Exception("Cannot install lshw.") return_code = os.system("curl -fsSL https://ollama.com/install.sh | sh") if return_code != 0: - print("Cannot install ollama.") + raise Exception("Cannot install ollama.") return_code = os.system("sudo systemctl enable ollama") if return_code != 0: - print("Cannot enable ollama.") + raise Exception("Cannot enable ollama.") sub = subprocess.Popen(["ollama", "serve"]) return_code = os.system("ollama -v") if return_code != 0: - print("Cannot serve ollama.") + raise Exception("Cannot serve ollama.") result = ollama.pull(self.model_name) if result["status"] != "success": - print(f"Did not pull {self.model_name}.") - - time.sleep(10) - + raise Exception(f"Did not pull {self.model_name}.") + def chat_completion(self, messages=[{"role": "user", "content": "hello"}], stream=False, schema=None, **kwargs): try: - ollama.ps() - result = ollama.pull(self.model_name) - if result["status"] == "success": - time.sleep(10) + result = ollama.ps() + if not result: + result = ollama.pull(self.model_name) + if result["status"] == "success": + return self.chat_completion(messages=messages, stream=stream, **kwargs) + raise Exception(f"Did not pull {self.model_name}. Try restarting.") + except Exception as e: + print(str(e)) + print("Retrying...") + self.load_model() return self.chat_completion(messages=messages, stream=stream, **kwargs) if schema: