Skip to content

Commit

Permalink
feat(decltypes): support to parse_decltypes
Browse files Browse the repository at this point in the history
  • Loading branch information
Daniele Briggi committed Aug 14, 2024
1 parent f872db6 commit 7833b28
Show file tree
Hide file tree
Showing 9 changed files with 671 additions and 125 deletions.
321 changes: 288 additions & 33 deletions bandit-baseline.json

Large diffs are not rendered by default.

21 changes: 19 additions & 2 deletions src/sqlitecloud/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,25 @@
# the classes and functions from the dbapi2 module.
# eg: sqlite3.connect() -> sqlitecloud.connect()
#
from .dbapi2 import Connection, Cursor, connect, register_adapter
from .dbapi2 import (
PARSE_COLNAMES,
PARSE_DECLTYPES,
Connection,
Cursor,
connect,
register_adapter,
register_converter,
)

__all__ = ["VERSION", "Connection", "Cursor", "connect", "register_adapter"]
__all__ = [
"VERSION",
"Connection",
"Cursor",
"connect",
"register_adapter",
"register_converter",
"PARSE_DECLTYPES",
"PARSE_COLNAMES",
]

VERSION = "0.0.79"
113 changes: 81 additions & 32 deletions src/sqlitecloud/dbapi2.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@
PARSE_DECLTYPES = 1
PARSE_COLNAMES = 2

# Adapter registry to convert Python types to SQLite types
adapters = {}
# Adapters registry to convert Python types to SQLite types
_adapters = {}
# Converters registry to convert SQLite types to Python types
_converters = {}


@overload
Expand Down Expand Up @@ -106,6 +108,11 @@ def connect(
It can be either a connection string or a `SqliteCloudAccount` object.
config (Optional[SQLiteCloudConfig]): The configuration options for the connection.
Defaults to None.
detect_types (int): Default (0), disabled. How data types not natively supported
by SQLite are looked up to be converted to Python types, using the converters
registered with register_converter().
Accepts any combination (using |, bitwise or) of PARSE_DECLTYPES and PARSE_COLNAMES.
Column names takes precedence over declared types if both flags are set.
Returns:
Connection: A DB-API 2.0 connection object representing the connection to the database.
Expand All @@ -122,13 +129,16 @@ def connect(
else:
config = SQLiteCloudConfig(connection_info)

return Connection(
driver.connect(config.account.hostname, config.account.port, config)
connection = Connection(
driver.connect(config.account.hostname, config.account.port, config),
detect_types=detect_types,
)

return connection


def register_adapter(
pytype: Type, adapter_callable: Callable[[object], SQLiteTypes]
pytype: Type, adapter_callable: Callable[[Any], SQLiteTypes]
) -> None:
"""
Registers a callable to convert the type into one of the supported SQLite types.
Expand All @@ -138,8 +148,21 @@ def register_adapter(
callable (Callable): The callable that converts the type into a supported
SQLite supported type.
"""
global adapters
adapters[pytype] = adapter_callable
global _adapters
_adapters[pytype] = adapter_callable


def register_converter(type_name: str, converter: Callable[[bytes], Any]) -> None:
"""
Registers a callable to convert a bytestring from the database into a custom Python type.
Args:
type_name (str): The name of the type to convert.
The match with the name of the type in the query is case-insensitive.
converter (Callable): The callable that converts the bytestring into the custom Python type.
"""
global _converters
_converters[type_name.lower()] = converter


class Connection:
Expand All @@ -154,16 +177,16 @@ class Connection:
SQLiteCloud_connection (SQLiteCloudConnect): The SQLite Cloud connection object.
"""

def __init__(self, sqlitecloud_connection: SQLiteCloudConnect) -> None:
def __init__(
self, sqlitecloud_connection: SQLiteCloudConnect, detect_types: int = 0
) -> None:
self._driver = Driver()
self.sqlitecloud_connection = sqlitecloud_connection

self.row_factory: Optional[Callable[["Cursor", Tuple], object]] = None
self.text_factory: Union[
Type[Union[str, bytes]], Callable[[bytes], object]
] = str
self.text_factory: Union[Type[Union[str, bytes]], Callable[[bytes], Any]] = str

self.detect_types = 0
self.detect_types = detect_types

@property
def sqlcloud_connection(self) -> SQLiteCloudConnect:
Expand Down Expand Up @@ -273,19 +296,19 @@ def cursor(self):
cursor.row_factory = self.row_factory
return cursor

def _apply_adapter(self, value: object) -> SQLiteTypes:
def _apply_adapter(self, value: Any) -> SQLiteTypes:
"""
Applies the registered adapter to convert the Python type into a SQLite supported type.
In the case there is no registered adapter, it calls the __conform__() method when the value object implements it.
Args:
value (object): The Python type to convert.
value (Any): The Python type to convert.
Returns:
SQLiteTypes: The SQLite supported type or the given value when no adapter is found.
"""
if type(value) in adapters:
return adapters[type(value)](value)
if type(value) in _adapters:
return _adapters[type(value)](value)

if hasattr(value, "__conform__"):
# we don't support sqlite3.PrepareProtocol
Expand Down Expand Up @@ -445,6 +468,8 @@ def executemany(

commands = ""
for parameters in seq_of_parameters:
parameters = self._adapt_parameters(parameters)

prepared_statement = self._driver.prepare_statement(sql, parameters)
commands += prepared_statement + ";"

Expand Down Expand Up @@ -547,24 +572,48 @@ def _adapt_parameters(self, parameters: Union[Dict, Tuple]) -> Union[Dict, Tuple

return tuple(self._connection._apply_adapter(p) for p in parameters)

def _convert_value(self, value: Any, decltype: Optional[str]) -> Any:
# todo: parse columns first

if (self.connection.detect_types & PARSE_DECLTYPES) == PARSE_DECLTYPES:
return self._parse_decltypes(value, decltype)

if decltype == SQLITECLOUD_VALUE_TYPE.TEXT.value or (
decltype is None and isinstance(value, str)
):
return self._apply_text_factory(value)

return value

def _parse_decltypes(self, value: Any, decltype: str) -> Any:
decltype = decltype.lower()
if decltype in _converters:
# sqlite3 always passes value as bytes
return _converters[decltype](str(value).encode("utf-8"))

return value

def _apply_text_factory(self, value: Any) -> Any:
"""Use Connection.text_factory to convert value with TEXT column or
string value with undleclared column type."""

if self._connection.text_factory is bytes:
return value.encode("utf-8")
if self._connection.text_factory is not str and callable(
self._connection.text_factory
):
return self._connection.text_factory(value.encode("utf-8"))

return value

def _get_value(self, row: int, col: int) -> Optional[Any]:
if not self._is_result_rowset():
return None

# Convert TEXT type with text_factory
value = self._resultset.get_value(row, col)
decltype = self._resultset.get_decltype(col)
if decltype is None or decltype == SQLITECLOUD_VALUE_TYPE.TEXT.value:
value = self._resultset.get_value(row, col, False)

if self._connection.text_factory is bytes:
return value.encode("utf-8")
if self._connection.text_factory is not str and callable(
self._connection.text_factory
):
return self._connection.text_factory(value.encode("utf-8"))
return value

return self._resultset.get_value(row, col)
return self._convert_value(value, decltype)

def __iter__(self) -> "Cursor":
return self
Expand Down Expand Up @@ -602,7 +651,7 @@ def adapt_datetime(val):
return val.isoformat(" ")

def convert_date(val):
return datetime.date(*map(int, val.split(b"-")))
return date(*map(int, val.split(b"-")))

def convert_timestamp(val):
datepart, timepart = val.split(b" ")
Expand All @@ -614,13 +663,13 @@ def convert_timestamp(val):
else:
microseconds = 0

val = datetime.datetime(year, month, day, hours, minutes, seconds, microseconds)
val = datetime(year, month, day, hours, minutes, seconds, microseconds)
return val

register_adapter(date, adapt_date)
register_adapter(datetime, adapt_datetime)
# register_converter("date", convert_date)
# register_converter("timestamp", convert_timestamp)
register_converter("date", convert_date)
register_converter("timestamp", convert_timestamp)


register_adapters_and_converters()
22 changes: 2 additions & 20 deletions src/sqlitecloud/resultset.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,12 @@ def _compute_index(self, row: int, col: int) -> int:
return -1
return row * self.ncols + col

def get_value(self, row: int, col: int, convert: bool = True) -> Optional[any]:
def get_value(self, row: int, col: int) -> Optional[any]:
index = self._compute_index(row, col)
if index < 0 or not self.data or index >= len(self.data):
return None

value = self.data[index]
return self._convert(value, col) if convert else value
return self.data[index]

def get_name(self, col: int) -> Optional[str]:
if col < 0 or col >= self.ncols:
Expand All @@ -79,23 +78,6 @@ def get_decltype(self, col: int) -> Optional[str]:

return self.decltype[col]

def _convert(self, value: str, col: int) -> any:
if col < 0 or col >= len(self.decltype):
return value

decltype = self.decltype[col]
if decltype == SQLITECLOUD_VALUE_TYPE.INTEGER.value:
return int(value)
if decltype == SQLITECLOUD_VALUE_TYPE.FLOAT.value:
return float(value)
if decltype == SQLITECLOUD_VALUE_TYPE.BLOB.value:
# values are received as bytes before being strings
return bytes(value)
if decltype == SQLITECLOUD_VALUE_TYPE.NULL.value:
return None

return value


class SQLiteCloudResultSet:
def __init__(self, result: SQLiteCloudResult) -> None:
Expand Down
7 changes: 4 additions & 3 deletions src/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,15 @@ def sqlitecloud_dbapi2_connection():
yield next(get_sqlitecloud_dbapi2_connection())


def get_sqlitecloud_dbapi2_connection():
def get_sqlitecloud_dbapi2_connection(detect_types: int = 0):
account = SQLiteCloudAccount()
account.username = os.getenv("SQLITE_USER")
account.password = os.getenv("SQLITE_PASSWORD")
account.dbname = os.getenv("SQLITE_DB")
account.hostname = os.getenv("SQLITE_HOST")
account.port = int(os.getenv("SQLITE_PORT"))

connection = sqlitecloud.connect(account)
connection = sqlitecloud.connect(account, detect_types=detect_types)

assert isinstance(connection, sqlitecloud.Connection)

Expand All @@ -62,12 +62,13 @@ def get_sqlitecloud_dbapi2_connection():
connection.close()


def get_sqlite3_connection():
def get_sqlite3_connection(detect_types: int = 0):
# set isolation_level=None to enable autocommit
# and to be aligned with the behavior of SQLite Cloud
connection = sqlite3.connect(
os.path.join(os.path.dirname(__file__), "./assets/chinook.sqlite"),
isolation_level=None,
detect_types=detect_types,
)
yield connection
connection.close()
7 changes: 4 additions & 3 deletions src/tests/integration/test_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import random
import time

import pytest
Expand Down Expand Up @@ -641,10 +642,10 @@ def test_big_rowset(self):

connection = client.open_connection()

table_name = "TestCompress" + str(int(time.time()))
table_name = "TestCompress" + str(random.randint(0, 99999))
try:
client.exec_query(
f"CREATE TABLE IF NOT EXISTS {table_name} (id INTEGER PRIMARY KEY, name TEXT)",
f"CREATE TABLE {table_name} (id INTEGER PRIMARY KEY, name TEXT)",
connection,
)

Expand All @@ -663,7 +664,7 @@ def test_big_rowset(self):

assert rowset.nrows == nRows
finally:
client.exec_query(f"DROP TABLE {table_name}", connection)
client.exec_query(f"DROP TABLE IF EXISTS {table_name}", connection)
client.disconnect(connection)

def test_compression_single_column(self):
Expand Down
Loading

0 comments on commit 7833b28

Please sign in to comment.