Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] authored and Mause committed Sep 7, 2023
1 parent 82523f1 commit 4d5cb5d
Show file tree
Hide file tree
Showing 10 changed files with 55 additions and 65 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,4 @@ repos:
- id: ruff
args:
- --fix
- --exit-non-zero-on-fix
- --exit-non-zero-on-fix
27 changes: 12 additions & 15 deletions duckdb_engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@

import duckdb
import sqlalchemy
from sqlalchemy import pool, text
from sqlalchemy import pool, text, util
from sqlalchemy import types as sqltypes
from sqlalchemy import util
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.dialects.postgresql.base import PGDialect, PGInspector, PGTypeCompiler
from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2
Expand Down Expand Up @@ -55,12 +54,12 @@ def Binary(x: Any) -> Any:

class DuckDBInspector(PGInspector):
def get_check_constraints(
self, table_name: str, schema: Optional[str] = None, **kw: Any
self, table_name: str, schema: Optional[str] = None, **kw: Any,
) -> List[Dict[str, Any]]:
try:
return super().get_check_constraints(table_name, schema, **kw)
except Exception as e:
raise NotImplementedError() from e
raise NotImplementedError from e

Check warning on line 62 in duckdb_engine/__init__.py

View check run for this annotation

Codecov / codecov/patch

duckdb_engine/__init__.py#L62

Added line #L62 was not covered by tests


class ConnectionWrapper:
Expand All @@ -71,7 +70,7 @@ class ConnectionWrapper:

def __init__(self, c: duckdb.DuckDBPyConnection) -> None:
self.__c = c
self.notices = list()
self.notices = []

def cursor(self) -> "Connection":
return self
Expand All @@ -88,7 +87,7 @@ def fetchmany(self, size: Optional[int] = None) -> List:
return cast(list, self.__c.fetch_df_chunk().values.tolist())
except RuntimeError as e:
if e.args[0].startswith(
"Invalid Input Error: Attempting to fetch from an unsuccessful or closed streaming query result"
"Invalid Input Error: Attempting to fetch from an unsuccessful or closed streaming query result",
):
return []
else:
Expand Down Expand Up @@ -230,7 +229,7 @@ def _get_server_version_info(self, connection: "Connection") -> Tuple[int, int]:
return (8, 0)

def get_default_isolation_level(self, connection: "Connection") -> None:
raise NotImplementedError()
raise NotImplementedError

def do_rollback(self, connection: "Connection") -> None:
try:
Expand All @@ -254,7 +253,7 @@ def get_view_names(
) -> Any:
s = "SELECT table_name FROM information_schema.tables WHERE table_type='VIEW' and table_schema=:schema_name"
rs = connection.execute(
text(s), {"schema_name": schema if schema is not None else "main"}
text(s), {"schema_name": schema if schema is not None else "main"},
)

return [row[0] for row in rs]
Expand Down Expand Up @@ -293,10 +292,10 @@ def import_dbapi(cls: Type["Dialect"]) -> Type[DBAPI]:
return cls.dbapi()

def do_executemany(
self, cursor: Any, statement: Any, parameters: Any, context: Optional[Any] = ...
self, cursor: Any, statement: Any, parameters: Any, context: Optional[Any] = ...,
) -> None:
return DefaultDialect.do_executemany(
self, cursor, statement, parameters, context
self, cursor, statement, parameters, context,
)

# FIXME: this method is a hack around the fact that we use a single cursor for all queries inside a connection,
Expand All @@ -310,8 +309,7 @@ def get_multi_columns(
kind: Optional[Tuple[str, ...]] = None,
**kw: Any,
) -> List:
"""
Copyright 2005-2023 SQLAlchemy authors and contributors <see AUTHORS file>.
"""Copyright 2005-2023 SQLAlchemy authors and contributors <see AUTHORS file>.
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
Expand All @@ -331,7 +329,6 @@ def get_multi_columns(
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""

has_filter_names, params = self._prepare_filter_names(filter_names) # type: ignore[attr-defined]
query = self._columns_query(schema, has_filter_names, scope, kind) # type: ignore[attr-defined]
rows = list(connection.execute(query, params).mappings())
Expand All @@ -341,7 +338,7 @@ def get_multi_columns(
domains = {
((d["schema"], d["name"]) if not d["visible"] else (d["name"],)): d
for d in self._load_domains( # type: ignore[attr-defined]
connection, schema="*", info_cache=kw.get("info_cache")
connection, schema="*", info_cache=kw.get("info_cache"),
)
}

Expand All @@ -352,7 +349,7 @@ def get_multi_columns(
if rec["visible"]
else ((rec["schema"], rec["name"]), rec)
for rec in self._load_enums( # type: ignore[attr-defined]
connection, schema="*", info_cache=kw.get("info_cache")
connection, schema="*", info_cache=kw.get("info_cache"),
)
)

Expand Down
37 changes: 18 additions & 19 deletions duckdb_engine/datatypes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""
See https://duckdb.org/docs/sql/data_types/numeric for more information
"""See https://duckdb.org/docs/sql/data_types/numeric for more information.
Also
```sql
Expand Down Expand Up @@ -32,21 +31,23 @@ class UInt32(Integer):


class UInt16(Integer):
"AKA USMALLINT"
"AKA USMALLINT."


class UInt8(Integer):
pass


class UTinyInteger(Integer):
"AKA UInt1"
"AKA UInt1."

name = "UTinyInt"
# UTINYINT - 0 255


class TinyInteger(Integer):
"AKA Int1"
"AKA Int1."

name = "TinyInt"
# TINYINT INT1 -128 127

Expand Down Expand Up @@ -91,8 +92,7 @@ def compile_uint(element: Integer, compiler: PGTypeCompiler, **kw: Any) -> str:


class Struct(TypeEngine):
"""
Represents a STRUCT type in DuckDB
"""Represents a STRUCT type in DuckDB.
```python
from duckdb_engine.datatypes import Struct
Expand All @@ -109,13 +109,12 @@ class Struct(TypeEngine):

__visit_name__ = "struct"

def __init__(self, fields: Optional[Dict[str, TV]] = None):
def __init__(self, fields: Optional[Dict[str, TV]] = None) -> None:
self.fields = fields


class Map(TypeEngine):
"""
Represents a MAP type in DuckDB
"""Represents a MAP type in DuckDB.
```python
from duckdb_engine.datatypes import Map
Expand All @@ -132,26 +131,25 @@ class Map(TypeEngine):
key_type: TV
value_type: TV

def __init__(self, key_type: TV, value_type: TV):
def __init__(self, key_type: TV, value_type: TV) -> None:
self.key_type = key_type
self.value_type = value_type

def bind_processor(
self, dialect: Dialect
self, dialect: Dialect,
) -> Optional[Callable[[Optional[dict]], Optional[dict]]]:
return lambda value: (
{"key": list(value), "value": list(value.values())} if value else None
)

def result_processor(
self, dialect: Dialect, coltype: str
self, dialect: Dialect, coltype: str,
) -> Optional[Callable[[Optional[dict]], Optional[dict]]]:
return lambda value: dict(zip(value["key"], value["value"])) if value else {}


class Union(TypeEngine):
"""
Represents a UNION type in DuckDB
"""Represents a UNION type in DuckDB.
```python
from duckdb_engine.datatypes import Union
Expand All @@ -167,7 +165,7 @@ class Union(TypeEngine):
__visit_name__ = "union"
fields: Dict[str, TV]

def __init__(self, fields: Dict[str, TV]):
def __init__(self, fields: Dict[str, TV]) -> None:
self.fields = fields


Expand Down Expand Up @@ -217,14 +215,15 @@ def struct_or_union(
) -> str:
fields = instance.fields
if fields is None:
raise exc.CompileError(f"DuckDB {repr(instance)} type requires fields")
msg = f"DuckDB {repr(instance)} type requires fields"
raise exc.CompileError(msg)

Check warning on line 219 in duckdb_engine/datatypes.py

View check run for this annotation

Codecov / codecov/patch

duckdb_engine/datatypes.py#L218-L219

Added lines #L218 - L219 were not covered by tests
return "({})".format(
", ".join(
"{} {}".format(
identifier_preparer.quote_identifier(key), process_type(value, compiler)
identifier_preparer.quote_identifier(key), process_type(value, compiler),
)
for key, value in fields.items()
)
),
)


Expand Down
4 changes: 2 additions & 2 deletions duckdb_engine/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@
FuncT = TypeVar("FuncT", bound=Callable[..., Any])


@fixture
@fixture()
def engine() -> Engine:
registry.register("duckdb", "duckdb_engine", "Dialect")

return create_engine("duckdb:///:memory:")


@fixture
@fixture()
def session(engine: Engine) -> Session:
return sessionmaker(bind=engine)()

Expand Down
2 changes: 1 addition & 1 deletion duckdb_engine/tests/snapshots/snap_test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@
"name": "id",
"nullable": True,
"type": GenericRepr("INTEGER()"),
}
},
]
22 changes: 11 additions & 11 deletions duckdb_engine/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session, relationship, sessionmaker

from .. import DBAPI, Dialect
from duckdb_engine import DBAPI, Dialect

try:
# sqlalchemy 2
Expand All @@ -48,7 +48,7 @@ class Mapped(Generic[T]): # type: ignore[no-redef]
pass


@fixture
@fixture()
def engine() -> Engine:
registry.register("duckdb", "duckdb_engine", "Dialect")

Expand All @@ -61,7 +61,7 @@ def engine() -> Engine:


class CompressedString(types.TypeDecorator):
"""Custom Column Type"""
"""Custom Column Type."""

impl = types.BLOB

Expand Down Expand Up @@ -107,7 +107,7 @@ class IntervalModel(Base):
field = Column(Interval)


@fixture
@fixture()
def session(engine: Engine) -> Session:
return sessionmaker(bind=engine)()

Expand Down Expand Up @@ -172,7 +172,7 @@ def test_get_views(engine: Engine) -> None:

con.execute(text("create view test as select 1"))
con.execute(
text("create schema scheme; create view scheme.schema_test as select 1")
text("create schema scheme; create view scheme.schema_test as select 1"),
)

views = engine.dialect.get_view_names(con)
Expand All @@ -197,11 +197,11 @@ def test_preload_extension() -> None:
# check that we get an error indicating that the extension was loaded
with engine.connect() as conn, raises(Exception, match="HTTP HEAD"):
conn.execute(
text("SELECT * FROM read_parquet('https://domain/path/to/file.parquet');")
text("SELECT * FROM read_parquet('https://domain/path/to/file.parquet');"),
)


@fixture
@fixture()
def inspector(engine: Engine, session: Session) -> Inspector:
session.execute(text("create table test (id int);"))
session.commit()
Expand Down Expand Up @@ -309,7 +309,7 @@ def test_binary(session: Session) -> None:


def test_comment_support() -> None:
"comments not yet supported by duckdb"
"Comments not yet supported by duckdb."
with raises(DBAPI.ParserException, match="syntax error"):
duckdb.default_connection.execute('comment on sqlite_master is "hello world";')

Expand Down Expand Up @@ -385,8 +385,8 @@ def test_url_config_and_dict_config() -> None:
with eng.connect() as conn:
res = conn.execute(
text(
"select current_setting('worker_threads'), current_setting('memory_limit')"
)
"select current_setting('worker_threads'), current_setting('memory_limit')",
),
)
row = res.first()
assert row is not None
Expand All @@ -395,7 +395,7 @@ def test_url_config_and_dict_config() -> None:

def test_do_ping(tmp_path: Path, caplog: LogCaptureFixture) -> None:
engine = create_engine(
"duckdb:///" + str(tmp_path / "db"), pool_pre_ping=True, pool_size=1
"duckdb:///" + str(tmp_path / "db"), pool_pre_ping=True, pool_size=1,
)

logger = cast(logging.Logger, engine.pool.logger) # type: ignore
Expand Down
5 changes: 2 additions & 3 deletions duckdb_engine/tests/test_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
from sqlalchemy.sql import sqltypes
from sqlalchemy.types import JSON

from ..datatypes import Map, Struct, types
from duckdb_engine.datatypes import Map, Struct, types


@mark.parametrize("coltype", types)
def test_unsigned_integer_type(
engine: Engine, session: Session, coltype: Type[Integer]
engine: Engine, session: Session, coltype: Type[Integer],
) -> None:
Base = declarative_base()

Expand Down Expand Up @@ -126,7 +126,6 @@ class Entry(base):
id = Column(Integer, primary_key=True, default=0)
struct = Column(Struct(fields={"name": String}))
map = Column(Map(String, Integer))
# union = Column(Union(fields={"name": String, "age": Integer}))

base.metadata.create_all(bind=engine)

Expand Down
6 changes: 2 additions & 4 deletions duckdb_engine/tests/test_ibis.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
"""
these are largely just smoke tests
"""
"""these are largely just smoke tests."""

from csv import DictWriter
from pathlib import Path
Expand All @@ -15,7 +13,7 @@
from ibis.backends.duckdb import Backend


@fixture
@fixture()
def ibis_conn() -> "Backend":
import ibis

Expand Down
2 changes: 1 addition & 1 deletion duckdb_engine/tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_integration(engine: Engine) -> None:

@mark.remote_data
@mark.skipif(
"dev" in duckdb.__version__, reason="md extension not available for dev builds" # type: ignore[attr-defined]
"dev" in duckdb.__version__, reason="md extension not available for dev builds", # type: ignore[attr-defined]
)
def test_motherduck() -> None:
importorskip("duckdb", "0.7.1")
Expand Down
Loading

0 comments on commit 4d5cb5d

Please sign in to comment.