Skip to content

Commit

Permalink
Merge pull request #894 from Mause/json
Browse files Browse the repository at this point in the history
fix: allow parsing json in dynamic queries
  • Loading branch information
Mause authored Mar 18, 2024
2 parents 4d6de0b + 2b2863f commit e2055f9
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 4 deletions.
21 changes: 20 additions & 1 deletion duckdb_engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
)
from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2
from sqlalchemy.engine.default import DefaultDialect
from sqlalchemy.engine.interfaces import Dialect as RootDialect
from sqlalchemy.engine.reflection import cache
from sqlalchemy.engine.url import URL
from sqlalchemy.exc import NoSuchTableError
Expand All @@ -47,7 +48,7 @@
if TYPE_CHECKING:
from sqlalchemy.base import Connection
from sqlalchemy.engine.interfaces import _IndexDict

from sqlalchemy.sql.type_api import _ResultProcessor

register_extension_types()

Expand Down Expand Up @@ -215,6 +216,16 @@ def quote_schema(self, schema: str, force: Any = None) -> str:
return self.format_schema(schema)


class DuckDBNullType(sqltypes.NullType):
def result_processor(
self, dialect: RootDialect, coltype: sqltypes.TypeEngine
) -> Optional["_ResultProcessor"]:
if coltype == "JSON":
return sqltypes.JSON().result_processor(dialect, coltype)
else:
return super().result_processor(dialect, coltype)


class Dialect(PGDialect_psycopg2):
name = "duckdb"
driver = "duckdb_engine"
Expand Down Expand Up @@ -247,6 +258,14 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
kwargs["use_native_hstore"] = False
super().__init__(*args, **kwargs)

def type_descriptor(self, typeobj: Type[sqltypes.TypeEngine]) -> Any: # type: ignore[override]
res = super().type_descriptor(typeobj)

if isinstance(res, sqltypes.NullType):
return DuckDBNullType()

return res

def connect(self, *cargs: Any, **cparams: Any) -> "Connection":
core_keys = get_core_config()
preload_extensions = cparams.pop("preload_extensions", [])
Expand Down
1 change: 1 addition & 0 deletions duckdb_engine/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def __init__(self, fields: Dict[str, TV]):
"timestamp_ms": sqltypes.TIMESTAMP,
"timestamp_ns": sqltypes.TIMESTAMP,
"enum": sqltypes.Enum,
"json": sqltypes.JSON,
}


Expand Down
70 changes: 67 additions & 3 deletions duckdb_engine/tests/test_datatypes.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,24 @@
import decimal
import json
import warnings
from typing import Type
from typing import Any, Dict, Type
from uuid import uuid4

import duckdb
from pytest import importorskip, mark
from sqlalchemy import Column, Integer, MetaData, String, Table, inspect, text
from sqlalchemy import (
Column,
Integer,
MetaData,
Sequence,
String,
Table,
inspect,
select,
text,
)
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.engine import Engine
from sqlalchemy.engine import Engine, create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session
from sqlalchemy.sql import sqltypes
Expand Down Expand Up @@ -45,6 +57,58 @@ def test_unsigned_integer_type(
assert session.query(table).one()


@mark.remote_data()
def test_raw_json(engine: Engine) -> None:
importorskip("duckdb", "0.9.3.dev4040")

with engine.connect() as conn:
assert conn.execute(text("load json"))

assert conn.execute(text("select {'Hello': 'world'}::JSON")).fetchone() == (
{"Hello": "world"},
)


@mark.remote_data()
def test_custom_json_serializer() -> None:
def default(o: Any) -> Any:
if isinstance(o, decimal.Decimal):
return {"__tag": "decimal", "value": str(o)}

def object_hook(pairs: Dict[str, Any]) -> Any:
if pairs.get("__tag", None) == "decimal":
return decimal.Decimal(pairs["value"])
else:
return pairs

engine = create_engine(
"duckdb://",
json_serializer=json.JSONEncoder(default=default).encode,
json_deserializer=json.JSONDecoder(object_hook=object_hook).decode,
)

Base = declarative_base()

class Entry(Base):
__tablename__ = "test_json"
id = Column(Integer, Sequence("id_seq"), primary_key=True)
data = Column(JSON, nullable=False)

Base.metadata.create_all(engine)

with engine.connect() as conn:
session = Session(bind=conn)

data = {"hello": decimal.Decimal("42")}

session.add(Entry(data=data)) # type: ignore[call-arg]
session.commit()

(res,) = session.execute(select(Entry)).one()

assert res.data == data


def test_json(engine: Engine, session: Session) -> None:
base = declarative_base()

Expand Down

0 comments on commit e2055f9

Please sign in to comment.