Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Florents-Tselai committed Jul 7, 2024
1 parent d5b2dc3 commit 9f80cdb
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 13 deletions.
18 changes: 15 additions & 3 deletions tests/test_tsellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def run_cli(self, *args, commands=()):
captured_stdin() as stdin,
captured_stdout() as stdout,
captured_stderr() as stderr,
self.assertRaises(SystemExit) as cm
self.assertRaises(SystemExit) as cm,
):
for cmd in commands:
stdin.write(cmd + "\n")
Expand All @@ -121,8 +121,9 @@ def run_cli(self, *args, commands=()):

out = stdout.getvalue()
err = stderr.getvalue()
self.assertEqual(cm.exception.code, 0,
f"Unexpected failure: {args=}\n{out}\n{err}")
self.assertEqual(
cm.exception.code, 0, f"Unexpected failure: {args=}\n{out}\n{err}"
)
return out, err

def test_interact(self):
Expand Down Expand Up @@ -299,5 +300,16 @@ def test_embed_hazo_binary(self):
pass


class DiskDuckDBTest(InMemoryDuckDBTest):
db_fp = None
path_args = ()

def setUp(self):
super().setUp()
self.db_fp = str(new_duckdb_file())
print(self.db_fp)
self.path_args = (self.db_fp,)


if __name__ == "__main__":
unittest.main()
42 changes: 32 additions & 10 deletions tsellm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ class TsellmConsole(InteractiveConsole, ABC):
def tsellm_version(self) -> str:
return __version__.__version__

@property
@abstractmethod
def is_in_memory(self) -> bool:
pass

@property
def eofkey(self):
if sys.platform == "win32" and "idlelib.run" not in sys.modules:
Expand Down Expand Up @@ -134,12 +139,14 @@ def complete_statement(self, source) -> bool:

@property
def version(self):
return " ".join([
"tsellm version",
self.tsellm_version,
self.db_type,
"version",
self.db_version]
return " ".join(
[
"tsellm version",
self.tsellm_version,
self.db_type,
"version",
self.db_version,
]
)

def load(self):
Expand Down Expand Up @@ -192,6 +199,10 @@ def __post_init__(self):

@dataclass
class SQLiteConsole(TsellmConsole):
@property
def is_in_memory(self) -> bool:
return self.path == ":memory:"

db_type = "SQLite"

def connect(self):
Expand Down Expand Up @@ -235,6 +246,10 @@ def db_version(self):

@dataclass
class DuckDBConsole(TsellmConsole):
@property
def is_in_memory(self) -> bool:
return self.path == ":memory:"

db_type = "DuckDB"
path: Union[Path, str, sqlite3.Connection, duckdb.DuckDBPyConnection]

Expand Down Expand Up @@ -342,10 +357,17 @@ def cli(*args):

if args.sqlite and args.duckdb:
raise ValueError("Only one of --sqlite and --duckdb can be specified.")

if (not args.sqlite) and (not args.duckdb) and args.filename == ":memory:":
args.sqlite = True
args.duckdb = False
if (not args.sqlite) and (not args.duckdb):
if args.filename == ":memory:":
args.sqlite = True
args.duckdb = False
else:
if TsellmConsoleMixin().is_duckdb(args.filename):
args.duckdb = True
args.sqlite = False
else:
args.duckdb = False
args.sqlite = True

console = (
DuckDBConsole(args.filename) if args.duckdb else SQLiteConsole(args.filename)
Expand Down

0 comments on commit 9f80cdb

Please sign in to comment.