diff --git a/neptyne_kernel/json_tools.py b/neptyne_kernel/json_tools.py index 20d9644..2e6c432 100644 --- a/neptyne_kernel/json_tools.py +++ b/neptyne_kernel/json_tools.py @@ -16,14 +16,13 @@ def json_clean(obj: Any) -> dict: def json_default(obj: Any) -> str | int | list | float | None: if isinstance(obj, Empty): return None + if hasattr(obj, "__json__"): + return obj.__json__() try: from jupyter_client.jsonutil import json_default as jupyter_json_default except ImportError: from jupyter_client.jsonutil import date_default as jupyter_json_default - if hasattr(obj, "__json__"): - obj = obj.__json__() - return jupyter_json_default(obj) diff --git a/neptyne_kernel/widgets/register_widget_test.py b/neptyne_kernel/widgets/register_widget_test.py index a17833d..8066a99 100644 --- a/neptyne_kernel/widgets/register_widget_test.py +++ b/neptyne_kernel/widgets/register_widget_test.py @@ -4,6 +4,7 @@ from typing import Callable from ..cell_range import CellRange +from ..json_tools import json_default from ..neptyne_protocol import WidgetParamType from ..util import list_like from .base_widget import BaseWidget @@ -115,7 +116,7 @@ def test_register_widget(): def test_widgets_with_enums_can_serialize(): scatter = Scatter(x=[1, 2, 3], y=[4, 5, 6], trendline=Scatter.TrendlineType.POLY) - s = json.dumps(scatter) + s = json.dumps(scatter, default=json_default) scatter2 = Scatter(**json.loads(s)) assert scatter2.trendline == Scatter.TrendlineType.POLY