-
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c885f4a
commit 92bbc37
Showing
8 changed files
with
225 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from sqlite_utils import Database | ||
from sqlite_utils.utils import sqlite3 | ||
import pytest | ||
|
||
|
||
def pytest_configure(config): | ||
import sys | ||
|
||
sys._called_from_test = True | ||
|
||
|
||
@pytest.fixture | ||
def fresh_db(): | ||
return Database(memory=True) | ||
|
||
|
||
@pytest.fixture | ||
def existing_db(db_path): | ||
database = Database(db_path) | ||
database.executescript( | ||
""" | ||
CREATE TABLE foo (text TEXT); | ||
INSERT INTO foo (text) values ("one"); | ||
INSERT INTO foo (text) values ("two"); | ||
INSERT INTO foo (text) values ("three"); | ||
""" | ||
) | ||
return database | ||
|
||
|
||
@pytest.fixture | ||
def db_path(tmpdir): | ||
path = str(tmpdir / "test.db") | ||
db = sqlite3.connect(path) | ||
return path |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,29 @@ | ||
from tsellm import example_function | ||
from sqlite_utils import Database | ||
|
||
from tsellm.cli import cli | ||
import pytest | ||
import datetime | ||
from click.testing import CliRunner | ||
|
||
def test_example_function(): | ||
assert example_function() == 2 | ||
|
||
def test_cli(db_path): | ||
db = Database(db_path) | ||
assert [] == db.table_names() | ||
table = db.create_table( | ||
"prompts", | ||
{ | ||
"prompt": str, | ||
"generated": str, | ||
"model": str, | ||
"embedding": dict, | ||
}, | ||
) | ||
|
||
assert ["prompts"] == db.table_names() | ||
|
||
table.insert({"prompt": "hello"}) | ||
table.insert({"prompt": "world"}) | ||
|
||
assert db.execute( | ||
"select prompt from prompts" | ||
).fetchall() == [("hello",), ("world",)] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,16 @@ | ||
def example_function(): | ||
return 1 + 1 | ||
def _prompt(p): | ||
return p * 2 | ||
|
||
|
||
TSELLM_CONFIG_SQL = """ | ||
CREATE TABLE IF NOT EXISTS __tsellm ( | ||
data | ||
); | ||
""" | ||
|
||
|
||
def _tsellm_init(con): | ||
"""Entry-point for tsellm initialization.""" | ||
con.execute(TSELLM_CONFIG_SQL) | ||
con.create_function("prompt", 1, _prompt) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
import sys | ||
from .cli import cli | ||
|
||
if __name__ == "__main__": | ||
cli() | ||
cli(sys.argv[1:]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,124 @@ | ||
import click | ||
@click.group() | ||
@click.version_option() | ||
def cli(): | ||
""" CLI for tsellm """ | ||
pass | ||
|
||
@cli.command() | ||
@click.argument( | ||
"name", | ||
type=str, | ||
required=True, | ||
) | ||
def hello(name): | ||
print(f"Hello, {name}") | ||
import sqlite3 | ||
import sys | ||
|
||
from argparse import ArgumentParser | ||
from code import InteractiveConsole | ||
from textwrap import dedent | ||
from . import _prompt, _tsellm_init | ||
|
||
|
||
def execute(c, sql, suppress_errors=True): | ||
"""Helper that wraps execution of SQL code. | ||
This is used both by the REPL and by direct execution from the CLI. | ||
'c' may be a cursor or a connection. | ||
'sql' is the SQL string to execute. | ||
""" | ||
|
||
try: | ||
for row in c.execute(sql): | ||
print(row) | ||
except sqlite3.Error as e: | ||
tp = type(e).__name__ | ||
try: | ||
print(f"{tp} ({e.sqlite_errorname}): {e}", file=sys.stderr) | ||
except AttributeError: | ||
print(f"{tp}: {e}", file=sys.stderr) | ||
if not suppress_errors: | ||
sys.exit(1) | ||
|
||
|
||
class SqliteInteractiveConsole(InteractiveConsole): | ||
"""A simple SQLite REPL.""" | ||
|
||
def __init__(self, connection): | ||
super().__init__() | ||
self._con = connection | ||
self._cur = connection.cursor() | ||
|
||
def runsource(self, source, filename="<input>", symbol="single"): | ||
"""Override runsource, the core of the InteractiveConsole REPL. | ||
Return True if more input is needed; buffering is done automatically. | ||
Return False is input is a complete statement ready for execution. | ||
""" | ||
match source: | ||
case ".version": | ||
print(f"{sqlite3.sqlite_version}") | ||
case ".help": | ||
print("Enter SQL code and press enter.") | ||
case ".quit": | ||
sys.exit(0) | ||
case _: | ||
if not sqlite3.complete_statement(source): | ||
return True | ||
execute(self._cur, source) | ||
return False | ||
|
||
|
||
def cli(*args): | ||
print(args) | ||
parser = ArgumentParser( | ||
description="tsellm sqlite3 CLI", | ||
prog="python -m tsellm", | ||
) | ||
parser.add_argument( | ||
"filename", type=str, default=":memory:", nargs="?", | ||
help=( | ||
"SQLite database to open (defaults to ':memory:'). " | ||
"A new database is created if the file does not previously exist." | ||
), | ||
) | ||
parser.add_argument( | ||
"sql", type=str, nargs="?", | ||
help=( | ||
"An SQL query to execute. " | ||
"Any returned rows are printed to stdout." | ||
), | ||
) | ||
parser.add_argument( | ||
"-v", "--version", action="version", | ||
version=f"SQLite version {sqlite3.sqlite_version}", | ||
help="Print underlying SQLite library version", | ||
) | ||
args = parser.parse_args(*args) | ||
|
||
if args.filename == ":memory:": | ||
db_name = "a transient in-memory database" | ||
else: | ||
db_name = repr(args.filename) | ||
|
||
# Prepare REPL banner and prompts. | ||
if sys.platform == "win32" and "idlelib.run" not in sys.modules: | ||
eofkey = "CTRL-Z" | ||
else: | ||
eofkey = "CTRL-D" | ||
banner = dedent(f""" | ||
tsellm shell, running on SQLite version {sqlite3.sqlite_version} | ||
Connected to {db_name} | ||
Each command will be run using execute() on the cursor. | ||
Type ".help" for more information; type ".quit" or {eofkey} to quit. | ||
""").strip() | ||
sys.ps1 = "tsellm> " | ||
sys.ps2 = " ... " | ||
|
||
con = sqlite3.connect(args.filename, isolation_level=None) | ||
_tsellm_init(con) | ||
try: | ||
if args.sql: | ||
# SQL statement provided on the command-line; execute it directly. | ||
execute(con, args.sql, suppress_errors=False) | ||
else: | ||
# No SQL provided; start the REPL. | ||
console = SqliteInteractiveConsole(con) | ||
try: | ||
import readline | ||
except ImportError: | ||
pass | ||
console.interact(banner, exitmsg="") | ||
finally: | ||
con.close() | ||
|
||
sys.exit(0) |