Skip to content

Commit

Permalink
Merge pull request #906 from Mause/feature/view-reflection
Browse files Browse the repository at this point in the history
fix: support views in has_table
  • Loading branch information
Mause authored Mar 1, 2024
2 parents ca820ff + 52d6a43 commit da69a33
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 12 deletions.
9 changes: 6 additions & 3 deletions duckdb_engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,10 +411,13 @@ def get_table_oid( # type: ignore[no-untyped-def]
In the latter scenario the schema associated with the default database is used.
"""
s = """
SELECT table_oid
FROM duckdb_tables()
SELECT oid, table_name
FROM (
SELECT table_oid AS oid, table_name, database_name, schema_name FROM duckdb_tables()
UNION ALL BY NAME
SELECT view_oid AS oid , view_name AS table_name, database_name, schema_name FROM duckdb_views()
)
WHERE schema_name NOT LIKE 'pg\\_%' ESCAPE '\\'
AND table_name = :table_name
"""
sql, params = self._build_query_where(table_name=table_name, schema_name=schema)
s += sql
Expand Down
7 changes: 6 additions & 1 deletion duckdb_engine/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pytest import fixture, raises
from sqlalchemy import create_engine
from sqlalchemy.dialects import registry # type: ignore
from sqlalchemy.engine import Engine
from sqlalchemy.engine import Dialect, Engine
from sqlalchemy.engine.base import Connection
from sqlalchemy.orm import Session, sessionmaker
from typing_extensions import ParamSpec
Expand Down Expand Up @@ -33,6 +33,11 @@ def conn(engine: Engine) -> Generator[Connection, None, None]:
yield conn


@fixture()
def dialect(engine: Engine) -> Dialect:
return engine.dialect


@fixture
def session(engine: Engine) -> Session:
return sessionmaker(bind=engine)()
Expand Down
18 changes: 10 additions & 8 deletions duckdb_engine/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
types,
)
from sqlalchemy.dialects import registry # type: ignore
from sqlalchemy.engine import Engine
from sqlalchemy.engine import Connection, Engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.exc import DBAPIError
from sqlalchemy.ext.declarative import declarative_base
Expand Down Expand Up @@ -233,22 +233,24 @@ def test_get_table_names(inspector: Inspector, session: Session) -> None:
assert inspector.has_table(table_name)


def test_get_views(engine: Engine) -> None:
con = engine.connect()
views = engine.dialect.get_view_names(con)
def test_get_views(conn: Connection, dialect: Dialect) -> None:
views = dialect.get_view_names(conn)
assert views == []

con.execute(text("create view test as select 1"))
con.execute(
conn.execute(text("create view test as select 1"))
conn.execute(
text("create schema scheme; create view scheme.schema_test as select 1")
)

views = engine.dialect.get_view_names(con)
views = dialect.get_view_names(conn)
assert views == ["test"]

views = engine.dialect.get_view_names(con, schema="scheme")
views = dialect.get_view_names(conn, schema="scheme")
assert views == ["schema_test"]

assert dialect.has_table(conn, table_name="test")
assert dialect.has_table(conn, table_name="schema_test", schema="scheme")


@mark.skipif(os.uname().machine == "aarch64", reason="not supported on aarch64")
@mark.remote_data
Expand Down

0 comments on commit da69a33

Please sign in to comment.