From c2a6fd0b379cf479268839be7b15b1058a305abe Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Tue, 6 Aug 2024 10:27:08 +0200 Subject: [PATCH] MAINT: upgrade to SymPy v1.13 (#435) * DX: ignore missing types `sympy` * FIX: adjust simplification code for SymPy v1.13 * MAINT: address `mypy` errors --- .constraints/py3.10.txt | 2 +- .constraints/py3.11.txt | 2 +- .constraints/py3.12.txt | 2 +- .constraints/py3.8.txt | 2 +- .constraints/py3.9.txt | 2 +- docs/usage/dynamics/k-matrix.ipynb | 18 +++++++++--------- pyproject.toml | 4 ++++ src/ampform/dynamics/__init__.py | 4 ++-- src/ampform/dynamics/builder.py | 8 ++++---- src/ampform/dynamics/kmatrix.py | 6 +++--- src/ampform/kinematics/lorentz.py | 7 ++++++- src/ampform/sympy/__init__.py | 2 +- src/ampform/sympy/_decorator.py | 2 +- src/ampform/sympy/deprecated.py | 2 +- src/ampform/sympy/math.py | 2 +- tests/dynamics/test_deprecated.py | 2 +- tests/dynamics/test_dynamics.py | 4 ++-- tests/dynamics/test_kmatrix.py | 11 ++++------- 18 files changed, 44 insertions(+), 38 deletions(-) diff --git a/.constraints/py3.10.txt b/.constraints/py3.10.txt index db84a3ddf..610c6ead3 100644 --- a/.constraints/py3.10.txt +++ b/.constraints/py3.10.txt @@ -183,7 +183,7 @@ sphinxcontrib-serializinghtml==1.1.10 sqlalchemy==2.0.31 stack-data==0.6.3 starlette==0.37.2 -sympy==1.12.1 +sympy==1.13.1 tabulate==0.9.0 terminado==0.18.1 tinycss2==1.3.0 diff --git a/.constraints/py3.11.txt b/.constraints/py3.11.txt index e84da6940..d93c316b2 100644 --- a/.constraints/py3.11.txt +++ b/.constraints/py3.11.txt @@ -182,7 +182,7 @@ sphinxcontrib-serializinghtml==1.1.10 sqlalchemy==2.0.31 stack-data==0.6.3 starlette==0.37.2 -sympy==1.12.1 +sympy==1.13.1 tabulate==0.9.0 terminado==0.18.1 tinycss2==1.3.0 diff --git a/.constraints/py3.12.txt b/.constraints/py3.12.txt index 7e1ee4831..416369118 100644 --- a/.constraints/py3.12.txt +++ b/.constraints/py3.12.txt @@ -182,7 +182,7 @@ sphinxcontrib-serializinghtml==1.1.10 sqlalchemy==2.0.31 stack-data==0.6.3 starlette==0.37.2 -sympy==1.12.1 +sympy==1.13.1 tabulate==0.9.0 terminado==0.18.1 tinycss2==1.3.0 diff --git a/.constraints/py3.8.txt b/.constraints/py3.8.txt index 4ade28364..97a6c41ab 100644 --- a/.constraints/py3.8.txt +++ b/.constraints/py3.8.txt @@ -188,7 +188,7 @@ sphinxcontrib-qthelp==1.0.3 sphinxcontrib-serializinghtml==1.1.5 sqlalchemy==2.0.31 stack-data==0.6.3 -sympy==1.12.1 +sympy==1.13.1 tabulate==0.9.0 terminado==0.18.1 tinycss2==1.3.0 diff --git a/.constraints/py3.9.txt b/.constraints/py3.9.txt index 380c4d859..ae86bbd4e 100644 --- a/.constraints/py3.9.txt +++ b/.constraints/py3.9.txt @@ -184,7 +184,7 @@ sphinxcontrib-serializinghtml==1.1.10 sqlalchemy==2.0.31 stack-data==0.6.3 starlette==0.37.2 -sympy==1.12.1 +sympy==1.13.1 tabulate==0.9.0 terminado==0.18.1 tinycss2==1.3.0 diff --git a/docs/usage/dynamics/k-matrix.ipynb b/docs/usage/dynamics/k-matrix.ipynb index b401194e2..24dd7bd89 100644 --- a/docs/usage/dynamics/k-matrix.ipynb +++ b/docs/usage/dynamics/k-matrix.ipynb @@ -642,9 +642,9 @@ "outputs": [], "source": [ "# reformulate terms\n", - "denominator, nominator = k_matrix.args\n", - "term1 = nominator.args[0] * denominator\n", - "term2 = nominator.args[1] * denominator\n", + "*rest, denominator, nominator = k_matrix.args\n", + "term1 = nominator.args[0] * denominator * sp.Mul(*rest)\n", + "term2 = nominator.args[1] * denominator * sp.Mul(*rest)\n", "k_matrix = term1 + term2\n", "k_matrix" ] @@ -934,9 +934,9 @@ " sp.sqrt(rho): 1,\n", " sp.conjugate(sp.sqrt(rho)): 1,\n", "})\n", - "denominator, nominator = rel_k_matrix_2r.args\n", - "term1 = nominator.args[0] * denominator\n", - "term2 = nominator.args[1] * denominator\n", + "*rest, denominator, nominator = rel_k_matrix_2r.args\n", + "term1 = nominator.args[0] * denominator * sp.Mul(*rest)\n", + "term2 = nominator.args[1] * denominator * sp.Mul(*rest)\n", "rel_k_matrix_2r = term1 + term2\n", "rel_k_matrix_2r" ] @@ -1081,9 +1081,9 @@ }, "outputs": [], "source": [ - "denominator, nominator = f_vector.args\n", - "term1 = nominator.args[0] * denominator\n", - "term2 = nominator.args[1] * denominator\n", + "*rest, denominator, nominator = f_vector.args\n", + "term1 = nominator.args[0] * denominator * sp.Mul(*rest)\n", + "term2 = nominator.args[1] * denominator * sp.Mul(*rest)\n", "f_vector = term1 + term2\n", "f_vector" ] diff --git a/pyproject.toml b/pyproject.toml index ab262546d..18c09ff4b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -160,6 +160,10 @@ exclude = "_build" show_error_codes = true warn_unused_configs = true +[[tool.mypy.overrides]] +ignore_missing_imports = true +module = ["sympy.*"] + [[tool.mypy.overrides]] ignore_missing_imports = true module = ["graphviz.*"] diff --git a/src/ampform/dynamics/__init__.py b/src/ampform/dynamics/__init__.py index 20d044774..cb0260e88 100644 --- a/src/ampform/dynamics/__init__.py +++ b/src/ampform/dynamics/__init__.py @@ -54,7 +54,7 @@ class EnergyDependentWidth(sp.Expr): m_b: Any angular_momentum: Any meson_radius: Any - phsp_factor: PhaseSpaceFactorProtocol = argument( + phsp_factor: PhaseSpaceFactorProtocol = argument( # type:ignore[assignment] default=PhaseSpaceFactor, sympify=False ) name: str | None = argument(default=None, sympify=False) @@ -92,7 +92,7 @@ def relativistic_breit_wigner_with_ff( # noqa: PLR0917 m_b, angular_momentum, meson_radius, - phsp_factor: PhaseSpaceFactorProtocol = PhaseSpaceFactor, + phsp_factor: PhaseSpaceFactorProtocol = PhaseSpaceFactor, # type:ignore[assignment] ) -> sp.Expr: """Relativistic Breit-Wigner with `.FormFactor`. diff --git a/src/ampform/dynamics/builder.py b/src/ampform/dynamics/builder.py index 7e717e661..2a5904c53 100644 --- a/src/ampform/dynamics/builder.py +++ b/src/ampform/dynamics/builder.py @@ -123,7 +123,7 @@ def __init__( phsp_factor: PhaseSpaceFactorProtocol | None = None, ) -> None: if phsp_factor is None: - phsp_factor = PhaseSpaceFactor + phsp_factor = PhaseSpaceFactor # type:ignore[arg-type,assignment] self.phsp_factor = phsp_factor self.energy_dependent_width = energy_dependent_width self.form_factor = form_factor @@ -189,7 +189,7 @@ def __energy_dependent_breit_wigner( m_b=m_b, angular_momentum=angular_momentum, meson_radius=meson_radius, - phsp_factor=self.phsp_factor, + phsp_factor=self.phsp_factor, # type:ignore[arg-type] ) breit_wigner_expr = (res_mass * res_width) / ( res_mass**2 - s - mass_dependent_width * res_mass * sp.I @@ -245,7 +245,7 @@ def __create_symbols( create_relativistic_breit_wigner_with_ff = RelativisticBreitWignerBuilder( energy_dependent_width=True, form_factor=True, - phsp_factor=PhaseSpaceFactor, + phsp_factor=PhaseSpaceFactor, # type:ignore[arg-type] ).__call__ """Create a `.relativistic_breit_wigner_with_ff` for a two-body decay. @@ -256,7 +256,7 @@ def __create_symbols( create_analytic_breit_wigner = RelativisticBreitWignerBuilder( energy_dependent_width=True, form_factor=True, - phsp_factor=EqualMassPhaseSpaceFactor, + phsp_factor=EqualMassPhaseSpaceFactor, # type:ignore[arg-type] ).__call__ """Create a `.relativistic_breit_wigner_with_ff` with analytic continuation. diff --git a/src/ampform/dynamics/kmatrix.py b/src/ampform/dynamics/kmatrix.py index da2e3bf63..47bf7dd29 100644 --- a/src/ampform/dynamics/kmatrix.py +++ b/src/ampform/dynamics/kmatrix.py @@ -56,7 +56,7 @@ def formulate( # type: ignore[override] # noqa: D417 n_poles, parametrize: bool = True, return_t_hat: bool = False, - phsp_factor: PhaseSpaceFactorProtocol = PhaseSpaceFactor, + phsp_factor: PhaseSpaceFactorProtocol = PhaseSpaceFactor, # type:ignore[assignment] angular_momentum=0, meson_radius=1, ) -> sp.MutableDenseMatrix: @@ -116,7 +116,7 @@ def parametrization( # noqa: PLR0917 pole_id, angular_momentum=0, meson_radius=1, - phsp_factor: PhaseSpaceFactorProtocol = PhaseSpaceFactor, + phsp_factor: PhaseSpaceFactorProtocol = PhaseSpaceFactor, # type:ignore[assignment] ) -> sp.Expr: def residue_function(pole_id, i) -> sp.Expr: return residue_constant[pole_id, i] * sp.sqrt( @@ -296,7 +296,7 @@ def formulate( # type: ignore[override] # noqa: D417 n_poles, parametrize: bool = True, return_f_hat: bool = False, - phsp_factor: PhaseSpaceFactorProtocol = PhaseSpaceFactor, + phsp_factor: PhaseSpaceFactorProtocol = PhaseSpaceFactor, # type:ignore[assignment] angular_momentum=0, meson_radius=1, ) -> sp.MutableDenseMatrix: diff --git a/src/ampform/kinematics/lorentz.py b/src/ampform/kinematics/lorentz.py index 6d4a14deb..076f50554 100644 --- a/src/ampform/kinematics/lorentz.py +++ b/src/ampform/kinematics/lorentz.py @@ -2,6 +2,7 @@ from __future__ import annotations +import sys from typing import TYPE_CHECKING, Any, Callable, Dict import sympy as sp @@ -17,6 +18,10 @@ ) from ampform.sympy.math import ComplexSqrt +if sys.version_info < (3, 10): + from typing_extensions import TypeAlias +else: + from typing import TypeAlias if TYPE_CHECKING: from qrules.topology import Topology from sympy.printing.latex import LatexPrinter @@ -45,7 +50,7 @@ def create_four_momentum_symbol(index: int) -> FourMomentumSymbol: It's best to create a `dict` of `.FourMomenta` with :func:`create_four_momentum_symbols`. """ -FourMomentumSymbol = ArraySymbol +FourMomentumSymbol: TypeAlias = ArraySymbol r"""Array-`~sympy.core.symbol.Symbol` that represents an array of four-momenta. The array is assumed to be of shape :math:`n\times 4` with :math:`n` the number of diff --git a/src/ampform/sympy/__init__.py b/src/ampform/sympy/__init__.py index c21bedb10..150e7864b 100644 --- a/src/ampform/sympy/__init__.py +++ b/src/ampform/sympy/__init__.py @@ -164,7 +164,7 @@ def free_symbols(self) -> set[sp.Basic]: return super().free_symbols - {s for s, _ in self.indices} @override - def doit(self, deep: bool = True) -> sp.Expr: # type: ignore[override] + def doit(self, deep: bool = True) -> sp.Expr: # type: ignore[misc] expr = self.evaluate() if deep: return expr.doit() diff --git a/src/ampform/sympy/_decorator.py b/src/ampform/sympy/_decorator.py index c7c417d6e..ebd5a27a4 100644 --- a/src/ampform/sympy/_decorator.py +++ b/src/ampform/sympy/_decorator.py @@ -274,7 +274,7 @@ def new_method(cls, *args, evaluate: bool = False, **kwargs) -> type[ExprClass]: return expr.evaluate() return expr - cls.__new__ = new_method # type: ignore[method-assign] + cls.__new__ = new_method # type: ignore[assignment] cls.__getnewargs__ = _get_arguments # type: ignore[assignment,method-assign] cls._hashable_content = _hashable_content_method # type: ignore[method-assign] if non_sympy_fields: diff --git a/src/ampform/sympy/deprecated.py b/src/ampform/sympy/deprecated.py index 34211e113..135caaf34 100644 --- a/src/ampform/sympy/deprecated.py +++ b/src/ampform/sympy/deprecated.py @@ -108,7 +108,7 @@ def __getnewargs_ex__(self) -> tuple[tuple, dict]: kwargs = {"name": self._name} return args, kwargs - @override + @override # type:ignore[misc] def _hashable_content(self) -> tuple: # https://github.com/sympy/sympy/blob/1.10/sympy/core/basic.py#L157-L165 # name is converted to string because unstable hash for None diff --git a/src/ampform/sympy/math.py b/src/ampform/sympy/math.py index eac223aa3..04c40dcd5 100644 --- a/src/ampform/sympy/math.py +++ b/src/ampform/sympy/math.py @@ -37,7 +37,7 @@ class ComplexSqrt(NumPyPrintable): @overload def __new__(cls, x: sp.Number, *args, **kwargs) -> sp.Expr: ... # type: ignore[misc] @overload - def __new__(cls, x: sp.Expr, *args, **kwargs) -> ComplexSqrt: ... + def __new__(cls, x: sp.Expr, *args, **kwargs) -> ComplexSqrt: ... # type:ignore[misc] @override def __new__(cls, x, *args, **kwargs): x = sp.sympify(x) diff --git a/tests/dynamics/test_deprecated.py b/tests/dynamics/test_deprecated.py index 0315b30f6..3bc2c364c 100644 --- a/tests/dynamics/test_deprecated.py +++ b/tests/dynamics/test_deprecated.py @@ -38,7 +38,7 @@ def test_pickle(): m_b=m_a, angular_momentum=0, meson_radius=1, - phsp_factor=EqualMassPhaseSpaceFactor, + phsp_factor=EqualMassPhaseSpaceFactor, # type:ignore[arg-type] name="Gamma_1", ) pickled_obj = pickle.dumps(expr) diff --git a/tests/dynamics/test_dynamics.py b/tests/dynamics/test_dynamics.py index 6b8d83201..06042f4dd 100644 --- a/tests/dynamics/test_dynamics.py +++ b/tests/dynamics/test_dynamics.py @@ -47,7 +47,7 @@ def test_init(): m_b=m_b, angular_momentum=angular_momentum, meson_radius=d, - phsp_factor=EqualMassPhaseSpaceFactor, + phsp_factor=EqualMassPhaseSpaceFactor, # type:ignore[arg-type] name="Gamma_1", ) assert width.phsp_factor is EqualMassPhaseSpaceFactor @@ -70,7 +70,7 @@ def test_doit_and_subs(self, method: str): m_b=m_a, angular_momentum=0, meson_radius=1, - phsp_factor=PhaseSpaceFactorSWave, + phsp_factor=PhaseSpaceFactorSWave, # type:ignore[arg-type] ) subs_first = round_nested(_subs(width, parameters, method).doit(), n_decimals=3) doit_first = round_nested(_subs(width.doit(), parameters, method), n_decimals=3) diff --git a/tests/dynamics/test_kmatrix.py b/tests/dynamics/test_kmatrix.py index 0d1f1ddf2..bd37c3baa 100644 --- a/tests/dynamics/test_kmatrix.py +++ b/tests/dynamics/test_kmatrix.py @@ -1,16 +1,13 @@ from __future__ import annotations import re -from typing import TYPE_CHECKING import pytest +import sympy as sp from ampform.dynamics.kmatrix import NonRelativisticKMatrix from symplot import rename_symbols, substitute_indexed_symbols -if TYPE_CHECKING: - import sympy as sp - class TestNonRelativisticKMatrix: @pytest.mark.parametrize( @@ -35,9 +32,9 @@ def test_interference_single_channel(self): expr = substitute_indexed_symbols(expr) expr = _remove_residue_constants(expr) expr = _rename_widths(expr) - denominator, nominator = expr.args - term1 = nominator.args[0] * denominator - term2 = nominator.args[1] * denominator + *rest, denominator, nominator = expr.args + term1 = nominator.args[0] * denominator * sp.Mul(*rest) + term2 = nominator.args[1] * denominator * sp.Mul(*rest) assert str(term1 / term2) == R"m1*w1*(m2**2 - s)/(m2*w2*(m1**2 - s))"