Skip to content

Commit

Permalink
Add support for DuckDB (#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
Florents-Tselai authored Jul 7, 2024
1 parent b68ec48 commit 7802049
Show file tree
Hide file tree
Showing 6 changed files with 529 additions and 98 deletions.
9 changes: 4 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from setuptools import setup
import os

VERSION = "0.1.0a10"
from tsellm import __version__


def get_long_description():
Expand All @@ -14,7 +13,7 @@ def get_long_description():

setup(
name="tsellm",
description="Interactive SQLite shell with LLM support",
description=__version__.__description__,
long_description=get_long_description(),
long_description_content_type="text/markdown",
author="Florents Tselai",
Expand All @@ -29,9 +28,9 @@ def get_long_description():
"Changelog": "https://github.com/Florents-Tselai/tsellm/releases",
},
license="BSD License",
version=VERSION,
version=__version__.__version__,
packages=["tsellm"],
install_requires=["llm", "setuptools", "pip"],
install_requires=["llm", "setuptools", "pip", "duckdb"],
extras_require={
"test": [
"pytest",
Expand Down
246 changes: 218 additions & 28 deletions tests/test_tsellm.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,46 @@
import llm.cli
from sqlite_utils import Database
from tsellm.cli import cli
import sqlite3
import tempfile
import unittest
from test.support import captured_stdout, captured_stderr, captured_stdin, os_helper
from pathlib import Path
from test.support import captured_stdout, captured_stderr, captured_stdin
from test.support.os_helper import TESTFN, unlink
from llm import models
import sqlite3

import duckdb
import llm.cli
from llm import cli as llm_cli

from tsellm.__version__ import __version__
from tsellm.cli import (
cli,
TsellmConsole,
SQLiteConsole,
TsellmConsoleMixin,
)


def new_tempfile():
return Path(tempfile.mkdtemp()) / "test"


def new_sqlite_file():
f = new_tempfile()
with sqlite3.connect(f) as db:
db.execute("SELECT 1")
return f


class CommandLineInterface(unittest.TestCase):
def new_duckdb_file():
f = new_tempfile()
con = duckdb.connect(f.__str__())
con.sql("SELECT 1")
return f


class TsellmConsoleTest(unittest.TestCase):
def setUp(self):
super().setUp()
llm_cli.set_default_model("markov")
llm_cli.set_default_embedding_model("hazo")

def _do_test(self, *args, expect_success=True):
with (
Expand Down Expand Up @@ -38,25 +69,132 @@ def expect_failure(self, *args):
self.assertEqual(out, "")
return err

def test_sniff_sqlite(self):
self.assertTrue(TsellmConsoleMixin().is_sqlite(new_sqlite_file()))

def test_sniff_duckdb(self):
self.assertTrue(TsellmConsoleMixin().is_duckdb(new_duckdb_file()))

def test_console_factory_sqlite(self):
s = new_sqlite_file()
self.assertTrue(TsellmConsoleMixin().is_sqlite(s))
obj = TsellmConsole.create_console(s)
self.assertIsInstance(obj, SQLiteConsole)

# def test_console_factory_duckdb(self):
# s = new_duckdb_file()
# self.assertTrue(TsellmConsole.is_duckdb(s))
# obj = TsellmConsole.create_console(s)
# self.assertIsInstance(obj, DuckDBConsole)

def test_cli_help(self):
out = self.expect_success("-h")
self.assertIn("usage: python -m tsellm", out)

def test_cli_version(self):
out = self.expect_success("-v")
self.assertIn(__version__, out)

def test_choose_db(self):
self.expect_failure("--sqlite", "--duckdb")

def test_deault_sqlite(self):
f = new_tempfile()
self.expect_success(str(f), "select 1")
self.assertTrue(TsellmConsoleMixin().is_sqlite(f))

MEMORY_DB_MSG = "Connected to :memory:"
PS1 = "tsellm> "
PS2 = "... "

def run_cli(self, *args, commands=()):
with (
captured_stdin() as stdin,
captured_stdout() as stdout,
captured_stderr() as stderr,
self.assertRaises(SystemExit) as cm
):
for cmd in commands:
stdin.write(cmd + "\n")
stdin.seek(0)
cli(args)

out = stdout.getvalue()
err = stderr.getvalue()
self.assertEqual(cm.exception.code, 0,
f"Unexpected failure: {args=}\n{out}\n{err}")
return out, err

def test_interact(self):
out, err = self.run_cli()
self.assertIn(self.MEMORY_DB_MSG, err)
self.assertIn(self.MEMORY_DB_MSG, err)
self.assertTrue(out.endswith(self.PS1))
self.assertEqual(out.count(self.PS1), 1)
self.assertEqual(out.count(self.PS2), 0)

def test_interact_quit(self):
out, err = self.run_cli(commands=(".quit",))
self.assertIn(self.MEMORY_DB_MSG, err)
self.assertTrue(out.endswith(self.PS1))
self.assertEqual(out.count(self.PS1), 1)
self.assertEqual(out.count(self.PS2), 0)

def test_interact_version(self):
out, err = self.run_cli(commands=(".version",))
self.assertIn(self.MEMORY_DB_MSG, err)
self.assertIn(sqlite3.sqlite_version + "\n", out)
self.assertTrue(out.endswith(self.PS1))
self.assertEqual(out.count(self.PS1), 2)
self.assertEqual(out.count(self.PS2), 0)
self.assertIn(sqlite3.sqlite_version, out)

def test_interact_valid_sql(self):
out, err = self.run_cli(commands=("SELECT 1;",))
self.assertIn(self.MEMORY_DB_MSG, err)
self.assertIn("(1,)\n", out)
self.assertTrue(out.endswith(self.PS1))
self.assertEqual(out.count(self.PS1), 2)
self.assertEqual(out.count(self.PS2), 0)

def test_interact_incomplete_multiline_sql(self):
out, err = self.run_cli(commands=("SELECT 1",))
self.assertIn(self.MEMORY_DB_MSG, err)
self.assertTrue(out.endswith(self.PS2))
self.assertEqual(out.count(self.PS1), 1)
self.assertEqual(out.count(self.PS2), 1)

def test_interact_valid_multiline_sql(self):
out, err = self.run_cli(commands=("SELECT 1\n;",))
self.assertIn(self.MEMORY_DB_MSG, err)
self.assertIn(self.PS2, out)
self.assertIn("(1,)\n", out)
self.assertTrue(out.endswith(self.PS1))
self.assertEqual(out.count(self.PS1), 2)
self.assertEqual(out.count(self.PS2), 1)


class InMemorySQLiteTest(TsellmConsoleTest):
path_args = None

def setUp(self):
super().setUp()
self.path_args = (
"--sqlite",
":memory:",
)

def test_cli_execute_sql(self):
out = self.expect_success(":memory:", "select 1")
out = self.expect_success(*self.path_args, "select 1")
self.assertIn("(1,)", out)

def test_cli_execute_too_much_sql(self):
stderr = self.expect_failure(":memory:", "select 1; select 2")
stderr = self.expect_failure(*self.path_args, "select 1; select 2")
err = "ProgrammingError: You can only execute one statement at a time"
self.assertIn(err, stderr)

def test_cli_execute_incomplete_sql(self):
stderr = self.expect_failure(":memory:", "sel")
stderr = self.expect_failure(*self.path_args, "sel")
self.assertIn("OperationalError (SQLITE_ERROR)", stderr)

def test_cli_on_disk_db(self):
Expand All @@ -66,47 +204,99 @@ def test_cli_on_disk_db(self):
out = self.expect_success(TESTFN, "select count(t) from t")
self.assertIn("(0,)", out)


class SQLiteLLMFunction(CommandLineInterface):

def setUp(self):
super().setUp()
llm_cli.set_default_model("markov")
llm_cli.set_default_embedding_model("hazo")

def assertMarkovResult(self, prompt, generated):
# Every word should be one of the original prompt (see https://github.com/simonw/llm-markov/blob/657ca504bcf9f0bfc1c6ee5fe838cde9a8976381/tests/test_llm_markov.py#L20)
for w in prompt.split(" "):
self.assertIn(w, generated)

def test_prompt_markov(self):
out = self.expect_success(":memory:", "select prompt('hello world', 'markov')")
out = self.expect_success(
*self.path_args, "select prompt('hello world', 'markov')"
)
self.assertMarkovResult("hello world", out)

def test_prompt_default_markov(self):
self.assertEqual(llm_cli.get_default_model(), "markov")
out = self.expect_success(":memory:", "select prompt('hello world')")
out = self.expect_success(*self.path_args, "select prompt('hello world')")
self.assertMarkovResult("hello world", out)

def test_embed_hazo(self):
out = self.expect_success(":memory:", "select embed('hello world', 'hazo')")
out = self.expect_success(
*self.path_args, "select embed('hello world', 'hazo')"
)
self.assertEqual(
"('[5.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]',)\n",
out,
)

def test_embed_hazo_binary(self):
self.assertTrue(llm.get_embedding_model("hazo").supports_binary)
self.expect_success(":memory:", "select embed(randomblob(16), 'hazo')")
self.expect_success(*self.path_args, "select embed(randomblob(16), 'hazo')")

def test_embed_default_hazo(self):
self.assertEqual(llm_cli.get_default_embedding_model(), "hazo")
out = self.expect_success(*self.path_args, "select embed('hello world')")
self.assertEqual(
"('[5.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]',)\n",
out,
)


class DefaultInMemorySQLiteTest(InMemorySQLiteTest):
"""--sqlite is omitted and should be the default, so all test cases remain the same"""

def setUp(self):
super().setUp()
self.path_args = (":memory:",)


class DiskSQLiteTest(InMemorySQLiteTest):
db_fp = None
path_args = ()

def setUp(self):
super().setUp()
self.db_fp = str(new_tempfile())
self.path_args = (
"--sqlite",
self.db_fp,
)

def test_embed_default_hazo_leaves_valid_db_behind(self):
# This should probably be called for all test cases
super().test_embed_default_hazo()
self.assertTrue(TsellmConsoleMixin().is_sqlite(self.db_fp))


class InMemoryDuckDBTest(InMemorySQLiteTest):
def setUp(self):
super().setUp()
self.path_args = (
"--duckdb",
":memory:",
)

def test_duckdb_execute(self):
out = self.expect_success(*self.path_args, "select 'Hello World!'")
self.assertIn("('Hello World!',)", out)

def test_cli_execute_incomplete_sql(self):
pass

def test_cli_execute_too_much_sql(self):
pass

def test_embed_default_hazo(self):
self.assertEqual(llm_cli.get_default_embedding_model(), "hazo")
out = self.expect_success(":memory:", "select embed('hello world')")
self.assertEqual(
"('[5.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]',)\n",
out,
)
# See https://github.com/Florents-Tselai/tsellm/issues/24
pass

def test_prompt_default_markov(self):
# See https://github.com/Florents-Tselai/tsellm/issues/24
pass

def test_embed_hazo_binary(self):
# See https://github.com/Florents-Tselai/tsellm/issues/25
pass


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions tsellm/__main__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sys

from .cli import cli

if __name__ == "__main__":
Expand Down
3 changes: 3 additions & 0 deletions tsellm/__version__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
__title__ = "tsellm"
__description__ = "Use LLMs in SQLite and DuckDB"
__version__ = "0.1.0a10"
Loading

0 comments on commit 7802049

Please sign in to comment.