diff --git a/src/decaylanguage/_compat/typing.py b/src/decaylanguage/_compat/typing.py new file mode 100644 index 00000000..c99b7471 --- /dev/null +++ b/src/decaylanguage/_compat/typing.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +import sys +import typing + +if sys.version_info < (3, 8): + from typing_extensions import TypedDict +else: + from typing import TypedDict + +if sys.version_info < (3, 11): + if typing.TYPE_CHECKING: + from typing_extensions import Self + else: + Self = object +else: + from typing import Self + + +__all__ = ["TypedDict", "Self"] diff --git a/src/decaylanguage/dec/dec.py b/src/decaylanguage/dec/dec.py index 237edc91..c2bfebef 100644 --- a/src/decaylanguage/dec/dec.py +++ b/src/decaylanguage/dec/dec.py @@ -47,14 +47,15 @@ from io import StringIO from itertools import zip_longest from pathlib import Path -from typing import Any, TypeVar +from typing import Any from lark import Lark, Token, Transformer, Tree, Visitor from particle import Particle from particle.converters import PDG2EvtGenNameMap from .. import data -from ..decay.decay import _expand_decay_modes +from .._compat.typing import Self +from ..decay.decay import DecayModeDict, _expand_decay_modes from ..utils import charge_conjugate_name from .enums import PhotosEnum @@ -67,9 +68,6 @@ class DecayNotFound(RuntimeError): pass -Self_DecFileParser = TypeVar("Self_DecFileParser", bound="DecFileParser") - - class DecFileParser: """ The class to parse a .dec decay file. @@ -142,9 +140,7 @@ def __init__(self, *filenames: str | os.PathLike[str]) -> None: self._include_ccdecays = True @classmethod - def from_string( - cls: type[Self_DecFileParser], filecontent: str - ) -> Self_DecFileParser: + def from_string(cls, filecontent: str) -> Self: """ Constructor from a .dec decay file provided as a multi-line string. @@ -751,7 +747,7 @@ def list_decay_modes(self, mother: str, pdg_name: bool = False) -> list[list[str def _decay_mode_details( self, decay_mode: Tree, display_photos_keyword: bool = True - ) -> tuple[float, list[str], str, str | list[str | Any]]: + ) -> DecayModeDict: """ Parse a decay mode (Tree instance) and return the relevant bits of information in it. @@ -772,7 +768,9 @@ def _decay_mode_details( if display_photos_keyword and list(decay_mode.find_data("photos")): model = "PHOTOS " + model - return (bf, fsp_names, model, model_params) + return DecayModeDict( + bf=bf, fs=fsp_names, model=model, model_params=model_params + ) def print_decay_modes( self, @@ -868,11 +866,13 @@ def print_decay_modes( ls_dict = {} for dm in dms: - bf, fsp_names, model, model_params = self._decay_mode_details( - dm, display_photos_keyword + dmdict = self._decay_mode_details(dm, display_photos_keyword) + model_params = [str(i) for i in dmdict["model_params"]] + ls_dict[dmdict["bf"]] = ( + dmdict["fs"], + dmdict["model"], + model_params, ) - model_params = [str(i) for i in model_params] - ls_dict[bf] = (fsp_names, model, model_params) dec_details = list(ls_dict.values()) ls_attrs_aligned = list( @@ -941,7 +941,7 @@ def build_decay_chains( self, mother: str, stable_particles: list[str] | set[str] | tuple[str] | tuple[()] = (), - ) -> dict[str, list[dict[str, float | str | list[Any]]]]: + ) -> dict[str, list[DecayModeDict]]: """ Iteratively build the entire decay chains of a given mother particle, optionally considering, on the fly, certain particles as stable. @@ -996,14 +996,12 @@ def build_decay_chains( >>> p.build_decay_chains('D+', stable_particles=['pi0']) # doctest: +SKIP {'D+': [{'bf': 1.0, 'fs': ['K-', 'pi+', 'pi+', 'pi0'], 'model': 'PHSP', 'model_params': ''}]} """ - keys = ("bf", "fs", "model", "model_params") info = [] for dm in self._find_decay_modes(mother): - list_dm_details = self._decay_mode_details(dm, display_photos_keyword=False) - d = dict(zip(keys, list_dm_details)) + d = self._decay_mode_details(dm, display_photos_keyword=False) - for i, fs in enumerate(d["fs"]): # type: ignore[arg-type] + for i, fs in enumerate(d["fs"]): if fs in stable_particles: continue @@ -1012,8 +1010,9 @@ def build_decay_chains( # if fs does not have decays defined in the parsed file # _n_dms = len(self._find_decay_modes(fs)) + assert isinstance(fs, str) _info = self.build_decay_chains(fs, stable_particles) - d["fs"][i] = _info # type: ignore[index,call-overload] + d["fs"][i] = _info # type: ignore[index] except DecayNotFound: pass diff --git a/src/decaylanguage/decay/decay.py b/src/decaylanguage/decay/decay.py index b9c462f6..728762fc 100644 --- a/src/decaylanguage/decay/decay.py +++ b/src/decaylanguage/decay/decay.py @@ -7,24 +7,34 @@ from __future__ import annotations from collections import Counter +from collections.abc import Sequence from copy import deepcopy from itertools import product -from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, TypeVar, Union +from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List from particle import PDGID, ParticleNotFound from particle.converters import EvtGenName2PDGIDBiMap from particle.exceptions import MatchingIDNotFound +from .._compat.typing import Self, TypedDict from ..utils import DescriptorFormat, charge_conjugate_name -Self_DaughtersDict = TypeVar("Self_DaughtersDict", bound="DaughtersDict") - if TYPE_CHECKING: CounterStr = Counter[str] # pragma: no cover else: CounterStr = Counter +class DecayModeDict(TypedDict): + bf: float + fs: Sequence[str | DecayChainDict] + model: str + model_params: str | Sequence[str | Any] + + +DecayChainDict = Dict[str, list[DecayModeDict]] + + class DaughtersDict(CounterStr): """ Class holding a decay final state as a dictionary. @@ -107,9 +117,7 @@ def to_list(self) -> list[str]: """ return sorted(self.elements()) - def charge_conjugate( - self: Self_DaughtersDict, pdg_name: bool = False - ) -> Self_DaughtersDict: + def charge_conjugate(self, pdg_name: bool = False) -> Self: """ Return the charge-conjugate final state. @@ -156,7 +164,7 @@ def __len__(self) -> int: """ return sum(self.values()) - def __add__(self: Self_DaughtersDict, other: Self_DaughtersDict) -> Self_DaughtersDict: # type: ignore[override] + def __add__(self, other: Self) -> Self: # type: ignore[override] """ Add two final states, particle-type-wise. """ @@ -167,9 +175,6 @@ def __iter__(self) -> Iterator[str]: return self.elements() -Self_DecayMode = TypeVar("Self_DecayMode", bound="DecayMode") - - class DecayMode: """ Class holding a particle decay mode, which is typically a branching fraction @@ -254,9 +259,9 @@ def __init__( @classmethod def from_dict( - cls: type[Self_DecayMode], - decay_mode_dict: dict[str, int | float | str | list[str]], - ) -> Self_DecayMode: + cls, + decay_mode_dict: DecayModeDict, + ) -> Self: """ Constructor from a dictionary of the form {'bf': , 'fs': [...], ...}. @@ -291,21 +296,18 @@ def from_dict( dm = deepcopy(decay_mode_dict) # Ensure the input dict has the 2 required keys 'bf' and 'fs' - try: - bf = dm.pop("bf") - daughters = dm.pop("fs") - except KeyError as e: - raise RuntimeError("Input not in the expected format!") from e + if not dm.keys() >= {"bf", "fs"}: + raise RuntimeError("Input not in the expected format! Needs 'bf' and 'fs'") - return cls(bf=bf, daughters=daughters, **dm) # type: ignore[arg-type] + return cls(**dm) @classmethod def from_pdgids( - cls: type[Self_DecayMode], + cls, bf: float = 0, daughters: list[int] | tuple[int] | None = None, **info: Any, - ) -> Self_DecayMode: + ) -> Self: """ Constructor for a final state given as a list of particle PDG IDs. @@ -397,9 +399,7 @@ def to_dict(self) -> dict[str, int | float | str | list[str]]: d["model_params"] = "" return d # type: ignore[return-value] - def charge_conjugate( - self: Self_DecayMode, pdg_name: bool = False - ) -> Self_DecayMode: + def charge_conjugate(self, pdg_name: bool = False) -> Self: """ Return the charge-conjugate decay mode. @@ -444,11 +444,6 @@ def __str__(self) -> str: return repr(self) -Self_DecayChain = TypeVar("Self_DecayChain", bound="DecayChain") -DecayModeDict = Dict[str, Union[float, str, List[Any]]] -DecayChainDict = Dict[str, List[DecayModeDict]] - - def _has_no_subdecay(ds: list[Any]) -> bool: """ Internal function to check whether the input list @@ -772,9 +767,7 @@ def __init__(self, mother: str, decays: dict[str, DecayMode]) -> None: self.decays = decays @classmethod - def from_dict( - cls: type[Self_DecayChain], decay_chain_dict: DecayChainDict - ) -> Self_DecayChain: + def from_dict(cls, decay_chain_dict: DecayChainDict) -> Self: """ Constructor from a decay chain represented as a dictionary. The format is the same as that returned by @@ -905,9 +898,9 @@ def _print( for i_decay in decay_dict[mother]: print(prefix, arrow if depth > 0 else "", mother, sep="") # noqa: T201 fsps = i_decay["fs"] - n = len(list(fsps)) # type: ignore[arg-type] + n = len(list(fsps)) depth += 1 - for j, fsp in enumerate(fsps): # type: ignore[arg-type] + for j, fsp in enumerate(fsps): prefix = bar if (link and depth > 1) else "" if last: prefix = prefix + " " * indent * (depth - 1) + " " @@ -966,9 +959,9 @@ def recursively_replace(mother: str) -> DecayDict: return recursively_replace(self.mother) def flatten( - self: Self_DecayChain, + self, stable_particles: Iterable[dict[str, int] | list[str] | str] = (), - ) -> Self_DecayChain: + ) -> Self: """ Flatten the decay chain replacing all intermediate, decaying particles, with their final states.