Skip to content

Commit

Permalink
fix: nicer typing
Browse files Browse the repository at this point in the history
Signed-off-by: Henry Schreiner <henryschreineriii@gmail.com>
  • Loading branch information
henryiii committed Nov 17, 2023
1 parent b682fac commit a65741d
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 56 deletions.
20 changes: 20 additions & 0 deletions src/decaylanguage/_compat/typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from __future__ import annotations

import sys
import typing

if sys.version_info < (3, 8):
from typing_extensions import TypedDict
else:
from typing import TypedDict

if sys.version_info < (3, 11):
if typing.TYPE_CHECKING:
from typing_extensions import Self
else:
Self = object
else:
from typing import Self


__all__ = ["TypedDict", "Self"]
39 changes: 19 additions & 20 deletions src/decaylanguage/dec/dec.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,15 @@
from io import StringIO
from itertools import zip_longest
from pathlib import Path
from typing import Any, TypeVar
from typing import Any

from lark import Lark, Token, Transformer, Tree, Visitor
from particle import Particle
from particle.converters import PDG2EvtGenNameMap

from .. import data
from ..decay.decay import _expand_decay_modes
from .._compat.typing import Self
from ..decay.decay import DecayModeDict, _expand_decay_modes
from ..utils import charge_conjugate_name
from .enums import PhotosEnum

Expand All @@ -67,9 +68,6 @@ class DecayNotFound(RuntimeError):
pass


Self_DecFileParser = TypeVar("Self_DecFileParser", bound="DecFileParser")


class DecFileParser:
"""
The class to parse a .dec decay file.
Expand Down Expand Up @@ -142,9 +140,7 @@ def __init__(self, *filenames: str | os.PathLike[str]) -> None:
self._include_ccdecays = True

@classmethod
def from_string(
cls: type[Self_DecFileParser], filecontent: str
) -> Self_DecFileParser:
def from_string(cls, filecontent: str) -> Self:
"""
Constructor from a .dec decay file provided as a multi-line string.
Expand Down Expand Up @@ -751,7 +747,7 @@ def list_decay_modes(self, mother: str, pdg_name: bool = False) -> list[list[str

def _decay_mode_details(
self, decay_mode: Tree, display_photos_keyword: bool = True
) -> tuple[float, list[str], str, str | list[str | Any]]:
) -> DecayModeDict:
"""
Parse a decay mode (Tree instance)
and return the relevant bits of information in it.
Expand All @@ -772,7 +768,9 @@ def _decay_mode_details(
if display_photos_keyword and list(decay_mode.find_data("photos")):
model = "PHOTOS " + model

return (bf, fsp_names, model, model_params)
return DecayModeDict(
bf=bf, fs=fsp_names, model=model, model_params=model_params
)

def print_decay_modes(
self,
Expand Down Expand Up @@ -868,11 +866,13 @@ def print_decay_modes(

ls_dict = {}
for dm in dms:
bf, fsp_names, model, model_params = self._decay_mode_details(
dm, display_photos_keyword
dmdict = self._decay_mode_details(dm, display_photos_keyword)
model_params = [str(i) for i in dmdict["model_params"]]
ls_dict[dmdict["bf"]] = (
dmdict["fs"],
dmdict["model"],
model_params,
)
model_params = [str(i) for i in model_params]
ls_dict[bf] = (fsp_names, model, model_params)

dec_details = list(ls_dict.values())
ls_attrs_aligned = list(
Expand Down Expand Up @@ -941,7 +941,7 @@ def build_decay_chains(
self,
mother: str,
stable_particles: list[str] | set[str] | tuple[str] | tuple[()] = (),
) -> dict[str, list[dict[str, float | str | list[Any]]]]:
) -> dict[str, list[DecayModeDict]]:
"""
Iteratively build the entire decay chains of a given mother particle,
optionally considering, on the fly, certain particles as stable.
Expand Down Expand Up @@ -996,14 +996,12 @@ def build_decay_chains(
>>> p.build_decay_chains('D+', stable_particles=['pi0']) # doctest: +SKIP
{'D+': [{'bf': 1.0, 'fs': ['K-', 'pi+', 'pi+', 'pi0'], 'model': 'PHSP', 'model_params': ''}]}
"""
keys = ("bf", "fs", "model", "model_params")

info = []
for dm in self._find_decay_modes(mother):
list_dm_details = self._decay_mode_details(dm, display_photos_keyword=False)
d = dict(zip(keys, list_dm_details))
d = self._decay_mode_details(dm, display_photos_keyword=False)

for i, fs in enumerate(d["fs"]): # type: ignore[arg-type]
for i, fs in enumerate(d["fs"]):
if fs in stable_particles:
continue

Expand All @@ -1012,8 +1010,9 @@ def build_decay_chains(
# if fs does not have decays defined in the parsed file
# _n_dms = len(self._find_decay_modes(fs))

assert isinstance(fs, str)
_info = self.build_decay_chains(fs, stable_particles)
d["fs"][i] = _info # type: ignore[index,call-overload]
d["fs"][i] = _info # type: ignore[index]
except DecayNotFound:
pass

Expand Down
65 changes: 29 additions & 36 deletions src/decaylanguage/decay/decay.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,34 @@
from __future__ import annotations

from collections import Counter
from collections.abc import Sequence
from copy import deepcopy
from itertools import product
from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, TypeVar, Union
from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List

from particle import PDGID, ParticleNotFound
from particle.converters import EvtGenName2PDGIDBiMap
from particle.exceptions import MatchingIDNotFound

from .._compat.typing import Self, TypedDict
from ..utils import DescriptorFormat, charge_conjugate_name

Self_DaughtersDict = TypeVar("Self_DaughtersDict", bound="DaughtersDict")

if TYPE_CHECKING:
CounterStr = Counter[str] # pragma: no cover
else:
CounterStr = Counter


class DecayModeDict(TypedDict):
bf: float
fs: Sequence[str | DecayChainDict]
model: str
model_params: str | Sequence[str | Any]


DecayChainDict = Dict[str, list[DecayModeDict]]


class DaughtersDict(CounterStr):
"""
Class holding a decay final state as a dictionary.
Expand Down Expand Up @@ -107,9 +117,7 @@ def to_list(self) -> list[str]:
"""
return sorted(self.elements())

def charge_conjugate(
self: Self_DaughtersDict, pdg_name: bool = False
) -> Self_DaughtersDict:
def charge_conjugate(self, pdg_name: bool = False) -> Self:
"""
Return the charge-conjugate final state.
Expand Down Expand Up @@ -156,7 +164,7 @@ def __len__(self) -> int:
"""
return sum(self.values())

def __add__(self: Self_DaughtersDict, other: Self_DaughtersDict) -> Self_DaughtersDict: # type: ignore[override]
def __add__(self, other: Self) -> Self: # type: ignore[override]
"""
Add two final states, particle-type-wise.
"""
Expand All @@ -167,9 +175,6 @@ def __iter__(self) -> Iterator[str]:
return self.elements()


Self_DecayMode = TypeVar("Self_DecayMode", bound="DecayMode")


class DecayMode:
"""
Class holding a particle decay mode, which is typically a branching fraction
Expand Down Expand Up @@ -254,9 +259,9 @@ def __init__(

@classmethod
def from_dict(
cls: type[Self_DecayMode],
decay_mode_dict: dict[str, int | float | str | list[str]],
) -> Self_DecayMode:
cls,
decay_mode_dict: DecayModeDict,
) -> Self:
"""
Constructor from a dictionary of the form
{'bf': <float>, 'fs': [...], ...}.
Expand Down Expand Up @@ -291,21 +296,18 @@ def from_dict(
dm = deepcopy(decay_mode_dict)

# Ensure the input dict has the 2 required keys 'bf' and 'fs'
try:
bf = dm.pop("bf")
daughters = dm.pop("fs")
except KeyError as e:
raise RuntimeError("Input not in the expected format!") from e
if not dm.keys() >= {"bf", "fs"}:
raise RuntimeError("Input not in the expected format! Needs 'bf' and 'fs'")

return cls(bf=bf, daughters=daughters, **dm) # type: ignore[arg-type]
return cls(**dm)

@classmethod
def from_pdgids(
cls: type[Self_DecayMode],
cls,
bf: float = 0,
daughters: list[int] | tuple[int] | None = None,
**info: Any,
) -> Self_DecayMode:
) -> Self:
"""
Constructor for a final state given as a list of particle PDG IDs.
Expand Down Expand Up @@ -397,9 +399,7 @@ def to_dict(self) -> dict[str, int | float | str | list[str]]:
d["model_params"] = ""
return d # type: ignore[return-value]

def charge_conjugate(
self: Self_DecayMode, pdg_name: bool = False
) -> Self_DecayMode:
def charge_conjugate(self, pdg_name: bool = False) -> Self:
"""
Return the charge-conjugate decay mode.
Expand Down Expand Up @@ -444,11 +444,6 @@ def __str__(self) -> str:
return repr(self)


Self_DecayChain = TypeVar("Self_DecayChain", bound="DecayChain")
DecayModeDict = Dict[str, Union[float, str, List[Any]]]
DecayChainDict = Dict[str, List[DecayModeDict]]


def _has_no_subdecay(ds: list[Any]) -> bool:
"""
Internal function to check whether the input list
Expand Down Expand Up @@ -772,9 +767,7 @@ def __init__(self, mother: str, decays: dict[str, DecayMode]) -> None:
self.decays = decays

@classmethod
def from_dict(
cls: type[Self_DecayChain], decay_chain_dict: DecayChainDict
) -> Self_DecayChain:
def from_dict(cls, decay_chain_dict: DecayChainDict) -> Self:
"""
Constructor from a decay chain represented as a dictionary.
The format is the same as that returned by
Expand Down Expand Up @@ -905,9 +898,9 @@ def _print(
for i_decay in decay_dict[mother]:
print(prefix, arrow if depth > 0 else "", mother, sep="") # noqa: T201
fsps = i_decay["fs"]
n = len(list(fsps)) # type: ignore[arg-type]
n = len(list(fsps))
depth += 1
for j, fsp in enumerate(fsps): # type: ignore[arg-type]
for j, fsp in enumerate(fsps):
prefix = bar if (link and depth > 1) else ""
if last:
prefix = prefix + " " * indent * (depth - 1) + " "
Expand Down Expand Up @@ -966,9 +959,9 @@ def recursively_replace(mother: str) -> DecayDict:
return recursively_replace(self.mother)

def flatten(
self: Self_DecayChain,
self,
stable_particles: Iterable[dict[str, int] | list[str] | str] = (),
) -> Self_DecayChain:
) -> Self:
"""
Flatten the decay chain replacing all intermediate, decaying particles,
with their final states.
Expand Down

0 comments on commit a65741d

Please sign in to comment.