diff --git a/setup.py b/setup.py index 787d1fe..7efc5aa 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name="text2text", - version="1.8.3", + version="1.8.4", author="artitw", author_email="artitw@gmail.com", description="Text2Text Language Modeling Toolkit", diff --git a/text2text/assistant.py b/text2text/assistant.py index 05b97d1..97e9801 100644 --- a/text2text/assistant.py +++ b/text2text/assistant.py @@ -49,6 +49,23 @@ def run_sh(script_string): except Exception as e: return str(e) +def apt_install_packages(packages, sudo=True): + try: + # Update the package list + cmds = ['apt', 'update'] + if sudo: + cmds = ['sudo']+cmds + subprocess.run(cmds, check=True) + + # Install the packages + cmds = ['apt', 'install', '-q', '-y'] + packages + if sudo: + cmds = ['sudo']+cmds + subprocess.run(cmds, check=True) + + except subprocess.CalledProcessError as e: + raise Exception(str(e)) + class Assistant(object): def __init__(self, **kwargs): self.host = kwargs.get("host", "http://localhost") @@ -57,6 +74,7 @@ def __init__(self, **kwargs): self.model_name = kwargs.get("model_name", "llama3.2") self.schema_timeout = kwargs.get("schema_timeout", 120.0) self.ollama_serve_proc = None + self.sudo = kwargs.get("sudo", True) self.load_model() def __del__(self): @@ -76,9 +94,7 @@ def load_model(self): pbar.update(1) if can_use_apt(): - return_code = os.system("apt install -q -y lshw curl") - if return_code != 0: - raise Exception("Cannot install lshw and/or curl.") + apt_install_packages(['lshw', 'curl'], self.sudo) pbar.update(1) elif platform.system() == "Windows": raise Exception("Windows not supported.") @@ -122,9 +138,9 @@ def model_loading(self): ps_result = ollama.ps() ls_result = ollama.list() if ps_result and ls_result and \ - ps_result.get("models", []) and ls_result.get("models", []) and \ - ps_result.get("models")[0].get("name", "").startswith(self.model_name) and \ - ls_result.get("models")[0].get("name", "").startswith(self.model_name): + ps_result.models and ls_result.models and \ + ps_result.models[0].model.startswith(self.model_name) and \ + ls_result.models[0].model.startswith(self.model_name): return False except Exception as e: warnings.warn(str(e))