Skip to content

Commit

Permalink
feat(param-style): workaround to support param-style in statements
Browse files Browse the repository at this point in the history
  • Loading branch information
Daniele Briggi committed Aug 20, 2024
1 parent 547baf8 commit f18126e
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 42 deletions.
28 changes: 27 additions & 1 deletion src/sqlitecloud/dbapi2.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,8 @@ def rowcount(self) -> int:
"""
The number of rows that the last .execute*() returned for DQL statements like SELECT or
the number rows affected by DML statements like UPDATE, INSERT and DELETE.
For the executemany() it returns the number of changes only for the last operation.
"""
if self._is_result_rowset():
return self._resultset.nrows
Expand Down Expand Up @@ -490,7 +492,9 @@ def execute(

parameters = self._adapt_parameters(parameters)

# TODO: convert parameters from :name to `?` style
if isinstance(parameters, dict):
parameters = self._named_to_question_mark_parameters(sql, parameters)

result = self._driver.execute_statement(
sql, parameters, self.connection.sqlcloud_connection
)
Expand Down Expand Up @@ -529,6 +533,9 @@ def executemany(
commands = ""
params = []
for parameters in seq_of_parameters:
if isinstance(parameters, dict):
parameters = self._named_to_question_mark_parameters(sql, parameters)

params += list(parameters)

if not sql.endswith(";"):
Expand Down Expand Up @@ -726,6 +733,25 @@ def _apply_text_factory(self, value: Any) -> Any:

return value

def _named_to_question_mark_parameters(
self, sql: str, params: Dict[str, Any]
) -> Tuple[Any]:
"""
Convert named placeholders parameters from a dictionary to a list of
parameters for question mark style.
SCSP protocol does not support named placeholders yet.
"""
pattern = r":(\w+)"
matches = re.findall(pattern, sql)

params_list = ()
for match in matches:
if match in params:
params_list += (params[match],)

return params_list

def _get_value(self, row: int, col: int) -> Optional[Any]:
if not self._is_result_rowset():
return None
Expand Down
46 changes: 17 additions & 29 deletions src/tests/integration/test_dbapi2.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,21 @@ def test_execute_with_named_placeholder(self, sqlitecloud_dbapi2_connection):
assert cursor.rowcount == 1
assert cursor.fetchone() == (1, "For Those About To Rock We Salute You", 1)

def test_execute_with_named_placeholder_and_a_fake_one_which_is_not_given(
self, sqlitecloud_dbapi2_connection
):
""" "Expect the converter from name to qmark placeholder to not be fooled by the
fake name with the colon in it."""
connection = sqlitecloud_dbapi2_connection

cursor = connection.execute(
"SELECT * FROM albums WHERE AlbumId = :id and Title != 'special:name'",
{"id": 1},
)

assert cursor.rowcount == 1
assert cursor.fetchone() == (1, "For Those About To Rock We Salute You", 1)

def test_execute_with_qmarks(self, sqlitecloud_dbapi2_connection):
connection = sqlitecloud_dbapi2_connection

Expand Down Expand Up @@ -408,7 +423,7 @@ def test_last_rowid_and_rowcount_with_executemany_deletes(
new_name1 = "Jazz" + str(uuid.uuid4())
new_name2 = "Jazz" + str(uuid.uuid4())

cursor_select = connection.executemany(
cursor_insert = connection.executemany(
"INSERT INTO genres (Name) VALUES (?)",
[(new_name1,), (new_name2,)],
)
Expand All @@ -418,32 +433,5 @@ def test_last_rowid_and_rowcount_with_executemany_deletes(
)

assert cursor.fetchone() is None
assert cursor.lastrowid == cursor_select.lastrowid
assert cursor.lastrowid == cursor_insert.lastrowid
assert cursor.rowcount == 1

def test_connection_total_changes(self, sqlitecloud_dbapi2_connection):
connection = sqlitecloud_dbapi2_connection

new_name1 = "Jazz" + str(uuid.uuid4())
new_name2 = "Jazz" + str(uuid.uuid4())
new_name3 = "Jazz" + str(uuid.uuid4())

connection.executemany(
"INSERT INTO genres (Name) VALUES (?)",
[(new_name1,), (new_name2,)],
)
assert connection.total_changes == 2

connection.execute("SELECT * FROM genres")
assert connection.total_changes == 2

connection.execute(
"UPDATE genres SET Name = ? WHERE Name = ?", (new_name3, new_name1)
)
assert connection.total_changes == 3

connection.execute(
"DELETE FROM genres WHERE Name in (?, ?, ?)",
(new_name1, new_name2, new_name3),
)
assert connection.total_changes == 5
135 changes: 123 additions & 12 deletions src/tests/integration/test_sqlite3_parity.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import random
import sqlite3
import string
import sys
import time
import uuid
from datetime import date, datetime

import pytest
Expand Down Expand Up @@ -69,6 +71,62 @@ def test_create_table_and_insert_many(

assert sqlitecloud_results == sqlite3_results

@pytest.mark.parametrize(
"connection", ["sqlitecloud_dbapi2_connection", "sqlite3_connection"]
)
def test_executemany_with_a_iterator(self, connection, request):
connection = request.getfixturevalue(connection)

class IterChars:
def __init__(self):
self.count = ord("a")

def __iter__(self):
return self

def __next__(self):
if self.count > ord("z"):
raise StopIteration
self.count += 1
return (chr(self.count - 1),)

try:
connection.execute("DROP TABLE IF EXISTS characters")
cursor = connection.execute("CREATE TABLE IF NOT EXISTS characters(c)")

theIter = IterChars()
cursor.executemany("INSERT INTO characters(c) VALUES (?)", theIter)

cursor.execute("SELECT c FROM characters")

results = cursor.fetchall()
assert len(results) == 26
finally:
connection.execute("DROP TABLE IF EXISTS characters")

@pytest.mark.parametrize(
"connection", ["sqlitecloud_dbapi2_connection", "sqlite3_connection"]
)
def test_executemany_with_yield_generator(self, connection, request):
connection = request.getfixturevalue(connection)

def char_generator():
for c in string.ascii_lowercase:
yield (c,)

try:
connection.execute("DROP TABLE IF EXISTS characters")
cursor = connection.execute("CREATE TABLE IF NOT EXISTS characters(c)")

cursor.executemany("INSERT INTO characters(c) VALUES (?)", char_generator())

cursor.execute("SELECT c FROM characters")

results = cursor.fetchall()
assert len(results) == 26
finally:
connection.execute("DROP TABLE IF EXISTS characters")

def test_execute_with_question_mark_style(
self, sqlitecloud_dbapi2_connection, sqlite3_connection
):
Expand All @@ -84,20 +142,37 @@ def test_execute_with_question_mark_style(

assert sqlitecloud_results == sqlite3_results

def test_execute_with_named_param_style(
self, sqlitecloud_dbapi2_connection, sqlite3_connection
):
sqlitecloud_connection = sqlitecloud_dbapi2_connection
@pytest.mark.parametrize(
"connection", ["sqlitecloud_dbapi2_connection", "sqlite3_connection"]
)
def test_execute_with_named_param_style(self, connection, request):
connection = request.getfixturevalue(connection)

select_query = "SELECT * FROM albums WHERE AlbumId = :id"
params = {"id": 1}
sqlitecloud_cursor = sqlitecloud_connection.execute(select_query, params)
sqlite3_cursor = sqlite3_connection.execute(select_query, params)
select_query = "SELECT * FROM albums WHERE AlbumId = :id and Title = :title and AlbumId = :id"
params = {"id": 1, "title": "For Those About To Rock We Salute You"}

sqlitecloud_results = sqlitecloud_cursor.fetchall()
sqlite3_results = sqlite3_cursor.fetchall()
cursor = connection.execute(select_query, params)

assert sqlitecloud_results == sqlite3_results
results = cursor.fetchall()

assert len(results) == 1
assert results[0][0] == 1

@pytest.mark.parametrize(
"connection", ["sqlitecloud_dbapi2_connection", "sqlite3_connection"]
)
def test_executemany_with_named_param_style(self, connection, request):
connection = request.getfixturevalue(connection)

select_query = "INSERT INTO customers (FirstName, Email, LastName) VALUES (:name, :email, :name)"
params = [
{"name": "pippo", "email": "pippo@disney.com"},
{"name": "pluto", "email": "pluto@disney.com"},
]

connection.executemany(select_query, params)

assert connection.total_changes == 2

@pytest.mark.skip(
reason="Rowcount does not contain the number of inserted rows yet"
Expand Down Expand Up @@ -1151,11 +1226,47 @@ def test_transaction_context_manager_on_failure(self, connection, request):
"INSERT INTO albums (Title, ArtistId) VALUES ('Test Album 1', 1)"
)
id1 = cursor.lastrowid
connection.execute("INVALID COMMAND")
connection.execute("insert into pippodd (p) values (1)")
except Exception:
assert True

cursor = connection.execute("SELECT * FROM albums WHERE AlbumId = ?", (id1,))
result = cursor.fetchone()

assert result is None

@pytest.mark.parametrize(
"connection",
[
"sqlitecloud_dbapi2_connection",
"sqlite3_connection",
],
)
def test_connection_total_changes(self, connection, request):
connection = request.getfixturevalue(connection)

new_name1 = "Jazz" + str(uuid.uuid4())
new_name2 = "Jazz" + str(uuid.uuid4())
new_name3 = "Jazz" + str(uuid.uuid4())

connection.executemany(
"INSERT INTO genres (Name) VALUES (?)",
[(new_name1,), (new_name2,)],
)
assert connection.total_changes == 2

connection.execute(
"SELECT * FROM genres WHERE Name IN (?, ?)", (new_name1, new_name2)
)
assert connection.total_changes == 2

connection.execute(
"UPDATE genres SET Name = ? WHERE Name = ?", (new_name3, new_name1)
)
assert connection.total_changes == 3

connection.execute(
"DELETE FROM genres WHERE Name in (?, ?, ?)",
(new_name1, new_name2, new_name3),
)
assert connection.total_changes == 5

0 comments on commit f18126e

Please sign in to comment.