Skip to content

Commit

Permalink
Merge pull request #67 from astronautas/pydantic-v2
Browse files Browse the repository at this point in the history
Pydantic v2
  • Loading branch information
murilo-cunha authored Sep 19, 2023
2 parents d6884d2 + 7aa698b commit 9d46c23
Show file tree
Hide file tree
Showing 11 changed files with 769 additions and 702 deletions.
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ repos:
- types-all==1.0.0
- typer==0.3.2
- rich==10.16.2
- pydantic==1.9.0
- pydantic==2.3.*
- GitPython==3.1.26
- typing-extensions==4.0.1
- typing-extensions==4.6.*
- tomli==2.0.1
- repo: https://github.com/codespell-project/codespell
rev: v2.2.4
Expand All @@ -55,7 +55,7 @@ repos:
- typer-cli==0.0.12
- typer==0.3.2
- rich==10.16.2
- pydantic==1.9.0
- pydantic==2.3.*
- GitPython==3.1.26
- typing-extensions==4.0.1
- typing-extensions==4.6.*
- tomli==2.0.1
10 changes: 4 additions & 6 deletions databooks/data_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections import UserList
from typing import Any, Dict, Generic, Iterable, List, TypeVar, cast, overload

from pydantic import BaseModel, Extra, create_model
from pydantic import BaseModel, ConfigDict, create_model
from typing_extensions import Protocol, runtime_checkable

T = TypeVar("T")
Expand Down Expand Up @@ -89,10 +89,7 @@ def resolve(
class DatabooksBase(BaseModel):
"""Base Pydantic class with extras on managing fields."""

class Config:
"""Default configuration for base class."""

extra = Extra.allow
model_config = ConfigDict(extra="allow")

def remove_fields(
self,
Expand Down Expand Up @@ -158,7 +155,8 @@ def __sub__(self, other: DatabooksBase) -> DiffModel:
"Diff" + type(self).__name__,
__base__=type(self),
resolve=resolve,
is_diff=True,
is_diff=(bool, True),
**fields_d,
)

return cast(DiffModel, DiffInstance()) # it'll be filled in with the defaults
38 changes: 22 additions & 16 deletions databooks/data_models/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Union

from pydantic import PositiveInt, validator
from pydantic import PositiveInt, RootModel, field_validator
from rich.console import Console, ConsoleOptions, ConsoleRenderable, RenderResult
from rich.markdown import Markdown
from rich.panel import Panel
Expand Down Expand Up @@ -63,9 +63,7 @@ def remove_fields(

if self.cell_type == "code":
self.outputs: CellOutputs = (
CellOutputs(__root__=[])
if "outputs" not in dict(self)
else self.outputs
CellOutputs([]) if "outputs" not in dict(self) else self.outputs
)
self.execution_count: Optional[PositiveInt] = (
None if "execution_count" not in dict(self) else self.execution_count
Expand Down Expand Up @@ -118,14 +116,16 @@ def __rich__(
"""Rich display of cell stream outputs."""
return Text("".join(self.text))

@validator("output_type")
@field_validator("output_type")
@classmethod
def output_type_must_be_stream(cls, v: str) -> str:
"""Check if stream has `stream` type."""
if v != "stream":
raise ValueError(f"Invalid output type. Expected `stream`, got {v}.")
return v

@validator("name")
@field_validator("name")
@classmethod
def stream_name_must_match(cls, v: str) -> str:
"""Check if stream name is either `stdout` or `stderr`."""
valid_names = ("stdout", "stderr")
Expand Down Expand Up @@ -178,7 +178,8 @@ def __rich_console__(
"""Rich display of data display outputs."""
yield from self.rich_output

@validator("output_type")
@field_validator("output_type")
@classmethod
def output_type_must_match(cls, v: str) -> str:
"""Check if stream has `display_data` type."""
if v != "display_data":
Expand All @@ -198,7 +199,8 @@ def __rich_console__(
yield Text(f"Out [{self.execution_count or ' '}]:", style="out_count")
yield from self.rich_output

@validator("output_type")
@field_validator("output_type")
@classmethod
def output_type_must_match(cls, v: str) -> str:
"""Check if stream has `execute_result` type."""
if v != "execute_result":
Expand All @@ -222,7 +224,8 @@ def __rich__(
"""Rich display of error outputs."""
return Text.from_ansi("\n".join(self.traceback))

@validator("output_type")
@field_validator("output_type")
@classmethod
def output_type_must_match(cls, v: str) -> str:
"""Check if stream has `error` type."""
if v != "error":
Expand All @@ -235,10 +238,10 @@ def output_type_must_match(cls, v: str) -> str:
]


class CellOutputs(DatabooksBase):
class CellOutputs(RootModel):
"""Outputs of notebook code cells."""

__root__: List[CellOutputType]
root: List[CellOutputType]

def __rich_console__(
self, console: Console, options: ConsoleOptions
Expand All @@ -250,8 +253,8 @@ def __rich_console__(
def values(
self,
) -> List[CellOutputType]:
"""Alias `__root__` with outputs for easy referencing."""
return self.__root__
"""Alias `root` with outputs for easy referencing."""
return self.root


class CodeCell(BaseCell):
Expand All @@ -273,7 +276,8 @@ def __rich_console__(
)
yield self.outputs

@validator("cell_type")
@field_validator("cell_type")
@classmethod
def cell_has_code_type(cls, v: str) -> str:
"""Extract the list values from the __root__ attribute of `CellOutputs`."""
if v != "code":
Expand All @@ -292,7 +296,8 @@ def __rich__(
"""Rich display of markdown cells."""
return Panel(Markdown("".join(self.source)))

@validator("cell_type")
@field_validator("cell_type")
@classmethod
def cell_has_md_type(cls, v: str) -> str:
"""Extract the list values from the __root__ attribute of `CellOutputs`."""
if v != "markdown":
Expand All @@ -311,7 +316,8 @@ def __rich__(
"""Rich display of raw cells."""
return Panel(Text("".join(self.source)))

@validator("cell_type")
@field_validator("cell_type")
@classmethod
def cell_has_md_type(cls, v: str) -> str:
"""Extract the list values from the __root__ attribute of `CellOutputs`."""
if v != "raw":
Expand Down
28 changes: 10 additions & 18 deletions databooks/data_models/notebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
cast,
)

from pydantic import Extra, validate_model
from pydantic.generics import GenericModel
from pydantic import Extra, RootModel
from rich import box
from rich.columns import Columns
from rich.console import Console, ConsoleOptions, Group, RenderableType, RenderResult
Expand All @@ -39,19 +38,15 @@
T = TypeVar("T", Cell, CellsPair)


class Cells(GenericModel, BaseCells[T]):
class Cells(RootModel[Sequence[T]], BaseCells[T]):
"""Similar to `list`, with `-` operator using `difflib.SequenceMatcher`."""

__root__: Sequence[T] = ()

def __init__(self, elements: Sequence[T] = ()) -> None:
"""Allow passing data as a positional argument when instantiating class."""
super(Cells, self).__init__(__root__=elements)
root: Sequence[T]

@property
def data(self) -> List[T]: # type: ignore
"""Define property `data` required for `collections.UserList` class."""
return list(self.__root__)
return list(self.root)

def __iter__(self) -> Generator[Any, None, None]:
"""Use list property as iterable."""
Expand Down Expand Up @@ -81,6 +76,7 @@ def __sub__(self: Cells[Cell], other: Cells[Cell]) -> Cells[CellsPair]:
f" {n_context} for {len(self)} and {len(other)} cells in"
" notebooks."
)

return Cells[CellsPair](
[
# https://github.com/python/mypy/issues/9459
Expand Down Expand Up @@ -141,19 +137,16 @@ def wrap_git(
MarkdownCell(
metadata=CellMetadata(git_hash=hash_first),
source=[f"`<<<<<<< {hash_first}`"],
cell_type="markdown",
),
*first_cells,
MarkdownCell(
source=["`=======`"],
cell_type="markdown",
metadata=CellMetadata(),
),
*last_cells,
MarkdownCell(
metadata=CellMetadata(git_hash=hash_last),
source=[f"`>>>>>>> {hash_last}`"],
cell_type="markdown",
),
]

Expand Down Expand Up @@ -250,9 +243,9 @@ def parse_file(cls, path: Path | str, **parse_kwargs: Any) -> JupyterNotebook:
raise ValueError(
f"Value of `content_type` must be `json` (default), got `{content_arg}`"
)
return super(JupyterNotebook, cls).parse_file(
path=path, content_type="json", **parse_kwargs
)

path = Path(path) if not isinstance(path, Path) else path
return JupyterNotebook.model_validate_json(json_data=path.read_text())

def write(
self, path: Path | str, overwrite: bool = False, **json_kwargs: Any
Expand All @@ -265,9 +258,8 @@ def write(
f"File exists at {path} exists. Specify `overwrite = True`."
)

_, _, validation_error = validate_model(self.__class__, self.dict())
if validation_error:
raise validation_error
self.__class__.model_validate(self.dict())

with path.open("w") as f:
json.dump(self.dict(), fp=f, **json_kwargs)

Expand Down
Loading

0 comments on commit 9d46c23

Please sign in to comment.