Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pydantic v2 #67

Merged
merged 14 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ repos:
- types-all==1.0.0
- typer==0.3.2
- rich==10.16.2
- pydantic==1.9.0
- pydantic==2.3.*
- GitPython==3.1.26
- typing-extensions==4.0.1
- typing-extensions==4.6.*
- tomli==2.0.1
- repo: https://github.com/codespell-project/codespell
rev: v2.2.4
Expand All @@ -55,7 +55,7 @@ repos:
- typer-cli==0.0.12
- typer==0.3.2
- rich==10.16.2
- pydantic==1.9.0
- pydantic==2.3.*
- GitPython==3.1.26
- typing-extensions==4.0.1
- typing-extensions==4.6.*
- tomli==2.0.1
10 changes: 4 additions & 6 deletions databooks/data_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections import UserList
from typing import Any, Dict, Generic, Iterable, List, TypeVar, cast, overload

from pydantic import BaseModel, Extra, create_model
from pydantic import BaseModel, ConfigDict, create_model
from typing_extensions import Protocol, runtime_checkable

T = TypeVar("T")
Expand Down Expand Up @@ -89,10 +89,7 @@ def resolve(
class DatabooksBase(BaseModel):
"""Base Pydantic class with extras on managing fields."""

class Config:
"""Default configuration for base class."""

extra = Extra.allow
model_config = ConfigDict(extra="allow")

def remove_fields(
self,
Expand Down Expand Up @@ -158,7 +155,8 @@ def __sub__(self, other: DatabooksBase) -> DiffModel:
"Diff" + type(self).__name__,
__base__=type(self),
resolve=resolve,
is_diff=True,
is_diff=(bool, True),
**fields_d,
)

return cast(DiffModel, DiffInstance()) # it'll be filled in with the defaults
38 changes: 22 additions & 16 deletions databooks/data_models/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Union

from pydantic import PositiveInt, validator
from pydantic import PositiveInt, RootModel, field_validator
from rich.console import Console, ConsoleOptions, ConsoleRenderable, RenderResult
from rich.markdown import Markdown
from rich.panel import Panel
Expand Down Expand Up @@ -63,9 +63,7 @@ def remove_fields(

if self.cell_type == "code":
self.outputs: CellOutputs = (
CellOutputs(__root__=[])
if "outputs" not in dict(self)
else self.outputs
CellOutputs([]) if "outputs" not in dict(self) else self.outputs
)
self.execution_count: Optional[PositiveInt] = (
None if "execution_count" not in dict(self) else self.execution_count
Expand Down Expand Up @@ -118,14 +116,16 @@ def __rich__(
"""Rich display of cell stream outputs."""
return Text("".join(self.text))

@validator("output_type")
@field_validator("output_type")
@classmethod
def output_type_must_be_stream(cls, v: str) -> str:
"""Check if stream has `stream` type."""
if v != "stream":
raise ValueError(f"Invalid output type. Expected `stream`, got {v}.")
return v

@validator("name")
@field_validator("name")
@classmethod
def stream_name_must_match(cls, v: str) -> str:
"""Check if stream name is either `stdout` or `stderr`."""
valid_names = ("stdout", "stderr")
Expand Down Expand Up @@ -178,7 +178,8 @@ def __rich_console__(
"""Rich display of data display outputs."""
yield from self.rich_output

@validator("output_type")
@field_validator("output_type")
@classmethod
def output_type_must_match(cls, v: str) -> str:
"""Check if stream has `display_data` type."""
if v != "display_data":
Expand All @@ -198,7 +199,8 @@ def __rich_console__(
yield Text(f"Out [{self.execution_count or ' '}]:", style="out_count")
yield from self.rich_output

@validator("output_type")
@field_validator("output_type")
@classmethod
def output_type_must_match(cls, v: str) -> str:
"""Check if stream has `execute_result` type."""
if v != "execute_result":
Expand All @@ -222,7 +224,8 @@ def __rich__(
"""Rich display of error outputs."""
return Text.from_ansi("\n".join(self.traceback))

@validator("output_type")
@field_validator("output_type")
@classmethod
def output_type_must_match(cls, v: str) -> str:
"""Check if stream has `error` type."""
if v != "error":
Expand All @@ -235,10 +238,10 @@ def output_type_must_match(cls, v: str) -> str:
]


class CellOutputs(DatabooksBase):
class CellOutputs(RootModel):
"""Outputs of notebook code cells."""

__root__: List[CellOutputType]
root: List[CellOutputType]

def __rich_console__(
self, console: Console, options: ConsoleOptions
Expand All @@ -250,8 +253,8 @@ def __rich_console__(
def values(
self,
) -> List[CellOutputType]:
"""Alias `__root__` with outputs for easy referencing."""
return self.__root__
"""Alias `root` with outputs for easy referencing."""
return self.root


class CodeCell(BaseCell):
Expand All @@ -273,7 +276,8 @@ def __rich_console__(
)
yield self.outputs

@validator("cell_type")
@field_validator("cell_type")
@classmethod
def cell_has_code_type(cls, v: str) -> str:
"""Extract the list values from the __root__ attribute of `CellOutputs`."""
if v != "code":
Expand All @@ -292,7 +296,8 @@ def __rich__(
"""Rich display of markdown cells."""
return Panel(Markdown("".join(self.source)))

@validator("cell_type")
@field_validator("cell_type")
@classmethod
def cell_has_md_type(cls, v: str) -> str:
"""Extract the list values from the __root__ attribute of `CellOutputs`."""
if v != "markdown":
Expand All @@ -311,7 +316,8 @@ def __rich__(
"""Rich display of raw cells."""
return Panel(Text("".join(self.source)))

@validator("cell_type")
@field_validator("cell_type")
@classmethod
def cell_has_md_type(cls, v: str) -> str:
"""Extract the list values from the __root__ attribute of `CellOutputs`."""
if v != "raw":
Expand Down
28 changes: 10 additions & 18 deletions databooks/data_models/notebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
cast,
)

from pydantic import Extra, validate_model
from pydantic.generics import GenericModel
from pydantic import Extra, RootModel
from rich import box
from rich.columns import Columns
from rich.console import Console, ConsoleOptions, Group, RenderableType, RenderResult
Expand All @@ -39,19 +38,15 @@
T = TypeVar("T", Cell, CellsPair)


class Cells(GenericModel, BaseCells[T]):
class Cells(RootModel[Sequence[T]], BaseCells[T]):
"""Similar to `list`, with `-` operator using `difflib.SequenceMatcher`."""

__root__: Sequence[T] = ()

def __init__(self, elements: Sequence[T] = ()) -> None:
"""Allow passing data as a positional argument when instantiating class."""
super(Cells, self).__init__(__root__=elements)
root: Sequence[T]

@property
def data(self) -> List[T]: # type: ignore
"""Define property `data` required for `collections.UserList` class."""
return list(self.__root__)
return list(self.root)

def __iter__(self) -> Generator[Any, None, None]:
"""Use list property as iterable."""
Expand Down Expand Up @@ -81,6 +76,7 @@ def __sub__(self: Cells[Cell], other: Cells[Cell]) -> Cells[CellsPair]:
f" {n_context} for {len(self)} and {len(other)} cells in"
" notebooks."
)

return Cells[CellsPair](
[
# https://github.com/python/mypy/issues/9459
Expand Down Expand Up @@ -141,19 +137,16 @@ def wrap_git(
MarkdownCell(
metadata=CellMetadata(git_hash=hash_first),
source=[f"`<<<<<<< {hash_first}`"],
cell_type="markdown",
),
*first_cells,
MarkdownCell(
source=["`=======`"],
cell_type="markdown",
metadata=CellMetadata(),
),
*last_cells,
MarkdownCell(
metadata=CellMetadata(git_hash=hash_last),
source=[f"`>>>>>>> {hash_last}`"],
cell_type="markdown",
),
]

Expand Down Expand Up @@ -250,9 +243,9 @@ def parse_file(cls, path: Path | str, **parse_kwargs: Any) -> JupyterNotebook:
raise ValueError(
f"Value of `content_type` must be `json` (default), got `{content_arg}`"
)
return super(JupyterNotebook, cls).parse_file(
path=path, content_type="json", **parse_kwargs
)

path = Path(path) if not isinstance(path, Path) else path
return JupyterNotebook.model_validate_json(json_data=path.read_text())

def write(
self, path: Path | str, overwrite: bool = False, **json_kwargs: Any
Expand All @@ -265,9 +258,8 @@ def write(
f"File exists at {path} exists. Specify `overwrite = True`."
)

_, _, validation_error = validate_model(self.__class__, self.dict())
if validation_error:
raise validation_error
self.__class__.model_validate(self.dict())

with path.open("w") as f:
json.dump(self.dict(), fp=f, **json_kwargs)

Expand Down
Loading
Loading