Skip to content

Commit

Permalink
feat(row-object): support to Row object for row_factory
Browse files Browse the repository at this point in the history
  • Loading branch information
Daniele Briggi committed Aug 16, 2024
1 parent 368d050 commit 9be6be2
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 14 deletions.
2 changes: 2 additions & 0 deletions src/sqlitecloud/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
PARSE_DECLTYPES,
Connection,
Cursor,
Row,
adapters,
connect,
converters,
Expand All @@ -25,6 +26,7 @@
"PARSE_COLNAMES",
"adapters",
"converters",
"Row",
]

VERSION = "0.0.79"
61 changes: 58 additions & 3 deletions src/sqlitecloud/dbapi2.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,15 +341,14 @@ class Cursor(Iterator[Any]):

arraysize: int = 1

row_factory: Optional[Callable[["Cursor", Tuple], object]] = None

def __init__(self, connection: Connection) -> None:
self._driver = Driver()
self.row_factory = None
self._connection = connection
self._iter_row: int = 0
self._resultset: SQLiteCloudResult = None

self.row_factory: Optional[Callable[["Cursor", Tuple], object]] = None

@property
def connection(self) -> Connection:
"""
Expand Down Expand Up @@ -577,6 +576,9 @@ def _call_row_factory(self, row: Tuple) -> object:
if self.row_factory is None:
return row

if self.row_factory is Row:
return Row(row, [col[0] for col in self.description])

return self.row_factory(self, row)

def _is_result_rowset(self) -> bool:
Expand Down Expand Up @@ -697,6 +699,59 @@ def __next__(self) -> Optional[Tuple[Any]]:
raise StopIteration


class Row:
def __init__(self, data: Tuple[Any], column_names: List[str]):
"""
Initialize the Row object with data and column names.
Args:
data (Tuple[Any]): A tuple containing the row data.
column_names (List[str]): A list of column names corresponding to the data.
"""
self._data = data
self._column_names = column_names
self._column_map = {name.lower(): idx for idx, name in enumerate(column_names)}

def keys(self) -> List[str]:
"""Return the column names."""
return self._column_names

def __getitem__(self, key):
"""Support indexing by both column name and index."""
if isinstance(key, int):
return self._data[key]
elif isinstance(key, str):
return self._data[self._column_map[key.lower()]]
else:
raise TypeError("Invalid key type. Must be int or str.")

def __len__(self) -> int:
return len(self._data)

def __iter__(self) -> Iterator[Any]:
return iter(self._data)

def __repr__(self) -> str:
return "\n".join(
f"{name}: {self._data[idx]}" for idx, name in enumerate(self._column_names)
)

def __hash__(self) -> int:
return hash((self._data, tuple(self._column_map)))

def __eq__(self, other) -> bool:
"""Check if both have the same data and column names."""
if not isinstance(other, Row):
return NotImplemented

return self._data == other._data and self._column_map == other._column_map

def __ne__(self, other):
if not isinstance(other, Row):
return NotImplemented
return not self.__eq__(other)


class MissingDecltypeException(Exception):
def __init__(self, message: str) -> None:
super().__init__(message)
Expand Down
13 changes: 13 additions & 0 deletions src/tests/integration/test_dbapi2.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,3 +246,16 @@ def test_row_factory(self, sqlitecloud_dbapi2_connection):
assert row["AlbumId"] == 1
assert row["Title"] == "For Those About To Rock We Salute You"
assert row["ArtistId"] == 1

def test_row_object_for_factory_string_representation(
self, sqlitecloud_dbapi2_connection
):
connection = sqlitecloud_dbapi2_connection

connection.row_factory = sqlitecloud.Row

cursor = connection.execute('SELECT "foo" as Bar, "john" Doe')

row = cursor.fetchone()

assert str(row) == "Bar: foo\nDoe: john"
71 changes: 60 additions & 11 deletions src/tests/integration/test_sqlite3_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,27 +131,76 @@ def test_close_cursor_raises_exception(
with pytest.raises(sqlite3.ProgrammingError) as e:
sqlite3_cursor.fetchall()

def test_row_factory(self, sqlitecloud_dbapi2_connection, sqlite3_connection):
sqlitecloud_connection = sqlitecloud_dbapi2_connection
@pytest.mark.parametrize(
"connection", ["sqlitecloud_dbapi2_connection", "sqlite3_connection"]
)
def test_row_factory(self, connection, request):
connection = request.getfixturevalue(connection)

def simple_factory(cursor, row):
return {
description[0]: row[i]
for i, description in enumerate(cursor.description)
}

sqlitecloud_connection.row_factory = simple_factory
sqlite3_connection.row_factory = simple_factory
connection.row_factory = simple_factory

select_query = "SELECT * FROM albums WHERE AlbumId = 1"
sqlitecloud_cursor = sqlitecloud_connection.execute(select_query)
sqlite3_cursor = sqlite3_connection.execute(select_query)
select_query = "SELECT AlbumId, Title, ArtistId FROM albums WHERE AlbumId = 1"
cursor = connection.execute(select_query)

sqlitecloud_results = sqlitecloud_cursor.fetchall()
sqlite3_results = sqlite3_cursor.fetchall()
results = cursor.fetchall()

assert sqlitecloud_results == sqlite3_results
assert sqlitecloud_results[0]["Title"] == sqlite3_results[0]["Title"]
assert results[0]["AlbumId"] == 1
assert results[0]["Title"] == "For Those About To Rock We Salute You"
assert results[0]["ArtistId"] == 1
assert connection.row_factory == cursor.row_factory

@pytest.mark.parametrize(
"connection", ["sqlitecloud_dbapi2_connection", "sqlite3_connection"]
)
def test_cursor_row_factory_as_instance_variable(self, connection, request):
connection = request.getfixturevalue(connection)

cursor = connection.execute("SELECT 1")
cursor.row_factory = lambda c, r: list(r)

cursor2 = connection.execute("SELECT 1")

assert cursor.row_factory != cursor2.row_factory

@pytest.mark.parametrize(
"connection, module",
[
("sqlitecloud_dbapi2_connection", sqlitecloud),
("sqlite3_connection", sqlite3),
],
)
def test_row_factory_with_row_object(self, connection, module, request):
connection = request.getfixturevalue(connection)

connection.row_factory = module.Row

select_query = "SELECT AlbumId, Title, ArtistId FROM albums WHERE AlbumId = 1"
cursor = connection.execute(select_query)

row = cursor.fetchone()

assert row["AlbumId"] == 1
assert row["Title"] == "For Those About To Rock We Salute You"
assert row[1] == row["Title"]
assert row["Title"] == row["title"]
assert row.keys() == ["AlbumId", "Title", "ArtistId"]
assert len(row) == 3
assert next(iter(row)) == 1 # AlbumId
assert not row != row
assert row == row

cursor = connection.execute(
"SELECT AlbumId, Title, ArtistId FROM albums WHERE AlbumId = 2"
)
other_row = cursor.fetchone()

assert row != other_row

@pytest.mark.parametrize(
"connection",
Expand Down

0 comments on commit 9be6be2

Please sign in to comment.