Skip to content

Commit

Permalink
fix: correctly pickle UnevaluatedExpression (#139)
Browse files Browse the repository at this point in the history
* fix: define getnewargs_ex instead of getnewargs
  Fixes pickling problem for UnevaluatedExpression
  https://docs.python.org/3/library/pickle.html#object.__getnewargs_ex__
* test: test pickling of UnevaluatedExpressions
* docs: unpickle helicity module in notebook
  (serves as an additional test)
  • Loading branch information
redeboer authored Aug 30, 2021
1 parent 29651c6 commit f4b740b
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 12 deletions.
15 changes: 5 additions & 10 deletions docs/usage/amplitude.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"There is no special export function to export an {class}`.HelicityModel`. However, we can just use the built-in {mod}`pickle` method to write the model to disk:"
"There is no special export function to export an {class}`.HelicityModel`. However, we can just use the built-in {mod}`pickle` module to write the model to disk and load it back:"
]
},
{
Expand All @@ -229,14 +229,9 @@
"import pickle\n",
"\n",
"with open(\"helicity_model.pickle\", \"wb\") as stream:\n",
" pickle.dump(model, stream)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This model will be imported again in {doc}`/usage/interactive`."
" pickle.dump(model, stream)\n",
"with open(\"helicity_model.pickle\", \"rb\") as stream:\n",
" model = pickle.load(stream)"
]
},
{
Expand Down Expand Up @@ -659,7 +654,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
"version": "3.8.11"
}
},
"nbformat": 4,
Expand Down
5 changes: 3 additions & 2 deletions src/ampform/sympy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# cspell:ignore mhash
# pylint: disable=invalid-getnewargs-ex-returned
"""Tools that facilitate in building :mod:`sympy` expressions."""

import functools
Expand Down Expand Up @@ -34,10 +35,10 @@ def __new__( # pylint: disable=unused-argument
obj._name = name
return obj

def __getnewargs__(self) -> tuple:
def __getnewargs_ex__(self) -> Tuple[tuple, dict]:
# Pickling support, see
# https://github.com/sympy/sympy/blob/1.8/sympy/core/basic.py#L124-L126
return (*self.args, self._name)
return (self.args, {"name": self._name})

@abstractmethod
def evaluate(self) -> sp.Expr:
Expand Down
31 changes: 31 additions & 0 deletions tests/dynamics/test_sympy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pickle

import sympy as sp

from ampform.dynamics import BlattWeisskopfSquared
from ampform.sympy import UnevaluatedExpression


class TestUnevaluatedExpression:
@staticmethod
def test_pickle():
z = sp.Symbol("z")
angular_momentum = sp.Symbol("L", integer=True)

# Pickle simple SymPy expression
expr = z * angular_momentum
pickled_obj = pickle.dumps(expr)
imported_expr = pickle.loads(pickled_obj)
assert expr == imported_expr

# Pickle UnevaluatedExpression
expr = UnevaluatedExpression()
pickled_obj = pickle.dumps(expr)
imported_expr = pickle.loads(pickled_obj)
assert expr == imported_expr

# Pickle class derived from UnevaluatedExpression
expr = BlattWeisskopfSquared(angular_momentum, z=z)
pickled_obj = pickle.dumps(expr)
imported_expr = pickle.loads(pickled_obj)
assert expr == imported_expr

0 comments on commit f4b740b

Please sign in to comment.