Skip to content

Commit

Permalink
Merge pull request #835 from motherduckdb/guen/eco-25-sql-editor-mult…
Browse files Browse the repository at this point in the history
…iple-main-schemas-appear-in-schema-dropdown

fix: Support fetching multiple databases and schemas and their associated tables
  • Loading branch information
Mause authored Dec 9, 2023
2 parents 4ef3cfb + d619322 commit 8e538ce
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 1 deletion.
70 changes: 70 additions & 0 deletions duckdb_engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
import warnings
from typing import (
TYPE_CHECKING,
Expand All @@ -23,6 +24,7 @@
from sqlalchemy.dialects.postgresql.base import PGDialect, PGInspector, PGTypeCompiler
from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2
from sqlalchemy.engine.default import DefaultDialect
from sqlalchemy.engine.reflection import cache
from sqlalchemy.engine.url import URL
from sqlalchemy.ext.compiler import compiles

Expand All @@ -31,6 +33,8 @@

__version__ = "0.9.3"
sqlalchemy_version = Version(sqlalchemy.__version__)
duckdb_version: str = duckdb.__version__ # type: ignore[attr-defined]
supports_attach: bool = Version(duckdb_version) >= Version("0.7.0")

if TYPE_CHECKING:
from sqlalchemy.base import Connection
Expand Down Expand Up @@ -265,6 +269,72 @@ def get_view_names(

return [row[0] for row in rs]

@cache # type: ignore[call-arg]
def get_schema_names(self, connection: "Connection", **kw: "Any"): # type: ignore[no-untyped-def]
"""
Return unquoted database_name.schema_name unless either contains spaces or double quotes.
In that case, escape double quotes and then wrap in double quotes.
SQLAlchemy definition of a schema includes database name for databases like SQL Server (Ex: databasename.dbo)
(see https://docs.sqlalchemy.org/en/20/dialects/mssql.html#multipart-schema-names)
"""

if not supports_attach:
return super().get_schema_names(connection, **kw)

s = """
SELECT database_name, schema_name AS npspname
FROM duckdb_schemas()
WHERE schema_name NOT LIKE 'pg\\_%' ESCAPE '\\'
ORDER BY database_name, npspname
"""
rs = connection.execute(text(s))

qs = self.identifier_preparer.quote_schema
return [f"{qs(db)}.{qs(schema)}" for (db, schema) in rs]

@cache # type: ignore[call-arg]
def get_table_names(self, connection: "Connection", schema=None, **kw: "Any"): # type: ignore[no-untyped-def]
"""
Return unquoted database_name.schema_name unless either contains spaces or double quotes.
In that case, escape double quotes and then wrap in double quotes.
SQLAlchemy definition of a schema includes database name for databases like SQL Server (Ex: databasename.dbo)
(see https://docs.sqlalchemy.org/en/20/dialects/mssql.html#multipart-schema-names)
"""

if not supports_attach:
return super().get_table_names(connection, schema, **kw)

s = """
SELECT database_name, schema_name, table_name
FROM duckdb_tables()
WHERE schema_name NOT LIKE 'pg\\_%' ESCAPE '\\'
"""
params = {}
if schema is not None:
params = {"schema_name": schema}
if "." in schema:
# Get database name and schema name from schema if it contains a database name
# Format:
# <db_name>.<schema_name>
# db_name and schema_name are double quoted if contains spaces or double quotes
database_name, schema_name = (
max(s) for s in re.findall(r'"([^.]+)"|([^.]+)', schema)
)
params = {"database_name": database_name, "schema_name": schema_name}
s += "AND database_name = :database_name\n"
s += "AND schema_name = :schema_name"

rs = connection.execute(text(s), params)

return [
table
for (
db,
sc,
table,
) in rs
]

def get_indexes(
self,
connection: "Connection",
Expand Down
47 changes: 46 additions & 1 deletion duckdb_engine/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session, relationship, sessionmaker

from .. import DBAPI, Dialect
from .. import DBAPI, Dialect, supports_attach

try:
# sqlalchemy 2
Expand Down Expand Up @@ -166,6 +166,51 @@ def test_get_tables(inspector: Inspector) -> None:
assert inspector.get_view_names() == []


@mark.skipif(
supports_attach is False,
reason="ATTACH is not supported for DuckDB version < 0.7.0",
)
def test_get_schema_names(inspector: Inspector, session: Session) -> None:
# Using multi-line strings because of all the single and double quotes flying around...
cmds = [
"""CREATE SCHEMA "quack quack" """,
"""ATTACH ':memory:' AS "daffy duck" """,
"""CREATE SCHEMA "daffy duck"."quack quack" """,
"""CREATE TABLE "daffy duck"."quack quack"."t1" (i INTEGER, j INTEGER);""",
"""CREATE TABLE "daffy duck"."quack quack"."t2" (i INTEGER, j INTEGER);""",
"""CREATE SCHEMA "daffy duck"."you're "" despicable" """,
]
for cmd in cmds:
session.execute(text(cmd))
session.commit()

# Deliberately excluding pg_catalog schema (to align with Postgres)
names = inspector.get_schema_names()
if supports_attach:
assert names == [
'"daffy duck".information_schema',
'"daffy duck".main',
'"daffy duck"."quack quack"',
'"daffy duck"."you\'re "" despicable"',
"memory.information_schema",
"memory.main",
'memory."quack quack"',
"system.information_schema",
"system.main",
"temp.information_schema",
"temp.main",
]
else:
assert names == ["quack quack", "information_schema", "main", "temp"]

table_names = inspector.get_table_names(schema='"daffy duck"."quack quack"')
assert set(table_names) == {"t1", "t2"}

table_names_all = inspector.get_table_names()
assert "t1" in table_names_all
assert "t2" in table_names_all


def test_get_views(engine: Engine) -> None:
con = engine.connect()
views = engine.dialect.get_view_names(con)
Expand Down

0 comments on commit 8e538ce

Please sign in to comment.