Skip to content

Commit

Permalink
cleaner dbsniffer
Browse files Browse the repository at this point in the history
  • Loading branch information
Florents-Tselai committed Jul 8, 2024
1 parent a19fc5c commit a9015dd
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 72 deletions.
47 changes: 26 additions & 21 deletions tests/test_tsellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
cli,
TsellmConsole,
SQLiteConsole,
TsellmConsoleMixin,
DuckDBConsole,

DBSniffer
)


Expand All @@ -25,17 +27,33 @@ def new_tempfile():
def new_sqlite_file():
f = new_tempfile()
with sqlite3.connect(f) as db:
db.execute("SELECT 1")
db.execute("CREATE TABLE my(x text)")
return f


def new_duckdb_file():
f = new_tempfile()
con = duckdb.connect(f.__str__())
con.sql("SELECT 1")
con.sql("CREATE TABLE my(x text)")
return f


class TestDBSniffer(unittest.TestCase):
def setUp(self):
self.sqlite_fp = new_sqlite_file()
self.duckdb_fp = new_duckdb_file()

def test_sniff_sqlite(self):
sqlite_sni = DBSniffer(self.sqlite_fp)
self.assertTrue(sqlite_sni.is_sqlite)
self.assertFalse(sqlite_sni.is_duckdb)

def test_snif_duckdb(self):
duckdb_sni = DBSniffer(self.duckdb_fp)
self.assertFalse(duckdb_sni.is_sqlite)
self.assertTrue(duckdb_sni.is_duckdb)


class TsellmConsoleTest(unittest.TestCase):
def setUp(self):
super().setUp()
Expand Down Expand Up @@ -69,23 +87,15 @@ def expect_failure(self, *args):
self.assertEqual(out, "")
return err

def test_sniff_sqlite(self):
self.assertTrue(TsellmConsoleMixin().is_sqlite(new_sqlite_file()))

def test_sniff_duckdb(self):
self.assertTrue(TsellmConsoleMixin().is_duckdb(new_duckdb_file()))

def test_console_factory_sqlite(self):
s = new_sqlite_file()
self.assertTrue(TsellmConsoleMixin().is_sqlite(s))
obj = TsellmConsole.create_console(s)
self.assertIsInstance(obj, SQLiteConsole)

# def test_console_factory_duckdb(self):
# s = new_duckdb_file()
# self.assertTrue(TsellmConsole.is_duckdb(s))
# obj = TsellmConsole.create_console(s)
# self.assertIsInstance(obj, DuckDBConsole)
d = new_duckdb_file()
self.assertTrue(TsellmConsole.create_console(d))
obj = TsellmConsole.create_console(d)
self.assertIsInstance(obj, DuckDBConsole)

def test_cli_help(self):
out = self.expect_success("-h")
Expand All @@ -98,11 +108,6 @@ def test_cli_version(self):
def test_choose_db(self):
self.expect_failure("--sqlite", "--duckdb")

def test_deault_sqlite(self):
f = new_tempfile()
self.expect_success(str(f), "select 1")
self.assertTrue(TsellmConsoleMixin().is_sqlite(f))

MEMORY_DB_MSG = "Connected to :memory:"
PS1 = "tsellm> "
PS2 = "... "
Expand Down Expand Up @@ -266,7 +271,7 @@ def setUp(self):
def test_embed_default_hazo_leaves_valid_db_behind(self):
# This should probably be called for all test cases
super().test_embed_default_hazo()
self.assertTrue(TsellmConsoleMixin().is_sqlite(self.db_fp))
self.assertTrue(DBSniffer(self.db_fp).is_sqlite)


class InMemoryDuckDBTest(InMemorySQLiteTest):
Expand Down
94 changes: 43 additions & 51 deletions tsellm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,43 +32,36 @@ class DatabaseType(Enum):
sys.ps2 = " ... "


class TsellmConsoleMixin(InteractiveConsole):
def is_sqlite(self, path):
try:
with sqlite3.connect(path) as conn:
conn.execute("SELECT 1")
return True
except:
return False
@dataclass
class DBSniffer:
fp: Union[str, Path]

def is_duckdb(self, path):
try:
con = duckdb.connect(path.__str__())
con.sql("SELECT 1")
return True
except:
return False
def sniff(self) -> DatabaseType:
with open(self.fp, 'rb') as f:
header = f.read(16)
if header.startswith(b'SQLite format 3'):
return DatabaseType.SQLITE

def sniff_db(self, path):
"""
Sniffs if the path is a SQLite or DuckDB database.
try:
con = duckdb.connect(str(self.fp))
con.sql("SELECT 1")
return DatabaseType.DUCKDB
except:
return DatabaseType.UNKNOWN

Args:
path (str): The file path to check.
@property
def is_duckdb(self) -> bool:
return self.sniff() == DatabaseType.DUCKDB

Returns:
DatabaseType: The type of database (DatabaseType.SQLITE, DatabaseType.DUCKDB,
DatabaseType.UNKNOWN, DatabaseType.FILE_NOT_FOUND, DatabaseType.ERROR).
"""
@property
def is_sqlite(self) -> bool:
return self.sniff() == DatabaseType.SQLITE

if TsellmConsole.is_sqlite(path):
return DatabaseType.SQLITE
if TsellmConsole.is_duckdb(path):
return DatabaseType.DUCKDB
return DatabaseType.UNKNOWN
@property
def is_in_memory(self) -> bool:
return self.fp == ':memory:'


@dataclass
class TsellmConsole(InteractiveConsole, ABC):
_TSELLM_CONFIG_SQL = """
-- tsellm configuration table
Expand All @@ -91,6 +84,25 @@ class TsellmConsole(InteractiveConsole, ABC):
db_type: str = field(init=False)
connection: Union[sqlite3.Connection, duckdb.DuckDBPyConnection] = field(init=False)

@staticmethod
def create_console(fp: Union[str, Path],
in_memory_type: DatabaseType = DatabaseType.UNKNOWN):
sniffer = DBSniffer(fp)
if sniffer.is_in_memory:
if sniffer.is_duckdb:
return DuckDBConsole(fp)
elif sniffer.is_sqlite:
return SQLiteConsole(fp)
else:
raise ValueError(f"To create an in-memory db, DatabaseType should be supplied")

if sniffer.is_duckdb:
return DuckDBConsole(fp)
elif sniffer.is_sqlite:
return SQLiteConsole(fp)
else:
raise ValueError(f"Cannot create console with fp={fp} and in_memory_type={in_memory_type}")

@property
def tsellm_version(self) -> str:
return __version__.__version__
Expand Down Expand Up @@ -154,15 +166,6 @@ def load(self):
for func_name, n_args, py_func, deterministic in self._functions:
self.connection.create_function(func_name, n_args, py_func)

@staticmethod
def create_console(path):
if TsellmConsoleMixin().is_duckdb(path):
return DuckDBConsole(path)
if TsellmConsoleMixin().is_sqlite(path):
return SQLiteConsole(path)
else:
raise ValueError(f"Database type {path} not supported")

@abstractmethod
def execute(self, sql, suppress_errors=True):
pass
Expand Down Expand Up @@ -268,7 +271,7 @@ def is_valid_db(self) -> bool:
]

def connect(self):
self.connection = duckdb.connect(self.path)
self.connection = duckdb.connect(str(self.path))

def load(self):
self.execute(self._TSELLM_CONFIG_SQL)
Expand Down Expand Up @@ -357,17 +360,6 @@ def cli(*args):

if args.sqlite and args.duckdb:
raise ValueError("Only one of --sqlite and --duckdb can be specified.")
if (not args.sqlite) and (not args.duckdb):
if args.filename == ":memory:":
args.sqlite = True
args.duckdb = False
else:
if TsellmConsoleMixin().is_duckdb(args.filename):
args.duckdb = True
args.sqlite = False
else:
args.duckdb = False
args.sqlite = True

console = (
DuckDBConsole(args.filename) if args.duckdb else SQLiteConsole(args.filename)
Expand Down

0 comments on commit a9015dd

Please sign in to comment.