Skip to content

Commit

Permalink
BREAK: deprecate UnevaluatedExpression templates (#383)
Browse files Browse the repository at this point in the history
* BREAK: issue deprecation warnings from `deprecated` module
* MAINT: move expression classes to `ampform.sympy.deprecated`
* MAINT: remove remaining `UnevaluatedExpresssion` calls and related
  • Loading branch information
redeboer committed Dec 22, 2023
1 parent f6ad4ce commit 17f383e
Show file tree
Hide file tree
Showing 7 changed files with 339 additions and 266 deletions.
13 changes: 7 additions & 6 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,30 +49,28 @@
add_module_names = False
api_github_repo = f"{ORGANIZATION}/{REPO_NAME}"
api_target_substitutions: dict[str, str | tuple[str, str]] = {
"T": "TypeVar",
"BuilderReturnType": ("obj", "ampform.dynamics.builder.BuilderReturnType"),
"DecoratedClass": ("obj", "ampform.sympy.DecoratedClass"),
"DecoratedExpr": ("obj", "ampform.sympy.DecoratedExpr"),
"ExprClass": "ampform.sympy.ExprClass",
"DecoratedClass": ("obj", "ampform.sympy.deprecated.DecoratedClass"),
"DecoratedExpr": ("obj", "ampform.sympy.deprecated.DecoratedExpr"),
"FourMomenta": ("obj", "ampform.kinematics.FourMomenta"),
"FourMomentumSymbol": ("obj", "ampform.kinematics.FourMomentumSymbol"),
"InteractionProperties": "qrules.quantum_numbers.InteractionProperties",
"LatexPrinter": "sympy.printing.printer.Printer",
"Literal[(-1, 1)]": "typing.Literal",
"Literal[-1, 1]": "typing.Literal",
"NumPyPrintable": ("class", "ampform.sympy.NumPyPrintable"),
"NumPyPrinter": "sympy.printing.printer.Printer",
"ParameterValue": ("obj", "ampform.helicity.ParameterValue"),
"Particle": "qrules.particle.Particle",
"ReactionInfo": "qrules.transition.ReactionInfo",
"Slider": ("obj", "symplot.Slider"),
"State": "qrules.transition.State",
"StateTransition": "qrules.transition.StateTransition",
"T": "TypeVar",
"Topology": "qrules.topology.Topology",
"WignerD": "sympy.physics.quantum.spin.WignerD",
"ampform.helicity._T": "typing.TypeVar",
"ampform.sympy._decorator.ExprClass": ("obj", "ampform.sympy.ExprClass"),
"ampform.sympy._decorator.SymPyAssumptions": "ampform.sympy.SymPyAssumptions",
"an object providing a view on D's values": "typing.ValuesView",
"sp.Basic": "sympy.core.basic.Basic",
"sp.Expr": "sympy.core.expr.Expr",
"sp.Float": "sympy.core.numbers.Float",
Expand Down Expand Up @@ -287,6 +285,9 @@
nb_execution_timeout = -1
nb_output_stderr = "remove"
nitpick_ignore = [
("py:class", "ArraySum"),
("py:class", "ExprClass"),
("py:class", "MatrixMultiplication"),
("py:class", "ampform.sympy._array_expressions.ArraySum"),
("py:class", "ampform.sympy._array_expressions.MatrixMultiplication"),
]
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ filterwarnings = [
"error",
"ignore:.*invalid value encountered in sqrt.*:RuntimeWarning",
"ignore:.*is deprecated and slated for removal in Python 3.14:DeprecationWarning",
"ignore:.*the @ampform.sympy.unevaluated_expression decorator instead( with commutative=True)?:DeprecationWarning",
"ignore:Passing a schema to Validator.iter_errors is deprecated.*:DeprecationWarning",
"ignore:The .* argument to NotebookFile is deprecated.*:pytest.PytestRemovedIn8Warning",
"ignore:The distutils package is deprecated.*:DeprecationWarning",
Expand Down
271 changes: 28 additions & 243 deletions src/ampform/sympy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
"""

# cspell:ignore mhash
from __future__ import annotations

import functools
Expand All @@ -22,7 +21,7 @@
from abc import abstractmethod
from os.path import abspath, dirname, expanduser
from textwrap import dedent
from typing import TYPE_CHECKING, Callable, Iterable, Sequence, SupportsFloat, TypeVar
from typing import TYPE_CHECKING, Iterable, Sequence, SupportsFloat

import sympy as sp
from sympy.printing.precedence import PRECEDENCE
Expand All @@ -33,6 +32,13 @@
argument, # noqa: F401 # pyright: ignore[reportUnusedImport]
unevaluated, # noqa: F401 # pyright: ignore[reportUnusedImport]
)
from .deprecated import (
UnevaluatedExpression, # noqa: F401 # pyright: ignore[reportUnusedImport]
create_expression, # noqa: F401 # pyright: ignore[reportUnusedImport]
implement_doit_method, # noqa: F401 # pyright: ignore[reportUnusedImport]
implement_expr, # pyright: ignore[reportUnusedImport] # noqa: F401
make_commutative, # pyright: ignore[reportUnusedImport] # noqa: F401
)

if TYPE_CHECKING:
from sympy.printing.latex import LatexPrinter
Expand All @@ -41,133 +47,13 @@
_LOGGER = logging.getLogger(__name__)


class UnevaluatedExpression(sp.Expr):
"""Base class for expression classes with an :meth:`evaluate` method.
Deriving from `~sympy.core.expr.Expr` allows us to keep expression trees condense
before unfolding them with their `~sympy.core.basic.Basic.doit` method. This allows
us to:
1. condense the LaTeX representation of an expression tree by providing a custom
:meth:`_latex` method.
2. overwrite its printer methods (see `NumPyPrintable` and e.g.
:doc:`compwa-org:report/001`).
The `UnevaluatedExpression` base class makes implementations of its derived classes
more secure by enforcing the developer to provide implementations for these methods,
so that SymPy mechanisms work correctly. Decorators like :func:`implement_expr` and
:func:`implement_doit_method` provide convenient means to implement the missing
methods.
.. autolink-preface::
import sympy as sp
from ampform.sympy import UnevaluatedExpression, create_expression
.. automethod:: __new__
.. automethod:: evaluate
.. automethod:: _latex
"""

# https://github.com/sympy/sympy/blob/1.8/sympy/core/basic.py#L74-L77
__slots__: tuple[str] = ("_name",)
_name: str | None
"""Optional instance attribute that can be used in LaTeX representations."""

def __new__(
cls: type[DecoratedClass],
*args,
name: str | None = None,
**hints,
) -> DecoratedClass:
"""Constructor for a class derived from `UnevaluatedExpression`.
This :meth:`~object.__new__` method correctly sets the
`~sympy.core.basic.Basic.args`, assumptions etc. Overwrite it in order to
further specify its signature. The function :func:`create_expression` can be
used in its implementation, like so:
>>> class MyExpression(UnevaluatedExpression):
... def __new__(
... cls, x: sp.Symbol, y: sp.Symbol, n: int, **hints
... ) -> "MyExpression":
... return create_expression(cls, x, y, n, **hints)
...
... def evaluate(self) -> sp.Expr:
... x, y, n = self.args
... return (x + y)**n
...
>>> x, y = sp.symbols("x y")
>>> expr = MyExpression(x, y, n=3)
>>> expr
MyExpression(x, y, 3)
>>> expr.evaluate()
(x + y)**3
"""
# https://github.com/sympy/sympy/blob/1.8/sympy/core/basic.py#L113-L119
obj = object.__new__(cls)
obj._args = args
obj._assumptions = cls.default_assumptions # type: ignore[attr-defined]
obj._mhash = None
obj._name = name
return obj

def __getnewargs_ex__(self) -> tuple[tuple, dict]:
# Pickling support, see
# https://github.com/sympy/sympy/blob/1.8/sympy/core/basic.py#L124-L126
args = tuple(self.args)
kwargs = {"name": self._name}
return args, kwargs

def _hashable_content(self) -> tuple:
# https://github.com/sympy/sympy/blob/1.10/sympy/core/basic.py#L157-L165
# name is converted to string because unstable hash for None
return (*super()._hashable_content(), str(self._name))

@abstractmethod
def evaluate(self) -> sp.Expr:
"""Evaluate and 'unfold' this `UnevaluatedExpression` by one level.
>>> from ampform.dynamics import BreakupMomentumSquared
>>> s, m1, m2 = sp.symbols("s m1 m2")
>>> expr = BreakupMomentumSquared(s, m1, m2)
>>> expr
BreakupMomentumSquared(s, m1, m2)
>>> expr.evaluate()
(s - (m1 - m2)**2)*(s - (m1 + m2)**2)/(4*s)
>>> expr.doit(deep=False)
(s - (m1 - m2)**2)*(s - (m1 + m2)**2)/(4*s)
.. note:: When decorating this class with :func:`implement_doit_method`,
its :meth:`evaluate` method is equivalent to
:meth:`~sympy.core.basic.Basic.doit` with :code:`deep=False`.
"""

def _latex(self, printer: LatexPrinter, *args) -> str:
r"""Provide a mathematical Latex representation for pretty printing.
>>> from ampform.dynamics import BreakupMomentumSquared
>>> s, m1 = sp.symbols("s m1")
>>> expr = BreakupMomentumSquared(s, m1, m1)
>>> print(sp.latex(expr))
q^2\left(s\right)
>>> print(sp.latex(expr.doit()))
- m_{1}^{2} + \frac{s}{4}
"""
args = tuple(map(printer._print, self.args))
name = type(self).__name__
if self._name is not None:
name = self._name
return f"{name}{args}"


class NumPyPrintable(sp.Expr):
r"""`~sympy.core.expr.Expr` class that can lambdify to NumPy code.
This interface for classes that derive from `sympy.Expr <sympy.core.expr.Expr>`
enforce the implementation of a :meth:`_numpycode` method in case the class does not
correctly :func:`~sympy.utilities.lambdify.lambdify` to NumPy code. For more info on
SymPy printers, see :doc:`sympy:modules/printing`.
This interface is for classes that derive from `sympy.Expr <sympy.core.expr.Expr>`
and that require a :meth:`_numpycode` method in case the class does not correctly
:func:`~sympy.utilities.lambdify.lambdify` to NumPy code. For more info on SymPy
printers, see :doc:`sympy:modules/printing`.
Several computational frameworks try to converge their interface to that of NumPy.
See for instance `TensorFlow's NumPy API
Expand All @@ -177,9 +63,9 @@ class NumPyPrintable(sp.Expr):
:func:`~sympy.utilities.lambdify.lambdify` SymPy expressions to these different
backends with the same lambdification code.
.. note:: This interface differs from `UnevaluatedExpression` in that it **should
not** implement an :meth:`.evaluate` (and therefore a
:meth:`~sympy.core.basic.Basic.doit`) method.
.. warning:: If you decorate this class with :func:`unevaluated`, you usually want
to do so with :code:`implement_doit=False`, because you do not want the class
to be 'unfolded' with :meth:`~sympy.core.basic.Basic.doit` before lambdification.
.. warning:: The implemented :meth:`_numpycode` method should countain as little
Expand All @@ -199,117 +85,6 @@ def _numpycode(self, printer: NumPyPrinter, *args) -> str:
"""Lambdify this `NumPyPrintable` class to NumPy code."""


DecoratedClass = TypeVar("DecoratedClass", bound=UnevaluatedExpression)
"""`~typing.TypeVar` for decorators like :func:`implement_doit_method`."""


def implement_expr(
n_args: int,
) -> Callable[[type[DecoratedClass]], type[DecoratedClass]]:
"""Decorator for classes that derive from `UnevaluatedExpression`.
Implement a :meth:`~object.__new__` and :meth:`~sympy.core.basic.Basic.doit` method
for a class that derives from `~sympy.core.expr.Expr` (via `UnevaluatedExpression`).
"""

def decorator(
decorated_class: type[DecoratedClass],
) -> type[DecoratedClass]:
decorated_class = implement_new_method(n_args)(decorated_class)
return implement_doit_method(decorated_class)

return decorator


def implement_new_method(
n_args: int,
) -> Callable[[type[DecoratedClass]], type[DecoratedClass]]:
"""Implement :meth:`UnevaluatedExpression.__new__` on a derived class.
Implement a :meth:`~object.__new__` method for a class that derives from
`~sympy.core.expr.Expr` (via `UnevaluatedExpression`).
"""

def decorator(
decorated_class: type[DecoratedClass],
) -> type[DecoratedClass]:
def new_method(
cls: type[DecoratedClass],
*args: sp.Symbol,
evaluate: bool = False,
**hints,
) -> DecoratedClass:
if len(args) != n_args:
msg = f"{n_args} parameters expected, got {len(args)}"
raise ValueError(msg)
args = sp.sympify(args)
expr = UnevaluatedExpression.__new__(cls, *args)
if evaluate:
return expr.evaluate() # type: ignore[return-value]
return expr

decorated_class.__new__ = new_method # type: ignore[assignment]
return decorated_class

return decorator


def implement_doit_method(
decorated_class: type[DecoratedClass],
) -> type[DecoratedClass]:
"""Implement ``doit()`` method for an `UnevaluatedExpression` class.
Implement a :meth:`~sympy.core.basic.Basic.doit` method for a class that derives
from `~sympy.core.expr.Expr` (via `UnevaluatedExpression`). A
:meth:`~sympy.core.basic.Basic.doit` method is an extension of an
:meth:`~.UnevaluatedExpression.evaluate` method in the sense that it can work
recursively on deeper expression trees.
"""

@functools.wraps(decorated_class.doit) # type: ignore[attr-defined]
def doit_method(self: UnevaluatedExpression, deep: bool = True) -> sp.Expr:
expr = self.evaluate()
if deep:
return expr.doit()
return expr

decorated_class.doit = doit_method # type: ignore[assignment]
return decorated_class


DecoratedExpr = TypeVar("DecoratedExpr", bound=sp.Expr)
"""`~typing.TypeVar` for decorators like :func:`make_commutative`."""


def make_commutative(
decorated_class: type[DecoratedExpr],
) -> type[DecoratedExpr]:
"""Set commutative and 'extended real' assumptions on expression class.
.. seealso:: :doc:`sympy:guides/assumptions`
"""
decorated_class.is_commutative = True # type: ignore[attr-defined]
decorated_class.is_extended_real = True # type: ignore[attr-defined]
return decorated_class


def create_expression(
cls: type[DecoratedExpr],
*args,
evaluate: bool = False,
name: str | None = None,
**kwargs,
) -> DecoratedExpr:
"""Helper function for implementing `UnevaluatedExpression.__new__`."""
args = sp.sympify(args)
if issubclass(cls, UnevaluatedExpression):
expr = UnevaluatedExpression.__new__(cls, *args, name=name, **kwargs)
if evaluate:
return expr.evaluate() # type: ignore[return-value]
return expr # type: ignore[return-value]
return sp.Expr.__new__(cls, *args, **kwargs) # type: ignore[return-value]


def create_symbol_matrix(name: str, m: int, n: int) -> sp.MutableDenseMatrix:
"""Create a `~sympy.matrices.dense.Matrix` with symbols as elements.
Expand All @@ -330,8 +105,7 @@ def create_symbol_matrix(name: str, m: int, n: int) -> sp.MutableDenseMatrix:
return sp.Matrix([[symbol[i, j] for j in range(n)] for i in range(m)])


@implement_doit_method
class PoolSum(UnevaluatedExpression):
class PoolSum(sp.Expr):
r"""Sum over indices where the values are taken from a domain set.
>>> i, j, m, n = sp.symbols("i j m n")
Expand All @@ -350,6 +124,7 @@ def __new__(
cls,
expression,
*indices: tuple[sp.Symbol, Iterable[sp.Basic]],
evaluate: bool = False,
**hints,
) -> PoolSum:
converted_indices = []
Expand All @@ -359,7 +134,11 @@ def __new__(
msg = f"No values provided for index {idx_symbol}"
raise ValueError(msg)
converted_indices.append((idx_symbol, values))
return create_expression(cls, expression, *converted_indices, **hints)
args = sp.sympify((expression, *converted_indices))
expr: PoolSum = sp.Expr.__new__(cls, *args, **hints)
if evaluate:
return expr.evaluate() # type: ignore[return-value]
return expr

@property
def expression(self) -> sp.Expr:
Expand All @@ -373,6 +152,12 @@ def indices(self) -> list[tuple[sp.Symbol, tuple[sp.Float, ...]]]:
def free_symbols(self) -> set[sp.Basic]:
return super().free_symbols - {s for s, _ in self.indices}

def doit(self, deep: bool = True) -> sp.Expr: # type: ignore[override]
expr = self.evaluate()
if deep:
return expr.doit()
return expr

def evaluate(self) -> sp.Expr:
indices = {symbol: tuple(values) for symbol, values in self.indices}
return sp.Add(*[
Expand Down
Loading

0 comments on commit 17f383e

Please sign in to comment.