Skip to content

Commit

Permalink
fix: allow creating UnevaluatedExpr with doit arg (#113)
Browse files Browse the repository at this point in the history
  • Loading branch information
redeboer committed Aug 3, 2021
1 parent c7e3df3 commit da323a7
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 6 deletions.
11 changes: 6 additions & 5 deletions src/ampform/dynamics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
import sympy as sp
from sympy.printing.latex import LatexPrinter

from .decorator import UnevaluatedExpression, implement_doit_method
from .decorator import (
UnevaluatedExpression,
create_expression,
implement_doit_method,
)
from .math import ComplexSqrt

try:
Expand Down Expand Up @@ -56,10 +60,7 @@ def __new__( # pylint: disable=arguments-differ
**hints: Any,
) -> "BlattWeisskopfSquared":
args = sp.sympify((angular_momentum, z))
if evaluate:
# pylint: disable=no-member
return sp.Expr.__new__(cls, *args, **hints).evaluate()
return sp.Expr.__new__(cls, *args, **hints)
return create_expression(cls, evaluate, *args, **hints)

def evaluate(self) -> sp.Expr:
angular_momentum, z = self.args
Expand Down
19 changes: 18 additions & 1 deletion src/ampform/dynamics/decorator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Tools for defining lineshapes with `sympy`."""

from abc import abstractmethod
from typing import Any, Callable, Type
from typing import Any, Callable, Optional, Type

import sympy as sp
from sympy.printing.latex import LatexPrinter
Expand Down Expand Up @@ -99,3 +99,20 @@ def doit_method(self: Any, **hints: Any) -> sp.Expr:
return decorated_class

return decorator


def create_expression(
cls: Type[UnevaluatedExpression], evaluate: bool, *args: Any, **kwargs: Any
) -> sp.Expr:
"""Helper function for implementing :code:`Expr.__new__`.
See e.g. source code of `.BlattWeisskopfSquared`.
"""
# pylint: disable=no-member
deep: Optional[bool] = kwargs.pop("deep", None)
expr = sp.Expr.__new__(cls, *args, **kwargs)
if evaluate:
expr = expr.evaluate()
if deep:
expr = expr.doit(deep=deep)
return expr

0 comments on commit da323a7

Please sign in to comment.