diff --git a/duckdb_engine/__init__.py b/duckdb_engine/__init__.py index 3d521e25..8ca96ffe 100644 --- a/duckdb_engine/__init__.py +++ b/duckdb_engine/__init__.py @@ -1,3 +1,4 @@ +import re import warnings from typing import ( TYPE_CHECKING, @@ -23,6 +24,7 @@ from sqlalchemy.dialects.postgresql.base import PGDialect, PGInspector, PGTypeCompiler from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2 from sqlalchemy.engine.default import DefaultDialect +from sqlalchemy.engine.reflection import cache from sqlalchemy.engine.url import URL from sqlalchemy.ext.compiler import compiles @@ -31,6 +33,8 @@ __version__ = "0.9.3" sqlalchemy_version = Version(sqlalchemy.__version__) +duckdb_version: str = duckdb.__version__ # type: ignore[attr-defined] +supports_attach: bool = Version(duckdb_version) >= Version("0.7.0") if TYPE_CHECKING: from sqlalchemy.base import Connection @@ -265,6 +269,72 @@ def get_view_names( return [row[0] for row in rs] + @cache # type: ignore[call-arg] + def get_schema_names(self, connection: "Connection", **kw: "Any"): # type: ignore[no-untyped-def] + """ + Return unquoted database_name.schema_name unless either contains spaces or double quotes. + In that case, escape double quotes and then wrap in double quotes. + SQLAlchemy definition of a schema includes database name for databases like SQL Server (Ex: databasename.dbo) + (see https://docs.sqlalchemy.org/en/20/dialects/mssql.html#multipart-schema-names) + """ + + if not supports_attach: + return super().get_schema_names(connection, **kw) + + s = """ + SELECT database_name, schema_name AS npspname + FROM duckdb_schemas() + WHERE schema_name NOT LIKE 'pg\\_%' ESCAPE '\\' + ORDER BY database_name, npspname + """ + rs = connection.execute(text(s)) + + qs = self.identifier_preparer.quote_schema + return [f"{qs(db)}.{qs(schema)}" for (db, schema) in rs] + + @cache # type: ignore[call-arg] + def get_table_names(self, connection: "Connection", schema=None, **kw: "Any"): # type: ignore[no-untyped-def] + """ + Return unquoted database_name.schema_name unless either contains spaces or double quotes. + In that case, escape double quotes and then wrap in double quotes. + SQLAlchemy definition of a schema includes database name for databases like SQL Server (Ex: databasename.dbo) + (see https://docs.sqlalchemy.org/en/20/dialects/mssql.html#multipart-schema-names) + """ + + if not supports_attach: + return super().get_table_names(connection, schema, **kw) + + s = """ + SELECT database_name, schema_name, table_name + FROM duckdb_tables() + WHERE schema_name NOT LIKE 'pg\\_%' ESCAPE '\\' + """ + params = {} + if schema is not None: + params = {"schema_name": schema} + if "." in schema: + # Get database name and schema name from schema if it contains a database name + # Format: + # . + # db_name and schema_name are double quoted if contains spaces or double quotes + database_name, schema_name = ( + max(s) for s in re.findall(r'"([^.]+)"|([^.]+)', schema) + ) + params = {"database_name": database_name, "schema_name": schema_name} + s += "AND database_name = :database_name\n" + s += "AND schema_name = :schema_name" + + rs = connection.execute(text(s), params) + + return [ + table + for ( + db, + sc, + table, + ) in rs + ] + def get_indexes( self, connection: "Connection", diff --git a/duckdb_engine/tests/test_basic.py b/duckdb_engine/tests/test_basic.py index e9fe0d3f..500ed4dd 100644 --- a/duckdb_engine/tests/test_basic.py +++ b/duckdb_engine/tests/test_basic.py @@ -35,7 +35,7 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import Session, relationship, sessionmaker -from .. import DBAPI, Dialect +from .. import DBAPI, Dialect, supports_attach try: # sqlalchemy 2 @@ -166,6 +166,51 @@ def test_get_tables(inspector: Inspector) -> None: assert inspector.get_view_names() == [] +@mark.skipif( + supports_attach is False, + reason="ATTACH is not supported for DuckDB version < 0.7.0", +) +def test_get_schema_names(inspector: Inspector, session: Session) -> None: + # Using multi-line strings because of all the single and double quotes flying around... + cmds = [ + """CREATE SCHEMA "quack quack" """, + """ATTACH ':memory:' AS "daffy duck" """, + """CREATE SCHEMA "daffy duck"."quack quack" """, + """CREATE TABLE "daffy duck"."quack quack"."t1" (i INTEGER, j INTEGER);""", + """CREATE TABLE "daffy duck"."quack quack"."t2" (i INTEGER, j INTEGER);""", + """CREATE SCHEMA "daffy duck"."you're "" despicable" """, + ] + for cmd in cmds: + session.execute(text(cmd)) + session.commit() + + # Deliberately excluding pg_catalog schema (to align with Postgres) + names = inspector.get_schema_names() + if supports_attach: + assert names == [ + '"daffy duck".information_schema', + '"daffy duck".main', + '"daffy duck"."quack quack"', + '"daffy duck"."you\'re "" despicable"', + "memory.information_schema", + "memory.main", + 'memory."quack quack"', + "system.information_schema", + "system.main", + "temp.information_schema", + "temp.main", + ] + else: + assert names == ["quack quack", "information_schema", "main", "temp"] + + table_names = inspector.get_table_names(schema='"daffy duck"."quack quack"') + assert set(table_names) == {"t1", "t2"} + + table_names_all = inspector.get_table_names() + assert "t1" in table_names_all + assert "t2" in table_names_all + + def test_get_views(engine: Engine) -> None: con = engine.connect() views = engine.dialect.get_view_names(con)