Skip to content

Commit

Permalink
Merge pull request #928 from Mause/cursor-wrapper
Browse files Browse the repository at this point in the history
fix: allow connections to be properly closed
  • Loading branch information
Mause authored Apr 18, 2024
2 parents 6dbeb82 + 4e0aeb4 commit e10374a
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 50 deletions.
70 changes: 31 additions & 39 deletions duckdb_engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
Set,
Tuple,
Type,
cast,
)

import duckdb
Expand Down Expand Up @@ -88,49 +87,26 @@ def __init__(self, c: duckdb.DuckDBPyConnection) -> None:
self.__c = c
self.notices = list()

def cursor(self) -> "Connection":
return self

def fetchmany(self, size: Optional[int] = None) -> List:
if hasattr(self.__c, "fetchmany"):
# fetchmany was only added in 0.5.0
if size is None:
return self.__c.fetchmany()
else:
return self.__c.fetchmany(size)

try:
return cast(list, self.__c.fetch_df_chunk().values.tolist())
except RuntimeError as e:
if e.args[0].startswith(
"Invalid Input Error: Attempting to fetch from an unsuccessful or closed streaming query result"
):
return []
else:
raise e

@property
def c(self) -> duckdb.DuckDBPyConnection:
warnings.warn(
"Directly accessing the internal connection object is deprecated (please go via the __getattr__ impl)",
DeprecationWarning,
)
return self.__c
def cursor(self) -> "CursorWrapper":
return CursorWrapper(self.__c, self)

def __getattr__(self, name: str) -> Any:
return getattr(self.__c, name)

@property
def connection(self) -> "Connection":
return self

def close(self) -> None:
# duckdb doesn't support 'soft closes'
pass
self.__c.close()
self.closed = True

@property
def rowcount(self) -> int:
return -1

class CursorWrapper:
__c: duckdb.DuckDBPyConnection
__connection_wrapper: "ConnectionWrapper"

def __init__(
self, c: duckdb.DuckDBPyConnection, connection_wrapper: "ConnectionWrapper"
) -> None:
self.__c = c
self.__connection_wrapper = connection_wrapper

def executemany(
self,
Expand Down Expand Up @@ -172,6 +148,22 @@ def execute(
else:
raise e

@property
def connection(self) -> "Connection":
return self.__connection_wrapper

def close(self) -> None:
pass # closing cursors is not supported in duckdb

def __getattr__(self, name: str) -> Any:
return getattr(self.__c, name)

def fetchmany(self, size: Optional[int] = None) -> List:
if size is None:
return self.__c.fetchmany()
else:
return self.__c.fetchmany(size)


class DuckDBEngineWarning(Warning):
pass
Expand Down Expand Up @@ -319,7 +311,7 @@ def do_rollback(self, connection: "Connection") -> None:
raise e

def do_begin(self, connection: "Connection") -> None:
connection.execute("begin")
connection.begin()

def get_view_names(
self,
Expand Down
12 changes: 12 additions & 0 deletions duckdb_engine/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,12 @@ def test_fetch_df_chunks() -> None:
duckdb.connect(":memory:").execute("select 1").fetch_df_chunk(1)


def test_fetchmany(engine: Engine) -> None:
with engine.connect() as conn:
res = conn.execute(text("select 1"))
assert res.fetchmany(1) == [(1,)]


def test_description() -> None:
import duckdb

Expand Down Expand Up @@ -594,3 +600,9 @@ def test_361(engine: Engine) -> None:

stmt = select(date_part).select_from(test).group_by(date_part)
assert conn.execute(stmt).fetchall() == [(2022,)]


def test_close(engine: Engine) -> None:
with engine.connect() as conn:
res = conn.execute(text("select 1"))
res.close()
11 changes: 0 additions & 11 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit e10374a

Please sign in to comment.