Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Support fetching multiple databases and schemas and their associated tables #835

Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions duckdb_engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
import warnings
from typing import (
TYPE_CHECKING,
Expand All @@ -23,6 +24,7 @@
from sqlalchemy.dialects.postgresql.base import PGDialect, PGInspector, PGTypeCompiler
from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2
from sqlalchemy.engine.default import DefaultDialect
from sqlalchemy.engine.reflection import cache
from sqlalchemy.engine.url import URL
from sqlalchemy.ext.compiler import compiles

Expand All @@ -31,6 +33,8 @@

__version__ = "0.9.3"
sqlalchemy_version = Version(sqlalchemy.__version__)
duckdb_version: str = duckdb.__version__ # type: ignore[attr-defined]
supports_attach: bool = Version(duckdb_version) >= Version("0.7.0")

if TYPE_CHECKING:
from sqlalchemy.base import Connection
Expand Down Expand Up @@ -265,6 +269,72 @@ def get_view_names(

return [row[0] for row in rs]

@cache # type: ignore[call-arg]
def get_schema_names(self, connection: "Connection", **kw: "Any"): # type: ignore[no-untyped-def]
"""
Return unquoted database_name.schema_name unless either contains spaces or double quotes.
In that case, escape double quotes and then wrap in double quotes.
SQLAlchemy definition of a schema includes database name for databases like SQL Server (Ex: databasename.dbo)
(see https://docs.sqlalchemy.org/en/20/dialects/mssql.html#multipart-schema-names)
"""

if not supports_attach:
return super().get_schema_names(connection, **kw)

s = """
SELECT database_name, schema_name AS npspname
FROM duckdb_schemas()
WHERE schema_name NOT LIKE 'pg\\_%' ESCAPE '\\'
ORDER BY database_name, npspname
"""
rs = connection.execute(text(s))

qs = self.identifier_preparer.quote_schema
return [f"{qs(db)}.{qs(schema)}" for (db, schema) in rs]

@cache # type: ignore[call-arg]
def get_table_names(self, connection: "Connection", schema=None, **kw: "Any"): # type: ignore[no-untyped-def]
"""
Return unquoted database_name.schema_name unless either contains spaces or double quotes.
In that case, escape double quotes and then wrap in double quotes.
SQLAlchemy definition of a schema includes database name for databases like SQL Server (Ex: databasename.dbo)
(see https://docs.sqlalchemy.org/en/20/dialects/mssql.html#multipart-schema-names)
"""

if not supports_attach:
return super().get_table_names(connection, schema, **kw)

s = """
SELECT database_name, schema_name, table_name
FROM duckdb_tables()
WHERE schema_name NOT LIKE 'pg\\_%' ESCAPE '\\'
"""
params = {}
if schema is not None:
params = {"schema_name": schema}
if "." in schema:
# Get database name and schema name from schema if it contains a database name
# Format:
# <db_name>.<schema_name>
# db_name and schema_name are double quoted if contains spaces or double quotes
database_name, schema_name = (
max(s) for s in re.findall(r'"([^.]+)"|([^.]+)', schema)
)
params = {"database_name": database_name, "schema_name": schema_name}
s += "AND database_name = :database_name\n"
s += "AND schema_name = :schema_name"

rs = connection.execute(text(s), params)

return [
table
for (
db,
sc,
table,
) in rs
]

def get_indexes(
self,
connection: "Connection",
Expand Down
47 changes: 46 additions & 1 deletion duckdb_engine/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session, relationship, sessionmaker

from .. import DBAPI, Dialect
from .. import DBAPI, Dialect, supports_attach

try:
# sqlalchemy 2
Expand Down Expand Up @@ -166,6 +166,51 @@ def test_get_tables(inspector: Inspector) -> None:
assert inspector.get_view_names() == []


@mark.skipif(
supports_attach is False,
reason="ATTACH is not supported for DuckDB version < 0.7.0",
)
def test_get_schema_names(inspector: Inspector, engine: Engine) -> None:
# Run each command in a separate transaction
for cmd in [
"""CREATE SCHEMA "quack quack" """,
"""ATTACH ':memory:' AS "daffy duck" """,
"""CREATE SCHEMA "daffy duck"."quack quack" """,
"""CREATE TABLE "daffy duck"."quack quack"."t1" (i INTEGER, j INTEGER);""",
"""CREATE TABLE "daffy duck"."quack quack"."t2" (i INTEGER, j INTEGER);""",
"""CREATE SCHEMA "daffy duck"."you're "" despicable" """,
]:
with engine.connect() as conn:
guenp marked this conversation as resolved.
Show resolved Hide resolved
# Using multi-line strings because of all the single and double quotes flying around...
conn.execute(text(cmd))

# Deliberately excluding pg_catalog schema (to align with Postgres)
names = inspector.get_schema_names()
if supports_attach:
assert set(names) == {
'memory."quack quack"',
"memory.information_schema",
"memory.main",
'"daffy duck"."you\'re "" despicable"',
'"daffy duck"."quack quack"',
'"daffy duck".information_schema',
'"daffy duck".main',
"system.information_schema",
"system.main",
"temp.information_schema",
"temp.main",
}
else:
assert names == ["quack quack", "information_schema", "main", "temp"]

table_names = inspector.get_table_names(schema='"daffy duck"."quack quack"')
assert set(table_names) == {"t1", "t2"}

table_names_all = inspector.get_table_names()
assert "t1" in table_names_all
assert "t2" in table_names_all


def test_get_views(engine: Engine) -> None:
con = engine.connect()
views = engine.dialect.get_view_names(con)
Expand Down
Loading