Skip to content

Commit

Permalink
fix(tests): expected errors, escaping
Browse files Browse the repository at this point in the history
  • Loading branch information
Daniele Briggi committed Aug 23, 2024
1 parent 9d793f8 commit 168b5bd
Show file tree
Hide file tree
Showing 13 changed files with 160 additions and 287 deletions.
1 change: 1 addition & 0 deletions src/sqlitecloud/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ def __init__(self) -> None:
self.value: Optional[int] = None
self.cstart: int = 0
self.extcode: int = None
self.offcode: int = None


class SQLiteCloudValue:
Expand Down
2 changes: 1 addition & 1 deletion src/sqlitecloud/dbapi2.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,7 @@ def _ensure_connection(self):
SQLiteCloudException: If the cursor is closed.
"""
if not self._connection:
raise SQLiteCloudOperationalError("The cursor is closed.")
raise SQLiteCloudProgrammingError("The cursor is closed.")

def _adapt_parameters(self, parameters: Union[Dict, Tuple]) -> Union[Dict, Tuple]:
if isinstance(parameters, dict):
Expand Down
51 changes: 31 additions & 20 deletions src/sqlitecloud/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
)
from sqlitecloud.exceptions import (
SQLiteCloudException,
raise_sqlitecloud_error_with_extended_code,
get_sqlitecloud_error_with_extended_code,
)
from sqlitecloud.resultset import (
SQLITECLOUD_RESULT_TYPE,
Expand Down Expand Up @@ -314,7 +314,9 @@ def upload_database(
command = f"UPLOAD DATABASE '{dbname}' {keyarg}{keyvalue}"

# execute command on server side
result = self._internal_run_command(connection, command)
result = self._internal_run_command(
connection, self._internal_serialize_command(command)
)
if not result.data[0]:
raise SQLiteCloudException(
"An error occurred while initializing the upload of the database."
Expand Down Expand Up @@ -350,7 +352,9 @@ def upload_database(
# Upload completed
break
except Exception as e:
self._internal_run_command(connection, "UPLOAD ABORT")
self._internal_run_command(
connection, self._internal_serialize_command("UPLOAD ABORT")
)
raise e

def download_database(
Expand All @@ -377,7 +381,10 @@ def download_database(
"""
exists_cmd = " IF EXISTS" if if_exists else ""
result = self._internal_run_command(
connection, f"DOWNLOAD DATABASE {dbname}{exists_cmd};"
connection,
self._internal_serialize_command(
f"DOWNLOAD DATABASE {dbname}{exists_cmd};"
),
)

if result.nrows == 0:
Expand All @@ -394,7 +401,9 @@ def download_database(

try:
while progress_size < db_size:
result = self._internal_run_command(connection, "DOWNLOAD STEP")
result = self._internal_run_command(
connection, self._internal_serialize_command("DOWNLOAD STEP")
)

# res is BLOB, decode it
data = result.data[0]
Expand All @@ -408,7 +417,9 @@ def download_database(
if data_len == 0:
break
except Exception as e:
self._internal_run_command(connection, "DOWNLOAD ABORT")
self._internal_run_command(
connection, self._internal_serialize_command("DOWNLOAD ABORT")
)
raise e

def _internal_config_apply(
Expand Down Expand Up @@ -488,12 +499,6 @@ def _internal_socket_write(
command (bytes): The command to send.
main_socket (bool): If True, write to the main socket, otherwise write to the pubsub socket.
"""
# try:
# if "ATTACH DATABASE" in command.decode() or '"test_schema".table_info' in command.decode():
# pdb.set_trace()
# except:
# pass

# write buffer
if len(command) == 0:
return
Expand Down Expand Up @@ -594,30 +599,36 @@ def _internal_parse_number(
sqlitecloud_number = SQLiteCloudNumber()
sqlitecloud_number.value = 0
extvalue = 0
isext = False
offcode = 0
isext = 0
blen = len(buffer)

# from 1 to skip the first command type character
for i in range(index, blen):
c = chr(buffer[i])

# check for optional extended error code (ERRCODE:EXTERRCODE)
# check for optional extended error code (ERRCODE:EXTERRCODE:OFFCODE)
if c == ":":
isext = True
isext += 1
continue

# check for end of value
if c == " ":
sqlitecloud_number.cstart = i + 1
sqlitecloud_number.extcode = extvalue
sqlitecloud_number.offcode = offcode
return sqlitecloud_number

val = int(c) if c.isdigit() else 0

# compute numeric value
if isext:
if isext == 1:
# XERRCODE
extvalue = (extvalue * 10) + val
elif isext == 2:
# OFFCODE
offcode = (offcode * 10) + val
else:
# generic value or ERRCODE
sqlitecloud_number.value = (sqlitecloud_number.value * 10) + val

sqlitecloud_number.value = 0
Expand Down Expand Up @@ -706,7 +717,7 @@ def _internal_parse_buffer(
return SQLiteCloudResult(tag, clone)

elif cmd == SQLITECLOUD_CMD.ERROR.value:
# -LEN ERRCODE:EXTCODE ERRMSG
# -LEN ERRCODE:EXTCODE:OFFCODE ERRMSG
sqlite_number = self._internal_parse_number(buffer)
len_ = sqlite_number.value
cstart = sqlite_number.cstart
Expand All @@ -721,9 +732,9 @@ def _internal_parse_buffer(
len_ -= cstart2
errmsg = clone[cstart2:]

raise raise_sqlitecloud_error_with_extended_code(
raise get_sqlitecloud_error_with_extended_code(
errmsg.decode(), errcode, xerrcode
)
)(errmsg.decode(), errcode, xerrcode)

elif cmd in [SQLITECLOUD_CMD.ROWSET.value, SQLITECLOUD_CMD.ROWSET_CHUNK.value]:
# CMD_ROWSET: *LEN 0:VERSION ROWS COLS DATA
Expand Down
15 changes: 8 additions & 7 deletions src/sqlitecloud/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,12 @@ def __init__(self, message: str, code: int = -1, xerrcode: int = 0) -> None:
super().__init__(message, code, xerrcode)


def raise_sqlitecloud_error_with_extended_code(
def get_sqlitecloud_error_with_extended_code(
message: str, code: int, xerrcode: int
) -> None:
# Define base error codes and their corresponding exceptions
"""Mapping of sqlite error codes: https://www.sqlite.org/rescode.html"""

# define base error codes and their corresponding exceptions
base_error_mapping = {
1: SQLiteCloudOperationalError, # SQLITE_ERROR
2: SQLiteCloudInternalError, # SQLITE_INTERNAL
Expand Down Expand Up @@ -113,7 +115,7 @@ def raise_sqlitecloud_error_with_extended_code(
101: SQLiteCloudWarning, # SQLITE_DONE (not an error)
}

# Define extended error codes and their corresponding exceptions
# define extended error codes and their corresponding exceptions
extended_error_mapping = {
257: SQLiteCloudOperationalError, # SQLITE_ERROR_MISSING_COLLSEQ
279: SQLiteCloudOperationalError, # SQLITE_AUTH_USER
Expand Down Expand Up @@ -176,11 +178,10 @@ def raise_sqlitecloud_error_with_extended_code(
283: SQLiteCloudWarning, # SQLITE_NOTICE_RECOVER_ROLLBACK
284: SQLiteCloudWarning, # SQLITE_WARNING_AUTOINDEX
}
# Combine base and extended mappings

error_mapping = {**base_error_mapping, **extended_error_mapping}

# Retrieve the corresponding exception based on the error code
# retrieve the corresponding exception based on the error code
exception = error_mapping.get(xerrcode, error_mapping.get(code, SQLiteCloudError))

# Raise the corresponding exception
raise exception(message, code, xerrcode)
return exception
17 changes: 9 additions & 8 deletions src/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,7 @@ def sqlitecloud_dbapi2_connection():

yield connection

try:
next(connection_generator)
except StopIteration:
pass
close_generator(connection_generator)


def get_sqlitecloud_dbapi2_connection(detect_types: int = 0):
Expand Down Expand Up @@ -103,10 +100,7 @@ def sqlite3_connection():

yield connection

try:
next(connection_generator)
except StopIteration:
pass
close_generator(connection_generator)


def get_sqlite3_connection(detect_types: int = 0):
Expand All @@ -120,3 +114,10 @@ def get_sqlite3_connection(detect_types: int = 0):
yield connection

connection.close()


def close_generator(generator):
try:
next(generator)
except StopIteration:
pass
28 changes: 10 additions & 18 deletions src/tests/integration/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
SQLiteCloudAccount,
SQLiteCloudConnect,
)
from sqlitecloud.exceptions import SQLiteCloudException
from sqlitecloud.exceptions import (
SQLiteCloudError,
SQLiteCloudException,
SQLiteCloudOperationalError,
)
from sqlitecloud.resultset import SQLITECLOUD_RESULT_TYPE


Expand Down Expand Up @@ -56,7 +60,7 @@ def test_connection_without_credentials_and_apikey(self):

client = SQLiteCloudClient(cloud_account=account)

with pytest.raises(SQLiteCloudException):
with pytest.raises(SQLiteCloudError):
client.open_connection()

def test_connect_with_string(self):
Expand Down Expand Up @@ -124,7 +128,7 @@ def test_select(self, sqlitecloud_connection):

def test_column_not_found(self, sqlitecloud_connection):
connection, client = sqlitecloud_connection
with pytest.raises(SQLiteCloudException) as e:
with pytest.raises(SQLiteCloudOperationalError) as e:
client.exec_query("SELECT not_a_column FROM albums", connection)

assert e.value.errcode == 1
Expand Down Expand Up @@ -267,7 +271,7 @@ def test_blob_zero_length(self, sqlitecloud_connection):
def test_error(self, sqlitecloud_connection):
connection, client = sqlitecloud_connection

with pytest.raises(SQLiteCloudException) as e:
with pytest.raises(SQLiteCloudError) as e:
client.exec_query("TEST ERROR", connection)

assert e.value.errcode == 66666
Expand All @@ -276,7 +280,7 @@ def test_error(self, sqlitecloud_connection):
def test_ext_error(self, sqlitecloud_connection):
connection, client = sqlitecloud_connection

with pytest.raises(SQLiteCloudException) as e:
with pytest.raises(SQLiteCloudError) as e:
client.exec_query("TEST EXTERROR", connection)

assert e.value.errcode == 66666
Expand Down Expand Up @@ -354,7 +358,7 @@ def test_max_rowset_option_to_fail_when_rowset_is_bigger(self):

connection = client.open_connection()

with pytest.raises(SQLiteCloudException) as e:
with pytest.raises(SQLiteCloudError) as e:
client.exec_query("SELECT * FROM albums", connection)

client.disconnect(connection)
Expand Down Expand Up @@ -737,18 +741,6 @@ def test_rowset_chunk_compressed(self, sqlitecloud_connection):
assert 147 == len(rowset.data)
assert "key" == rowset.get_name(0)

def test_exec_statement_with_named_placeholder(self, sqlitecloud_connection):
connection, client = sqlitecloud_connection

result = client.exec_statement(
"SELECT * FROM albums WHERE AlbumId = :id and Title = :title",
{"id": 1, "title": "For Those About To Rock We Salute You"},
connection,
)

assert result.nrows == 1
assert result.get_value(0, 0) == 1

def test_exec_statement_with_qmarks(self, sqlitecloud_connection):
connection, client = sqlitecloud_connection

Expand Down
6 changes: 3 additions & 3 deletions src/tests/integration/test_dbapi2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import sqlitecloud
from sqlitecloud.datatypes import SQLITECLOUD_INTERNAL_ERRCODE, SQLiteCloudAccount
from sqlitecloud.exceptions import SQLiteCloudException
from sqlitecloud.exceptions import SQLiteCloudError, SQLiteCloudException


class TestDBAPI2:
Expand Down Expand Up @@ -74,7 +74,7 @@ def test_connection_execute(self, sqlitecloud_dbapi2_connection):
def test_column_not_found(self, sqlitecloud_dbapi2_connection):
connection = sqlitecloud_dbapi2_connection

with pytest.raises(SQLiteCloudException) as e:
with pytest.raises(SQLiteCloudError) as e:
connection.execute("SELECT not_a_column FROM albums")

assert e.value.errcode == 1
Expand Down Expand Up @@ -123,7 +123,7 @@ def test_integer(self, sqlitecloud_dbapi2_connection):
def test_error(self, sqlitecloud_dbapi2_connection):
connection = sqlitecloud_dbapi2_connection

with pytest.raises(SQLiteCloudException) as e:
with pytest.raises(SQLiteCloudError) as e:
connection.execute("TEST ERROR")

assert e.value.errcode == 66666
Expand Down
6 changes: 2 additions & 4 deletions src/tests/integration/test_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
import pytest

from sqlitecloud import download
from sqlitecloud.datatypes import SQLITECLOUD_ERRCODE
from sqlitecloud.exceptions import SQLiteCloudException
from sqlitecloud.exceptions import SQLiteCloudError


class TestDownload:
Expand All @@ -26,8 +25,7 @@ def test_download_missing_database(self, sqlitecloud_connection):

temp_file = tempfile.mkstemp(prefix="missing")[1]

with pytest.raises(SQLiteCloudException) as e:
with pytest.raises(SQLiteCloudError) as e:
download.download_db(connection, "missing.sqlite", temp_file)

assert e.value.errcode == SQLITECLOUD_ERRCODE.COMMAND.value
assert e.value.errmsg == "Database missing.sqlite does not exist."
3 changes: 2 additions & 1 deletion src/tests/integration/test_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ def test_insert_from_dataframe(self, sqlitecloud_dbapi2_connection):
}
)

conn.executemany("DROP TABLE IF EXISTS ?", [("PRICES",), ("TICKER_MAPPING",)])
for table in ["PRICES", "TICKER_MAPPING"]:
conn.execute(f"DROP TABLE IF EXISTS {table}")

# arg if_exists="replace" raises the error
dfprices.to_sql("PRICES", conn, index=False)
Expand Down
4 changes: 2 additions & 2 deletions src/tests/integration/test_pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest

from sqlitecloud.datatypes import SQLITECLOUD_ERRCODE, SQLITECLOUD_PUBSUB_SUBJECT
from sqlitecloud.exceptions import SQLiteCloudException
from sqlitecloud.exceptions import SQLiteCloudError
from sqlitecloud.pubsub import SQLiteCloudPubSub
from sqlitecloud.resultset import SQLITECLOUD_RESULT_TYPE, SQLiteCloudResultSet

Expand Down Expand Up @@ -78,7 +78,7 @@ def test_create_channel_to_fail_if_exists(self, sqlitecloud_connection):

pubsub.create_channel(connection, channel_name, if_not_exists=True)

with pytest.raises(SQLiteCloudException) as e:
with pytest.raises(SQLiteCloudError) as e:
pubsub.create_channel(connection, channel_name, if_not_exists=False)

assert (
Expand Down
Loading

0 comments on commit 168b5bd

Please sign in to comment.