Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Sep 4, 2023
1 parent 280f646 commit 83cdd55
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 26 deletions.
28 changes: 22 additions & 6 deletions duckdb_engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ def Binary(x: Any) -> Any:

class DuckDBInspector(PGInspector):
def get_check_constraints(
self, table_name: str, schema: Optional[str] = None, **kw: Any,
self,
table_name: str,
schema: Optional[str] = None,
**kw: Any,
) -> List[Dict[str, Any]]:
try:
return super().get_check_constraints(table_name, schema, **kw)
Expand Down Expand Up @@ -253,7 +256,8 @@ def get_view_names(
) -> Any:
s = "SELECT table_name FROM information_schema.tables WHERE table_type='VIEW' and table_schema=:schema_name"
rs = connection.execute(
text(s), {"schema_name": schema if schema is not None else "main"},
text(s),
{"schema_name": schema if schema is not None else "main"},
)

return [row[0] for row in rs]
Expand Down Expand Up @@ -292,10 +296,18 @@ def import_dbapi(cls: Type["Dialect"]) -> Type[DBAPI]:
return cls.dbapi()

def do_executemany(
self, cursor: Any, statement: Any, parameters: Any, context: Optional[Any] = ...,
self,
cursor: Any,
statement: Any,
parameters: Any,
context: Optional[Any] = ...,
) -> None:
return DefaultDialect.do_executemany(
self, cursor, statement, parameters, context,
self,
cursor,
statement,
parameters,
context,
)

# FIXME: this method is a hack around the fact that we use a single cursor for all queries inside a connection,
Expand Down Expand Up @@ -338,7 +350,9 @@ def get_multi_columns(
domains = {
((d["schema"], d["name"]) if not d["visible"] else (d["name"],)): d
for d in self._load_domains( # type: ignore[attr-defined]
connection, schema="*", info_cache=kw.get("info_cache"),
connection,
schema="*",
info_cache=kw.get("info_cache"),
)
}

Expand All @@ -349,7 +363,9 @@ def get_multi_columns(
if rec["visible"]
else ((rec["schema"], rec["name"]), rec)
for rec in self._load_enums( # type: ignore[attr-defined]
connection, schema="*", info_cache=kw.get("info_cache"),
connection,
schema="*",
info_cache=kw.get("info_cache"),
)
)

Expand Down
26 changes: 12 additions & 14 deletions duckdb_engine/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@ def compile_uint(element: Integer, compiler: PGTypeCompiler, **kw: Any) -> str:


class Struct(TypeEngine):
"""
Represents a STRUCT type in DuckDB
"""Represents a STRUCT type in DuckDB.
```python
from duckdb_engine.datatypes import Struct
Expand All @@ -110,13 +109,12 @@ class Struct(TypeEngine):

__visit_name__ = "struct"

def __init__(self, fields: Optional[Dict[str, TV]] = None):
def __init__(self, fields: Optional[Dict[str, TV]] = None) -> None:
self.fields = fields


class Map(TypeEngine):
"""
Represents a MAP type in DuckDB
"""Represents a MAP type in DuckDB.
```python
from duckdb_engine.datatypes import Map
Expand All @@ -133,26 +131,25 @@ class Map(TypeEngine):
key_type: TV
value_type: TV

def __init__(self, key_type: TV, value_type: TV):
def __init__(self, key_type: TV, value_type: TV) -> None:
self.key_type = key_type
self.value_type = value_type

def bind_processor(
self, dialect: Dialect
self, dialect: Dialect,
) -> Optional[Callable[[Optional[dict]], Optional[dict]]]:
return lambda value: (
{"key": list(value), "value": list(value.values())} if value else None
)

def result_processor(
self, dialect: Dialect, coltype: str
self, dialect: Dialect, coltype: str,
) -> Optional[Callable[[Optional[dict]], Optional[dict]]]:
return lambda value: dict(zip(value["key"], value["value"])) if value else {}


class Union(TypeEngine):
"""
Represents a UNION type in DuckDB
"""Represents a UNION type in DuckDB.
```python
from duckdb_engine.datatypes import Union
Expand All @@ -168,7 +165,7 @@ class Union(TypeEngine):
__visit_name__ = "union"
fields: Dict[str, TV]

def __init__(self, fields: Dict[str, TV]):
def __init__(self, fields: Dict[str, TV]) -> None:
self.fields = fields


Expand Down Expand Up @@ -218,14 +215,15 @@ def struct_or_union(
) -> str:
fields = instance.fields
if fields is None:
raise exc.CompileError(f"DuckDB {repr(instance)} type requires fields")
msg = f"DuckDB {repr(instance)} type requires fields"
raise exc.CompileError(msg)
return "({})".format(
", ".join(
"{} {}".format(
identifier_preparer.quote_identifier(key), process_type(value, compiler)
identifier_preparer.quote_identifier(key), process_type(value, compiler),
)
for key, value in fields.items()
)
),
)


Expand Down
8 changes: 5 additions & 3 deletions duckdb_engine/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,8 @@ def test_url_config_and_dict_config() -> None:
with eng.connect() as conn:
res = conn.execute(
text(
"select current_setting('worker_threads'), current_setting('memory_limit')"
)
"select current_setting('worker_threads'), current_setting('memory_limit')",
),
)
row = res.first()
assert row is not None
Expand All @@ -395,7 +395,9 @@ def test_url_config_and_dict_config() -> None:

def test_do_ping(tmp_path: Path, caplog: LogCaptureFixture) -> None:
engine = create_engine(
"duckdb:///" + str(tmp_path / "db"), pool_pre_ping=True, pool_size=1,
"duckdb:///" + str(tmp_path / "db"),
pool_pre_ping=True,
pool_size=1,
)

logger = cast(logging.Logger, engine.pool.logger) # type: ignore
Expand Down
5 changes: 3 additions & 2 deletions duckdb_engine/tests/test_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

@mark.parametrize("coltype", types)
def test_unsigned_integer_type(
engine: Engine, session: Session, coltype: Type[Integer],
engine: Engine,
session: Session,
coltype: Type[Integer],
) -> None:
Base = declarative_base()

Expand Down Expand Up @@ -126,7 +128,6 @@ class Entry(base):
id = Column(Integer, primary_key=True, default=0)
struct = Column(Struct(fields={"name": String}))
map = Column(Map(String, Integer))
# union = Column(Union(fields={"name": String, "age": Integer}))

base.metadata.create_all(bind=engine)

Expand Down
2 changes: 1 addition & 1 deletion duckdb_engine/tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_integration(engine: Engine) -> None:

@mark.remote_data
@mark.skipif(
"dev" in duckdb.__version__, reason="md extension not available for dev builds" # type: ignore[attr-defined]
"dev" in duckdb.__version__, reason="md extension not available for dev builds", # type: ignore[attr-defined]
)
def test_motherduck() -> None:
importorskip("duckdb", "0.7.1")
Expand Down

0 comments on commit 83cdd55

Please sign in to comment.