Skip to content

Commit

Permalink
json_embed implementation for SQLite
Browse files Browse the repository at this point in the history
  • Loading branch information
Florents-Tselai committed Aug 15, 2024
1 parent 31f77b8 commit b45100c
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 2 deletions.
35 changes: 34 additions & 1 deletion tests/test_tsellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import unittest
from pathlib import Path
from test.support import captured_stdout, captured_stderr, captured_stdin
from test.support.os_helper import TESTFN, unlink

import duckdb
import llm.cli
Expand Down Expand Up @@ -225,6 +224,40 @@ def test_embed_hazo_binary(self):
self.assertTrue(llm.get_embedding_model("hazo").supports_binary)
self.expect_success(*self.path_args, "select embed(randomblob(16), 'hazo')")

def test_embed_json_recursive(self):
example_json = """{
\"name\": \"Alice\",
\"details\": {
\"age\": 30,
\"hobbies\": [\"reading\", \"cycling\"],
\"location\": \"Wonderland\"
},
\"greeting\": \"Hello, World!\"
}"""
out = self.expect_success(
*self.path_args,
f"select json_extract('{example_json}', '$.name')",
)
self.assertEqual(
"('Alice',)\n",
out,
)

out = self.expect_success(
*self.path_args,
f"select json_embed('{example_json}', 'hazo')",
)
self.assertEqual(
('(\'{"name": [5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, '
'0.0, 0.0, 0.0, 0.0], "details": {"age": 30, "hobbies": [[7.0, 0.0, 0.0, 0.0, '
'0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [7.0, 0.0, 0.0, '
'0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], '
'"location": [10.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, '
'0.0, 0.0, 0.0, 0.0]}, "greeting": [6.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, '
"0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]}',)\n"),
out,
)

def test_embed_default_hazo(self):
self.assertEqual(llm_cli.get_default_embedding_model(), "hazo")
out = self.expect_success(*self.path_args, "select embed('hello world')")
Expand Down
4 changes: 3 additions & 1 deletion tsellm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
_prompt_model,
_prompt_model_default,
_embed_model,
_json_embed_model,
_embed_model_default,
)

Expand Down Expand Up @@ -79,6 +80,7 @@ class TsellmConsole(InteractiveConsole, ABC):
("prompt", 1, _prompt_model_default, False),
("embed", 2, _embed_model, False),
("embed", 1, _embed_model_default, False),
("json_embed", 2, _json_embed_model, False),
]

error_class = None
Expand All @@ -87,7 +89,7 @@ class TsellmConsole(InteractiveConsole, ABC):

@staticmethod
def create_console(
fp: Union[str, Path], in_memory_type: DatabaseType = DatabaseType.UNKNOWN
fp: Union[str, Path], in_memory_type: DatabaseType = DatabaseType.UNKNOWN
):
sniffer = DBSniffer(fp)
if sniffer.is_in_memory:
Expand Down
21 changes: 21 additions & 0 deletions tsellm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,21 @@
"""



def json_recurse_apply(json_obj, f):
if isinstance(json_obj, dict):
# Recursively apply the function to dictionary values
return {k: json_recurse_apply(v, f) for k, v in json_obj.items()}
elif isinstance(json_obj, list):
# Recursively apply the function to list elements
return [json_recurse_apply(item, f) for item in json_obj]
elif isinstance(json_obj, str):
# Apply the function to string values, which returns a list of floats
return f(json_obj)
else:
# Return the object as is if it's neither a dictionary, list, or string
return json_obj

def _prompt_model(prompt: str, model: str) -> str:
return llm.get_model(model).prompt(prompt).text()

Expand All @@ -26,6 +41,12 @@ def _embed_model(text: str, model: str) -> str:
return json.dumps(llm.get_embedding_model(model).embed(text))


def _json_embed_model(js: str, model: str) -> str:
return json.dumps(
json_recurse_apply(json.loads(js), lambda v: json.loads(_embed_model(v, model)))
)


def _embed_model_default(text: str) -> str:
return json.dumps(
llm.get_embedding_model(llm_cli.get_default_embedding_model()).embed(text)
Expand Down

0 comments on commit b45100c

Please sign in to comment.