diff --git a/duckdb_engine/__init__.py b/duckdb_engine/__init__.py index 8cea1a91..3516bdb4 100644 --- a/duckdb_engine/__init__.py +++ b/duckdb_engine/__init__.py @@ -15,6 +15,7 @@ import duckdb import sqlalchemy +from packaging.version import Version from sqlalchemy import pool, text from sqlalchemy import types as sqltypes from sqlalchemy import util @@ -29,6 +30,7 @@ from .datatypes import ISCHEMA_NAMES, register_extension_types __version__ = "0.9.2" +sqlalchemy_version = Version(sqlalchemy.__version__) if TYPE_CHECKING: from sqlalchemy.base import Connection @@ -39,7 +41,7 @@ class DBAPI: - paramstyle = duckdb.paramstyle + paramstyle = "numeric_dollar" if sqlalchemy_version >= Version("2.0.0") else "qmark" apilevel = duckdb.apilevel threadsafety = duckdb.threadsafety @@ -134,7 +136,11 @@ def execute( try: if statement.lower() == "commit": # this is largely for ipython-sql self.__c.commit() - elif statement.lower() in ("register", "register(?, ?)"): + elif statement.lower() in ( + "register", + "register(?, ?)", + "register($1, $2)", + ): assert parameters and len(parameters) == 2, parameters view_name, df = parameters self.__c.register(view_name, df) diff --git a/duckdb_engine/tests/test_basic.py b/duckdb_engine/tests/test_basic.py index b474de0f..e9fe0d3f 100644 --- a/duckdb_engine/tests/test_basic.py +++ b/duckdb_engine/tests/test_basic.py @@ -430,6 +430,8 @@ def test_params(engine: Engine) -> None: def test_361(engine: Engine) -> None: + importorskip("sqlalchemy", "2.0.0") + with engine.connect() as conn: conn.execute(text("create table test (dt date);")) conn.execute(text("insert into test values ('2022-01-01');")) @@ -440,10 +442,5 @@ def test_361(engine: Engine) -> None: part = "year" date_part = func.date_part(part, test.c.dt) - stmt = ( - select(date_part) - .select_from(test) - .group_by(date_part) - .compile(dialect=engine.dialect, compile_kwargs={"literal_binds": True}) - ) - conn.execute(stmt).fetchall() + stmt = select(date_part).select_from(test).group_by(date_part) + assert conn.execute(stmt).fetchall() == [(2022,)]