diff --git a/src/connection.c b/src/connection.c index d1ee3a1..b651581 100644 --- a/src/connection.c +++ b/src/connection.c @@ -1410,7 +1410,7 @@ static PyObject* pysqlite_connection_set_busy_handler(pysqlite_Connection* self, static PyObject* pysqlite_connection_set_busy_timeout(pysqlite_Connection* self, PyObject* args, PyObject* kwargs) { - int busy_timeout; + double busy_timeout; static char *kwlist[] = { "timeout", NULL }; @@ -1418,13 +1418,13 @@ static PyObject* pysqlite_connection_set_busy_timeout(pysqlite_Connection* self, return NULL; } - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "i:set_busy_timeout", + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "d:set_busy_timeout", kwlist, &busy_timeout)) { return NULL; } int rc; - rc = sqlite3_busy_timeout(self->db, busy_timeout * 1000); + rc = sqlite3_busy_timeout(self->db, (int)busy_timeout * 1000); if (rc != SQLITE_OK) { PyErr_SetString(pysqlite_OperationalError, "Error setting busy timeout"); return NULL; diff --git a/tests/hooks.py b/tests/hooks.py index 9b9bd89..d0a1ede 100644 --- a/tests/hooks.py +++ b/tests/hooks.py @@ -23,6 +23,7 @@ import os import unittest + from sqlcipher3 import dbapi2 as sqlite @@ -277,12 +278,48 @@ def trace(statement): self.assertEqual(traced_statements, queries) +class TestBusyHandlerTimeout(unittest.TestCase): + def test_busy_handler(self): + accum = [] + def custom_handler(n): + accum.append(n) + return 0 if n == 3 else 1 + + self.addCleanup(os.unlink, 'busy.db') + conn1 = sqlite.connect('busy.db') + conn2 = sqlite.connect('busy.db') + conn2.set_busy_handler(custom_handler) + + conn1.execute('begin exclusive') + with self.assertRaises(sqlite.OperationalError): + conn2.execute('create table test(id)') + self.assertEqual(accum, [0, 1, 2, 3]) + accum.clear() + + conn2.set_busy_handler(None) + with self.assertRaises(sqlite.OperationalError): + conn2.execute('create table test(id)') + self.assertEqual(accum, []) + + conn2.set_busy_handler(custom_handler) + with self.assertRaises(sqlite.OperationalError): + conn2.execute('create table test(id)') + self.assertEqual(accum, [0, 1, 2, 3]) + accum.clear() + + conn2.set_busy_timeout(0.01) # Clears busy handler. + with self.assertRaises(sqlite.OperationalError): + conn2.execute('create table test(id)') + self.assertEqual(accum, []) + + def suite(): loader = unittest.TestLoader() tests = [loader.loadTestsFromTestCase(t) for t in ( CollationTests, ProgressTests, - TraceCallbackTests)] + TraceCallbackTests, + TestBusyHandlerTimeout)] return unittest.TestSuite(tests) def test():