diff --git a/.bumpversion.cfg b/.bumpversion.cfg index d07fe05..87fcbcb 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.4.0-alpha-6 +current_version = 0.4.0-alpha-7 parse = (?P\d+)\.(?P\d+)\.(?P\d+)(-(?P.*)-(?P\d+))? serialize = {major}.{minor}.{patch}-{release}-{build} diff --git a/pyproject.toml b/pyproject.toml index 7f09d3b..c4e7880 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "sql-athame" -version = "0.4.0-alpha-6" +version = "0.4.0-alpha-7" description = "Python tool for slicing and dicing SQL" authors = ["Brian Downing "] license = "MIT" diff --git a/sql_athame/dataclasses.py b/sql_athame/dataclasses.py index 2463515..90a13ed 100644 --- a/sql_athame/dataclasses.py +++ b/sql_athame/dataclasses.py @@ -1,4 +1,5 @@ import datetime +import functools import uuid from collections.abc import AsyncGenerator, Iterable, Mapping from dataclasses import Field, InitVar, dataclass, fields @@ -29,27 +30,56 @@ @dataclass class ColumnInfo: - type: str - create_type: str = "" - nullable: bool = False + type: Optional[str] = None + create_type: Optional[str] = None + nullable: Optional[bool] = None _constraints: tuple[str, ...] = () constraints: InitVar[Union[str, Iterable[str], None]] = None def __post_init__(self, constraints: Union[str, Iterable[str], None]) -> None: - if self.create_type == "": - self.create_type = self.type - self.type = sql_create_type_map.get(self.type.upper(), self.type) if constraints is not None: if type(constraints) is str: constraints = (constraints,) self._constraints = tuple(constraints) + @staticmethod + def merge(a: "ColumnInfo", b: "ColumnInfo") -> "ColumnInfo": + return ColumnInfo( + type=b.type if b.type is not None else a.type, + create_type=b.create_type if b.create_type is not None else a.create_type, + nullable=b.nullable if b.nullable is not None else a.nullable, + _constraints=(*a._constraints, *b._constraints), + ) + + +@dataclass +class ConcreteColumnInfo: + type: str + create_type: str + nullable: bool + constraints: tuple[str, ...] + + @staticmethod + def from_column_info(name: str, *args: ColumnInfo) -> "ConcreteColumnInfo": + info = functools.reduce(ColumnInfo.merge, args, ColumnInfo()) + if info.create_type is None and info.type is not None: + info.create_type = info.type + info.type = sql_create_type_map.get(info.type.upper(), info.type) + if type(info.type) is not str or type(info.create_type) is not str: + raise ValueError(f"Missing SQL type for column {name!r}") + return ConcreteColumnInfo( + type=info.type, + create_type=info.create_type, + nullable=bool(info.nullable), + constraints=info._constraints, + ) + def create_table_string(self) -> str: parts = ( self.create_type, *(() if self.nullable else ("NOT NULL",)), - *self._constraints, + *self.constraints, ) return " ".join(parts) @@ -86,7 +116,7 @@ def create_table_string(self) -> str: class ModelBase: - _column_info: Optional[dict[str, ColumnInfo]] + _column_info: Optional[dict[str, ConcreteColumnInfo]] _cache: dict[tuple, Any] table_name: str primary_key_names: tuple[str, ...] @@ -138,19 +168,23 @@ def type_hints(cls) -> dict[str, type]: return cls._type_hints @classmethod - def column_info_for_field(cls, field: Field) -> ColumnInfo: + def column_info_for_field(cls, field: Field) -> ConcreteColumnInfo: type_info = cls.type_hints()[field.name] base_type = type_info if get_origin(type_info) is Annotated: base_type = type_info.__origin__ # type: ignore + info = [] + if base_type in sql_type_map: + _type, nullable = sql_type_map[base_type] + info.append(ColumnInfo(type=_type, nullable=nullable)) + if get_origin(type_info) is Annotated: for md in type_info.__metadata__: # type: ignore if isinstance(md, ColumnInfo): - return md - type, nullable = sql_type_map[base_type] - return ColumnInfo(type=type, nullable=nullable) + info.append(md) + return ConcreteColumnInfo.from_column_info(field.name, *info) @classmethod - def column_info(cls, column: str) -> ColumnInfo: + def column_info(cls, column: str) -> ConcreteColumnInfo: try: return cls._column_info[column] # type: ignore except AttributeError: diff --git a/tests/test_dataclasses.py b/tests/test_dataclasses.py index c1cc9bd..da71114 100644 --- a/tests/test_dataclasses.py +++ b/tests/test_dataclasses.py @@ -6,6 +6,8 @@ from dataclasses import dataclass from typing import Annotated, Optional +import pytest + from sql_athame import sql from sql_athame.dataclasses import ColumnInfo, ModelBase @@ -67,16 +69,33 @@ class Test(ModelBase, table_name="table", primary_key="foo"): foo: int bar: str baz: Optional[uuid.UUID] + quux: Annotated[int, ColumnInfo(constraints="REFERENCES foobar")] + quuux: Annotated[ + int, + ColumnInfo(constraints="REFERENCES foobar"), + ColumnInfo(constraints="BLAH", nullable=True), + ] assert list(Test.create_table_sql()) == [ 'CREATE TABLE IF NOT EXISTS "table" (' '"foo" INTEGER NOT NULL, ' '"bar" TEXT NOT NULL, ' '"baz" UUID, ' + '"quux" INTEGER NOT NULL REFERENCES foobar, ' + '"quuux" INTEGER REFERENCES foobar BLAH, ' 'PRIMARY KEY ("foo"))' ] +def test_modelclass_missing_type(): + @dataclass + class Test(ModelBase, table_name="table", primary_key="foo"): + foo: dict + + with pytest.raises(ValueError, match="Missing SQL type for column 'foo'"): + Test.create_table_sql() + + def test_upsert(): @dataclass class Test(ModelBase, table_name="table", primary_key="id"):