Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Support fetching multiple databases and schemas and their associated tables #835

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions duckdb_engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
import warnings
from typing import (
TYPE_CHECKING,
Expand All @@ -23,6 +24,7 @@
from sqlalchemy.dialects.postgresql.base import PGDialect, PGInspector, PGTypeCompiler
from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2
from sqlalchemy.engine.default import DefaultDialect
from sqlalchemy.engine.reflection import cache
from sqlalchemy.engine.url import URL
from sqlalchemy.ext.compiler import compiles

Expand All @@ -31,6 +33,8 @@

__version__ = "0.9.3"
sqlalchemy_version = Version(sqlalchemy.__version__)
duckdb_version: str = duckdb.__version__ # type: ignore[attr-defined]
supports_attach: bool = Version(duckdb_version) >= Version("0.7.0")

if TYPE_CHECKING:
from sqlalchemy.base import Connection
Expand Down Expand Up @@ -265,6 +269,72 @@

return [row[0] for row in rs]

@cache # type: ignore[call-arg]
def get_schema_names(self, connection: "Connection", **kw: "Any"): # type: ignore[no-untyped-def]
"""
Return unquoted database_name.schema_name unless either contains spaces or double quotes.
In that case, escape double quotes and then wrap in double quotes.
SQLAlchemy definition of a schema includes database name for databases like SQL Server (Ex: databasename.dbo)
(see https://docs.sqlalchemy.org/en/20/dialects/mssql.html#multipart-schema-names)
"""

if not supports_attach:
return super().get_schema_names(connection, **kw)

Check warning on line 282 in duckdb_engine/__init__.py

View check run for this annotation

Codecov / codecov/patch

duckdb_engine/__init__.py#L282

Added line #L282 was not covered by tests

s = """
SELECT database_name, schema_name AS npspname
FROM duckdb_schemas()
WHERE schema_name NOT LIKE 'pg\\_%' ESCAPE '\\'
ORDER BY database_name, npspname
"""
rs = connection.execute(text(s))

qs = self.identifier_preparer.quote_schema
return [f"{qs(db)}.{qs(schema)}" for (db, schema) in rs]

@cache # type: ignore[call-arg]
def get_table_names(self, connection: "Connection", schema=None, **kw: "Any"): # type: ignore[no-untyped-def]
"""
Return unquoted database_name.schema_name unless either contains spaces or double quotes.
In that case, escape double quotes and then wrap in double quotes.
SQLAlchemy definition of a schema includes database name for databases like SQL Server (Ex: databasename.dbo)
(see https://docs.sqlalchemy.org/en/20/dialects/mssql.html#multipart-schema-names)
"""

if not supports_attach:
return super().get_table_names(connection, schema, **kw)

Check warning on line 305 in duckdb_engine/__init__.py

View check run for this annotation

Codecov / codecov/patch

duckdb_engine/__init__.py#L305

Added line #L305 was not covered by tests

s = """
SELECT database_name, schema_name, table_name
FROM duckdb_tables()
WHERE schema_name NOT LIKE 'pg\\_%' ESCAPE '\\'
"""
params = {}
if schema is not None:
params = {"schema_name": schema}
if "." in schema:
# Get database name and schema name from schema if it contains a database name
# Format:
# <db_name>.<schema_name>
# db_name and schema_name are double quoted if contains spaces or double quotes
database_name, schema_name = (
max(s) for s in re.findall(r'"([^.]+)"|([^.]+)', schema)
)
params = {"database_name": database_name, "schema_name": schema_name}
s += "AND database_name = :database_name\n"
s += "AND schema_name = :schema_name"

rs = connection.execute(text(s), params)

return [
table
for (
db,
sc,
table,
) in rs
]

def get_indexes(
self,
connection: "Connection",
Expand Down
47 changes: 46 additions & 1 deletion duckdb_engine/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session, relationship, sessionmaker

from .. import DBAPI, Dialect
from .. import DBAPI, Dialect, supports_attach

try:
# sqlalchemy 2
Expand Down Expand Up @@ -166,6 +166,51 @@
assert inspector.get_view_names() == []


@mark.skipif(
supports_attach is False,
reason="ATTACH is not supported for DuckDB version < 0.7.0",
)
def test_get_schema_names(inspector: Inspector, session: Session) -> None:
# Using multi-line strings because of all the single and double quotes flying around...
cmds = [
"""CREATE SCHEMA "quack quack" """,
"""ATTACH ':memory:' AS "daffy duck" """,
"""CREATE SCHEMA "daffy duck"."quack quack" """,
"""CREATE TABLE "daffy duck"."quack quack"."t1" (i INTEGER, j INTEGER);""",
"""CREATE TABLE "daffy duck"."quack quack"."t2" (i INTEGER, j INTEGER);""",
"""CREATE SCHEMA "daffy duck"."you're "" despicable" """,
]
for cmd in cmds:
session.execute(text(cmd))
session.commit()

# Deliberately excluding pg_catalog schema (to align with Postgres)
names = inspector.get_schema_names()
if supports_attach:
assert names == [
'"daffy duck".information_schema',
'"daffy duck".main',
'"daffy duck"."quack quack"',
'"daffy duck"."you\'re "" despicable"',
"memory.information_schema",
"memory.main",
'memory."quack quack"',
"system.information_schema",
"system.main",
"temp.information_schema",
"temp.main",
]
else:
assert names == ["quack quack", "information_schema", "main", "temp"]

Check warning on line 204 in duckdb_engine/tests/test_basic.py

View check run for this annotation

Codecov / codecov/patch

duckdb_engine/tests/test_basic.py#L204

Added line #L204 was not covered by tests

table_names = inspector.get_table_names(schema='"daffy duck"."quack quack"')
assert set(table_names) == {"t1", "t2"}

table_names_all = inspector.get_table_names()
assert "t1" in table_names_all
assert "t2" in table_names_all


def test_get_views(engine: Engine) -> None:
con = engine.connect()
views = engine.dialect.get_view_names(con)
Expand Down