From db9223bcc6f482c25e6ac48cb8ecf44ecd4e62af Mon Sep 17 00:00:00 2001 From: Yossi Mosbacher Date: Tue, 13 Feb 2024 19:06:38 +0200 Subject: [PATCH] use seattr override for bidirectional sync --- pydantic_panel/__init__.py | 2 +- pydantic_panel/dispatchers.py | 98 ++++++++------- pydantic_panel/numpy.py | 4 +- pydantic_panel/pandas.py | 10 +- pydantic_panel/pane.py | 3 +- pydantic_panel/widgets.py | 218 ++++++++++++++++++---------------- pyproject.toml | 5 +- 7 files changed, 179 insertions(+), 161 deletions(-) diff --git a/pydantic_panel/__init__.py b/pydantic_panel/__init__.py index ea15e8b..cdb5cde 100644 --- a/pydantic_panel/__init__.py +++ b/pydantic_panel/__init__.py @@ -89,4 +89,4 @@ __all__.append('NPArray') except ImportError: - pass \ No newline at end of file + pass diff --git a/pydantic_panel/dispatchers.py b/pydantic_panel/dispatchers.py index b7a40bd..00e2c97 100644 --- a/pydantic_panel/dispatchers.py +++ b/pydantic_panel/dispatchers.py @@ -1,7 +1,9 @@ import param import datetime -from typing import Dict, List, Any, Optional -from pydantic.fields import ModelField +import annotated_types + +from typing import Any, Optional +from pydantic.fields import FieldInfo try: from typing import _LiteralGenericAlias @@ -33,20 +35,20 @@ def clean_kwargs(obj: param.Parameterized, - kwargs: Dict[str,Any]) -> Dict[str,Any]: + kwargs: dict[str,Any]) -> dict[str,Any]: '''Remove any kwargs that are not explicit parameters of obj. ''' - return {k: v for k, v in kwargs.items() if k in obj.param.params()} + return {k: v for k, v in kwargs.items() if k in obj.param.values()} @dispatch -def infer_widget(value: Any, field: Optional[ModelField] = None, **kwargs) -> Widget: +def infer_widget(value: Any, field: Optional[FieldInfo] = None, **kwargs) -> Widget: """Fallback function when a more specific function was not registered. """ - if field is not None and type(field.outer_type_) == _LiteralGenericAlias: - options = list(field.outer_type_.__args__) + if field is not None and type(field.annotation) == _LiteralGenericAlias: + options = list(field.annotation.__args__) if value not in options: value = options[0] options = kwargs.pop("options", options) @@ -59,54 +61,61 @@ def infer_widget(value: Any, field: Optional[ModelField] = None, **kwargs) -> Wi @dispatch -def infer_widget(value: Integral, field: Optional[ModelField] = None, **kwargs) -> Widget: +def infer_widget(value: Integral, field: Optional[FieldInfo] = None, **kwargs) -> Widget: start = None end = None if field is not None: - if type(field.outer_type_) == _LiteralGenericAlias: - options = list(field.outer_type_.__args__) + if type(field.annotation) == _LiteralGenericAlias: + options = list(field.annotation.__args__) if value not in options: value = options[0] options = kwargs.pop("options", options) kwargs = clean_kwargs(Select, kwargs) return Select(value=value, options=options, **kwargs) - start = getattr(field.field_info, "gt", None) - if start is not None: - start += 1 - else: - start = getattr(field.field_info, "ge") - - end = getattr(field.field_info, "lt", None) - if end is not None: - end -= 1 - else: - end = getattr(field.field_info, "le", None) + for m in field.metadata: + if isinstance(m, annotated_types.Gt): + start = m.gt + 1 + if isinstance(m, annotated_types.Ge): + start = m.ge + if isinstance(m, annotated_types.Lt): + end = m.lt - 1 + if isinstance(m, annotated_types.Le): + end = m.le + kwargs = clean_kwargs(IntInput, kwargs) return IntInput(value=value, start=start, end=end, **kwargs) @dispatch -def infer_widget(value: Number, field: Optional[ModelField] = None, **kwargs) -> Widget: +def infer_widget(value: Number, field: Optional[FieldInfo] = None, **kwargs) -> Widget: start = None end = None if field is not None: - if type(field.outer_type_) == _LiteralGenericAlias: - options = list(field.outer_type_.__args__) + if type(field.annotation) == _LiteralGenericAlias: + options = list(field.annotation.__args__) if value not in options: value = options[0] options = kwargs.pop("options", options) kwargs = clean_kwargs(Select, kwargs) return Select(value=value, options=options, **kwargs) - start = getattr(field.field_info, "gt", None) - end = getattr(field.field_info, "lt", None) + for m in field.metadata: + if isinstance(m, annotated_types.Gt): + start = m.gt + 1 + if isinstance(m, annotated_types.Ge): + start = m.ge + if isinstance(m, annotated_types.Lt): + end = m.lt - 1 + if isinstance(m, annotated_types.Le): + end = m.le + kwargs = clean_kwargs(NumberInput, kwargs) return NumberInput(value=value, start=start, end=end, **kwargs) @dispatch -def infer_widget(value: bool, field: Optional[ModelField] = None, **kwargs) -> Widget: +def infer_widget(value: bool, field: Optional[FieldInfo] = None, **kwargs) -> Widget: if value is None: value = False kwargs = clean_kwargs(Checkbox, kwargs) @@ -114,28 +123,27 @@ def infer_widget(value: bool, field: Optional[ModelField] = None, **kwargs) -> W @dispatch -def infer_widget(value: str, field: Optional[ModelField] = None, **kwargs) -> Widget: +def infer_widget(value: str, field: Optional[FieldInfo] = None, **kwargs) -> Widget: min_length = kwargs.pop("min_length", None) max_length = kwargs.pop("max_length", 100) if field is not None: - if type(field.outer_type_) == _LiteralGenericAlias: - options = list(field.outer_type_.__args__) + if type(field.annotation) == _LiteralGenericAlias: + options = list(field.annotation.__args__) if value not in options: value = options[0] options = kwargs.pop("options", options) kwargs = clean_kwargs(Select, kwargs) return Select(value=value, options=options, **kwargs) - max_length = field.field_info.max_length - min_length = field.field_info.min_length + for m in field.metadata: + if isinstance(m, annotated_types.MinLen): + min_length = m.min_length + if isinstance(m, annotated_types.MaxLen): + max_length = m.max_length kwargs["min_length"] = min_length - if max_length is None: - kwargs = clean_kwargs(TextAreaInput, kwargs) - return TextAreaInput(value=value, **kwargs) - - elif max_length < 100: + if max_length is not None and max_length < 100: kwargs = clean_kwargs(TextInput, kwargs) return TextInput(value=value, max_length=max_length, **kwargs) @@ -144,9 +152,9 @@ def infer_widget(value: str, field: Optional[ModelField] = None, **kwargs) -> Wi @dispatch -def infer_widget(value: List, field: Optional[ModelField] = None, **kwargs) -> Widget: - if field is not None and type(field.type_) == _LiteralGenericAlias: - options = list(field.type_.__args__) +def infer_widget(value: list, field: Optional[FieldInfo] = None, **kwargs) -> Widget: + if field is not None and type(field.annotation) == _LiteralGenericAlias: + options = list(field.annotation.__args__) if value not in options: value = [] kwargs = clean_kwargs(ListInput, kwargs) @@ -158,20 +166,20 @@ def infer_widget(value: List, field: Optional[ModelField] = None, **kwargs) -> W @dispatch -def infer_widget(value: Dict, field: Optional[ModelField] = None, **kwargs) -> Widget: +def infer_widget(value: dict, field: Optional[FieldInfo] = None, **kwargs) -> Widget: kwargs = clean_kwargs(DictInput, kwargs) return DictInput(value=value, **kwargs) @dispatch -def infer_widget(value: tuple, field: Optional[ModelField] = None, **kwargs) -> Widget: +def infer_widget(value: tuple, field: Optional[FieldInfo] = None, **kwargs) -> Widget: kwargs = clean_kwargs(TupleInput, kwargs) return TupleInput(value=value, **kwargs) @dispatch def infer_widget( - value: datetime.datetime, field: Optional[ModelField] = None, **kwargs + value: datetime.datetime, field: Optional[FieldInfo] = None, **kwargs ): kwargs = clean_kwargs(DatetimePicker, kwargs) return DatetimePicker(value=value, **kwargs) @@ -179,7 +187,7 @@ def infer_widget( @dispatch def infer_widget( - value: param.Parameterized, field: Optional[ModelField] = None, **kwargs + value: param.Parameterized, field: Optional[FieldInfo] = None, **kwargs ): kwargs = clean_kwargs(Param, kwargs) return Param(value, **kwargs) @@ -187,7 +195,7 @@ def infer_widget( @dispatch def infer_widget( - value: List[param.Parameterized], field: Optional[ModelField] = None, **kwargs + value: list[param.Parameterized], field: Optional[FieldInfo] = None, **kwargs ): kwargs = clean_kwargs(Param, kwargs) return Column(*[Param(val, **kwargs) for val in value]) diff --git a/pydantic_panel/numpy.py b/pydantic_panel/numpy.py index fc3ee73..5d8c3bc 100644 --- a/pydantic_panel/numpy.py +++ b/pydantic_panel/numpy.py @@ -3,7 +3,7 @@ from typing import Optional from plum import dispatch, parametric, type_of -from pydantic.fields import ModelField +from pydantic.fields import FieldInfo from panel.widgets import Widget, ArrayInput from .dispatchers import clean_kwargs @@ -11,7 +11,7 @@ @dispatch def infer_widget( - value: np.ndarray, field: Optional[ModelField] = None, **kwargs + value: np.ndarray, field: Optional[FieldInfo] = None, **kwargs ) -> Widget: kwargs = clean_kwargs(ArrayInput, kwargs) return ArrayInput(value=value, **kwargs) diff --git a/pydantic_panel/pandas.py b/pydantic_panel/pandas.py index b7d8794..165246e 100644 --- a/pydantic_panel/pandas.py +++ b/pydantic_panel/pandas.py @@ -4,14 +4,14 @@ import param import pandas as pd -from pydantic.fields import ModelField +from pydantic.fields import FieldInfo from panel.widgets import DatetimeRangePicker, EditableRangeSlider from .dispatchers import clean_kwargs class PandasTimeIntervalEditor(DatetimeRangePicker): - value = param.ClassSelector(pd.Interval, default=None) + value = param.ClassSelector(class_=pd.Interval, default=None) def _serialize_value(self, value): value = super()._serialize_value(value) @@ -35,9 +35,9 @@ def _update_value_bounds(self): class PandasIntervalEditor(EditableRangeSlider): - value = param.ClassSelector(pd.Interval, default=None) + value = param.ClassSelector(class_=pd.Interval, default=None) - value_throttled = param.ClassSelector(pd.Interval, default=None) + value_throttled = param.ClassSelector(class_=pd.Interval, default=None) @param.depends("value", watch=True) def _update_value(self): @@ -83,7 +83,7 @@ class PandasIntegerIntervalEditor(PandasIntervalEditor): @dispatch -def infer_widget(value: pd.Interval, field: Optional[ModelField] = None, **kwargs): +def infer_widget(value: pd.Interval, field: Optional[FieldInfo] = None, **kwargs): if isinstance(value.left, pd.Timestamp) or isinstance(value.right, pd.Timestamp): kwargs = clean_kwargs(PandasTimeIntervalEditor, kwargs) return PandasTimeIntervalEditor(value=value, **kwargs) diff --git a/pydantic_panel/pane.py b/pydantic_panel/pane.py index 98e098f..adb22ec 100644 --- a/pydantic_panel/pane.py +++ b/pydantic_panel/pane.py @@ -20,7 +20,6 @@ pyobject = object - class Pydantic(PaneBase): """The Pydantic pane wraps your Pydantic model into a Panel component. @@ -58,7 +57,7 @@ def __init__(self, object=None, default_layout: Panel | None = None, **params): params["default_layout"] = default_layout pane_params = { - name: params[name] for name in Pydantic.param.params() if name in params + name: params[name] for name in Pydantic.param.values() if name in params } super().__init__(object, **pane_params) diff --git a/pydantic_panel/widgets.py b/pydantic_panel/widgets.py index b36add9..501e224 100644 --- a/pydantic_panel/widgets.py +++ b/pydantic_panel/widgets.py @@ -5,8 +5,7 @@ from typing import Dict, List, Any, Optional, Type, ClassVar from pydantic import ValidationError, BaseModel -from pydantic.fields import ModelField -from pydantic.config import inherit_config +from pydantic.fields import FieldInfo from plum import dispatch, NotFoundLookupError @@ -21,7 +20,6 @@ from .dispatchers import infer_widget, clean_kwargs from pydantic_panel import infer_widget - from typing import ClassVar, Type, List, Dict, Tuple, Any # See https://github.com/holoviz/panel/issues/3736 @@ -51,7 +49,7 @@ class pydantic_widgets(param.ParameterizedFunction): of a pydantic model. """ - model = param.ClassSelector(pydantic.BaseModel, is_instance=False) + model = param.ClassSelector(class_=pydantic.BaseModel, is_instance=False) aliases = param.Dict({}) @@ -65,23 +63,23 @@ def __call__(self, **params): p = param.ParamOverrides(self, params) if isinstance(p.model, BaseModel): - self.defaults = {f: getattr(p.model, f, None) for f in p.model.__fields__} + self.defaults = {f: getattr(p.model, f, None) for f in p.model.model_fields} if p.use_model_aliases: default_aliases = { field.name: field.alias.capitalize() - for field in p.model.__fields__.values() + for field in p.model.model_fields.values() } else: default_aliases = { - name: name.replace("_", " ").capitalize() for name in p.model.__fields__ + name: name.replace("_", " ").capitalize() for name in p.model.model_fields } aliases = params.get("aliases", default_aliases) widgets = {} for field_name, alias in aliases.items(): - field = p.model.__fields__[field_name] + field = p.model.model_fields[field_name] value = p.defaults.get(field_name, None) @@ -89,7 +87,7 @@ def __call__(self, **params): value = field.default try: - widget_builder = infer_widget.invoke(field.outer_type_, field.__class__) + widget_builder = infer_widget.invoke(field.annotation, field.__class__) widget = widget_builder( value, field, name=field_name, **p.widget_kwargs ) @@ -104,57 +102,6 @@ def __call__(self, **params): return widgets -class InstanceOverride: - """This allows us to override pydantic class attributes - for specific instance without touching the instance __dict__ - since pydantic expects the instance __dict__ to only hold field - values. We implement the descriptor protocol and lookup the value - based on the id of the instance. - """ - - @classmethod - def override(cls, instance: Any, name: str, value: Any, default: Any = None): - """Override the class attribute `name` with `value` - only when accessed from `instance`. - - Args: - instance (Any): An instance of some class - name (str): the attribute to be overriden - value (Any): the value to override with for this instance - default (Any, optional): Default value to return for other instances. - Only used if attribute doesnt exist on class. - Defaults to None. - - Returns: - Any: the instance that was passed - """ - - class_ = type(instance) - - if not hasattr(class_, name): - setattr(class_, name, cls(default)) - - elif not isinstance(vars(class_)[name], cls): - setattr(class_, name, cls(getattr(class_, name))) - - vars(class_)[name].mapper[id(instance)] = value - - return instance - - def revert_override(self, instance: Any): - return self.mapper.pop(id(instance)) - - def __init__(self, default, mapper=None): - self.default = default - self.mapper = mapper or {} - - def __get__(self, obj, objtype=None): - if id(obj) in self.mapper: - return self.mapper[id(obj)] - else: - return self.default - - class PydanticModelEditor(CompositeWidget): """A composet widget whos value is a pydantic model and whos children widgets are synced with the model attributes @@ -197,7 +144,7 @@ def __init__(self, **params): def widgets(self): fields = self.fields if self.fields else list(self._widgets) return [self._widgets[field] for field in fields if field in self._widgets] - + def _recreate_widgets(self, *events): if self.class_ is None: self.value = None @@ -256,33 +203,30 @@ def _update_value(self, event: param.Event): if self.value is not None and self.bidirectional: # We need to ensure the model validates on assignment - if not self.value.__config__.validate_assignment: - config = inherit_config(Config, self.value.__config__) - InstanceOverride.override(self.value, "__config__", config) + if not self.value.model_config.get("validate_assignment", False): + config = self.value.model_config.copy() + config.update(validate_assignment=True) # Add a callback to the root validators - # to sync widgets to the changes made to - # the model attributes - callback = (False, self._update_widgets) - if callback not in self.value.__post_root_validators__: - validators = self.value.__post_root_validators__ + [callback] - InstanceOverride.override( - self.value, "__post_root_validators__", validators - ) + # sync widgets to the changes made directly + # to the model attributes + add_setattr_callback(self.value, self._update_widget) + # If the previous value was a model - # instance we unlink it by removing - # the instance root validator and config + # instance we unlink it if id(self.value) != id(event.old) and isinstance(event.old, BaseModel): - for var in vars(type(event.old)).values(): - if not isinstance(var, InstanceOverride): - continue - var.revert_override(event.old) + remove_setattr_callback(event.old, self._update_widget) + + def __del__(self): + if self.value is not None and self.bidirectional: + remove_setattr_callback(self.value, self._update_widget) def items(self): if self.value is None: return [] - return [(name, getattr(self.value, name)) for name in self.value.__fields__] + return [(name, getattr(self.value, name)) + for name in self.value.model_fields] def _validate_field(self, event: param.Event): if not event or self._updating: @@ -297,32 +241,40 @@ def _validate_field(self, event: param.Event): pass return + if self.value is None: + return + for name, widget in self._widgets.items(): if event.obj == widget: break else: return - field = self.value.__fields__[name] - data = {k: w.value for k, w in self._widgets.items()} - - val = data.pop(name, None) - val, error = field.validate(val, data, loc=name) - if error: - self.updating = True + try: + self.class_.__pydantic_validator__.validate_assignment(self.value, + name, + event.new) + except ValidationError as e: + self._updating = True try: event.obj.value = event.old + self._updating_field = True + self.param.trigger("value") + self._updating_field = False finally: - self.updating = False - raise ValidationError([error], type(self.value)) + self._updating = False + raise e - if self.value is not None: - setattr(self.value, name, val) - self._updating_field = True + def _update_widget(self, name, value): + if self._updating: + return + + if name in self._widgets: + self._updating = True try: - self.param.trigger("value") + self._widgets[name].value = value finally: - self._updating_field = False + self._updating = False def _update_widgets(self, cls, values): if self.value is None: @@ -363,6 +315,62 @@ def json(self): ) +def add_setattr_callback(model_instance: BaseModel, callback: callable): + """Syncs the fields of a pydantic model with a dictionary of widgets + + Args: + model_instance (BaseModel): The model instance to sync + callback (callable): The callback function to sync the fields + + Returns: + callback: A callback function that can be used to unsync the fields + """ + + class_ = model_instance.__class__ + if hasattr(class_, "__panel_callbacks__"): + class_.__panel_callbacks__ += (callback,) + else: + class ModifiedModel(class_): + __panel_callbacks__ = (callback,) + + def __setattr__(self, name, value): + super().__setattr__(name, value) + if not hasattr(self.__class__, "__panel_callbacks__"): + return + for cb in self.__class__.__panel_callbacks__: + cb(name, value) + + model_instance.__class__ = ModifiedModel + + return callback + +def remove_setattr_callback(model_instance: BaseModel, callback: callable): + """Unsyncs the fields of a pydantic model with a dictionary of widgets + + Args: + model_instance (BaseModel): The model instance to unsync + + Returns: + None + """ + class_ = model_instance.__class__ + + if hasattr(class_, "__panel_callbacks__"): + class_.__panel_callbacks__ = tuple( + cb for cb in class_.__panel_callbacks__ if cb is not callback + ) + else: + return + + if class_.__panel_callbacks__: + return + + for class_ in model_instance.__class__.mro(): + if hasattr(class_, "__panel_callbacks__"): + continue + model_instance.__class__ = class_ + + class PydanticModelEditorCard(PydanticModelEditor): """Same as PydanticModelEditor but uses a Card container to hold the widgets and synces the header with the widget `name` @@ -395,9 +403,9 @@ class BaseCollectionEditor(CompositeWidget): expand = param.Boolean(True) - class_ = param.ClassSelector(object, is_instance=False) + class_ = param.ClassSelector(class_=object, is_instance=False) - item_field = param.ClassSelector(ModelField, default=None, allow_None=True) + item_field = param.ClassSelector(class_=FieldInfo, default=None, allow_None=True) default_item = param.Parameter(default=None) @@ -478,7 +486,7 @@ def keys(self): def values(self): raise NotImplementedError - def items(self) -> List[Tuple[str, Any]]: + def items(self) -> list[Tuple[str, Any]]: raise NotImplementedError def add_item(self, item, name=None): @@ -507,7 +515,7 @@ def keys(self): def values(self): return list(self.value) - def items(self) -> List[Tuple[str, Any]]: + def items(self) -> list[Tuple[str, Any]]: return list(enumerate(self.value)) def add_item(self, item, name=None): @@ -576,7 +584,7 @@ class ItemDictEditor(BaseCollectionEditor): default={}, ) - key_type = param.ClassSelector(object, default=str, is_instance=False) + key_type = param.ClassSelector(class_=object, default=str, is_instance=False) default_key = param.Parameter(default="") @@ -586,7 +594,7 @@ def keys(self): def values(self): return list(self.value.values()) - def items(self) -> List[Tuple[str, Any]]: + def items(self) -> list[tuple[str, Any]]: return list(self.value.items()) def add_item(self, item, name=None): @@ -643,21 +651,21 @@ def cb(event): @dispatch -def infer_widget(value: BaseModel, field: Optional[ModelField] = None, **kwargs): +def infer_widget(value: BaseModel, field: Optional[FieldInfo] = None, **kwargs): if field is None: class_ = kwargs.pop("class_", type(value)) return PydanticModelEditor(value=value, class_=class_, **kwargs) - class_ = kwargs.pop("class_", field.outer_type_) + class_ = kwargs.pop("class_", field.annotation) kwargs = clean_kwargs(PydanticModelEditorCard, kwargs) return PydanticModelEditorCard(value=value, class_=class_, **kwargs) @dispatch -def infer_widget(value: List[BaseModel], field: Optional[ModelField] = None, **kwargs): +def infer_widget(value: list[BaseModel], field: Optional[FieldInfo] = None, **kwargs): if field is not None: - kwargs["class_"] = kwargs.pop("class_", field.type_) + kwargs["class_"] = kwargs.pop("class_", field.annotation) if value is None: value = field.default @@ -669,11 +677,11 @@ def infer_widget(value: List[BaseModel], field: Optional[ModelField] = None, **k @dispatch def infer_widget( - value: Dict[str, BaseModel], field: Optional[ModelField] = None, **kwargs + value: dict[str, BaseModel], field: Optional[FieldInfo] = None, **kwargs ): if field is not None: - kwargs["class_"] = kwargs.pop("class_", field.type_) + kwargs["class_"] = kwargs.pop("class_", field.annotation) if value is None: value = field.default diff --git a/pyproject.toml b/pyproject.toml index 3cdf89b..e5f0b1c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ packages = [ [tool.poetry.dependencies] python = ">=3.8,<4.0" panel = ">=0.13" -pydantic = "*" +pydantic = ">=2.0" plum-dispatch = "*" @@ -46,6 +46,9 @@ twine = "^4.0.1" black = "^22.6.0" pytest-cov = "^3.0.0" +[tool.poetry.plugins."panel.extension"] +pydantic = 'pydantic_panel' + [build-system] requires = ["poetry-core>=1.0.8", "setuptools"] build-backend = "poetry.core.masonry.api"