Skip to content

Commit

Permalink
feat(adapters): text factory and adapters
Browse files Browse the repository at this point in the history
  • Loading branch information
Daniele Briggi committed Aug 13, 2024
1 parent af80e21 commit 1c4745a
Show file tree
Hide file tree
Showing 7 changed files with 360 additions and 39 deletions.
4 changes: 2 additions & 2 deletions src/sqlitecloud/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
# the classes and functions from the dbapi2 module.
# eg: sqlite3.connect() -> sqlitecloud.connect()
#
from .dbapi2 import Connection, Cursor, connect
from .dbapi2 import Connection, Cursor, connect, register_adapter

__all__ = ["VERSION", "Connection", "Cursor", "connect"]
__all__ = ["VERSION", "Connection", "Cursor", "connect", "register_adapter"]

VERSION = "0.0.79"
4 changes: 2 additions & 2 deletions src/sqlitecloud/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ def __init__(self, connection_str: Optional[str] = None) -> None:
# Socket connection timeout
self.connect_timeout = SQLITECLOUD_DEFAULT.TIMEOUT.value

# Enable compression
self.compression = False
# Compression enabled by default
self.compression = True
# Tell the server to zero-terminate strings
self.zerotext = False
# Database will be created in memory
Expand Down
125 changes: 121 additions & 4 deletions src/sqlitecloud/dbapi2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# https://peps.python.org/pep-0249/
#
import logging
from datetime import date, datetime
from typing import (
Any,
Callable,
Expand All @@ -13,6 +14,7 @@
List,
Optional,
Tuple,
Type,
Union,
overload,
)
Expand All @@ -25,7 +27,14 @@
SQLiteCloudException,
)
from sqlitecloud.driver import Driver
from sqlitecloud.resultset import SQLITECLOUD_RESULT_TYPE, SQLiteCloudResult
from sqlitecloud.resultset import (
SQLITECLOUD_RESULT_TYPE,
SQLITECLOUD_VALUE_TYPE,
SQLiteCloudResult,
)

# SQLite supported types
SQLiteTypes = Union[int, float, str, bytes, None]

# Question mark style, e.g. ...WHERE name=?
# Module also supports Named style, e.g. ...WHERE name=:name
Expand All @@ -37,6 +46,14 @@
# DB API level
apilevel = "2.0"

# These constants are meant to be used with the detect_types
# parameter of the connect() function
PARSE_DECLTYPES = 1
PARSE_COLNAMES = 2

# Adapter registry to convert Python types to SQLite types
adapters = {}


@overload
def connect(connection_str: str) -> "Connection":
Expand Down Expand Up @@ -80,6 +97,7 @@ def connect(
def connect(
connection_info: Union[str, SQLiteCloudAccount],
config: Optional[SQLiteCloudConfig] = None,
detect_types: int = 0,
) -> "Connection":
"""
Establishes a connection to the SQLite Cloud database.
Expand Down Expand Up @@ -110,6 +128,21 @@ def connect(
)


def register_adapter(
pytype: Type, adapter_callable: Callable[[object], SQLiteTypes]
) -> None:
"""
Registers a callable to convert the type into one of the supported SQLite types.
Args:
type (Type): The type to convert.
callable (Callable): The callable that converts the type into a supported
SQLite supported type.
"""
global adapters
adapters[pytype] = adapter_callable


class Connection:
"""
Represents a DB-APi 2.0 connection to the SQLite Cloud database.
Expand All @@ -123,11 +156,13 @@ class Connection:
"""

row_factory: Optional[Callable[["Cursor", Tuple], object]] = None
text_factory: Union[Type[Union[str, bytes]], Callable[[bytes], object]] = str

def __init__(self, sqlitecloud_connection: SQLiteCloudConnect) -> None:
self._driver = Driver()
self.row_factory = None
self.sqlitecloud_connection = sqlitecloud_connection
self.detect_types = 0

@property
def sqlcloud_connection(self) -> SQLiteCloudConnect:
Expand Down Expand Up @@ -243,6 +278,21 @@ def cursor(self):
cursor.row_factory = self.row_factory
return cursor

def _apply_adapter(self, value: object) -> SQLiteTypes:
"""
Applies the adapter to convert the Python type into a SQLite supported type.
Args:
value (object): The Python type to convert.
Returns:
SQLiteTypes: The SQLite supported type.
"""
if type(value) in adapters:
return adapters[type(value)](value)

return value

def __del__(self) -> None:
self.close()

Expand Down Expand Up @@ -364,6 +414,8 @@ def execute(
"""
self._ensure_connection()

parameters = self._adapt_parameters(parameters)

prepared_statement = self._driver.prepare_statement(sql, parameters)
result = self._driver.execute(
prepared_statement, self.connection.sqlcloud_connection
Expand Down Expand Up @@ -492,12 +544,37 @@ def _ensure_connection(self):
if not self._connection:
raise SQLiteCloudException("The cursor is closed.")

def _adapt_parameters(self, parameters: Union[Dict, Tuple]) -> Union[Dict, Tuple]:
if isinstance(parameters, dict):
params = {}
for i in parameters.keys():
params[i] = self._connection._apply_adapter(parameters[i])
return params

return tuple(self._connection._apply_adapter(p) for p in parameters)

def _get_value(self, row: int, col: int) -> Optional[Any]:
if not self._is_result_rowset():
return None

# Convert TEXT type with text_factory
decltype = self._resultset.get_decltype(col)
if decltype is None or decltype == SQLITECLOUD_VALUE_TYPE.TEXT.value:
value = self._resultset.get_value(row, col, False)

if self._connection.text_factory is bytes:
return value.encode("utf-8")
if self._connection.text_factory is str:
return value
# callable
return self._connection.text_factory(value.encode("utf-8"))

return self._resultset.get_value(row, col)

def __iter__(self) -> "Cursor":
return self

def __next__(self) -> Optional[Tuple[Any]]:
self._ensure_connection()

if (
not self._resultset.is_result
and self._resultset.data
Expand All @@ -506,9 +583,49 @@ def __next__(self) -> Optional[Tuple[Any]]:
out: Tuple[Any] = ()

for col in range(self._resultset.ncols):
out += (self._resultset.get_value(self._iter_row, col),)
out += (self._get_value(self._iter_row, col),)
self._iter_row += 1

return self._call_row_factory(out)

raise StopIteration


def register_adapters_and_converters():
"""
sqlite3 default adapters and converters.
This code is adapted from the Python standard library's sqlite3 module.
The Python standard library is licensed under the Python Software Foundation License.
Source: https://github.com/python/cpython/blob/3.6/Lib/sqlite3/dbapi2.py
"""

def adapt_date(val):
return val.isoformat()

def adapt_datetime(val):
return val.isoformat(" ")

def convert_date(val):
return datetime.date(*map(int, val.split(b"-")))

def convert_timestamp(val):
datepart, timepart = val.split(b" ")
year, month, day = map(int, datepart.split(b"-"))
timepart_full = timepart.split(b".")
hours, minutes, seconds = map(int, timepart_full[0].split(b":"))
if len(timepart_full) == 2:
microseconds = int("{:0<6.6}".format(timepart_full[1].decode()))
else:
microseconds = 0

val = datetime.datetime(year, month, day, hours, minutes, seconds, microseconds)
return val

register_adapter(date, adapt_date)
register_adapter(datetime, adapt_datetime)
# register_converter("date", convert_date)
# register_converter("timestamp", convert_timestamp)


register_adapters_and_converters()
6 changes: 6 additions & 0 deletions src/sqlitecloud/resultset.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ def get_name(self, col: int) -> Optional[str]:
return None
return self.colname[col]

def get_decltype(self, col: int) -> Optional[str]:
if col < 0 or col >= self.ncols or col >= len(self.decltype):
return None

return self.decltype[col]

def _convert(self, value: str, col: int) -> any:
if col < 0 or col >= len(self.decltype):
return value
Expand Down
12 changes: 12 additions & 0 deletions src/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import sqlite3

import pytest
from dotenv import load_dotenv
Expand Down Expand Up @@ -59,3 +60,14 @@ def get_sqlitecloud_dbapi2_connection():
yield connection

connection.close()


def get_sqlite3_connection():
# set isolation_level=None to enable autocommit
# and to be aligned with the behavior of SQLite Cloud
connection = sqlite3.connect(
os.path.join(os.path.dirname(__file__), "./assets/chinook.sqlite"),
isolation_level=None,
)
yield connection
connection.close()
18 changes: 0 additions & 18 deletions src/tests/integration/test_dbapi2.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,21 +246,3 @@ def test_row_factory(self, sqlitecloud_dbapi2_connection):
assert row["AlbumId"] == 1
assert row["Title"] == "For Those About To Rock We Salute You"
assert row["ArtistId"] == 1

def test_commit_without_any_transaction_does_not_raise_exception(
self, sqlitecloud_dbapi2_connection
):
connection = sqlitecloud_dbapi2_connection

connection.commit()

assert True

def test_rollback_without_any_transaction_does_not_raise_exception(
self, sqlitecloud_dbapi2_connection
):
connection = sqlitecloud_dbapi2_connection

connection.rollback()

assert True
Loading

0 comments on commit 1c4745a

Please sign in to comment.