Skip to content

Commit

Permalink
json_embed(json,model) for DuckDB
Browse files Browse the repository at this point in the history
  • Loading branch information
Florents-Tselai committed Aug 15, 2024
1 parent b45100c commit c5bb51a
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 19 deletions.
65 changes: 47 additions & 18 deletions tests/test_tsellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,15 @@ def test_interact_valid_multiline_sql(self):

class InMemorySQLiteTest(TsellmConsoleTest):
path_args = None
alice_json = """{
\"name\": \"Alice\",
\"details\": {
\"age\": 30,
\"hobbies\": [\"reading\", \"cycling\"],
\"location\": \"Wonderland\"
},
\"greeting\": \"Hello, World!\"
}"""

def setUp(self):
super().setUp()
Expand Down Expand Up @@ -225,18 +234,9 @@ def test_embed_hazo_binary(self):
self.expect_success(*self.path_args, "select embed(randomblob(16), 'hazo')")

def test_embed_json_recursive(self):
example_json = """{
\"name\": \"Alice\",
\"details\": {
\"age\": 30,
\"hobbies\": [\"reading\", \"cycling\"],
\"location\": \"Wonderland\"
},
\"greeting\": \"Hello, World!\"
}"""
out = self.expect_success(
*self.path_args,
f"select json_extract('{example_json}', '$.name')",
f"select json_extract('{self.alice_json}', '$.name')",
)
self.assertEqual(
"('Alice',)\n",
Expand All @@ -245,16 +245,18 @@ def test_embed_json_recursive(self):

out = self.expect_success(
*self.path_args,
f"select json_embed('{example_json}', 'hazo')",
f"select json_embed('{self.alice_json}', 'hazo')",
)
self.assertEqual(
('(\'{"name": [5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, '
'0.0, 0.0, 0.0, 0.0], "details": {"age": 30, "hobbies": [[7.0, 0.0, 0.0, 0.0, '
'0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [7.0, 0.0, 0.0, '
'0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], '
'"location": [10.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, '
'0.0, 0.0, 0.0, 0.0]}, "greeting": [6.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, '
"0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]}',)\n"),
(
'(\'{"name": [5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, '
'0.0, 0.0, 0.0, 0.0], "details": {"age": 30, "hobbies": [[7.0, 0.0, 0.0, 0.0, '
"0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [7.0, 0.0, 0.0, "
"0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], "
'"location": [10.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, '
'0.0, 0.0, 0.0, 0.0]}, "greeting": [6.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, '
"0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]}',)\n"
),
out,
)

Expand Down Expand Up @@ -323,6 +325,33 @@ def test_embed_hazo_binary(self):
# See https://github.com/Florents-Tselai/tsellm/issues/25
pass

def test_embed_json_recursive(self):
out = self.expect_success(
*self.path_args,
f"select '{self.alice_json}'::json -> 'name'",
)
self.assertEqual(
"('\"Alice\"',)\n",
out,
)

out = self.expect_success(
*self.path_args,
f"select json_embed('{self.alice_json}'::json, 'hazo')",
)
self.assertEqual(
(
'(\'{"name": [5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, '
'0.0, 0.0, 0.0, 0.0], "details": {"age": 30, "hobbies": [[7.0, 0.0, 0.0, 0.0, '
"0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [7.0, 0.0, 0.0, "
"0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], "
'"location": [10.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, '
'0.0, 0.0, 0.0, 0.0]}, "greeting": [6.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, '
"0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]}',)\n"
),
out,
)


class DiskDuckDBTest(InMemoryDuckDBTest):
db_fp = None
Expand Down
1 change: 1 addition & 0 deletions tsellm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ def is_valid_db(self) -> bool:
_functions = [
("prompt", 2, _prompt_model, False),
("embed", 2, _embed_model, False),
("json_embed", 2, _json_embed_model, False),
]

def connect(self):
Expand Down
2 changes: 1 addition & 1 deletion tsellm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"""



def json_recurse_apply(json_obj, f):
if isinstance(json_obj, dict):
# Recursively apply the function to dictionary values
Expand All @@ -29,6 +28,7 @@ def json_recurse_apply(json_obj, f):
# Return the object as is if it's neither a dictionary, list, or string
return json_obj


def _prompt_model(prompt: str, model: str) -> str:
return llm.get_model(model).prompt(prompt).text()

Expand Down

0 comments on commit c5bb51a

Please sign in to comment.