diff --git a/setup.py b/setup.py
index a1d3f63..f7bb703 100644
--- a/setup.py
+++ b/setup.py
@@ -1,7 +1,6 @@
from setuptools import setup
import os
-
-VERSION = "0.1.0a10"
+from tsellm import __version__
def get_long_description():
@@ -14,7 +13,7 @@ def get_long_description():
setup(
name="tsellm",
- description="Interactive SQLite shell with LLM support",
+ description=__version__.__description__,
long_description=get_long_description(),
long_description_content_type="text/markdown",
author="Florents Tselai",
@@ -29,9 +28,9 @@ def get_long_description():
"Changelog": "https://github.com/Florents-Tselai/tsellm/releases",
},
license="BSD License",
- version=VERSION,
+ version=__version__.__version__,
packages=["tsellm"],
- install_requires=["llm", "setuptools", "pip"],
+ install_requires=["llm", "setuptools", "pip", "duckdb"],
extras_require={
"test": [
"pytest",
diff --git a/tests/test_tsellm.py b/tests/test_tsellm.py
index 51a1f51..5a15e87 100644
--- a/tests/test_tsellm.py
+++ b/tests/test_tsellm.py
@@ -1,15 +1,46 @@
-import llm.cli
-from sqlite_utils import Database
-from tsellm.cli import cli
+import sqlite3
+import tempfile
import unittest
-from test.support import captured_stdout, captured_stderr, captured_stdin, os_helper
+from pathlib import Path
+from test.support import captured_stdout, captured_stderr, captured_stdin
from test.support.os_helper import TESTFN, unlink
-from llm import models
-import sqlite3
+
+import duckdb
+import llm.cli
from llm import cli as llm_cli
+from tsellm.__version__ import __version__
+from tsellm.cli import (
+ cli,
+ TsellmConsole,
+ SQLiteConsole,
+ TsellmConsoleMixin,
+)
+
+
+def new_tempfile():
+ return Path(tempfile.mkdtemp()) / "test"
+
+
+def new_sqlite_file():
+ f = new_tempfile()
+ with sqlite3.connect(f) as db:
+ db.execute("SELECT 1")
+ return f
+
-class CommandLineInterface(unittest.TestCase):
+def new_duckdb_file():
+ f = new_tempfile()
+ con = duckdb.connect(f.__str__())
+ con.sql("SELECT 1")
+ return f
+
+
+class TsellmConsoleTest(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+ llm_cli.set_default_model("markov")
+ llm_cli.set_default_embedding_model("hazo")
def _do_test(self, *args, expect_success=True):
with (
@@ -38,25 +69,132 @@ def expect_failure(self, *args):
self.assertEqual(out, "")
return err
+ def test_sniff_sqlite(self):
+ self.assertTrue(TsellmConsoleMixin().is_sqlite(new_sqlite_file()))
+
+ def test_sniff_duckdb(self):
+ self.assertTrue(TsellmConsoleMixin().is_duckdb(new_duckdb_file()))
+
+ def test_console_factory_sqlite(self):
+ s = new_sqlite_file()
+ self.assertTrue(TsellmConsoleMixin().is_sqlite(s))
+ obj = TsellmConsole.create_console(s)
+ self.assertIsInstance(obj, SQLiteConsole)
+
+ # def test_console_factory_duckdb(self):
+ # s = new_duckdb_file()
+ # self.assertTrue(TsellmConsole.is_duckdb(s))
+ # obj = TsellmConsole.create_console(s)
+ # self.assertIsInstance(obj, DuckDBConsole)
+
def test_cli_help(self):
out = self.expect_success("-h")
self.assertIn("usage: python -m tsellm", out)
def test_cli_version(self):
out = self.expect_success("-v")
+ self.assertIn(__version__, out)
+
+ def test_choose_db(self):
+ self.expect_failure("--sqlite", "--duckdb")
+
+ def test_deault_sqlite(self):
+ f = new_tempfile()
+ self.expect_success(str(f), "select 1")
+ self.assertTrue(TsellmConsoleMixin().is_sqlite(f))
+
+ MEMORY_DB_MSG = "Connected to :memory:"
+ PS1 = "tsellm> "
+ PS2 = "... "
+
+ def run_cli(self, *args, commands=()):
+ with (
+ captured_stdin() as stdin,
+ captured_stdout() as stdout,
+ captured_stderr() as stderr,
+ self.assertRaises(SystemExit) as cm
+ ):
+ for cmd in commands:
+ stdin.write(cmd + "\n")
+ stdin.seek(0)
+ cli(args)
+
+ out = stdout.getvalue()
+ err = stderr.getvalue()
+ self.assertEqual(cm.exception.code, 0,
+ f"Unexpected failure: {args=}\n{out}\n{err}")
+ return out, err
+
+ def test_interact(self):
+ out, err = self.run_cli()
+ self.assertIn(self.MEMORY_DB_MSG, err)
+ self.assertIn(self.MEMORY_DB_MSG, err)
+ self.assertTrue(out.endswith(self.PS1))
+ self.assertEqual(out.count(self.PS1), 1)
+ self.assertEqual(out.count(self.PS2), 0)
+
+ def test_interact_quit(self):
+ out, err = self.run_cli(commands=(".quit",))
+ self.assertIn(self.MEMORY_DB_MSG, err)
+ self.assertTrue(out.endswith(self.PS1))
+ self.assertEqual(out.count(self.PS1), 1)
+ self.assertEqual(out.count(self.PS2), 0)
+
+ def test_interact_version(self):
+ out, err = self.run_cli(commands=(".version",))
+ self.assertIn(self.MEMORY_DB_MSG, err)
+ self.assertIn(sqlite3.sqlite_version + "\n", out)
+ self.assertTrue(out.endswith(self.PS1))
+ self.assertEqual(out.count(self.PS1), 2)
+ self.assertEqual(out.count(self.PS2), 0)
self.assertIn(sqlite3.sqlite_version, out)
+ def test_interact_valid_sql(self):
+ out, err = self.run_cli(commands=("SELECT 1;",))
+ self.assertIn(self.MEMORY_DB_MSG, err)
+ self.assertIn("(1,)\n", out)
+ self.assertTrue(out.endswith(self.PS1))
+ self.assertEqual(out.count(self.PS1), 2)
+ self.assertEqual(out.count(self.PS2), 0)
+
+ def test_interact_incomplete_multiline_sql(self):
+ out, err = self.run_cli(commands=("SELECT 1",))
+ self.assertIn(self.MEMORY_DB_MSG, err)
+ self.assertTrue(out.endswith(self.PS2))
+ self.assertEqual(out.count(self.PS1), 1)
+ self.assertEqual(out.count(self.PS2), 1)
+
+ def test_interact_valid_multiline_sql(self):
+ out, err = self.run_cli(commands=("SELECT 1\n;",))
+ self.assertIn(self.MEMORY_DB_MSG, err)
+ self.assertIn(self.PS2, out)
+ self.assertIn("(1,)\n", out)
+ self.assertTrue(out.endswith(self.PS1))
+ self.assertEqual(out.count(self.PS1), 2)
+ self.assertEqual(out.count(self.PS2), 1)
+
+
+class InMemorySQLiteTest(TsellmConsoleTest):
+ path_args = None
+
+ def setUp(self):
+ super().setUp()
+ self.path_args = (
+ "--sqlite",
+ ":memory:",
+ )
+
def test_cli_execute_sql(self):
- out = self.expect_success(":memory:", "select 1")
+ out = self.expect_success(*self.path_args, "select 1")
self.assertIn("(1,)", out)
def test_cli_execute_too_much_sql(self):
- stderr = self.expect_failure(":memory:", "select 1; select 2")
+ stderr = self.expect_failure(*self.path_args, "select 1; select 2")
err = "ProgrammingError: You can only execute one statement at a time"
self.assertIn(err, stderr)
def test_cli_execute_incomplete_sql(self):
- stderr = self.expect_failure(":memory:", "sel")
+ stderr = self.expect_failure(*self.path_args, "sel")
self.assertIn("OperationalError (SQLITE_ERROR)", stderr)
def test_cli_on_disk_db(self):
@@ -66,30 +204,26 @@ def test_cli_on_disk_db(self):
out = self.expect_success(TESTFN, "select count(t) from t")
self.assertIn("(0,)", out)
-
-class SQLiteLLMFunction(CommandLineInterface):
-
- def setUp(self):
- super().setUp()
- llm_cli.set_default_model("markov")
- llm_cli.set_default_embedding_model("hazo")
-
def assertMarkovResult(self, prompt, generated):
# Every word should be one of the original prompt (see https://github.com/simonw/llm-markov/blob/657ca504bcf9f0bfc1c6ee5fe838cde9a8976381/tests/test_llm_markov.py#L20)
for w in prompt.split(" "):
self.assertIn(w, generated)
def test_prompt_markov(self):
- out = self.expect_success(":memory:", "select prompt('hello world', 'markov')")
+ out = self.expect_success(
+ *self.path_args, "select prompt('hello world', 'markov')"
+ )
self.assertMarkovResult("hello world", out)
def test_prompt_default_markov(self):
self.assertEqual(llm_cli.get_default_model(), "markov")
- out = self.expect_success(":memory:", "select prompt('hello world')")
+ out = self.expect_success(*self.path_args, "select prompt('hello world')")
self.assertMarkovResult("hello world", out)
def test_embed_hazo(self):
- out = self.expect_success(":memory:", "select embed('hello world', 'hazo')")
+ out = self.expect_success(
+ *self.path_args, "select embed('hello world', 'hazo')"
+ )
self.assertEqual(
"('[5.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]',)\n",
out,
@@ -97,16 +231,72 @@ def test_embed_hazo(self):
def test_embed_hazo_binary(self):
self.assertTrue(llm.get_embedding_model("hazo").supports_binary)
- self.expect_success(":memory:", "select embed(randomblob(16), 'hazo')")
+ self.expect_success(*self.path_args, "select embed(randomblob(16), 'hazo')")
+
+ def test_embed_default_hazo(self):
+ self.assertEqual(llm_cli.get_default_embedding_model(), "hazo")
+ out = self.expect_success(*self.path_args, "select embed('hello world')")
+ self.assertEqual(
+ "('[5.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]',)\n",
+ out,
+ )
+
+
+class DefaultInMemorySQLiteTest(InMemorySQLiteTest):
+ """--sqlite is omitted and should be the default, so all test cases remain the same"""
+
+ def setUp(self):
+ super().setUp()
+ self.path_args = (":memory:",)
+class DiskSQLiteTest(InMemorySQLiteTest):
+ db_fp = None
+ path_args = ()
+
+ def setUp(self):
+ super().setUp()
+ self.db_fp = str(new_tempfile())
+ self.path_args = (
+ "--sqlite",
+ self.db_fp,
+ )
+
+ def test_embed_default_hazo_leaves_valid_db_behind(self):
+ # This should probably be called for all test cases
+ super().test_embed_default_hazo()
+ self.assertTrue(TsellmConsoleMixin().is_sqlite(self.db_fp))
+
+
+class InMemoryDuckDBTest(InMemorySQLiteTest):
+ def setUp(self):
+ super().setUp()
+ self.path_args = (
+ "--duckdb",
+ ":memory:",
+ )
+
+ def test_duckdb_execute(self):
+ out = self.expect_success(*self.path_args, "select 'Hello World!'")
+ self.assertIn("('Hello World!',)", out)
+
+ def test_cli_execute_incomplete_sql(self):
+ pass
+
+ def test_cli_execute_too_much_sql(self):
+ pass
+
def test_embed_default_hazo(self):
- self.assertEqual(llm_cli.get_default_embedding_model(), "hazo")
- out = self.expect_success(":memory:", "select embed('hello world')")
- self.assertEqual(
- "('[5.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]',)\n",
- out,
- )
+ # See https://github.com/Florents-Tselai/tsellm/issues/24
+ pass
+
+ def test_prompt_default_markov(self):
+ # See https://github.com/Florents-Tselai/tsellm/issues/24
+ pass
+
+ def test_embed_hazo_binary(self):
+ # See https://github.com/Florents-Tselai/tsellm/issues/25
+ pass
if __name__ == "__main__":
diff --git a/tsellm/__main__.py b/tsellm/__main__.py
index 189a4fd..c1ea504 100644
--- a/tsellm/__main__.py
+++ b/tsellm/__main__.py
@@ -1,4 +1,5 @@
import sys
+
from .cli import cli
if __name__ == "__main__":
diff --git a/tsellm/__version__.py b/tsellm/__version__.py
new file mode 100644
index 0000000..22a9227
--- /dev/null
+++ b/tsellm/__version__.py
@@ -0,0 +1,3 @@
+__title__ = "tsellm"
+__description__ = "Use LLMs in SQLite and DuckDB"
+__version__ = "0.1.0a10"
diff --git a/tsellm/cli.py b/tsellm/cli.py
index aa320ee..36e16f5 100644
--- a/tsellm/cli.py
+++ b/tsellm/cli.py
@@ -1,41 +1,164 @@
import sqlite3
import sys
-
+from abc import ABC, abstractmethod
from argparse import ArgumentParser
from code import InteractiveConsole
+from dataclasses import dataclass, field
+from enum import Enum, auto
+from pathlib import Path
from textwrap import dedent
-from .core import _tsellm_init
+from typing import Union
+import duckdb
-def execute(c, sql, suppress_errors=True):
- """Helper that wraps execution of SQL code.
+from . import __version__
+from .core import (
+ _prompt_model,
+ _prompt_model_default,
+ _embed_model,
+ _embed_model_default,
+)
- This is used both by the REPL and by direct execution from the CLI.
- 'c' may be a cursor or a connection.
- 'sql' is the SQL string to execute.
- """
+class DatabaseType(Enum):
+ SQLITE = auto()
+ DUCKDB = auto()
+ UNKNOWN = auto()
+ FILE_NOT_FOUND = auto()
+ ERROR = auto()
- try:
- for row in c.execute(sql):
- print(row)
- except sqlite3.Error as e:
- tp = type(e).__name__
+
+sys.ps1 = "tsellm> "
+sys.ps2 = " ... "
+
+
+class TsellmConsoleMixin(InteractiveConsole):
+ def is_sqlite(self, path):
try:
- print(f"{tp} ({e.sqlite_errorname}): {e}", file=sys.stderr)
- except AttributeError:
- print(f"{tp}: {e}", file=sys.stderr)
- if not suppress_errors:
- sys.exit(1)
+ with sqlite3.connect(path) as conn:
+ conn.execute("SELECT 1")
+ return True
+ except:
+ return False
+ def is_duckdb(self, path):
+ try:
+ con = duckdb.connect(path.__str__())
+ con.sql("SELECT 1")
+ return True
+ except:
+ return False
-class SqliteInteractiveConsole(InteractiveConsole):
- """A simple SQLite REPL."""
+ def sniff_db(self, path):
+ """
+ Sniffs if the path is a SQLite or DuckDB database.
- def __init__(self, connection):
- super().__init__()
- self._con = connection
- self._cur = connection.cursor()
+ Args:
+ path (str): The file path to check.
+
+ Returns:
+ DatabaseType: The type of database (DatabaseType.SQLITE, DatabaseType.DUCKDB,
+ DatabaseType.UNKNOWN, DatabaseType.FILE_NOT_FOUND, DatabaseType.ERROR).
+ """
+
+ if TsellmConsole.is_sqlite(path):
+ return DatabaseType.SQLITE
+ if TsellmConsole.is_duckdb(path):
+ return DatabaseType.DUCKDB
+ return DatabaseType.UNKNOWN
+
+
+@dataclass
+class TsellmConsole(InteractiveConsole, ABC):
+ _TSELLM_CONFIG_SQL = """
+-- tsellm configuration table
+-- need to be taken care of accross migrations and versions.
+
+CREATE TABLE IF NOT EXISTS __tsellm (
+x text
+);
+
+"""
+
+ _functions = [
+ ("prompt", 2, _prompt_model, False),
+ ("prompt", 1, _prompt_model_default, False),
+ ("embed", 2, _embed_model, False),
+ ("embed", 1, _embed_model_default, False),
+ ]
+
+ error_class = None
+ db_type: str = field(init=False)
+ connection: Union[sqlite3.Connection, duckdb.DuckDBPyConnection] = field(init=False)
+
+ @property
+ def tsellm_version(self) -> str:
+ return __version__.__version__
+
+ @property
+ def eofkey(self):
+ if sys.platform == "win32" and "idlelib.run" not in sys.modules:
+ return "CTRL-Z"
+ else:
+ return "CTRL-D"
+
+ @property
+ def db_name(self):
+ return self.path
+
+ @property
+ def banner(self) -> str:
+ return dedent(
+ f"""
+ tsellm shell version {self.tsellm_version}, running on {self.db_type} version {self.db_version}
+ Connected to {self.db_name}
+
+ Each command will be run using execute() on the cursor.
+ Type ".help" for more information; type ".quit" or {self.eofkey} to quit.
+ """
+ ).strip()
+
+ @property
+ @abstractmethod
+ def db_version(self) -> str:
+ pass
+
+ @property
+ @abstractmethod
+ def is_valid_db(self) -> bool:
+ pass
+
+ @abstractmethod
+ def complete_statement(self, source) -> bool:
+ pass
+
+ @property
+ def version(self):
+ return " ".join([
+ "tsellm version",
+ self.tsellm_version,
+ self.db_type,
+ "version",
+ self.db_version]
+ )
+
+ def load(self):
+ self.execute(self._TSELLM_CONFIG_SQL)
+ for func_name, n_args, py_func, deterministic in self._functions:
+ self.connection.create_function(func_name, n_args, py_func)
+
+ @staticmethod
+ def create_console(path):
+ if TsellmConsoleMixin().is_duckdb(path):
+ return DuckDBConsole(path)
+ if TsellmConsoleMixin().is_sqlite(path):
+ return SQLiteConsole(path)
+ else:
+ raise ValueError(f"Database type {path} not supported")
+
+ @abstractmethod
+ def execute(self, sql, suppress_errors=True):
+ pass
def runsource(self, source, filename="", symbol="single"):
"""Override runsource, the core of the InteractiveConsole REPL.
@@ -45,19 +168,125 @@ def runsource(self, source, filename="", symbol="single"):
"""
match source:
case ".version":
- print(f"{sqlite3.sqlite_version}")
+ print(f"{self.version}")
case ".help":
print("Enter SQL code and press enter.")
case ".quit":
sys.exit(0)
case _:
- if not sqlite3.complete_statement(source):
+ if not self.complete_statement(source):
return True
- execute(self._cur, source)
+ self.execute(source)
return False
+ @abstractmethod
+ def connect(self):
+ pass
-def cli(*args):
+ def __post_init__(self):
+ super().__init__()
+ self.connect()
+ self._cur = self.connection.cursor()
+ self.load()
+
+
+@dataclass
+class SQLiteConsole(TsellmConsole):
+ db_type = "SQLite"
+
+ def connect(self):
+ self.connection = sqlite3.connect(self.path, isolation_level=None)
+
+ path: Union[Path, str, sqlite3.Connection, duckdb.DuckDBPyConnection]
+ error_class = sqlite3.Error
+
+ def complete_statement(self, source) -> bool:
+ return sqlite3.complete_statement(source)
+
+ @property
+ def is_valid_db(self) -> bool:
+ pass
+
+ def execute(self, sql, suppress_errors=True):
+ """Helper that wraps execution of SQL code.
+
+ This is used both by the REPL and by direct execution from the CLI.
+
+ 'c' may be a cursor or a connection.
+ 'sql' is the SQL string to execute.
+ """
+
+ try:
+ for row in self._cur.execute(sql):
+ print(row)
+ except self.error_class as e:
+ tp = type(e).__name__
+ try:
+ print(f"{tp} ({e.sqlite_errorname}): {e}", file=sys.stderr)
+ except AttributeError:
+ print(f"{tp}: {e}", file=sys.stderr)
+ if not suppress_errors:
+ sys.exit(1)
+
+ @property
+ def db_version(self):
+ return sqlite3.sqlite_version
+
+
+@dataclass
+class DuckDBConsole(TsellmConsole):
+ db_type = "DuckDB"
+ path: Union[Path, str, sqlite3.Connection, duckdb.DuckDBPyConnection]
+
+ def complete_statement(self, source) -> bool:
+ return sqlite3.complete_statement(source)
+
+ @property
+ def is_valid_db(self) -> bool:
+ pass
+
+ error_class = sqlite3.Error
+
+ _functions = [
+ ("prompt", 2, _prompt_model, False),
+ ("embed", 2, _embed_model, False),
+ ]
+
+ def connect(self):
+ self.connection = duckdb.connect(self.path)
+
+ def load(self):
+ self.execute(self._TSELLM_CONFIG_SQL)
+ for func_name, _, py_func, _ in self._functions:
+ self.connection.create_function(func_name, py_func)
+
+ @property
+ def db_version(self):
+ return duckdb.__version__
+
+ def execute(self, sql, suppress_errors=True):
+ """Helper that wraps execution of SQL code.
+
+ This is used both by the REPL and by direct execution from the CLI.
+
+ 'c' may be a cursor or a connection.
+ 'sql' is the SQL string to execute.
+ """
+
+ try:
+ for row in self.connection.execute(sql).fetchall():
+ print(row)
+ except self.error_class as e:
+ tp = type(e).__name__
+ try:
+ print(f"{tp} ({e.sqlite_errorname}): {e}", file=sys.stderr)
+ except AttributeError:
+ print(f"{tp}: {e}", file=sys.stderr)
+ if not suppress_errors:
+ sys.exit(1)
+
+
+def make_parser():
parser = ArgumentParser(
description="tsellm sqlite3 CLI",
prog="python -m tsellm",
@@ -68,7 +297,7 @@ def cli(*args):
default=":memory:",
nargs="?",
help=(
- "SQLite database to open (defaults to ':memory:'). "
+ "SQLite/DuckDB database to open (defaults to SQLite ':memory:'). "
"A new database is created if the file does not previously exist."
),
)
@@ -78,52 +307,61 @@ def cli(*args):
nargs="?",
help=("An SQL query to execute. " "Any returned rows are printed to stdout."),
)
+
+ # Create a mutually exclusive group
+ group = parser.add_mutually_exclusive_group()
+
+ # Add the SQLite argument
+ group.add_argument(
+ "--sqlite",
+ action="store_true",
+ default=False, # Change the default to False to ensure only one can be true
+ help="SQLite mode",
+ )
+
+ # Add the DuckDB argument
+ group.add_argument(
+ "--duckdb",
+ action="store_true",
+ default=False, # Set the default to False
+ help="DuckDB mode",
+ )
+
parser.add_argument(
"-v",
"--version",
action="version",
- version=f"SQLite version {sqlite3.sqlite_version}",
+ version=f"tsellm version {__version__.__version__}",
help="Print underlying SQLite library version",
)
- args = parser.parse_args(*args)
-
- if args.filename == ":memory:":
- db_name = "a transient in-memory database"
- else:
- db_name = repr(args.filename)
-
- # Prepare REPL banner and prompts.
- if sys.platform == "win32" and "idlelib.run" not in sys.modules:
- eofkey = "CTRL-Z"
- else:
- eofkey = "CTRL-D"
- banner = dedent(
- f"""
- tsellm shell, running on SQLite version {sqlite3.sqlite_version}
- Connected to {db_name}
+ return parser
+
+
+def cli(*args):
+ args = make_parser().parse_args(*args)
+
+ if args.sqlite and args.duckdb:
+ raise ValueError("Only one of --sqlite and --duckdb can be specified.")
+
+ if (not args.sqlite) and (not args.duckdb) and args.filename == ":memory:":
+ args.sqlite = True
+ args.duckdb = False
+
+ console = (
+ DuckDBConsole(args.filename) if args.duckdb else SQLiteConsole(args.filename)
+ )
- Each command will be run using execute() on the cursor.
- Type ".help" for more information; type ".quit" or {eofkey} to quit.
- """
- ).strip()
- sys.ps1 = "tsellm> "
- sys.ps2 = " ... "
-
- con = sqlite3.connect(args.filename, isolation_level=None)
- _tsellm_init(con)
try:
if args.sql:
# SQL statement provided on the command-line; execute it directly.
- execute(con, args.sql, suppress_errors=False)
+ console.execute(args.sql, suppress_errors=False)
else:
- # No SQL provided; start the REPL.
- console = SqliteInteractiveConsole(con)
try:
import readline
except ImportError:
pass
- console.interact(banner, exitmsg="")
+ console.interact(console.banner, exitmsg="")
finally:
- con.close()
+ console.connection.close()
sys.exit(0)
diff --git a/tsellm/core.py b/tsellm/core.py
index f844207..5f87e1f 100644
--- a/tsellm/core.py
+++ b/tsellm/core.py
@@ -1,6 +1,6 @@
-import llm
import json
+import llm
from llm import cli as llm_cli
TSELLM_CONFIG_SQL = """
@@ -14,19 +14,19 @@
"""
-def _prompt_model(prompt, model):
+def _prompt_model(prompt: str, model: str) -> str:
return llm.get_model(model).prompt(prompt).text()
-def _prompt_model_default(prompt):
+def _prompt_model_default(prompt: str) -> str:
return llm.get_model("markov").prompt(prompt).text()
-def _embed_model(text, model):
+def _embed_model(text: str, model: str) -> str:
return json.dumps(llm.get_embedding_model(model).embed(text))
-def _embed_model_default(text):
+def _embed_model_default(text: str) -> str:
return json.dumps(
llm.get_embedding_model(llm_cli.get_default_embedding_model()).embed(text)
)