diff --git a/README.md b/README.md index db0fa83..1f5cb0b 100755 --- a/README.md +++ b/README.md @@ -225,9 +225,11 @@ class Song(BaseModel): result = asst.chat_completion([ {"role": "user", "content": "What is Britney Spears's best song?"} -], schema=Song, max_new_tokens=16) - +], schema=Song) # Song(name='Toxic', artist='Britney Spears') + +# Embeddings +asst.embed(["hello, world!", "this will be embedded"]) ``` ### Tokenization diff --git a/demos/Text2Text_LLM.ipynb b/demos/Text2Text_LLM.ipynb index f1ecdd2..c894809 100644 --- a/demos/Text2Text_LLM.ipynb +++ b/demos/Text2Text_LLM.ipynb @@ -174,13 +174,26 @@ "\n", "result = asst.chat_completion([\n", " {\"role\": \"user\", \"content\": \"What is Britney Spears's best song?\"}\n", - "], schema=Song, max_new_tokens=16)" + "], schema=Song)\n", + "print(result) #Song(name='Toxic', artist='Britney Spears')" ], "metadata": { "id": "e5khHlNQZ0FD" }, "execution_count": null, "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Embeddings\n", + "asst.embed([\"hello, world!\", \"this will be embedded\"])" + ], + "metadata": { + "id": "WJX2klusQR9q" + }, + "execution_count": null, + "outputs": [] } ] } \ No newline at end of file diff --git a/setup.py b/setup.py index 32a57dd..b033308 100755 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name="text2text", - version="1.5.2", + version="1.5.3", author="artitw", author_email="artitw@gmail.com", description="Text2Text: Crosslingual NLP/G toolkit", diff --git a/text2text/assistant.py b/text2text/assistant.py index a777dc8..e2ebe3a 100644 --- a/text2text/assistant.py +++ b/text2text/assistant.py @@ -22,6 +22,7 @@ def __init__(self, **kwargs): self.model_name = kwargs.get("model_name", "llama3.1") self.load_model() self.client = ollama.Client(host=self.model_url) + self.llama_index_client = Ollama(model=self.model_name, request_timeout=120.0) def __del__(self): ollama.delete(self.model_name) @@ -46,12 +47,14 @@ def chat_completion(self, messages=[{"role": "user", "content": "hello"}], strea if is_port_in_use(self.port): if schema: msgs = [ChatMessage(**m) for m in messages] - llama_index_client = Ollama(model=self.model_name, request_timeout=120.0) - return llama_index_client.as_structured_llm(schema).chat(messages=msgs).raw + return self.llama_index_client.as_structured_llm(schema).chat(messages=msgs).raw return self.client.chat(model=self.model_name, messages=messages, stream=stream) self.load_model() return self.chat_completion(messages=messages, stream=stream, **kwargs) + def embed(self, texts): + return ollama.embed(model=self.model_name, input=texts) + def transform(self, input_lines, src_lang='en', **kwargs): return self.chat_completion([{"role": "user", "content": input_lines}])["message"]["content"]