diff --git a/tests/test_tsellm.py b/tests/test_tsellm.py index 3e5a8b8..06a918c 100644 --- a/tests/test_tsellm.py +++ b/tests/test_tsellm.py @@ -14,7 +14,9 @@ cli, TsellmConsole, SQLiteConsole, - TsellmConsoleMixin, + DuckDBConsole, + + DBSniffer ) @@ -25,17 +27,33 @@ def new_tempfile(): def new_sqlite_file(): f = new_tempfile() with sqlite3.connect(f) as db: - db.execute("SELECT 1") + db.execute("CREATE TABLE my(x text)") return f def new_duckdb_file(): f = new_tempfile() con = duckdb.connect(f.__str__()) - con.sql("SELECT 1") + con.sql("CREATE TABLE my(x text)") return f +class TestDBSniffer(unittest.TestCase): + def setUp(self): + self.sqlite_fp = new_sqlite_file() + self.duckdb_fp = new_duckdb_file() + + def test_sniff_sqlite(self): + sqlite_sni = DBSniffer(self.sqlite_fp) + self.assertTrue(sqlite_sni.is_sqlite) + self.assertFalse(sqlite_sni.is_duckdb) + + def test_snif_duckdb(self): + duckdb_sni = DBSniffer(self.duckdb_fp) + self.assertFalse(duckdb_sni.is_sqlite) + self.assertTrue(duckdb_sni.is_duckdb) + + class TsellmConsoleTest(unittest.TestCase): def setUp(self): super().setUp() @@ -69,23 +87,15 @@ def expect_failure(self, *args): self.assertEqual(out, "") return err - def test_sniff_sqlite(self): - self.assertTrue(TsellmConsoleMixin().is_sqlite(new_sqlite_file())) - - def test_sniff_duckdb(self): - self.assertTrue(TsellmConsoleMixin().is_duckdb(new_duckdb_file())) - def test_console_factory_sqlite(self): s = new_sqlite_file() - self.assertTrue(TsellmConsoleMixin().is_sqlite(s)) obj = TsellmConsole.create_console(s) self.assertIsInstance(obj, SQLiteConsole) - # def test_console_factory_duckdb(self): - # s = new_duckdb_file() - # self.assertTrue(TsellmConsole.is_duckdb(s)) - # obj = TsellmConsole.create_console(s) - # self.assertIsInstance(obj, DuckDBConsole) + d = new_duckdb_file() + self.assertTrue(TsellmConsole.create_console(d)) + obj = TsellmConsole.create_console(d) + self.assertIsInstance(obj, DuckDBConsole) def test_cli_help(self): out = self.expect_success("-h") @@ -98,11 +108,6 @@ def test_cli_version(self): def test_choose_db(self): self.expect_failure("--sqlite", "--duckdb") - def test_deault_sqlite(self): - f = new_tempfile() - self.expect_success(str(f), "select 1") - self.assertTrue(TsellmConsoleMixin().is_sqlite(f)) - MEMORY_DB_MSG = "Connected to :memory:" PS1 = "tsellm> " PS2 = "... " @@ -266,7 +271,7 @@ def setUp(self): def test_embed_default_hazo_leaves_valid_db_behind(self): # This should probably be called for all test cases super().test_embed_default_hazo() - self.assertTrue(TsellmConsoleMixin().is_sqlite(self.db_fp)) + self.assertTrue(DBSniffer(self.db_fp).is_sqlite) class InMemoryDuckDBTest(InMemorySQLiteTest): diff --git a/tsellm/cli.py b/tsellm/cli.py index de1a640..4098576 100644 --- a/tsellm/cli.py +++ b/tsellm/cli.py @@ -32,43 +32,36 @@ class DatabaseType(Enum): sys.ps2 = " ... " -class TsellmConsoleMixin(InteractiveConsole): - def is_sqlite(self, path): - try: - with sqlite3.connect(path) as conn: - conn.execute("SELECT 1") - return True - except: - return False +@dataclass +class DBSniffer: + fp: Union[str, Path] - def is_duckdb(self, path): - try: - con = duckdb.connect(path.__str__()) - con.sql("SELECT 1") - return True - except: - return False + def sniff(self) -> DatabaseType: + with open(self.fp, 'rb') as f: + header = f.read(16) + if header.startswith(b'SQLite format 3'): + return DatabaseType.SQLITE - def sniff_db(self, path): - """ - Sniffs if the path is a SQLite or DuckDB database. + try: + con = duckdb.connect(str(self.fp)) + con.sql("SELECT 1") + return DatabaseType.DUCKDB + except: + return DatabaseType.UNKNOWN - Args: - path (str): The file path to check. + @property + def is_duckdb(self) -> bool: + return self.sniff() == DatabaseType.DUCKDB - Returns: - DatabaseType: The type of database (DatabaseType.SQLITE, DatabaseType.DUCKDB, - DatabaseType.UNKNOWN, DatabaseType.FILE_NOT_FOUND, DatabaseType.ERROR). - """ + @property + def is_sqlite(self) -> bool: + return self.sniff() == DatabaseType.SQLITE - if TsellmConsole.is_sqlite(path): - return DatabaseType.SQLITE - if TsellmConsole.is_duckdb(path): - return DatabaseType.DUCKDB - return DatabaseType.UNKNOWN + @property + def is_in_memory(self) -> bool: + return self.fp == ':memory:' -@dataclass class TsellmConsole(InteractiveConsole, ABC): _TSELLM_CONFIG_SQL = """ -- tsellm configuration table @@ -91,6 +84,25 @@ class TsellmConsole(InteractiveConsole, ABC): db_type: str = field(init=False) connection: Union[sqlite3.Connection, duckdb.DuckDBPyConnection] = field(init=False) + @staticmethod + def create_console(fp: Union[str, Path], + in_memory_type: DatabaseType = DatabaseType.UNKNOWN): + sniffer = DBSniffer(fp) + if sniffer.is_in_memory: + if sniffer.is_duckdb: + return DuckDBConsole(fp) + elif sniffer.is_sqlite: + return SQLiteConsole(fp) + else: + raise ValueError(f"To create an in-memory db, DatabaseType should be supplied") + + if sniffer.is_duckdb: + return DuckDBConsole(fp) + elif sniffer.is_sqlite: + return SQLiteConsole(fp) + else: + raise ValueError(f"Cannot create console with fp={fp} and in_memory_type={in_memory_type}") + @property def tsellm_version(self) -> str: return __version__.__version__ @@ -154,15 +166,6 @@ def load(self): for func_name, n_args, py_func, deterministic in self._functions: self.connection.create_function(func_name, n_args, py_func) - @staticmethod - def create_console(path): - if TsellmConsoleMixin().is_duckdb(path): - return DuckDBConsole(path) - if TsellmConsoleMixin().is_sqlite(path): - return SQLiteConsole(path) - else: - raise ValueError(f"Database type {path} not supported") - @abstractmethod def execute(self, sql, suppress_errors=True): pass @@ -268,7 +271,7 @@ def is_valid_db(self) -> bool: ] def connect(self): - self.connection = duckdb.connect(self.path) + self.connection = duckdb.connect(str(self.path)) def load(self): self.execute(self._TSELLM_CONFIG_SQL) @@ -357,17 +360,6 @@ def cli(*args): if args.sqlite and args.duckdb: raise ValueError("Only one of --sqlite and --duckdb can be specified.") - if (not args.sqlite) and (not args.duckdb): - if args.filename == ":memory:": - args.sqlite = True - args.duckdb = False - else: - if TsellmConsoleMixin().is_duckdb(args.filename): - args.duckdb = True - args.sqlite = False - else: - args.duckdb = False - args.sqlite = True console = ( DuckDBConsole(args.filename) if args.duckdb else SQLiteConsole(args.filename)