Skip to content

Commit

Permalink
assistant fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
artitw committed Dec 1, 2024
1 parent 3472562 commit e9f2ef4
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 7 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="text2text",
version="1.8.3",
version="1.8.4",
author="artitw",
author_email="artitw@gmail.com",
description="Text2Text Language Modeling Toolkit",
Expand Down
28 changes: 22 additions & 6 deletions text2text/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,23 @@ def run_sh(script_string):
except Exception as e:
return str(e)

def apt_install_packages(packages, sudo=True):
try:
# Update the package list
cmds = ['apt', 'update']
if sudo:
cmds = ['sudo']+cmds
subprocess.run(cmds, check=True)

# Install the packages
cmds = ['apt', 'install', '-q', '-y'] + packages
if sudo:
cmds = ['sudo']+cmds
subprocess.run(cmds, check=True)

except subprocess.CalledProcessError as e:
raise Exception(str(e))

class Assistant(object):
def __init__(self, **kwargs):
self.host = kwargs.get("host", "http://localhost")
Expand All @@ -57,6 +74,7 @@ def __init__(self, **kwargs):
self.model_name = kwargs.get("model_name", "llama3.2")
self.schema_timeout = kwargs.get("schema_timeout", 120.0)
self.ollama_serve_proc = None
self.sudo = kwargs.get("sudo", True)
self.load_model()

def __del__(self):
Expand All @@ -76,9 +94,7 @@ def load_model(self):
pbar.update(1)

if can_use_apt():
return_code = os.system("apt install -q -y lshw curl")
if return_code != 0:
raise Exception("Cannot install lshw and/or curl.")
apt_install_packages(['lshw', 'curl'], self.sudo)
pbar.update(1)
elif platform.system() == "Windows":
raise Exception("Windows not supported.")
Expand Down Expand Up @@ -122,9 +138,9 @@ def model_loading(self):
ps_result = ollama.ps()
ls_result = ollama.list()
if ps_result and ls_result and \
ps_result.get("models", []) and ls_result.get("models", []) and \
ps_result.get("models")[0].get("name", "").startswith(self.model_name) and \
ls_result.get("models")[0].get("name", "").startswith(self.model_name):
ps_result.models and ls_result.models and \
ps_result.models[0].model.startswith(self.model_name) and \
ls_result.models[0].model.startswith(self.model_name):
return False
except Exception as e:
warnings.warn(str(e))
Expand Down

0 comments on commit e9f2ef4

Please sign in to comment.