From 7802049865c85a90fdd70dab8cefa78073e1184f Mon Sep 17 00:00:00 2001 From: Florents Tselai Date: Sun, 7 Jul 2024 12:10:05 +0300 Subject: [PATCH] Add support for DuckDB (#27) --- setup.py | 9 +- tests/test_tsellm.py | 246 +++++++++++++++++++++++++---- tsellm/__main__.py | 1 + tsellm/__version__.py | 3 + tsellm/cli.py | 358 +++++++++++++++++++++++++++++++++++------- tsellm/core.py | 10 +- 6 files changed, 529 insertions(+), 98 deletions(-) create mode 100644 tsellm/__version__.py diff --git a/setup.py b/setup.py index a1d3f63..f7bb703 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,6 @@ from setuptools import setup import os - -VERSION = "0.1.0a10" +from tsellm import __version__ def get_long_description(): @@ -14,7 +13,7 @@ def get_long_description(): setup( name="tsellm", - description="Interactive SQLite shell with LLM support", + description=__version__.__description__, long_description=get_long_description(), long_description_content_type="text/markdown", author="Florents Tselai", @@ -29,9 +28,9 @@ def get_long_description(): "Changelog": "https://github.com/Florents-Tselai/tsellm/releases", }, license="BSD License", - version=VERSION, + version=__version__.__version__, packages=["tsellm"], - install_requires=["llm", "setuptools", "pip"], + install_requires=["llm", "setuptools", "pip", "duckdb"], extras_require={ "test": [ "pytest", diff --git a/tests/test_tsellm.py b/tests/test_tsellm.py index 51a1f51..5a15e87 100644 --- a/tests/test_tsellm.py +++ b/tests/test_tsellm.py @@ -1,15 +1,46 @@ -import llm.cli -from sqlite_utils import Database -from tsellm.cli import cli +import sqlite3 +import tempfile import unittest -from test.support import captured_stdout, captured_stderr, captured_stdin, os_helper +from pathlib import Path +from test.support import captured_stdout, captured_stderr, captured_stdin from test.support.os_helper import TESTFN, unlink -from llm import models -import sqlite3 + +import duckdb +import llm.cli from llm import cli as llm_cli +from tsellm.__version__ import __version__ +from tsellm.cli import ( + cli, + TsellmConsole, + SQLiteConsole, + TsellmConsoleMixin, +) + + +def new_tempfile(): + return Path(tempfile.mkdtemp()) / "test" + + +def new_sqlite_file(): + f = new_tempfile() + with sqlite3.connect(f) as db: + db.execute("SELECT 1") + return f + -class CommandLineInterface(unittest.TestCase): +def new_duckdb_file(): + f = new_tempfile() + con = duckdb.connect(f.__str__()) + con.sql("SELECT 1") + return f + + +class TsellmConsoleTest(unittest.TestCase): + def setUp(self): + super().setUp() + llm_cli.set_default_model("markov") + llm_cli.set_default_embedding_model("hazo") def _do_test(self, *args, expect_success=True): with ( @@ -38,25 +69,132 @@ def expect_failure(self, *args): self.assertEqual(out, "") return err + def test_sniff_sqlite(self): + self.assertTrue(TsellmConsoleMixin().is_sqlite(new_sqlite_file())) + + def test_sniff_duckdb(self): + self.assertTrue(TsellmConsoleMixin().is_duckdb(new_duckdb_file())) + + def test_console_factory_sqlite(self): + s = new_sqlite_file() + self.assertTrue(TsellmConsoleMixin().is_sqlite(s)) + obj = TsellmConsole.create_console(s) + self.assertIsInstance(obj, SQLiteConsole) + + # def test_console_factory_duckdb(self): + # s = new_duckdb_file() + # self.assertTrue(TsellmConsole.is_duckdb(s)) + # obj = TsellmConsole.create_console(s) + # self.assertIsInstance(obj, DuckDBConsole) + def test_cli_help(self): out = self.expect_success("-h") self.assertIn("usage: python -m tsellm", out) def test_cli_version(self): out = self.expect_success("-v") + self.assertIn(__version__, out) + + def test_choose_db(self): + self.expect_failure("--sqlite", "--duckdb") + + def test_deault_sqlite(self): + f = new_tempfile() + self.expect_success(str(f), "select 1") + self.assertTrue(TsellmConsoleMixin().is_sqlite(f)) + + MEMORY_DB_MSG = "Connected to :memory:" + PS1 = "tsellm> " + PS2 = "... " + + def run_cli(self, *args, commands=()): + with ( + captured_stdin() as stdin, + captured_stdout() as stdout, + captured_stderr() as stderr, + self.assertRaises(SystemExit) as cm + ): + for cmd in commands: + stdin.write(cmd + "\n") + stdin.seek(0) + cli(args) + + out = stdout.getvalue() + err = stderr.getvalue() + self.assertEqual(cm.exception.code, 0, + f"Unexpected failure: {args=}\n{out}\n{err}") + return out, err + + def test_interact(self): + out, err = self.run_cli() + self.assertIn(self.MEMORY_DB_MSG, err) + self.assertIn(self.MEMORY_DB_MSG, err) + self.assertTrue(out.endswith(self.PS1)) + self.assertEqual(out.count(self.PS1), 1) + self.assertEqual(out.count(self.PS2), 0) + + def test_interact_quit(self): + out, err = self.run_cli(commands=(".quit",)) + self.assertIn(self.MEMORY_DB_MSG, err) + self.assertTrue(out.endswith(self.PS1)) + self.assertEqual(out.count(self.PS1), 1) + self.assertEqual(out.count(self.PS2), 0) + + def test_interact_version(self): + out, err = self.run_cli(commands=(".version",)) + self.assertIn(self.MEMORY_DB_MSG, err) + self.assertIn(sqlite3.sqlite_version + "\n", out) + self.assertTrue(out.endswith(self.PS1)) + self.assertEqual(out.count(self.PS1), 2) + self.assertEqual(out.count(self.PS2), 0) self.assertIn(sqlite3.sqlite_version, out) + def test_interact_valid_sql(self): + out, err = self.run_cli(commands=("SELECT 1;",)) + self.assertIn(self.MEMORY_DB_MSG, err) + self.assertIn("(1,)\n", out) + self.assertTrue(out.endswith(self.PS1)) + self.assertEqual(out.count(self.PS1), 2) + self.assertEqual(out.count(self.PS2), 0) + + def test_interact_incomplete_multiline_sql(self): + out, err = self.run_cli(commands=("SELECT 1",)) + self.assertIn(self.MEMORY_DB_MSG, err) + self.assertTrue(out.endswith(self.PS2)) + self.assertEqual(out.count(self.PS1), 1) + self.assertEqual(out.count(self.PS2), 1) + + def test_interact_valid_multiline_sql(self): + out, err = self.run_cli(commands=("SELECT 1\n;",)) + self.assertIn(self.MEMORY_DB_MSG, err) + self.assertIn(self.PS2, out) + self.assertIn("(1,)\n", out) + self.assertTrue(out.endswith(self.PS1)) + self.assertEqual(out.count(self.PS1), 2) + self.assertEqual(out.count(self.PS2), 1) + + +class InMemorySQLiteTest(TsellmConsoleTest): + path_args = None + + def setUp(self): + super().setUp() + self.path_args = ( + "--sqlite", + ":memory:", + ) + def test_cli_execute_sql(self): - out = self.expect_success(":memory:", "select 1") + out = self.expect_success(*self.path_args, "select 1") self.assertIn("(1,)", out) def test_cli_execute_too_much_sql(self): - stderr = self.expect_failure(":memory:", "select 1; select 2") + stderr = self.expect_failure(*self.path_args, "select 1; select 2") err = "ProgrammingError: You can only execute one statement at a time" self.assertIn(err, stderr) def test_cli_execute_incomplete_sql(self): - stderr = self.expect_failure(":memory:", "sel") + stderr = self.expect_failure(*self.path_args, "sel") self.assertIn("OperationalError (SQLITE_ERROR)", stderr) def test_cli_on_disk_db(self): @@ -66,30 +204,26 @@ def test_cli_on_disk_db(self): out = self.expect_success(TESTFN, "select count(t) from t") self.assertIn("(0,)", out) - -class SQLiteLLMFunction(CommandLineInterface): - - def setUp(self): - super().setUp() - llm_cli.set_default_model("markov") - llm_cli.set_default_embedding_model("hazo") - def assertMarkovResult(self, prompt, generated): # Every word should be one of the original prompt (see https://github.com/simonw/llm-markov/blob/657ca504bcf9f0bfc1c6ee5fe838cde9a8976381/tests/test_llm_markov.py#L20) for w in prompt.split(" "): self.assertIn(w, generated) def test_prompt_markov(self): - out = self.expect_success(":memory:", "select prompt('hello world', 'markov')") + out = self.expect_success( + *self.path_args, "select prompt('hello world', 'markov')" + ) self.assertMarkovResult("hello world", out) def test_prompt_default_markov(self): self.assertEqual(llm_cli.get_default_model(), "markov") - out = self.expect_success(":memory:", "select prompt('hello world')") + out = self.expect_success(*self.path_args, "select prompt('hello world')") self.assertMarkovResult("hello world", out) def test_embed_hazo(self): - out = self.expect_success(":memory:", "select embed('hello world', 'hazo')") + out = self.expect_success( + *self.path_args, "select embed('hello world', 'hazo')" + ) self.assertEqual( "('[5.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]',)\n", out, @@ -97,16 +231,72 @@ def test_embed_hazo(self): def test_embed_hazo_binary(self): self.assertTrue(llm.get_embedding_model("hazo").supports_binary) - self.expect_success(":memory:", "select embed(randomblob(16), 'hazo')") + self.expect_success(*self.path_args, "select embed(randomblob(16), 'hazo')") + + def test_embed_default_hazo(self): + self.assertEqual(llm_cli.get_default_embedding_model(), "hazo") + out = self.expect_success(*self.path_args, "select embed('hello world')") + self.assertEqual( + "('[5.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]',)\n", + out, + ) + + +class DefaultInMemorySQLiteTest(InMemorySQLiteTest): + """--sqlite is omitted and should be the default, so all test cases remain the same""" + + def setUp(self): + super().setUp() + self.path_args = (":memory:",) +class DiskSQLiteTest(InMemorySQLiteTest): + db_fp = None + path_args = () + + def setUp(self): + super().setUp() + self.db_fp = str(new_tempfile()) + self.path_args = ( + "--sqlite", + self.db_fp, + ) + + def test_embed_default_hazo_leaves_valid_db_behind(self): + # This should probably be called for all test cases + super().test_embed_default_hazo() + self.assertTrue(TsellmConsoleMixin().is_sqlite(self.db_fp)) + + +class InMemoryDuckDBTest(InMemorySQLiteTest): + def setUp(self): + super().setUp() + self.path_args = ( + "--duckdb", + ":memory:", + ) + + def test_duckdb_execute(self): + out = self.expect_success(*self.path_args, "select 'Hello World!'") + self.assertIn("('Hello World!',)", out) + + def test_cli_execute_incomplete_sql(self): + pass + + def test_cli_execute_too_much_sql(self): + pass + def test_embed_default_hazo(self): - self.assertEqual(llm_cli.get_default_embedding_model(), "hazo") - out = self.expect_success(":memory:", "select embed('hello world')") - self.assertEqual( - "('[5.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]',)\n", - out, - ) + # See https://github.com/Florents-Tselai/tsellm/issues/24 + pass + + def test_prompt_default_markov(self): + # See https://github.com/Florents-Tselai/tsellm/issues/24 + pass + + def test_embed_hazo_binary(self): + # See https://github.com/Florents-Tselai/tsellm/issues/25 + pass if __name__ == "__main__": diff --git a/tsellm/__main__.py b/tsellm/__main__.py index 189a4fd..c1ea504 100644 --- a/tsellm/__main__.py +++ b/tsellm/__main__.py @@ -1,4 +1,5 @@ import sys + from .cli import cli if __name__ == "__main__": diff --git a/tsellm/__version__.py b/tsellm/__version__.py new file mode 100644 index 0000000..22a9227 --- /dev/null +++ b/tsellm/__version__.py @@ -0,0 +1,3 @@ +__title__ = "tsellm" +__description__ = "Use LLMs in SQLite and DuckDB" +__version__ = "0.1.0a10" diff --git a/tsellm/cli.py b/tsellm/cli.py index aa320ee..36e16f5 100644 --- a/tsellm/cli.py +++ b/tsellm/cli.py @@ -1,41 +1,164 @@ import sqlite3 import sys - +from abc import ABC, abstractmethod from argparse import ArgumentParser from code import InteractiveConsole +from dataclasses import dataclass, field +from enum import Enum, auto +from pathlib import Path from textwrap import dedent -from .core import _tsellm_init +from typing import Union +import duckdb -def execute(c, sql, suppress_errors=True): - """Helper that wraps execution of SQL code. +from . import __version__ +from .core import ( + _prompt_model, + _prompt_model_default, + _embed_model, + _embed_model_default, +) - This is used both by the REPL and by direct execution from the CLI. - 'c' may be a cursor or a connection. - 'sql' is the SQL string to execute. - """ +class DatabaseType(Enum): + SQLITE = auto() + DUCKDB = auto() + UNKNOWN = auto() + FILE_NOT_FOUND = auto() + ERROR = auto() - try: - for row in c.execute(sql): - print(row) - except sqlite3.Error as e: - tp = type(e).__name__ + +sys.ps1 = "tsellm> " +sys.ps2 = " ... " + + +class TsellmConsoleMixin(InteractiveConsole): + def is_sqlite(self, path): try: - print(f"{tp} ({e.sqlite_errorname}): {e}", file=sys.stderr) - except AttributeError: - print(f"{tp}: {e}", file=sys.stderr) - if not suppress_errors: - sys.exit(1) + with sqlite3.connect(path) as conn: + conn.execute("SELECT 1") + return True + except: + return False + def is_duckdb(self, path): + try: + con = duckdb.connect(path.__str__()) + con.sql("SELECT 1") + return True + except: + return False -class SqliteInteractiveConsole(InteractiveConsole): - """A simple SQLite REPL.""" + def sniff_db(self, path): + """ + Sniffs if the path is a SQLite or DuckDB database. - def __init__(self, connection): - super().__init__() - self._con = connection - self._cur = connection.cursor() + Args: + path (str): The file path to check. + + Returns: + DatabaseType: The type of database (DatabaseType.SQLITE, DatabaseType.DUCKDB, + DatabaseType.UNKNOWN, DatabaseType.FILE_NOT_FOUND, DatabaseType.ERROR). + """ + + if TsellmConsole.is_sqlite(path): + return DatabaseType.SQLITE + if TsellmConsole.is_duckdb(path): + return DatabaseType.DUCKDB + return DatabaseType.UNKNOWN + + +@dataclass +class TsellmConsole(InteractiveConsole, ABC): + _TSELLM_CONFIG_SQL = """ +-- tsellm configuration table +-- need to be taken care of accross migrations and versions. + +CREATE TABLE IF NOT EXISTS __tsellm ( +x text +); + +""" + + _functions = [ + ("prompt", 2, _prompt_model, False), + ("prompt", 1, _prompt_model_default, False), + ("embed", 2, _embed_model, False), + ("embed", 1, _embed_model_default, False), + ] + + error_class = None + db_type: str = field(init=False) + connection: Union[sqlite3.Connection, duckdb.DuckDBPyConnection] = field(init=False) + + @property + def tsellm_version(self) -> str: + return __version__.__version__ + + @property + def eofkey(self): + if sys.platform == "win32" and "idlelib.run" not in sys.modules: + return "CTRL-Z" + else: + return "CTRL-D" + + @property + def db_name(self): + return self.path + + @property + def banner(self) -> str: + return dedent( + f""" + tsellm shell version {self.tsellm_version}, running on {self.db_type} version {self.db_version} + Connected to {self.db_name} + + Each command will be run using execute() on the cursor. + Type ".help" for more information; type ".quit" or {self.eofkey} to quit. + """ + ).strip() + + @property + @abstractmethod + def db_version(self) -> str: + pass + + @property + @abstractmethod + def is_valid_db(self) -> bool: + pass + + @abstractmethod + def complete_statement(self, source) -> bool: + pass + + @property + def version(self): + return " ".join([ + "tsellm version", + self.tsellm_version, + self.db_type, + "version", + self.db_version] + ) + + def load(self): + self.execute(self._TSELLM_CONFIG_SQL) + for func_name, n_args, py_func, deterministic in self._functions: + self.connection.create_function(func_name, n_args, py_func) + + @staticmethod + def create_console(path): + if TsellmConsoleMixin().is_duckdb(path): + return DuckDBConsole(path) + if TsellmConsoleMixin().is_sqlite(path): + return SQLiteConsole(path) + else: + raise ValueError(f"Database type {path} not supported") + + @abstractmethod + def execute(self, sql, suppress_errors=True): + pass def runsource(self, source, filename="", symbol="single"): """Override runsource, the core of the InteractiveConsole REPL. @@ -45,19 +168,125 @@ def runsource(self, source, filename="", symbol="single"): """ match source: case ".version": - print(f"{sqlite3.sqlite_version}") + print(f"{self.version}") case ".help": print("Enter SQL code and press enter.") case ".quit": sys.exit(0) case _: - if not sqlite3.complete_statement(source): + if not self.complete_statement(source): return True - execute(self._cur, source) + self.execute(source) return False + @abstractmethod + def connect(self): + pass -def cli(*args): + def __post_init__(self): + super().__init__() + self.connect() + self._cur = self.connection.cursor() + self.load() + + +@dataclass +class SQLiteConsole(TsellmConsole): + db_type = "SQLite" + + def connect(self): + self.connection = sqlite3.connect(self.path, isolation_level=None) + + path: Union[Path, str, sqlite3.Connection, duckdb.DuckDBPyConnection] + error_class = sqlite3.Error + + def complete_statement(self, source) -> bool: + return sqlite3.complete_statement(source) + + @property + def is_valid_db(self) -> bool: + pass + + def execute(self, sql, suppress_errors=True): + """Helper that wraps execution of SQL code. + + This is used both by the REPL and by direct execution from the CLI. + + 'c' may be a cursor or a connection. + 'sql' is the SQL string to execute. + """ + + try: + for row in self._cur.execute(sql): + print(row) + except self.error_class as e: + tp = type(e).__name__ + try: + print(f"{tp} ({e.sqlite_errorname}): {e}", file=sys.stderr) + except AttributeError: + print(f"{tp}: {e}", file=sys.stderr) + if not suppress_errors: + sys.exit(1) + + @property + def db_version(self): + return sqlite3.sqlite_version + + +@dataclass +class DuckDBConsole(TsellmConsole): + db_type = "DuckDB" + path: Union[Path, str, sqlite3.Connection, duckdb.DuckDBPyConnection] + + def complete_statement(self, source) -> bool: + return sqlite3.complete_statement(source) + + @property + def is_valid_db(self) -> bool: + pass + + error_class = sqlite3.Error + + _functions = [ + ("prompt", 2, _prompt_model, False), + ("embed", 2, _embed_model, False), + ] + + def connect(self): + self.connection = duckdb.connect(self.path) + + def load(self): + self.execute(self._TSELLM_CONFIG_SQL) + for func_name, _, py_func, _ in self._functions: + self.connection.create_function(func_name, py_func) + + @property + def db_version(self): + return duckdb.__version__ + + def execute(self, sql, suppress_errors=True): + """Helper that wraps execution of SQL code. + + This is used both by the REPL and by direct execution from the CLI. + + 'c' may be a cursor or a connection. + 'sql' is the SQL string to execute. + """ + + try: + for row in self.connection.execute(sql).fetchall(): + print(row) + except self.error_class as e: + tp = type(e).__name__ + try: + print(f"{tp} ({e.sqlite_errorname}): {e}", file=sys.stderr) + except AttributeError: + print(f"{tp}: {e}", file=sys.stderr) + if not suppress_errors: + sys.exit(1) + + +def make_parser(): parser = ArgumentParser( description="tsellm sqlite3 CLI", prog="python -m tsellm", @@ -68,7 +297,7 @@ def cli(*args): default=":memory:", nargs="?", help=( - "SQLite database to open (defaults to ':memory:'). " + "SQLite/DuckDB database to open (defaults to SQLite ':memory:'). " "A new database is created if the file does not previously exist." ), ) @@ -78,52 +307,61 @@ def cli(*args): nargs="?", help=("An SQL query to execute. " "Any returned rows are printed to stdout."), ) + + # Create a mutually exclusive group + group = parser.add_mutually_exclusive_group() + + # Add the SQLite argument + group.add_argument( + "--sqlite", + action="store_true", + default=False, # Change the default to False to ensure only one can be true + help="SQLite mode", + ) + + # Add the DuckDB argument + group.add_argument( + "--duckdb", + action="store_true", + default=False, # Set the default to False + help="DuckDB mode", + ) + parser.add_argument( "-v", "--version", action="version", - version=f"SQLite version {sqlite3.sqlite_version}", + version=f"tsellm version {__version__.__version__}", help="Print underlying SQLite library version", ) - args = parser.parse_args(*args) - - if args.filename == ":memory:": - db_name = "a transient in-memory database" - else: - db_name = repr(args.filename) - - # Prepare REPL banner and prompts. - if sys.platform == "win32" and "idlelib.run" not in sys.modules: - eofkey = "CTRL-Z" - else: - eofkey = "CTRL-D" - banner = dedent( - f""" - tsellm shell, running on SQLite version {sqlite3.sqlite_version} - Connected to {db_name} + return parser + + +def cli(*args): + args = make_parser().parse_args(*args) + + if args.sqlite and args.duckdb: + raise ValueError("Only one of --sqlite and --duckdb can be specified.") + + if (not args.sqlite) and (not args.duckdb) and args.filename == ":memory:": + args.sqlite = True + args.duckdb = False + + console = ( + DuckDBConsole(args.filename) if args.duckdb else SQLiteConsole(args.filename) + ) - Each command will be run using execute() on the cursor. - Type ".help" for more information; type ".quit" or {eofkey} to quit. - """ - ).strip() - sys.ps1 = "tsellm> " - sys.ps2 = " ... " - - con = sqlite3.connect(args.filename, isolation_level=None) - _tsellm_init(con) try: if args.sql: # SQL statement provided on the command-line; execute it directly. - execute(con, args.sql, suppress_errors=False) + console.execute(args.sql, suppress_errors=False) else: - # No SQL provided; start the REPL. - console = SqliteInteractiveConsole(con) try: import readline except ImportError: pass - console.interact(banner, exitmsg="") + console.interact(console.banner, exitmsg="") finally: - con.close() + console.connection.close() sys.exit(0) diff --git a/tsellm/core.py b/tsellm/core.py index f844207..5f87e1f 100644 --- a/tsellm/core.py +++ b/tsellm/core.py @@ -1,6 +1,6 @@ -import llm import json +import llm from llm import cli as llm_cli TSELLM_CONFIG_SQL = """ @@ -14,19 +14,19 @@ """ -def _prompt_model(prompt, model): +def _prompt_model(prompt: str, model: str) -> str: return llm.get_model(model).prompt(prompt).text() -def _prompt_model_default(prompt): +def _prompt_model_default(prompt: str) -> str: return llm.get_model("markov").prompt(prompt).text() -def _embed_model(text, model): +def _embed_model(text: str, model: str) -> str: return json.dumps(llm.get_embedding_model(model).embed(text)) -def _embed_model_default(text): +def _embed_model_default(text: str) -> str: return json.dumps( llm.get_embedding_model(llm_cli.get_default_embedding_model()).embed(text) )